经典的模型还是Unet,也可以使用torch自带的unet来训练,但为了更好地了解,还是选择自己搭建。
unet.py:
import torch
import torch.nn as nn
import torch.nn.functional as Fclass Up(nn.Module):def __init__(self, in_channel, out_channel):super(Up, self).__init__()self.block = nn.Sequential(nn.Conv2d(in_channel, out_channel, 3, 1, 1),nn.BatchNorm2d(out_channel),nn.ReLU())def forward(self, x):x = self.block(x)out = F.interpolate(x, scale_factor=2)return outclass Down(nn.Module):def __init__(self, in_channel, out_channel, stride=2):super(Down, self).__init__()self.block = nn.Sequential(nn.Conv2d(in_channel, out_channel, 3, stride, 1),nn.BatchNorm2d(out_channel),nn.ReLU())def forward(self, x):return self.block(x)class UpConcat(nn.Module):def __init__(self, in_channel, out_channel):super(UpConcat, self).__init__()self.up = nn.Upsample(scale_factor=2)self.conv2 = nn.Sequential(nn.Conv2d(in_channel+out_channel, out_channel, kernel_size=3, padding=1),nn.ReLU6(inplace=True),nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),nn.ReLU6(inplace=True),)def forward(self, in_map1, in_map2):in_map2 = self.up(in_map2)out = torch.cat([in_map1, in_map2], dim=1)return self.conv2(out)class MainNet(nn.Module):def __init__(self, num_classes):super(MainNet, self).__init__()self.down1 = Down(3, 64, stride=1)self.down2 = Down(64, 128)self.down3 = Down(128, 256)self.down4 = Down(256, 512)self.down5 = Down(512, 1024)# self.conv = nn.Conv2d(1024, 512, 3, 1, 1)self.up5concat = UpConcat(1024, 512)self.up4concat = UpConcat(512, 256)self.up3concat = UpConcat(256, 128)self.up2concat = UpConcat(128, 64)self.head = nn.Sequential(nn.Conv2d(64, num_classes, 1),nn.Sigmoid())def forward(self, x):feat1 = self.down1(x) # 3, 512, 512 ---->64, 512, 512feat2 = self.down2(feat1) # 64, 512, 512 ---->128, 256, 256feat3 = self.down3(feat2) # 128, 256, 256 ---->256,128,128feat4 = self.down4(feat3) # 256,128,128 ---> 512,64,64feat5 = self.down5(feat4) # 512,64,64 ----> 1024,32,32print("feat5:", feat5.shape)# feat5 = self.conv(feat5)feat4_up = self.up5concat(feat4, feat5)print("feat4_up:", feat4_up.shape)feat3_up = self.up4concat(feat3, feat4_up)feat2_up = self.up3concat(feat2, feat3_up)feat1_up = self.up2concat(feat1, feat2_up)print("feat1_up:", feat1_up.shape)print(feat1_up.shape, feat2_up.shape, feat3_up.shape, feat4_up.shape)return self.head(feat1_up)if __name__ == '__main__':device = torch.device("cuda" if torch.cuda.is_available() else "cpu")tensor = torch.zeros((1, 3, 512, 512)).to(device)model = MainNet(num_classes=3).to(device)# print(model)# model.apply(inplace_relu)out = model(tensor)# print(out.shape)#from torchsummary import torchsummarytorchsummary.summary(model, (3, 512, 512))# # from torchstat import stat# # stat(model, (3, 512, 512))# from thop import profile## flops, params = profile(model, inputs=(tensor,))## print("FLOPs=", str(flops / 1e9) + '{}'.format("G"))# print("params=", str(params / 1e6) + '{}'.format("M"))## #FLOPs= 63.406604288G# # params= 14.127683M