pytorch 将整数标签转成one-hot编码
import torchnum_class = 5N = 3tensor = torch.randint(0, num_class, [N])print(tensor)one_hot = torch.zeros(N, num_class).long()one_hot.scatter_(dim=1,index=tensor.unsqueeze(dim=1),src=torch.ones(N, num
·
import torch
num_class = 5
N = 3
tensor = torch.randint(0, num_class, [N])
print(tensor)
one_hot = torch.zeros(N, num_class).long()
one_hot.scatter_(dim=1,index=tensor.unsqueeze(dim=1),src=torch.ones(N, num_class).long())
print(one_hot)

更多推荐




所有评论(0)