在pytorch中动态调整优化器的学习率方式

作者:FesianXu 时间:2022-08-14 00:30:57 

在深度学习中,经常需要动态调整学习率,以达到更好地训练效果,本文纪录在pytorch中的实现方法,其优化器实例为SGD优化器,其他如Adam优化器同样适用。

一般来说,在以SGD优化器作为基本优化器,然后根据epoch实现学习率指数下降,代码如下:


step = [10,20,30,40]
base_lr = 1e-4
sgd_opt = torch.optim.SGD(model.parameters(), lr=base_lr, nesterov=True, momentum=0.9)
def adjust_lr(epoch):
lr = base_lr * (0.1 ** np.sum(epoch >= np.array(step)))
for params_group in sgd_opt.param_groups:
 params_group['lr'] = lr
return lr

只需要在每个train的epoch之前使用这个函数即可。


for epoch in range(60):
model.train()
adjust_lr(epoch)
for ind, each in enumerate(train_loader):
mat, label = each
...

补充知识:Pytorch框架下应用Bi-LSTM实现汽车评论文本关键词抽取

需要调用的模块及整体Bi-lstm流程


import torch
import pandas as pd
import numpy as np
from tensorflow import keras
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import gensim
from sklearn.model_selection import train_test_split
class word_extract(nn.Module):
def __init__(self,d_model,embedding_matrix):
 super(word_extract, self).__init__()
 self.d_model=d_model
 self.embedding=nn.Embedding(num_embeddings=len(embedding_matrix),embedding_dim=200)
 self.embedding.weight.data.copy_(embedding_matrix)
 self.embedding.weight.requires_grad=False
 self.lstm1=nn.LSTM(input_size=200,hidden_size=50,bidirectional=True)
 self.lstm2=nn.LSTM(input_size=2*self.lstm1.hidden_size,hidden_size=50,bidirectional=True)
 self.linear=nn.Linear(2*self.lstm2.hidden_size,4)

def forward(self,x):
 w_x=self.embedding(x)
 first_x,(first_h_x,first_c_x)=self.lstm1(w_x)
 second_x,(second_h_x,second_c_x)=self.lstm2(first_x)
 output_x=self.linear(second_x)
 return output_x

将文本转换为数值形式


def trans_num(word2idx,text):
text_list=[]
for i in text:
 s=i.rstrip().replace('\r','').replace('\n','').split(' ')
 numtext=[word2idx[j] if j in word2idx.keys() else word2idx['_PAD'] for j in s ]
 text_list.append(numtext)
return text_list

将Gensim里的词向量模型转为矩阵形式,后续导入到LSTM模型中


def establish_word2vec_matrix(model): #负责将数值索引转为要输入的数据
word2idx = {"_PAD": 0} # 初始化 `[word : token]` 字典,后期 tokenize 语料库就是用该词典。
num2idx = {0: "_PAD"}
vocab_list = [(k, model.wv[k]) for k, v in model.wv.vocab.items()]

# 存储所有 word2vec 中所有向量的数组,留意其中多一位,词向量全为 0, 用于 padding
embeddings_matrix = np.zeros((len(model.wv.vocab.items()) + 1, model.vector_size))
for i in range(len(vocab_list)):
 word = vocab_list[i][0]
 word2idx[word] = i + 1
 num2idx[i + 1] = word
 embeddings_matrix[i + 1] = vocab_list[i][1]
embeddings_matrix = torch.Tensor(embeddings_matrix)
return embeddings_matrix, word2idx, num2idx

训练过程


def train(model,epoch,learning_rate,batch_size,x, y, val_x, val_y):
optimizor = optim.Adam(model.parameters(), lr=learning_rate)
data = TensorDataset(x, y)
data = DataLoader(data, batch_size=batch_size)
for i in range(epoch):
 for j, (per_x, per_y) in enumerate(data):
  output_y = model(per_x)
  loss = F.cross_entropy(output_y.view(-1,output_y.size(2)), per_y.view(-1))
  optimizor.zero_grad()
  loss.backward()
  optimizor.step()
  arg_y=output_y.argmax(dim=2)
  fit_correct=(arg_y==per_y).sum()
  fit_acc=fit_correct.item()/(per_y.size(0)*per_y.size(1))
  print('##################################')
  print('第{}次迭代第{}批次的训练误差为{}'.format(i + 1, j + 1, loss), end=' ')
  print('第{}次迭代第{}批次的训练准确度为{}'.format(i + 1, j + 1, fit_acc))
  val_output_y = model(val_x)
  val_loss = F.cross_entropy(val_output_y.view(-1,val_output_y.size(2)), val_y.view(-1))
  arg_val_y=val_output_y.argmax(dim=2)
  val_correct=(arg_val_y==val_y).sum()
  val_acc=val_correct.item()/(val_y.size(0)*val_y.size(1))
  print('第{}次迭代第{}批次的预测误差为{}'.format(i + 1, j + 1, val_loss), end=' ')
  print('第{}次迭代第{}批次的预测准确度为{}'.format(i + 1, j + 1, val_acc))
torch.save(model,'./extract_model.pkl')#保存模型

主函数部分


if __name__ =='__main__':
#生成词向量矩阵
word2vec = gensim.models.Word2Vec.load('./word2vec_model')
embedding_matrix,word2idx,num2idx=establish_word2vec_matrix(word2vec)#输入的是词向量模型
#
train_data=pd.read_csv('./数据.csv')
x=list(train_data['文本'])
# 将文本从文字转化为数值,这部分trans_num函数你需要自己改动去适应你自己的数据集
x=trans_num(word2idx,x)
#x需要先进行填充,也就是每个句子都是一样长度,不够长度的以0来填充,填充词单独分为一类
# #也就是说输入的x是固定长度的数值列表,例如[50,123,1850,21,199,0,0,...]
#输入的y是[2,0,1,0,0,1,3,3,3,3,3,.....]
#填充代码你自行编写,以下部分是针对我的数据集
x=keras.preprocessing.sequence.pad_sequences(
  x,maxlen=60,value=0,padding='post',
)
y=list(train_data['BIO数值'])
y_text=[]
for i in y:
 s=i.rstrip().split(' ')
 numtext=[int(j) for j in s]
 y_text.append(numtext)
y=y_text
y=keras.preprocessing.sequence.pad_sequences(
  y,maxlen=60,value=3,padding='post',
 )
# 将数据进行划分
fit_x,val_x,fit_y,val_y=train_test_split(x,y,train_size=0.8,test_size=0.2)
fit_x=torch.LongTensor(fit_x)
fit_y=torch.LongTensor(fit_y)
val_x=torch.LongTensor(val_x)
val_y=torch.LongTensor(val_y)
#开始应用
w_extract=word_extract(d_model=200,embedding_matrix=embedding_matrix)
train(model=w_extract,epoch=5,learning_rate=0.001,batch_size=50,
  x=fit_x,y=fit_y,val_x=val_x,val_y=val_y)#可以自行改动参数,设置学习率,批次,和迭代次数
w_extract=torch.load('./extract_model.pkl')#加载保存好的模型
pred_val_y=w_extract(val_x).argmax(dim=2)

来源:https://blog.csdn.net/LoseInVain/article/details/87858408

标签:pytorch,优化器,学习率
0
投稿

猜你喜欢

  • python交互式图形编程实例(一)

    2022-11-12 14:44:53
  • Python 可视化matplotlib模块基础知识

    2021-09-09 05:17:45
  • Python图像读写方法对比

    2022-10-07 08:13:46
  • Go语言压缩和解压缩tar.gz文件的方法

    2024-05-21 10:21:46
  • 解决Django layui {{}}冲突的问题

    2023-07-23 15:22:18
  • Python实现提取给定网页内的所有链接

    2022-03-29 19:01:11
  • python 三边测量定位的实现代码

    2023-02-03 08:37:31
  • Python实现二叉排序树与平衡二叉树的示例代码

    2023-01-04 17:29:36
  • 解析:Perl下应当如何连接Access数据库

    2008-11-28 16:40:00
  • python微信跳一跳系列之棋子定位像素遍历

    2023-11-04 01:27:47
  • SQL实现LeetCode(180.连续的数字)

    2024-01-24 13:45:21
  • JS轮播图中缓动函数的封装

    2023-08-22 20:50:11
  • Django商城项目注册功能的实现

    2022-01-19 05:22:36
  • python ElementTree 基本读操作示例

    2022-10-23 07:27:25
  • vue实现移动端图片裁剪上传功能

    2024-05-10 14:15:04
  • 详细介绍Python函数中的默认参数

    2021-02-14 09:41:47
  • Python 如何截取字符函数

    2023-02-08 11:39:04
  • 如何在ADSI中查询用户属性?

    2010-06-17 12:53:00
  • 利用mycat实现mysql数据库读写分离的示例

    2024-01-12 21:55:52
  • MySQL批量SQL插入性能优化详解

    2024-01-21 15:25:59
  • asp之家 网络编程 m.aspxhome.com