Message passing scheme
Message passing scheme is a popular GNN scheme in many frameworks. It adapts the property of connectivity of neighbors and form a general approach for spatial graph convolutional neural network. It comprises two user-defined functions and one aggregate function. A message function is defined to process information from edge states and node states from neighbors and itself. Messages from each node are obtained and aggregated by aggregate function to provide node-level information for update function. Update function takes current node state and aggregated message and gives a new node state.
Message passing scheme is realized into a abstract type MessagePassing
. Any subtype of MessagePassing
is a message passing layer which utilize default message and update functions:
message(mp, x_i, x_j, e_ij) = x_j
update(mp, m, x) = m
mp
denotes a message passing layer. message
accepts node state x_i
for node i
and its neighbor state x_j
for node j
, as well as corresponding edge state e_ij
for edge (i,j)
. The default message function gives all the neighbor state x_j
for neighbor of node i
. update
takes aggregated message m
and current node state x
, and then outputs m
.
Message function
A message function accepts feature vector representing node state x_i
, feature vectors for neighbor state x_j
and corresponding edge state e_ij
. A vector is expected to output from message
for message. User can override message
for customized message passing layer to provide desired behavior.
Aggregate messages
Messages from message function are aggregated by an aggregate function. An aggregated message is passed to update function for node-level computation. An aggregate function is given by the following:
propagate(mp, fg::FeaturedGraph, aggr::Symbol=:add)
propagate
function calls the whole message passing layer. fg
acts as an input for message passing layer and aggr
represents assignment of aggregate function to propagate
function. :add
represents an aggregate function of addition of all messages.
The following aggr
are available aggregate functions:
:add
: sum over all messages :sub
: negative of sum over all messages :mul
: multiplication over all messages :div
: inverse of multiplication over all messages :max
: the maximum of all messages :min
: the minimum of all messages :mean
: the average of all messages
Update function
An update function takes aggregated message m
and current node state x
as arguments. An output vector is expected to be the new node state for next layer. User can override update
for customized message passing layer to provide desired behavior.