pytorch 如何使用batch训练lstm网络

作者:king的江鸟 时间:2023-10-18 04:46:02 

batch的lstm


# 导入相应的包
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
torch.manual_seed(1)

# 准备数据的阶段
def prepare_sequence(seq, to_ix):
   idxs = [to_ix[w] for w in seq]
   return torch.tensor(idxs, dtype=torch.long)

with open("/home/lstm_train.txt", encoding='utf8') as f:
   train_data = []
   word = []
   label = []
   data = f.readline().strip()
   while data:
       data = data.strip()
       SP = data.split(' ')
       if len(SP) == 2:
           word.append(SP[0])
           label.append(SP[1])
       else:
           if len(word) == 100 and 'I-PRO' in label:
               train_data.append((word, label))
           word = []
           label = []
       data = f.readline()

word_to_ix = {}
for sent, _ in train_data:
   for word in sent:
       if word not in word_to_ix:
           word_to_ix[word] = len(word_to_ix)

tag_to_ix = {"O": 0, "I-PRO": 1}
for i in range(len(train_data)):
   train_data[i] = ([word_to_ix[t] for t in train_data[i][0]], [tag_to_ix[t] for t in train_data[i][1]])

# 词向量的维度
EMBEDDING_DIM = 128

# 隐藏层的单元数
HIDDEN_DIM = 128

# 批大小
batch_size = 10  
class LSTMTagger(nn.Module):

def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size, batch_size):
       super(LSTMTagger, self).__init__()
       self.hidden_dim = hidden_dim
       self.batch_size = batch_size
       self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)

# The LSTM takes word embeddings as inputs, and outputs hidden states
       # with dimensionality hidden_dim.
       self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)

# The linear layer that maps from hidden state space to tag space
       self.hidden2tag = nn.Linear(hidden_dim, tagset_size)

def forward(self, sentence):
       embeds = self.word_embeddings(sentence)
       # input_tensor = embeds.view(self.batch_size, len(sentence) // self.batch_size, -1)
       lstm_out, _ = self.lstm(embeds)
       tag_space = self.hidden2tag(lstm_out)
       scores = F.log_softmax(tag_space, dim=2)
       return scores

def predict(self, sentence):
       embeds = self.word_embeddings(sentence)
       lstm_out, _ = self.lstm(embeds)
       tag_space = self.hidden2tag(lstm_out)
       scores = F.log_softmax(tag_space, dim=2)
       return scores

loss_function = nn.NLLLoss()
model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix), batch_size)
optimizer = optim.SGD(model.parameters(), lr=0.1)

data_set_word = []
data_set_label = []
for data_tuple in train_data:
   data_set_word.append(data_tuple[0])
   data_set_label.append(data_tuple[1])
torch_dataset = Data.TensorDataset(torch.tensor(data_set_word, dtype=torch.long), torch.tensor(data_set_label, dtype=torch.long))
# 把 dataset 放入 DataLoader
loader = Data.DataLoader(
   dataset=torch_dataset,  # torch TensorDataset format
   batch_size=batch_size,  # mini batch size
   shuffle=True,  #
   num_workers=2,  # 多线程来读数据
)

# 训练过程
for epoch in range(200):
   for step, (batch_x, batch_y) in enumerate(loader):
       # 梯度清零
       model.zero_grad()
       tag_scores = model(batch_x)

# 计算损失
       tag_scores = tag_scores.view(-1, tag_scores.shape[2])
       batch_y = batch_y.view(batch_y.shape[0]*batch_y.shape[1])
       loss = loss_function(tag_scores, batch_y)
       print(loss)
       # 后向传播
       loss.backward()

# 更新参数
       optimizer.step()

# 测试过程
with torch.no_grad():
   inputs = torch.tensor([data_set_word[0]], dtype=torch.long)
   print(inputs)
   tag_scores = model.predict(inputs)
   print(tag_scores.shape)
   print(torch.argmax(tag_scores, dim=2))

补充:PyTorch基础-使用LSTM神经网络实现手写数据集识别

看代码吧~


import numpy as np
import torch
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import DataLoader

# 训练集
train_data = datasets.MNIST(root="./", # 存放位置
                           train = True, # 载入训练集
                           transform=transforms.ToTensor(), # 把数据变成tensor类型
                           download = True # 下载
                          )
# 测试集
test_data = datasets.MNIST(root="./",
                           train = False,
                           transform=transforms.ToTensor(),
                           download = True
                          )

# 批次大小
batch_size = 64
# 装载训练集
train_loader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
# 装载测试集
test_loader = DataLoader(dataset=test_data,batch_size=batch_size,shuffle=True)

for i,data in enumerate(train_loader):
   inputs,labels = data
   print(inputs.shape)
   print(labels.shape)
   break

# 定义网络结构
class LSTM(nn.Module):
   def __init__(self):
       super(LSTM,self).__init__()# 初始化
       self.lstm = torch.nn.LSTM(
           input_size = 28, # 表示输入特征的大小
           hidden_size = 64, # 表示lstm模块的数量
           num_layers = 1, # 表示lstm隐藏层的层数
           batch_first = True # lstm默认格式input(seq_len,batch,feature)等于True表示input和output变成(batch,seq_len,feature)
       )
       self.out = torch.nn.Linear(in_features=64,out_features=10)
       self.softmax = torch.nn.Softmax(dim=1)
   def forward(self,x):
       # (batch,seq_len,feature)
       x = x.view(-1,28,28)
       # output:(batch,seq_len,hidden_size)包含每个序列的输出结果
       # 虽然lstm的batch_first为True,但是h_n,c_n的第0个维度还是num_layers
       # h_n :[num_layers,batch,hidden_size]只包含最后一个序列的输出结果
       # c_n:[num_layers,batch,hidden_size]只包含最后一个序列的输出结果
       output,(h_n,c_n) = self.lstm(x)
       output_in_last_timestep = h_n[-1,:,:]
       x = self.out(output_in_last_timestep)
       x = self.softmax(x)
       return x

# 定义模型
model = LSTM()
# 定义代价函数
mse_loss = nn.CrossEntropyLoss()# 交叉熵
# 定义优化器
optimizer = optim.Adam(model.parameters(),lr=0.001)# 随机梯度下降

# 定义模型训练和测试的方法
def train():
   # 模型的训练状态
   model.train()
   for i,data in enumerate(train_loader):
       # 获得一个批次的数据和标签
       inputs,labels = data
       # 获得模型预测结果(64,10)
       out = model(inputs)
       # 交叉熵代价函数out(batch,C:类别的数量),labels(batch)
       loss = mse_loss(out,labels)
       # 梯度清零
       optimizer.zero_grad()
       # 计算梯度
       loss.backward()
       # 修改权值
       optimizer.step()

def test():
   # 模型的测试状态
   model.eval()
   correct = 0 # 测试集准确率
   for i,data in enumerate(test_loader):
       # 获得一个批次的数据和标签
       inputs,labels = data
       # 获得模型预测结果(64,10)
       out = model(inputs)
       # 获得最大值,以及最大值所在的位置
       _,predicted = torch.max(out,1)
       # 预测正确的数量
       correct += (predicted==labels).sum()
   print("Test acc:{0}".format(correct.item()/len(test_data)))

correct = 0
   for i,data in enumerate(train_loader): # 训练集准确率
       # 获得一个批次的数据和标签
       inputs,labels = data
       # 获得模型预测结果(64,10)
       out = model(inputs)
       # 获得最大值,以及最大值所在的位置
       _,predicted = torch.max(out,1)
       # 预测正确的数量
       correct += (predicted==labels).sum()
   print("Train acc:{0}".format(correct.item()/len(train_data)))

# 训练
for epoch in range(10):
   print("epoch:",epoch)
   train()
   test()

pytorch 如何使用batch训练lstm网络

来源:https://blog.csdn.net/weixin_40939578/article/details/104462188

标签:pytorch,batch,lstm
0
投稿

猜你喜欢

  • python调用动态链接库的基本过程详解

    2023-05-31 13:24:00
  • Go 语言进阶单元测试示例详解

    2024-02-07 18:17:06
  • Mysql覆盖索引详解

    2024-01-14 06:54:29
  • 浅谈Python在pycharm中的调试(debug)

    2023-05-04 15:33:20
  • python实现本地批量ping多个IP的方法示例

    2023-12-19 02:36:36
  • django将图片保存到mysql数据库并展示在前端页面的实现

    2024-01-26 06:59:44
  • JS内部事件机制之单线程原理

    2024-05-03 15:58:24
  • Python实现PS滤镜碎片特效功能示例

    2021-04-25 01:35:31
  • python输入、数据类型转换及运算符方式

    2021-08-09 19:20:17
  • 运行asp.net程序 报错:磁盘空间不足

    2011-11-03 17:16:22
  • python使用pip安装SciPy、SymPy、matplotlib教程

    2022-03-05 01:46:12
  • *.HTC 文件的简单介绍

    2008-11-24 17:36:00
  • PyTorch梯度下降反向传播

    2021-05-15 17:06:14
  • 用户的期望以及背后真正的需求

    2009-06-19 12:39:00
  • PHP autoload使用方法及步骤详解

    2023-08-22 13:05:44
  • rollup打包vue组件并发布到npm的方法

    2024-05-22 10:43:32
  • python模糊图片过滤的方法

    2022-07-01 04:37:16
  • Android通过PHP服务器实现登录功能

    2023-07-02 07:08:58
  • 实例讲解Python爬取网页数据

    2023-01-10 03:55:05
  • git版本库创建拓展添加文件到版本库教程

    2022-08-11 09:44:32
  • asp之家 网络编程 m.aspxhome.com