您的位置:首页 > 教育 > 培训 > 网站平台维护_网站建设素材_网络培训班_百家号seo怎么做

网站平台维护_网站建设素材_网络培训班_百家号seo怎么做

2024/10/6 13:48:59 来源:https://blog.csdn.net/weixin_32393347/article/details/142650680  浏览:    关键词:网站平台维护_网站建设素材_网络培训班_百家号seo怎么做
网站平台维护_网站建设素材_网络培训班_百家号seo怎么做

cnn手写体识别

1. 基本介绍

  1. 手写体识别,是指对图像进行识别,判断图像中的内容是否为手写文字。
  2. 本项目是一手写数字识别为主,采用的模型是cnn。

1.1 步骤

  1. 数据集:MNIST手写数字数据集,该数据集包含60000张训练图片和10000张测试图片,每张图片大小为28*28,共10个类别。
  2. python的框架是pytorch,使用pytorch的框架进行训练和测试。
  3. 识别准确率为,98%
  4. 模型转化:将pytorch的模型转化为onnx格式,方便在安卓端使用。
  5. 以java的代码推理模型,在安卓端或者其他环境中实现手写数字识别。

1.2 项目结构

.
├── DNS_tunnel_detect
│   ├── DNS_tunnel_detect.iml
│   ├── README.md
│   ├── bin
│   ├── lib
│   ├── out
│   ├── source
│   └── src
├── cnn_py
│   ├── data
│   ├── main.py
│   └── model
├── model2onnx
│   ├── model
│   ├── model2onnx.py
│   └── test_onnx_model.py
└── 第3集: java落地AI项目案例:cnn手写字体识别.md

1.3 模型结构

第一层包含卷积、批量归一化、ReLU激活和最大池化操作;
第二层结构相同但输出通道数为32;
全连接层将前一层输出扁平化后接分类器。

import torch
import torch.nn as nn# Convolutional neural network (two convolutional layers)
class ConvNet(nn.Module):def __init__(self, num_classes=10):super(ConvNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.fc = nn.Linear(7*7*32, num_classes)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = out.reshape(out.size(0), -1)out = self.fc(out)return out

2.训练

model = ConvNet(num_classes).to(device)
print(model)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)print(images.size())# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

3.测试模型

# Test the model
model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))# Save the model checkpoint
torch.save(model.state_dict(), './model/model.ckpt')

在这里插入图片描述

4. 模型转化

4.1 模型转化

import os
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn as nnclass ConvNet(nn.Module):def __init__(self, num_classes=10):super(ConvNet, self).__init__()self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(16),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),nn.BatchNorm2d(32),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2))self.fc = nn.Linear(7*7*32, num_classes)def forward(self, x):out = self.layer1(x)out = self.layer2(out)out = out.reshape(out.size(0), -1)out = self.fc(out)return outdevice = torch.device("cpu")
num_classes = 10
model = ConvNet(num_classes).to(device)
print(model)model.load_state_dict(torch.load('../cnn_py/model/model.ckpt',map_location=device))sample_input = torch.rand((1,1,28,28)).to(device)
print(sample_input)model.eval()
with torch.no_grad():outputs = model(sample_input)print("output:",outputs)_, predicted = torch.max(outputs.data, 1)print("predicted:",predicted)torch.onnx.export(model,sample_input,'./model/model.onnx',input_names=["input"],output_names=["output"],export_params=True,       # 是否保存模型参数do_constant_folding=True)	# 是否执行常量折叠优化torch.cuda.empty_cache()

在这里插入图片描述

4.2 pytorch模型转化为onnx模型

import os
import warnings
warnings.filterwarnings('ignore')import onnxruntime
import torchinput_data = torch.rand(1,1,28,28)
session = onnxruntime.InferenceSession("./model/model.onnx")
input_name = session.get_inputs()[0].name
result = session.run([], {input_name: input_data.numpy()})
print("result: ",result)
print(result[0][0])
max_value = max(list(result[0][0]))
predict = list(result[0][0]).index(max_value)
print(predict)

在这里插入图片描述

5. java端使用onnx模型进行预测

  • 需要安装onnxruntime库

import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtUtil;import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;public class App {public static void main(String[] args) throws Exception {String model_path = "./source/model.onnx";System.out.println(model_path);float[][][][] feature = new float[1][1][28][28];// 初始化数组元素for (int i = 0; i < 1; i++) {for (int j = 0; j < 1; j++) {for (int k = 0; k < 28; k++) {for (int l = 0; l < 28; l++) {feature[i][j][k][l] = (i + 1) * (j + 1) * (k + 1) * (l + 1);}}}}System.out.println(Arrays.toString(feature));OrtEnvironment env = OrtEnvironment.getEnvironment();OrtSession.Result res = null;try (OrtSession session = env.createSession(model_path)){Map<String, OnnxTensor> container = new HashMap<>();OnnxTensor inputTensor = OnnxTensor.createTensor(env, feature);container.put("input", inputTensor);try(OrtSession.Result result = session.run(container)){OnnxTensor outputTensor = (OnnxTensor) result.get(0);float[][] result88 = (float[][])outputTensor.getValue();System.out.println(Arrays.toString(result88));for (int i = 0; i < result88.length; i++) {for (int j = 0; j < result88[i].length; j++) {System.out.println(result88[i][j]);}}}OnnxValue.close(container);}catch (OrtException e) {throw new RuntimeException(e);} finally {System.out.println("all done");}}
}

6.总结

  1. 完成手写字体的python脚本训练和测试
  2. 完成onnx模型转化
  3. 完成java端使用onnx模型进行预测

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com