Temporal Neighborhood Coding (TNC)
Tonekaboni, Sana, Danny Eytan, and Anna Goldenberg. "Unsupervised representation learning for time series with temporal neighborhood coding." arXiv preprint arXiv:2106.00750 (2021).
references :
- https://github.com/sanatonek/TNC_representation_learning
- https://arxiv.org/abs/2106.00750.pdf
import torch
import torch.nn as nn
1. Encoder ( for “simulation” & “har” dataset )
class RnnEncoder(torch.nn.Module):
def __init__(self, hidden_size, in_channel, encoding_size,
cell_type='GRU', num_layers=1, device='cpu', dropout=0, bidirectional=True):
super(RnnEncoder, self).__init__()
self.hidden_size = hidden_size
self.in_channel = in_channel
self.num_layers = num_layers
self.cell_type = cell_type
self.encoding_size = encoding_size
self.bidirectional = bidirectional
self.device = device
self.nn = torch.nn.Linear(self.hidden_size*(int(self.bidirectional) + 1), self.encoding_size).to(self.device)
if cell_type=='GRU':
self.rnn = torch.nn.GRU(input_size=self.in_channel, hidden_size=self.hidden_size, num_layers=num_layers,
batch_first=False, dropout=dropout, bidirectional=bidirectional).to(self.device)
elif cell_type=='LSTM':
self.rnn = torch.nn.LSTM(input_size=self.in_channel, hidden_size=self.hidden_size, num_layers=num_layers,
batch_first=False, dropout=dropout, bidirectional=bidirectional).to(self.device)
else:
raise ValueError('Cell type not defined, must be one of the following {GRU, LSTM, RNN}')
def forward(self, x):
x = x.permute(2,0,1)
if self.cell_type=='GRU':
past = torch.zeros(self.num_layers * (int(self.bidirectional) + 1),
x.shape[1],
self.hidden_size).to(self.device)
elif self.cell_type=='LSTM':
h_0 = torch.zeros(self.num_layers * (int(self.bidirectional) + 1),
x.shape[1],
self.hidden_size).to(self.device)
c_0 = torch.zeros(self.num_layers * (int(self.bidirectional) + 1),
x.shape[1],
self.hidden_size).to(self.device)
past = (h_0, c_0)
out, _ = self.rnn(x.to(self.device), past) # out = (L, B, num_directions * hidden_size)
encodings = self.nn(out[-1].squeeze(0))
return encodings
lstm_enc = RnnEncoder(hidden_size=32, in_channel=1, encoding_size=120,
cell_type='LSTM', num_layers=1, device='cpu', dropout=0, bidirectional=True)
B = 64
C_in = 1
L = 100
input = torch.randn((B, C_in, L))
output = lstm_enc(input)
print(input.shape)
print(output.shape)
torch.Size([64, 1, 100])
torch.Size([64, 120])
2. State Classifier
Batch Norm + FC layer
class StateClassifier(torch.nn.Module):
def __init__(self, input_size, output_size):
super(StateClassifier, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.normalize = torch.nn.BatchNorm1d(self.input_size)
self.nn = torch.nn.Linear(self.input_size, self.output_size)
torch.nn.init.xavier_uniform_(self.nn.weight)
def forward(self, x):
x = self.normalize(x)
logits = self.nn(x)
return logits
sc = StateClassifier(input_size = 16, output_size= 32)
B = 64
C_in = 16
input = torch.randn((B, C_in))
output = sc(input)
print(input.shape)
print(output.shape)
torch.Size([64, 16])
torch.Size([64, 32])
3. Encoder ( for “waveform” dataset )
class WFEncoder(nn.Module):
def __init__(self, encoding_size, classify=False, n_classes=None):
super(WFEncoder, self).__init__()
self.encoding_size = encoding_size
self.n_classes = n_classes
self.classify = classify
self.classifier =None
if self.classify:
if self.n_classes is None:
raise ValueError('Need to specify the number of output classes for te encoder')
else:
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(self.encoding_size, self.n_classes)
)
nn.init.xavier_uniform_(self.classifier[1].weight)
self.features = nn.Sequential(
nn.Conv1d(2, 64, kernel_size=4, stride=1, padding=1),
nn.ELU(inplace=True),
nn.BatchNorm1d(64, eps=0.001),
nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1),
nn.ELU(inplace=True),
nn.BatchNorm1d(64, eps=0.001),
nn.MaxPool1d(kernel_size=2, stride=2),
nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ELU(inplace=True),
nn.BatchNorm1d(128, eps=0.001),
# nn.Dropout(),
nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1),
nn.ELU(inplace=True),
nn.BatchNorm1d(128, eps=0.001),
nn.MaxPool1d(kernel_size=2, stride=2),
# nn.Dropout(0.5),
nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1),
nn.ELU(inplace=True),
nn.BatchNorm1d(256, eps=0.001),
# nn.Dropout(0.5),
nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1),
nn.ELU(inplace=True),
nn.BatchNorm1d(256, eps=0.001),
nn.MaxPool1d(kernel_size=2, stride=2)
)
self.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(3840, 2048),
nn.ELU(inplace=True),
nn.BatchNorm1d(2048, eps=0.001),
nn.Linear(2048, self.encoding_size)
)
def forward(self, x):
x = self.features(x) # (B,2,128) -> (B,256,12)
x = x.view(x.size(0), -1) # (B,256,12) -> (B,3840)
encoding = self.fc(x) # (B,3840) -> (B,100)
if self.classify:
c = self.classifier(encoding) # (B,100) -> (B, num_cls)
return c
else:
return encoding
wave_form_enc = WFEncoder(encoding_size=100, classify=False, n_classes=None)
B = 64
C_in = 2
L = 128
input = torch.randn((B, C_in, L))
output = wave_form_enc(input)
print(input.shape)
print(output.shape)
torch.Size([64, 2, 128])
torch.Size([64, 100])
wave_form_enc = WFEncoder(encoding_size=100, classify=True, n_classes=10)
B = 64
C_in = 2
L = 128
input = torch.randn((B, C_in, L))
output = wave_form_enc(input)
print(input.shape)
print(output.shape)
torch.Size([64, 2, 128])
torch.Size([64, 10])
4. Discriminator
output : “log prob”
- will be the input of
torch.nn.BCEWithLogitsLoss()
class Discriminator(torch.nn.Module):
def __init__(self, input_size, device):
super(Discriminator, self).__init__()
self.device = device
self.input_size = input_size
self.model = torch.nn.Sequential(torch.nn.Linear(2*self.input_size, 4*self.input_size),
torch.nn.ReLU(inplace=True),
torch.nn.Dropout(0.5),
torch.nn.Linear(4*self.input_size, 1))
torch.nn.init.xavier_uniform_(self.model[0].weight)
torch.nn.init.xavier_uniform_(self.model[3].weight)
def forward(self, x1, x2):
"""
log prob of x1 & x2 in SAME NEIGHBORHOOD
"""
x_all = torch.cat([x1, x2], -1)
p = self.model(x_all) # (B,1)
return p.view((-1,)) # (B)
D = Discriminator(input_size=32, device='cpu')
B = 64
in_dim = 32
input1 = torch.randn((B, in_dim))
input2 = torch.randn((B, in_dim))
output = D(input1, input2)
print(input.shape)
print(output.shape)
print(output[0:5])
torch.Size([64, 32])
torch.Size([64])
tensor([-0.0554, -1.1055, -0.4377, 2.1171, 0.6736], grad_fn=<SliceBackward0>)
5. Run
def epoch_run(loader, disc_model, encoder, device, w=0, optimizer=None, train=True):
if train:
encoder.train()
disc_model.train()
else:
encoder.eval()
disc_model.eval()
loss_fn = torch.nn.BCEWithLogitsLoss()
encoder.to(device)
disc_model.to(device)
epoch_loss = 0
epoch_acc = 0
batch_count = 0
for x_t, x_p, x_n, _ in loader:
# x_t : anchor
# x_p : positive ( = neighbor )
# x_n : negative ( = non-neighbor )
mc_sample = x_p.shape[1]
batch_size, f_size, len_size = x_t.shape
x_p = x_p.reshape((-1, f_size, len_size))
x_n = x_n.reshape((-1, f_size, len_size))
x_t = np.repeat(x_t, mc_sample, axis=0)
# generate labels
neighbors = torch.ones((len(x_p))).to(device)
non_neighbors = torch.zeros((len(x_n))).to(device)
z_t = encoder(x_t.to(device))
z_p = encoder(x_p.to(device))
z_n = encoder(x_n.to(device))
d_p = disc_model(z_t, z_p)
d_n = disc_model(z_t, z_n)
p_loss = loss_fn(d_p, neighbors) # PU -> Positive
n_loss = loss_fn(d_n, non_neighbors) # PU -> U (Negative)
n_loss_u = loss_fn(d_n, neighbors) # PU -> U (Positive)
loss = (p_loss + w*n_loss_u + (1-w)*n_loss)/2
if train:
optimizer.zero_grad()
loss.backward()
optimizer.step()
p_acc = torch.sum(torch.nn.Sigmoid()(d_p) > 0.5).item() / len(z_p)
n_acc = torch.sum(torch.nn.Sigmoid()(d_n) < 0.5).item() / len(z_n)
epoch_acc = epoch_acc + (p_acc+n_acc)/2
epoch_loss += loss.item()
batch_count += 1
return epoch_loss/batch_count, epoch_acc/batch_count