Gathers values along an axis specified by dim.
For a 2-D tensor the output is specified by:
out[i][j]= input[index[i][j]][j] # if dim == 0 out[i][j] = input[i][index[i][j]] # if dim == 1
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Parameters
- input (Tensor) – the source tensor
- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to gather
- out (Tensor, optional) – the destination tensor
- sparse_grad (bool,optional) – If
True
, gradient w.r.t.input
will be a sparse tensor.
Example:
>>> import torch as t >>> a = t.arange(0,20).view(4,5) >>> a tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15]]) #选取对角线的元素 >>> index = t.LongTensor([[0,1,2,3]]) >>> a.gather(0,index) tensor([[ 0, 5, 10, 15]])
举个例子
import torch a = torch.arange(0,20).view(4,5) >>> a tensor([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]])
index3的shape为2×4
index3 = torch.LongTensor([[0,1,2,3],[2,3,4,0]]) index3 >>> tensor([[0, 1, 2, 3], [2, 3, 4, 0]])
具体:
index[0][0]=0 index[0][1]=1 index[0][2]=2 index[0][3]=3 index[1][0]=2 index[1][1]=3 index[1][2]=4 index[1][3]=0
- dim =1
那么
a.gather(1,index3) >>> tensor([[0, 1, 2, 3], [7, 8, 9, 5]])
out 的shape 与index的shape 一致,为2×4
根据 out[i][j] = input[i][index[i][j]]
out[0][0]=input[0][0]=0 out[0][1]=input[0][1]=1 out[0][2]=input[0][2]=2 out[0][3]=input[0][3]=3 out[1][0]=input[1][2]=7 out[1][1]=input[1][3]=8 out[1][2]=input[1][4]=9 out[1][3]=input[1][0]=5
- dim=0
index5的shape,为4×1
index5 = torch.LongTensor([[0,1,2,3]]).t() index5 >>> tensor([[0], [1], [2], [3]])
具体
index[0][0]=0 index[1][0]=1 index[2][0]=2 index[3][0]=3
那么
a.gather(0,index5) >>> tensor([[ 0], [ 5], [10], [15]])
out 的shape 与index的shape 一致,为4×1
out[0][0]=input[0][0]=0 out[1][0]=input[1][0]=5 out[2][0]=input[2][0]=10 out[3][0]=input[3][0]=15
参考:
利用pytorch中的gather函数取出矩阵中的元素_耳东鹿其-CSDN博客blog.csdn.net
今天的文章
pytorch gather_nd_discourage的用法分享到此就结束了,感谢您的阅读。
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/60244.html