特征 ndarray 转 tensor torch.Tensor(train_examples[0][0])报错:
TypeError: new(): data must be a sequence (got numpy.float64)

需要改成:
torch.Tensor(train_examples[0][0].reshape(1, n_feature))

在这里插入图片描述

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐