class Extractor(nn.Module):

    def __init__(self):
        super(Extractor, self).__init__()

        self.mlp = nn.Linear(768, 1024),
        self.flatten = nn.Flatten()

    def forward(self, x):

        x = self.mlp(x)
        x = self.flatten(x)
        return x

跑深度学习代码的时候突然总报这个错误,后来才发现是代码后面不小心加了逗号。

把逗号去掉就可以了。

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐