用一段小程序来说明
先看torch.nn.ReLU()
import torch
input = torch.Tensor([[1.0, 2.0],
                    [-1.0, 3.0]])
class demo(torch.nn.Module):
    def __init__(self):
        super(demo, self).__init__()
        self.relu = torch.nn.ReLU()
    def forward(self, x):
        return self.relu(x)
model = demo()
print(model(input))

>>
tensor([[1., 2.],
        [0., 3.]])

torch.nn.ReLU必须在继承torch.nn.Module的情况下使用

再看torch.nn.functional.relu()
import torch
input = torch.Tensor([[1.0, 2.0],
                    [-1.0, 3.0]])
print(torch.nn.functional.relu(input))

>>
tensor([[1., 2.],
        [0., 3.]])

torch.nn.functional.relu更像是独立的一个函数,可以自由使用

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐