import torch

a = torch.randn(3, 200, 200)
print(a.dtype)

b = a.type(torch.float16)
print(b.dtype)

c = a.type(torch.int32)
print(c.dtype)

d = a.type(torch.long)
print(d.dtype)

e = a.type(torch.float32)
print(e.dtype)

在这里插入图片描述

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐