Modernize Image Recognition lab
This commit is contained in:
@@ -10,6 +10,10 @@ import java.nio.file.Files;
|
|||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.function.BiConsumer;
|
||||||
|
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -24,10 +28,11 @@ public class ProductRecognitionNewModel {
|
|||||||
// private static final int SAMPLE_SIZE = 10;
|
// private static final int SAMPLE_SIZE = 10;
|
||||||
private static final int RESIZE_WIDTH = 100;
|
private static final int RESIZE_WIDTH = 100;
|
||||||
private static final int RESIZE_HEIGHT = 100;
|
private static final int RESIZE_HEIGHT = 100;
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(ProductRecognitionNewModel.class);
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
try {
|
try {
|
||||||
System.out.println("Starting Product Recognition Lab");
|
log.info("Starting Product Recognition Lab");
|
||||||
|
|
||||||
// Step 1: Create necessary directories
|
// Step 1: Create necessary directories
|
||||||
File projectDir = new File("ImageRecognitionLab");
|
File projectDir = new File("ImageRecognitionLab");
|
||||||
@@ -49,15 +54,14 @@ public class ProductRecognitionNewModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
System.out.printf("No images found in directory: %s%n", imagesDir);
|
log.info("No images found in directory: {}", (Object) imageFiles);
|
||||||
System.out.println("Please add some product images to the 'images' folder.");
|
log.info("Please add some product images to the 'images' folder.");
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Product recognition completed successfully!");
|
log.info("Product recognition completed successfully!");
|
||||||
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
System.err.println("Error occurred: " + e.getMessage());
|
log.error("Error occurred: {}", e.getMessage(), e);
|
||||||
e.printStackTrace();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,12 +69,12 @@ public class ProductRecognitionNewModel {
|
|||||||
* Process a single image file
|
* Process a single image file
|
||||||
*/
|
*/
|
||||||
private static void processImage(File imageFile, Path outputDir) throws IOException {
|
private static void processImage(File imageFile, Path outputDir) throws IOException {
|
||||||
System.out.println("Analyzing image: " + imageFile.getName());
|
log.info("Analyzing image: {}", imageFile.getName());
|
||||||
|
|
||||||
// Step 1: Load the image
|
// Step 1: Load the image
|
||||||
BufferedImage originalImage = ImageIO.read(imageFile);
|
BufferedImage originalImage = ImageIO.read(imageFile);
|
||||||
if (originalImage == null) {
|
if (originalImage == null) {
|
||||||
System.err.println("Failed to load image: " + imageFile.getName());
|
log.error("Failed to load image: {}", imageFile.getName());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,11 +85,11 @@ public class ProductRecognitionNewModel {
|
|||||||
List<Prediction> predictions = classifyImage(features, imageFile.getName());
|
List<Prediction> predictions = classifyImage(features, imageFile.getName());
|
||||||
|
|
||||||
// Step 4: Print the results
|
// Step 4: Print the results
|
||||||
System.out.println("Top 5 predictions for " + imageFile.getName() + ":");
|
log.info("Top 5 predictions for {}:", imageFile.getName());
|
||||||
for (Prediction p : predictions) {
|
for (Prediction p : predictions) {
|
||||||
System.out.printf("%-30s: %.2f%%\n", p.getLabel(), p.getProbability() * 100);
|
System.out.printf("%-30s: %.2f%%\n", p.label(), p.probability() * 100);
|
||||||
}
|
}
|
||||||
System.out.println("-----------------------------------------\n");
|
log.info("-----------------------------------------\n");
|
||||||
|
|
||||||
// Step 5: Save the results to a file
|
// Step 5: Save the results to a file
|
||||||
saveResultsToFile(imageFile.getName(), predictions, outputDir);
|
saveResultsToFile(imageFile.getName(), predictions, outputDir);
|
||||||
@@ -218,6 +222,8 @@ public class ProductRecognitionNewModel {
|
|||||||
*/
|
*/
|
||||||
private static List<Prediction> classifyImage(Map<String, Double> features, String filename) {
|
private static List<Prediction> classifyImage(Map<String, Double> features, String filename) {
|
||||||
List<Prediction> predictions = new ArrayList<>();
|
List<Prediction> predictions = new ArrayList<>();
|
||||||
|
BiConsumer<Integer,Double> adjustProbability = (index, probability) ->
|
||||||
|
predictions.set(index, predictions.get(index).probability(probability));
|
||||||
|
|
||||||
// Using filename for simulation, since this is just a demonstration
|
// Using filename for simulation, since this is just a demonstration
|
||||||
filename = filename.toLowerCase();
|
filename = filename.toLowerCase();
|
||||||
@@ -241,42 +247,42 @@ public class ProductRecognitionNewModel {
|
|||||||
|
|
||||||
// Dark colors with high edge density might be electronic devices
|
// Dark colors with high edge density might be electronic devices
|
||||||
if (avgRed < 0.5 && avgGreen < 0.5 && avgBlue < 0.5 && edgeDensity > 0.1) {
|
if (avgRed < 0.5 && avgGreen < 0.5 && avgBlue < 0.5 && edgeDensity > 0.1) {
|
||||||
predictions.get(0).setProbability(0.6); // laptop
|
adjustProbability.accept(0, 0.6);
|
||||||
predictions.get(4).setProbability(0.3); // smartphone
|
adjustProbability.accept(4, 0.3);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Blue tones might suggest water bottles
|
// Blue tones might suggest water bottles
|
||||||
if (avgBlue > avgRed && avgBlue > avgGreen) {
|
if (avgBlue > avgRed && avgBlue > avgGreen) {
|
||||||
predictions.get(1).setProbability(0.7); // water bottle
|
adjustProbability.accept(1, 0.7); // water bottle
|
||||||
}
|
}
|
||||||
|
|
||||||
// High uniformity might suggest solid objects like mugs
|
// High uniformity might suggest solid objects like mugs
|
||||||
if (textureUniformity > 0.1 && avgRed > 0.3) {
|
if (textureUniformity > 0.1 && avgRed > 0.3) {
|
||||||
predictions.get(2).setProbability(0.65); // coffee mug
|
adjustProbability.accept(2, 0.65); // coffee mug
|
||||||
}
|
}
|
||||||
|
|
||||||
// Medium brightness with texture might be books
|
// Medium brightness with texture might be books
|
||||||
if (avgRed > 0.3 && avgRed < 0.7 && textureUniformity < 0.1) {
|
if (avgRed > 0.3 && avgRed < 0.7 && textureUniformity < 0.1) {
|
||||||
predictions.get(3).setProbability(0.55); // book
|
adjustProbability.accept(3, 0.55); // book
|
||||||
}
|
}
|
||||||
|
|
||||||
// Override with filename-based simulated results for this demo
|
// Override with filename-based simulated results for this demo
|
||||||
if (filename.contains("laptop")) {
|
if (filename.contains("laptop")) {
|
||||||
predictions.get(0).setProbability(0.92); // laptop
|
adjustProbability.accept(0,0.92); // laptop
|
||||||
predictions.get(4).setProbability(0.05); // smartphone
|
adjustProbability.accept(4,0.05); // smartphone
|
||||||
} else if (filename.contains("bottle") || filename.contains("water")) {
|
} else if (filename.contains("bottle") || filename.contains("water")) {
|
||||||
predictions.get(1).setProbability(0.89); // water bottle
|
adjustProbability.accept(1,0.89); // water bottle
|
||||||
} else if (filename.contains("mug") || filename.contains("coffee")) {
|
} else if (filename.contains("mug") || filename.contains("coffee")) {
|
||||||
predictions.get(2).setProbability(0.94); // coffee mug
|
adjustProbability.accept(2,0.94); // coffee mug
|
||||||
} else if (filename.contains("book")) {
|
} else if (filename.contains("book")) {
|
||||||
predictions.get(3).setProbability(0.91); // book
|
adjustProbability.accept(3,0.91); // book
|
||||||
} else if (filename.contains("phone") || filename.contains("smartphone")) {
|
} else if (filename.contains("phone") || filename.contains("smartphone")) {
|
||||||
predictions.get(4).setProbability(0.95); // smartphone
|
adjustProbability.accept(4,0.95); // smartphone
|
||||||
predictions.get(0).setProbability(0.03); // laptop
|
adjustProbability.accept(0,0.03); // laptop
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sort by probability
|
// Sort by probability
|
||||||
predictions.sort((a, b) -> Double.compare(b.getProbability(), a.getProbability()));
|
predictions.sort((a, b) -> Double.compare(b.probability(), a.probability()));
|
||||||
|
|
||||||
return predictions;
|
return predictions;
|
||||||
}
|
}
|
||||||
@@ -296,9 +302,7 @@ public class ProductRecognitionNewModel {
|
|||||||
writer.write("-----------------------------------------\n");
|
writer.write("-----------------------------------------\n");
|
||||||
|
|
||||||
for (Prediction p : results) {
|
for (Prediction p : results) {
|
||||||
writer.write(String.format("%-30s: %.2f%%\n",
|
writer.write(String.format("%-30s: %.2f%%\n", p.label(), p.probability() * 100));
|
||||||
p.getLabel(),
|
|
||||||
p.getProbability() * 100));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
writer.write("\n\nImage Analysis:\n");
|
writer.write("\n\nImage Analysis:\n");
|
||||||
@@ -309,7 +313,7 @@ public class ProductRecognitionNewModel {
|
|||||||
writer.write("thousands of product images.\n");
|
writer.write("thousands of product images.\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.println("Results saved to file: " + outputFile);
|
log.info("Results saved to file: {}", outputFile);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -345,32 +349,13 @@ public class ProductRecognitionNewModel {
|
|||||||
ImageIO.write(processed, "jpg", outputFile.toFile());
|
ImageIO.write(processed, "jpg", outputFile.toFile());
|
||||||
|
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
System.err.println("Error saving processed image: " + e.getMessage());
|
log.error("Error saving processed image: {}", e.getMessage(),e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
private record Prediction(String label, double probability) {
|
||||||
* A simple class to hold a prediction label and its probability
|
Prediction probability(double probability) {
|
||||||
*/
|
return new Prediction(label, probability);
|
||||||
private static class Prediction {
|
|
||||||
private final String label;
|
|
||||||
private double probability;
|
|
||||||
|
|
||||||
public Prediction(String label, double probability) {
|
|
||||||
this.label = label;
|
|
||||||
this.probability = probability;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getLabel() {
|
|
||||||
return label;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getProbability() {
|
|
||||||
return probability;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setProbability(double probability) {
|
|
||||||
this.probability = probability;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,8 +1,5 @@
|
|||||||
package com.example.img;
|
package com.example.img;
|
||||||
|
|
||||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
|
||||||
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
|
|
||||||
|
|
||||||
import ai.djl.ModelException;
|
import ai.djl.ModelException;
|
||||||
import ai.djl.inference.Predictor;
|
import ai.djl.inference.Predictor;
|
||||||
import ai.djl.modality.Classifications;
|
import ai.djl.modality.Classifications;
|
||||||
@@ -15,17 +12,19 @@ import ai.djl.repository.zoo.Criteria;
|
|||||||
import ai.djl.repository.zoo.ZooModel;
|
import ai.djl.repository.zoo.ZooModel;
|
||||||
import ai.djl.translate.TranslateException;
|
import ai.djl.translate.TranslateException;
|
||||||
import ai.djl.translate.Translator;
|
import ai.djl.translate.Translator;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.FileInputStream;
|
import java.io.FileInputStream;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
import org.springframework.core.io.Resource;
|
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
|
||||||
@SpringBootApplication
|
@SpringBootApplication
|
||||||
public class ProductRecognitionPreTrainedModel {
|
public class ProductRecognitionPreTrainedModel {
|
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(ProductRecognitionPreTrainedModel.class);
|
||||||
|
|
||||||
public static void main(String[] args) throws IOException, ModelException, TranslateException {
|
public static void main(String[] args) throws IOException, ModelException, TranslateException {
|
||||||
// Path to your image file
|
// Path to your image file
|
||||||
String imagePath = "ImageRecognitionLab/images/pill_bottle.png";
|
String imagePath = "ImageRecognitionLab/images/pill_bottle.png";
|
||||||
@@ -37,14 +36,11 @@ public class ProductRecognitionPreTrainedModel {
|
|||||||
|
|
||||||
try (InputStream is = new FileInputStream(imagePath)) {
|
try (InputStream is = new FileInputStream(imagePath)) {
|
||||||
|
|
||||||
// Load image
|
|
||||||
Image img = ImageFactory.getInstance().fromInputStream(is);
|
Image img = ImageFactory.getInstance().fromInputStream(is);
|
||||||
|
|
||||||
// Run prediction
|
|
||||||
Classifications predictions = predict(img);
|
Classifications predictions = predict(img);
|
||||||
|
|
||||||
// Print results
|
log.info("Top 5 Predictions:");
|
||||||
System.out.println("Top 5 Predictions:");
|
|
||||||
predictions.topK(5).forEach(System.out::println);
|
predictions.topK(5).forEach(System.out::println);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user