1、当使用如下代码,想保存一个tensor的一部分时


large = torch.arange(1, 1000)
small = large[0:5]
torch.save(small, 'small.pt')
loaded_small = torch.load('small.pt')
loaded_small.storage().size()
# 999

最后保存的结果却不是,large[0:5],而是整个large,
想要解决这个问题,需要加一个使用tensorclone函数。

small.clone()

完整保存代码

large = torch.arange(1, 1000)
small = large[0:5]
torch.save(small.clone(), 'small.pt')  # saves a clone of small
loaded_small = torch.load('small.pt')
loaded_small.storage().size()
# 5
Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐