您的位置:首页 > 房产 > 家装 > 昇思25天学习打卡营第21天|ResNet50迁移学习

昇思25天学习打卡营第21天|ResNet50迁移学习

2024/10/7 2:30:05 来源:https://blog.csdn.net/main_h_/article/details/140420487  浏览:    关键词:昇思25天学习打卡营第21天|ResNet50迁移学习

在实际应用场景中,由于训练数据集不足,所以很少有人会从头开始训练整个网络。普遍的做法是,在一个非常大的基础数据集上训练得到一个预训练模型,然后使用该模型来初始化网络的权重参数或作为固定特征提取器应用于特定的任务中。

一些通用的东西,可能会有大量的数据用来训练,但是大多数的行业应用,基本很难找到大量的数据提供给模型训练。如果是之前的分类模型,缺乏足够的数据很难得到一个比较理想的模型。由此,迁移学习产生了,带着强大的光环。她可以应用于这种数据量不足的场合和应用。

在这里插入图片描述
通过简单的添加狼和狗的数据集训练以后,可以看到训练后的模型对于狼狗的预测准确率还是不错的。
可以多次运行

import matplotlib.pyplot as plt
import mindspore as msdef visualize_model(best_ckpt_path, val_ds):net = resnet50()# 全连接层输入层的大小in_channels = net.fc.in_channels# 输出通道数大小为狼狗分类数2head = nn.Dense(in_channels, 2)# 重置全连接层net.fc = head# 平均池化层kernel size为7avg_pool = nn.AvgPool2d(kernel_size=7)# 重置平均池化层net.avg_pool = avg_pool# 加载模型参数param_dict = ms.load_checkpoint(best_ckpt_path)ms.load_param_into_net(net, param_dict)model = train.Model(net)# 加载验证集的数据进行验证data = next(val_ds.create_dict_iterator())images = data["image"].asnumpy()labels = data["label"].asnumpy()class_name = {0: "dogs", 1: "wolves"}# 预测图像类别output = model.predict(ms.Tensor(data['image']))pred = np.argmax(output.asnumpy(), axis=1)# 显示图像及图像的预测值plt.figure(figsize=(5, 5))for i in range(4):plt.subplot(2, 2, i + 1)# 若预测正确,显示为蓝色;若预测错误,显示为红色color = 'blue' if pred[i] == labels[i] else 'red'plt.title('predict:{}'.format(class_name[pred[i]]), color=color)picture_show = np.transpose(images[i], (1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])picture_show = std * picture_show + meanpicture_show = np.clip(picture_show, 0, 1)plt.imshow(picture_show)plt.axis('off')plt.show()
visualize_model(best_ckpt_path, dataset_val)

看每次的结果来看,基本可以准确预测。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com