Pytorch 矩阵相乘

Pytorch 矩阵相乘torch.bmm()torch.matmul()torch.bmm()强制规定维度和大小相同torch.matmul()没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作当进行操作的两个tensor都是3D时,两者等同。torch.bmm()官网:https://pytorch.org/docs/stable/torch.html#torch.bmmtorch.bmm(input,mat2,out=None)→Tensortorch.bmm()是ten…

torch.bmm()

torch.matmul()

torch.bmm()强制规定维度和大小相同

torch.matmul()没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作

当进行操作的两个tensor都是3D时,两者等同。

torch.bmm()

torch.bmm(input, mat2, out=None) → Tensor

 torch.bmm()是tensor中的一个相乘操作,类似于矩阵中的A*B。

参数:

input,mat2:两个要进行相乘的tensor结构,两者必须是3D维度的,每个维度中的大小是相同的。

output:输出结果

并且相乘的两个矩阵,要满足一定的维度要求:input(p,m,n) * mat2(p,n,a) ->output(p,m,a)。这个要求,可以类比于矩阵相乘。前一个矩阵的列等于后面矩阵的行才可以相乘。
 

torch.matmul()
torch.matmul(input, other, out=None) → Tensor

 torch.matmul()也是一种类似于矩阵相乘操作的tensor联乘操作。但是它可以利用python 中的广播机制,处理一些维度不同的tensor结构进行相乘操作。这也是该函数与torch.bmm()区别所在。

参数:

input,other:两个要进行操作的tensor结构

output:结果

举例:

a = torch.randn(2,3,2)

a = tensor([[[ 0.4198, -1.6376],
         [-1.0197, -0.1295],
         [-0.2412,  0.2189]],
        [[-0.1045,  1.8026],
         [-0.5264, -0.9585],
         [ 2.4333, -0.3726]]])

b = torch.randn(2,1)

tensor([[-1.1622],
        [-0.6326]])

torch.mm(a[0,:],b)

tensor([[0.5481],
        [1.2670],
        [0.1419]])

torch.mm(a[1,:],b)

tensor([[-1.0190],
        [ 1.2181],
        [-2.5922]])

c = torch.matmul(a,b)

c[0,:]=tensor([[0.5481],
        [1.2670],
        [0.1419]])

c[1,:]= tensor([[-1.0190],
        [ 1.2181],
        [-2.5922]])

今天的文章Pytorch 矩阵相乘分享到此就结束了,感谢您的阅读。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
如需转载请保留出处:https://bianchenghao.cn/11809.html

(0)
编程小号编程小号

相关推荐

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注