LSTNet 코드 리뷰

( 논문 리뷰 : https://seunghan96.github.io/ts/ts15/)


CNN

  • short term local dependency patterns ( among variables )


LSTM

  • long term patterns for time series trends


Etc

  • leverage traditional autoregressive model to tackle the scale insensitive problem


import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self, args, data):
        super(Model, self).__init__()
        self.use_cuda = args.cuda
        self.P = args.window;
        self.m = data.m # number of TS
        self.hidR = args.hidRNN;
        self.hidC = args.hidCNN;
        self.hidS = args.hidSkip;
        self.Ck = args.CNN_kernel;
        self.skip = args.skip;
        self.pt = (self.P - self.Ck)/self.skip
        self.hw = args.highway_window
        self.conv1 = nn.Conv2d(1, self.hidC, kernel_size = (self.Ck, self.m));
        self.GRU1 = nn.GRU(self.hidC, self.hidR);
        self.dropout = nn.Dropout(p = args.dropout);
        if (self.skip > 0):
            self.GRUskip = nn.GRU(self.hidC, self.hidS);
            self.linear1 = nn.Linear(self.hidR + self.skip * self.hidS, self.m);
        else:
            self.linear1 = nn.Linear(self.hidR, self.m);
        if (self.hw > 0):
            self.highway = nn.Linear(self.hw, 1);
        self.act = None;
        if (args.output_fun == 'sigmoid'):
            self.act = F.sigmoid;
        if (args.output_fun == 'tanh'):
            self.act = F.tanh;
 
    def forward(self, x):
    		# x의 크기 = (bs,n,T)
        ## ---- bs = batch size
        ## ---- n = number of TS
        ## ---- T = window size
        batch_size = x.size(0);
        #-----------------------------------------------#
        # 1. CNN (wo pooling)
        # (구) x : (128, n, T)
        # (신) x : (128 , 1, T, n)
        c = x.view(-1, 1, self.P, self.m);
        c = F.relu(self.conv1(c));
        c = self.dropout(c);
        out_cnn = torch.squeeze(c, 3);
        
        #-----------------------------------------------#
        # 2-1. RNN 
        out_rnn = out_cnn.permute(2, 0, 1).contiguous();
        _, out_rnn = self.GRU1(out_rnn);
        out_rnn = self.dropout(torch.squeeze(out_rnn,0));
        
        #-----------------------------------------------#
        # 2-2. a) RNN-skip
        if (self.skip > 0):
            out_skip = out_cnn[:,:, int(-self.pt * self.skip):].contiguous();
            out_skip = out_skip.view(batch_size, self.hidC, self.pt, self.skip);
            out_skip = out_skip.permute(2,0,3,1).contiguous();
            out_skip = out_skip.view(self.pt, batch_size * self.skip, self.hidC);
            _, out_skip = self.GRUskip(out_skip);
            out_skip = out_skip.view(batch_size, self.skip * self.hidS);
            out_skip = self.dropout(out_skip);
            out_rnn = torch.cat((out_rnn,out_skip),1);
        
        result = self.linear1(out_rnn);
        
        #-----------------------------------------------#
        # 2-2. b) Highway
        if (self.hw > 0):
            out_hw = x[:, -self.hw:, :];
            out_hw = out_hw.permute(0,2,1).contiguous().view(-1, self.hw);
            out_hw = self.highway(out_hw);
            out_hw = out_hw.view(-1,self.m);
            result = result + out_hw;
            
        if (self.act):
            result = self.act(result);
        return result;

Categories: ,

Updated: