pytorch 如何使用batch训练lstm网络
作者:king的江鸟 时间:2023-10-18 04:46:02
batch的lstm
# 导入相应的包
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
torch.manual_seed(1)
# 准备数据的阶段
def prepare_sequence(seq, to_ix):
idxs = [to_ix[w] for w in seq]
return torch.tensor(idxs, dtype=torch.long)
with open("/home/lstm_train.txt", encoding='utf8') as f:
train_data = []
word = []
label = []
data = f.readline().strip()
while data:
data = data.strip()
SP = data.split(' ')
if len(SP) == 2:
word.append(SP[0])
label.append(SP[1])
else:
if len(word) == 100 and 'I-PRO' in label:
train_data.append((word, label))
word = []
label = []
data = f.readline()
word_to_ix = {}
for sent, _ in train_data:
for word in sent:
if word not in word_to_ix:
word_to_ix[word] = len(word_to_ix)
tag_to_ix = {"O": 0, "I-PRO": 1}
for i in range(len(train_data)):
train_data[i] = ([word_to_ix[t] for t in train_data[i][0]], [tag_to_ix[t] for t in train_data[i][1]])
# 词向量的维度
EMBEDDING_DIM = 128
# 隐藏层的单元数
HIDDEN_DIM = 128
# 批大小
batch_size = 10
class LSTMTagger(nn.Module):
def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size, batch_size):
super(LSTMTagger, self).__init__()
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
# The LSTM takes word embeddings as inputs, and outputs hidden states
# with dimensionality hidden_dim.
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
# The linear layer that maps from hidden state space to tag space
self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
def forward(self, sentence):
embeds = self.word_embeddings(sentence)
# input_tensor = embeds.view(self.batch_size, len(sentence) // self.batch_size, -1)
lstm_out, _ = self.lstm(embeds)
tag_space = self.hidden2tag(lstm_out)
scores = F.log_softmax(tag_space, dim=2)
return scores
def predict(self, sentence):
embeds = self.word_embeddings(sentence)
lstm_out, _ = self.lstm(embeds)
tag_space = self.hidden2tag(lstm_out)
scores = F.log_softmax(tag_space, dim=2)
return scores
loss_function = nn.NLLLoss()
model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix), batch_size)
optimizer = optim.SGD(model.parameters(), lr=0.1)
data_set_word = []
data_set_label = []
for data_tuple in train_data:
data_set_word.append(data_tuple[0])
data_set_label.append(data_tuple[1])
torch_dataset = Data.TensorDataset(torch.tensor(data_set_word, dtype=torch.long), torch.tensor(data_set_label, dtype=torch.long))
# 把 dataset 放入 DataLoader
loader = Data.DataLoader(
dataset=torch_dataset, # torch TensorDataset format
batch_size=batch_size, # mini batch size
shuffle=True, #
num_workers=2, # 多线程来读数据
)
# 训练过程
for epoch in range(200):
for step, (batch_x, batch_y) in enumerate(loader):
# 梯度清零
model.zero_grad()
tag_scores = model(batch_x)
# 计算损失
tag_scores = tag_scores.view(-1, tag_scores.shape[2])
batch_y = batch_y.view(batch_y.shape[0]*batch_y.shape[1])
loss = loss_function(tag_scores, batch_y)
print(loss)
# 后向传播
loss.backward()
# 更新参数
optimizer.step()
# 测试过程
with torch.no_grad():
inputs = torch.tensor([data_set_word[0]], dtype=torch.long)
print(inputs)
tag_scores = model.predict(inputs)
print(tag_scores.shape)
print(torch.argmax(tag_scores, dim=2))
补充:PyTorch基础-使用LSTM神经网络实现手写数据集识别
看代码吧~
import numpy as np
import torch
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
# 训练集
train_data = datasets.MNIST(root="./", # 存放位置
train = True, # 载入训练集
transform=transforms.ToTensor(), # 把数据变成tensor类型
download = True # 下载
)
# 测试集
test_data = datasets.MNIST(root="./",
train = False,
transform=transforms.ToTensor(),
download = True
)
# 批次大小
batch_size = 64
# 装载训练集
train_loader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
# 装载测试集
test_loader = DataLoader(dataset=test_data,batch_size=batch_size,shuffle=True)
for i,data in enumerate(train_loader):
inputs,labels = data
print(inputs.shape)
print(labels.shape)
break
# 定义网络结构
class LSTM(nn.Module):
def __init__(self):
super(LSTM,self).__init__()# 初始化
self.lstm = torch.nn.LSTM(
input_size = 28, # 表示输入特征的大小
hidden_size = 64, # 表示lstm模块的数量
num_layers = 1, # 表示lstm隐藏层的层数
batch_first = True # lstm默认格式input(seq_len,batch,feature)等于True表示input和output变成(batch,seq_len,feature)
)
self.out = torch.nn.Linear(in_features=64,out_features=10)
self.softmax = torch.nn.Softmax(dim=1)
def forward(self,x):
# (batch,seq_len,feature)
x = x.view(-1,28,28)
# output:(batch,seq_len,hidden_size)包含每个序列的输出结果
# 虽然lstm的batch_first为True,但是h_n,c_n的第0个维度还是num_layers
# h_n :[num_layers,batch,hidden_size]只包含最后一个序列的输出结果
# c_n:[num_layers,batch,hidden_size]只包含最后一个序列的输出结果
output,(h_n,c_n) = self.lstm(x)
output_in_last_timestep = h_n[-1,:,:]
x = self.out(output_in_last_timestep)
x = self.softmax(x)
return x
# 定义模型
model = LSTM()
# 定义代价函数
mse_loss = nn.CrossEntropyLoss()# 交叉熵
# 定义优化器
optimizer = optim.Adam(model.parameters(),lr=0.001)# 随机梯度下降
# 定义模型训练和测试的方法
def train():
# 模型的训练状态
model.train()
for i,data in enumerate(train_loader):
# 获得一个批次的数据和标签
inputs,labels = data
# 获得模型预测结果(64,10)
out = model(inputs)
# 交叉熵代价函数out(batch,C:类别的数量),labels(batch)
loss = mse_loss(out,labels)
# 梯度清零
optimizer.zero_grad()
# 计算梯度
loss.backward()
# 修改权值
optimizer.step()
def test():
# 模型的测试状态
model.eval()
correct = 0 # 测试集准确率
for i,data in enumerate(test_loader):
# 获得一个批次的数据和标签
inputs,labels = data
# 获得模型预测结果(64,10)
out = model(inputs)
# 获得最大值,以及最大值所在的位置
_,predicted = torch.max(out,1)
# 预测正确的数量
correct += (predicted==labels).sum()
print("Test acc:{0}".format(correct.item()/len(test_data)))
correct = 0
for i,data in enumerate(train_loader): # 训练集准确率
# 获得一个批次的数据和标签
inputs,labels = data
# 获得模型预测结果(64,10)
out = model(inputs)
# 获得最大值,以及最大值所在的位置
_,predicted = torch.max(out,1)
# 预测正确的数量
correct += (predicted==labels).sum()
print("Train acc:{0}".format(correct.item()/len(train_data)))
# 训练
for epoch in range(10):
print("epoch:",epoch)
train()
test()
来源:https://blog.csdn.net/weixin_40939578/article/details/104462188
标签:pytorch,batch,lstm
0
投稿
猜你喜欢
python调用动态链接库的基本过程详解
2023-05-31 13:24:00
Go 语言进阶单元测试示例详解
2024-02-07 18:17:06
Mysql覆盖索引详解
2024-01-14 06:54:29
浅谈Python在pycharm中的调试(debug)
2023-05-04 15:33:20
python实现本地批量ping多个IP的方法示例
2023-12-19 02:36:36
django将图片保存到mysql数据库并展示在前端页面的实现
2024-01-26 06:59:44
JS内部事件机制之单线程原理
2024-05-03 15:58:24
Python实现PS滤镜碎片特效功能示例
2021-04-25 01:35:31
python输入、数据类型转换及运算符方式
2021-08-09 19:20:17
运行asp.net程序 报错:磁盘空间不足
2011-11-03 17:16:22
python使用pip安装SciPy、SymPy、matplotlib教程
2022-03-05 01:46:12
*.HTC 文件的简单介绍
2008-11-24 17:36:00
PyTorch梯度下降反向传播
2021-05-15 17:06:14
用户的期望以及背后真正的需求
2009-06-19 12:39:00
PHP autoload使用方法及步骤详解
2023-08-22 13:05:44
rollup打包vue组件并发布到npm的方法
2024-05-22 10:43:32
python模糊图片过滤的方法
2022-07-01 04:37:16
Android通过PHP服务器实现登录功能
2023-07-02 07:08:58
实例讲解Python爬取网页数据
2023-01-10 03:55:05
git版本库创建拓展添加文件到版本库教程
2022-08-11 09:44:32