PyTorch Geometric 2 - Message Passing
1. Graph Data in PyG
컴퓨터가 edge를 저장하는 방식 : “indicies” ( shape = 2 x edge 개수 )
따라서, 컴퓨터가 그래프 \(G=(V,E)\)를 표현하기 위해,
- (1) FEATURE Matrix
- (2) INDICIES Matrix
로 나타냄
Pytorch Geometric이 그래프를 정의하는 방식 : \(G = (X,(I,E))\)
- \(X\) : Node feature ( \(\mathrm{X} \in \mathbb{R}^{ \mid V \mid \times F}\) )
- \(I\) : Edge indices ( \(\mathrm{I} \in\{1, \ldots, N\}^{2 \times \mid E \mid }\) )
- \(E\) : Edge features ( \(\mathrm{E} \in \mathbb{R}^{ \mid E \mid \times D}\) )
2. Message Passing scheme
Target Node를 임베딩하는 방식
\(\mathbf{x}_{i}^{(k)}=\gamma^{(k)}\left(x_{i}^{(k-1)}, \square_{j \in N(i)} \phi^{(k)}\left(\mathbf{x}_{i}^{(k-1)}, \mathbf{x}_{j}^{(k-1)}, \mathbf{e}_{j, i}\right)\right)\).
- \(x\) : Node embedding
- \(k\) : layer index
- \(e\) : Edge feature
- Functions
- (1) \(\phi\) : Message function
- (2) \(\square\) : Aggregation function
- ex) uniform, GAT, …
- (3) \(\gamma\) : Update function
3. Message Passing class
Class : torch_geometric.nn.MessagePassing
핵심
- (1) forward
- (2) message
- (3) propagate
Custom Message Passing 구성하기!
MessagePassing
을 상속받아, 새로운 class로 구성- 새롭게 구현(override)할 메소드
- (1)
forward
- (2)
message
- (1)
- (3)
aggr
: 선택 가능한 aggregation 방법- Ex) add (default), mean, max
propagate
-
MessagePassing
class가 가지고 있는 메소드 중 하나로, -
아래의 인자들을 받는다.
- args1 : edge index (
edge_index
) - args2 : node feature (
x
) - args3 : edge feature (
e
)
(
forawrd
도 위와 동일한 인자를 받는다 ) - args1 : edge index (
Example
class MyOwnConv(MessagePassing):
def __init__(self):
super(MyOwnConv, self).__init__(aggr='add')
#super(MyOwnConv, self).__init__(aggr='mean')
#super(MyOwnConv, self).__init__(aggr='max')
def forward(self, x, edge_index, e):
return self.propagate(edge_index, x=x, e=e)
def message(self, x_j, x_i, e):
# x_j,x_i : j,i번째 node의 feature들
# e : edge feature
return x_j * e
4. Message Passing class ( details )
(1) add & flow
torch_geometric.nn.MessagePassing(aggr="add", flow="source_to_target")
-
aggr
: 생성된 메세지들을 aggregate하는 방식- ex) add, mean, max
-
flow
: 메세지의 흐르는 방향- ex) source_to_target, target_to_source
- source_to_target : 주변으로부터 전달 “받음”
- target_to_source : 주변으로 전달 “함”
- ex) source_to_target, target_to_source
-
node_dim
: node dimension-
여기서 차원은, “특징의 개수”를 의미하는 차원이 아니라,
메세지를 전달할 “축”을 의미함
-
(default) 0
-
(2) propagate
위에서도 간단히 언급을 했지만, 보다 자세히 설명하자면…
propagate(edge_index, size=None, **kwargs)
“forward 함수”를,
- ”propagate 함수”를 사용하여 구성하고,
- propagate 함수는 내부적으로 “message / aggregate 함수”와 “update 함수”를 사용한다
( source code )
P = M+A+U
-
P : Propagate
- M : message
- A : aggregate
- U : Update
def propagate():
# (방법 1) MP & AG 동시
if ..........:
out = self.message_and_aggregate(xxx)
# (방법 2) MP & AG 따로
else:
out = self.message(xxx)
out = self.aggregate(out,xxx)
final_output = self.update(out,xxxx)
(3) message
def message(self, x_j : torch.Tensor) -> torch.Tensor:
# write your code here~
return x_j
- propagate 함수 실행 시, message 함수가 호출 됨
- Node feature(들)을 인자로 받는다
(4) update
def update(self, inputs : torch.Tensor) :
# write your code here~
return inputs
- node embedding을 update함
- out = message(x), out = aggregate(out), 를 거쳐서 나온 “aggregated message”를 input으로 받는다
5. ex) GCN
GCN의 Message Passing :
\(\mathbf{x}_{i}^{(k)}=\sum_{j \in \mathcal{N}(i) \cup i} \frac{1}{\sqrt{\operatorname{deg}(i)} \cdot \sqrt{\operatorname{deg}(j)}} \cdot\left(\boldsymbol{\Theta} \cdot \mathbf{x}_{j}^{(k-1)}\right)\).
진행 순서
( 1~3 : message 생성 )
( 4 : message 취합+전파 )
- Self-loop 추가
- neighbor 뿐만 아니라, target node 스스로의 feature도 인풋으로 받기 때문에
- Weight matrix 곱하기
- Normalization coefficient 계산
- Propagate message
- [M] message 함수 ( + normalize 하기 )
- [A] aggregation 함수
- [U] update 함수
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add')
# ( in_channels, out_channels )
self.lin = nn.Linear(in_Channels, out_channels)
def forward(self, x, edge_index):
# x : ( N_v, in_channels )
# edge_index : ( 2, N_e )
#===================================================#
# [Step 1] Adjacency matrix에 self loop 추가
# edge_index : ( 2, N_e + N_v )
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
#===================================================#
# [Step 2] node feature에 weight matrix 곱하기
# x : ( N_v, in_channels )
x = self.lin(X)
# x : ( N_v, out_channels )
#===================================================#
# [Step 3] Compute normalization coef
row, col = edge_index
# row : edge의 왼쪽 ( shape : (N_e + N_v) )
# col : edge의 오른쪽 ( shape : (N_e + N_v) )
#----------------------------------------------------#
deg = degree(row, x.size(0), dtype=x.dtype)
# deg : 모든 node들의 degree ( shape : (N_v) )
deg_inv_sqrt = deg.pow(-0.5)
# deg : ( shape : (N_v) )
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# norm : ( shape : (N_e + N_v) )
#===================================================#
# [Step 4] Propagate message
## 4-1) message 함수 ( + normalize 하기 )
## 4-2) aggregation 함수
## 4-3) update 함수
return self.propagate(edge_index,
size = (x.size(0), x.size(0)),
x = x,
norm = norm)
def message(self, x_j, norm):
## 4-1) message 함수 ( + normalize 하기 )
# norm : (N_e + N_v)
# norm.view(-1,1) : (N_e + N_v, 1)
# x_j : (N_e + N_v, out_channels)
return norm.view(-1, 1) * x_j
def update(self, aggr_out):
## 4-3) update 함수
# aggr_out : (N_v, out_channels)
return aggr_out
conv = GCNConv(in_channels=16, out_channels=32)
x = conv(x=x, edge_index=edge_index)