当前位置:网站首页>25.时间序列预测实战
25.时间序列预测实战
2022-08-04 07:03:00 【派大星的最爱海绵宝宝】
时间序列预测实战
[b,50,1],b为1时,可以理解为只送入一条曲线,每一条曲线有50点的数据,每个点数据都是实数。
start = np.random.randint(3, size=1)[0]
time_step = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_step)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
start是随机的,是每次开始的起点。
我们需要完成的功能是,对于一条曲线,给出红色部分时,要求预测出蓝色部分曲线。
x是给定的0到48的部分,y需要预测出1到49的部分。
Train
out[b,seq_len,h]
h[b,1,h]
hidden_prev是h0,最开始是一个batch,一层,h是10。
我们将output和y之间进行一个MSE求误差,根据这个误差进行网络的更新。
hidden_prev=torch.zeros(1,1,hidden_size)
for iter in range(6000):
start = np.random.randint(3, size=1)[0]
time_step = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_step)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
output,hidden_prev=model(x,hidden_prev)
hidden_prev=hidden_prev.detach()
loss=criteon(output,y)
model.zero_grad()
loss.backward()
optimizer.step()
if iter %100 ==0:
print("Iteration:{} loss:{}".format(iter,loss.item()))
Test
先将预测值做一个空的数组。
x[1,seq,1]。
每次的input等于pred出来的点,每次只画一个点,最后进行串联。
predictions=[]
input=x[:,0,:]
for _ in range(x.shape[1]):
input=input.view(1,1,1)
(pred,hidden_prev)=model(input,hidden_prev)
input=pred
predictions.append(pred.detach().numpy().ravel()[0])
结果


代码
import numpy
import torch
from torch import nn
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
num_time_steps = 50
input_size=1
hidden_size=16
output_size=1
lr = 0.01
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.rnn=nn.RNN(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
batch_first=True
)
self.linear=nn.Linear(hidden_size,output_size)
def forward(self,x,hidden_prev):
out,hidden_prev=self.rnn(x,hidden_prev)
#[1,seq,h]->[seq,h]
out=out.view(-1,hidden_size)
out=self.linear(out) #[seq,h]->[seq,1]
out=out.unsqueeze(dim=0) #->[1,seq,-1]
return out,hidden_prev
def main():
model=Net()
criteon=nn.MSELoss()
optimizer=optim.Adam(model.parameters(),lr)
hidden_prev=torch.zeros(1,1,hidden_size)
for iter in range(6000):
start = np.random.randint(3, size=1)[0]
time_step = np.linspace(start, start + 10, num_time_steps)
data = np.sin(time_step)
data = data.reshape(num_time_steps, 1)
x = torch.tensor(data[:-1]).float().view(1, num_time_steps - 1, 1)
y = torch.tensor(data[1:]).float().view(1, num_time_steps - 1, 1)
output,hidden_prev=model(x,hidden_prev)
hidden_prev=hidden_prev.detach()
loss=criteon(output,y)
model.zero_grad()
loss.backward()
optimizer.step()
if iter %100 ==0:
print("Iteration:{} loss:{}".format(iter,loss.item()))
predictions=[]
input=x[:,0,:]
for _ in range(x.shape[1]):
input=input.view(1,1,1)
(pred,hidden_prev)=model(input,hidden_prev)
input=pred
predictions.append(pred.detach().numpy().ravel()[0])
x=x.data.numpy().ravel()
y=y.data.numpy()
plt.scatter(time_step[:-1],x.ravel(),s=90)
plt.plot(time_step[:-1],x.ravel())
plt.scatter(time_step[1:],predictions)
plt.show()
if __name__ == '__main__':
main()
边栏推荐
猜你喜欢

【愚公系列】2022年07月 Go教学课程 027-深拷贝和浅拷贝

一天学会JDBC03:Statement的用法

中职网络安全竞赛C模块MS17-010批量扫描

西门子PLC1200与fanuc机器人进行profibus通讯
![[Paper Notes] - Low Illumination Image Enhancement - Supervised - RetinexNet - 2018-BMVC](/img/54/685fb2620aa53416437943705d3d38.png)
[Paper Notes] - Low Illumination Image Enhancement - Supervised - RetinexNet - 2018-BMVC

fanuc机器人IO分配报警信号分配无效

GIS数据与CAD数据间带属性字段互相转换还原工具,解决ArcGIS等软件进行GIS数据转CAD数据无法保留属性字段问题

babylon 里面加gltf 模型

SystemVerilog-条件(三元)运算符

LeetCode 97. 交错字符串
随机推荐
关于我写的循环遍历
MAML principle explanation and code implementation
MMDeploy部署实战系列【第四章】:onnx,tensorrt模型推理
Verilog“七宗罪”
反序列化字符逃逸漏洞之
登录拦截实现过程
千古第一文人苏轼的众CP
SQL如何从字符串截取指定字符(LEFT、MID、RIGHT三大函数)
FCN - the originator of semantic segmentation (based on tf-Kersa reproduction code)
两日总结八
分布式计算实验2 线程池
中断和异常的处理与抢占式多任务
IDEA中创建编写JSP
分布式计算实验1 负载均衡
likeshop外卖点餐系统开源啦100%开源无加密
ConstraintSet of animation of ContrstrainLayout
中职网络安全竞赛C模块MS17-010批量扫描
「PHP基础知识」转换数据类型
数据特征预处理——缺失值的查看方式及处理
redis---分布式锁存在的问题及解决方案(Redisson)