Chapter 7. Adversarial Regularizer (Variational) Graph Autoencoders

( 참고 : https://www.youtube.com/watch?v=JtDgmmQ60x8&list=PLGMXrbDNfqTzqxB1IGgimuhtfAhGd8lHF )

1) Adversarial Regularizer GAE ( VGAE )

goal : 잠재 벡터 (\(Z\)) 가 최대한 continuous 한 space에 놓이게끔!


2) ARGA/ARGVA vs GAE/GVAE

GAE/GVAE

figure2


ARGA/ARGVA

figure2


그렇다면, 무엇이 real data이고,무엇이 fake data인가?

  • real data : \(Z’ \sim p(Z)\)
    • 정규분포 N(0,1)에서 샘플한 데이터
  • fake data : \(Z \sim q(Z)\)
    • 그래프의 노드가 임베딩된 latent representation


Loss Function :

  • \(\min _{\mathcal{G}} \max _{\mathcal{D}} \mathbb{E}_{\mathbf{z} \sim p_{z}}[\log \mathcal{D}(\mathbf{Z})]+\mathbb{E}_{\mathbf{x} \sim p(\mathbf{x})}[\log (1-\mathcal{D}(\mathcal{G}(\mathbf{X}, \mathbf{A})))]\).


3) Encoder ( variational ver )

( GVAE의 encoder와 동일하다 )

class VEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels, cached=True)
        self.conv_mu = GCNConv(2 * out_channels, out_channels, cached=True)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels, cached=True)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)


4) Discriminator

class Discriminator(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(Discriminator, self).__init__()
        self.lin1 = torch.nn.Linear(in_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.lin3 = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = self.lin3(x)
        return x


5) Modeling

모델 생성

from torch_geometric.nn.models.autoencoder import ARGVA

latent_size = 32
encoder = VEncoder(data.num_features, 
                   out_channels=latent_size)
discriminator = Discriminator(in_channels=latent_size, 
                              hidden_channels=64, 
                              out_channels=1) 

model = ARGVA(encoder, discriminator)


6) Train

적대적 학습 과정이 이루어진다는 차이점이 있다

def train():
    model.train()
    encoder_optimizer.zero_grad()
    
    z = model.encode(data.x, data.train_pos_edge_index)
    #-------------------------------------------------------#
    # (1) update DISCRIMINATOR
    for i in range(5):
        idx = range(num_nodes)  
        discriminator.train()
        discriminator_optimizer.zero_grad()
        discriminator_loss = model.discriminator_loss(z[idx]) 
        discriminator_loss.backward(retain_graph=True)
        discriminator_optimizer.step()
 		#-------------------------------------------------------#
		# (2) update ENCODER
    loss = 0
    ## 1) regularizer loss
    loss = loss + model.reg_loss(z)  
    ## 2) reconstruction loss
    loss = loss + model.recon_loss(z, data.train_pos_edge_index) 
    ## 3) KL-divergence loss
    loss = loss + (1 / data.num_nodes) * model.kl_loss() 
    loss.backward()

    encoder_optimizer.step()
		#-------------------------------------------------------#
    return loss

Categories:

Updated: