定义模型

net = nn.Sequential(
    nn.Linear(28 * 28, 400),
    nn.ReLU(),
    nn.Linear(400, 200),
    nn.ReLU(),
    nn.Linear(200, 100),
    nn.ReLU(),
    nn.Linear(100, 10)
).cuda()

1. torch自身方法获取参数量

total = sum([param.nelement() for param in net.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))  

2. torchsummary库获取模型的参数量

from torchsummary import summary

summary(net, input_size=(784,))

3. thop库获取模型参数量和计算量

from thop import profile, clever_format


myinput = torch.zeros((1, 1, 784)).cuda()
flops, params = profile(net, inputs=myinput)
flops, params = clever_format([flops, params], "%.3f")
print(flops, params)

分析:

这个模型参数量为415.31K,但是浮点计算量为414.6K,因为ReLU等层不需要浮点计算

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐