pytorch中index_select() 用法案例与解析

index_select(input, dim, index)
功能:在指定的维度dim上选取数据,不如选取某些行,列

参数介绍

  • 第一个参数input是要索引查找的对象
  • 第二个参数dim是要查找的维度,因为通常情况下我们使用的都是二维张量,所以可以简单的记忆: 0代表,1代表
  • 第三个参数index是你要索引的序列,它是一个tensor对象

下面简单的看几个案例
首先简单的创建一个矩阵

x = torch.rand(5,4)
print(x)

tensor([[0.6198, 0.4874, 0.2826, 0.1908],
        [0.3884, 0.1720, 0.8688, 0.1023],
        [0.3972, 0.6469, 0.4800, 0.9155],
        [0.7255, 0.8646, 0.4741, 0.2681],
        [0.6407, 0.3080, 0.5546, 0.7326]])

如果我们想要查看矩阵的第一列信息
最简单的方法就是 直接用切片取值

print(x[:,1]) 
tensor([0.4874, 0.1720, 0.6469, 0.8646, 0.3080])

如果使用index_select()方法则如下,三种得到的结果是一样的
注意:
这里的 dim 参数为 1 代表列,
input参数根据具体情况来写:
如果是torch.index_select那么就像下面的第三条语句,需要写上查找的对象x

print(x.index_select(1,torch.tensor([1]))) # 第 1 列
print(x.index_select(1,torch.tensor(1))) # 第 1 列
print(torch.index_select(x, 1,torch.tensor([1]))) # 第 1 列
tensor([[0.4874],
        [0.1720],
        [0.6469],
        [0.8646],
        [0.3080]])
tensor([[0.4874],
        [0.1720],
        [0.6469],
        [0.8646],
        [0.3080]])
tensor([[0.4874],
        [0.1720],
        [0.6469],
        [0.8646],
        [0.3080]])

再比如查找0,1列

print(x.index_select(0,torch.tensor([0,1]))) # 0,1行
print(x.index_select(1,torch.tensor([0,1]))) # 0,1列
tensor([[0.6198, 0.4874, 0.2826, 0.1908],
        [0.3884, 0.1720, 0.8688, 0.1023]])
tensor([[0.6198, 0.4874],
        [0.3884, 0.1720],
        [0.3972, 0.6469],
        [0.7255, 0.8646],
        [0.6407, 0.3080]])

我们可以创建一个多维的数据来试试
首先创建一个三维的矩阵 dim为 0, 1, 2

x = torch.linspace(1,24,24).view(2,3,4)
print(x)
tensor([[[ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.]],

        [[ 7.,  8.],
         [ 9., 10.],
         [11., 12.]],

        [[13., 14.],
         [15., 16.],
         [17., 18.]],

        [[19., 20.],
         [21., 22.],
         [23., 24.]]])

0维的 0,1,2

print(x.index_select(0,torch.tensor([0,1,2])))
tensor([[[ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.]],

        [[ 7.,  8.],
         [ 9., 10.],
         [11., 12.]],

        [[13., 14.],
         [15., 16.],
         [17., 18.]],

        [[19., 20.],
         [21., 22.],
         [23., 24.]]])

1维的0 1

print(x.index_select(1,torch.tensor([0,1])))
tensor([[[ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.]],

        [[ 7.,  8.],
         [ 9., 10.],
         [11., 12.]],

        [[13., 14.],
         [15., 16.],
         [17., 18.]]])
tensor([[[ 1.,  2.],
         [ 3.,  4.]],

        [[ 7.,  8.],
         [ 9., 10.]],

        [[13., 14.],
         [15., 16.]],

        [[19., 20.],
         [21., 22.]]])

2维的 0 1

print(x.index_select(2,torch.tensor([0,1])))
tensor([[[ 1.,  2.],
         [ 3.,  4.],
         [ 5.,  6.]],

        [[ 7.,  8.],
         [ 9., 10.],
         [11., 12.]],

        [[13., 14.],
         [15., 16.],
         [17., 18.]],

        [[19., 20.],
         [21., 22.],
         [23., 24.]]])
Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐