文章目录
- 1 搭建环境
- 1.1 下载 libtorch c++ 版本
- 1.2 下载数字数据集
- 1.3 目录结构
- 2 源码目录
- 2.1 CMakeLists.txt文件
- 2.2 训练集主程序 main.cpp
- 2.2 dataset.h
- 2.3 model.cpp
- 2.4 识别主程序
- 3 运行结果
- 3.1 训练集程序运行
- 3.2 识别程序
1 搭建环境
1.1 下载 libtorch c++ 版本
由于我当前采用虚拟机形式,使用cpu处理器,没有使用GPU处理器。
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcpu.zip
unzip libtorch-cxx11-abi-shared-with-deps-2.0.0+cpu.zip
1.2 下载数字数据集
wget https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
wget https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
wget https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
wget https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gzgunzip t10k-images-idx3-ubyte.gzgunzip t10k-labels-idx1-ubyte.gzgunzip train-images-idx3-ubyte.gzgunzip train-labels-idx1-ubyte.gz
1.3 目录结构
xhome@ubuntu:~/project_ai$ ls
data image libtorch mnist_demo
2 源码目录
xhome@ubuntu:~/project_ai/mnist_demo$ ls
build CMakeLists.txt srcxhome@ubuntu:~/project_ai/mnist_demo/src$ ls
dataset.h main.cpp model.cpp model.h predict.cpp
2.1 CMakeLists.txt文件
xhome@ubuntu:~/project_ai/mnist_demo$ cat CMakeLists.txt
cmake_minimum_required(VERSION 3.0)
project(mnist_demo)set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)# 直接在这里设置 LibTorch 路径(修改为你的实际路径)
#set(CMAKE_PREFIX_PATH "${CMAKE_CURRENT_SOURCE_DIR}/libtorch")
# 或者使用绝对路径
set(CMAKE_PREFIX_PATH "/home/xhome/project_ai/libtorch")find_package(Torch REQUIRED)
find_package(OpenCV REQUIRED)# 添加源文件
add_executable(mnist_demo src/main.cppsrc/model.cpp
)# 添加预测程序
add_executable(mnist_predictsrc/predict.cppsrc/model.cpp
)# 包含头文件目录
target_include_directories(mnist_demo PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)# 链接 LibTorch
target_link_libraries(mnist_demo "${TORCH_LIBRARIES}" ${OpenCV_LIBS} pthread)
target_link_libraries(mnist_predict ${TORCH_LIBRARIES} ${OpenCV_LIBS} pthread)# 设置编译选项
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
2.2 训练集主程序 main.cpp
xhome@ubuntu:~/project_ai/mnist_demo/src$ cat main.cpp
#include <torch/torch.h>
#include "model.h"
#include "dataset.h"
#include <iostream>
#include <iomanip>
#include <chrono>// 训练一个epoch
template<typename DataLoader>
void train(Net& model,DataLoader& data_loader,torch::optim::Optimizer& optimizer,size_t epoch,size_t dataset_size) {model.train();size_t batch_idx = 0;size_t processed = 0;float last_loss = 0.0f; // 添加这行来存储最后一个loss值for (auto& batch : data_loader) {auto data = batch.data;auto target = batch.target;optimizer.zero_grad();auto output = model.forward(data);auto loss = torch::nll_loss(output, target);loss.backward();optimizer.step();processed += batch.data.size(0);last_loss = loss.template item<float>(); // 保存最后一个loss值if (batch_idx++ % 100 == 0) {std::cout << "\rTrain Epoch: " << epoch << " [" << processed << "/" << dataset_size<< " (" << std::fixed << std::setprecision(1)<< (100.0 * processed / dataset_size)<< "%)] Loss: " << std::fixed << std::setprecision(4)<< last_loss << std::flush;}}// 在epoch结束时显示最终进度std::cout << "\rTrain Epoch: " << epoch << " [" << dataset_size << "/" << dataset_size<< " (100.0%)] Loss: " << std::fixed << std::setprecision(4)<< last_loss << std::endl;
}// 测试模型
template<typename DataLoader>
void test(Net& model,DataLoader& data_loader,size_t dataset_size) {model.eval();double test_loss = 0;int32_t correct = 0;torch::NoGradGuard no_grad;for (const auto& batch : data_loader) {auto data = batch.data;auto target = batch.target;auto output = model.forward(data);test_loss += torch::nll_loss(output, target, /*weight=*/{},torch::Reduction::Sum).template item<float>();auto pred = output.argmax(1);correct += pred.eq(target).sum().template item<int64_t>();}test_loss /= dataset_size;double accuracy = 100.0 * correct / dataset_size;std::cout << "Test set: Average loss: " << std::fixed << std::setprecision(4)<< test_loss << ", Accuracy: " << correct << "/"<< dataset_size << " (" << std::fixed << std::setprecision(1)<< accuracy << "%)\n";
}int main() {try {// 设置随机种子torch::manual_seed(1);// 使用CPUtorch::Device device(torch::kCPU);std::cout << "Training on CPU." << std::endl;// 创建模型auto model = std::make_shared<Net>();model->to(device);// 创建数据集const std::string data_path = "data/MNIST/raw";std::cout << "Loading training dataset..." << std::endl;auto train_dataset = MNISTDataset(data_path, MNISTDataset::Mode::kTrain).map(torch::data::transforms::Stack<>());std::cout << "Loading test dataset..." << std::endl;auto test_dataset = MNISTDataset(data_path, MNISTDataset::Mode::kTest).map(torch::data::transforms::Stack<>());// 获取并存储数据集大小const size_t train_size = train_dataset.size().value();const size_t test_size = test_dataset.size().value();std::cout << "Training dataset size: " << train_size << std::endl;std::cout << "Test dataset size: " << test_size << std::endl;// 创建数据加载器auto train_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(std::move(train_dataset),torch::data::DataLoaderOptions().batch_size(32).workers(1));auto test_loader = torch::data::make_data_loader(std::move(test_dataset),torch::data::DataLoaderOptions().batch_size(100).workers(1));// 创建优化器torch::optim::SGD optimizer(model->parameters(), torch::optim::SGDOptions(0.01).momentum(0.5));// 记录开始时间auto start_time = std::chrono::high_resolution_clock::now();// 训练循环const int num_epochs = 10;for (size_t epoch = 1; epoch <= num_epochs; ++epoch) {std::cout << "\nEpoch " << epoch << "/" << num_epochs << std::endl;train(*model, *train_loader, optimizer, epoch, train_size);test(*model, *test_loader, test_size);}// 计算总训练时间auto end_time = std::chrono::high_resolution_clock::now();auto duration = std::chrono::duration_cast<std::chrono::minutes>(end_time - start_time);std::cout << "\nTraining completed in " << duration.count() << " minutes" << std::endl;// 保存模型torch::save(model, "mnist_cnn.pt");std::cout << "Model saved to mnist_cnn.pt" << std::endl;} catch (const std::exception& e) {std::cerr << "Error: " << e.what() << std::endl;return -1;}return 0;
}
2.2 dataset.h
#ifndef __DATA_SET_HEAD_H
#define __DATA_SET_HEAD_H#include <torch/torch.h>
#include <string>
#include <vector>
#include <fstream>
#include <iostream>class MNISTDataset : public torch::data::Dataset<MNISTDataset> {
public:enum Mode { kTrain, kTest };explicit MNISTDataset(const std::string& root, Mode mode = Mode::kTrain) {std::string prefix = mode == Mode::kTrain ? "train" : "t10k";// 读取图像std::string image_file = root + "/" + prefix + "-images-idx3-ubyte";std::cout << "Loading images from: " << image_file << std::endl;images = read_images(image_file);// 读取标签std::string label_file = root + "/" + prefix + "-labels-idx1-ubyte";std::cout << "Loading labels from: " << label_file << std::endl;labels = read_labels(label_file);std::cout << "Dataset size: " << images.size(0) << std::endl;}torch::data::Example<> get(size_t index) override {return {images[index], labels[index]};}torch::optional<size_t> size() const override {return images.size(0);}private:torch::Tensor images, labels;torch::Tensor read_images(const std::string& path) {std::ifstream file(path, std::ios::binary);if (!file) {throw std::runtime_error("Cannot open file: " + path);}int32_t magic_number = 0, n_images = 0, n_rows = 0, n_cols = 0;file.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));file.read(reinterpret_cast<char*>(&n_images), sizeof(n_images));file.read(reinterpret_cast<char*>(&n_rows), sizeof(n_rows));file.read(reinterpret_cast<char*>(&n_cols), sizeof(n_cols));// 转换字节序magic_number = __builtin_bswap32(magic_number);n_images = __builtin_bswap32(n_images);n_rows = __builtin_bswap32(n_rows);n_cols = __builtin_bswap32(n_cols);// 读取图像数据std::vector<uint8_t> buffer(n_images * n_rows * n_cols);file.read(reinterpret_cast<char*>(buffer.data()), buffer.size());auto tensor = torch::from_blob(buffer.data(), {n_images, n_rows, n_cols}, torch::kUInt8).clone();// 添加通道维度并归一化tensor = tensor.unsqueeze(1).to(torch::kFloat32).div(255.0);return tensor;}torch::Tensor read_labels(const std::string& path) {std::ifstream file(path, std::ios::binary);if (!file) {throw std::runtime_error("Cannot open file: " + path);}int32_t magic_number = 0, n_labels = 0;file.read(reinterpret_cast<char*>(&magic_number), sizeof(magic_number));file.read(reinterpret_cast<char*>(&n_labels), sizeof(n_labels));magic_number = __builtin_bswap32(magic_number);n_labels = __builtin_bswap32(n_labels);std::vector<uint8_t> buffer(n_labels);file.read(reinterpret_cast<char*>(buffer.data()), n_labels);return torch::from_blob(buffer.data(), {n_labels}, torch::kUInt8).clone().to(torch::kLong);}
};#endif
2.3 model.cpp
#include "model.h"Net::Net() {// 第一个卷积块: 1->32 通道, 3x3 卷积核conv1 = register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 32, 3).stride(1).padding(1)));batch_norm1 = register_module("batch_norm1",torch::nn::BatchNorm2d(32));// 第二个卷积块: 32->64 通道, 3x3 卷积核conv2 = register_module("conv2", torch::nn::Conv2d(torch::nn::Conv2dOptions(32, 64, 3).stride(1).padding(1)));batch_norm2 = register_module("batch_norm2",torch::nn::BatchNorm2d(64));// 全连接层fc1 = register_module("fc1", torch::nn::Linear(7 * 7 * 64, 128));fc2 = register_module("fc2", torch::nn::Linear(128, 10));// Dropout层dropout = register_module("dropout",torch::nn::Dropout(torch::nn::DropoutOptions(0.25)));
}torch::Tensor Net::forward(torch::Tensor x) {// 第一个卷积块x = conv1->forward(x);x = batch_norm1->forward(x);x = torch::relu(x);x = torch::max_pool2d(x, 2);// 第二个卷积块x = conv2->forward(x);x = batch_norm2->forward(x);x = torch::relu(x);x = torch::max_pool2d(x, 2);// 展平x = x.view({-1, 7 * 7 * 64});// 全连接层x = torch::relu(fc1->forward(x));x = dropout->forward(x);x = fc2->forward(x);return torch::log_softmax(x, 1);
}
2.4 识别主程序
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <iostream>
#include <iomanip>
#include <memory>
#include <string>
#include "model.h"torch::Tensor preprocess_image(const std::string& image_path) {// 读取图片cv::Mat image = cv::imread(image_path, cv::IMREAD_GRAYSCALE);if (image.empty()) {throw std::runtime_error("Error: Could not read image: " + image_path);}std::cout << "Original image size: " << image.size() << std::endl;// 使用Otsu's方法进行自动阈值二值化cv::Mat binary;cv::threshold(image, binary, 0, 255, cv::THRESH_BINARY | cv::THRESH_OTSU);// 应用形态学操作来改善数字形状cv::Mat morph;cv::Mat kernel = cv::getStructuringElement(cv::MORPH_ELLIPSE, cv::Size(3, 3));cv::morphologyEx(binary, morph, cv::MORPH_CLOSE, kernel);// 找到数字的边界框std::vector<std::vector<cv::Point>> contours;cv::findContours(255 - morph, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);// 找到最大的轮廓cv::Rect boundingBox;double maxArea = 0;for (const auto& contour : contours) {double area = cv::contourArea(contour);if (area > maxArea) {maxArea = area;boundingBox = cv::boundingRect(contour);}}// 确保边界框是正方形,并且保持纵横比int maxSide = std::max(boundingBox.width, boundingBox.height);int xCenter = boundingBox.x + boundingBox.width / 2;int yCenter = boundingBox.y + boundingBox.height / 2;// 扩展边界框为正方形boundingBox.x = xCenter - maxSide / 2;boundingBox.y = yCenter - maxSide / 2;boundingBox.width = maxSide;boundingBox.height = maxSide;// 添加paddingint padding = maxSide / 4;boundingBox.x = std::max(0, boundingBox.x - padding);boundingBox.y = std::max(0, boundingBox.y - padding);boundingBox.width = std::min(image.cols - boundingBox.x, boundingBox.width + 2 * padding);boundingBox.height = std::min(image.rows - boundingBox.y, boundingBox.height + 2 * padding);// 裁剪图像cv::Mat cropped = morph(boundingBox);// 调整大小为20x20cv::Mat resized;cv::resize(cropped, resized, cv::Size(20, 20), 0, 0, cv::INTER_AREA);// 添加4像素边框cv::Mat padded;cv::copyMakeBorder(resized, padded, 4, 4, 4, 4, cv::BORDER_CONSTANT, cv::Scalar(255));// 反转颜色cv::Mat inverted;cv::bitwise_not(padded, inverted);// 应用轻微的高斯模糊cv::Mat blurred;cv::GaussianBlur(inverted, blurred, cv::Size(3, 3), 0.5);// 再次应用阈值,确保清晰的边界cv::Mat final;cv::threshold(blurred, final, 127, 255, cv::THRESH_BINARY);// 转换为浮点数并归一化cv::Mat float_img;final.convertTo(float_img, CV_32F, 1.0/255.0);// 保存所有处理步骤的图片cv::imwrite("step1_binary.jpg", binary);cv::imwrite("step2_morph.jpg", morph);cv::imwrite("step3_cropped.jpg", cropped);cv::imwrite("step4_resized.jpg", resized);cv::imwrite("step5_padded.jpg", padded);cv::imwrite("step6_inverted.jpg", inverted);cv::imwrite("step7_final.jpg", final);std::cout << "Preprocessing steps saved as images" << std::endl;// 转换为tensorauto tensor = torch::from_blob(float_img.data, {1, 28, 28}, torch::kFloat32).clone();tensor = tensor.unsqueeze(0);return tensor;
}void display_tensor(const torch::Tensor& tensor) {std::cout << "\nProcessed image (ASCII art):\n";for (int i = 0; i < 28; ++i) {for (int j = 0; j < 28; ++j) {float pixel = tensor[0][0][i][j].item<float>();if (pixel < 0.2) std::cout << " ";else if (pixel < 0.4) std::cout << "..";else if (pixel < 0.6) std::cout << "**";else if (pixel < 0.8) std::cout << "##";else std::cout << "@@";}std::cout << std::endl;}
}int main(int argc, char* argv[]) {if (argc < 2) {std::cerr << "Usage: " << argv[0] << " <image_path>" << std::endl;return 1;}try {// 使用CPUtorch::Device device(torch::kCPU);// 加载模型auto model = std::make_shared<Net>();torch::load(model, "mnist_cnn.pt");model->to(device);model->eval();// 预处理图片auto input = preprocess_image(argv[1]);// 显示处理后的图像display_tensor(input);// 进行预测torch::NoGradGuard no_grad;auto output = model->forward(input);auto probabilities = torch::softmax(output, 1);auto prediction = output.argmax(1);// 打印预测结果std::cout << "\nPredicted digit: " << prediction.item<int>() << std::endl;// 打印每个数字的概率std::cout << "\nProbabilities for each digit:" << std::endl;for (int i = 0; i < 10; ++i) {std::cout << "Digit " << i << ": " << std::fixed << std::setprecision(4) << probabilities[0][i].item<float>() * 100 << "%" << std::endl;}} catch (const std::exception& e) {std::cerr << "Error: " << e.what() << std::endl;return -1;}return 0;
}
3 运行结果
3.1 训练集程序运行
xhome@ubuntu:~/project_ai/mnist_demo/build$ ./mnist_demo
Training on CPU.
Loading training dataset...
Loading images from: data/MNIST/raw/train-images-idx3-ubyte
Loading labels from: data/MNIST/raw/train-labels-idx1-ubyte
Dataset size: 60000
Loading test dataset...
Loading images from: data/MNIST/raw/t10k-images-idx3-ubyte
Loading labels from: data/MNIST/raw/t10k-labels-idx1-ubyte
Dataset size: 10000
Training dataset size: 60000
Test dataset size: 10000Epoch 1/10
Train Epoch: 1 [60000/60000 (100.0%)] Loss: 0.0716
Test set: Average loss: 0.0504, Accuracy: 9840/10000 (98.4%)Epoch 2/10
Train Epoch: 2 [60000/60000 (100.0%)] Loss: 0.0070
Test set: Average loss: 0.0420, Accuracy: 9863/10000 (98.6%)Epoch 3/10
Train Epoch: 3 [60000/60000 (100.0%)] Loss: 0.0337
Test set: Average loss: 0.0352, Accuracy: 9889/10000 (98.9%)Epoch 4/10
Train Epoch: 4 [60000/60000 (100.0%)] Loss: 0.0009
Test set: Average loss: 0.0325, Accuracy: 9893/10000 (98.9%)Epoch 5/10
Train Epoch: 5 [60000/60000 (100.0%)] Loss: 0.0131
Test set: Average loss: 0.0277, Accuracy: 9907/10000 (99.1%)Epoch 6/10
Train Epoch: 6 [60000/60000 (100.0%)] Loss: 0.0219
Test set: Average loss: 0.0270, Accuracy: 9910/10000 (99.1%)Epoch 7/10
Train Epoch: 7 [60000/60000 (100.0%)] Loss: 0.0033
Test set: Average loss: 0.0254, Accuracy: 9917/10000 (99.2%)Epoch 8/10
Train Epoch: 8 [60000/60000 (100.0%)] Loss: 0.0033
Test set: Average loss: 0.0247, Accuracy: 9915/10000 (99.2%)Epoch 9/10
Train Epoch: 9 [60000/60000 (100.0%)] Loss: 0.0561
Test set: Average loss: 0.0221, Accuracy: 9926/10000 (99.3%)Epoch 10/10
Train Epoch: 10 [60000/60000 (100.0%)] Loss: 0.0003
Test set: Average loss: 0.0248, Accuracy: 9914/10000 (99.1%)Training completed in 1 minutes
Model saved to mnist_cnn.pt
3.2 识别程序
xhome@ubuntu:~/project_ai/mnist_demo/build$ ./mnist_predict ../../image/1.png
Original image size: [456 x 479]
Preprocessing steps saved as imagesProcessed image (ASCII art):@@@@@@ @@@@@@@@@@ @@@@@@@@ @@@@@@@@ @@@@@@ @@@@@@ @@@@@@ @@@@@@ @@@@@@ @@@@@@ @@@@@@ @@@@@@ @@@@@@@@ @@@@@@@@ @@@@@@@@@@@@ Predicted digit: 1Probabilities for each digit:
Digit 0: 0.0047%
Digit 1: 94.2735%
Digit 2: 0.0737%
Digit 3: 0.0274%
Digit 4: 0.0054%
Digit 5: 2.7657%
Digit 6: 0.0065%
Digit 7: 0.0035%
Digit 8: 2.7485%
Digit 9: 0.0911%
xhome@ubuntu:~/project_ai/mnist_demo/build$ ./mnist_predict ../../image/5.png
Original image size: [508 x 512]
Preprocessing steps saved as imagesProcessed image (ASCII art):@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@ @@@@@@@@ @@@@@@@@ @@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@@@@@ @@@@@@@@ @@@@@@@@ @@@@@@@@ @@@@@@@@ @@@@@@@@ @@@@@@@@ @@@@@@@@ @@@@@@@@ @@@@@@@@ @@@@@@@@ @@@@@@@@@@@@@@@@@@ @@@@@@@@@@@@@@ @@@@@@@@ Predicted digit: 5Probabilities for each digit:
Digit 0: 0.0118%
Digit 1: 0.0003%
Digit 2: 0.0079%
Digit 3: 0.3717%
Digit 4: 0.0000%
Digit 5: 84.2349%
Digit 6: 0.0197%
Digit 7: 0.0005%
Digit 8: 13.8340%
Digit 9: 1.5192%