LSTNet 코드 리뷰
( 논문 리뷰 : https://seunghan96.github.io/ts/ts15/)
CNN
- short term local dependency patterns ( among variables )
LSTM
- long term patterns for time series trends
Etc
- leverage traditional autoregressive model to tackle the scale insensitive problem
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self, args, data):
super(Model, self).__init__()
self.use_cuda = args.cuda
self.P = args.window;
self.m = data.m # number of TS
self.hidR = args.hidRNN;
self.hidC = args.hidCNN;
self.hidS = args.hidSkip;
self.Ck = args.CNN_kernel;
self.skip = args.skip;
self.pt = (self.P - self.Ck)/self.skip
self.hw = args.highway_window
self.conv1 = nn.Conv2d(1, self.hidC, kernel_size = (self.Ck, self.m));
self.GRU1 = nn.GRU(self.hidC, self.hidR);
self.dropout = nn.Dropout(p = args.dropout);
if (self.skip > 0):
self.GRUskip = nn.GRU(self.hidC, self.hidS);
self.linear1 = nn.Linear(self.hidR + self.skip * self.hidS, self.m);
else:
self.linear1 = nn.Linear(self.hidR, self.m);
if (self.hw > 0):
self.highway = nn.Linear(self.hw, 1);
self.act = None;
if (args.output_fun == 'sigmoid'):
self.act = F.sigmoid;
if (args.output_fun == 'tanh'):
self.act = F.tanh;
def forward(self, x):
# x의 크기 = (bs,n,T)
## ---- bs = batch size
## ---- n = number of TS
## ---- T = window size
batch_size = x.size(0);
#-----------------------------------------------#
# 1. CNN (wo pooling)
# (구) x : (128, n, T)
# (신) x : (128 , 1, T, n)
c = x.view(-1, 1, self.P, self.m);
c = F.relu(self.conv1(c));
c = self.dropout(c);
out_cnn = torch.squeeze(c, 3);
#-----------------------------------------------#
# 2-1. RNN
out_rnn = out_cnn.permute(2, 0, 1).contiguous();
_, out_rnn = self.GRU1(out_rnn);
out_rnn = self.dropout(torch.squeeze(out_rnn,0));
#-----------------------------------------------#
# 2-2. a) RNN-skip
if (self.skip > 0):
out_skip = out_cnn[:,:, int(-self.pt * self.skip):].contiguous();
out_skip = out_skip.view(batch_size, self.hidC, self.pt, self.skip);
out_skip = out_skip.permute(2,0,3,1).contiguous();
out_skip = out_skip.view(self.pt, batch_size * self.skip, self.hidC);
_, out_skip = self.GRUskip(out_skip);
out_skip = out_skip.view(batch_size, self.skip * self.hidS);
out_skip = self.dropout(out_skip);
out_rnn = torch.cat((out_rnn,out_skip),1);
result = self.linear1(out_rnn);
#-----------------------------------------------#
# 2-2. b) Highway
if (self.hw > 0):
out_hw = x[:, -self.hw:, :];
out_hw = out_hw.permute(0,2,1).contiguous().view(-1, self.hw);
out_hw = self.highway(out_hw);
out_hw = out_hw.view(-1,self.m);
result = result + out_hw;
if (self.act):
result = self.act(result);
return result;
₩