Graph Convolutional Layers
Graph Convolutional Layer
\[X' = \sigma(\hat{D}^{-1/2} \hat{A} \hat{D}^{-1/2} X \Theta)\]
where $\hat{A} = A + I$, $A$ denotes the adjacency matrix, and $\hat{D} = [\hat{d}_{ij}] = \sum_{j=0} [\hat{a}_{ij}]$ is degree matrix.
GeometricFlux.GCNConv
— TypeGCNConv(in => out, σ=identity; bias=true, init=glorot_uniform)
Graph convolutional layer. The input to the layer is a node feature array X
of size (num_features, num_nodes)
.
Arguments
in
: The dimension of input features.out
: The dimension of output features.σ
: Activation function.bias
: Add learnable bias.init
: Weights' initializer.
Examples
julia> gc = GCNConv(1024=>256, relu)
GCNConv(1024 => 256, relu)
See also WithGraph
for training layer with static graph.
Reference: Thomas N. Kipf, Max Welling (2017)
Chebyshev Spectral Graph Convolutional Layer
\[X' = \sum^{K-1}_{k=0} Z^{(k)} \Theta^{(k)}\]
where $Z^{(k)}$ is the $k$-th term of Chebyshev polynomials, and can be calculated by the following recursive form:
\[Z^{(0)} = X \\ Z^{(1)} = \hat{L} X \\ Z^{(k)} = 2 \hat{L} Z^{(k-1)} - Z^{(k-2)}\]
and $\hat{L} = \frac{2}{\lambda_{max}} L - I$.
GeometricFlux.ChebConv
— TypeChebConv(in=>out, k; bias=true, init=glorot_uniform)
Chebyshev spectral graph convolutional layer.
Arguments
in
: The dimension of input features.out
: The dimension of output features.k
: The order of Chebyshev polynomial.bias
: Add learnable bias.init
: Weights' initializer.
Examples
julia> cc = ChebConv(1024=>256, 5, relu)
ChebConv(1024 => 256, k=5, relu)
See also WithGraph
for training layer with static graph.
Reference: Michaël Defferrard, Xavier Bresson, Pierre Vandergheynst (2016)
Graph Neural Network Layer
\[\textbf{x}_i' = \sigma (\Theta_1 \textbf{x}_i + \sum_{j \in \mathcal{N}(i)} \Theta_2 \textbf{x}_j)\]
GeometricFlux.GraphConv
— TypeGraphConv(in => out, σ=identity, aggr=+; bias=true, init=glorot_uniform)
Graph neural network layer.
Arguments
in
: The dimension of input features.out
: The dimension of output features.σ
: Activation function.aggr
: An aggregate function applied to the result of message function.+
,-
,
*
, /
, max
, min
and mean
are available.
bias
: Add learnable bias.init
: Weights' initializer.
Examples
julia> GraphConv(1024=>256, relu)
GraphConv(1024 => 256, relu, aggr=+)
julia> GraphConv(1024=>256, relu, *)
GraphConv(1024 => 256, relu, aggr=*)
See also WithGraph
for training layer with static graph.
SAmple and aggreGatE (GraphSAGE) Network
\[\hat{\textbf{x}}_j = sample(\textbf{x}_j), \forall j \in \mathcal{N}(i) \\ \textbf{m}_i = aggregate(\hat{\textbf{x}}_j) \\ \textbf{x}_i' = \sigma (\Theta_1 \textbf{x}_i + \Theta_2 \textbf{m}_i)\]
GeometricFlux.SAGEConv
— TypeSAGEConv(in => out, σ=identity, aggr=mean; normalize=true, project=false,
bias=true, num_sample=10, init=glorot_uniform)
SAmple and aggreGatE convolutional layer for GraphSAGE network.
Arguments
in
: The dimension of input features.out
: The dimension of output features.σ
: Activation function.aggr
: An aggregate function applied to the result of message function.mean
,max
,
LSTM
and GCNConv
are available.
normalize::Bool
: Whether to normalize features across all nodes or not.project::Bool
: Whether to project, i.e.Dense(in, in)
, before aggregation.bias
: Add learnable bias.num_sample::Int
: Number of samples for each node from their neighbors.init
: Weights' initializer.
Examples
julia> SAGEConv(1024=>256, relu)
SAGEConv(1024 => 256, relu, aggr=mean, normalize=true, #sample=10)
julia> SAGEConv(1024=>256, relu, num_sample=5)
SAGEConv(1024 => 256, relu, aggr=mean, normalize=true, #sample=5)
julia> MeanAggregator(1024=>256, relu, normalize=false)
SAGEConv(1024 => 256, relu, aggr=mean, normalize=false, #sample=10)
julia> MeanPoolAggregator(1024=>256, relu)
SAGEConv(1024 => 256, relu, project=Dense(1024 => 1024), aggr=mean, normalize=true, #sample=10)
julia> MaxPoolAggregator(1024=>256, relu)
SAGEConv(1024 => 256, relu, project=Dense(1024 => 1024), aggr=max, normalize=true, #sample=10)
julia> LSTMAggregator(1024=>256, relu)
SAGEConv(1024 => 256, relu, aggr=LSTMCell(1024 => 1024), normalize=true, #sample=10)
See also WithGraph
for training layer with static graph and MeanAggregator
, MeanPoolAggregator
, MaxPoolAggregator
and LSTMAggregator
.
GeometricFlux.MeanAggregator
— FunctionMeanAggregator(in => out, σ=identity; normalize=true, project=false,
bias=true, num_sample=10, init=glorot_uniform)
SAGEConv with mean aggregator.
See also SAGEConv
.
GeometricFlux.MeanPoolAggregator
— FunctionMeanAggregator(in => out, σ=identity; normalize=true,
bias=true, num_sample=10, init=glorot_uniform)
SAGEConv with meanpool aggregator.
See also SAGEConv
.
GeometricFlux.MaxPoolAggregator
— FunctionMeanAggregator(in => out, σ=identity; normalize=true,
bias=true, num_sample=10, init=glorot_uniform)
SAGEConv with maxpool aggregator.
See also SAGEConv
.
GeometricFlux.LSTMAggregator
— FunctionLSTMAggregator(in => out, σ=identity; normalize=true, project=false,
bias=true, num_sample=10, init=glorot_uniform)
SAGEConv with LSTM aggregator.
See also SAGEConv
.
Reference: William L Hamilton, Rex Ying, Jure Leskovec (2017) and GraphSAGE website
Graph Attentional Layer
\[\textbf{x}_i' = \alpha_{i,i} \Theta \textbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \Theta \textbf{x}_j\]
where the attention coefficient $\alpha_{i,j}$ can be calculated from
\[\alpha_{i,j} = \frac{exp(LeakyReLU(\textbf{a}^T [\Theta \textbf{x}_i || \Theta \textbf{x}_j]))}{\sum_{k \in \mathcal{N}(i) \cup \{i\}} exp(LeakyReLU(\textbf{a}^T [\Theta \textbf{x}_i || \Theta \textbf{x}_k]))}\]
GeometricFlux.GATConv
— TypeGATConv(in => out, σ=identity; heads=1, concat=true,
init=glorot_uniform, bias=true, negative_slope=0.2)
Graph attentional layer.
Arguments
in
: The dimension of input features.out
: The dimension of output features.bias::Bool
: Keyword argument, whether to learn the additive bias.σ
: Activation function.heads
: Number attention headsconcat
: Concatenate layer output or not. If not, layer output is averaged.negative_slope::Real
: Keyword argument, the parameter of LeakyReLU.
Examples
julia> GATConv(1024=>256, relu)
GATConv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.2))
julia> GATConv(1024=>256, relu, heads=4)
GATConv(1024=>1024, relu, heads=4, concat=true, LeakyReLU(λ=0.2))
julia> GATConv(1024=>256, relu, heads=4, concat=false)
GATConv(1024=>1024, relu, heads=4, concat=false, LeakyReLU(λ=0.2))
julia> GATConv(1024=>256, relu, negative_slope=0.1f0)
GATConv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.1))
See also WithGraph
for training layer with static graph.
Graph Attentional Layer v2
\[\textbf{x}_i' = \alpha_{i,i} \Theta \textbf{x}_i + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \Theta \textbf{x}_j\]
where the attention coefficient $\alpha_{i,j}$ can be calculated from
\[\alpha_{i,j} = \frac{exp(\textbf{a}^T LeakyReLU(\Theta [\textbf{x}_i || \textbf{x}_j]))}{\sum_{k \in \mathcal{N}(i) \cup \{i\}} exp(\textbf{a}^T LeakyReLU(\Theta [\textbf{x}_i || \textbf{x}_k]))}\]
GeometricFlux.GATv2Conv
— TypeGATv2Conv(in => out, σ=identity; heads=1, concat=true,
init=glorot_uniform, negative_slope=0.2)
Graph attentional layer v2.
Arguments
in
: The dimension of input features.out
: The dimension of output features.σ
: Activation function.heads
: Number attention headsconcat
: Concatenate layer output or not. If not, layer output is averaged.negative_slope::Real
: Keyword argument, the parameter of LeakyReLU.
Examples
julia> GATv2Conv(1024=>256, relu)
GATv2Conv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.2))
julia> GATv2Conv(1024=>256, relu, heads=4)
GATv2Conv(1024=>1024, relu, heads=4, concat=true, LeakyReLU(λ=0.2))
julia> GATv2Conv(1024=>256, relu, heads=4, concat=false)
GATv2Conv(1024=>1024, relu, heads=4, concat=false, LeakyReLU(λ=0.2))
julia> GATv2Conv(1024=>256, relu, negative_slope=0.1f0)
GATv2Conv(1024=>256, relu, heads=1, concat=true, LeakyReLU(λ=0.1))
See also WithGraph
for training layer with static graph.
Reference: Shaked Brody, Uri Alon, Eran Yahav (2022)
Gated Graph Convolution Layer
\[\textbf{h}^{(0)}_i = \textbf{x}_i || \textbf{0} \\ \textbf{h}^{(l)}_i = GRU(\textbf{h}^{(l-1)}_i, \sum_{j \in \mathcal{N}(i)} \Theta \textbf{h}^{(l-1)}_j)\]
where $\textbf{h}^{(l)}_i$ denotes the $l$-th hidden variables passing through GRU. The dimension of input $\textbf{x}_i$ needs to be less or equal to out
.
GeometricFlux.GatedGraphConv
— TypeGatedGraphConv([fg,] out, num_layers; aggr=+, init=glorot_uniform)
Gated graph convolution layer.
Arguments
out
: The dimension of output features.num_layers
: The number of gated recurrent unit.aggr
: An aggregate function applied to the result of message function.+
,-
,
*
, /
, max
, min
and mean
are available.
Examples
julia> GatedGraphConv(256, 4)
GatedGraphConv((256 => 256)^4, aggr=+)
julia> GatedGraphConv(256, 4, aggr=*)
GatedGraphConv((256 => 256)^4, aggr=*)
See also WithGraph
for training layer with static graph.
Reference: Yujia Li, Daniel Tarlow, Marc Brockschmidt, Richard Zemel (2016)
Edge Convolutional Layer
\[\textbf{x}_i' = \sum_{j \in \mathcal{N}(i)} f_{\Theta}(\textbf{x}_i || \textbf{x}_j - \textbf{x}_i)\]
where $f_{\Theta}$ denotes a neural network parametrized by $\Theta$, i.e., a MLP.
GeometricFlux.EdgeConv
— TypeEdgeConv(nn; aggr=max)
Edge convolutional layer.
Arguments
nn
: A neural network (e.g. a Dense layer or a MLP).aggr
: An aggregate function applied to the result of message function.
+
, max
and mean
are available.
Examples
julia> EdgeConv(Dense(1024, 256, relu))
EdgeConv(Dense(1024 => 256, relu), aggr=max)
julia> EdgeConv(Dense(1024, 256, relu), aggr=+)
EdgeConv(Dense(1024 => 256, relu), aggr=+)
See also WithGraph
for training layer with static graph.
Reference: Yue Wang, Yongbin Sun, Ziwei Liu, Sanjay E. Sarma, Michael M. Bronstein, Justin M. Solomon (2019)
Graph Isomorphism Network
\[\textbf{x}_i' = f_{\Theta}\left((1 + \varepsilon) \cdot \textbf{x}_i + \sum_{j \in \mathcal{N}(i)} \textbf{x}_j \right)\]
where $f_{\Theta}$ denotes a neural network parametrized by $\Theta$, i.e., a MLP.
GeometricFlux.GINConv
— TypeGINConv(nn, [eps=0])
Graph Isomorphism Network.
Arguments
nn
: A neural network/layer.eps
: Weighting factor.
Examples
julia> GINConv(Dense(1024, 256, relu))
GINConv(Dense(1024 => 256, relu), ϵ=0.0)
julia> GINConv(Dense(1024, 256, relu), 1.f-6)
GINConv(Dense(1024 => 256, relu), ϵ=1.0e-6)
See also WithGraph
for training layer with static graph.
Reference: Keyulu Xu, Weihua Hu, Jure Leskovec, Stefanie Jegelka (2019)
Crystal Graph Convolutional Network
\[\textbf{x}_i' = \textbf{x}_i + \sum_{j \in \mathcal{N}(i)} \sigma\left( \textbf{z}_{i,j} \textbf{W}_f + \textbf{b}_f \right) \odot \text{softplus}\left(\textbf{z}_{i,j} \textbf{W}_s + \textbf{b}_s \right)\]
where $\textbf{z}_{i,j} = [\textbf{x}_i, \textbf{x}_j}, \textbf{e}_{i,j}]$ denotes the concatenation of node features, neighboring node features, and edge features. The operation $\odot$ represents elementwise multiplication, and $\sigma$ denotes the sigmoid function.
GeometricFlux.CGConv
— TypeCGConv((node_dim, edge_dim), init, bias=true)
Crystal Graph Convolutional network. Uses both node and edge features.
Arguments
node_dim
: Dimensionality of the input node features. Also is necessarily the output dimensionality.edge_dim
: Dimensionality of the input edge features.init
: Initialization algorithm for each of the weight matricesbias
: Whether or not to learn an additive bias parameter.
Examples
julia> CGConv((128, 32))
CGConv(node dim=128, edge dim=32)
See also WithGraph
for training layer with static graph.
Reference: Tian Xie, Jeffrey C. Grossman (2018)