[AI编程从入门到入土] 装饰器decorator

个人导航

知乎:https://www.zhihu.com/people/byzh_rc

CSDN:https://blog.csdn.net/qq_54636039

注:本文仅对所述内容做了框架性引导,具体细节可查询其余相关资料or源码

参考文章:各方资料

装饰器decorator

  • registration decorator: 只注册, 不改行为
  • wrapper decorator: 添加行为 (使用wraps)
1. registration decorator - 情况1
@xxx
def test():
  pass

等价于

def test():
    pass
test = xxx(test)

test 函数先被创建 -> 然后作为参数传给 xxx -> xxx 返回一个新对象 -> 再覆盖原来的 test

此时xxx接收到的参数是func

2. registration decorator - 情况2
@xxx(abc)
def test():
  pass

等价于

def test():
    pass
test = xxx(abc)(test)

test 函数先被创建 -> 然后作为参数传给 xxx(abc) -> xxx(abc) 返回一个新对象 -> 再覆盖原来的 test

此时xxx接收到的参数是abc

注册表registry

注册表本质是字典, 装饰器的第一个参数是func:

# 创建注册表
registry = {}

# 创建装饰器
def register(func):
    registry[func.__name__] = func
    return func

使用:

@register
def hello():
    print("hello")

@register
def world():
    print("world")

执行后, registry就变成:

{
    "hello": <function hello>,
    "world": <function world>,
}

就可以如此调用:

registry["hello"]()

装饰器工厂

registry = {}

def register(name):

    def decorator(func):
        registry[name] = func
        return func

    return decorator

@register("add")
def func1():
    print("111")
# 等价于: func1 = register("add")(func1)

@register("sub")
def func2():
    print("222")
# 等价于: func2 = register("sub")(func2)

decorator 分类

类型 作用
registration 注册
wrapper 包装行为
cache 缓存
retry 重试
permission 权限
validation 参数校验
singleton 单例
async 异步
transaction 事务
logging 日志
injection 依赖注入

标准 decorator 模板

from functools import wraps

def decorator(func):

    @wraps(func)
    def wrapper(*args, **kwargs):
        
        # before
        result = func(*args, **kwargs)
        # after
        
        return result

    return wrapper
1. 类注册decorator
MODELS = {}

def register(name):

    def decorator(cls):
        MODELS[name] = cls
        return cls

    return decorator

@register("resnet")
class resnet:
    ...

@register("lstm")
class lstm:
    ...
2. 计时器decorator
import time
from functools import wraps

def timer(func):

    @wraps(func)
    def wrapper(*args, **kwargs):

        start = time.time()

        result = func(*args, **kwargs)

        end = time.time()

        print(f"{func.__name__} 耗时: {end - start:.4f}s")

        return result

    return wrapper
3. 参数检查decorator
from functools import wraps

def check_non_negative(func):

    @wraps(func)
    def wrapper(x):

        if x < 0:
            raise ValueError("不能小于0")

        return func(x)

    return wrapper
4. 类decorator
from functools import wraps

def enhance(cls):

    # 动态增加类属性
    cls.version = "1.0"
    cls.author = "byzh"
    cls.category = "AI"

    # 动态增加实例方法
    def info(self):
        print("========== INFO ==========")
        print("class    :", cls.__name__)
        print("name     :", self.name)
        print("version  :", cls.version)
        print("author   :", cls.author)
        print("category :", cls.category)
    cls.info = info

    return cls

AI训练常用decorator

def benchmark(func):

    @wraps(func)
    def wrapper(*args, **kwargs):

        import tracemalloc
        import time

        tracemalloc.start()

        start = time.time()

        result = func(*args, **kwargs)

        current, peak = tracemalloc.get_traced_memory()

        end = time.time()

        print(f"time={end-start}")
        print(f"peak={peak/1024/1024:.2f}MB")

        return result

    return wrapper
Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐