pytorch gather_nd_discourage的用法

pytorch gather_nd_discourage的用法Gathersvaluesalonganaxisspecifiedbydim.Fora2-Dtensortheoutputisspecifiedby:out[i][j]=input[index[i][j]][j]#if

Gathers values along an axis specified by dim.

For a 2-D tensor the output is specified by:

out[i][j]= input[index[i][j]][j] # if dim == 0 out[i][j] = input[i][index[i][j]] # if dim == 1

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

Parameters

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather
  • out (Tensor, optional) – the destination tensor
  • sparse_grad (bool,optional) – If True, gradient w.r.t. input will be a sparse tensor.

Example:

>>> import torch as t >>> a = t.arange(0,20).view(4,5) >>> a tensor([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15]]) #选取对角线的元素 >>> index = t.LongTensor([[0,1,2,3]]) >>> a.gather(0,index) tensor([[ 0, 5, 10, 15]])

举个例子

import torch a = torch.arange(0,20).view(4,5) >>> a tensor([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18, 19]]) 

index3的shape为2×4

index3 = torch.LongTensor([[0,1,2,3],[2,3,4,0]]) index3 >>> tensor([[0, 1, 2, 3], [2, 3, 4, 0]])

具体:

index[0][0]=0 index[0][1]=1 index[0][2]=2 index[0][3]=3 index[1][0]=2 index[1][1]=3 index[1][2]=4 index[1][3]=0

  • dim =1

那么

a.gather(1,index3) >>> tensor([[0, 1, 2, 3], [7, 8, 9, 5]])

out 的shape 与index的shape 一致,为2×4

根据 out[i][j] = input[i][index[i][j]]

out[0][0]=input[0][0]=0 out[0][1]=input[0][1]=1 out[0][2]=input[0][2]=2 out[0][3]=input[0][3]=3 out[1][0]=input[1][2]=7 out[1][1]=input[1][3]=8 out[1][2]=input[1][4]=9 out[1][3]=input[1][0]=5

  • dim=0

index5的shape,为4×1

index5 = torch.LongTensor([[0,1,2,3]]).t() index5 >>> tensor([[0], [1], [2], [3]])

具体

index[0][0]=0 index[1][0]=1 index[2][0]=2 index[3][0]=3

那么

a.gather(0,index5) >>> tensor([[ 0], [ 5], [10], [15]])

out 的shape 与index的shape 一致,为4×1

out[0][0]=input[0][0]=0 out[1][0]=input[1][0]=5 out[2][0]=input[2][0]=10 out[3][0]=input[3][0]=15


参考:

利用pytorch中的gather函数取出矩阵中的元素_耳东鹿其-CSDN博客​blog.csdn.net

79303a91ae9b8f46b2bb970f4dcfeb60.png

今天的文章
pytorch gather_nd_discourage的用法分享到此就结束了,感谢您的阅读。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/60244.html

(0)
编程小号编程小号

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注