pytorch 网络参数 weight bias 初始化详解

作者:Ibelievesunshine 时间:2023-08-12 07:43:57 

权重初始化对于训练神经网络至关重要,好的初始化权重可以有效的避免梯度消失等问题的发生。

在pytorch的使用过程中有几种权重初始化的方法供大家参考。

注意:第一种方法不推荐。尽量使用后两种方法。


# not recommend
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
 m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
 m.weight.data.normal_(1.0, 0.02)
 m.bias.data.fill_(0)

# recommend
def initialize_weights(m):
if isinstance(m, nn.Conv2d):
 m.weight.data.normal_(0, 0.02)
 m.bias.data.zero_()
elif isinstance(m, nn.Linear):
 m.weight.data.normal_(0, 0.02)
 m.bias.data.zero_()

# recommend
def weights_init(m):
if isinstance(m, nn.Conv2d):
 nn.init.xavier_normal_(m.weight.data)
 nn.init.xavier_normal_(m.bias.data)
elif isinstance(m, nn.BatchNorm2d):
 nn.init.constant_(m.weight,1)
 nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
 nn.init.constant_(m.weight,1)
 nn.init.constant_(m.bias, 0)

编写好weights_init函数后,可以使用模型的apply方法对模型进行权重初始化。

net = Residual() # generate an instance network from the Net class

net.apply(weights_init) # apply weight init

补充知识:Pytorch权值初始化及参数分组

1. 模型参数初始化


# ————————————————— 利用model.apply(weights_init)实现初始化
def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
   n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
   m.weight.data.normal_(0, math.sqrt(2. / n))
   if m.bias is not None:
     m.bias.data.zero_()
 elif classname.find('BatchNorm') != -1:
   m.weight.data.fill_(1)
   m.bias.data.zero_()
 elif classname.find('Linear') != -1:
   n = m.weight.size(1)
   m.weight.data.normal_(0, 0.01)
   m.bias.data = torch.ones(m.bias.data.size())

# ————————————————— 直接放在__init__构造函数中实现初始化
for m in self.modules():
 if isinstance(m, nn.Conv2d):
   n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
   m.weight.data.normal_(0, math.sqrt(2. / n))
   if m.bias is not None:
     m.bias.data.zero_()
 elif isinstance(m, nn.BatchNorm2d):
   m.weight.data.fill_(1)
   m.bias.data.zero_()
 elif isinstance(m, nn.BatchNorm1d):
   m.weight.data.fill_(1)
   m.bias.data.zero_()
 elif isinstance(m, nn.Linear):
   nn.init.xavier_uniform_(m.weight.data)
   if m.bias is not None:
     m.bias.data.zero_()

# —————————————————
self.weight = Parameter(torch.Tensor(out_features, in_features))
self.bias = Parameter(torch.FloatTensor(out_features))
nn.init.xavier_uniform_(self.weight)
nn.init.zero_(self.bias)
nn.init.constant_(m, initm)
# nn.init.kaiming_uniform_()
# self.weight.data.normal_(std=0.001)

2. 模型参数分组weight_decay


def separate_bn_prelu_params(model, ignored_params=[]):
 bn_prelu_params = []
 for m in model.modules():
   if isinstance(m, nn.BatchNorm2d):
     ignored_params += list(map(id, m.parameters()))  
     bn_prelu_params += m.parameters()
   if isinstance(m, nn.BatchNorm1d):
     ignored_params += list(map(id, m.parameters()))  
     bn_prelu_params += m.parameters()
   elif isinstance(m, nn.PReLU):
     ignored_params += list(map(id, m.parameters()))
     bn_prelu_params += m.parameters()
 base_params = list(filter(lambda p: id(p) not in ignored_params, model.parameters()))

return base_params, bn_prelu_params, ignored_params

OPTIMIZER = optim.SGD([
   {'params': base_params, 'weight_decay': WEIGHT_DECAY},    
   {'params': fc_head_param, 'weight_decay': WEIGHT_DECAY * 10},
   {'params': bn_prelu_params, 'weight_decay': 0.0}
   ], lr=LR, momentum=MOMENTUM ) # , nesterov=True

Note 1:PReLU(x) = max(0,x) + a * min(0,x). Here a is a learnable parameter. When called without arguments, nn.PReLU() uses a single parameter a across all input channels. If called with nn.PReLU(nChannels), a separate a is used for each input channel.

Note 2: weight decay should not be used when learning a for good performance.

Note 3: The default number of a to learn is 1, the default initial value of a is 0.25.

3. 参数分组weight_decay–其他

第2节中的内容可以满足一般的参数分组需求,此部分可以满足更个性化的分组需求。参考:face_evoLVe_Pytorch-master

自定义schedule


def schedule_lr(optimizer):
 for params in optimizer.param_groups:
   params['lr'] /= 10.
 print(optimizer)

方法一:利用model.modules()和obj.__class__ (更普适)


# model.modules()和model.children()的区别:model.modules()会迭代地遍历模型的所有子层,而model.children()只会遍历模型下的一层
# 下面的关键词if 'model',源于模型定义文件。如model_resnet.py中自定义的所有nn.Module子类,都会前缀'model_resnet',所以可通过这种方式一次性筛选出自定义的模块
def separate_irse_bn_paras(model):
 paras_only_bn = []        
 paras_no_bn = []
 for layer in model.modules():
   if 'model' in str(layer.__class__):      # eg. a=[1,2] type(a): <class 'list'> a.__class__: <class 'list'>
     continue
   if 'container' in str(layer.__class__):       # 去掉Sequential型的模块
     continue
   else:
     if 'batchnorm' in str(layer.__class__):
       paras_only_bn.extend([*layer.parameters()])
     else:
       paras_no_bn.extend([*layer.parameters()])  # extend()用于在列表末尾一次性追加另一个序列中的多个值(用新列表扩展原来的列表)

return paras_only_bn, paras_no_bn

方法二:调用modules.parameters和named_parameters()

但是本质上,parameters()是根据named_parameters()获取,named_parameters()是根据modules()获取。使用此方法的前提是,须按下文1,2中的方式定义模型,或者利用Sequential+OrderedDict定义模型。


def separate_resnet_bn_paras(model):
 all_parameters = model.parameters()
 paras_only_bn = []

for pname, p in model.named_parameters():
   if pname.find('bn') >= 0:
     paras_only_bn.append(p)

paras_only_bn_id = list(map(id, paras_only_bn))
 paras_no_bn = list(filter(lambda p: id(p) not in paras_only_bn_id, all_parameters))

return paras_only_bn, paras_no_bn

两种方法的区别

参数分组的区别,其实对应了模型构造时的区别。举例:

1、构造ResNet的basic block,在__init__()函数中定义了


self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = BatchNorm2d(planes)
self.relu = ReLU(inplace = True)

2、在forward()中定义


out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

3、对ResNet取model.name_parameters()返回的pname形如:


‘layer1.0.conv1.weight'
‘layer1.0.bn1.weight'
‘layer1.0.bn1.bias'
# layer对应conv2_x, …, conv5_x; '0'对应各layer中的block索引,比如conv2_x有3个block,对应索引为layer1.0, …, layer1.2; 'conv1'就是__init__()中定义的self.conv1

4、若构造model时采用了Sequential(),则model.name_parameters()返回的pname形如:

‘body.3.res_layer.1.weight',此处的1.weight实际对应了BN的weight,无法通过pname.find(‘bn')找到该模块。


self.res_layer = Sequential(
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
BatchNorm2d(depth),
ReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth)
)

5、针对4中的情况,两种解决办法:利用OrderedDict修饰Sequential,或利用方法一


downsample = Sequential( OrderedDict([
(‘conv_ds', conv1x1(self.inplanes, planes * block.expansion, stride)),
(‘bn_ds', BatchNorm2d(planes * block.expansion)),
]))
# 如此,相应模块的pname将会带有'conv_ds',‘bn_ds'字样

来源:https://blog.csdn.net/Ibelievesunshine/article/details/99478182

标签:pytorch,参数,weight,bias,初始化
0
投稿

猜你喜欢

  • vscode ssh安装librosa处理音频的解决方法

    2022-04-25 04:33:54
  • asp中如何过滤到单引号

    2009-07-05 18:38:00
  • 页面只能打开一次Cooike如何实现

    2024-04-22 22:39:32
  • 几个图片按比例缩放的代码

    2008-02-13 08:51:00
  • Go语言学习之指针的用法详解

    2024-02-12 06:56:10
  • Python和php通信乱码问题解决方法

    2023-03-04 14:50:43
  • sql server如何利用开窗函数over()进行分组统计

    2024-01-16 01:55:36
  • Python编程使用有限状态机识别地址有效性

    2023-09-03 00:14:56
  • django中模板的html自动转意方法

    2023-06-28 15:33:49
  • Oracle 处理json数据的方法

    2024-01-16 15:11:15
  • python通过smpt发送邮件的方法

    2021-06-18 02:50:59
  • 建立三层结构的ASP应用程序

    2009-01-21 19:41:00
  • SQL Server错误代码大全及解释(留着备用)

    2024-01-14 07:08:44
  • Python爬取豆瓣数据实现过程解析

    2022-01-27 09:12:20
  • 如何尽快释放掉Connection对象建立的连接?

    2009-12-16 18:38:00
  • Python爬取新型冠状病毒“谣言”新闻进行数据分析

    2021-06-14 04:47:33
  • windows下Virtualenvwrapper安装教程

    2023-11-08 02:15:20
  • 如何使用pyinstaller打包32位的exe程序

    2021-12-17 10:15:20
  • 由浅入深讲解python中的yield与generator

    2022-08-14 06:26:11
  • python urllib库的使用详解

    2021-06-12 14:42:04
  • asp之家 网络编程 m.aspxhome.com