CS224W: Machine Learning with Graphs - 08 GNN Augmentation and Training
作者:互联网
GNN Augmentation and Training
0. A General GNN Framework
Idea: raw input graph ≠ \neq = computational graph
- Graph feature augmentation
- Graph structure manipulation
1). Why Agument Graphs?
Our assumption so far has been: raw input graph = computational graph
Reasons for breaking this assumption
- Features
The input graph lacks features - Graph structure
The graph is too sparse → \rightarrow → inefficient message passing
The graph is too dense → \rightarrow → message passing is too costly
The graph is too large → \rightarrow → cannot fit the computational graph into a GPU
It is unlikely that the input graph happens to be the optimal computation graph for embeddings.
2). Graph Augmentation Approaches
- Graph feature augmentation
The input graph lacks features → \rightarrow → feature augmentation - Graph structure augmentation
The graph is too sparse → \rightarrow → add virtual nodes / edges
The graph is too dense → \rightarrow → sample neighbors when doing message passing
The graph is too large → \rightarrow → sample subgraphs to compute embeddings
1. Feature Augmentation on Graphs
to be update
1). Message Computation
Message function:
m
u
l
=
M
S
G
l
(
h
u
l
−
1
)
m_u^l=MSG^l(h_u^{l-1})
mul=MSGl(hul−1)
Intuition: each node will create a message, which will be sent to other nodes later
Example: a linear layer
m
u
l
=
W
l
h
u
l
−
1
m_u^l=W^lh_u^{l-1}
mul=Wlhul−1
2). Message Aggregation
Intuition: each node will aggregate the messages from node
v
v
v's neighbors
h
v
l
=
A
G
G
l
(
{
m
u
l
,
u
∈
N
(
v
)
}
)
h_v^l=AGG^l(\{m_u^l, u\in N(v)\})
hvl=AGGl({mul,u∈N(v)})
Example: sum, mean, max aggregator
Issue: information from node
v
v
v itself could get lost (computation of
h
v
l
h_v^l
hvl does not directly depend on
h
v
l
−
1
h_v^{l-1}
hvl−1)
Solution: include
h
v
l
−
1
h_v^{l-1}
hvl−1 when computing
h
v
l
h_v^l
hvl
- Message: compute message from node
v
v
v itself
Perform a different message computation m v l = B l h v l − 1 m_v^l=B^lh_v^{l-1} mvl=Blhvl−1 - Aggregation: after aggregating from neighbors, we can aggregate message from node
v
v
v itself via concatenation or summation
h v l = CONCAT ( A G G l ( { m u l , u ∈ N ( v ) } ) , m v l ) h_v^l=\text{CONCAT}(AGG^l(\{m_u^l, u\in N(v)\}), m_v^l) hvl=CONCAT(AGGl({mul,u∈N(v)}),mvl) - Nonlinearity (activation): Adds expressiveness to message or aggregation
2. Classical GNN Layers
1). Graph Convolutional Networks (GCNs)
h
v
l
=
σ
(
W
l
∑
u
∈
N
(
v
)
h
u
l
−
1
∣
N
(
v
)
∣
)
=
σ
(
∑
u
∈
N
(
v
)
W
l
h
u
l
−
1
∣
N
(
v
)
∣
)
h_v^l=\sigma (W^l\sum_{u\in N(v)}\frac{h_u^{l-1}}{|N(v)|})=\sigma (\sum_{u\in N(v)}W^l\frac{h_u^{l-1}}{|N(v)|})
hvl=σ(Wlu∈N(v)∑∣N(v)∣hul−1)=σ(u∈N(v)∑Wl∣N(v)∣hul−1)
Message: each neighbor
m
u
l
=
1
∣
N
(
v
)
∣
W
l
h
u
l
−
1
m_u^l=\frac{1}{|N(v)|}W^lh_u^{l-1}
mul=∣N(v)∣1Wlhul−1 (normalized by node degree)
Aggregation: sum over messages from neighbors, then apply activation
h
v
l
=
σ
(
Sum
(
{
m
u
l
,
u
∈
N
(
v
)
}
)
)
h_v^l=\sigma (\text{Sum}(\{m_u^l, u\in N(v)\}))
hvl=σ(Sum({mul,u∈N(v)}))
2). GraphSAGE
h v l = σ ( W l ⋅ CONCAT ( h v l − 1 , A G G l ( { h u l − 1 , u ∈ N ( v ) } ) ) ) h_v^l=\sigma(W^l\cdot\text{CONCAT}(h_v^{l-1}, AGG^l(\{h_u^{l-1}, u\in N(v)\}))) hvl=σ(Wl⋅CONCAT(hvl−1,AGGl({hul−1,u∈N(v)})))
a). GraphSAGE neighbor aggregation
- Mean: take a weighted average of neighbors (GCN)
A G G = ∑ u ∈ N ( v ) h u l − 1 ∣ N ( v ) ∣ AGG = \sum_{u\in N(v)}\frac{h_u^{l-1}}{|N(v)|} AGG=u∈N(v)∑∣N(v)∣hul−1 - Pool: transform neighbor vectors and apply symmetric vector function (mean/max)
A G G = Mean ( { MLP ( h u l − 1 ) , u ∈ N ( v ) } ) AGG = \text{Mean}(\{\text{MLP}(h_u^{l-1}), u\in N(v)\}) AGG=Mean({MLP(hul−1),u∈N(v)}) - LSTM: apply LSTM to reshuffled of neghbors
A G G = LSTM ( [ h u l − 1 , ∀ u ∈ π ( N ( v ) ) ] ) AGG = \text{LSTM}([h_u^{l-1}, \forall u \in \pi(N(v))]) AGG=LSTM([hul−1,∀u∈π(N(v))])
b). L 2 L_2 L2 normalization
Optional: apply
L
2
L_2
L2 normalization to
h
v
l
h_v^l
hvl at every layer
h
v
l
←
h
v
l
∣
∣
h
v
l
∣
∣
2
∀
v
∈
V
h_v^l\leftarrow\frac{h_v^l}{||h_v^l||_2} \forall v \in V
hvl←∣∣hvl∣∣2hvl∀v∈V where
∣
∣
u
∣
∣
2
=
∑
i
u
i
2
||u||_2=\sqrt{\sum_iu_i^2}
∣∣u∣∣2=∑iui2
(
L
2
L_2
L2-norm)
Without
L
2
L_2
L2 normalization, the embedding vectors have different scales for vectors
In some cases, normalization of embedding results in performance improvement
After
L
2
L_2
L2 normalization, all vectors will have the same
L
2
L_2
L2-norm
3). Graph Attention Networks (GATs)
a). Not all nodes’ neighbors are equally important
- In GCN and GraphSAGE, 1 ∣ N ( v ) ∣ \frac{1}{|N(v)|} ∣N(v)∣1 is the weighting factor (importance) of node u u u's message to node v v v. It is defined explicitly based on the structural properties of the graph (node degree) and all neighbors u ∈ N ( v ) u\in N(v) u∈N(v) are equally important to node v v v.
- The attention α v u \alpha_{vu} αvu focuses on the important parts of the input data and fades out the rest.
- Idea: the NN should devote more computing power on that small but important part of the data, which depends on the context and is learned through training.
h
v
l
=
σ
(
∑
u
∈
N
(
v
)
α
v
u
W
l
h
u
l
−
1
)
h_v^l=\sigma (\sum_{u\in N(v)}\alpha_{vu}W^lh_u^{l-1})
hvl=σ(u∈N(v)∑αvuWlhul−1)
Goal: specify arbitrary importance to different neighbors of each node in the graph.
Idea: compute embedding
h
v
l
h_v^l
hvl of each node in the graph following an attention strategy。
b). Attention mechanism
Let α v u \alpha_{vu} αvu be computed as a byproduct of an attention mechanism a a a
- Let
a
a
a compute attention coefficients
e
v
u
e_{vu}
evu across pairs of nodes
u
u
u and
v
v
v based on their messages
e v u = a ( W l h u l − 1 , W l h v l − 1 ) e_{vu}=a(W^lh_u^{l-1}, W^lh_v^{l-1}) evu=a(Wlhul−1,Wlhvl−1)
which indicates the importance of u u u's message to node v v v. - Normalize
e
v
u
e_{vu}
evu into the final attention weight
α
v
u
\alpha_{vu}
αvu by the softmax function
α v u = exp ( e v u ) ∑ k ∈ N ( v ) exp ( e v k ) \alpha_{vu}=\frac{\exp(e_{vu})}{\sum_{k\in N(v)}\exp(e_{vk})} αvu=∑k∈N(v)exp(evk)exp(evu) - Weighted sum based on the final attention weight
α
v
u
\alpha_{vu}
αvu
h v l = σ ( ∑ u ∈ N ( v ) α v u W l h u l − 1 ) h_v^l=\sigma (\sum_{u\in N(v)}\alpha_{vu}W^lh_u^{l-1}) hvl=σ(u∈N(v)∑αvuWlhul−1)
Form of attention mechanism
a
a
a: the approach is agnostic to the choice of
a
a
a
Example: use a simple single-layer neural network (
a
a
a have trainable parameters in the Linear layer)
e
A
B
=
a
(
W
l
h
A
l
−
1
,
W
l
h
B
l
−
1
)
=
Linear
(
Concat
(
W
l
h
A
l
−
1
,
W
l
h
B
l
−
1
)
)
e_{AB}=a(W^lh_A^{l-1}, W^lh_B^{l-1})=\text{Linear}(\text{Concat}(W^lh_A^{l-1}, W^lh_B^{l-1}))
eAB=a(WlhAl−1,WlhBl−1)=Linear(Concat(WlhAl−1,WlhBl−1))
Parameters of
a
a
a are trained together with weight matrices (i.e., para. of
W
l
W^l
Wl) in an end-to-end fashion.
c). Multi-head attention
To be updated
d). Benefits of attention mechanism
Key benefit: allow for (implicitly) specifying different importance values to different neighbors
- Computationally efficient: computation and aggregation can be parallelized
- Storage efficient: sparse matrix operations do not require more than O ( V + E ) O(V+E) O(V+E) enties to be sotred; fixed number of parameters, irrespective of graph size
- Localized: only attends over local network neighbors
- Inductive capability: a shared edge-wise mechanism that does not depend on the global graph structure
e). GNN layer in practice
We can include modern deep learning modules that proved to be useful in many domains
- Batch normalization: stabilize NN training
- Dropout: prevent overfitting
- Attention/Gating: control the importantce of a message
- More: any other deep learning modules
3. Stacking GNN Layers
0). How to Connect GNN Layers into a GNN?
- Stack layers sequentially (standard way)
Input: initial raw node feature x v x_v xv
Output: node embeddings h v l h_v^l hvl after L L L GNN layers - Ways of adding skip connections
1). The Over-smoothing Problem
Issue: all the node embeddings converge to the same value after stacking many GNN layers. This is bad because we want to use node embeddings to differentiate nodes
a). Receptive field of a GNN
Receptive field: the set of nodes that determinte the embedding of a node of interest
In a
K
K
K-layer GNN, each node has a receptive field of
K
K
K-hop neighborhood. The shared neighbors quickly grows when we increase the number of hops (num of GNN layers)
b). Receptive field & over-smoothing
Stack many GNN layers → \rightarrow → Nodes will have highly-overlapped receptive fields → \rightarrow → Node embeddings will be highly similar → \rightarrow → Suffer from the over-smoothing problem
c). Be cautious when stacking GNN layers
Unlike NN in other domains, adding more GNN layers does not always help
- Step 1: analyze the necessary receptive field to solve the problem (e.g., by computing the diameter of the graph)
- Step 2: set number of GNN layers L L L to be a bit more than the receptive field. Do not set L L L to be unnecessarily large.
2). Expressive Power for Shallow GNNs
a). Increase the expressive power within each GNN layer
- In our previous examples, each transformation or aggregation function only include one linear layer
- We can make aggregation and transformation become a DNN
b). Add layers that do not pass messages
A GNN does not necessarily only contain GNN layers. We can add MLP layers before and after GNN layers as preprocessing layers and postprocessing layers.
- Preprocessing layers: important when encoding node features is necessary (e.g., when nodes represent images / text)
- Postprocessing layers: important when reasoning / transformation over node embeddings are needed (graph classification, knowledge graphs)
In practice, adding these layers work great.
3). Add skip connections in GNNs
Observation from over-smoothing: node embeddings in earlier GNN layers can sometimes better differentiate nodes.
Solutions: we can increase the impact of earlier layers on the final node embeddings by adding shortcuts in GNNs
- Idea of skip connections
Before adding shortcuts: F ( x ) F(x) F(x)
After adding shortcuts: F ( x ) + x F(x)+x F(x)+x - Why do skip connections work?
Intuition: skip connections create a mixture of models.
N N N skip connections lead to 2 N 2^N 2N possible paths and each path could have up to N N N modules. We automatically get a mixture of shallow GNNs and deep GNNs. - Other options: directly skip to the last layer
标签:node,layers,Training,CS224W,hvl,graph,08,vu,GNN 来源: https://blog.csdn.net/fxb163/article/details/122283645