PyTorch搭建LSTM实现时间序列负荷预测
作者:Cyril_KI 时间:2023-08-18 09:10:09
I. 前言
在上一篇文章深入理解PyTorch中LSTM的输入和输出(从input输入到Linear输出)中,我详细地解释了如何利用PyTorch来搭建一个LSTM模型,本篇文章的主要目的是搭建一个LSTM模型用于时间序列预测。
系列文章:
PyTorch搭建LSTM实现多变量多步长时序负荷预测
PyTorch搭建LSTM实现多变量时序负荷预测
PyTorch深度学习LSTM从input输入到Linear输出
PyTorch搭建双向LSTM实现时间序列负荷预测
II. 数据处理
数据集为某个地区某段时间内的电力负荷数据,除了负荷以外,还包括温度、湿度等信息。
本篇文章暂时不考虑其它变量,只考虑用历史负荷来预测未来负荷。
本文中,我们根据前24个时刻的负荷下一时刻的负荷。有关多变量预测请参考:PyTorch搭建LSTM实现多变量时间序列预测(负荷预测)。
def load_data(file_name):
global MAX, MIN
df = pd.read_csv('data/new_data/' + file_name, encoding='gbk')
columns = df.columns
df.fillna(df.mean(), inplace=True)
MAX = np.max(df[columns[1]])
MIN = np.min(df[columns[1]])
df[columns[1]] = (df[columns[1]] - MIN) / (MAX - MIN)
return df
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, item):
return self.data[item]
def __len__(self):
return len(self.data)
def nn_seq(file_name, B):
print('处理数据:')
data = load_data(file_name)
load = data[data.columns[1]]
load = load.tolist()
load = torch.FloatTensor(load).view(-1)
data = data.values.tolist()
seq = []
for i in range(len(data) - 24):
train_seq = []
train_label = []
for j in range(i, i + 24):
train_seq.append(load[j])
train_label.append(load[i + 24])
train_seq = torch.FloatTensor(train_seq).view(-1)
train_label = torch.FloatTensor(train_label).view(-1)
seq.append((train_seq, train_label))
# print(seq[:5])
Dtr = seq[0:int(len(seq) * 0.7)]
Dte = seq[int(len(seq) * 0.7):len(seq)]
train_len = int(len(Dtr) / B) * B
test_len = int(len(Dte) / B) * B
Dtr, Dte = Dtr[:train_len], Dte[:test_len]
train = MyDataset(Dtr)
test = MyDataset(Dte)
Dtr = DataLoader(dataset=train, batch_size=B, shuffle=False, num_workers=0)
Dte = DataLoader(dataset=test, batch_size=B, shuffle=False, num_workers=0)
return Dtr, Dte
上面代码用了DataLoader来对原始数据进行处理,最终得到了batch_size=B的数据集Dtr和Dte,Dtr为训练集,Dte为测试集。
III. LSTM模型
这里采用了深入理解PyTorch中LSTM的输入和输出(从input输入到Linear输出)中的模型:
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.output_size = output_size
self.num_directions = 1 # 单向LSTM
self.batch_size = batch_size
self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
self.linear = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input_seq):
h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
seq_len = input_seq.shape[1] # (5, 24)
# input(batch_size, seq_len, input_size)
input_seq = input_seq.view(self.batch_size, seq_len, 1) # (5, 24, 1)
# output(batch_size, seq_len, num_directions * hidden_size)
output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 24, 64)
output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size) # (5 * 24, 64)
pred = self.linear(output) # pred(150, 1)
pred = pred.view(self.batch_size, seq_len, -1) # (5, 24, 1)
pred = pred[:, -1, :] # (5, 1)
return pred
IV. 训练
def LSTM_train(name, b):
Dtr, Dte = nn_seq(file_name=name, B=b)
input_size, hidden_size, num_layers, output_size = 1, 64, 5, 1
model = LSTM(input_size, hidden_size, num_layers, output_size, batch_size=b).to(device)
loss_function = nn.MSELoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 训练
epochs = 15
cnt = 0
for i in range(epochs):
cnt = 0
print('当前', i)
for (seq, label) in Dtr:
cnt += 1
seq = seq.to(device)
label = label.to(device)
y_pred = model(seq)
loss = loss_function(y_pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if cnt % 100 == 0:
print('epoch', i, ':', cnt - 100, '~', cnt, loss.item())
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(state, LSTM_PATH)
一共训练了 * :
V. 测试
def test(name, b):
global MAX, MIN
Dtr, Dte = nn_seq(file_name=name, B=b)
pred = []
y = []
print('loading model...')
input_size, hidden_size, num_layers, output_size = 1, 64, 5, 1
model = LSTM(input_size, hidden_size, num_layers, output_size, batch_size=b).to(device)
model.load_state_dict(torch.load(LSTM_PATH)['model'])
model.eval()
print('predicting...')
for (seq, target) in Dte:
target = list(chain.from_iterable(target.data.tolist()))
y.extend(target)
seq = seq.to(device)
seq_len = seq.shape[1]
seq = seq.view(model.batch_size, seq_len, 1) # (5, 24, 1)
with torch.no_grad():
y_pred = model(seq)
y_pred = list(chain.from_iterable(y_pred.data.tolist()))
pred.extend(y_pred)
y, pred = np.array(y), np.array(pred)
y = (MAX - MIN) * y + MIN
pred = (MAX - MIN) * pred + MIN
print('accuracy:', get_mape(y, pred))
# plot
x = [i for i in range(1, 151)]
x_smooth = np.linspace(np.min(x), np.max(x), 600)
y_smooth = make_interp_spline(x, y[0:150])(x_smooth)
plt.plot(x_smooth, y_smooth, c='green', marker='*', ms=1, alpha=0.75, label='true')
y_smooth = make_interp_spline(x, pred[0:150])(x_smooth)
plt.plot(x_smooth, y_smooth, c='red', marker='o', ms=1, alpha=0.75, label='pred')
plt.grid(axis='y')
plt.legend()
plt.show()
MAPE为6.07%:
VI. 源码及数据
源码及数据我放在了GitHub上,LSTM-Load-Forecasting
来源:https://blog.csdn.net/Cyril_KI/article/details/122569775
标签:PyTorch,LSTM,时间序列,负荷预测
![](/images/zang.png)
![](/images/jiucuo.png)
猜你喜欢
ASP.NET Core中的Configuration配置一
2024-06-05 09:32:59
![](https://img.aspxhome.com/file/2023/9/122779_0s.jpg)
django 发送手机验证码的示例代码
2023-05-07 03:42:37
![](https://img.aspxhome.com/file/2023/7/120377_0s.png)
Python学习笔记之变量与转义符
2022-12-20 23:21:18
![](https://img.aspxhome.com/file/2023/9/72769_0s.png)
Golang空接口与类型断言的实现
2024-04-27 15:39:21
PHP简单实现冒泡排序的方法
2024-06-07 15:45:49
如何解决Oracle EBS R12 - 以Excel查看输出格式为“文本”的请求时乱码
2024-01-22 01:17:55
利用Pytorch实现ResNet网络构建及模型训练
2022-02-24 19:57:59
Centos 6.4 安装Python 2.7 python-pip的详细步骤
2023-12-01 10:21:03
Vue项目的网络请求代理到封装步骤详解
2024-04-30 10:23:55
![](https://img.aspxhome.com/file/2023/4/130244_0s.png)
Python入门教程(二)Python快速上手
2023-10-16 08:54:09
![](https://img.aspxhome.com/file/2023/0/135070_0s.png)
python采集百度百科的方法
2023-01-12 03:48:35
Python字符串拼接的几种方法整理
2021-05-08 18:55:47
![](https://img.aspxhome.com/file/2023/5/64735_0s.jpg)
PHP使用星号替代用户名手机和邮箱的实现代码
2023-11-17 05:50:02
pycharm中使用request和Pytest进行接口测试的方法
2022-06-30 03:34:46
![](https://img.aspxhome.com/file/2023/9/90599_0s.png)
彻底卸载MySQL的方法分享
2024-01-23 11:49:08
SQL Server查询条件IN中能否使用变量的示例详解
2024-01-15 17:55:55
![](https://img.aspxhome.com/file/2023/3/83173_0s.png)
Python使用百度api做人脸对比的方法
2023-08-18 12:52:24
![](https://img.aspxhome.com/file/2023/7/68467_0s.jpg)
Python集中化管理平台Ansible介绍与YAML简介
2023-09-23 12:52:34
![](https://img.aspxhome.com/file/2023/6/88026_0s.png)
python下读取公私钥做加解密实例详解
2022-04-17 03:39:00
Python中使用第三方库xlutils来追加写入Excel文件示例
2022-05-23 10:04:11