您的位置:首页 > 文旅 > 旅游 > 数字广东网络建设有限公司_广东广州自己建网站公司_华与华营销策划公司_多地优化完善疫情防控措施

数字广东网络建设有限公司_广东广州自己建网站公司_华与华营销策划公司_多地优化完善疫情防控措施

2024/12/23 15:31:48 来源:https://blog.csdn.net/qq_64603703/article/details/142220221  浏览:    关键词:数字广东网络建设有限公司_广东广州自己建网站公司_华与华营销策划公司_多地优化完善疫情防控措施
数字广东网络建设有限公司_广东广州自己建网站公司_华与华营销策划公司_多地优化完善疫情防控措施

目录

一、了解MINIST数据集

1、什么是MNIST

2、查看MNIST由来

二、实操代码

1、下载训练数据集

2、下载测试数据集

运行结果:

3、展示手写数字图片

运行结果:

4、打包图片

运行结果:

5、判断当前pytorch使用的设备

1)torch.cuda.is_available()

2)torch.backends.mps.is_available()

3)MPS

运行结果:


一、了解MNIST数据集

1、什么是MNIST

        MNIST是一种基于神经网络的手写数字识别算法。它是LeCun等人在1998年提出的,是深度学习领域的里程碑之一。MNIST数据集包含了大量的手写数字图片,MNIST算法通过训练神经网络,可以有效地识别这些手写数字。MINIST算法在计算机视觉和模式识别中有广泛的应用,被认为是机器学习领域的经典问题之一。

        MNIST包含70,000张手写数字图像,其中60,000张用于训练,10,000张用于测试

        所有的图像都是灰度的,大小为28x28像素的,并且居中的,以减少预处理和加快运行。

2、查看MINIST由来

        进入下列网页,即可查看

https://yann.lecun.com/exdb/mnist/icon-default.png?t=O83Ahttps://yann.lecun.com/exdb/mnist/        打卡即可得到下列画面:

此时可知道这个MINIST数据集中训练集和测试集所占大小等等:

二、实操代码

1、下载训练数据集

训练数据集包含训练用的手写数字图片及其对应的标签

import torch 
print(torch.__version__)   # 查看torch版本号from torch import nn  # 导入神经网络模块,提供了构建网络所需的各种层
from torch.utils.data import DataLoader  # 数据包管理工具,打包数据,它可以将数据集封装成适合批处理的数据加载器。
from torchvision import datasets   # 封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor   一个数据转换操作,用于将图片数据转换为PyTorch张量(Tensor)。PyTorch中的模型只能接受张量作为输入。training_data = datasets.MNIST(   # 跳转到函数的内部源代码,pycharm 按下ctrl +鼠标点击root='data',  # 指定数据集下载后储存的根目录train=True,  # 表示下载的是训练集,如需下载测试集则更改为False即可download=True,   # 表示如果本地没有数据集,则自动下载,有则不再下载transform=ToTensor()   # 指定一个数据转换操作,即将下载的图片转换为pytorch张量tensor,因为pytorch模型只能处理张量类型的数据
)

        将代码和下列测试集代码一起运行。

2、下载测试数据集

        只需将训练数据集中的train参数结果更改为False

test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor()  
)  # NumPy 数组只能在CPU上运行。Tensor可以在GPU上运行,这在深度学习应用中可以显著提高计算速度。print(len(training_data))   # 打印训练集数据条数
print(len(test_data))   # 打印测试集数据条数
运行结果:

3、展示手写数字图片

from matplotlib import pyplot as plt   # 导入绘图库
figure = plt.figure()   # 设置一个空白画布
for i in range(9):img,label = training_data[i+59000]   # 提取第59000张图片开始,共9张,返回图片及其对应的标签值figure.add_subplot(3,3,i+1)   # 在画布创建3行3列的小窗口,通过遍历的值i来确定每个画布展示的图片plt.title(label)   # 设置每个窗口的标题,设置标签为上述返回的标签值plt.axis('off')   # 取消画布中的坐标轴的图像plt.imshow(img.squeeze(),cmap='gray')   # plt.imshow()将NumPy数组data中的数据显示为图像,并在图形窗口中,a = img.squeeze()   # img.squeeze()从张量img中去掉维度为1的。如果该维度的大小不为1,则张量不会改变。
plt.show()

        最后一步img.squeeze降低维度是因为遍历出来的图像有一个冗余的维度没有用,如下所示,维度为1,图像大小为28x28像素的。

运行结果:

4、打包图片

train_dataloader = DataLoader(training_data,batch_size=64)  # 调用上述定义的DataLoader打包库,将训练集的图片和标签,64张图片为一个包,
test_dataloader = DataLoader(test_data,batch_size=64)   # 将测试集的图片和标签,每64张打包成一份
for x,y in test_dataloader:# x是表示打包好的每一个数据包,其形状为[64,1,28,28],64表示批次大小,1表示通道数为1,即灰度图,28表示图像的宽高像素值# y表示每个图片标签print(f"shape of x[N,C,H,W]:{x.shape}")   # 打印图片形状print(f"shape of y:{y.shape}{y.dtype}")   # 打印标签的形状和数据类型break  # 跳出并终止循环,表示只遍历一个包的数据情况
运行结果:

5、判断当前pytorch使用的设备

"""判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU"""  # 返回cuda,mps,cpu,
device = "cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")  # 字符串的格式化。CUDA驱动软件的功能:pytorch能够去执行cuda的命令,cuda通过GPU指令集
# 神经网络的模型也需要传入到GPU,1个batchsize的数据集也需要传入到GPU,才可以进行训练。
        1)torch.cuda.is_available()

                 检查CUDA是否在当前系统上可用。CUDA是NVIDIA的并行计算平台和编程模型,它允许软件利用NVIDIA图形处理单元(GPU)进行加速计算。如果CUDA可用,这意味着你的系统有NVIDIA GPU,并且PyTorch已经配置为可以使用CUDA。

        2)torch.backends.mps.is_available()

                检查MPS是否可用。请注意,这个检查通常只在Apple Silicon Macs上返回True

        3)MPS

                MPS是Apple提供的一套高性能图形和计算框架,专门设计用于Apple Silicon Macs上的Metal API。虽然MPS不直接对应于PyTorch的CUDA,但PyTorch从1.8版本开始增加了对Apple Silicon Macs的支持,通过MPS后端进行加速。

        表示如果torch.cuda.is_available()返回的是True则返回cuda,即当前使用的设备是cuda,如果返回False即执行下面的判断语句,即如果torch.backends.mps.is_available()返回的是True则返回mps,即当前使用的是苹果设备的mps,反之则使用的是cpu设备来计算。

运行结果:

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com