Pytorch Geometric Review 3 - GVAE

( 참고 : https://www.youtube.com/c/DeepFindr/videos )


Graph VAE ( GVAE ) pytorch

import torch
import torch.nn as nn
from torch.nn import Linear
from torch_geometric.nn.conv import TransformerConv
from torch_geometric.nn import Set2Set
from torch_geometric.nn import BatchNorm
from config import SUPPORTED_ATOMS, SUPPORTED_EDGES, MAX_MOLECULE_SIZE, ATOMIC_NUMBERS
from utils import graph_representation_to_molecule, to_one_hot
from tqdm import tqdm
class GVAE(nn.Module):
    def __init__(self, feature_size):
        super(GVAE, self).__init__()
        self.encoder_embedding_size = 64
        self.edge_dim = 11
        self.latent_embedding_size = 128
        self.num_edge_types = len(SUPPORTED_EDGES) 
        self.num_atom_types = len(SUPPORTED_ATOMS)
        self.max_num_atoms = MAX_MOLECULE_SIZE 
        self.decoder_hidden_neurons = 512
				#----------------------------------------------------------------
        # Encoder layers
        self.conv1 = TransformerConv(feature_size, 
                                    self.encoder_embedding_size, 
                                    heads=4, 
                                    concat=False,
                                    beta=True,
                                    edge_dim=self.edge_dim)
        self.bn1 = BatchNorm(self.encoder_embedding_size)
        self.conv2 = TransformerConv(self.encoder_embedding_size, 
                                    self.encoder_embedding_size, 
                                    heads=4, 
                                    concat=False,
                                    beta=True,
                                    edge_dim=self.edge_dim)
        self.bn2 = BatchNorm(self.encoder_embedding_size)
        self.conv3 = TransformerConv(self.encoder_embedding_size, 
                                    self.encoder_embedding_size, 
                                    heads=4, 
                                    concat=False,
                                    beta=True,
                                    edge_dim=self.edge_dim)
        self.bn3 = BatchNorm(self.encoder_embedding_size)
        self.conv4 = TransformerConv(self.encoder_embedding_size, 
                                    self.encoder_embedding_size, 
                                    heads=4, 
                                    concat=False,
                                    beta=True,
                                    edge_dim=self.edge_dim)
        
				#----------------------------------------------------------------#
        # Pooling layers
        self.pooling = Set2Set(self.encoder_embedding_size, processing_steps=4)

        # Mean & Log variance
        self.mu_transform = Linear(self.encoder_embedding_size*2, self.latent_embedding_size)
        self.logvar_transform = Linear(self.encoder_embedding_size*2, self.latent_embedding_size)
        
				#----------------------------------------------------------------#
        # Decoder layers
        # --- Shared layers
        self.linear_1 = Linear(self.latent_embedding_size, self.decoder_hidden_neurons)
        self.linear_2 = Linear(self.decoder_hidden_neurons, self.decoder_hidden_neurons)

        # --- Atom decoding (outputs a matrix: (max_num_atoms) * (# atom_types + "none"-type))   
        atom_output_dim = self.max_num_atoms*(self.num_atom_types + 1)
        self.atom_decode = Linear(self.decoder_hidden_neurons, atom_output_dim)

        # --- Edge decoding (outputs a triu tensor: (max_num_atoms*(max_num_atoms-1)/2*(#edge_types + 1) ))
        edge_output_dim = int(((self.max_num_atoms * (self.max_num_atoms - 1)) / 2) * (self.num_edge_types + 1))
        self.edge_decode = Linear(self.decoder_hidden_neurons, edge_output_dim)
        

    def encode(self, x, edge_attr, edge_index, batch_index):
        # step 1) GNN layers 통과하기
        x = self.bn1(self.conv1(x, edge_index, edge_attr).relu())
        x = self.bn2(self.conv2(x, edge_index, edge_attr).relu())
        x = self.bn3(self.conv3(x, edge_index, edge_attr).relu())
        x = self.conv4(x, edge_index, edge_attr).relu()
        
        # step 2) Graph Representation 뽑아내기
        graph_vector = self.pooling(x, batch_index)

        # step 3) mean & log variance 얻어내기
        mu = self.mu_transform(graph_vector)
        logvar = self.logvar_transform(graph_vector)
        return mu, logvar

    def decode_graph(self, graph_z):  
        # Pass through shared layers
        z = self.linear_1(graph_z).relu()
        z = self.linear_2(z).relu()
        # Decode atom types
        atom_logits = self.atom_decode(z)
        # Decode edge types
        edge_logits = self.edge_decode(z)

        return atom_logits, edge_logits


    def decode(self, z, batch_index):
        node_logits = []
        triu_logits = []
        # Iterate over molecules in batch
        for graph_id in torch.unique(batch_index):
            # Get latent vector for this graph
            graph_z = z[graph_id]

            # Recover graph from latent vector
            atom_logits, edge_logits = self.decode_graph(graph_z)

            # Store per graph results
            node_logits.append(atom_logits)
            triu_logits.append(edge_logits)

        # Concatenate all outputs of the batch
        node_logits = torch.cat(node_logits)
        triu_logits = torch.cat(triu_logits)
        return triu_logits, node_logits


    def reparameterize(self, mu, logvar):
        if self.training:
            # Get standard deviation
            std = torch.exp(logvar)
            # Returns random numbers from a normal distribution
            eps = torch.randn_like(std)
            # Return sampled values
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, x, edge_attr, edge_index, batch_index):
        # Encode the molecule
        mu, logvar = self.encode(x, edge_attr, edge_index, batch_index)
        # Sample latent vector (per atom)
        z = self.reparameterize(mu, logvar)
        # Decode latent vector into original molecule
        triu_logits, node_logits = self.decode(z, batch_index)

        return triu_logits, node_logits, mu, logvar

    
    def sample_mols(self, num=10000):
        print("Sampling molecules ... ")

        n_valid = 0
        # Sample molecules and check if they are valid
        for _ in tqdm(range(num)):
            # Sample latent space
            z = torch.randn(1, self.latent_embedding_size)

            # Get model output (this could also be batched)
            dummy_batch_index = torch.Tensor([0]).int()
            triu_logits, node_logits = self.decode(z, dummy_batch_index)

            # Reshape triu predictions 
            edge_matrix_shape = (int((MAX_MOLECULE_SIZE * (MAX_MOLECULE_SIZE - 1))/2), len(SUPPORTED_EDGES) + 1) 
            triu_preds_matrix = triu_logits.reshape(edge_matrix_shape)
            triu_preds = torch.argmax(triu_preds_matrix, dim=1)

            # Reshape node predictions
            node_matrix_shape = (MAX_MOLECULE_SIZE, (len(SUPPORTED_ATOMS) + 1)) 
            node_preds_matrix = node_logits.reshape(node_matrix_shape)
            node_preds = torch.argmax(node_preds_matrix[:, :9], dim=1)
            
            # Get atomic numbers 
            node_preds_one_hot = to_one_hot(node_preds, options=ATOMIC_NUMBERS)
            atom_numbers_dummy = torch.Tensor(ATOMIC_NUMBERS).repeat(node_preds_one_hot.shape[0], 1)
            atom_types = torch.masked_select(atom_numbers_dummy, node_preds_one_hot.bool())

            # Attempt to create valid molecule
            smiles, _ = graph_representation_to_molecule(atom_types, triu_preds.float())

            # A dot means disconnected
            if smiles and "." not in smiles:
                print("Successfully generated: ", smiles)
                n_valid += 1    
        return n_valid

Categories:

Updated: