[ 7.Graph Neural Networks 2 ]

( 참고 : CS224W: Machine Learning with Graphs )


Contents

  • 7-1. Introduction
  • 7-2. Single GNN Layer
  • 7-3. Types of GNN Layers
  • 7-4. General GNN Layers
  • 7-5. Stacking GNN Layers


7-1. Introduction

GNN Layer : consists of 2 main parts

  • 1) message COMPUTATION
  • 2) message AGGERGATION

figure2


After we have composed GNN Layer…..

how to compose multiple GNN layers?

  • 3) Layer Connectivity

figure2


After that… will talk about..

  • 4) Graph Augmentation
    • 4-1) Graph FEATURE augmentation
    • 4-2) Graph STRUCTURE augmentation


Lastly, will cover

  • 5) Learning Objective


Total Framework :

figure2


7-2. Single GNN Layer

Single GNN Layer consists of 2 steps

  • 1) Message computation
    • how to make each neighborhood node as embedding…
  • 2) Message aggregation
    • and how to combine those embeddings
    • ( of course, the target node itself can also be an input! )

figure2


1) Message Computation

\(\mathbf{m}_{u}^{(l)}=\operatorname{MSG}^{(l)}\left(\mathbf{h}_{u}^{(l-1)}\right)\).

  • function \(\text{MSG}\)?
    • ex) Linear Layer : \(\mathbf{m}_{u}^{(l)}=\mathbf{W}^{(l)} \mathbf{h}_{u}^{(l-1)}\)


2) Message Aggregation

\(\mathbf{h}_{v}^{(l)}=\mathrm{AGG}^{(l)}\left(\left\{\mathbf{m}_{u}^{(l)}, u \in N(v)\right\}\right)\).

  • function \(\text{AGG}\)?

    • ex) sum / mean / max

      \(\mathbf{h}_{v}^{(l)}=\operatorname{Sum}\left(\left\{\mathbf{m}_{u}^{(l)}, u \in N(v)\right\}\right)\).


Problem ( Issue ) : information of the target node ITSELF can be lost!

Solution :

  • 1) message computation

    • compute the message of target node itself
      • ex) neighborhood node : \(\mathbf{m}_{u}^{(l)}=\mathbf{W}^{(l)} \mathbf{h}_{u}^{(l-1)}\)
      • ex) target node : \(\mathbf{m}_{v}^{(l)}=\mathbf{B}^{(l)} \mathbf{h}_{v}^{(l-1)}\)
  • 2) message aggregation

    • ex) concatenation / summation :

      figure2


7-3. Types of GNN Layers

1) GCN

  • expression 1)

    figure2


  • expression 2) as “message COMPUTATION” + “message AGGREGATION”

    figure2

    • meaning :
      • 1) message COMPUTATON : \(\mathbf{m}_{u}^{(l)}=\frac{1}{ \mid N(v) \mid} \mathbf{W}^{(l)} \mathbf{h}_{u}^{(l-1)}\).
      • 2) message AGGREGATION : \(\mathbf{h}_{v}^{(l)}=\sigma\left(\operatorname{Sum}\left(\left\{\mathbf{m}_{u}^{(l)}, u \in N(v)\right\}\right)\right)\).


2) GraphSAGE

GCN vs GraphSAGE

  • GCN : \(\mathbf{h}_{v}^{(l)}=\sigma\left(\sum_{u \in N(v)} \mathbf{W}^{(l)} \frac{\mathbf{h}_{u}^{(l-1)}}{\mid N(v) \mid }\right)\)
  • GraphSAGE : \(\mathbf{h}_{v}^{(l)}=\sigma\left(\mathbf{W}^{(l)} \cdot \operatorname{CONCAT}\left(\mathbf{h}_{v}^{(l-1)}, \operatorname{AGG}\left(\left\{\mathbf{h}_{u}^{(l-1)}, \forall u \in N(v)\right\}\right)\right)\right)\).


To decompose GraphSAGE …

  • 1) message COMPUTATION : aggregate from node neighbors

    • \(\mathbf{h}_{N(v)}^{(l)} \leftarrow \mathrm{AGG}\left(\left\{\mathbf{h}_{u}^{(l-1)}, \forall u \in N(v)\right\}\right)\).

    • AGG function :

      • 1) mean : \(\begin{gathered} A G G= \end{gathered} \sum_{u \in N(v)} \frac{\mathbf{h}_{u}^{(l-1)}}{ \mid N(v) \mid }\)

      • 2) pool : \(\mathrm{AGG}=\operatorname{Mean}\left(\left\{\operatorname{MLP}\left(\mathbf{h}_{u}^{(l-1)}\right), \forall u \in N(v)\right\}\right)\).

      • 3) lstm : \(AGG =\operatorname{LSTM}\left(\left[\mathbf{h}_{u}^{(l-1)}, \forall u \in \pi(N(v))\right]\right)\)

  • 2) message AGGREGATION : aggregate over the node itself

    • \(\mathbf{h}_{v}^{(l)} \leftarrow \sigma\left(\mathbf{W}^{(l)} \cdot \operatorname{CONCAT}\left(\mathbf{h}_{v}^{(l-1)}, \mathbf{h}_{N(v)}^{(l)}\right)\right)\).


3) Graph Attention Networks

\(\mathbf{h}_{v}^{(l)}=\sigma\left(\sum_{u \in N(v)} \alpha_{v u} \mathbf{W}^{(l)} \mathbf{h}_{u}^{(l-1)}\right)\).

  • attention weight in GCN / GraphSAGE :
    • \(\alpha_{v u}=\frac{1}{ \mid N(v) \mid }\) : node u’s message to node v
  • why not learn “attention weight”?
    • attention coefficient : \(e_{v u}=a\left(\mathbf{W}^{(l)} \mathbf{h}_{u}^{(l-1)}, \mathbf{W}^{(l)} \boldsymbol{h}_{v}^{(l-1)}\right)\).
    • attention weight : \(\alpha_{v u}=\frac{\exp \left(e_{v u}\right)}{\sum_{k \in N(v)} \exp \left(e_{v k}\right)}\).
    • weighted sum, based on attention weight : \(\mathbf{h}_{v}^{(l)}=\sigma\left(\sum_{u \in N(v)} \alpha_{v u} \mathbf{W}^{(l)} \mathbf{h}_{u}^{(l-1)}\right)\).


what to use as \(a\) function?

  • ex) single NN

    • \(\begin{aligned} &e_{A B}=a\left(\mathbf{W}^{(l)} \mathbf{h}_{A}^{(l-1)}, \mathbf{W}^{(l)} \mathbf{h}_{B}^{(l-1)}\right) =\text { Linear }\left(\text { Concat }\left(\mathbf{W}^{(l)} \mathbf{h}_{A}^{(l-1)}, \mathbf{W}^{(l)} \mathbf{h}_{B}^{(l-1)}\right)\right) \end{aligned}\).

    figure2


Multi-head attention

ex) 3 heads :

  • \(\mathbf{h}_{v}^{(l)}[1]=\sigma\left(\sum_{u \in N(v)} \alpha_{v u}^{1} \mathbf{W}^{(l)} \mathbf{h}_{u}^{(l-1)}\right)\).
  • \(\mathbf{h}_{v}^{(l)}[2]=\sigma\left(\sum_{u \in N(v)} \alpha_{v u}^{2} \mathbf{W}^{(l)} \mathbf{h}_{u}^{(l-1)}\right)\).
  • \(\mathbf{h}_{v}^{(l)}[3]=\sigma\left(\sum_{u \in N(v)} \alpha_{v u}^{3} \mathbf{W}^{(l)} \mathbf{h}_{u}^{(l-1)}\right)\).
  • then……aggregate by “concatenation or summation”
    • \(\mathbf{h}_{v}^{(l)}=\) AGG \(\left(\mathbf{h}_{v}^{(l)}[1], \mathbf{h}_{v}^{(l)}[2], \mathbf{h}_{v}^{(l)}[3]\right)\)


Benefits of attention?

  • key : different importance for different neighbors
  • 1) computationally efficient
    • attentional coefficients : can be computed in PARALLEL
  • 2) storage efficient
    • fixed number of params ( \(O(V+E)\) entries to be stored )
  • 3) localized
    • attention over LOCAL NETWORK neighborhood
  • 4) inductive capacity
    • shared EDGE-wise mechanism


7-4. General GNN Layers

can use modern DL techniques

  • 1) Batch Norm
  • 2) Dropout
  • 3) Attention & Gating


7-5. Stacking GNN Layers

how to construct GNN?

  • stack GNN layers sequentially!
  • input : raw node feature ( \(\mathrm{x}_{v}\) )
  • output : NODE EMBEDDINGS ( \(\mathbf{h}_{v}^{(L)}\) ) …….. ( after \(L\) GNN Layers )


Over-smoothing Problem

  • over-smoothing : all nodes have similar embeddings

  • reason : TOO MANY GNN LAYERS

    ( think of receptive field! )

    the more GNN layers, the larger receptive field!


figure2

  • deeper GNN layers \(\rightarrow\) shared number of neighbors \(\uparrow\)


So, how to overcome?

  • solution 1) Increase the expressive power WITHIN each GNN layer

    figure2


  • solution 2) Add layers that DO NOT PASS MESSAGES

    • ex) skip connection

    figure2


Skip-connections

but…to use many layers…

  • \(N\) skip connections \(\rightarrow\) \(2^N\) possible paths
  • meaning : mixture of “SHALLOW & DEEP” GNNs

figure2

figure2


Standard GCN

figure2


GCN + Skip Connection

figure2


Other options

  • final layers : aggregates from the all the node embeddings in the previous layers

figure2

Tags: ,

Categories:

Updated: