Temporal Neighborhood Coding (TNC)

Tonekaboni, Sana, Danny Eytan, and Anna Goldenberg. "Unsupervised representation learning for time series with temporal neighborhood coding." arXiv preprint arXiv:2106.00750 (2021).


references :

  • https://github.com/sanatonek/TNC_representation_learning
  • https://arxiv.org/abs/2106.00750.pdf


import torch
import torch.nn as nn


1. Encoder ( for “simulation” & “har” dataset )

class RnnEncoder(torch.nn.Module):
    def __init__(self, hidden_size, in_channel, encoding_size, 
                 cell_type='GRU', num_layers=1, device='cpu', dropout=0, bidirectional=True):
        super(RnnEncoder, self).__init__()
        self.hidden_size = hidden_size
        self.in_channel = in_channel
        self.num_layers = num_layers
        self.cell_type = cell_type
        self.encoding_size = encoding_size
        self.bidirectional = bidirectional
        self.device = device

        self.nn = torch.nn.Linear(self.hidden_size*(int(self.bidirectional) + 1), self.encoding_size).to(self.device)
        
        if cell_type=='GRU':
            self.rnn = torch.nn.GRU(input_size=self.in_channel, hidden_size=self.hidden_size, num_layers=num_layers,
                                    batch_first=False, dropout=dropout, bidirectional=bidirectional).to(self.device)
        elif cell_type=='LSTM':
            self.rnn = torch.nn.LSTM(input_size=self.in_channel, hidden_size=self.hidden_size, num_layers=num_layers,
                                    batch_first=False, dropout=dropout, bidirectional=bidirectional).to(self.device)
        else:
            raise ValueError('Cell type not defined, must be one of the following {GRU, LSTM, RNN}')

    def forward(self, x):
        x = x.permute(2,0,1) 
        if self.cell_type=='GRU':
            past = torch.zeros(self.num_layers * (int(self.bidirectional) + 1), 
                               x.shape[1], 
                               self.hidden_size).to(self.device) 
            
        elif self.cell_type=='LSTM':
            h_0 = torch.zeros(self.num_layers * (int(self.bidirectional) + 1), 
                              x.shape[1], 
                              self.hidden_size).to(self.device)
            c_0 = torch.zeros(self.num_layers * (int(self.bidirectional) + 1), 
                              x.shape[1], 
                              self.hidden_size).to(self.device)
            past = (h_0, c_0)

        out, _ = self.rnn(x.to(self.device), past)  # out = (L, B, num_directions * hidden_size)
        encodings = self.nn(out[-1].squeeze(0))
        return encodings


lstm_enc = RnnEncoder(hidden_size=32, in_channel=1, encoding_size=120, 
                 cell_type='LSTM', num_layers=1, device='cpu', dropout=0, bidirectional=True)

B = 64
C_in = 1
L = 100

input = torch.randn((B, C_in, L))
output = lstm_enc(input)
print(input.shape)
print(output.shape)
torch.Size([64, 1, 100])
torch.Size([64, 120])


2. State Classifier

Batch Norm + FC layer

class StateClassifier(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super(StateClassifier, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.normalize = torch.nn.BatchNorm1d(self.input_size)
        self.nn = torch.nn.Linear(self.input_size, self.output_size)
        torch.nn.init.xavier_uniform_(self.nn.weight)

    def forward(self, x):
        x = self.normalize(x)
        logits = self.nn(x)
        return logits


sc = StateClassifier(input_size = 16, output_size= 32)

B = 64
C_in = 16

input = torch.randn((B, C_in))
output = sc(input)
print(input.shape)
print(output.shape)
torch.Size([64, 16])
torch.Size([64, 32])


3. Encoder ( for “waveform” dataset )

class WFEncoder(nn.Module):
    def __init__(self, encoding_size, classify=False, n_classes=None):
        super(WFEncoder, self).__init__()

        self.encoding_size = encoding_size
        self.n_classes = n_classes
        self.classify = classify
        self.classifier =None
        if self.classify:
            if self.n_classes is None:
                raise ValueError('Need to specify the number of output classes for te encoder')
            else:
                self.classifier = nn.Sequential(
                    nn.Dropout(0.5),
                    nn.Linear(self.encoding_size, self.n_classes)
                )
                nn.init.xavier_uniform_(self.classifier[1].weight)

        self.features = nn.Sequential(
            nn.Conv1d(2, 64, kernel_size=4, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(64, eps=0.001),
            nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(64, eps=0.001),
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(128, eps=0.001),
            # nn.Dropout(),
            nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(128, eps=0.001),
            nn.MaxPool1d(kernel_size=2, stride=2),
            # nn.Dropout(0.5),
            nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(256, eps=0.001),
            # nn.Dropout(0.5),
            nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(256, eps=0.001),
            nn.MaxPool1d(kernel_size=2, stride=2)
            )

        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(3840, 2048),
            nn.ELU(inplace=True),
            nn.BatchNorm1d(2048, eps=0.001),
            nn.Linear(2048, self.encoding_size)
        )

    def forward(self, x):
        x = self.features(x) # (B,2,128) -> (B,256,12)
        x = x.view(x.size(0), -1) # (B,256,12) -> (B,3840)
        encoding = self.fc(x) # (B,3840) -> (B,100)
        if self.classify:
            c = self.classifier(encoding) # (B,100) -> (B, num_cls)
            return c
        else:
            return encoding


wave_form_enc = WFEncoder(encoding_size=100, classify=False, n_classes=None)

B = 64
C_in = 2
L = 128

input = torch.randn((B, C_in, L))
output = wave_form_enc(input)
print(input.shape)
print(output.shape)
torch.Size([64, 2, 128])
torch.Size([64, 100])


wave_form_enc = WFEncoder(encoding_size=100, classify=True, n_classes=10)

B = 64
C_in = 2
L = 128

input = torch.randn((B, C_in, L))
output = wave_form_enc(input)
print(input.shape)
print(output.shape)
torch.Size([64, 2, 128])
torch.Size([64, 10])


4. Discriminator

output : “log prob”

  • will be the input of torch.nn.BCEWithLogitsLoss()
class Discriminator(torch.nn.Module):
    def __init__(self, input_size, device):
        super(Discriminator, self).__init__()
        self.device = device
        self.input_size = input_size

        self.model = torch.nn.Sequential(torch.nn.Linear(2*self.input_size, 4*self.input_size),
                                         torch.nn.ReLU(inplace=True),
                                         torch.nn.Dropout(0.5),
                                         torch.nn.Linear(4*self.input_size, 1))

        torch.nn.init.xavier_uniform_(self.model[0].weight)
        torch.nn.init.xavier_uniform_(self.model[3].weight)

    def forward(self, x1, x2):
        """
        log prob of x1 & x2 in SAME NEIGHBORHOOD
        """
        x_all = torch.cat([x1, x2], -1)
        p = self.model(x_all) # (B,1)
        return p.view((-1,)) # (B)


D = Discriminator(input_size=32, device='cpu')

B = 64
in_dim = 32

input1 = torch.randn((B, in_dim))
input2 = torch.randn((B, in_dim))
output = D(input1, input2)

print(input.shape)
print(output.shape)
print(output[0:5])
torch.Size([64, 32])
torch.Size([64])
tensor([-0.0554, -1.1055, -0.4377,  2.1171,  0.6736], grad_fn=<SliceBackward0>)


5. Run

def epoch_run(loader, disc_model, encoder, device, w=0, optimizer=None, train=True):
    if train:
        encoder.train()
        disc_model.train()
    else:
        encoder.eval()
        disc_model.eval()
    
    loss_fn = torch.nn.BCEWithLogitsLoss()
    encoder.to(device)
    disc_model.to(device)
    epoch_loss = 0
    epoch_acc = 0
    batch_count = 0
    
    for x_t, x_p, x_n, _ in loader: 
        # x_t : anchor
        # x_p : positive ( = neighbor )
        # x_n : negative ( = non-neighbor )
        mc_sample = x_p.shape[1]
        batch_size, f_size, len_size = x_t.shape
        x_p = x_p.reshape((-1, f_size, len_size))
        x_n = x_n.reshape((-1, f_size, len_size))
        x_t = np.repeat(x_t, mc_sample, axis=0)
        
        # generate labels
        neighbors = torch.ones((len(x_p))).to(device)
        non_neighbors = torch.zeros((len(x_n))).to(device)

        z_t = encoder(x_t.to(device))
        z_p = encoder(x_p.to(device))
        z_n = encoder(x_n.to(device))

        d_p = disc_model(z_t, z_p)
        d_n = disc_model(z_t, z_n)

        p_loss = loss_fn(d_p, neighbors) # PU -> Positive
        n_loss = loss_fn(d_n, non_neighbors) #  PU -> U (Negative)
        n_loss_u = loss_fn(d_n, neighbors) #  PU -> U (Positive)
        loss = (p_loss + w*n_loss_u + (1-w)*n_loss)/2

        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        p_acc = torch.sum(torch.nn.Sigmoid()(d_p) > 0.5).item() / len(z_p)
        n_acc = torch.sum(torch.nn.Sigmoid()(d_n) < 0.5).item() / len(z_n)
        epoch_acc = epoch_acc + (p_acc+n_acc)/2
        epoch_loss += loss.item()
        batch_count += 1
    return epoch_loss/batch_count, epoch_acc/batch_count

Categories: ,

Updated: