当前位置:网站首页>PyTorch搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
PyTorch搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
2022-06-26 06:43:00 【Cyril_KI】
I. 前言
关于LSTM的具体原理可以参考:人工智能教程。除了LSTM以外,这个网站还囊括了其他大多机器学习以及深度学习模型的具体讲解,配图生动,简单易懂。
前面已经写了很多关于时间序列预测的文章:
- 深入理解PyTorch中LSTM的输入和输出(从input输入到Linear输出)
- PyTorch搭建LSTM实现时间序列预测(负荷预测)
- PyTorch搭建LSTM实现多变量时间序列预测(负荷预测)
- PyTorch搭建双向LSTM实现时间序列预测(负荷预测)
- PyTorch搭建LSTM实现多变量多步长时间序列预测(一):直接多输出
- PyTorch搭建LSTM实现多变量多步长时间序列预测(二):单步滚动预测
- PyTorch搭建LSTM实现多变量多步长时间序列预测(三):多模型单步预测
- PyTorch搭建LSTM实现多变量多步长时间序列预测(四):多模型滚动预测
- PyTorch搭建LSTM实现多变量多步长时间序列预测(五):seq2seq
- PyTorch中实现LSTM多步长时间序列预测的几种方法总结(负荷预测)
- PyTorch-LSTM时间序列预测中如何预测真正的未来值
- PyTorch搭建LSTM实现多变量输入多变量输出时间序列预测(多任务学习)
- PyTorch搭建ANN实现时间序列预测(风速预测)
- PyTorch搭建CNN实现时间序列预测(风速预测)
- PyTorch搭建CNN-LSTM混合模型实现多变量多步长时间序列预测(负荷预测)
上面所有文章一共采用了LSTM、ANN以及CNN三种模型来分别进行时间序列预测。众所周知,CNN提取特征的能力非常强,因此现在不少论文将CNN和LSTM结合起来进行时间序列预测。本文将利用PyTorch来搭建一个简单的CNN-LSTM混合模型实现负荷预测。
II. CNN-LSTM
CNN-LSTM模型搭建如下:
class CNN_LSTM(nn.Module):
def __init__(self, args):
super(CNN_LSTM, self).__init__()
self.args = args
self.relu = nn.ReLU(inplace=True)
# (batch_size=30, seq_len=24, input_size=7) ---> permute(0, 2, 1)
# (30, 7, 24)
self.conv = nn.Sequential(
nn.Conv1d(in_channels=args.in_channels, out_channels=args.out_channels, kernel_size=3),
nn.ReLU(),
nn.MaxPool1d(kernel_size=3, stride=1)
)
# (batch_size=30, out_channels=32, seq_len-4=20) ---> permute(0, 2, 1)
# (30, 20, 32)
self.lstm = nn.LSTM(input_size=args.out_channels, hidden_size=args.hidden_size,
num_layers=args.num_layers, batch_first=True)
self.fc = nn.Linear(args.hidden_size, args.output_size)
def forward(self, x):
x = x.permute(0, 2, 1)
x = self.conv(x)
x = x.permute(0, 2, 1)
x, _ = self.lstm(x)
x = self.fc(x)
x = x[:, -1, :]
return x
可以看到,该CNN-LSTM由一层一维卷积+LSTM组成。
通过PyTorch搭建CNN实现时间序列预测(风速预测)我们知道,一维卷积的原始定义如下:
nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
本文模型的一维卷积定义:
nn.Conv1d(in_channels=args.in_channels, out_channels=args.out_channels, kernel_size=3)
这里in_channels的概念相当于自然语言处理中的embedding,因此输入通道数为7,表示负荷+其他6个环境变量;out_channels的可以随意设置,本文设置为32;kernel_size设置为3。
PyTorch中一维卷积的输入尺寸为:
input(batch_size, input_size, seq_len)=(30, 7, 24)
而经过数据处理后得到的数据维度为:
input(batch_size, seq_len, input_size)=(30, 24, 7)
因此,我们需要进行维度交换:
x = x.permute(0, 2, 1)
交换后的输入数据将符合CNN的输入。
一维卷积中卷积操作是针对seq_len维度进行的,也就是(30, 7, 24)中的最后一个维度。因此,经过:
nn.Conv1d(in_channels=args.in_channels, out_channels=args.out_channels, kernel_size=3)
后,数据维度将变为:
(30, 32, 24-3+1)=(30, 32, 22)
第一维度的batch_size不变,第二维度的input_size将由in_channels=7变成out_channels=32,第三维度进行卷积变成22。
然后经过一个最大池化变成:
(30, 32, 22-3+1)=(30, 32, 20)
此时的(30, 32, 20)将作为LSTM的输入。由于在LSTM中我们设置了batch_first=True,因此LSTM能够接收的输入维度为:
input(batch_size, seq_len, input_size)
而经卷积池化后得到的数据维度为:
input(batch_size=30, input_size=32, seq_len=20)
因此,同样需要进行维度交换:
x = x.permute(0, 2, 1)
然后就是比较常规的LSTM输入输出的,不再细说。
因此,完整的forward函数如下所示:
def forward(self, x):
x = x.permute(0, 2, 1)
x = self.conv(x)
x = x.permute(0, 2, 1)
x, _ = self.lstm(x)
x = self.fc(x)
x = x[:, -1, :]
return x
III. 代码实现
3.1 数据处理
我们根据前24个时刻的负荷以及该时刻的环境变量来预测接下来4个时刻的负荷,这里采用了直接多输出策略,调整output_size即可调整输出步长。
代码实现:
def nn_seq(args):
seq_len, B, num = args.seq_len, args.batch_size, args.output_size
print('data processing...')
dataset = load_data()
# split
train = dataset[:int(len(dataset) * 0.6)]
val = dataset[int(len(dataset) * 0.6):int(len(dataset) * 0.8)]
test = dataset[int(len(dataset) * 0.8):len(dataset)]
m, n = np.max(train[train.columns[1]]), np.min(train[train.columns[1]])
def process(data, batch_size, step_size):
load = data[data.columns[1]]
data = data.values.tolist()
load = (load - n) / (m - n)
load = load.tolist()
seq = []
for i in range(0, len(data) - seq_len - num, step_size):
train_seq = []
train_label = []
for j in range(i, i + seq_len):
x = [load[j]]
for c in range(2, 8):
x.append(data[j][c])
train_seq.append(x)
for j in range(i + seq_len, i + seq_len + num):
train_label.append(load[j])
train_seq = torch.FloatTensor(train_seq)
train_label = torch.FloatTensor(train_label).view(-1)
seq.append((train_seq, train_label))
# print(seq[-1])
seq = MyDataset(seq)
seq = DataLoader(dataset=seq, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False)
return seq
Dtr = process(train, B, step_size=1)
Val = process(val, B, step_size=1)
Dte = process(test, B, step_size=num)
return Dtr, Val, Dte, m, n
3.2 模型训练/测试
和前面一致:
def train(args, Dtr, Val, path):
model = CNN_LSTM(args).to(args.device)
loss_function = nn.MSELoss().to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
print('training...')
epochs = 50
min_epochs = 10
best_model = None
min_val_loss = 5
for epoch in range(epochs):
train_loss = []
for batch_idx, (seq, target) in enumerate(Dtr, 0):
seq, target = seq.to(args.device), target.to(args.device)
optimizer.zero_grad()
y_pred = model(seq)
loss = loss_function(y_pred, target)
train_loss.append(loss.item())
loss.backward()
optimizer.step()
# validation
val_loss = get_val_loss(args, model, Val)
if epoch + 1 >= min_epochs and val_loss < min_val_loss:
min_val_loss = val_loss
best_model = copy.deepcopy(model)
print('epoch {:03d} train_loss {:.8f} val_loss {:.8f}'.format(epoch, np.mean(train_loss), val_loss))
model.train()
state = {
'model': best_model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(state, path)
def test(args, Dte, path, m, n):
print('loading model...')
model = CNN_LSTM(args).to(args.device)
model.load_state_dict(torch.load(path)['model'])
model.eval()
pred = []
y = []
for batch_idx, (seq, target) in enumerate(Dte, 0):
seq = seq.to(args.device)
with torch.no_grad():
target = list(chain.from_iterable(target.tolist()))
y.extend(target)
y_pred = model(seq)
y_pred = list(chain.from_iterable(y_pred.data.tolist()))
pred.extend(y_pred)
y, pred = np.array(y), np.array(pred)
y = (m - n) * y + n
pred = (m - n) * pred + n
print('mape:', get_mape(y, pred))
# plot
x = [i for i in range(1, 151)]
x_smooth = np.linspace(np.min(x), np.max(x), 900)
y_smooth = make_interp_spline(x, y[150:300])(x_smooth)
plt.plot(x_smooth, y_smooth, c='green', marker='*', ms=1, alpha=0.75, label='true')
y_smooth = make_interp_spline(x, pred[150:300])(x_smooth)
plt.plot(x_smooth, y_smooth, c='red', marker='o', ms=1, alpha=0.75, label='pred')
plt.grid(axis='y')
plt.legend()
plt.show()
3.3 实验结果
前24个时刻预测未来4个时刻,MAPE为7.41%:
IV. 源码及数据
后续考虑公开~
边栏推荐
- Load balancer does not have available server for client: userService问题解决
- Gof23 - prototype mode
- MYSQL索引不生效的原因
- I use flask to write the website "II"
- Container with the most water
- LightGBM--调参笔记
- [micro service series] protocol buffer dynamic analysis
- Connexion et déconnexion TCP, détails du diagramme de migration de l'état
- What is data mining?
- 连接数服务器数据库报:错误号码2003Can‘t connect to MySQL server on ‘服务器地址‘(10061)
猜你喜欢

Load balancer does not have available server for client: userService问题解决

LabVIEW Arduino TCP/IP远程智能家居系统(项目篇—5)
New generation engineers teach you how to play with alluxio + ml (Part 1)

遇到女司机业余开滴滴,日入500!

Vulnerability discovery - API interface service vulnerability probe type utilization and repair

TCP連接與斷開,狀態遷移圖詳解

营销技巧:相比较讲产品的优点,更有效的是要向客户展示使用效果

数据湖架构之Hudi编译篇

LabVIEW Arduino tcp/ip remote smart home system (project part-5)

连接数服务器数据库报:错误号码2003Can‘t connect to MySQL server on ‘服务器地址‘(10061)
随机推荐
How can an enterprise successfully complete cloud migration?
Differences, advantages and disadvantages between synchronous communication and asynchronous communication
Pagoda server setup and database remote connection
My SQL (II)
Pytorch mixing accuracy principle and how to start this method
数据湖架构之Hudi编译篇
[micro service series] protocol buffer dynamic analysis
Spark3.3.0 source code compilation supplement - Crazy certificate problem
连接数服务器数据库报:错误号码2003Can‘t connect to MySQL server on ‘服务器地址‘(10061)
STM 32 使用cube 生成TIM触发ADC并通过DMA传输的问题
typescript的type
营销技巧:相比较讲产品的优点,更有效的是要向客户展示使用效果
STM32F1与STM32CubeIDE编程实例-热敏传感器驱动
Interviewer: what is the difference between a test plan and a test plan?
Failed to configure a DataSource: ‘url‘ attribute is not specified and no embedded datasource could
MYSQL(三)
Unsatisfied dependency expressed through field ‘baseMapper‘; nested exceptio
Go语言学习笔记 1.1
SHOW语句用法补充
Gof23 - abstract factory pattern