pytorch网络模型构建场景的问题介绍

作者:mingqian_chu 时间:2022-07-24 22:38:42 

记录使用pytorch构建网络模型过程遇到的点

1. 网络模型构建中的问题

1.1 输入变量是Tensor张量

各个模块和网络模型的输入,一定要是tensor 张量;

可以用一个列表存放多个张量。

如果是张量维度不够,需要升维度,

可以先使用 torch.unsqueeze(dim = expected)

然后再使用torch.cat(dim ) 进行拼接;

需要传递梯度的数据,禁止使用numpy, 也禁止先使用numpy,然后再转换成张量的这种情况出现;

这是因为pytorch的机制是只有是 Tensor张量的类型,才会有梯度等属性值,如果是numpy这些类别,这些变量并会丢失其梯度值。

1.2 __init__()方法使用

class ex:
   def __init__(self):
       pass

__init__方法必须接受至少一个参数即self,

Python中,self是指向该对象本身的一个引用,

通过在类的内部使用self变量,

类中的方法可以访问自己的成员变量,简单来说,self.varname的意义为”访问该对象的varname属性“

当然,__init__()中可以封装任意的程序逻辑,这是允许的,init()方法还接受任意多个其他参数,允许在初始化时提供一些数据,例如,对于刚刚的worker类,可以这样写:

class worker:
   def __init__(self,name,pay):
       self.name=name
       self.pay=pay

这样,在创建worker类的对象时,必须提供name和pay两个参数:

b=worker('Jim',5000)

Python会自动调用worker.init()方法,并传递参数。

细节参考这里init方法

1.3 内置函数setattr()

此时,可以使用python自带的内置函数 setattr(),和对应的getattr()

setattr(object, name, value)

object – 对象。

name – 字符串,对象属性。

value – 属性值。

对已存在的属性进行赋值:
>>>class A(object):
...     bar = 1
... 
>>> a = A()
>>> getattr(a, 'bar')          # 获取属性 bar 值
1
>>> setattr(a, 'bar', 5)       # 设置属性 bar 值
>>> a.bar
5
如果属性不存在会创建一个新的对象属性,并对属性赋值:

>>>class A():
...     name = "runoob"
... 
>>> a = A()
>>> setattr(a, "age", 28)
>>> print(a.age)
28
>>>

setattr() 语法

setattr(object, name, value)

object – 对象。

name – 字符串,对象属性。

value – 属性值。

1.4 网络模型的构建

注意到,在python的 __init__() 函数中,self 本身就是该类的对象的一个引用,即self是指向该对象本身的一个引用,

利用上述这一点,当在神经网络中,

需要给多个属性进行实例化时,

且这多个属性使用的是同一个类进行实例化.

则使用 setattr(self, string, object1) 添加属性;

class Temporal_GroupTrans(nn.Module):
   def __init__(self,   num_classes=10,num_groups=35, drop_prob=0.5, pretrained= True):
       super(Temporal_GroupTrans, self).__init__()
       conv_block = Basic_slide_conv()
       for i in range( num_groups):
           setattr(self, "group" + str(i), conv_block)
       # 自定义transformer模型的初始化, CustomTransformerModel() 在该类中传入初始化模型的参数,
       # nip:512输入序列中,每个列向量的编码维度,16:注意力头的个数
       # 600:中间mlp 隐藏层的维数,  6: 堆叠transforEncode编码模块的个数;
       self.trans_model = CustomTransformerModel(512,16,600, 6,droupout=0.5,nclass=4)

则使用 getattr(self, string, object1) 获取属性;

trans_input_sequence = []
       for i in range(0, num_groups, ):
           #  每组语谱图的大小是一个 (bt, ch,96,12)的矩阵,组与组之间没有重叠;
           cur_group = x[:, :, :, 12 * i:12 * (i + 1)]
           # VARIABLE_fun = "self.group"   # 每一组,与之对应的卷积模块;
           # cur_fun = eval(VARIABLE_fun + str(i ))
           cur_fun = getattr(self, 'group'+str(i))
           cur_group_out = cur_fun(cur_group).unsqueeze(dim=1)  # [bt,1, 512]
           trans_input_sequence.append(cur_group_out)

来源:https://blog.csdn.net/chumingqian/article/details/129417691

标签:pytorch,网络模型,构建
0
投稿

猜你喜欢

  • Python实战之设计一个多功能办公小工具

    2023-05-26 02:54:11
  • 详解django中视图函数的FBV和CBV

    2022-05-26 01:11:10
  • Python进度条tqdm的用法详解

    2022-09-03 00:27:35
  • Python入门篇之字符串

    2022-01-09 00:31:17
  • 在Django框架中设置语言偏好的教程

    2022-01-18 22:47:35
  • ASP.NET中的几种弹出框提示基本实现方法

    2023-07-13 00:23:50
  • 可能是史上最细的python中import详解

    2023-01-07 18:40:07
  • mysql中普通索引和唯一索引的效率对比

    2010-12-08 16:03:00
  • python中使用多线程改进flask案例

    2022-11-07 05:44:55
  • Python中垃圾回收和del语句详解

    2023-12-20 01:02:55
  • Python机器学习之逻辑回归

    2023-11-18 14:19:02
  • pip安装Python库时遇到的问题及解决方法

    2023-06-20 14:00:01
  • Linux下python3.7.0安装教程

    2021-07-30 05:25:58
  • PyTorch使用GPU训练的两种方法实例

    2023-09-21 08:11:40
  • 基于PHP+Ajax实现表单验证的详解

    2023-11-14 12:52:43
  • 利用Python求阴影部分的面积实例代码

    2021-10-05 15:18:03
  • 对Django的restful用法详解(自带的增删改查)

    2023-11-12 07:45:04
  • python机器学习实战之树回归详解

    2022-02-05 04:37:10
  • pytorch绘制曲线的方法

    2022-09-03 06:43:54
  • 详解Python定时器Timer的使用及示例

    2021-01-25 23:54:34
  • asp之家 网络编程 m.aspxhome.com