torch.bmm()
torch.matmul()
torch.bmm()强制规定维度和大小相同
torch.matmul()没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作
当进行操作的两个tensor都是3D时,两者等同。
torch.bmm()
torch.bmm(input, mat2, out=None) → Tensor
torch.bmm()是tensor中的一个相乘操作,类似于矩阵中的A*B。
参数:
input,mat2:两个要进行相乘的tensor结构,两者必须是3D维度的,每个维度中的大小是相同的。
output:输出结果
并且相乘的两个矩阵,要满足一定的维度要求:input(p,m,n) * mat2(p,n,a) ->output(p,m,a)。这个要求,可以类比于矩阵相乘。前一个矩阵的列等于后面矩阵的行才可以相乘。
torch.matmul()
torch.matmul(input, other, out=None) → Tensor
torch.matmul()也是一种类似于矩阵相乘操作的tensor联乘操作。但是它可以利用python 中的广播机制,处理一些维度不同的tensor结构进行相乘操作。这也是该函数与torch.bmm()区别所在。
参数:
input,other:两个要进行操作的tensor结构
output:结果
举例:
a = torch.randn(2,3,2)
a = tensor([[[ 0.4198, -1.6376],
[-1.0197, -0.1295],
[-0.2412, 0.2189]],
[[-0.1045, 1.8026],
[-0.5264, -0.9585],
[ 2.4333, -0.3726]]])
b = torch.randn(2,1)
tensor([[-1.1622],
[-0.6326]])
torch.mm(a[0,:],b)
tensor([[0.5481],
[1.2670],
[0.1419]])
torch.mm(a[1,:],b)
tensor([[-1.0190],
[ 1.2181],
[-2.5922]])
c = torch.matmul(a,b)
c[0,:]=tensor([[0.5481],
[1.2670],
[0.1419]])
c[1,:]= tensor([[-1.0190],
[ 1.2181],
[-2.5922]])
今天的文章Pytorch 矩阵相乘分享到此就结束了,感谢您的阅读。
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/11809.html