简单聊聊装饰器
装饰器是 Python
提供的语法糖:通过 @
的方式构造,传入函数,装饰器返回的是内部函数 wrapper()
,该函数本身不会改变传入的函数,但是会添加一些额外的功能,常见的如:统计运行时间、检查合法输入、提供缓存等。详情可参照:Python-闭包+装饰器。本文主要介绍一些常用的装饰器:让自己的代码看上去更简洁高效吧。
常用装饰器介绍
Python自带的 装饰器 用法:
@lru_cache 基于缓存能够加速代码性能:尤其是递归结构:
1 2 3
@lru_cache def factorial(n): return n * factorial(n-1) if n else 1
@jit jit全称“Just In Time compilation”,也能加速代码运行速度,但是通常是对于数据量大的计算才能显示优势,简单的计算可能还会变慢:
1 2 3 4 5 6 7 8 9 10 11 12
from numba import jit import random @jit(nopython=True) def monte_carlo_pi(nsamples): acc = 0 for i in range(nsamples): x = random.random() y = random.random() if (x ** 2 + y ** 2) < 1.0: acc += 1 return 4.0 * acc / nsamples
@dataclass dataclass装饰器能够自动生成__repr__,__init__,__eq__等方法,方便类的创建:
1 2 3 4 5 6 7 8 9 10
from dataclasses import dataclass @dataclass class Food: name: str, unit_price: float, stockL: int = 0 def stock_value(self): -> float: return (self.stock * self.unit_price)
@functools.wraps 用于消除装饰器的副作用:被装饰的函数成为了另一个函数;使用functolls的wrap后,能够保留原有函数的名称和docstring:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
from functools import wraps def my_decorator(func): @wraps(func) def wrapper(*args, **kwargs): '''decorator''' print('Calling decorated function...') return func(*args, **kwargs) return wrapper @my_decorator def example(): """Docstring""" print('Called example function') print(example.__name__, example.__doc__)
@staticmethod/@classmethod 静态函数/类函数:
1 2 3 4 5 6 7 8
class Example: @staticmethod def our_func(stuff): print(stuff) @classmethod def cls_func(cls, stucc): print(stuff)
@singledispatch 主要用于函数的重载:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
from functools import singledispatch @singledispatch def connect(address): print(address) @connect.register def _(addr: str): ip, port = addr.split(':') print(f'IP:{ip}, port:{port}') @connect.register def _(addr: tuple): ip, port = addr print(f'IP:{ip}, port:{port}')
常用的一些自定义 *装饰器*:
Reusing Function 常用的有”do_twice”,”repeat(times)”等:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
import functools def do_twice(func): @functools.wraps(func) def wrapper_do_twice(*args, **kwargs): func(*args, **kwargs) return func(*args, **kwargs) return wrapper_do_twice def repeat(num_times): def decorator_repeat(func): @functools.wraps(func) def wrapper_repeat(*args, **kwargs): for _ in range(num_times): value = func(*args, **kwargs) return value return wrapper_repeat return decorator_repeat
Timing Functions 计算函数耗时:
1 2 3 4 5 6 7 8 9 10 11 12 13 14
import functools import time def timer(func): """Print the runtime of the decorated function""" @functools.wraps(func) def wrapper_timer(*args, **kwargs): start_time = time.perf_counter() # 1 value = func(*args, **kwargs) end_time = time.perf_counter() # 2 run_time = end_time - start_time # 3 print(f"Finished {func.__name__!r} in {run_time:.4f} secs") return value return wrapper_timer
Debugging Code debug函数的基本信息:
1 2 3 4 5 6 7 8 9 10 11 12 13 14
import functools def debug(func): """Print the function signature and return value""" @functools.wraps(func) def wrapper_debug(*args, **kwargs): args_repr = [repr(a) for a in args] # 1 kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()] # 2 signature = ", ".join(args_repr + kwargs_repr) # 3 print(f"Calling {func.__name__}({signature})") value = func(*args, **kwargs) print(f"{func.__name__!r} returned {value!r}") # 4 return value return wrapper_debug
Slowing Down Code 提供延时功能:
1 2 3 4 5 6 7
def slow_down(func): """Sleep 1 second before calling the function""" @functools.wraps(func) def wrapper_slow_down(*args, **kwargs): time.sleep(1) return func(*args, **kwargs) return wrapper_slow_down
Registering Plugins 实现类和函数的注册(深度学习项目常见):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
import random PLUGINS = dict() def register(func): """Register a function as a plug-in""" PLUGINS[func.__name__] = func return func @register def say_hello(name): return f"Hello {name}" @register def be_awesome(name): return f"Yo {name}, together we are the awesomest!" def randomly_greet(name): greeter, greeter_func = random.choice(list(PLUGINS.items())) print(f"Using {greeter!r}") return greeter_func(name)
Count Calls 统计函数调用次数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14
import functools def count_calls(func): @functools.wraps(func) def wrapper_count_calls(*args, **kwargs): wrapper_count_calls.num_calls += 1 print(f"Call {wrapper_count_calls.num_calls} of {func.__name__!r}") return func(*args, **kwargs) wrapper_count_calls.num_calls = 0 return wrapper_count_calls @count_calls def say_whee(): print("Whee!")
Singletons 单例模式:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
import functools def singleton(cls): """Make a class a Singleton class (only one instance)""" @functools.wraps(cls) def wrapper_singleton(*args, **kwargs): if not wrapper_singleton.instance: wrapper_singleton.instance = cls(*args, **kwargs) return wrapper_singleton.instance wrapper_singleton.instance = None return wrapper_singleton @singleton class TheOne: pass
Validate Json 验证json文件内容:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
from flask import Flask, request, abort import functools app = Flask(__name__) def validate_json(*expected_args): # 1 def decorator_validate_json(func): @functools.wraps(func) def wrapper_validate_json(*args, **kwargs): json_object = request.get_json() for expected_arg in expected_args: # 2 if expected_arg not in json_object: abort(400) return func(*args, **kwargs) return wrapper_validate_json return decorator_validate_json @app.route("/grade", methods=["POST"]) @validate_json("student_id") def update_grade(): json_data = request.get_json() # Update database. return "success!"