本文目录导读:

我将为您提供一个使用Java实现图像识别的完整案例,这里使用TensorFlow Java API和预训练的深度学习模型来实现图像分类。
环境准备
Maven依赖 (pom.xml)
<dependencies>
<!-- TensorFlow Java API -->
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.15.0</version>
</dependency>
<!-- 图像处理 -->
<dependency>
<groupId>net.coobird</groupId>
<artifactId>thumbnailator</artifactId>
<version>0.4.8</version>
</dependency>
<!-- JSON处理 -->
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.6</version>
</dependency>
</dependencies>
图像识别核心类
package com.example.imagerecognition;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.types.UInt8;
import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferByte;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.*;
public class ImageRecognizer {
private byte[] graphDef;
private List<String> labels;
private Graph graph;
public ImageRecognizer(String modelPath, String labelsPath) throws IOException {
// 加载模型
this.graphDef = readAllBytes(Paths.get(modelPath));
this.labels = readLabels(labelsPath);
// 加载TensorFlow图
this.graph = new Graph();
graph.importGraphDef(graphDef);
}
/**
* 识别图像
* @param imagePath 图像路径
* @return 识别结果列表(按置信度排序)
*/
public List<RecognitionResult> recognize(String imagePath) throws IOException {
BufferedImage image = ImageIO.read(Paths.get(imagePath).toFile());
return recognize(image);
}
/**
* 识别图像
* @param image BufferedImage对象
* @return 识别结果列表
*/
public List<RecognitionResult> recognize(BufferedImage image) throws IOException {
// 预处理图像
Tensor<Float> imageTensor = preprocessImage(image);
// 运行模型
try (Session session = new Session(graph);
Tensor<Float> result = session.runner()
.feed("input", imageTensor)
.fetch("output")
.run()
.get(0)
.expect(Float.class)) {
// 解析结果
return parseResult(result);
}
}
/**
* 预处理图像
*/
private Tensor<Float> preprocessImage(BufferedImage image) throws IOException {
// 调整图像大小为224x224 (InceptionV3/MobileNet)
BufferedImage resizedImage = new BufferedImage(224, 224, BufferedImage.TYPE_3BYTE_BGR);
resizedImage.getGraphics().drawImage(image, 0, 0, 224, 224, null);
// 将图像转换为字节数组
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ImageIO.write(resizedImage, "jpg", baos);
byte[] imageBytes = baos.toByteArray();
// 创建Tensor
return Tensor.create(new long[]{1, 224, 224, 3}, Float.class)
.copyFrom(convertImageToFloat(resizedImage));
}
/**
* 将BufferedImage转换为float数组
*/
private float[][][][] convertImageToFloat(BufferedImage image) {
int width = image.getWidth();
int height = image.getHeight();
float[][][][] result = new float[1][height][width][3];
for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
int rgb = image.getRGB(x, y);
// 归一化到[-1, 1]范围 (适用于MobileNet)
result[0][y][x][0] = ((rgb >> 16) & 0xFF) / 127.5f - 1.0f; // R
result[0][y][x][1] = ((rgb >> 8) & 0xFF) / 127.5f - 1.0f; // G
result[0][y][x][2] = (rgb & 0xFF) / 127.5f - 1.0f; // B
}
}
return result;
}
/**
* 解析模型输出结果
*/
private List<RecognitionResult> parseResult(Tensor<Float> tensor) {
float[][] probabilities = tensor.copyTo(new float[1][labels.size()]);
List<RecognitionResult> results = new ArrayList<>();
for (int i = 0; i < labels.size(); i++) {
results.add(new RecognitionResult(labels.get(i), probabilities[0][i]));
}
// 按置信度降序排序
Collections.sort(results, (a, b) -> Float.compare(b.getConfidence(), a.getConfidence()));
return results;
}
/**
* 读取标签文件
*/
private List<String> readLabels(String labelsPath) throws IOException {
List<String> labels = new ArrayList<>();
Files.lines(Paths.get(labelsPath))
.forEach(line -> labels.add(line.trim()));
return labels;
}
/**
* 读取文件所有字节
*/
private byte[] readAllBytes(Path path) throws IOException {
return Files.readAllBytes(path);
}
/**
* 关闭资源
*/
public void close() {
if (graph != null) {
graph.close();
}
}
/**
* 识别结果类
*/
public static class RecognitionResult {
private String label;
private float confidence;
public RecognitionResult(String label, float confidence) {
this.label = label;
this.confidence = confidence;
}
public String getLabel() { return label; }
public float getConfidence() { return confidence; }
@Override
public String toString() {
return String.format("%s: %.2f%%", label, confidence * 100);
}
}
}
主程序示例
package com.example.imagerecognition;
import javax.swing.*;
import java.awt.*;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.File;
import java.io.IOException;
import java.util.List;
public class ImageRecognitionDemo {
private JFrame frame;
private JLabel imageLabel;
private JTextArea resultArea;
private ImageRecognizer recognizer;
public ImageRecognitionDemo() throws IOException {
// 初始化识别器(需要下载模型文件)
String modelPath = "models/mobilenet_v1_1.0_224_frozen.pb";
String labelsPath = "models/labels.txt";
// 如果模型不存在,使用备用方案
if (!new File(modelPath).exists()) {
System.out.println("模型文件不存在,使用模拟识别功能");
recognizer = null;
} else {
recognizer = new ImageRecognizer(modelPath, labelsPath);
}
initUI();
}
private void initUI() {
frame = new JFrame("Java图像识别演示");
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
frame.setLayout(new BorderLayout(10, 10));
// 图像显示面板
imageLabel = new JLabel("请选择图片", SwingConstants.CENTER);
imageLabel.setPreferredSize(new Dimension(400, 400));
imageLabel.setBorder(BorderFactory.createLineBorder(Color.GRAY));
// 结果文本区域
resultArea = new JTextArea(10, 40);
resultArea.setEditable(false);
resultArea.setFont(new Font("Monospaced", Font.PLAIN, 12));
JScrollPane scrollPane = new JScrollPane(resultArea);
// 按钮面板
JPanel buttonPanel = new JPanel();
JButton selectButton = new JButton("选择图片");
JButton quitButton = new JButton("退出");
selectButton.addActionListener(new ActionListener() {
@Override
public void actionPerformed(ActionEvent e) {
selectAndRecognizeImage();
}
});
quitButton.addActionListener(e -> System.exit(0));
buttonPanel.add(selectButton);
buttonPanel.add(quitButton);
// 布局
frame.add(imageLabel, BorderLayout.CENTER);
frame.add(scrollPane, BorderLayout.SOUTH);
frame.add(buttonPanel, BorderLayout.NORTH);
frame.pack();
frame.setLocationRelativeTo(null);
frame.setVisible(true);
}
private void selectAndRecognizeImage() {
JFileChooser fileChooser = new JFileChooser();
fileChooser.setFileFilter(new javax.swing.filechooser.FileNameExtensionFilter(
"图片文件", "jpg", "jpeg", "png", "gif", "bmp"));
int result = fileChooser.showOpenDialog(frame);
if (result == JFileChooser.APPROVE_OPTION) {
File selectedFile = fileChooser.getSelectedFile();
try {
// 显示图片
ImageIcon imageIcon = new ImageIcon(selectedFile.getAbsolutePath());
Image image = imageIcon.getImage();
Image scaledImage = image.getScaledInstance(400, 400, Image.SCALE_SMOOTH);
imageLabel.setIcon(new ImageIcon(scaledImage));
// 执行识别
if (recognizer != null) {
List<ImageRecognizer.RecognitionResult> results =
recognizer.recognize(selectedFile.getAbsolutePath());
displayResults(results);
} else {
// 模拟识别结果
simulateRecognition();
}
} catch (IOException e) {
resultArea.setText("识别失败: " + e.getMessage());
}
}
}
private void displayResults(List<ImageRecognizer.RecognitionResult> results) {
StringBuilder sb = new StringBuilder();
sb.append("=== 识别结果 ===\n\n");
sb.append("前5个最可能的类别:\n");
int count = Math.min(5, results.size());
for (int i = 0; i < count; i++) {
ImageRecognizer.RecognitionResult result = results.get(i);
sb.append(String.format("%d. %s\n", i + 1, result));
}
resultArea.setText(sb.toString());
}
private void simulateRecognition() {
String[] labels = {"猫", "狗", "花", "汽车", "建筑"};
double[] confidences = {0.85, 0.72, 0.68, 0.45, 0.30};
StringBuilder sb = new StringBuilder();
sb.append("=== 模拟识别结果 ===\n\n");
sb.append("注意:这是模拟数据,需要模型文件才能真实识别\n\n");
sb.append("前5个最可能的类别:\n");
for (int i = 0; i < labels.length; i++) {
sb.append(String.format("%d. %s: %.2f%%\n",
i + 1, labels[i], confidences[i] * 100));
}
sb.append("\n---\n要启用真实识别,请下载MobileNet模型文件:\n");
sb.append("1. 从TensorFlow官网下载mobilenet_v1_1.0_224_frozen.pb\n");
sb.append("2. 下载对应的labels.txt\n");
sb.append("3. 放到models/目录下");
resultArea.setText(sb.toString());
}
public static void main(String[] args) {
SwingUtilities.invokeLater(new Runnable() {
@Override
public void run() {
try {
new ImageRecognitionDemo();
} catch (IOException e) {
e.printStackTrace();
}
}
});
}
}
简单图像识别(不使用深度学习框架)
如果不想使用TensorFlow,这里提供一个简单的基于像素比较的颜色识别器:
package com.example.imagerecognition;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.*;
import java.util.List;
public class SimpleColorRecognizer {
/**
* 识别图像中的主要颜色
*/
public static List<ColorResult> recognizeMainColors(String imagePath) throws IOException {
BufferedImage image = ImageIO.read(new File(imagePath));
return analyzeColors(image);
}
private static List<ColorResult> analyzeColors(BufferedImage image) {
int width = image.getWidth();
int height = image.getHeight();
// 颜色计数器
Map<String, Integer> colorCount = new HashMap<>();
// 采样像素(每10像素采样一次以减少计算量)
for (int y = 0; y < height; y += 10) {
for (int x = 0; x < width; x += 10) {
int rgb = image.getRGB(x, y);
String colorName = getColorName(new Color(rgb));
colorCount.put(colorName, colorCount.getOrDefault(colorName, 0) + 1);
}
}
// 计算总采样点
int totalSamples = colorCount.values().stream().mapToInt(Integer::intValue).sum();
// 排序并返回结果
List<ColorResult> results = new ArrayList<>();
colorCount.forEach((name, count) -> {
double percentage = (double) count / totalSamples * 100;
results.add(new ColorResult(name, percentage));
});
results.sort((a, b) -> Double.compare(b.getPercentage(), a.getPercentage()));
return results;
}
private static String getColorName(Color color) {
int red = color.getRed();
int green = color.getGreen();
int blue = color.getBlue();
// 简化的颜色分类
if (red > 200 && green > 200 && blue > 200) return "白色";
if (red < 50 && green < 50 && blue < 50) return "黑色";
if (red > 200 && green < 100 && blue < 100) return "红色";
if (red < 100 && green > 200 && blue < 100) return "绿色";
if (red < 100 && green < 100 && blue > 200) return "蓝色";
if (red > 200 && green > 200 && blue < 100) return "黄色";
if (red > 200 && green < 100 && blue > 200) return "紫色";
if (red < 100 && green > 200 && blue > 200) return "青色";
if (red > 150 && green < 150 && blue < 150) return "橙色";
if (red > 150 && green > 150 && blue < 150) return "棕色";
return "其他";
}
public static class ColorResult {
private String colorName;
private double percentage;
public ColorResult(String colorName, double percentage) {
this.colorName = colorName;
this.percentage = percentage;
}
public String getColorName() { return colorName; }
public double getPercentage() { return percentage; }
@Override
public String toString() {
return String.format("%s: %.1f%%", colorName, percentage);
}
}
public static void main(String[] args) {
try {
List<ColorResult> results = recognizeMainColors("test.jpg");
System.out.println("图像主要颜色分析:");
results.forEach(System.out::println);
} catch (IOException e) {
e.printStackTrace();
}
}
}
使用说明
下载模型文件
- 访问 TensorFlow Models 下载MobileNet模型
- 下载对应的标签文件
运行程序
# 编译 javac -cp "lib/*" ImageRecognitionDemo.java # 运行 java -cp "lib/*;." ImageRecognitionDemo
注意事项
- 模型兼容性:确保TensorFlow模型版本与Java API版本匹配
- 性能优化:处理大图像时可使用缩略图来提高速度
- 内存管理:及时释放Tensor和Session资源
- 异常处理:添加适当的异常处理机制
这个案例提供了完整的图像识别实现,从基础的图像处理到使用深度学习模型进行真实识别,您可以根据需求选择合适的实现方式。