Posts Python-单元测试
Post
Cancel

Python-单元测试

前言

单元测试(unit test),对于程序规范开发,尤其是测试驱动开发(TDD),非常重要。单元测试,其核心是:编写测试来验证某一个模块的功能正确性,一般会指定输入,验证输出是否符合预期。 对Python而言,常用的测试相关库有:unitest/nose/pytest等。这里重点介绍下unitest,为Python内置库,模仿PyUnit写的。

基本使用(unittest)

简单看下 unittest 库的简单使用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import unittest

# 将要被测试的排序函数
def sort(arr):
    l = len(arr)
    for i in range(0, l):
        for j in range(i + 1, l):
            if arr[i] >= arr[j]:
                tmp = arr[i]
                arr[i] = arr[j]
                arr[j] = tmp


# 编写子类继承unittest.TestCase
class TestSort(unittest.TestCase):

   # 以test开头的函数将会被测试
   def test_sort(self):
        arr = [3, 4, 1, 5, 6]
        sort(arr)
        # assert 结果跟我们期待的一样
        self.assertEqual(arr, [1, 3, 4, 5, 6])

if __name__ == '__main__':
    ## 如果在Jupyter下,请用如下方式运行单元测试
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

    ## 如果是命令行下运行,则:
    ## unittest.main()

## 输出
..
----------------------------------------------------------------------
Ran 2 tests in 0.002s

OK

分析下上述代码: TestSort 类继承自 unittest.TestCase ,然后以 test 开头作为测试函数,进行测试。而测试函数内部通常的断言函数: assertEqual()/assertTrue()/assertFalse()/assertRaise()

mock

mock 为单元测试中最核心的一环, mock 的核心是:通过一个虚假对象,来替代被测试函数或者模块需要对象。 mock 的应用场景在于:部分模块测试依赖于其他模块,这样就可以通过 mock 来创建一些虚假的对象,以便为后续模块做测试。 Python mock 主要使用 mock/MagicMock 对象来实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import unittest
from unittest.mock import MagicMock

class A(unittest.TestCase):
    def m1(self):
        val = self.m2()
        self.m3(val)

    def m2(self):
        pass

    def m3(self, val):
        pass

    def test_m1(self):
        a = A()
        a.m2 = MagicMock(return_value="custom_val")
        a.m3 = MagicMock()
        a.m1()
        self.assertTrue(a.m2.called) #验证m2被call过
        a.m3.assert_called_with("custom_val") #验证m3被指定参数call过

if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

## 输出
..
----------------------------------------------------------------------
Ran 2 tests in 0.002s

OK

分析上述代码, test_m1 实现的是对 m1() 函数的测试,而 m1() 函数调用了 m2/m3 函数。 mock 这里做的工作便是将 m2 函数替换为返回具体数值的”custom_val”, m3 替换为空函数。 上述 mock 的使用是一种最简单的方法,只能规定返回的value。另外可以通过 Mock Side Effect 来对 mock 函数进行输入/输出进行设计:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from unittest.mock import MagicMock
def side_effect(arg):
    if arg < 0:
        return 1
    else:
        return 2
mock = MagicMock()
mock.side_effect = side_effect

mock(-1)
1

mock(1)
2

上述 Mock Side Effect 便实现了一个简单根据输入是否小于0来返回不同的值。

patch

patch 实际是提供了一种非常方便的 mock 方法,可以通过 decoration/context manager 模式快速mock所需的函数或者模块,如对默认的 sort 函数进行 mock ,以此可以设计对应的 return_value/side_effect

1
2
3
4
5
6
from unittest.mock import patch

@patch('sort')
def test_sort(self, mock_sort):
    ...
    ...

patch 也可以对类成员函数进行 mock ,如某类复杂的初始化函数可以mock为 None 的函数,避免复杂的初始化:

1
2
with patch.object(A, '__init__', lambda x: None):
      

高质量单元测试

如何设计高质量的单元测试?一些比较关键的参数:

  • 覆盖率( Test Coverage ) : 衡量代码中语句被cover的百分比,可以说,提高代码模块的覆盖率,一定程度上等同于提高代码的正确性
  • 模块化:核心是从测试角度去开发代码,去思考如何模块化代码。

比如,一个 前处理/sort/后处理 的代码,简单描述如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def work(arr):
    # pre process
    ...
    ...
    # sort
    l = len(arr)
    for i in range(0, l):
        for j in range(i + 1, j):
            if arr[i] >= arr[j]:
                tmp = arr[i]
                arr[i] = arr[j]
                arr[j] = tmp
    # post process
    ...
    ...
    Return arr

但是上述代码明显很难进行模块化测试,而模块化代码后可以设计为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def preprocess(arr):
    ...
    ...
    return arr

def sort(arr):
    ...
    ...
    return arr

def postprocess(arr):
    ...
    return arr

def work(self):
    arr = preprocess(arr)
    arr = sort(arr)
    arr = postprocess(arr)
    return arr

基于上述开发代码可以设计对应的单元测试如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from unittest.mock import patch

def test_preprocess(self):
    ...

def test_sort(self):
    ...

def test_postprocess(self):
    ...

@patch('%s.preprocess')
@patch('%s.sort')
@patch('%s.postprocess')
def test_work(self,mock_post_process, mock_sort, mock_preprocess):
    work()
    self.assertTrue(mock_post_process.called)
    self.assertTrue(mock_sort.called)
    self.assertTrue(mock_preprocess.called)
This post is licensed under CC BY 4.0 by the author.

Python-上下文管理器

Transformer在CV领域的应用与部署