您的位置:首页 > 教育 > 锐评 > 昇思25天学习打卡营第五天|应用实践/计算机视觉/FCN图像语义分割

昇思25天学习打卡营第五天|应用实践/计算机视觉/FCN图像语义分割

2024/10/6 20:36:01 来源:https://blog.csdn.net/guojun0718/article/details/140434149  浏览:    关键词:昇思25天学习打卡营第五天|应用实践/计算机视觉/FCN图像语义分割

心得

以前就提到,学习这个课程要有耐心,果然,今天就来了,整个代码执行下来,至少一个小时是不够的。当然期间,你可以去做别的事情,让模型自己去跑。不过看到最后的结果,一定得有耐心和心里准备。

本次课程开始了实际的应用,进行图像分割。介绍了图像分割的一种工具FCN的使用方法。并进行了一次操作实验。按照实验步骤进行,可以得到预期的图像处理结果。学习之后能掌握这个工具的实际运用。

打卡截图:

FCN图像语义分割

全卷积网络(Fully Convolutional Networks,FCN)是UC Berkeley的Jonathan Long等人于2015年在Fully Convolutional Networks for Semantic Segmentation[1]一文中提出的用于图像语义分割的一种框架。

FCN是首个端到端(end to end)进行像素级(pixel level)预测的全卷积网络。

fcn-1

语义分割

在具体介绍FCN之前,首先介绍何为语义分割:

图像语义分割(semantic segmentation)是图像处理和机器视觉技术中关于图像理解的重要一环,AI领域中一个重要分支,常被应用于人脸识别、物体检测、医学影像、卫星图像分析、自动驾驶感知等领域。

语义分割的目的是对图像中每个像素点进行分类。与普通的分类任务只输出某个类别不同,语义分割任务输出与输入大小相同的图像,输出图像的每个像素对应了输入图像每个像素的类别。语义在图像领域指的是图像的内容,对图片意思的理解,下图是一些语义分割的实例:

fcn-2

模型简介

FCN主要用于图像分割领域,是一种端到端的分割方法,是深度学习应用在图像语义分割的开山之作。通过进行像素级的预测直接得出与原图大小相等的label map。因FCN丢弃全连接层替换为全卷积层,网络所有层均为卷积层,故称为全卷积网络。

全卷积神经网络主要使用以下三种技术:

  1. 卷积化(Convolutional)

    使用VGG-16作为FCN的backbone。VGG-16的输入为224*224的RGB图像,输出为1000个预测值。VGG-16只能接受固定大小的输入,丢弃了空间坐标,产生非空间输出。VGG-16中共有三个全连接层,全连接层也可视为带有覆盖整个区域的卷积。将全连接层转换为卷积层能使网络输出由一维非空间输出变为二维矩阵,利用输出能生成输入图片映射的heatmap。

    fcn-3

  2. 上采样(Upsample)

    在卷积过程的卷积操作和池化操作会使得特征图的尺寸变小,为得到原图的大小的稠密图像预测,需要对得到的特征图进行上采样操作。使用双线性插值的参数来初始化上采样逆卷积的参数,后通过反向传播来学习非线性上采样。在网络中执行上采样,以通过像素损失的反向传播进行端到端的学习。

    fcn-4

  3. 跳跃结构(Skip Layer)

    利用上采样技巧对最后一层的特征图进行上采样得到原图大小的分割是步长为32像素的预测,称之为FCN-32s。由于最后一层的特征图太小,损失过多细节,采用skips结构将更具有全局信息的最后一层预测和更浅层的预测结合,使预测结果获取更多的局部细节。将底层(stride 32)的预测(FCN-32s)进行2倍的上采样得到原尺寸的图像,并与从pool4层(stride 16)进行的预测融合起来(相加),这一部分的网络被称为FCN-16s。随后将这一部分的预测再进行一次2倍的上采样并与从pool3层得到的预测融合起来,这一部分的网络被称为FCN-8s。 Skips结构将深层的全局信息与浅层的局部信息相结合。

    fcn-5

网络特点

  1. 不含全连接层(fc)的全卷积(fully conv)网络,可适应任意尺寸输入。
  2. 增大数据尺寸的反卷积(deconv)层,能够输出精细的结果。
  3. 结合不同深度层结果的跳级(skip)结构,同时确保鲁棒性和精确性。

数据处理

开始实验前,需确保本地已经安装Python环境及MindSpore。

[1]:

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14

[2]:

# 查看当前 mindspore 版本
!pip show mindspore
Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 

[3]:

from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar"
download(url, "./dataset", kind="tar", replace=True)
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar (537.2 MB)file_sizes: 100%|█████████████████████████████| 563M/563M [00:03<00:00, 152MB/s]
Extracting tar file...
Successfully downloaded / unzipped to ./dataset

[3]:

'./dataset'

数据预处理

由于PASCAL VOC 2012数据集中图像的分辨率大多不一致,无法放在一个tensor中,故输入前需做标准化处理。

数据加载

将PASCAL VOC 2012数据集与SDB数据集进行混合。

[4]:

import numpy as np
import cv2
import mindspore.dataset as ds
class SegDataset:
    def __init__(self,
                 image_mean,
                 image_std,
                 data_file='',
                 batch_size=32,
                 crop_size=512,
                 max_scale=2.0,
                 min_scale=0.5,
                 ignore_label=255,
                 num_classes=21,
                 num_readers=2,
                 num_parallel_calls=4):
        self.data_file = data_file
        self.batch_size = batch_size
        self.crop_size = crop_size
        self.image_mean = np.array(image_mean, dtype=np.float32)
        self.image_std = np.array(image_std, dtype=np.float32)
        self.max_scale = max_scale
        self.min_scale = min_scale
        self.ignore_label = ignore_label
        self.num_classes = num_classes
        self.num_readers = num_readers
        self.num_parallel_calls = num_parallel_calls
        max_scale > min_scale
    def preprocess_dataset(self, image, label):
        image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
        label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
        sc = np.random.uniform(self.min_scale, self.max_scale)
        new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
        image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
        label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
        image_out = (image_out - self.image_mean) / self.image_std
        out_h, out_w = max(new_h, self.crop_size), max(new_w, self.crop_size)
        pad_h, pad_w = out_h - new_h, out_w - new_w
        if pad_h > 0 or pad_w > 0:
            image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
            label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)
        offset_h = np.random.randint(0, out_h - self.crop_size + 1)
        offset_w = np.random.randint(0, out_w - self.crop_size + 1)
        image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]
        label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]
        if np.random.uniform(0.0, 1.0) > 0.5:
            image_out = image_out[:, ::-1, :]
            label_out = label_out[:, ::-1]
        image_out = image_out.transpose((2, 0, 1))
        image_out = image_out.copy()
        label_out = label_out.copy()
        label_out = label_out.astype("int32")
        return image_out, label_out
    def get_dataset(self):
        ds.config.set_numa_enable(True)
        dataset = ds.MindDataset(self.data_file, columns_list=["data", "label"],
                                 shuffle=True, num_parallel_workers=self.num_readers)
        transforms_list = self.preprocess_dataset
        dataset = dataset.map(operations=transforms_list, input_columns=["data", "label"],
                              output_columns=["data", "label"],
                              num_parallel_workers=self.num_parallel_calls)
        dataset = dataset.shuffle(buffer_size=self.batch_size * 10)
        dataset = dataset.batch(self.batch_size, drop_remainder=True)
        return dataset
# 定义创建数据集的参数
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"
# 定义模型训练参数
train_batch_size = 4
crop_size = 512
min_scale = 0.5
max_scale = 2.0
ignore_label = 255
num_classes = 21
# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,
                     image_std=IMAGE_STD,
                     data_file=DATA_FILE,
                     batch_size=train_batch_size,
                     crop_size=crop_size,
                     max_scale=max_scale,
                     min_scale=min_scale,
                     ignore_label=ignore_label,
                     num_classes=num_classes,
                     num_readers=2,
                     num_parallel_calls=4)
dataset = dataset.get_dataset()

训练集可视化

运行以下代码观察载入的数据集图片(数据处理过程中已做归一化处理)。

[5]:

import numpy as np
import matplotlib.pyplot as plt
plt.figure(figsize=(16, 8))
# 对训练集中的数据进行展示
for i in range(1, 9):
    plt.subplot(2, 4, i)
    show_data = next(dataset.create_dict_iterator())
    show_images = show_data["data"].asnumpy()
    show_images = np.clip(show_images, 0, 1)
# 将图片转换HWC格式后进行展示
    plt.imshow(show_images[0].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0)
plt.show()

网络构建

网络流程

FCN网络的流程如下图所示:

  1. 输入图像image,经过pool1池化后,尺寸变为原始尺寸的1/2。
  2. 经过pool2池化,尺寸变为原始尺寸的1/4。
  3. 接着经过pool3、pool4、pool5池化,大小分别变为原始尺寸的1/8、1/16、1/32。
  4. 经过conv6-7卷积,输出的尺寸依然是原图的1/32。
  5. FCN-32s是最后使用反卷积,使得输出图像大小与输入图像相同。
  6. FCN-16s是将conv7的输出进行反卷积,使其尺寸扩大两倍至原图的1/16,并将其与pool4输出的特征图进行融合,后通过反卷积扩大到原始尺寸。
  7. FCN-8s是将conv7的输出进行反卷积扩大4倍,将pool4输出的特征图反卷积扩大2倍,并将pool3输出特征图拿出,三者融合后通反卷积扩大到原始尺寸。

fcn-6

使用以下代码构建FCN-8s网络。

[6]:

import mindspore.nn as nn
class FCN8s(nn.Cell):
    def __init__(self, n_class):
        super().__init__()
        self.n_class = n_class
        self.conv1 = nn.SequentialCell(
            nn.Conv2d(in_channels=3, out_channels=64,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.SequentialCell(
            nn.Conv2d(in_channels=64, out_channels=128,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.SequentialCell(
            nn.Conv2d(in_channels=128, out_channels=256,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv4 = nn.SequentialCell(
            nn.Conv2d(in_channels=256, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv5 = nn.SequentialCell(
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512,
                      kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv6 = nn.SequentialCell(
            nn.Conv2d(in_channels=512, out_channels=4096,
                      kernel_size=7, weight_init='xavier_uniform'),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.conv7 = nn.SequentialCell(
            nn.Conv2d(in_channels=4096, out_channels=4096,
                      kernel_size=1, weight_init='xavier_uniform'),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class,
                                  kernel_size=1, weight_init='xavier_uniform')
        self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
                                           kernel_size=4, stride=2, weight_init='xavier_uniform')
        self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class,
                                     kernel_size=1, weight_init='xavier_uniform')
        self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
                                                kernel_size=4, stride=2, weight_init='xavier_uniform')
        self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class,
                                     kernel_size=1, weight_init='xavier_uniform')
        self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,
                                           kernel_size=16, stride=8, weight_init='xavier_uniform')
    def construct(self, x):
        x1 = self.conv1(x)
        p1 = self.pool1(x1)
        x2 = self.conv2(p1)
        p2 = self.pool2(x2)
        x3 = self.conv3(p2)
        p3 = self.pool3(x3)
        x4 = self.conv4(p3)
        p4 = self.pool4(x4)
        x5 = self.conv5(p4)
        p5 = self.pool5(x5)
        x6 = self.conv6(p5)
        x7 = self.conv7(x6)
        sf = self.score_fr(x7)
        u2 = self.upscore2(sf)
        s4 = self.score_pool4(p4)
        f4 = s4 + u2
        u4 = self.upscore_pool4(f4)
        s3 = self.score_pool3(p3)
        f3 = s3 + u4
        out = self.upscore8(f3)
        return out

训练准备

导入VGG-16部分预训练权重

FCN使用VGG-16作为骨干网络,用于实现图像编码。使用下面代码导入VGG-16预训练模型的部分预训练权重。

[7]:

from download import download
from mindspore import load_checkpoint, load_param_into_net
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt"
download(url, "fcn8s_vgg16_pretrain.ckpt", replace=True)
def load_vgg16():
    ckpt_vgg16 = "fcn8s_vgg16_pretrain.ckpt"
    param_vgg = load_checkpoint(ckpt_vgg16)
    load_param_into_net(net, param_vgg)
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt (513.2 MB)file_sizes: 100%|█████████████████████████████| 538M/538M [00:02<00:00, 225MB/s]
Successfully downloaded file to fcn8s_vgg16_pretrain.ckpt

损失函数

语义分割是对图像中每个像素点进行分类,仍是分类问题,故损失函数选择交叉熵损失函数来计算FCN网络输出与mask之间的交叉熵损失。这里我们使用的是mindspore.nn.CrossEntropyLoss()作为损失函数。

自定义评价指标 Metrics

这一部分主要对训练出来的模型效果进行评估,为了便于解释,假设如下:共有 𝑘+1𝑘+1 个类(从 𝐿0𝐿0 到 𝐿𝑘𝐿𝑘, 其中包含一个空类或背景), 𝑝𝑖𝑗𝑝𝑖𝑗 表示本属于𝑖𝑖类但被预测为𝑗𝑗类的像素数量。即, 𝑝𝑖𝑖𝑝𝑖𝑖 表示真正的数量, 而 𝑝𝑖𝑗𝑝𝑗𝑖𝑝𝑖𝑗𝑝𝑗𝑖 则分别被解释为假正和假负, 尽管两者都是假正与假负之和。

  • Pixel Accuracy(PA, 像素精度):这是最简单的度量,为标记正确的像素占总像素的比例。

𝑃𝐴=∑𝑘𝑖=0𝑝𝑖𝑖∑𝑘𝑖=0∑𝑘𝑗=0𝑝𝑖𝑗𝑃𝐴=∑𝑖=0𝑘𝑝𝑖𝑖∑𝑖=0𝑘∑𝑗=0𝑘𝑝𝑖𝑗

  • Mean Pixel Accuracy(MPA, 均像素精度):是PA的一种简单提升,计算每个类内被正确分类像素数的比例,之后求所有类的平均。

𝑀𝑃𝐴=1𝑘+1∑𝑖=0𝑘𝑝𝑖𝑖∑𝑘𝑗=0𝑝𝑖𝑗𝑀𝑃𝐴=1𝑘+1∑𝑖=0𝑘𝑝𝑖𝑖∑𝑗=0𝑘𝑝𝑖𝑗

  • Mean Intersection over Union(MloU, 均交并比):为语义分割的标准度量。其计算两个集合的交集和并集之,在语义分割的问题中,这两个集合为真实值(ground truth) 和预测值(predicted segmentation)。这个比例可以变形为正真数 (intersection) 比上真正、假负、假正(并集)之和。在每个类上计算loU,之后平均。

𝑀𝐼𝑜𝑈=1𝑘+1∑𝑖=0𝑘𝑝𝑖𝑖∑𝑘𝑗=0𝑝𝑖𝑗+∑𝑘𝑗=0𝑝𝑗𝑖−𝑝𝑖𝑖𝑀𝐼𝑜𝑈=1𝑘+1∑𝑖=0𝑘𝑝𝑖𝑖∑𝑗=0𝑘𝑝𝑖𝑗+∑𝑗=0𝑘𝑝𝑗𝑖−𝑝𝑖𝑖

  • Frequency Weighted Intersection over Union(FWIoU, 频权交井比):为MloU的一种提升,这种方法根据每个类出现的频率为其设置权重。

𝐹𝑊𝐼𝑜𝑈=1∑𝑘𝑖=0∑𝑘𝑗=0𝑝𝑖𝑗∑𝑖=0𝑘𝑝𝑖𝑖∑𝑘𝑗=0𝑝𝑖𝑗+∑𝑘𝑗=0𝑝𝑗𝑖−𝑝𝑖𝑖𝐹𝑊𝐼𝑜𝑈=1∑𝑖=0𝑘∑𝑗=0𝑘𝑝𝑖𝑗∑𝑖=0𝑘𝑝𝑖𝑖∑𝑗=0𝑘𝑝𝑖𝑗+∑𝑗=0𝑘𝑝𝑗𝑖−𝑝𝑖𝑖

[8]:

import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.train as train
class PixelAccuracy(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracy, self).__init__()
        self.num_class = num_class
    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix
    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)
    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)
    def eval(self):
        pixel_accuracy = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        return pixel_accuracy
class PixelAccuracyClass(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracyClass, self).__init__()
        self.num_class = num_class
    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix
    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)
    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)
    def eval(self):
        mean_pixel_accuracy = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
        mean_pixel_accuracy = np.nanmean(mean_pixel_accuracy)
        return mean_pixel_accuracy
class MeanIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(MeanIntersectionOverUnion, self).__init__()
        self.num_class = num_class
    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix
    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)
    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)
    def eval(self):
        mean_iou = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))
        mean_iou = np.nanmean(mean_iou)
        return mean_iou
class FrequencyWeightedIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(FrequencyWeightedIntersectionOverUnion, self).__init__()
        self.num_class = num_class
    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix
    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)
    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)
    def eval(self):
        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
        iu = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))
        frequency_weighted_iou = (freq[freq > 0] * iu[freq > 0]).sum()
        return frequency_weighted_iou

模型训练

导入VGG-16预训练参数后,实例化损失函数、优化器,使用Model接口编译网络,训练FCN-8s网络。

[9]:

 
import mindspore
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Model
device_target = "Ascend"
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target=device_target)
train_batch_size = 4
num_classes = 21
# 初始化模型结构
net = FCN8s(n_class=21)
# 导入vgg16预训练参数
load_vgg16()
# 计算学习率
min_lr = 0.0005
base_lr = 0.05
train_epochs = 1
iters_per_epoch = dataset.get_dataset_size()
total_step = iters_per_epoch * train_epochs
lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                            base_lr,
                                            total_step,
                                            iters_per_epoch,
                                            decay_epoch=2)
lr = Tensor(lr_scheduler[-1])
# 定义损失函数
loss = nn.CrossEntropyLoss(ignore_index=255)
# 定义优化器
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001)
# 定义loss_scale
scale_factor = 4
scale_window = 3000
loss_scale_manager = ms.amp.DynamicLossScaleManager(scale_factor, scale_window)
# 初始化模型
if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
# 设置ckpt文件保存的参数
time_callback = TimeMonitor(data_size=iters_per_epoch)
loss_callback = LossMonitor()
callbacks = [time_callback, loss_callback]
save_steps = 330
keep_checkpoint_max = 5
config_ckpt = CheckpointConfig(save_checkpoint_steps=10,
                               keep_checkpoint_max=keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="FCN8s",
                                directory="./ckpt",
                                config=config_ckpt)
callbacks.append(ckpt_callback)
model.train(train_epochs, dataset, callbacks=callbacks)
epoch: 1 step: 1, loss is 3.0555766
epoch: 1 step: 2, loss is 3.011279
epoch: 1 step: 3, loss is 2.949507
epoch: 1 step: 4, loss is 2.797522
epoch: 1 step: 5, loss is 2.680241
epoch: 1 step: 6, loss is 2.6088614
epoch: 1 step: 7, loss is 2.016888
epoch: 1 step: 8, loss is 2.3972445
epoch: 1 step: 9, loss is 2.3787959
epoch: 1 step: 10, loss is 2.7982762
epoch: 1 step: 11, loss is 2.5244904
epoch: 1 step: 12, loss is 2.2200372
epoch: 1 step: 13, loss is 1.6440588
epoch: 1 step: 14, loss is 1.2138156
epoch: 1 step: 15, loss is 1.1831312
epoch: 1 step: 16, loss is 1.624466
epoch: 1 step: 17, loss is 1.682687
epoch: 1 step: 18, loss is 2.0801966
epoch: 1 step: 19, loss is 1.4549251
epoch: 1 step: 20, loss is 1.1921303
epoch: 1 step: 21, loss is 1.210499
epoch: 1 step: 22, loss is 2.8143759
epoch: 1 step: 23, loss is 1.3854611
epoch: 1 step: 24, loss is 1.1596894
epoch: 1 step: 25, loss is 1.1558864
epoch: 1 step: 26, loss is 1.6064881
epoch: 1 step: 27, loss is 2.0444117
epoch: 1 step: 28, loss is 1.7394395
epoch: 1 step: 29, loss is 1.3064567
epoch: 1 step: 30, loss is 1.8637064
epoch: 1 step: 31, loss is 0.91164637
epoch: 1 step: 32, loss is 1.7243373
epoch: 1 step: 33, loss is 1.4663271
epoch: 1 step: 34, loss is 2.8225431
epoch: 1 step: 35, loss is 1.7434237
epoch: 1 step: 36, loss is 1.5633485
epoch: 1 step: 37, loss is 1.4842057
epoch: 1 step: 38, loss is 1.3332514
epoch: 1 step: 39, loss is 1.8021711
epoch: 1 step: 40, loss is 1.6750821
epoch: 1 step: 41, loss is 1.6699724
epoch: 1 step: 42, loss is 1.1504056
epoch: 1 step: 43, loss is 1.5332088
epoch: 1 step: 44, loss is 1.1708857
epoch: 1 step: 45, loss is 0.968395
epoch: 1 step: 46, loss is 1.564452
epoch: 1 step: 47, loss is 1.2555189
epoch: 1 step: 48, loss is 1.3750407
epoch: 1 step: 49, loss is 0.7245584
epoch: 1 step: 50, loss is 1.8305458
epoch: 1 step: 51, loss is 1.1896458
epoch: 1 step: 52, loss is 1.6661855
epoch: 1 step: 53, loss is 2.1858478
epoch: 1 step: 54, loss is 1.282305
epoch: 1 step: 55, loss is 1.0138123
epoch: 1 step: 56, loss is 1.0049514
epoch: 1 step: 57, loss is 1.9347887
epoch: 1 step: 58, loss is 2.0374985
epoch: 1 step: 59, loss is 1.4390357
epoch: 1 step: 60, loss is 1.7942659
epoch: 1 step: 61, loss is 1.0669905
epoch: 1 step: 62, loss is 1.6652745
epoch: 1 step: 63, loss is 2.0288475
epoch: 1 step: 64, loss is 1.3913418
epoch: 1 step: 65, loss is 2.3688924
epoch: 1 step: 66, loss is 1.8568288
epoch: 1 step: 67, loss is 2.1639984
epoch: 1 step: 68, loss is 2.2862463
epoch: 1 step: 69, loss is 1.5901521
epoch: 1 step: 70, loss is 1.7778679
epoch: 1 step: 71, loss is 1.7192123
epoch: 1 step: 72, loss is 1.8845917
epoch: 1 step: 73, loss is 1.8787395
epoch: 1 step: 74, loss is 2.5372388
epoch: 1 step: 75, loss is 1.2529737
epoch: 1 step: 76, loss is 1.5405328
epoch: 1 step: 77, loss is 1.1448927
epoch: 1 step: 78, loss is 1.9307976
epoch: 1 step: 79, loss is 1.0813539
epoch: 1 step: 80, loss is 1.2065259
epoch: 1 step: 81, loss is 0.84389406
epoch: 1 step: 82, loss is 1.7352059
epoch: 1 step: 83, loss is 2.768298
epoch: 1 step: 84, loss is 1.853559
epoch: 1 step: 85, loss is 1.2282575
epoch: 1 step: 86, loss is 1.4924873
epoch: 1 step: 87, loss is 1.3827902
epoch: 1 step: 88, loss is 2.4994068
epoch: 1 step: 89, loss is 2.3336148
epoch: 1 step: 90, loss is 1.1288173
epoch: 1 step: 91, loss is 1.2394044
epoch: 1 step: 92, loss is 1.6227185
epoch: 1 step: 93, loss is 1.5456296
epoch: 1 step: 94, loss is 2.1259058
epoch: 1 step: 95, loss is 2.168944
epoch: 1 step: 96, loss is 2.1336882
epoch: 1 step: 97, loss is 1.1592224
epoch: 1 step: 98, loss is 1.4533103
epoch: 1 step: 99, loss is 1.2709868
epoch: 1 step: 100, loss is 1.2122513
epoch: 1 step: 101, loss is 1.9324129
epoch: 1 step: 102, loss is 1.0846018
epoch: 1 step: 103, loss is 1.0182064
epoch: 1 step: 104, loss is 1.7873803
epoch: 1 step: 105, loss is 0.8500672
epoch: 1 step: 106, loss is 1.3119268
epoch: 1 step: 107, loss is 1.8134536
epoch: 1 step: 108, loss is 1.3881676
epoch: 1 step: 109, loss is 2.3944142
epoch: 1 step: 110, loss is 0.96989447
epoch: 1 step: 111, loss is 1.060776
epoch: 1 step: 112, loss is 0.8959741
epoch: 1 step: 113, loss is 1.1339622
epoch: 1 step: 114, loss is 0.93720037
epoch: 1 step: 115, loss is 1.0732929
epoch: 1 step: 116, loss is 1.9089131
epoch: 1 step: 117, loss is 1.4585028
epoch: 1 step: 118, loss is 2.024569
epoch: 1 step: 119, loss is 0.92920697
epoch: 1 step: 120, loss is 1.2659358
epoch: 1 step: 121, loss is 1.1494372
epoch: 1 step: 122, loss is 0.95037353
epoch: 1 step: 123, loss is 1.0830729
epoch: 1 step: 124, loss is 2.0103998
epoch: 1 step: 125, loss is 2.020546
epoch: 1 step: 126, loss is 0.86224157
epoch: 1 step: 127, loss is 2.2195559
epoch: 1 step: 128, loss is 1.3265861
epoch: 1 step: 129, loss is 1.1170424
epoch: 1 step: 130, loss is 1.4194286
epoch: 1 step: 131, loss is 1.5703492
epoch: 1 step: 132, loss is 0.92511356
epoch: 1 step: 133, loss is 1.463008
epoch: 1 step: 134, loss is 1.3614761
epoch: 1 step: 135, loss is 0.73445517
epoch: 1 step: 136, loss is 2.1498892
epoch: 1 step: 137, loss is 1.2681451
epoch: 1 step: 138, loss is 2.1569643
epoch: 1 step: 139, loss is 1.1049633
epoch: 1 step: 140, loss is 1.2802197
epoch: 1 step: 141, loss is 1.4277495
epoch: 1 step: 142, loss is 1.4321274
epoch: 1 step: 143, loss is 1.5357358
epoch: 1 step: 144, loss is 1.2731328
epoch: 1 step: 145, loss is 1.5327505
epoch: 1 step: 146, loss is 1.3492303
epoch: 1 step: 147, loss is 1.6624199
epoch: 1 step: 148, loss is 1.2496737
epoch: 1 step: 149, loss is 1.4304627
epoch: 1 step: 150, loss is 1.41121
epoch: 1 step: 151, loss is 1.0365019
epoch: 1 step: 152, loss is 1.514378
epoch: 1 step: 153, loss is 2.182979
epoch: 1 step: 154, loss is 1.1910573
epoch: 1 step: 155, loss is 2.11435
epoch: 1 step: 156, loss is 1.2731586
epoch: 1 step: 157, loss is 1.2932687
epoch: 1 step: 158, loss is 1.1846976
epoch: 1 step: 159, loss is 0.9954618
epoch: 1 step: 160, loss is 1.2538458
epoch: 1 step: 161, loss is 2.326229
epoch: 1 step: 162, loss is 3.8072379
epoch: 1 step: 163, loss is 0.97503257
epoch: 1 step: 164, loss is 1.8161008
epoch: 1 step: 165, loss is 1.4205337
epoch: 1 step: 166, loss is 1.7755113
epoch: 1 step: 167, loss is 1.5880908
epoch: 1 step: 168, loss is 2.3484373
epoch: 1 step: 169, loss is 2.017313
epoch: 1 step: 170, loss is 1.16473
epoch: 1 step: 171, loss is 1.3574165
epoch: 1 step: 172, loss is 2.0594084
epoch: 1 step: 173, loss is 1.2244663
epoch: 1 step: 174, loss is 1.0532262
epoch: 1 step: 175, loss is 1.3620684
epoch: 1 step: 176, loss is 1.8909686
epoch: 1 step: 177, loss is 1.6962887
epoch: 1 step: 178, loss is 2.1196303
epoch: 1 step: 179, loss is 0.72215146
epoch: 1 step: 180, loss is 0.870613
epoch: 1 step: 181, loss is 1.1956799
epoch: 1 step: 182, loss is 0.90118635
epoch: 1 step: 183, loss is 1.2950683
epoch: 1 step: 184, loss is 1.4789718
epoch: 1 step: 185, loss is 0.6774502
epoch: 1 step: 186, loss is 0.9872845
epoch: 1 step: 187, loss is 2.1856194
epoch: 1 step: 188, loss is 2.2945027
epoch: 1 step: 189, loss is 1.7134911
epoch: 1 step: 190, loss is 1.2342936
epoch: 1 step: 191, loss is 1.117181
epoch: 1 step: 192, loss is 0.87840533
epoch: 1 step: 193, loss is 1.3413175
epoch: 1 step: 194, loss is 1.7594101
epoch: 1 step: 195, loss is 2.8101902
epoch: 1 step: 196, loss is 2.056832
epoch: 1 step: 197, loss is 1.1438655
epoch: 1 step: 198, loss is 1.0083306
epoch: 1 step: 199, loss is 1.7709624
epoch: 1 step: 200, loss is 1.1107464
epoch: 1 step: 201, loss is 0.8215651
epoch: 1 step: 202, loss is 1.5630063
epoch: 1 step: 203, loss is 1.5303402
epoch: 1 step: 204, loss is 1.6619282
epoch: 1 step: 205, loss is 1.0484596
epoch: 1 step: 206, loss is 1.3477086
epoch: 1 step: 207, loss is 1.4278488
epoch: 1 step: 208, loss is 0.93051696
epoch: 1 step: 209, loss is 1.1216402
epoch: 1 step: 210, loss is 1.6794689
epoch: 1 step: 211, loss is 1.9324993
epoch: 1 step: 212, loss is 1.3326603
epoch: 1 step: 213, loss is 0.9417144
epoch: 1 step: 214, loss is 2.2692409
epoch: 1 step: 215, loss is 0.9456654
epoch: 1 step: 216, loss is 2.158712
epoch: 1 step: 217, loss is 1.7745775
epoch: 1 step: 218, loss is 0.8500304
epoch: 1 step: 219, loss is 2.048301
epoch: 1 step: 220, loss is 1.2511446
epoch: 1 step: 221, loss is 2.1398218
epoch: 1 step: 222, loss is 0.9471376
epoch: 1 step: 223, loss is 1.2454917
epoch: 1 step: 224, loss is 1.0045025
epoch: 1 step: 225, loss is 1.1593109
epoch: 1 step: 226, loss is 0.84702426
epoch: 1 step: 227, loss is 0.9293098
epoch: 1 step: 228, loss is 1.3737947
epoch: 1 step: 229, loss is 2.4534392
epoch: 1 step: 230, loss is 0.981399
epoch: 1 step: 231, loss is 1.7862939
epoch: 1 step: 232, loss is 3.1854684
epoch: 1 step: 233, loss is 1.0561717
epoch: 1 step: 234, loss is 1.7223994
epoch: 1 step: 235, loss is 2.1247573
epoch: 1 step: 236, loss is 1.3466293
epoch: 1 step: 237, loss is 1.8926768
epoch: 1 step: 238, loss is 1.3040072
epoch: 1 step: 239, loss is 1.8698696
epoch: 1 step: 240, loss is 1.3131483
epoch: 1 step: 241, loss is 0.94112504
epoch: 1 step: 242, loss is 0.960366
epoch: 1 step: 243, loss is 2.8023486
epoch: 1 step: 244, loss is 1.8175435
epoch: 1 step: 245, loss is 2.3082366
epoch: 1 step: 246, loss is 1.2431244
epoch: 1 step: 247, loss is 1.2784892
epoch: 1 step: 248, loss is 1.3878565
epoch: 1 step: 249, loss is 1.2829117
epoch: 1 step: 250, loss is 2.2661626
epoch: 1 step: 251, loss is 1.0243592
epoch: 1 step: 252, loss is 1.0008858
epoch: 1 step: 253, loss is 1.2566956
epoch: 1 step: 254, loss is 2.3871443
epoch: 1 step: 255, loss is 2.3329723
epoch: 1 step: 256, loss is 1.8692724
epoch: 1 step: 257, loss is 1.8601097
epoch: 1 step: 258, loss is 1.9079914
epoch: 1 step: 259, loss is 1.5481105
epoch: 1 step: 260, loss is 1.3310486
epoch: 1 step: 261, loss is 1.7917584
epoch: 1 step: 262, loss is 1.476254
epoch: 1 step: 263, loss is 1.678522
epoch: 1 step: 264, loss is 0.9226237
epoch: 1 step: 265, loss is 1.6592662
epoch: 1 step: 266, loss is 0.6105888
epoch: 1 step: 267, loss is 2.7194543
epoch: 1 step: 268, loss is 2.0942457
epoch: 1 step: 269, loss is 1.7210892
epoch: 1 step: 270, loss is 1.2116667
epoch: 1 step: 271, loss is 2.621303
epoch: 1 step: 272, loss is 1.612633
epoch: 1 step: 273, loss is 1.0244111
epoch: 1 step: 274, loss is 2.2110302
epoch: 1 step: 275, loss is 1.2532884
epoch: 1 step: 276, loss is 1.3016429
epoch: 1 step: 277, loss is 1.4824853
epoch: 1 step: 278, loss is 1.8323926
epoch: 1 step: 279, loss is 2.1883802
epoch: 1 step: 280, loss is 1.6778849
epoch: 1 step: 281, loss is 1.064788
epoch: 1 step: 282, loss is 1.8052906
epoch: 1 step: 283, loss is 2.0622268
epoch: 1 step: 284, loss is 1.0310879
epoch: 1 step: 285, loss is 1.5694209
epoch: 1 step: 286, loss is 1.5061257
epoch: 1 step: 287, loss is 1.752204
epoch: 1 step: 288, loss is 1.5782819
epoch: 1 step: 289, loss is 1.7041552
epoch: 1 step: 290, loss is 1.9080397
epoch: 1 step: 291, loss is 1.4770608
epoch: 1 step: 292, loss is 0.9793343
epoch: 1 step: 293, loss is 1.0625703
epoch: 1 step: 294, loss is 0.6268948
epoch: 1 step: 295, loss is 0.6231143
epoch: 1 step: 296, loss is 1.698651
epoch: 1 step: 297, loss is 2.8340588
epoch: 1 step: 298, loss is 1.3296031
epoch: 1 step: 299, loss is 1.673642
epoch: 1 step: 300, loss is 1.3425673
epoch: 1 step: 301, loss is 2.072333
epoch: 1 step: 302, loss is 1.6061625
epoch: 1 step: 303, loss is 1.9129176
epoch: 1 step: 304, loss is 1.5781337
epoch: 1 step: 305, loss is 1.4359487
epoch: 1 step: 306, loss is 0.9163391
epoch: 1 step: 307, loss is 1.7451954
epoch: 1 step: 308, loss is 0.9168517
epoch: 1 step: 309, loss is 1.8906676
epoch: 1 step: 310, loss is 1.5679315
epoch: 1 step: 311, loss is 2.2958515
epoch: 1 step: 312, loss is 0.8852064
epoch: 1 step: 313, loss is 1.2897096
epoch: 1 step: 314, loss is 1.7482219
epoch: 1 step: 315, loss is 1.4070395
epoch: 1 step: 316, loss is 3.3530662
epoch: 1 step: 317, loss is 1.4104733
epoch: 1 step: 318, loss is 1.3837914
epoch: 1 step: 319, loss is 1.3505452
epoch: 1 step: 320, loss is 1.5230938
epoch: 1 step: 321, loss is 1.7830057
epoch: 1 step: 322, loss is 1.6079345
epoch: 1 step: 323, loss is 1.5249194
epoch: 1 step: 324, loss is 1.7795275
epoch: 1 step: 325, loss is 1.5974939
epoch: 1 step: 326, loss is 1.8992445
epoch: 1 step: 327, loss is 1.8355097
epoch: 1 step: 328, loss is 1.7659526
epoch: 1 step: 329, loss is 0.860772
epoch: 1 step: 330, loss is 0.9947982
epoch: 1 step: 331, loss is 1.4873415
epoch: 1 step: 332, loss is 1.8094522
epoch: 1 step: 333, loss is 2.6617954
epoch: 1 step: 334, loss is 1.2003756
epoch: 1 step: 335, loss is 1.1727043
epoch: 1 step: 336, loss is 1.8665714
epoch: 1 step: 337, loss is 1.2992274
epoch: 1 step: 338, loss is 1.4808803
epoch: 1 step: 339, loss is 1.8172686
epoch: 1 step: 340, loss is 1.4347913
epoch: 1 step: 341, loss is 1.7422496
epoch: 1 step: 342, loss is 2.092747
epoch: 1 step: 343, loss is 1.3389716
epoch: 1 step: 344, loss is 1.9055254
epoch: 1 step: 345, loss is 1.3471433
epoch: 1 step: 346, loss is 3.1216211
epoch: 1 step: 347, loss is 2.3377044
epoch: 1 step: 348, loss is 1.7978139
epoch: 1 step: 349, loss is 1.6264254
epoch: 1 step: 350, loss is 1.5625271
epoch: 1 step: 351, loss is 1.4955949
epoch: 1 step: 352, loss is 1.3828536
epoch: 1 step: 353, loss is 1.7420508
epoch: 1 step: 354, loss is 1.6943358
epoch: 1 step: 355, loss is 1.2388207
epoch: 1 step: 356, loss is 1.8323805
epoch: 1 step: 357, loss is 1.0489881
epoch: 1 step: 358, loss is 0.97317934
epoch: 1 step: 359, loss is 2.8636708
epoch: 1 step: 360, loss is 1.6132842
epoch: 1 step: 361, loss is 0.98375934
epoch: 1 step: 362, loss is 1.487464
epoch: 1 step: 363, loss is 2.4894373
epoch: 1 step: 364, loss is 1.0135264
epoch: 1 step: 365, loss is 1.879363
epoch: 1 step: 366, loss is 1.331053
epoch: 1 step: 367, loss is 1.1489205
epoch: 1 step: 368, loss is 1.226815
epoch: 1 step: 369, loss is 1.7478739
epoch: 1 step: 370, loss is 1.354594
epoch: 1 step: 371, loss is 1.8299588
epoch: 1 step: 372, loss is 1.2977899
epoch: 1 step: 373, loss is 2.5712938
epoch: 1 step: 374, loss is 1.0928872
epoch: 1 step: 375, loss is 1.6891974
epoch: 1 step: 376, loss is 2.0215597
epoch: 1 step: 377, loss is 1.1820699
epoch: 1 step: 378, loss is 1.0994648
epoch: 1 step: 379, loss is 1.9773788
epoch: 1 step: 380, loss is 1.4613835
epoch: 1 step: 381, loss is 1.8520812
epoch: 1 step: 382, loss is 1.7668234
epoch: 1 step: 383, loss is 0.8450322
epoch: 1 step: 384, loss is 1.5982153
epoch: 1 step: 385, loss is 1.2754704
epoch: 1 step: 386, loss is 2.315125
epoch: 1 step: 387, loss is 0.7088874
epoch: 1 step: 388, loss is 1.8649011
epoch: 1 step: 389, loss is 1.2458955
epoch: 1 step: 390, loss is 0.82165086
epoch: 1 step: 391, loss is 0.9822279
epoch: 1 step: 392, loss is 1.443359
epoch: 1 step: 393, loss is 1.0647573
epoch: 1 step: 394, loss is 0.83201784
epoch: 1 step: 395, loss is 1.8173921
epoch: 1 step: 396, loss is 0.64982957
epoch: 1 step: 397, loss is 2.2437513
epoch: 1 step: 398, loss is 0.6847647
epoch: 1 step: 399, loss is 0.89103276
epoch: 1 step: 400, loss is 1.4338744
epoch: 1 step: 401, loss is 1.8985468
epoch: 1 step: 402, loss is 1.0379727
epoch: 1 step: 403, loss is 2.4956627
epoch: 1 step: 404, loss is 1.0091558
epoch: 1 step: 405, loss is 1.1559733
epoch: 1 step: 406, loss is 1.5052986
epoch: 1 step: 407, loss is 2.4269326
epoch: 1 step: 408, loss is 1.0364431
epoch: 1 step: 409, loss is 2.1860657
epoch: 1 step: 410, loss is 1.7374128
epoch: 1 step: 411, loss is 2.1806374
epoch: 1 step: 412, loss is 1.0271499
epoch: 1 step: 413, loss is 1.4027704
epoch: 1 step: 414, loss is 1.3922296
epoch: 1 step: 415, loss is 1.3656495
epoch: 1 step: 416, loss is 1.0574238
epoch: 1 step: 417, loss is 1.1328485
epoch: 1 step: 418, loss is 1.7255224
epoch: 1 step: 419, loss is 0.7683136
epoch: 1 step: 420, loss is 1.2930223
epoch: 1 step: 421, loss is 1.350565
epoch: 1 step: 422, loss is 2.4475327
epoch: 1 step: 423, loss is 1.7392212
epoch: 1 step: 424, loss is 1.357533
epoch: 1 step: 425, loss is 1.6967012
epoch: 1 step: 426, loss is 1.5918407
epoch: 1 step: 427, loss is 1.1265403
epoch: 1 step: 428, loss is 1.240321
epoch: 1 step: 429, loss is 1.5559906
epoch: 1 step: 430, loss is 1.2735507
epoch: 1 step: 431, loss is 1.7454333
epoch: 1 step: 432, loss is 1.1046945
epoch: 1 step: 433, loss is 1.2196519
epoch: 1 step: 434, loss is 1.7430303
epoch: 1 step: 435, loss is 0.83525527
epoch: 1 step: 436, loss is 0.7670407
epoch: 1 step: 437, loss is 1.5120745
epoch: 1 step: 438, loss is 2.3453658
epoch: 1 step: 439, loss is 1.670254
epoch: 1 step: 440, loss is 1.9222342
epoch: 1 step: 441, loss is 1.5817118
epoch: 1 step: 442, loss is 1.779689
epoch: 1 step: 443, loss is 1.3573812
epoch: 1 step: 444, loss is 1.6718678
epoch: 1 step: 445, loss is 1.9022393
epoch: 1 step: 446, loss is 1.2964453
epoch: 1 step: 447, loss is 1.8454082
epoch: 1 step: 448, loss is 1.7099202
epoch: 1 step: 449, loss is 1.2723054
epoch: 1 step: 450, loss is 1.3928919
epoch: 1 step: 451, loss is 2.593335
epoch: 1 step: 452, loss is 0.957349
epoch: 1 step: 453, loss is 1.1754612
epoch: 1 step: 454, loss is 2.1085362
epoch: 1 step: 455, loss is 0.8977358
epoch: 1 step: 456, loss is 1.029734
epoch: 1 step: 457, loss is 1.7968893
epoch: 1 step: 458, loss is 0.6207457
epoch: 1 step: 459, loss is 1.3950763
epoch: 1 step: 460, loss is 1.5005456
epoch: 1 step: 461, loss is 1.16712
epoch: 1 step: 462, loss is 1.4269016
epoch: 1 step: 463, loss is 1.1122184
epoch: 1 step: 464, loss is 1.3946863
epoch: 1 step: 465, loss is 1.0945292
epoch: 1 step: 466, loss is 1.7415801
epoch: 1 step: 467, loss is 1.197471
epoch: 1 step: 468, loss is 0.9762926
epoch: 1 step: 469, loss is 1.7352037
epoch: 1 step: 470, loss is 2.9586048
epoch: 1 step: 471, loss is 0.7965474
epoch: 1 step: 472, loss is 1.2332946
epoch: 1 step: 473, loss is 1.2858999
epoch: 1 step: 474, loss is 2.675951
epoch: 1 step: 475, loss is 1.5099132
epoch: 1 step: 476, loss is 1.1351485
epoch: 1 step: 477, loss is 1.4003451
epoch: 1 step: 478, loss is 1.3482885
epoch: 1 step: 479, loss is 1.6599878
epoch: 1 step: 480, loss is 2.0790308
epoch: 1 step: 481, loss is 1.1344154
epoch: 1 step: 482, loss is 1.1183742
epoch: 1 step: 483, loss is 1.3431644
epoch: 1 step: 484, loss is 1.891571
epoch: 1 step: 485, loss is 1.8390507
epoch: 1 step: 486, loss is 0.68230194
epoch: 1 step: 487, loss is 2.4264948
epoch: 1 step: 488, loss is 1.1321049
epoch: 1 step: 489, loss is 2.3570054
epoch: 1 step: 490, loss is 1.5638088
epoch: 1 step: 491, loss is 1.0281556
epoch: 1 step: 492, loss is 1.179501
epoch: 1 step: 493, loss is 1.7425754
epoch: 1 step: 494, loss is 1.216838
epoch: 1 step: 495, loss is 1.313088
epoch: 1 step: 496, loss is 1.8575542
epoch: 1 step: 497, loss is 1.1804881
epoch: 1 step: 498, loss is 0.7338756
epoch: 1 step: 499, loss is 1.0781401
epoch: 1 step: 500, loss is 1.1411486
epoch: 1 step: 501, loss is 1.0012753
epoch: 1 step: 502, loss is 2.480338
epoch: 1 step: 503, loss is 2.0434027
epoch: 1 step: 504, loss is 1.9330958
epoch: 1 step: 505, loss is 1.3525292
epoch: 1 step: 506, loss is 1.3453496
epoch: 1 step: 507, loss is 1.1349841
epoch: 1 step: 508, loss is 1.3275613
epoch: 1 step: 509, loss is 1.3702389
epoch: 1 step: 510, loss is 1.5459789
epoch: 1 step: 511, loss is 1.4093105
epoch: 1 step: 512, loss is 1.3102345
epoch: 1 step: 513, loss is 1.4180427
epoch: 1 step: 514, loss is 1.0469306
epoch: 1 step: 515, loss is 1.3003435
epoch: 1 step: 516, loss is 0.9984727
epoch: 1 step: 517, loss is 1.0238106
epoch: 1 step: 518, loss is 1.5167332
epoch: 1 step: 519, loss is 1.3844187
epoch: 1 step: 520, loss is 1.3757316
epoch: 1 step: 521, loss is 1.3516656
epoch: 1 step: 522, loss is 1.8678075
epoch: 1 step: 523, loss is 2.3297002
epoch: 1 step: 524, loss is 1.9795201
epoch: 1 step: 525, loss is 1.1288508
epoch: 1 step: 526, loss is 1.6564837
epoch: 1 step: 527, loss is 1.7353119
epoch: 1 step: 528, loss is 1.5148716
epoch: 1 step: 529, loss is 1.8389469
epoch: 1 step: 530, loss is 1.5839268
epoch: 1 step: 531, loss is 1.1102291
epoch: 1 step: 532, loss is 1.6192364
epoch: 1 step: 533, loss is 1.5492799
epoch: 1 step: 534, loss is 0.66978854
epoch: 1 step: 535, loss is 1.0883406
epoch: 1 step: 536, loss is 1.467099
epoch: 1 step: 537, loss is 1.4951227
epoch: 1 step: 538, loss is 1.22861
epoch: 1 step: 539, loss is 1.8916509
epoch: 1 step: 540, loss is 1.3318044
epoch: 1 step: 541, loss is 1.1179008
epoch: 1 step: 542, loss is 0.86763227
epoch: 1 step: 543, loss is 1.3380413
epoch: 1 step: 544, loss is 0.877036
epoch: 1 step: 545, loss is 1.8263694
epoch: 1 step: 546, loss is 1.9933003
epoch: 1 step: 547, loss is 3.0388918
epoch: 1 step: 548, loss is 1.1376984
epoch: 1 step: 549, loss is 1.7338281
epoch: 1 step: 550, loss is 1.1078362
epoch: 1 step: 551, loss is 1.7842574
epoch: 1 step: 552, loss is 1.847948
epoch: 1 step: 553, loss is 1.3383169
epoch: 1 step: 554, loss is 2.253529
epoch: 1 step: 555, loss is 1.4470934
epoch: 1 step: 556, loss is 2.5758104
epoch: 1 step: 557, loss is 1.2049958
epoch: 1 step: 558, loss is 1.2582105
epoch: 1 step: 559, loss is 1.1200962
epoch: 1 step: 560, loss is 1.4020823
epoch: 1 step: 561, loss is 1.0559015
epoch: 1 step: 562, loss is 1.6117083
epoch: 1 step: 563, loss is 1.5042626
epoch: 1 step: 564, loss is 0.60846215
epoch: 1 step: 565, loss is 1.3206059
epoch: 1 step: 566, loss is 0.9752561
epoch: 1 step: 567, loss is 2.8149579
epoch: 1 step: 568, loss is 1.4749202
epoch: 1 step: 569, loss is 1.5989314
epoch: 1 step: 570, loss is 1.9049196
epoch: 1 step: 571, loss is 2.2102985
epoch: 1 step: 572, loss is 1.076511
epoch: 1 step: 573, loss is 1.2375902
epoch: 1 step: 574, loss is 2.0209324
epoch: 1 step: 575, loss is 1.6674724
epoch: 1 step: 576, loss is 1.8105242
epoch: 1 step: 577, loss is 1.6968787
epoch: 1 step: 578, loss is 1.2098968
epoch: 1 step: 579, loss is 1.1359947
epoch: 1 step: 580, loss is 0.74574137
epoch: 1 step: 581, loss is 1.3472857
epoch: 1 step: 582, loss is 1.4352955
epoch: 1 step: 583, loss is 1.0112945
epoch: 1 step: 584, loss is 1.3712543
epoch: 1 step: 585, loss is 1.2365879
epoch: 1 step: 586, loss is 2.063514
epoch: 1 step: 587, loss is 2.341364
epoch: 1 step: 588, loss is 1.5827311
epoch: 1 step: 589, loss is 1.9994929
epoch: 1 step: 590, loss is 1.7015483
epoch: 1 step: 591, loss is 1.8882269
epoch: 1 step: 592, loss is 1.769226
epoch: 1 step: 593, loss is 1.5512
epoch: 1 step: 594, loss is 1.9845947
epoch: 1 step: 595, loss is 0.9146105
epoch: 1 step: 596, loss is 1.8261771
epoch: 1 step: 597, loss is 1.2112893
epoch: 1 step: 598, loss is 1.3299814
epoch: 1 step: 599, loss is 3.7674086
epoch: 1 step: 600, loss is 1.8343416
epoch: 1 step: 601, loss is 1.2081834
epoch: 1 step: 602, loss is 1.3641579
epoch: 1 step: 603, loss is 0.9953077
epoch: 1 step: 604, loss is 1.2975556
epoch: 1 step: 605, loss is 2.4187882
epoch: 1 step: 606, loss is 1.3597724
epoch: 1 step: 607, loss is 1.2239795
epoch: 1 step: 608, loss is 2.1632843
epoch: 1 step: 609, loss is 1.3404518
epoch: 1 step: 610, loss is 1.459078
epoch: 1 step: 611, loss is 1.462854
epoch: 1 step: 612, loss is 1.3525069
epoch: 1 step: 613, loss is 1.8743639
epoch: 1 step: 614, loss is 2.0211759
epoch: 1 step: 615, loss is 2.3340921
epoch: 1 step: 616, loss is 0.80409557
epoch: 1 step: 617, loss is 1.748388
epoch: 1 step: 618, loss is 1.6724974
epoch: 1 step: 619, loss is 1.7130597
epoch: 1 step: 620, loss is 1.7452227
epoch: 1 step: 621, loss is 1.202181
epoch: 1 step: 622, loss is 0.91535544
epoch: 1 step: 623, loss is 1.2121412
epoch: 1 step: 624, loss is 1.3355671
epoch: 1 step: 625, loss is 0.94627005
epoch: 1 step: 626, loss is 1.1966462
epoch: 1 step: 627, loss is 1.1274178
epoch: 1 step: 628, loss is 1.3206655
epoch: 1 step: 629, loss is 1.4013119
epoch: 1 step: 630, loss is 1.4011741
epoch: 1 step: 631, loss is 1.5077602
epoch: 1 step: 632, loss is 0.8348478
epoch: 1 step: 633, loss is 1.5395391
epoch: 1 step: 634, loss is 0.7114434
epoch: 1 step: 635, loss is 2.5204458
epoch: 1 step: 636, loss is 1.8323419
epoch: 1 step: 637, loss is 1.4461967
epoch: 1 step: 638, loss is 1.1745355
epoch: 1 step: 639, loss is 2.300219
epoch: 1 step: 640, loss is 0.993178
epoch: 1 step: 641, loss is 1.4015422
epoch: 1 step: 642, loss is 0.88615584
epoch: 1 step: 643, loss is 2.5094633
epoch: 1 step: 644, loss is 1.1049241
epoch: 1 step: 645, loss is 1.364984
epoch: 1 step: 646, loss is 1.2748551
epoch: 1 step: 647, loss is 1.0834305
epoch: 1 step: 648, loss is 1.1728472
epoch: 1 step: 649, loss is 1.7560409
epoch: 1 step: 650, loss is 1.2263372
epoch: 1 step: 651, loss is 1.1797118
epoch: 1 step: 652, loss is 1.5146737
epoch: 1 step: 653, loss is 0.6696658
epoch: 1 step: 654, loss is 1.5268859
epoch: 1 step: 655, loss is 1.3210653
epoch: 1 step: 656, loss is 0.6579518
epoch: 1 step: 657, loss is 0.9740461
epoch: 1 step: 658, loss is 1.0518541
epoch: 1 step: 659, loss is 2.487858
epoch: 1 step: 660, loss is 1.2870302
epoch: 1 step: 661, loss is 2.1053555
epoch: 1 step: 662, loss is 2.42398
epoch: 1 step: 663, loss is 1.3543024
epoch: 1 step: 664, loss is 1.5404183
epoch: 1 step: 665, loss is 1.4579546
epoch: 1 step: 666, loss is 1.4293582
epoch: 1 step: 667, loss is 1.5389482
epoch: 1 step: 668, loss is 1.7665721
epoch: 1 step: 669, loss is 1.5622315
epoch: 1 step: 670, loss is 0.7570114
epoch: 1 step: 671, loss is 1.8234335
epoch: 1 step: 672, loss is 2.2982135
epoch: 1 step: 673, loss is 0.67956674
epoch: 1 step: 674, loss is 0.9968495
epoch: 1 step: 675, loss is 2.2092586
epoch: 1 step: 676, loss is 0.9612413
epoch: 1 step: 677, loss is 1.4292934
epoch: 1 step: 678, loss is 1.2439922
epoch: 1 step: 679, loss is 2.4466763
epoch: 1 step: 680, loss is 1.1736733
epoch: 1 step: 681, loss is 1.35983
epoch: 1 step: 682, loss is 2.0899472
epoch: 1 step: 683, loss is 1.3609959
epoch: 1 step: 684, loss is 1.4911706
epoch: 1 step: 685, loss is 1.340486
epoch: 1 step: 686, loss is 1.1201618
epoch: 1 step: 687, loss is 1.5683829
epoch: 1 step: 688, loss is 1.630942
epoch: 1 step: 689, loss is 1.5108232
epoch: 1 step: 690, loss is 1.0738064
epoch: 1 step: 691, loss is 0.9212303
epoch: 1 step: 692, loss is 0.7916537
epoch: 1 step: 693, loss is 0.6979817
epoch: 1 step: 694, loss is 2.3420534
epoch: 1 step: 695, loss is 0.87849545
epoch: 1 step: 696, loss is 1.519859
epoch: 1 step: 697, loss is 0.75940263
epoch: 1 step: 698, loss is 1.9717739
epoch: 1 step: 699, loss is 1.2356297
epoch: 1 step: 700, loss is 0.89456916
epoch: 1 step: 701, loss is 1.1256486
epoch: 1 step: 702, loss is 1.1596085
epoch: 1 step: 703, loss is 1.3044443
epoch: 1 step: 704, loss is 2.1377606
epoch: 1 step: 705, loss is 2.295943
epoch: 1 step: 706, loss is 1.8033284
epoch: 1 step: 707, loss is 1.6320504
epoch: 1 step: 708, loss is 0.9701102
epoch: 1 step: 709, loss is 1.7464141
epoch: 1 step: 710, loss is 1.5633839
epoch: 1 step: 711, loss is 1.6021218
epoch: 1 step: 712, loss is 2.2535717
epoch: 1 step: 713, loss is 1.4982432
epoch: 1 step: 714, loss is 1.2432361
epoch: 1 step: 715, loss is 0.9069972
epoch: 1 step: 716, loss is 1.9144322
epoch: 1 step: 717, loss is 1.0706913
epoch: 1 step: 718, loss is 1.6145098
epoch: 1 step: 719, loss is 1.6405451
epoch: 1 step: 720, loss is 1.4849119
epoch: 1 step: 721, loss is 1.8228263
epoch: 1 step: 722, loss is 1.1857268
epoch: 1 step: 723, loss is 0.97221446
epoch: 1 step: 724, loss is 1.4749986
epoch: 1 step: 725, loss is 1.1818168
epoch: 1 step: 726, loss is 1.9165746
epoch: 1 step: 727, loss is 1.1358266
epoch: 1 step: 728, loss is 1.5439519
epoch: 1 step: 729, loss is 1.5858732
epoch: 1 step: 730, loss is 1.5767627
epoch: 1 step: 731, loss is 2.3513443
epoch: 1 step: 732, loss is 1.9786195
epoch: 1 step: 733, loss is 1.2790738
epoch: 1 step: 734, loss is 1.9605577
epoch: 1 step: 735, loss is 1.0936247
epoch: 1 step: 736, loss is 1.2805703
epoch: 1 step: 737, loss is 1.119981
epoch: 1 step: 738, loss is 1.6804544
epoch: 1 step: 739, loss is 0.8024065
epoch: 1 step: 740, loss is 1.904917
epoch: 1 step: 741, loss is 0.93471247
epoch: 1 step: 742, loss is 0.80442107
epoch: 1 step: 743, loss is 1.6238126
epoch: 1 step: 744, loss is 0.8951113
epoch: 1 step: 745, loss is 1.2869899
epoch: 1 step: 746, loss is 1.0973653
epoch: 1 step: 747, loss is 0.8711085
epoch: 1 step: 748, loss is 1.0454528
epoch: 1 step: 749, loss is 2.1668606
epoch: 1 step: 750, loss is 2.5275462
epoch: 1 step: 751, loss is 1.0681784
epoch: 1 step: 752, loss is 2.933765
epoch: 1 step: 753, loss is 1.4255711
epoch: 1 step: 754, loss is 1.191608
epoch: 1 step: 755, loss is 1.4095919
epoch: 1 step: 756, loss is 1.142974
epoch: 1 step: 757, loss is 1.2644136
epoch: 1 step: 758, loss is 1.9802994
epoch: 1 step: 759, loss is 1.5662123
epoch: 1 step: 760, loss is 1.1805629
epoch: 1 step: 761, loss is 2.0095615
epoch: 1 step: 762, loss is 1.6808025
epoch: 1 step: 763, loss is 1.0026639
epoch: 1 step: 764, loss is 1.3832494
epoch: 1 step: 765, loss is 2.052904
epoch: 1 step: 766, loss is 1.373161
epoch: 1 step: 767, loss is 1.43591
epoch: 1 step: 768, loss is 1.4153485
epoch: 1 step: 769, loss is 0.79898584
epoch: 1 step: 770, loss is 1.3049191
epoch: 1 step: 771, loss is 0.90859413
epoch: 1 step: 772, loss is 1.9727582
epoch: 1 step: 773, loss is 1.1863426
epoch: 1 step: 774, loss is 3.7333474
epoch: 1 step: 775, loss is 1.0496565
epoch: 1 step: 776, loss is 1.3426573
epoch: 1 step: 777, loss is 1.3528662
epoch: 1 step: 778, loss is 1.6478986
epoch: 1 step: 779, loss is 1.3920547
epoch: 1 step: 780, loss is 1.8827794
epoch: 1 step: 781, loss is 1.3012863
epoch: 1 step: 782, loss is 1.9876854
epoch: 1 step: 783, loss is 2.5755439
epoch: 1 step: 784, loss is 2.1082442
epoch: 1 step: 785, loss is 1.5048356
epoch: 1 step: 786, loss is 1.3922017
epoch: 1 step: 787, loss is 1.8908904
epoch: 1 step: 788, loss is 2.599439
epoch: 1 step: 789, loss is 2.303181
epoch: 1 step: 790, loss is 1.3750371
epoch: 1 step: 791, loss is 1.8312877
epoch: 1 step: 792, loss is 2.2078962
epoch: 1 step: 793, loss is 1.5333905
epoch: 1 step: 794, loss is 1.3166717
epoch: 1 step: 795, loss is 1.1350552
epoch: 1 step: 796, loss is 1.6482236
epoch: 1 step: 797, loss is 2.3583062
epoch: 1 step: 798, loss is 1.3941627
epoch: 1 step: 799, loss is 2.0419552
epoch: 1 step: 800, loss is 0.90875214
epoch: 1 step: 801, loss is 0.9101978
epoch: 1 step: 802, loss is 1.8258369
epoch: 1 step: 803, loss is 2.5883722
epoch: 1 step: 804, loss is 1.2424961
epoch: 1 step: 805, loss is 2.3432624
epoch: 1 step: 806, loss is 1.3616704
epoch: 1 step: 807, loss is 2.2598088
epoch: 1 step: 808, loss is 1.1630859
epoch: 1 step: 809, loss is 1.7900795
epoch: 1 step: 810, loss is 1.8361194
epoch: 1 step: 811, loss is 1.3335819
epoch: 1 step: 812, loss is 2.1354847
epoch: 1 step: 813, loss is 1.8652128
epoch: 1 step: 814, loss is 1.2343959
epoch: 1 step: 815, loss is 1.3479736
epoch: 1 step: 816, loss is 1.1582621
epoch: 1 step: 817, loss is 1.0801764
epoch: 1 step: 818, loss is 1.8511257
epoch: 1 step: 819, loss is 1.0402728
epoch: 1 step: 820, loss is 1.8012185
epoch: 1 step: 821, loss is 2.335875
epoch: 1 step: 822, loss is 1.5279088
epoch: 1 step: 823, loss is 1.0979595
epoch: 1 step: 824, loss is 1.8266139
epoch: 1 step: 825, loss is 1.9584131
epoch: 1 step: 826, loss is 1.0054066
epoch: 1 step: 827, loss is 1.8476264
epoch: 1 step: 828, loss is 1.1876247
epoch: 1 step: 829, loss is 2.058835
epoch: 1 step: 830, loss is 1.6776699
epoch: 1 step: 831, loss is 1.271608
epoch: 1 step: 832, loss is 1.4422956
epoch: 1 step: 833, loss is 1.2951993
epoch: 1 step: 834, loss is 0.9828206
epoch: 1 step: 835, loss is 1.0924329
epoch: 1 step: 836, loss is 1.4952904
epoch: 1 step: 837, loss is 1.4025043
epoch: 1 step: 838, loss is 1.9040225
epoch: 1 step: 839, loss is 1.401846
epoch: 1 step: 840, loss is 1.0887309
epoch: 1 step: 841, loss is 1.8236426
epoch: 1 step: 842, loss is 1.2837555
epoch: 1 step: 843, loss is 1.400443
epoch: 1 step: 844, loss is 1.1982963
epoch: 1 step: 845, loss is 1.1106625
epoch: 1 step: 846, loss is 2.2019417
epoch: 1 step: 847, loss is 1.3425461
epoch: 1 step: 848, loss is 1.4576385
epoch: 1 step: 849, loss is 0.957398
epoch: 1 step: 850, loss is 0.8842392
epoch: 1 step: 851, loss is 1.1657537
epoch: 1 step: 852, loss is 1.3742013
epoch: 1 step: 853, loss is 1.1035196
epoch: 1 step: 854, loss is 1.4083391
epoch: 1 step: 855, loss is 2.0467982
epoch: 1 step: 856, loss is 0.8769012
epoch: 1 step: 857, loss is 2.2031364
epoch: 1 step: 858, loss is 1.0501319
epoch: 1 step: 859, loss is 1.5099256
epoch: 1 step: 860, loss is 2.230537
epoch: 1 step: 861, loss is 2.3482902
epoch: 1 step: 862, loss is 1.0410942
epoch: 1 step: 863, loss is 1.8426225
epoch: 1 step: 864, loss is 1.3824652
epoch: 1 step: 865, loss is 2.2829404
epoch: 1 step: 866, loss is 2.2329688
epoch: 1 step: 867, loss is 1.4724648
epoch: 1 step: 868, loss is 1.8667552
epoch: 1 step: 869, loss is 0.9998264
epoch: 1 step: 870, loss is 1.0731683
epoch: 1 step: 871, loss is 1.2733049
epoch: 1 step: 872, loss is 1.4793786
epoch: 1 step: 873, loss is 0.9667675
epoch: 1 step: 874, loss is 1.292985
epoch: 1 step: 875, loss is 0.760757
epoch: 1 step: 876, loss is 1.3204714
epoch: 1 step: 877, loss is 2.2355037
epoch: 1 step: 878, loss is 1.3717364
epoch: 1 step: 879, loss is 1.1240346
epoch: 1 step: 880, loss is 1.1605307
epoch: 1 step: 881, loss is 1.7609496
epoch: 1 step: 882, loss is 1.0830563
epoch: 1 step: 883, loss is 2.1481783
epoch: 1 step: 884, loss is 1.4446073
epoch: 1 step: 885, loss is 1.3095372
epoch: 1 step: 886, loss is 0.7894097
epoch: 1 step: 887, loss is 0.778172
epoch: 1 step: 888, loss is 1.3542448
epoch: 1 step: 889, loss is 0.93296754
epoch: 1 step: 890, loss is 1.2241555
epoch: 1 step: 891, loss is 0.99201494
epoch: 1 step: 892, loss is 1.1960433
epoch: 1 step: 893, loss is 0.84757686
epoch: 1 step: 894, loss is 0.6043976
epoch: 1 step: 895, loss is 0.93262863
epoch: 1 step: 896, loss is 0.9961176
epoch: 1 step: 897, loss is 0.7792142
epoch: 1 step: 898, loss is 1.2804202
epoch: 1 step: 899, loss is 1.8603704
epoch: 1 step: 900, loss is 1.8095875
epoch: 1 step: 901, loss is 2.464622
epoch: 1 step: 902, loss is 2.01489
epoch: 1 step: 903, loss is 1.1549289
epoch: 1 step: 904, loss is 1.5530778
epoch: 1 step: 905, loss is 1.6928029
epoch: 1 step: 906, loss is 1.8224834
epoch: 1 step: 907, loss is 1.2855705
epoch: 1 step: 908, loss is 1.268188
epoch: 1 step: 909, loss is 1.2265819
epoch: 1 step: 910, loss is 1.0300361
epoch: 1 step: 911, loss is 1.946483
epoch: 1 step: 912, loss is 1.3053142
epoch: 1 step: 913, loss is 1.6834987
epoch: 1 step: 914, loss is 2.0741272
epoch: 1 step: 915, loss is 1.184945
epoch: 1 step: 916, loss is 0.81467485
epoch: 1 step: 917, loss is 1.139117
epoch: 1 step: 918, loss is 1.4476173
epoch: 1 step: 919, loss is 1.0973291
epoch: 1 step: 920, loss is 2.1329618
epoch: 1 step: 921, loss is 2.5329778
epoch: 1 step: 922, loss is 1.1555674
epoch: 1 step: 923, loss is 1.1908537
epoch: 1 step: 924, loss is 2.0221212
epoch: 1 step: 925, loss is 1.318033
epoch: 1 step: 926, loss is 1.0614426
epoch: 1 step: 927, loss is 1.35431
epoch: 1 step: 928, loss is 1.1131707
epoch: 1 step: 929, loss is 1.9753772
epoch: 1 step: 930, loss is 1.8827976
epoch: 1 step: 931, loss is 1.6199201
epoch: 1 step: 932, loss is 0.89537066
epoch: 1 step: 933, loss is 0.87524974
epoch: 1 step: 934, loss is 1.3639565
epoch: 1 step: 935, loss is 1.1188961
epoch: 1 step: 936, loss is 2.2195475
epoch: 1 step: 937, loss is 0.6986935
epoch: 1 step: 938, loss is 1.270506
epoch: 1 step: 939, loss is 1.9118234
epoch: 1 step: 940, loss is 0.9057662
epoch: 1 step: 941, loss is 1.3038914
epoch: 1 step: 942, loss is 0.9594154
epoch: 1 step: 943, loss is 1.4642558
epoch: 1 step: 944, loss is 1.4990832
epoch: 1 step: 945, loss is 0.91882765
epoch: 1 step: 946, loss is 1.0566436
epoch: 1 step: 947, loss is 0.9932025
epoch: 1 step: 948, loss is 1.7192651
epoch: 1 step: 949, loss is 1.3222369
epoch: 1 step: 950, loss is 1.9712995
epoch: 1 step: 951, loss is 1.9556195
epoch: 1 step: 952, loss is 1.8104154
epoch: 1 step: 953, loss is 0.94096905
epoch: 1 step: 954, loss is 1.4459616
epoch: 1 step: 955, loss is 1.4283258
epoch: 1 step: 956, loss is 0.9211304
epoch: 1 step: 957, loss is 1.4047655
epoch: 1 step: 958, loss is 1.0073998
epoch: 1 step: 959, loss is 1.5877329
epoch: 1 step: 960, loss is 0.9423544
epoch: 1 step: 961, loss is 1.6313819
epoch: 1 step: 962, loss is 1.0350293
epoch: 1 step: 963, loss is 2.198759
epoch: 1 step: 964, loss is 1.2830672
epoch: 1 step: 965, loss is 1.1649876
epoch: 1 step: 966, loss is 1.0637901
epoch: 1 step: 967, loss is 1.3201073
epoch: 1 step: 968, loss is 1.0360725
epoch: 1 step: 969, loss is 1.642942
epoch: 1 step: 970, loss is 0.7525851
epoch: 1 step: 971, loss is 2.8349874
epoch: 1 step: 972, loss is 0.97390026
epoch: 1 step: 973, loss is 1.4190799
epoch: 1 step: 974, loss is 1.3180606
epoch: 1 step: 975, loss is 1.0637872
epoch: 1 step: 976, loss is 1.9322275
epoch: 1 step: 977, loss is 1.8632646
epoch: 1 step: 978, loss is 0.98935014
epoch: 1 step: 979, loss is 1.5272365
epoch: 1 step: 980, loss is 1.7695069
epoch: 1 step: 981, loss is 1.1859069
epoch: 1 step: 982, loss is 1.591681
epoch: 1 step: 983, loss is 1.6642396
epoch: 1 step: 984, loss is 1.17676
epoch: 1 step: 985, loss is 1.6031128
epoch: 1 step: 986, loss is 0.74796176
epoch: 1 step: 987, loss is 2.1651971
epoch: 1 step: 988, loss is 2.2072072
epoch: 1 step: 989, loss is 0.9877745
epoch: 1 step: 990, loss is 0.81926566
epoch: 1 step: 991, loss is 1.2299742
epoch: 1 step: 992, loss is 1.1739465
epoch: 1 step: 993, loss is 1.5589671
epoch: 1 step: 994, loss is 2.7460325
epoch: 1 step: 995, loss is 1.5822508
epoch: 1 step: 996, loss is 0.71706843
epoch: 1 step: 997, loss is 1.5346807
epoch: 1 step: 998, loss is 1.5048069
epoch: 1 step: 999, loss is 1.0560807
epoch: 1 step: 1000, loss is 1.599892
epoch: 1 step: 1001, loss is 1.1897614
epoch: 1 step: 1002, loss is 1.9205663
epoch: 1 step: 1003, loss is 1.6471921
epoch: 1 step: 1004, loss is 0.9387723
epoch: 1 step: 1005, loss is 1.3150833
epoch: 1 step: 1006, loss is 1.3438991
epoch: 1 step: 1007, loss is 1.5021478
epoch: 1 step: 1008, loss is 1.4890593
epoch: 1 step: 1009, loss is 1.2778269
epoch: 1 step: 1010, loss is 0.6882769
epoch: 1 step: 1011, loss is 0.9347951
epoch: 1 step: 1012, loss is 1.3499761
epoch: 1 step: 1013, loss is 0.9955309
epoch: 1 step: 1014, loss is 1.1137868
epoch: 1 step: 1015, loss is 1.8241159
epoch: 1 step: 1016, loss is 1.8701798
epoch: 1 step: 1017, loss is 2.7673414
epoch: 1 step: 1018, loss is 1.4919015
epoch: 1 step: 1019, loss is 1.9045514
epoch: 1 step: 1020, loss is 0.7976063
epoch: 1 step: 1021, loss is 1.1083182
epoch: 1 step: 1022, loss is 1.2022533
epoch: 1 step: 1023, loss is 0.90036416
epoch: 1 step: 1024, loss is 2.163247
epoch: 1 step: 1025, loss is 1.2025884
epoch: 1 step: 1026, loss is 1.6793993
epoch: 1 step: 1027, loss is 0.98231167
epoch: 1 step: 1028, loss is 2.4958415
epoch: 1 step: 1029, loss is 0.82211256
epoch: 1 step: 1030, loss is 1.3351381
epoch: 1 step: 1031, loss is 1.5578023
epoch: 1 step: 1032, loss is 1.1692554
epoch: 1 step: 1033, loss is 1.1860975
epoch: 1 step: 1034, loss is 1.9775729
epoch: 1 step: 1035, loss is 1.0032942
epoch: 1 step: 1036, loss is 1.4271852
epoch: 1 step: 1037, loss is 1.8364547
epoch: 1 step: 1038, loss is 1.4331493
epoch: 1 step: 1039, loss is 1.723385
epoch: 1 step: 1040, loss is 1.0946501
epoch: 1 step: 1041, loss is 1.0817194
epoch: 1 step: 1042, loss is 1.712061
epoch: 1 step: 1043, loss is 1.1384515
epoch: 1 step: 1044, loss is 0.80386084
epoch: 1 step: 1045, loss is 0.7549771
epoch: 1 step: 1046, loss is 1.0141368
epoch: 1 step: 1047, loss is 1.2612225
epoch: 1 step: 1048, loss is 1.5343955
epoch: 1 step: 1049, loss is 2.06873
epoch: 1 step: 1050, loss is 2.208825
epoch: 1 step: 1051, loss is 1.5479069
epoch: 1 step: 1052, loss is 1.3816451
epoch: 1 step: 1053, loss is 1.1809994
epoch: 1 step: 1054, loss is 1.7838947
epoch: 1 step: 1055, loss is 1.8119993
epoch: 1 step: 1056, loss is 2.1126428
epoch: 1 step: 1057, loss is 1.7385744
epoch: 1 step: 1058, loss is 1.8985868
epoch: 1 step: 1059, loss is 1.0533351
epoch: 1 step: 1060, loss is 1.3286018
epoch: 1 step: 1061, loss is 1.5958496
epoch: 1 step: 1062, loss is 2.735841
epoch: 1 step: 1063, loss is 1.5903555
epoch: 1 step: 1064, loss is 1.8263632
epoch: 1 step: 1065, loss is 1.6997477
epoch: 1 step: 1066, loss is 2.1154797
epoch: 1 step: 1067, loss is 1.0733387
epoch: 1 step: 1068, loss is 1.298177
epoch: 1 step: 1069, loss is 1.4665344
epoch: 1 step: 1070, loss is 1.2869245
epoch: 1 step: 1071, loss is 2.5070689
epoch: 1 step: 1072, loss is 0.8369065
epoch: 1 step: 1073, loss is 1.96385
epoch: 1 step: 1074, loss is 1.0408354
epoch: 1 step: 1075, loss is 1.5579263
epoch: 1 step: 1076, loss is 1.1409992
epoch: 1 step: 1077, loss is 1.9133062
epoch: 1 step: 1078, loss is 1.5850081
epoch: 1 step: 1079, loss is 1.3458362
epoch: 1 step: 1080, loss is 1.0490382
epoch: 1 step: 1081, loss is 1.5765111
epoch: 1 step: 1082, loss is 1.2873293
epoch: 1 step: 1083, loss is 1.3729318
epoch: 1 step: 1084, loss is 1.3120595
epoch: 1 step: 1085, loss is 1.4377011
epoch: 1 step: 1086, loss is 1.6065898
epoch: 1 step: 1087, loss is 1.4036217
epoch: 1 step: 1088, loss is 1.5997933
epoch: 1 step: 1089, loss is 1.0620599
epoch: 1 step: 1090, loss is 1.0169908
epoch: 1 step: 1091, loss is 0.6156008
epoch: 1 step: 1092, loss is 1.5311052
epoch: 1 step: 1093, loss is 1.6317396
epoch: 1 step: 1094, loss is 1.5030129
epoch: 1 step: 1095, loss is 2.3450406
epoch: 1 step: 1096, loss is 0.85104626
epoch: 1 step: 1097, loss is 1.4093329
epoch: 1 step: 1098, loss is 1.2002712
epoch: 1 step: 1099, loss is 1.857006
epoch: 1 step: 1100, loss is 0.74850774
epoch: 1 step: 1101, loss is 0.75107664
epoch: 1 step: 1102, loss is 1.0464947
epoch: 1 step: 1103, loss is 2.3531427
epoch: 1 step: 1104, loss is 1.0908316
epoch: 1 step: 1105, loss is 1.0285232
epoch: 1 step: 1106, loss is 0.88112867
epoch: 1 step: 1107, loss is 1.8172126
epoch: 1 step: 1108, loss is 1.1937814
epoch: 1 step: 1109, loss is 2.5850446
epoch: 1 step: 1110, loss is 0.8868383
epoch: 1 step: 1111, loss is 1.247651
epoch: 1 step: 1112, loss is 0.51564074
epoch: 1 step: 1113, loss is 1.1576463
epoch: 1 step: 1114, loss is 1.1993115
epoch: 1 step: 1115, loss is 1.0398614
epoch: 1 step: 1116, loss is 0.90250474
epoch: 1 step: 1117, loss is 0.8586816
epoch: 1 step: 1118, loss is 0.94835603
epoch: 1 step: 1119, loss is 0.7367901
epoch: 1 step: 1120, loss is 0.5107117
epoch: 1 step: 1121, loss is 1.0653195
epoch: 1 step: 1122, loss is 1.244807
epoch: 1 step: 1123, loss is 2.7920418
epoch: 1 step: 1124, loss is 1.1155611
epoch: 1 step: 1125, loss is 2.9529092
epoch: 1 step: 1126, loss is 2.0163338
epoch: 1 step: 1127, loss is 0.9600493
epoch: 1 step: 1128, loss is 1.6848012
epoch: 1 step: 1129, loss is 1.5389204
epoch: 1 step: 1130, loss is 1.4002346
epoch: 1 step: 1131, loss is 2.5345118
epoch: 1 step: 1132, loss is 1.8191888
epoch: 1 step: 1133, loss is 1.5761399
epoch: 1 step: 1134, loss is 1.6412932
epoch: 1 step: 1135, loss is 1.4581002
epoch: 1 step: 1136, loss is 2.097257
epoch: 1 step: 1137, loss is 2.6064966
epoch: 1 step: 1138, loss is 1.6247393
epoch: 1 step: 1139, loss is 1.4533117
epoch: 1 step: 1140, loss is 1.3305353
epoch: 1 step: 1141, loss is 1.379223
epoch: 1 step: 1142, loss is 1.5438745
epoch: 1 step: 1143, loss is 1.9404602
Train epoch time: 651031.996 ms, per step time: 569.582 ms

因为FCN网络在训练的过程中需要大量的训练数据和训练轮数,这里只提供了小数据单个epoch的训练来演示loss收敛的过程,下文中使用已训练好的权重文件进行模型评估和推理效果的展示。

模型评估

[10]:

IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"
# 下载已训练好的权重文件
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt"
download(url, "FCN8s.ckpt", replace=True)
net = FCN8s(n_class=num_classes)
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,
                     image_std=IMAGE_STD,
                     data_file=DATA_FILE,
                     batch_size=train_batch_size,
                     crop_size=crop_size,
                     max_scale=max_scale,
                     min_scale=min_scale,
                     ignore_label=ignore_label,
                     num_classes=num_classes,
                     num_readers=2,
                     num_parallel_calls=4)
dataset_eval = dataset.get_dataset()
model.eval(dataset_eval)
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt (1.00 GB)file_sizes: 100%|███████████████████████████| 1.08G/1.08G [00:05<00:00, 201MB/s]
Successfully downloaded file to FCN8s.ckpt
-

[10]:

{'pixel accuracy': 0.9729141093358019,'mean pixel accuracy': 0.941011879124398,'mean IoU': 0.894833164943202,'frequency weighted IoU': 0.9477743138691888}

模型推理

使用训练的网络对模型推理结果进行展示。

[11]:

import cv2
import matplotlib.pyplot as plt
net = FCN8s(n_class=num_classes)
# 设置超参
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
eval_batch_size = 4
img_lst = []
mask_lst = []
res_lst = []
# 推理效果展示(上方为输入图片,下方为推理效果图片)
plt.figure(figsize=(8, 5))
show_data = next(dataset_eval.create_dict_iterator())
show_images = show_data["data"].asnumpy()
mask_images = show_data["label"].reshape([4, 512, 512])
show_images = np.clip(show_images, 0, 1)
for i in range(eval_batch_size):
    img_lst.append(show_images[i])
    mask_lst.append(mask_images[i])
res = net(show_data["data"]).asnumpy().argmax(axis=1)
for i in range(eval_batch_size):
    plt.subplot(2, 4, i + 1)
    plt.imshow(img_lst[i].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    plt.subplot(2, 4, i + 5)
    plt.imshow(res[i])
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

总结

FCN的核心贡献在于提出使用全卷积层,通过学习让图片实现端到端分割。与传统使用CNN进行图像分割的方法相比,FCN有两大明显的优点:一是可以接受任意大小的输入图像,无需要求所有的训练图像和测试图像具有固定的尺寸。二是更加高效,避免了由于使用像素块而带来的重复存储和计算卷积的问题。

同时FCN网络也存在待改进之处:

一是得到的结果仍不够精细。进行8倍上采样虽然比32倍的效果好了很多,但是上采样的结果仍比较模糊和平滑,尤其是边界处,网络对图像中的细节不敏感。 二是对各个像素进行分类,没有充分考虑像素与像素之间的关系(如不连续性和相似性)。忽略了在通常的基于像素分类的分割方法中使用的空间规整(spatial regularization)步骤,缺乏空间一致性。

引用

[1]Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks for Semantic Segmentation." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.

[13]:

 
import time
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),'guojun0718')
2024-07-15 03:24:23 guojun0718

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com