您的位置:首页 > 健康 > 美食 > pytorch学习(十二):对现有的模型进行修改

pytorch学习(十二):对现有的模型进行修改

2025/3/16 19:54:05 来源:https://blog.csdn.net/weixin_52307528/article/details/141092824  浏览:    关键词:pytorch学习(十二):对现有的模型进行修改

以VGG16为例:

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

特征提取部分(features

  • 卷积层与ReLU激活:网络的前半部分主要由卷积层(Conv2d)和ReLU激活函数(ReLU)交替组成。每个卷积层后都紧跟一个ReLU层,用于引入非线性。这种结构有助于网络学习复杂的特征表示。

  • 卷积层配置

    • 初始阶段,使用64个3x3的卷积核,然后是ReLU激活,接着是另一个3x3卷积核和ReLU激活,之后是一个2x2的最大池化层(MaxPool2d),用于降低特征图的尺寸并增加感受野。
    • 类似地,这个过程在特征图的通道数增加到128、256和512时重复,每次增加通道数后都会跟随几个卷积层和ReLU激活,然后是一个最大池化层。
    • 值得注意的是,在512通道的部分,卷积层和ReLU激活的组合被重复了三次,而没有立即进行池化,这可能是为了进一步增强特征表示。
  • 最大池化层:用于在每个阶段的末尾减少特征图的尺寸,这有助于减少计算量和参数数量,同时保持重要的特征信息。

全连接层部分(classifier

  • 自适应平均池化:在特征提取部分之后,使用了一个自适应平均池化层(AdaptiveAvgPool2d),将特征图的尺寸调整为7x7。这是为了确保无论输入图像的大小如何,全连接层都能接收到固定大小的输入。

  • 全连接层

    • 第一个全连接层(Linear)将7x7x512的特征图展平为25088个特征,并映射到4096个输出特征上。
    • 接着是两个ReLU激活层、两个Dropout层(用于防止过拟合)和另外两个全连接层,最终输出1000个类别的得分(假设是用于ImageNet分类任务)。

可以看到,最后一层的全连接层的输出是1000,那么当我们有例如十分类的问题时候,就需要对网络进行修改。

 

vgg16_true.add_module('add_linear',nn.Linear(1000,10))

运行上述代码,在末尾加一层线性层,也就是全连接层。

还有一种方式是对原有的全连接层进行修改,将1000改为10。

vgg16_false.classifier[6]=nn.Linear(4096,10)

附上所有源代码;

# -*- coding: utf-8 -*-  
# File created on 2024/8/9 
# 作者:酷尔
# 公众号:酷尔计算机import torchvision
from torch import nn
# train_data=torchvision.datasets.ImageNet('./data_imagenet',split='train',download=True,transform=torchvision.transforms.ToTensor())vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_true=torchvision.models.vgg16(pretrained=True)# print(vgg16_true)
# import os
#
# # 尝试从环境变量中获取TORCH_HOME
# torch_home = os.getenv('TORCH_HOME', os.path.expanduser('~/.torch'))
# model_cache_dir = os.path.join(torch_home, 'models')
#
# print(f"Model cache directory: {model_cache_dir}")
# 
# # 注意:这个目录可能不直接包含模型文件,因为 PyTorch 可能使用了内部的缓存机制
# # 来管理这些文件,并且它们可能以哈希名存储而不是直接以模型名存储。train_data=torchvision.datasets.CIFAR10('./dataset',train=True,download=True,transform=torchvision.transforms.ToTensor())
vgg16_true.classifier.add_module('add_linear',nn.Linear(1000,10))
# print(vgg16_true)vgg16_false.classifier[6]=nn.Linear(4096,10)
print(vgg16_false)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com