Pytorch深度学习addmm()和addmm_()函数用法解析

作者:悲恋花丶无心之人 时间:2021-01-02 04:04:25 

一、函数解释

在torch/_C/_VariableFunctions.py的有该定义,意义就是实现一下公式:

Pytorch深度学习addmm()和addmm_()函数用法解析

换句话说,就是需要传入5个参数,mat里的每个元素乘以beta,mat1和mat2进行矩阵乘法(左行乘右列)后再乘以alpha,最后将这2个结果加在一起。但是这样说可能没啥概念,接下来博主为大家写上一段代码,大家就明白了~

def addmm(self, beta=1, mat, alpha=1, mat1, mat2, out=None): # real signature unknown; restored from __doc__
       """
       addmm(beta=1, mat, alpha=1, mat1, mat2, out=None) -> Tensor
       Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`.
       The matrix :attr:`mat` is added to the final result.
       If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a
       :math:`(m \times p)` tensor, then :attr:`mat` must be
       :ref:`broadcastable <broadcasting-semantics>` with a :math:`(n \times p)` tensor
       and :attr:`out` will be a :math:`(n \times p)` tensor.
       :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between
       :attr:`mat1` and :attr`mat2` and the added matrix :attr:`mat` respectively.
       .. math::
           out = \beta\ mat + \alpha\ (mat1_i \mathbin{@} mat2_i)
       For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and
       :attr:`alpha` must be real numbers, otherwise they should be integers.
       Args:
           beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
           mat (Tensor): matrix to be added
           alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
           mat1 (Tensor): the first matrix to be multiplied
           mat2 (Tensor): the second matrix to be multiplied
           out (Tensor, optional): the output tensor
       Example::
           >>> M = torch.randn(2, 3)
           >>> mat1 = torch.randn(2, 3)
           >>> mat2 = torch.randn(3, 3)
           >>> torch.addmm(M, mat1, mat2)
           tensor([[-4.8716,  1.4671, -1.3746],
                   [ 0.7573, -3.9555, -2.8681]])
       """
       pass

二、代码范例

1.先摆出代码,大家可以先复制粘贴运行一下,在之后博主会一一讲解

"""
@author:nickhuang1996
"""
import torch
rectangle_height = 3
rectangle_width = 3
inputs = torch.randn(rectangle_height, rectangle_width)
for i in range(rectangle_height):
   for j in range(rectangle_width):
       inputs[i] = i * torch.ones(rectangle_width)
'''
inputs and its transpose
-->inputs   =   tensor([[0., 0., 0.],
                       [1., 1., 1.],
                       [2., 2., 2.]])
-->inputs_t =   tensor([[0., 1., 2.],
                       [0., 1., 2.],
                       [0., 1., 2.]])
'''
print("inputs:\n", inputs)
inputs_t = inputs.t()
print("inputs_t:\n", inputs_t)
'''
inputs_t @ inputs_t    [[0., 1., 2.],       [[0., 1., 2.],          [[0., 3., 6.]
                   =   [0., 1., 2.],   @    [0., 1., 2.],     =     [0., 3., 6.]
                       [0., 1., 2.]]        [0., 1., 2.]]           [0., 3., 6.]]
'''
'''a, b, c and d = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
a = torch.addmm(input=inputs, mat1=inputs_t, mat2=inputs_t)
b = inputs.addmm(mat1=inputs_t, mat2=inputs_t)
c = torch.addmm(input=inputs, beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
d = inputs.addmm(beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
'''e and f = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
e = torch.addmm(inputs, inputs_t, inputs_t)
f = inputs.addmm(inputs_t, inputs_t)
'''1 * inputs + 1 * (inputs_t @ inputs_t)'''
g = inputs.addmm(1, inputs_t, inputs_t)
'''2 * inputs + 1 * (inputs_t @ inputs_t)'''
g2 = inputs.addmm(2, inputs_t, inputs_t)
'''h = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
h = inputs.addmm(1, 1, inputs_t, inputs_t)
'''h12 = 1 * inputs + 2 * (inputs_t @ inputs_t)'''
h12 = inputs.addmm(1, 2, inputs_t, inputs_t)
'''h21 = 2 * inputs + 1 * (inputs_t @ inputs_t)'''
h21 = inputs.addmm(2, 1, inputs_t, inputs_t)
print("a:\n", a)
print("b:\n", b)
print("c:\n", c)
print("d:\n", d)
print("e:\n", e)
print("f:\n", f)
print("g:\n", g)
print("g2:\n", g2)
print("h:\n", h)
print("h12:\n", h12)
print("h21:\n", h21)
print("inputs:\n", inputs)
'''inputs = 1 * inputs - 2 * (inputs @ inputs_t)'''
'''
inputs @ inputs_t       [[0., 0., 0.],       [[0., 1., 2.],          [[0., 0., 0.]
                   =    [1., 1., 1.],   @    [0., 1., 2.],     =     [0., 3., 6.]
                        [2., 2., 2.]]        [0., 1., 2.]]           [0., 6., 12.]]
'''
inputs.addmm_(1, -2, inputs, inputs_t)  # In-place
print("inputs:\n", inputs)

2.其中

inputs是一个3&times;3的矩阵,为

tensor([[0., 0., 0.],
       [1., 1., 1.],
       [2., 2., 2.]])

inputs_t也是一个3&times;3的矩阵,是inputs的转置矩阵,为

tensor([[0., 1., 2.],
       [0., 1., 2.],
       [0., 1., 2.]])

* inputs_t @ inputs_t为

'''
inputs_t @ inputs_t    [[0., 1., 2.],       [[0., 1., 2.],          [[0., 3., 6.]
                   =   [0., 1., 2.],   @    [0., 1., 2.],     =     [0., 3., 6.]
                       [0., 1., 2.]]        [0., 1., 2.]]           [0., 3., 6.]]
'''

3.代码中a,b,c和d展示的是完全形式,即标明了位置参数和传入参数。可以看到input这个位置参数可以写在函数的前面,即

torch.addmm(input, mat1, mat2) = inputs.addmm(mat1, mat2)

完成的公式为:

1 &times; inputs + 1 &times;(inputs_t @ inputs_t)

'''a, b, c and d = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
a = torch.addmm(input=inputs, mat1=inputs_t, mat2=inputs_t)
b = inputs.addmm(mat1=inputs_t, mat2=inputs_t)
c = torch.addmm(input=inputs, beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
d = inputs.addmm(beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
a:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])
b:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])
c:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])
d:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])

4.下面的例子更好了说明了input参数的位置可变性,并且beta和alpha都缺省了:

完成的公式为:

1 &times; inputs + 1 &times;(inputs_t @ inputs_t)

'''e and f = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
e = torch.addmm(inputs, inputs_t, inputs_t)
f = inputs.addmm(inputs_t, inputs_t)
e:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])
f:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])

5.加一个参数,实际上是添加了beta这个参数

完成的公式为:

g   = 1 &times; inputs + 1 &times;(inputs_t @ inputs_t)

g2 = 2 &times; inputs + 1 &times;(inputs_t @ inputs_t)

'''1 * inputs + 1 * (inputs_t @ inputs_t)'''
g = inputs.addmm(1, inputs_t, inputs_t)
'''2 * inputs + 1 * (inputs_t @ inputs_t)'''
g2 = inputs.addmm(2, inputs_t, inputs_t)
g:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])
g2:
tensor([[ 0.,  3.,  6.],
       [ 2.,  5.,  8.],
       [ 4.,  7., 10.]])

6.再加一个参数,实际上是添加了alpha这个参数

完成的公式为:

h   = 1 &times; inputs + 1 &times;(inputs_t @ inputs_t)

h12 = 1 &times; inputs + 2 &times;(inputs_t @ inputs_t)

h21 = 2 &times; inputs + 1 &times;(inputs_t @ inputs_t)

'''h = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
h = inputs.addmm(1, 1, inputs_t, inputs_t)
'''h12 = 1 * inputs + 2 * (inputs_t @ inputs_t)'''
h12 = inputs.addmm(1, 2, inputs_t, inputs_t)
'''h21 = 2 * inputs + 1 * (inputs_t @ inputs_t)'''
h21 = inputs.addmm(2, 1, inputs_t, inputs_t)
h:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])
h12:
tensor([[ 0.,  6., 12.],
       [ 1.,  7., 13.],
       [ 2.,  8., 14.]])
h21:
tensor([[ 0.,  3.,  6.],
       [ 2.,  5.,  8.],
       [ 4.,  7., 10.]])

7.当然,以上的步骤inputs没有变化,还是为

inputs:
tensor([[0., 0., 0.],
       [1., 1., 1.],
       [2., 2., 2.]])

8.addmm_()的操作和addmm()函数功能相同,区别就是addmm_()有inplace的操作,也就是在原对象基础上进行修改,即把改变之后的变量再赋给原来的变量。例如:

inputs的值变成了改变之后的值,不用再去写 某个变量=addmm_() 了,因为inputs就是改变之后的变量!

*inputs@ inputs_t为

'''
inputs @ inputs_t       [[0., 0., 0.],       [[0., 1., 2.],          [[0., 0., 0.]
                   =    [1., 1., 1.],   @    [0., 1., 2.],     =     [0., 3., 6.]
                        [2., 2., 2.]]        [0., 1., 2.]]           [0., 6., 12.]]
'''

完成的公式为:

inputs   = 1 &times; inputs - 2 &times;(inputs @ inputs_t)

'''inputs = 1 * inputs - 2 * (inputs @ inputs_t)'''
inputs.addmm_(1, -2, inputs, inputs_t)  # In-place
inputs:
tensor([[  0.,   0.,   0.],
       [  1.,  -5., -11.],
       [  2., -10., -22.]])

三、代码运行结果

inputs:
tensor([[0., 0., 0.],
       [1., 1., 1.],
       [2., 2., 2.]])
inputs_t:
tensor([[0., 1., 2.],
       [0., 1., 2.],
       [0., 1., 2.]])
a:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])
b:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])
c:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])
d:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])
e:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])
f:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])
g:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])
g2:
tensor([[ 0.,  3.,  6.],
       [ 2.,  5.,  8.],
       [ 4.,  7., 10.]])
h:
tensor([[0., 3., 6.],
       [1., 4., 7.],
       [2., 5., 8.]])
h12:
tensor([[ 0.,  6., 12.],
       [ 1.,  7., 13.],
       [ 2.,  8., 14.]])
h21:
tensor([[ 0.,  3.,  6.],
       [ 2.,  5.,  8.],
       [ 4.,  7., 10.]])
inputs:
tensor([[0., 0., 0.],
       [1., 1., 1.],
       [2., 2., 2.]])
inputs:
tensor([[  0.,   0.,   0.],
       [  1.,  -5., -11.],
       [  2., -10., -22.]])

来源:https://nickhuang1996.blog.csdn.net/article/details/90638449

标签:Pytorch,深度学习,函数,addmm(),addmm,()
0
投稿

猜你喜欢

  • 《写给大家看的设计书》阅读笔记之色彩

    2009-07-30 12:45:00
  • Windows10安装Oracle19c数据库详细记录(图文详解)

    2024-01-23 20:13:15
  • 如何使用Python修改matplotlib.pyplot.colorbar的位置以对齐主图

    2021-09-28 18:01:30
  • 成功解决ValueError: Supported target types are:('binary', 'multiclass'). Got 'continuous' instead.

    2023-01-24 03:59:00
  • pyecharts结合flask框架的使用

    2022-12-01 18:37:25
  • Python入门教程(五)Python变量的用法

    2021-04-05 11:17:35
  • SQL SERVER 分组求和sql语句

    2024-01-13 04:43:16
  • ASP.NET MVC4入门教程(八):给数据模型添加校验器

    2024-06-05 09:27:38
  • vue3+vite使用jsx和tsx详情

    2024-05-10 14:15:47
  • Layui事件监听的实现(表单和数据表格)

    2024-05-09 15:01:11
  • myFocus 一个KindEditor的焦点图插件

    2023-11-09 03:56:09
  • Python中的高级数据结构详解

    2022-02-20 04:01:09
  • Pycharm+Python+PyQt5使用详解

    2021-08-20 06:39:25
  • Python通过DOM和SAX方式解析XML的应用实例分享

    2023-10-15 10:46:32
  • ubuntu下搭建Go语言(golang)环境

    2024-05-11 09:08:47
  • python通过exifread模块获得图片exif信息的方法

    2023-08-18 05:00:15
  • python自动化测试selenium操作下拉列表实现

    2023-09-06 00:26:50
  • Selenium之模拟登录铁路12306的示例代码

    2022-01-22 17:06:27
  • python调用系统ffmpeg实现视频截图、http发送

    2021-05-20 13:18:53
  • 用python修改excel表某一列内容的操作方法

    2022-01-22 20:51:29
  • asp之家 网络编程 m.aspxhome.com