Graph Convolutional Layers

Graph Convolutional Layer

\[X' = \sigma(\hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2} X \Theta)\]

where $\hat{A} = A + I$, $A$ denotes the adjacency matrix, and $\hat{D} = [\hat{d}_{ij}] = \sum_{j=0} [\hat{a}_{ij}]$ is degree matrix.

GeometricFlux.GCNConvType
GCNConv(in => out, σ=identity; bias=true, init=glorot_uniform)

Graph convolutional layer. The input to the layer is a node feature array X of size (num_features, num_nodes).

Arguments

  • in: The dimension of input features.
  • out: The dimension of output features.
  • σ: Activation function.
  • bias: Add learnable bias.
  • init: Weights' initializer.

Examples

julia> gc = GCNConv(1024=>256, relu)
GCNConv(1024 => 256, relu)

See also WithGraph for training layer with static graph.

source

Reference: Thomas N. Kipf, Max Welling (2017)


Chebyshev Spectral Graph Convolutional Layer

\[X' = \sum^{K-1}_{k=0} Z^{(k)} \Theta^{(k)}\]

where $Z^{(k)}$ is the $k$-th term of Chebyshev polynomials, and can be calculated by the following recursive form:

\[Z^{(0)} = X \\ Z^{(1)} = \hat{L} X \\ Z^{(k)} = 2 \hat{L} Z^{(k-1)} - Z^{(k-2)}\]

and $\hat{L} = \frac{2}{\lambda_{max}} L - I$.

GeometricFlux.ChebConvType
ChebConv(in=>out, k; bias=true, init=glorot_uniform)

Chebyshev spectral graph convolutional layer.

Arguments

  • in: The dimension of input features.
  • out: The dimension of output features.
  • k: The order of Chebyshev polynomial.
  • bias: Add learnable bias.
  • init: Weights' initializer.

Examples

julia> cc = ChebConv(1024=>256, 5, relu)
ChebConv(1024 => 256, k=5, relu)

See also WithGraph for training layer with static graph.

source

Reference: Michaël Defferrard, Xavier Bresson, Pierre Vandergheynst (2016)


Graph Neural Network Layer

\[\textbf{x}_i' = \sigma (\Theta_1 \textbf{x}_i + \sum_{j \in \mathcal{N}(i)} \Theta_2 \textbf{x}_j)\]

GeometricFlux.GraphConvType
GraphConv(in => out, σ=identity, aggr=+; bias=true, init=glorot_uniform)

Graph neural network layer.

Arguments

  • in: The dimension of input features.
  • out: The dimension of output features.
  • σ: Activation function.
  • aggr: An aggregate function applied to the result of message function. +, -,

*, /, max, min and mean are available.

  • bias: Add learnable bias.
  • init: Weights' initializer.

Examples

julia> GraphConv(1024=>256, relu)
GraphConv(1024 => 256, relu, aggr=+)

julia> GraphConv(1024=>256, relu, *)
GraphConv(1024 => 256, relu, aggr=*)

See also WithGraph for training layer with static graph.

source

Reference: Christopher Morris, Martin Ritzert, Matthias Fey, William L. Hamilton, Jan Eric Lenssen, Gaurav Rattan, Martin Grohe (2019)


SAmple and aggreGatE (GraphSAGE) Network

\[\hat{\textbf{x}}_j = sample(\textbf{x}_j), \forall j \in \mathcal{N}(i) \\ \textbf{m}_i = aggregate(\hat{\textbf{x}}_j) \\ \textbf{x}_i' = \sigma (\Theta_1 \textbf{x}_i + \Theta_2 \textbf{m}_i)\]

GeometricFlux.SAGEConvType
SAGEConv(in => out, σ=identity, aggr=mean; normalize=true, project=false,
         bias=true, num_sample=10, init=glorot_uniform)

SAmple and aggreGatE convolutional layer for GraphSAGE network.

Arguments

  • in: The dimension of input features.
  • out: The dimension of output features.
  • σ: Activation function.
  • aggr: An aggregate function applied to the result of message function. mean, max,

LSTM and GCNConv are available.

  • normalize::Bool: Whether to normalize features across all nodes or not.
  • project::Bool: Whether to project, i.e. Dense(in, in), before aggregation.
  • bias: Add learnable bias.
  • num_sample::Int: Number of samples for each node from their neighbors.
  • init: Weights' initializer.

Examples

julia> SAGEConv(1024=>256, relu)
SAGEConv(1024 => 256, relu, aggr=mean, normalize=true, #sample=10)

julia> SAGEConv(1024=>256, relu, num_sample=5)
SAGEConv(1024 => 256, relu, aggr=mean, normalize=true, #sample=5)

julia> MeanAggregator(1024=>256, relu, normalize=false)
SAGEConv(1024 => 256, relu, aggr=mean, normalize=false, #sample=10)

julia> MeanPoolAggregator(1024=>256, relu)
SAGEConv(1024 => 256, relu, project=Dense(1024 => 1024), aggr=mean, normalize=true, #sample=10)

julia> MaxPoolAggregator(1024=>256, relu)
SAGEConv(1024 => 256, relu, project=Dense(1024 => 1024), aggr=max, normalize=true, #sample=10)

julia> LSTMAggregator(1024=>256, relu)
SAGEConv(1024 => 256, relu, aggr=LSTMCell(1024 => 1024), normalize=true, #sample=10)

See also WithGraph for training layer with static graph and MeanAggregator, MeanPoolAggregator, MaxPoolAggregator and LSTMAggregator.

source
GeometricFlux.MeanAggregatorFunction
MeanAggregator(in => out, σ=identity; normalize=true, project=false,
               bias=true, num_sample=10, init=glorot_uniform)

SAGEConv with mean aggregator.

See also SAGEConv.

source
GeometricFlux.LSTMAggregatorFunction
LSTMAggregator(in => out, σ=identity; normalize=true, project=false,
               bias=true, num_sample=10, init=glorot_uniform)

SAGEConv with LSTM aggregator.

See also SAGEConv.

source

Reference: William L Hamilton, Rex Ying, Jure Leskovec (2017) and GraphSAGE website


Graph Attentional Layer

\[\textbf{x}_i' = \alpha_{i,i} \Theta \textbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \Theta \textbf{x}_j\]

where the attention coefficient $\alpha_{i,j}$ can be calculated from

\[\alpha_{i,j} = \frac{exp(LeakyReLU(\textbf{a}^T [\Theta \textbf{x}_i || \Theta \textbf{x}_j]))}{\sum_{k \in \mathcal{N}(i) \cup \{i\}} exp(LeakyReLU(\textbf{a}^T [\Theta \textbf{x}_i || \Theta \textbf{x}_k]))}\]

GeometricFlux.GATConvType
GATConv(in => out, σ=identity; heads=1, concat=true,
        init=glorot_uniform, bias=true, negative_slope=0.2)

Graph attentional layer.

Arguments

  • in: The dimension of input features.
  • out: The dimension of output features.
  • bias::Bool: Keyword argument, whether to learn the additive bias.
  • σ: Activation function.
  • heads: Number attention heads
  • concat: Concatenate layer output or not. If not, layer output is averaged.
  • negative_slope::Real: Keyword argument, the parameter of LeakyReLU.

Examples

julia> GATConv(1024=>256, relu)
GATConv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.2))

julia> GATConv(1024=>256, relu, heads=4)
GATConv(1024=>1024, relu, heads=4, concat=true, LeakyReLU(λ=0.2))

julia> GATConv(1024=>256, relu, heads=4, concat=false)
GATConv(1024=>1024, relu, heads=4, concat=false, LeakyReLU(λ=0.2))

julia> GATConv(1024=>256, relu, negative_slope=0.1f0)
GATConv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.1))

See also WithGraph for training layer with static graph.

source

Reference: Petar Veličković, Guillem Cucurull, Arantxa Casanova, Adriana Romero, Pietro Liò, Yoshua Bengio (2018)


Graph Attentional Layer v2

\[\textbf{x}_i' = \alpha_{i,i} \Theta \textbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \Theta \textbf{x}_j\]

where the attention coefficient $\alpha_{i,j}$ can be calculated from

\[\alpha_{i,j} = \frac{exp(\textbf{a}^T LeakyReLU(\Theta [\textbf{x}_i || \textbf{x}_j]))}{\sum_{k \in \mathcal{N}(i) \cup \{i\}} exp(\textbf{a}^T LeakyReLU(\Theta [\textbf{x}_i || \textbf{x}_k]))}\]

GeometricFlux.GATv2ConvType
GATv2Conv(in => out, σ=identity; heads=1, concat=true,
          init=glorot_uniform, negative_slope=0.2)

Graph attentional layer v2.

Arguments

  • in: The dimension of input features.
  • out: The dimension of output features.
  • σ: Activation function.
  • heads: Number attention heads
  • concat: Concatenate layer output or not. If not, layer output is averaged.
  • negative_slope::Real: Keyword argument, the parameter of LeakyReLU.

Examples

julia> GATv2Conv(1024=>256, relu)
GATv2Conv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.2))

julia> GATv2Conv(1024=>256, relu, heads=4)
GATv2Conv(1024=>1024, relu, heads=4, concat=true, LeakyReLU(λ=0.2))

julia> GATv2Conv(1024=>256, relu, heads=4, concat=false)
GATv2Conv(1024=>1024, relu, heads=4, concat=false, LeakyReLU(λ=0.2))

julia> GATv2Conv(1024=>256, relu, negative_slope=0.1f0)
GATv2Conv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.1))

See also WithGraph for training layer with static graph.

source

Reference: Shaked Brody, Uri Alon, Eran Yahav (2022)


Gated Graph Convolution Layer

\[\textbf{h}^{(0)}_i = \textbf{x}_i || \textbf{0} \\ \textbf{h}^{(l)}_i = GRU(\textbf{h}^{(l-1)}_i, \sum_{j \in \mathcal{N}(i)} \Theta \textbf{h}^{(l-1)}_j)\]

where $\textbf{h}^{(l)}_i$ denotes the $l$-th hidden variables passing through GRU. The dimension of input $\textbf{x}_i$ needs to be less or equal to out.

GeometricFlux.GatedGraphConvType
GatedGraphConv([fg,] out, num_layers; aggr=+, init=glorot_uniform)

Gated graph convolution layer.

Arguments

  • out: The dimension of output features.
  • num_layers: The number of gated recurrent unit.
  • aggr: An aggregate function applied to the result of message function. +, -,

*, /, max, min and mean are available.

Examples

julia> GatedGraphConv(256, 4)
GatedGraphConv((256 => 256)^4, aggr=+)

julia> GatedGraphConv(256, 4, aggr=*)
GatedGraphConv((256 => 256)^4, aggr=*)

See also WithGraph for training layer with static graph.

source

Reference: Yujia Li, Daniel Tarlow, Marc Brockschmidt, Richard Zemel (2016)


Edge Convolutional Layer

\[\textbf{x}_i' = \sum_{j \in \mathcal{N}(i)} f_{\Theta}(\textbf{x}_i || \textbf{x}_j - \textbf{x}_i)\]

where $f_{\Theta}$ denotes a neural network parametrized by $\Theta$, i.e., a MLP.

GeometricFlux.EdgeConvType
EdgeConv(nn; aggr=max)

Edge convolutional layer.

Arguments

  • nn: A neural network (e.g. a Dense layer or a MLP).
  • aggr: An aggregate function applied to the result of message function.

+, max and mean are available.

Examples

julia> EdgeConv(Dense(1024, 256, relu))
EdgeConv(Dense(1024 => 256, relu), aggr=max)

julia> EdgeConv(Dense(1024, 256, relu), aggr=+)
EdgeConv(Dense(1024 => 256, relu), aggr=+)

See also WithGraph for training layer with static graph.

source

Reference: Yue Wang, Yongbin Sun, Ziwei Liu, Sanjay E. Sarma, Michael M. Bronstein, Justin M. Solomon (2019)


Graph Isomorphism Network

\[\textbf{x}_i' = f_{\Theta}\left((1 + \varepsilon) \cdot \textbf{x}_i + \sum_{j \in \mathcal{N}(i)} \textbf{x}_j \right)\]

where $f_{\Theta}$ denotes a neural network parametrized by $\Theta$, i.e., a MLP.

GeometricFlux.GINConvType
GINConv(nn, [eps=0])

Graph Isomorphism Network.

Arguments

  • nn: A neural network/layer.
  • eps: Weighting factor.

Examples

julia> GINConv(Dense(1024, 256, relu))
GINConv(Dense(1024 => 256, relu), ϵ=0.0)

julia> GINConv(Dense(1024, 256, relu), 1.f-6)
GINConv(Dense(1024 => 256, relu), ϵ=1.0e-6)

See also WithGraph for training layer with static graph.

source

Reference: Keyulu Xu, Weihua Hu, Jure Leskovec, Stefanie Jegelka (2019)


Crystal Graph Convolutional Network

\[\textbf{x}_i' = \textbf{x}_i + \sum_{j \in \mathcal{N}(i)} \sigma\left( \textbf{z}_{i,j} \textbf{W}_f + \textbf{b}_f \right) \odot \text{softplus}\left(\textbf{z}_{i,j} \textbf{W}_s + \textbf{b}_s \right)\]

where $\textbf{z}_{i,j} = [\textbf{x}_i, \textbf{x}_j}, \textbf{e}_{i,j}]$ denotes the concatenation of node features, neighboring node features, and edge features. The operation $\odot$ represents elementwise multiplication, and $\sigma$ denotes the sigmoid function.

GeometricFlux.CGConvType
CGConv((node_dim, edge_dim), init, bias=true)

Crystal Graph Convolutional network. Uses both node and edge features.

Arguments

  • node_dim: Dimensionality of the input node features. Also is necessarily the output dimensionality.
  • edge_dim: Dimensionality of the input edge features.
  • init: Initialization algorithm for each of the weight matrices
  • bias: Whether or not to learn an additive bias parameter.

Examples

julia> CGConv((128, 32))
CGConv(node dim=128, edge dim=32)

See also WithGraph for training layer with static graph.

source

Reference: Tian Xie, Jeffrey C. Grossman (2018)