Global, Local, Global & Local
Settings
import torch
import torch.nn as nn
import torch.nn.functional as F
B = 64
z_global_dim = 128
z_local_dim = 32
DTW_matrix = torch.randn((B,B))
1. Global : between TS
Z = torch.randn((B,z_global_dim))
Z_norm = F.normalize(Z, p=2)
DTW_pred = (Z_norm @ Z_norm.T)
mse_loss = nn.MSELoss()
mse_loss(DTW_pred, DTW_matrix)
tensor(1.0226)
2. Local : within TS
K = 4
ts_part_ANCHOR = torch.randn(((64, z_local_dim)))
ts_part_POS = torch.randn((64, z_local_dim))
ts_part_NEG = torch.randn((K, 64, z_local_dim))
triplet_loss_fn = nn.TripletMarginLoss(margin=1.0, p=2)
triplet_loss = 0
for k in range(K):
triplet_loss += triplet_loss_fn(ts_part_ANCHOR,
ts_part_POS,
ts_part_NEG[k])
3. Global & Local interaction
ts_pos_ANCHOR_pos = torch.cat([ts_part_ANCHOR,ts_part_POS], dim=1)
linear_model = nn.Linear(2*z_local_dim, z_global_dim)
aggregated = linear_model(ts_pos_ANCHOR_pos)
aggregated_norm = F.normalize(aggregated, p=2)
pred = (aggregated_norm@Z_norm.T)
pred = F.softmax(pred,dim=1)
bce_loss_fn = nn.BCELoss()
bce_loss = bce_loss_fn(pred, torch.eye(B))