深入理解 Python Mock 库:Mock,patch 和 patch.object 的实现原理
Python 中 unittest.mock
是我经常用到的一个库,它提供了非常方便的 mock 功能,可以帮助我们写出更好的单元测试。本文不是介绍 unittest.mock
的使用,而是探讨它的实现原理,从零开始实现 unittest.mock
中的 Mock
,patch
和 patch.object
,帮助我们更好地理解它们的工作原理。
如何 mock 一个函数
假如有一个函数 random_boolean
,实现如下:
import random
def random_boolean(threshold=0.5):
return random.random() < threshold
我现在想给这个函数写一个单元测试,测试它的返回值是否符合预期,应该怎么做呢?
这时候可以 mock random.random()
,让它总是返回小于 0.5 的值,测试 random_boolean
的返回值是否为 True
。
from unittest.mock import patch
from my_module import random_boolean
def test_random_boolean():
with patch('random.random', new=lambda: random.uniform(0, 0.5)):
assert random_boolean() == True
接下来看看如何实现 patch
,让它完成上述的功能呢?
实现 patch
import importlib
class Patch:
def __init__(self, target, new):
self.target = target
self.new = new
self.original = None
def __enter__(self):
# Split the target into module and attribute
parts = self.target.split('.')
module_name = '.'.join(parts[:-1])
attr_name = parts[-1]
# Import the module and get the original attribute
module = importlib.import_module(module_name)
self.original = getattr(module, attr_name)
# Replace the original attribute with the new one
setattr(module, attr_name, self.new)
return self.new
def __exit__(self, exc_type, exc_val, exc_tb):
# Restore the original attribute when exiting the with block
parts = self.target.split('.')
module_name = '.'.join(parts[:-1])
attr_name = parts[-1]
module = importlib.import_module(module_name)
setattr(module, attr_name, self.original)
def patch(target, new):
return Patch(target, new)
实现的核心思路是,通过 importlib.import_module
导入需要 patch 的模块(random
),然后通过 setattr
将模块中的需要 patch 的属性(random 模块中的 random()
函数)替换为新的值(lambda: random.uniform(0, 0.5)
)。替换只在 with
语句块中生效,with
语句块结束后,再将原来的值赋回去。
如何 mock 一个类方法
假设现在有这样一个类:
class ProductionClass:
def method(self):
return self.something(1, 2, 3)
def something(self, a, b, c):
pass
我想 mock something
方法,可以用 unittest.mock.patch.object
来实现:
from unittest.mock import patch
from my_module import ProductionClass
def test_method():
with patch.object(ProductionClass, 'something', new=lambda self, a, b, c: 3) as mock_method:
assert ProductionClass().method() == 3
实现 patch.object
class PatchObject:
def __init__(self, target, attr, new):
self.target = target
self.attr = attr
self.new = new
self.original = None
def __enter__(self):
# Get the original attribute
self.original = getattr(self.target, self.attr)
# Replace the original attribute with the new one
setattr(self.target, self.attr, self.new)
return self.new
def __exit__(self, exc_type, exc_val, exc_tb):
# Restore the original attribute when exiting the with block
setattr(self.target, self.attr, self.original)
def patch_object(target, attr, new):
return PatchObject(target, attr, new)
实现的核心思路是,通过 getattr
获取需要 patch 的属性(something
方法),然后通过 setattr
将属性替换为新的值(lambda self, a, b, c: 3
)。替换只在 with
语句块中生效,with
语句块结束后,再将原来的值赋回去。
实现 Mock
class Mock:
def __init__(self, return_value=None):
self._methods = {}
self.return_value = return_value
self.call_count = 0
self.call_args = None
def __getattr__(self, name):
if name not in self._methods:
self._methods[name] = Mock()
return self._methods[name]
def __setattr__(self, name, value):
if isinstance(value, Mock):
self._methods[name] = value
else:
super().__setattr__(name, value)
def __call__(self, *args, **kwargs):
self.call_count += 1
self.call_args = (args, kwargs)
if self.return_value is None:
return Mock()
return self.return_value
Mock 类的主要工作原理是,当我们试图访问它的一个属性时,如果这个属性不存在,那么它就会创建一个新的 Mock 实例并返回。这样,我们就可以无限制地访问它的属性,每个属性都是一个新的 Mock 实例。当我们调用 Mock 实例时,它会记录调用的次数和参数,并返回一个预设的值或者一个新的 Mock 实例。
通过上述的讨论,我们深入地了解了 Python 的 mock 库的实现原理。我们了解了如何通过修改模块的属性来模拟函数或者方法的行为,以及如何通过 Mock 类来模拟对象的行为。然而,实际的 mock 库的实现要比我们讨论的更复杂,它还包括了很多其他的特性,例如 side_effect、call_args_list 等。如果你对这个主题感兴趣,我鼓励你去阅读 Python 官方文档或者 mock 库的源代码,以获取更深入的理解。