Message passing scheme
Message passing scheme is a popular GNN scheme in many frameworks. It adapts the property of connectivity of neighbors and form a general approach for spatial graph convolutional neural network. It comprises two user-defined functions and one aggregate function. A message function is defined to process information from edge states and node states from neighbors and itself. Messages from each node are obtained and aggregated by aggregate function to provide node-level information for update function. Update function takes current node state and aggregated message and gives a new node state.
\[\begin{aligned} m_{ij}^{(l+1)} &= message(h_i^{(l)}, h_j^{(l)}, e_{ij}^{(l)}) \\ m_{i}^{(l+1)} &= \Box_{j \in \mathcal{N}(i)} m_{ij}^{(l+1)} \\ h_i^{(l+1)} &= update(h_i^{(l)}, m_{i}^{(l+1)}) \end{aligned}\]
where $h_i$ and $h_j$ are node features from node $i$ and its neighbor node $j$, $e_{ij}$ is edge feature for edge $(i, j)$, and $u$ is global feature for whole graph. $m_{ij}^{(l+1)}$ denotes messages for $(i, j)$ in $l$-th layer. $message$ and $update$ are message functions and update function, respectively. Aggregate function $\Box$ can be any supported aggregate functions, e.g. max
, sum
or mean
.
Reference: Justin Gilmer, Samuel S. Schoenholz, Patrick F. Riley, Oriol Vinyals, George E. Dahl (2017)
Message passing scheme is realized into a abstract type MessagePassing
. Any subtype of MessagePassing
is a message passing layer which utilize default message and update functions:
message(mp, x_i, x_j, e_ij) = x_j
update(mp, m, x) = m
mp
denotes a message passing layer. message
accepts node state x_i
for node i
and its neighbor state x_j
for node j
, as well as corresponding edge state e_ij
for edge (i,j)
. The default message function gives all the neighbor state x_j
for neighbor of node i
. update
takes aggregated message m
and current node state x
, and then outputs m
.
GeometricFlux.MessagePassing
— TypeMessage function
A message function accepts feature vector representing node state x_i
, feature vectors for neighbor state x_j
and corresponding edge state e_ij
. A vector is expected to output from message
for message. User can override message
for customized message passing layer to provide desired behavior.
GeometricFlux.message
— Functionmessage(mp::MessagePassing, x_i, x_j, e_ij)
Message function for the message-passing scheme, returning the message from node j
to node i
. In the message-passing scheme. the incoming messages from the neighborhood of i
will later be aggregated in order to update
the features of node i
.
By default, the function returns x_j
. Layers subtyping MessagePassing
should specialize this method with custom behavior.
Arguments
mp
: message-passing layer.x_i
: the features of nodei
.x_j
: the features of the nighborj
of nodei
.e_ij
: the features of edge (i
,j
).
See also update
.
Aggregate messages
Messages from message function are aggregated by an aggregate function. An aggregated message is passed to update function for node-level computation. An aggregate function is given by the following:
propagate(mp, fg::FeaturedGraph, aggr::Symbol=:add)
propagate
function calls the whole message passing layer. fg
acts as an input for message passing layer and aggr
represents assignment of aggregate function to propagate
function. :add
represents an aggregate function of addition of all messages.
The following aggr
are available aggregate functions:
:add
: sum over all messages :sub
: negative of sum over all messages :mul
: multiplication over all messages :div
: inverse of multiplication over all messages :max
: the maximum of all messages :min
: the minimum of all messages :mean
: the average of all messages
Update function
An update function takes aggregated message m
and current node state x
as arguments. An output vector is expected to be the new node state for next layer. User can override update
for customized message passing layer to provide desired behavior.
GeometricFlux.update
— Functionupdate(mp::MessagePassing, m, x)
Update function for the message-passing scheme, returning a new set of node features x′
based on old features x
and the incoming message from the neighborhood aggregation m
.
By default, the function returns m
. Layers subtyping MessagePassing
should specialize this method with custom behavior.
Arguments
mp
: message-passing layer.m
: the aggregated edge messages from themessage
function.x
: the node features to be updated.
See also message
.