Skip to content

Commit

Permalink
docs: add GCN Cora example (#1210)
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Jan 19, 2025
1 parent 30e7b01 commit 926029d
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pages = [
"tutorials/intermediate/3_HyperNet.md",
"tutorials/intermediate/4_PINN2DPDE.md",
"tutorials/intermediate/5_ConvolutionalVAE.md",
"tutorials/intermediate/6_GCN_Cora.md",
],
"Advanced" => [
"tutorials/advanced/1_GravitationalWaveForm.md"
Expand Down
6 changes: 5 additions & 1 deletion docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,11 @@ export default defineConfig({
{
text: "Convolutional VAE for MNIST using Reactant",
link: "/tutorials/intermediate/5_ConvolutionalVAE",
}
},
{
text: "Graph Convolutional Network on Cora",
link: "/tutorials/intermediate/6_GCN_Cora",
},
],
},
{
Expand Down
Binary file added docs/src/public/gcn_cora.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions docs/src/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ const intermediate = [
src: "../conditional_vae.png",
caption: "Convolutional VAE for MNIST using Reactant",
desc: "Train a Convolutional VAE to generate images from a latent space."
},
{
href: "intermediate/6_GCN_Cora",
src: "../gcn_cora.jpg",
caption: "Graph Convolutional Network on Cora",
desc: "Train a Graph Convolutional Network on Cora dataset."
}
];
Expand Down
1 change: 1 addition & 0 deletions docs/tutorials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const INTERMEDIATE_TUTORIALS = [
"HyperNet/main.jl" => "CUDA",
"PINN2DPDE/main.jl" => "CUDA",
"ConvolutionalVAE/main.jl" => "CUDA",
"GCN_Cora/main.jl" => "CUDA",
]
const ADVANCED_TUTORIALS = [
"GravitationalWaveForm/main.jl" => "CPU",
Expand Down
4 changes: 0 additions & 4 deletions examples/ConvolutionalVAE/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[deps]
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand All @@ -13,10 +12,8 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
Comonicon = "1"
ConcreteStructs = "0.2.3"
DataAugmentation = "0.3.2"
Enzyme = "0.13.20"
Expand All @@ -30,4 +27,3 @@ Optimisers = "0.4"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.9"
StableRNGs = "1"
10 changes: 7 additions & 3 deletions examples/ConvolutionalVAE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
# based on the [CVAE implementation in MLX](https://github.com/ml-explore/mlx-examples/blob/main/cvae/).

using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, MLUtils, DataAugmentation,
ConcreteStructs, OneHotArrays, ImageShow, Images, Printf, Optimisers, Comonicon,
StableRNGs
ConcreteStructs, OneHotArrays, ImageShow, Images, Printf, Optimisers

const xdev = reactant_device(; force=true)
const cdev = cpu_device()
Expand Down Expand Up @@ -280,4 +279,9 @@ function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_f
return model_img_full
end

main()
img = main()
nothing #hide

# ---

img #hide
25 changes: 25 additions & 0 deletions examples/GCN_Cora/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[deps]
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ConcreteStructs = "0.2.3"
Enzyme = "0.13.28"
Lux = "1.5"
MLDatasets = "0.7.18"
MLUtils = "0.4.5"
Optimisers = "0.4.4"
Printf = "1.10"
Random = "1.10"
Reactant = "0.2.18"
Statistics = "1.10"
141 changes: 141 additions & 0 deletions examples/GCN_Cora/main.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# # [Graph Convolutional Networks on Cora](@id GCN-Tutorial-Cora)

# This example is based on [GCN MLX tutorial](https://github.com/ml-explore/mlx-examples/blob/main/gcn/). While we are doing this manually, we recommend directly using
# [GNNLux.jl](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/).

using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, GNNGraphs, MLUtils,
ConcreteStructs, Printf, OneHotArrays, Optimisers

const xdev = reactant_device(; force=true)
const cdev = cpu_device()

# ## Loading Cora Dataset

function loadcora()
data = Cora()
gph = data.graphs[1]
gnngraph = GNNGraph(
gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes
)
return (
gph.node_data.features,
onehotbatch(gph.node_data.targets, data.metadata["classes"]),
## We use a dense matrix here to avoid incompatibility with Reactant
Matrix(adjacency_matrix(gnngraph)),
## We use this since Reactant doesn't yet support gather adjoint
(1:140, 141:640, 1709:2708)
)
end

# ## Model Definition

function GCNLayer(args...; kwargs...)
return @compact(; dense=Dense(args...; kwargs...)) do (x, adj)
@return dense(x) * adj
end
end

function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...)
layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers])
gcn_layers = [GCNLayer(in_dim => out_dim; kwargs...)
for (in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])]
last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...)
dropout = Dropout(dropout)

return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask)
for layer in gcn_layers
x = relu.(layer((x, adj)))
x = dropout(x)
end
@return last_layer((x, adj))[:, mask]
end
end

# ## Helper Functions

function loss_function(model, ps, st, (x, y, adj, mask))
y_pred, st = model((x, adj, mask), ps, st)
loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask])
return loss, st, (; y_pred)
end

accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100

# ## Training the Model

function main(;
hidden_dim::Int=64, dropout::Float64=0.1, nb_layers::Int=2, use_bias::Bool=true,
lr::Float64=0.001, weight_decay::Float64=0.0, patience::Int=20, epochs::Int=200
)
rng = Random.default_rng()
Random.seed!(rng, 0)

features, targets, adj, (train_idx, val_idx, test_idx) = loadcora() |> xdev

gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias)
ps, st = Lux.setup(rng, gcn) |> xdev
opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay)

train_state = Training.TrainState(gcn, ps, st, opt)

@printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps)/1e6)

val_loss_compiled = @compile loss_function(
gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx))

train_model_compiled = @compile gcn((features, adj, train_idx), ps, Lux.testmode(st))
val_model_compiled = @compile gcn((features, adj, val_idx), ps, Lux.testmode(st))

best_loss_val = Inf
cnt = 0

for epoch in 1:epochs
(_, loss, _, train_state) = Lux.Training.single_train_step!(
AutoEnzyme(), loss_function, (features, targets, adj, train_idx), train_state;
return_gradients=Val(false)
)
train_acc = accuracy(
Array(train_model_compiled((features, adj, train_idx),
train_state.parameters, Lux.testmode(train_state.states))[1]),
Array(targets)[:, train_idx]
)

val_loss = first(val_loss_compiled(
gcn, train_state.parameters, Lux.testmode(train_state.states),
(features, targets, adj, val_idx)))
val_acc = accuracy(
Array(val_model_compiled((features, adj, val_idx),
train_state.parameters, Lux.testmode(train_state.states))[1]),
Array(targets)[:, val_idx]
)

@printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\
Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc

if val_loss < best_loss_val
best_loss_val = val_loss
cnt = 0
else
cnt += 1
if cnt == patience
@printf "Early Stopping at Epoch %d\n" epoch
break
end
end
end

test_loss = @jit(loss_function(
gcn, train_state.parameters, Lux.testmode(train_state.states),
(features, targets, adj, test_idx)))[1]
test_acc = accuracy(
Array(@jit(gcn((features, adj, test_idx),
train_state.parameters, Lux.testmode(train_state.states)))[1]),
Array(targets)[:, test_idx]
)

@printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc
return
end

main()
nothing #hide

0 comments on commit 926029d

Please sign in to comment.