Skip to content Skip to sidebar Skip to footer

How To Dynamically Index The Tensor In Pytorch?

For example, I got a tensor: tensor = torch.rand(12, 512, 768) And I got an index list, say it is: [0,2,3,400,5,32,7,8,321,107,100,511] I wish to select 1 element out of 512 elem

Solution 1:

There is also a way just using PyTorch and avoiding the loop using indexing and torch.split:

tensor = torch.rand(12, 512, 768)

# create tensor with idx
idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]
# convert list to tensor
idx_tensor = torch.tensor(idx_list) 

# indexing and splitting
list_of_tensors = tensor[:, idx_tensor, :].split(1, dim=1)

When you call tensor[:, idx_tensor, :] you will get a tensor of shape:
(12, len_of_idx_list, 768).
Where the second dimension depends on your number of indices.

Using torch.split this tensor is split into a list of tensors of shape: (12, 1, 768).

So finally list_of_tensors contains tensors of the shape:

[torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768])]

Solution 2:

Yes, you can directly slice it using the index and then use torch.unsqueeze() to promote the 2D tensor to 3D:

# inputs
In [6]: tensor = torch.rand(12, 512, 768)
In [7]: idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]

# slice using the index and then put a singleton dimension along axis 1
In [8]: for idx in idx_list:
   ...:     sampled_tensor = torch.unsqueeze(tensor[:, idx, :], 1)
   ...:     print(sampled_tensor.shape)
   ...:     
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])

Alternatively, if you want it even more terse code and don't want to use torch.unsqueeze(), then use:

In [11]: for idx in idx_list:
    ...:     sampled_tensor = tensor[:, [idx], :]
    ...:     print(sampled_tensor.shape)
    ...:     
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])
torch.Size([12, 1, 768])

Note: there's no need to use a for loop if you wish to do this slicing only for one idx from idx_list


Post a Comment for "How To Dynamically Index The Tensor In Pytorch?"