Pytorch模型
import torchvision
import torch
import matplotlib.pyplot as plt# 1. 加载 FashionMNIST 数据集
test_data = torchvision.datasets.FashionMNIST(root="data", # 数据存储目录train=False, # 使用测试数据集download=True, # 如果数据集不存在则下载transform=torchvision.transforms.ToTensor() # 将图像转换为 Tensor 格式
)# 2. 创建 DataLoader 用于批量加载数据
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True)# 3. 定义类别名称
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']# 4. 获取测试集的一个批次数据(图像和标签)
img, label = next(iter(test_loader))# 5. 对图像进行处理并转换为 NumPy 数组
img = img.reshape(28, 28).numpy()# 6. 使用 matplotlib 显示图像
plt.figure()
plt.imshow(img, cmap='gray') # 显示图像,使用灰度色图
plt.colorbar() # 显示色条
plt.grid(False) # 关闭网格
plt.show()# 7. 打印对应标签的类别名称
print("Class:", class_names[label[0]])
将模型转化成了RelayIR
Relay IR 作为 TVM 的中间表示(IR),处于模型代码和硬件特定代码之间。它抽象了具体硬件(如 CPU、GPU、TPU、FPGA)底层的实现细节,提供了一个统一的计算图表示。
将 TorchScript 模型转换为 TVM Relay IR:
mod, params = relay.frontend.from_pytorch(script_mod, ['dataA', img.reshape(1, 784).shape)])
-
relay.frontend.from_pytorch
用于将 PyTorch 模型转换为 TVM 的 Relay 中间表示(IR)。它接受 TorchScript 模型和输入的形状信息作为参数。 -
['dataA', img.reshape(1, 784).shape]
是输入张量的名称和形状。这里的dataA
是模型输入的名称(可以自定义),img.reshape(1, 784).shape
是输入数据的形状(假设模型是一个基于图像的分类模型,输入是一个 28x28 的图像)。
TE
AutoTVM
TE+Schedule
TIR
Machine code
编译:编译成Machine Code