您的位置:首页 > 汽车 > 时评 > Pytorch语义分割(2)--------模型搭建

Pytorch语义分割(2)--------模型搭建

2024/10/31 9:51:13 来源:https://blog.csdn.net/m0_48095841/article/details/139495815  浏览:    关键词:Pytorch语义分割(2)--------模型搭建

经典的模型还是Unet,也可以使用torch自带的unet来训练,但为了更好地了解,还是选择自己搭建。

unet.py:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Up(nn.Module):def __init__(self, in_channel, out_channel):super(Up, self).__init__()self.block = nn.Sequential(nn.Conv2d(in_channel, out_channel, 3, 1, 1),nn.BatchNorm2d(out_channel),nn.ReLU())def forward(self, x):x = self.block(x)out = F.interpolate(x, scale_factor=2)return outclass Down(nn.Module):def __init__(self, in_channel, out_channel, stride=2):super(Down, self).__init__()self.block = nn.Sequential(nn.Conv2d(in_channel, out_channel, 3, stride, 1),nn.BatchNorm2d(out_channel),nn.ReLU())def forward(self, x):return self.block(x)class UpConcat(nn.Module):def __init__(self, in_channel, out_channel):super(UpConcat, self).__init__()self.up = nn.Upsample(scale_factor=2)self.conv2 = nn.Sequential(nn.Conv2d(in_channel+out_channel, out_channel, kernel_size=3, padding=1),nn.ReLU6(inplace=True),nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),nn.ReLU6(inplace=True),)def forward(self, in_map1, in_map2):in_map2 = self.up(in_map2)out = torch.cat([in_map1, in_map2], dim=1)return self.conv2(out)class MainNet(nn.Module):def __init__(self, num_classes):super(MainNet, self).__init__()self.down1 = Down(3, 64, stride=1)self.down2 = Down(64, 128)self.down3 = Down(128, 256)self.down4 = Down(256, 512)self.down5 = Down(512, 1024)# self.conv = nn.Conv2d(1024, 512, 3, 1, 1)self.up5concat = UpConcat(1024, 512)self.up4concat = UpConcat(512, 256)self.up3concat = UpConcat(256, 128)self.up2concat = UpConcat(128, 64)self.head = nn.Sequential(nn.Conv2d(64, num_classes, 1),nn.Sigmoid())def forward(self, x):feat1 = self.down1(x)       # 3, 512, 512 ---->64, 512, 512feat2 = self.down2(feat1)   # 64, 512, 512 ---->128, 256, 256feat3 = self.down3(feat2)   # 128, 256, 256 ---->256,128,128feat4 = self.down4(feat3)   # 256,128,128 ---> 512,64,64feat5 = self.down5(feat4)   # 512,64,64 ----> 1024,32,32print("feat5:", feat5.shape)# feat5 = self.conv(feat5)feat4_up = self.up5concat(feat4, feat5)print("feat4_up:", feat4_up.shape)feat3_up = self.up4concat(feat3, feat4_up)feat2_up = self.up3concat(feat2, feat3_up)feat1_up = self.up2concat(feat1, feat2_up)print("feat1_up:", feat1_up.shape)print(feat1_up.shape, feat2_up.shape, feat3_up.shape, feat4_up.shape)return self.head(feat1_up)if __name__ == '__main__':device = torch.device("cuda" if torch.cuda.is_available() else "cpu")tensor = torch.zeros((1, 3, 512, 512)).to(device)model = MainNet(num_classes=3).to(device)# print(model)# model.apply(inplace_relu)out = model(tensor)# print(out.shape)#from torchsummary import torchsummarytorchsummary.summary(model, (3, 512, 512))# # from torchstat import stat# # stat(model, (3, 512, 512))# from thop import profile## flops, params = profile(model, inputs=(tensor,))## print("FLOPs=", str(flops / 1e9) + '{}'.format("G"))# print("params=", str(params / 1e6) + '{}'.format("M"))## #FLOPs= 63.406604288G# # params= 14.127683M

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com