Global, Local, Global & Local

Settings

import torch
import torch.nn as nn
import torch.nn.functional as F


B = 64
z_global_dim = 128
z_local_dim = 32

DTW_matrix = torch.randn((B,B))


1. Global : between TS

Z = torch.randn((B,z_global_dim))
Z_norm = F.normalize(Z, p=2)
DTW_pred = (Z_norm @ Z_norm.T)


mse_loss = nn.MSELoss()
mse_loss(DTW_pred, DTW_matrix)
tensor(1.0226)


2. Local : within TS

K = 4

ts_part_ANCHOR = torch.randn(((64, z_local_dim)))
ts_part_POS = torch.randn((64, z_local_dim))
ts_part_NEG = torch.randn((K, 64, z_local_dim))
triplet_loss_fn = nn.TripletMarginLoss(margin=1.0, p=2)
triplet_loss = 0
for k in range(K):
    triplet_loss += triplet_loss_fn(ts_part_ANCHOR, 
                                    ts_part_POS, 
                                    ts_part_NEG[k])


3. Global & Local interaction

ts_pos_ANCHOR_pos = torch.cat([ts_part_ANCHOR,ts_part_POS], dim=1)

linear_model = nn.Linear(2*z_local_dim, z_global_dim)
aggregated = linear_model(ts_pos_ANCHOR_pos)
aggregated_norm = F.normalize(aggregated, p=2)


pred = (aggregated_norm@Z_norm.T)
pred = F.softmax(pred,dim=1)

bce_loss_fn = nn.BCELoss()
bce_loss = bce_loss_fn(pred, torch.eye(B))

Categories: ,

Updated: