编写基本的单元测试来验证代码的行为。
使用的库:unittest
单元测试框架
python的unittest
库的基本单元测试框架可以表示为:
import unittestclass XXXTests(unittest.TestCase): # 第一个测试集@classmethoddef setUpClass(self):...self.x, self.y = _get_xy() # example funcdef test_a(self): # 第一个单元测试...def test_b(self): # 第二个单元测试self.assertEqual(self.x, self.y)class YYYTests(unittest.TestCase): # 第二个测试集...if __name__ == '__main__':unittest.main()
简单来说,派生的每个
TestCase
子类,表示一个单元测试集,其中的每个test_xxx
函数,都是一个单独的单元测试。
组织单元测试时,可以按照层级式的划分来进行。在单个文件中,单元测试有两个层级:TestCase类派生,以及TestCase下的test_xxx
函数。unittest的单元测试对函数名有要求,必须以test_
开头,才能被当作一个测试函数存在。每个单元测试内可以存在多个assert
函数。
单元测试可以直接写在源码文件中,在if __name__ == '__main__':
时调用unittest.main()
,使得只有在main直接运行源码文件才会进行单元测试。此外,通过__all__
变量也能进一步约束对源码文件的导入行为,通过from xxx import *
的形式导入该源码文件时,只能使用__all__
中的成员。
__all__ = ['UsefulClass']
class UsefulClass():def __init__()### ---
# unittests
### ---
class ClassTests(unittest.TestCase):def test_UsefulClass():[some test code]
assert方法
在单元测试中,通过调用.assertXXX
方法来自动验证某些关键信息。例如
# 输出形状
class ModelTests(unittest.TestCase):def test_asserts(self):x, y = get_xy()self.assertEqual(x, y)self.assertTrue(x==y)...
通常来讲,只要传入的参数有对应的重写运算符,就可以简单地调用.assertXXX
方法来做验证。
使用单元测试检查深度学习组件的行为
可以通过简单检查输出shape
的方法来测试深度学习组件的行为,例如
import torch
import torch.nn as nn
import unittest
from resnet50 import ResNet50class ConvTests(unittest.TestCase):def test_Conv2d(self):B, C, H, W = 5, 128, 28, 28x = torch.rand((B, C, H, W))conv = nn.Conv2d(in_channels=C, out_channels=C*2, kernel_size=3, stride=2, padding=1)out = conv(x)self.assertEqual(out.shape, torch.Size([B, C*2, H//2, W//2]))def test_ResNet50(self):input = _get_input()model = ResNet50()out = model(input)self.assertEqual(out.shape, torch.Size([shape values]))