pytorch的torch.add()torch.split()函数

import torch
# outputs是一个[batch, seq, 40]维的tensor,把outputs分割成两个[batch, seq, 20]的tensor,并每个元素求平均值
add = torch.add(*torch.split(outputs, 20, dim=2)) / 2
Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐