pytorch中index_select() 用法案例与解析
pytorch中index_select() 用法案例与解析index_select(input, dim, index)功能:在指定的维度dim上选取数据,不如选取某些行,列参数介绍第一个参数input是要索引查找的对象第二个参数dim是要查找的维度,因为通常情况下我们使用的都是二维张量,所以可以简单的记忆: 0代表行,1代表列第三个参数index是你要索引的序列,它是一个tens...
·
pytorch中index_select() 用法案例与解析
index_select(input, dim, index)
功能:在指定的维度dim上选取数据,不如选取某些行,列
参数介绍
- 第一个参数
input是要索引查找的对象 - 第二个参数
dim是要查找的维度,因为通常情况下我们使用的都是二维张量,所以可以简单的记忆:0代表行,1代表列 - 第三个参数
index是你要索引的序列,它是一个tensor对象
下面简单的看几个案例
首先简单的创建一个矩阵
x = torch.rand(5,4)
print(x)
tensor([[0.6198, 0.4874, 0.2826, 0.1908],
[0.3884, 0.1720, 0.8688, 0.1023],
[0.3972, 0.6469, 0.4800, 0.9155],
[0.7255, 0.8646, 0.4741, 0.2681],
[0.6407, 0.3080, 0.5546, 0.7326]])
如果我们想要查看矩阵的第一列信息
最简单的方法就是 直接用切片取值
print(x[:,1])
tensor([0.4874, 0.1720, 0.6469, 0.8646, 0.3080])
如果使用index_select()方法则如下,三种得到的结果是一样的
注意:
这里的 dim 参数为 1 代表列,input参数根据具体情况来写:
如果是torch.index_select那么就像下面的第三条语句,需要写上查找的对象x
print(x.index_select(1,torch.tensor([1]))) # 第 1 列
print(x.index_select(1,torch.tensor(1))) # 第 1 列
print(torch.index_select(x, 1,torch.tensor([1]))) # 第 1 列
tensor([[0.4874],
[0.1720],
[0.6469],
[0.8646],
[0.3080]])
tensor([[0.4874],
[0.1720],
[0.6469],
[0.8646],
[0.3080]])
tensor([[0.4874],
[0.1720],
[0.6469],
[0.8646],
[0.3080]])
再比如查找0,1列
print(x.index_select(0,torch.tensor([0,1]))) # 0,1行
print(x.index_select(1,torch.tensor([0,1]))) # 0,1列
tensor([[0.6198, 0.4874, 0.2826, 0.1908],
[0.3884, 0.1720, 0.8688, 0.1023]])
tensor([[0.6198, 0.4874],
[0.3884, 0.1720],
[0.3972, 0.6469],
[0.7255, 0.8646],
[0.6407, 0.3080]])
我们可以创建一个多维的数据来试试
首先创建一个三维的矩阵 dim为 0, 1, 2
x = torch.linspace(1,24,24).view(2,3,4)
print(x)
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.]],
[[ 7., 8.],
[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.],
[17., 18.]],
[[19., 20.],
[21., 22.],
[23., 24.]]])
0维的 0,1,2
print(x.index_select(0,torch.tensor([0,1,2])))
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.]],
[[ 7., 8.],
[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.],
[17., 18.]],
[[19., 20.],
[21., 22.],
[23., 24.]]])
1维的0 1
print(x.index_select(1,torch.tensor([0,1])))
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.]],
[[ 7., 8.],
[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.],
[17., 18.]]])
tensor([[[ 1., 2.],
[ 3., 4.]],
[[ 7., 8.],
[ 9., 10.]],
[[13., 14.],
[15., 16.]],
[[19., 20.],
[21., 22.]]])
2维的 0 1
print(x.index_select(2,torch.tensor([0,1])))
tensor([[[ 1., 2.],
[ 3., 4.],
[ 5., 6.]],
[[ 7., 8.],
[ 9., 10.],
[11., 12.]],
[[13., 14.],
[15., 16.],
[17., 18.]],
[[19., 20.],
[21., 22.],
[23., 24.]]])
更多推荐




所有评论(0)