Posts Python-常用装饰器
Post
Cancel

Python-常用装饰器

简单聊聊装饰器

装饰器是 Python 提供的语法糖:通过 @ 的方式构造,传入函数,装饰器返回的是内部函数 wrapper() ,该函数本身不会改变传入的函数,但是会添加一些额外的功能,常见的如:统计运行时间、检查合法输入、提供缓存等。详情可参照:Python-闭包+装饰器。本文主要介绍一些常用的装饰器:让自己的代码看上去更简洁高效吧。

常用装饰器介绍

Python自带的 装饰器 用法:

  • @lru_cache 基于缓存能够加速代码性能:尤其是递归结构:

    1
    2
    3
    
    @lru_cache
    def factorial(n):
        return n * factorial(n-1) if n else 1
    
  • @jit jit全称“Just In Time compilation”,也能加速代码运行速度,但是通常是对于数据量大的计算才能显示优势,简单的计算可能还会变慢:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    
    from numba import jit
    import random
        
    @jit(nopython=True)
    def monte_carlo_pi(nsamples):
        acc = 0
        for i in range(nsamples):
            x = random.random()
            y = random.random()
            if (x ** 2 + y ** 2) < 1.0:
                acc += 1
        return 4.0 * acc / nsamples
    
  • @dataclass dataclass装饰器能够自动生成__repr__,__init__,__eq__等方法,方便类的创建:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    
    from dataclasses import dataclass
        
    @dataclass
    class Food:
        name: str,
        unit_price: float,
        stockL: int = 0
        
        def stock_value(self): -> float:
            return (self.stock * self.unit_price)
    
  • @functools.wraps 用于消除装饰器的副作用:被装饰的函数成为了另一个函数;使用functolls的wrap后,能够保留原有函数的名称和docstring:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    
    from functools import wraps
        
    def my_decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            '''decorator'''
            print('Calling decorated function...')
            return func(*args, **kwargs)
        return wrapper
        
    @my_decorator
    def example():
        """Docstring"""
        print('Called example function')
    print(example.__name__, example.__doc__)
    
  • @staticmethod/@classmethod 静态函数/类函数:

    1
    2
    3
    4
    5
    6
    7
    8
    
    class Example:
        @staticmethod
        def our_func(stuff):
            print(stuff)
        
        @classmethod
        def cls_func(cls, stucc):
            print(stuff)
    
  • @singledispatch 主要用于函数的重载:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    
    from functools import singledispatch
        
    @singledispatch
    def connect(address):
        print(address)
        
    @connect.register
    def _(addr: str):
        ip, port = addr.split(':')
        print(f'IP:{ip}, port:{port}')
        
    @connect.register
    def _(addr: tuple):
        ip, port = addr
        print(f'IP:{ip}, port:{port}')
    

常用的一些自定义 *装饰器*:

  • Reusing Function 常用的有”do_twice”,”repeat(times)”等:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    
    import functools
        
    def do_twice(func):
        @functools.wraps(func)
        def wrapper_do_twice(*args, **kwargs):
            func(*args, **kwargs)
            return func(*args, **kwargs)
        return wrapper_do_twice
        
    def repeat(num_times):
        def decorator_repeat(func):
            @functools.wraps(func)
            def wrapper_repeat(*args, **kwargs):
                for _ in range(num_times):
                    value = func(*args, **kwargs)
                return value
            return wrapper_repeat
        return decorator_repeat
    
  • Timing Functions 计算函数耗时:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    
    import functools
    import time
        
    def timer(func):
        """Print the runtime of the decorated function"""
        @functools.wraps(func)
        def wrapper_timer(*args, **kwargs):
            start_time = time.perf_counter()    # 1
            value = func(*args, **kwargs)
            end_time = time.perf_counter()      # 2
            run_time = end_time - start_time    # 3
            print(f"Finished {func.__name__!r} in {run_time:.4f} secs")
            return value
        return wrapper_timer
    
  • Debugging Code debug函数的基本信息:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    
    import functools
        
    def debug(func):
        """Print the function signature and return value"""
        @functools.wraps(func)
        def wrapper_debug(*args, **kwargs):
            args_repr = [repr(a) for a in args]                      # 1
            kwargs_repr = [f"{k}={v!r}" for k, v in kwargs.items()]  # 2
            signature = ", ".join(args_repr + kwargs_repr)           # 3
            print(f"Calling {func.__name__}({signature})")
            value = func(*args, **kwargs)
            print(f"{func.__name__!r} returned {value!r}")           # 4
            return value
        return wrapper_debug
    
  • Slowing Down Code 提供延时功能:

    1
    2
    3
    4
    5
    6
    7
    
    def slow_down(func):
        """Sleep 1 second before calling the function"""
        @functools.wraps(func)
        def wrapper_slow_down(*args, **kwargs):
            time.sleep(1)
            return func(*args, **kwargs)
        return wrapper_slow_down
    
  • Registering Plugins 实现类和函数的注册(深度学习项目常见):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    
    import random
    PLUGINS = dict()
        
    def register(func):
        """Register a function as a plug-in"""
        PLUGINS[func.__name__] = func
        return func
        
    @register
    def say_hello(name):
        return f"Hello {name}"
        
    @register
    def be_awesome(name):
        return f"Yo {name}, together we are the awesomest!"
        
    def randomly_greet(name):
        greeter, greeter_func = random.choice(list(PLUGINS.items()))
        print(f"Using {greeter!r}")
        return greeter_func(name)
    
  • Count Calls 统计函数调用次数:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    
    import functools
        
    def count_calls(func):
        @functools.wraps(func)
        def wrapper_count_calls(*args, **kwargs):
            wrapper_count_calls.num_calls += 1
            print(f"Call {wrapper_count_calls.num_calls} of {func.__name__!r}")
            return func(*args, **kwargs)
        wrapper_count_calls.num_calls = 0
        return wrapper_count_calls
        
    @count_calls
    def say_whee():
        print("Whee!")
    
  • Singletons 单例模式:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    
    import functools
        
    def singleton(cls):
        """Make a class a Singleton class (only one instance)"""
        @functools.wraps(cls)
        def wrapper_singleton(*args, **kwargs):
            if not wrapper_singleton.instance:
                wrapper_singleton.instance = cls(*args, **kwargs)
            return wrapper_singleton.instance
        wrapper_singleton.instance = None
        return wrapper_singleton
        
    @singleton
    class TheOne:
        pass
    
  • Validate Json 验证json文件内容:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    
    from flask import Flask, request, abort
    import functools
    app = Flask(__name__)
        
    def validate_json(*expected_args):                  # 1
        def decorator_validate_json(func):
            @functools.wraps(func)
            def wrapper_validate_json(*args, **kwargs):
                json_object = request.get_json()
                for expected_arg in expected_args:      # 2
                    if expected_arg not in json_object:
                        abort(400)
                return func(*args, **kwargs)
            return wrapper_validate_json
        return decorator_validate_json
        
        
    @app.route("/grade", methods=["POST"])
    @validate_json("student_id")
    def update_grade():
        json_data = request.get_json()
        # Update database.
        return "success!"
    

部分参考资料

Primer on Python Decorators 10 Fabulous Python Decorators

This post is licensed under CC BY 4.0 by the author.

设计模式-UML

设计模式-中介模式