cnn手写体识别
1. 基本介绍
- 手写体识别,是指对图像进行识别,判断图像中的内容是否为手写文字。
- 本项目是一手写数字识别为主,采用的模型是cnn。
1.1 步骤
- 数据集:MNIST手写数字数据集,该数据集包含60000张训练图片和10000张测试图片,每张图片大小为28*28,共10个类别。
- python的框架是pytorch,使用pytorch的框架进行训练和测试。
- 识别准确率为,98%
- 模型转化:将pytorch的模型转化为onnx格式,方便在安卓端使用。
- 以java的代码推理模型,在安卓端或者其他环境中实现手写数字识别。
1.2 项目结构
.
├── DNS_tunnel_detect
│ ├── DNS_tunnel_detect.iml
│ ├── README.md
│ ├── bin
│ ├── lib
│ ├── out
│ ├── source
│ └── src
├── cnn_py
│ ├── data
│ ├── main.py
│ └── model
├── model2onnx
│ ├── model
│ ├── model2onnx.py
│ └── test_onnx_model.py
└── 第3集: java落地AI项目案例:cnn手写字体识别.md
1.3 模型结构
第一层包含卷积、批量归一化、ReLU激活和最大池化操作;
第二层结构相同但输出通道数为32;
全连接层将前一层输出扁平化后接分类器。
import torch
import torch.nn as nn# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):def __init__(self, num_classes=10):super(ConvNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.fc = nn.Linear(7*7*32, num_classes)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = out.reshape(out.size(0), -1)out = self.fc(out)return out
2.训练
model = ConvNet(num_classes).to(device)
print(model)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)print(images.size())# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
3.测试模型
# Test the model
model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))# Save the model checkpoint
torch.save(model.state_dict(), './model/model.ckpt')
4. 模型转化
4.1 模型转化
import os
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn as nnclass ConvNet(nn.Module):def __init__(self, num_classes=10):super(ConvNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.fc = nn.Linear(7*7*32, num_classes)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = out.reshape(out.size(0), -1)out = self.fc(out)return outdevice = torch.device("cpu")
num_classes = 10
model = ConvNet(num_classes).to(device)
print(model)model.load_state_dict(torch.load('../cnn_py/model/model.ckpt',map_location=device))sample_input = torch.rand((1,1,28,28)).to(device)
print(sample_input)model.eval()
with torch.no_grad():outputs = model(sample_input)print("output:",outputs)_, predicted = torch.max(outputs.data, 1)print("predicted:",predicted)torch.onnx.export(model,sample_input,'./model/model.onnx',input_names=["input"],output_names=["output"],export_params=True, # 是否保存模型参数do_constant_folding=True) # 是否执行常量折叠优化torch.cuda.empty_cache()
4.2 pytorch模型转化为onnx模型
import os
import warnings
warnings.filterwarnings('ignore')import onnxruntime
import torchinput_data = torch.rand(1,1,28,28)
session = onnxruntime.InferenceSession("./model/model.onnx")
input_name = session.get_inputs()[0].name
result = session.run([], {input_name: input_data.numpy()})
print("result: ",result)
print(result[0][0])
max_value = max(list(result[0][0]))
predict = list(result[0][0]).index(max_value)
print(predict)
5. java端使用onnx模型进行预测
- 需要安装onnxruntime库
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtUtil;import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;public class App {public static void main(String[] args) throws Exception {String model_path = "./source/model.onnx";System.out.println(model_path);float[][][][] feature = new float[1][1][28][28];// 初始化数组元素for (int i = 0; i < 1; i++) {for (int j = 0; j < 1; j++) {for (int k = 0; k < 28; k++) {for (int l = 0; l < 28; l++) {feature[i][j][k][l] = (i + 1) * (j + 1) * (k + 1) * (l + 1);}}}}System.out.println(Arrays.toString(feature));OrtEnvironment env = OrtEnvironment.getEnvironment();OrtSession.Result res = null;try (OrtSession session = env.createSession(model_path)){Map<String, OnnxTensor> container = new HashMap<>();OnnxTensor inputTensor = OnnxTensor.createTensor(env, feature);container.put("input", inputTensor);try(OrtSession.Result result = session.run(container)){OnnxTensor outputTensor = (OnnxTensor) result.get(0);float[][] result88 = (float[][])outputTensor.getValue();System.out.println(Arrays.toString(result88));for (int i = 0; i < result88.length; i++) {for (int j = 0; j < result88[i].length; j++) {System.out.println(result88[i][j]);}}}OnnxValue.close(container);}catch (OrtException e) {throw new RuntimeException(e);} finally {System.out.println("all done");}}
}
6.总结
- 完成手写字体的python脚本训练和测试
- 完成onnx模型转化
- 完成java端使用onnx模型进行预测