diff --git a/docs/make.jl b/docs/make.jl index 7e75df69d2..ce91fb6b42 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -30,6 +30,7 @@ pages = [ "tutorials/intermediate/3_HyperNet.md", "tutorials/intermediate/4_PINN2DPDE.md", "tutorials/intermediate/5_ConvolutionalVAE.md", + "tutorials/intermediate/6_GCN_Cora.md", ], "Advanced" => [ "tutorials/advanced/1_GravitationalWaveForm.md" diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 732ce75004..a66f2e5d3b 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -221,7 +221,11 @@ export default defineConfig({ { text: "Convolutional VAE for MNIST using Reactant", link: "/tutorials/intermediate/5_ConvolutionalVAE", - } + }, + { + text: "Graph Convolutional Network on Cora", + link: "/tutorials/intermediate/6_GCN_Cora", + }, ], }, { diff --git a/docs/src/public/gcn_cora.jpg b/docs/src/public/gcn_cora.jpg new file mode 100644 index 0000000000..bb1d52340c Binary files /dev/null and b/docs/src/public/gcn_cora.jpg differ diff --git a/docs/src/tutorials/index.md b/docs/src/tutorials/index.md index 82856d1338..bfbacf5001 100644 --- a/docs/src/tutorials/index.md +++ b/docs/src/tutorials/index.md @@ -71,6 +71,12 @@ const intermediate = [ src: "../conditional_vae.png", caption: "Convolutional VAE for MNIST using Reactant", desc: "Train a Convolutional VAE to generate images from a latent space." + }, + { + href: "intermediate/6_GCN_Cora", + src: "../gcn_cora.jpg", + caption: "Graph Convolutional Network on Cora", + desc: "Train a Graph Convolutional Network on Cora dataset." } ]; diff --git a/docs/tutorials.jl b/docs/tutorials.jl index fce7546923..4f878e6f8d 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -13,6 +13,7 @@ const INTERMEDIATE_TUTORIALS = [ "HyperNet/main.jl" => "CUDA", "PINN2DPDE/main.jl" => "CUDA", "ConvolutionalVAE/main.jl" => "CUDA", + "GCN_Cora/main.jl" => "CUDA", ] const ADVANCED_TUTORIALS = [ "GravitationalWaveForm/main.jl" => "CPU", diff --git a/examples/ConvolutionalVAE/Project.toml b/examples/ConvolutionalVAE/Project.toml index a1b5e56a50..88b49f78f8 100644 --- a/examples/ConvolutionalVAE/Project.toml +++ b/examples/ConvolutionalVAE/Project.toml @@ -1,5 +1,4 @@ [deps] -Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -13,10 +12,8 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" -StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] -Comonicon = "1" ConcreteStructs = "0.2.3" DataAugmentation = "0.3.2" Enzyme = "0.13.20" @@ -30,4 +27,3 @@ Optimisers = "0.4" Printf = "1.10" Random = "1.10" Reactant = "0.2.9" -StableRNGs = "1" diff --git a/examples/ConvolutionalVAE/main.jl b/examples/ConvolutionalVAE/main.jl index cfd95e8c4f..7a89e1e611 100644 --- a/examples/ConvolutionalVAE/main.jl +++ b/examples/ConvolutionalVAE/main.jl @@ -4,8 +4,7 @@ # based on the [CVAE implementation in MLX](https://github.com/ml-explore/mlx-examples/blob/main/cvae/). using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, MLUtils, DataAugmentation, - ConcreteStructs, OneHotArrays, ImageShow, Images, Printf, Optimisers, Comonicon, - StableRNGs + ConcreteStructs, OneHotArrays, ImageShow, Images, Printf, Optimisers const xdev = reactant_device(; force=true) const cdev = cpu_device() @@ -280,4 +279,9 @@ function main(; batchsize=128, image_size=(64, 64), num_latent_dims=8, max_num_f return model_img_full end -main() +img = main() +nothing #hide + +# --- + +img #hide diff --git a/examples/GCN_Cora/Project.toml b/examples/GCN_Cora/Project.toml new file mode 100644 index 0000000000..56238e0485 --- /dev/null +++ b/examples/GCN_Cora/Project.toml @@ -0,0 +1,25 @@ +[deps] +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[compat] +ConcreteStructs = "0.2.3" +Enzyme = "0.13.28" +Lux = "1.5" +MLDatasets = "0.7.18" +MLUtils = "0.4.5" +Optimisers = "0.4.4" +Printf = "1.10" +Random = "1.10" +Reactant = "0.2.18" +Statistics = "1.10" diff --git a/examples/GCN_Cora/main.jl b/examples/GCN_Cora/main.jl new file mode 100644 index 0000000000..fc322471ee --- /dev/null +++ b/examples/GCN_Cora/main.jl @@ -0,0 +1,141 @@ +# # [Graph Convolutional Networks on Cora](@id GCN-Tutorial-Cora) + +# This example is based on [GCN MLX tutorial](https://github.com/ml-explore/mlx-examples/blob/main/gcn/). While we are doing this manually, we recommend directly using +# [GNNLux.jl](https://juliagraphs.org/GraphNeuralNetworks.jl/docs/GNNLux.jl/stable/). + +using Lux, Reactant, MLDatasets, Random, Statistics, Enzyme, GNNGraphs, MLUtils, + ConcreteStructs, Printf, OneHotArrays, Optimisers + +const xdev = reactant_device(; force=true) +const cdev = cpu_device() + +# ## Loading Cora Dataset + +function loadcora() + data = Cora() + gph = data.graphs[1] + gnngraph = GNNGraph( + gph.edge_index; ndata=gph.node_data, edata=gph.edge_data, gph.num_nodes + ) + return ( + gph.node_data.features, + onehotbatch(gph.node_data.targets, data.metadata["classes"]), + ## We use a dense matrix here to avoid incompatibility with Reactant + Matrix(adjacency_matrix(gnngraph)), + ## We use this since Reactant doesn't yet support gather adjoint + (1:140, 141:640, 1709:2708) + ) +end + +# ## Model Definition + +function GCNLayer(args...; kwargs...) + return @compact(; dense=Dense(args...; kwargs...)) do (x, adj) + @return dense(x) * adj + end +end + +function GCN(x_dim, h_dim, out_dim; nb_layers=2, dropout=0.5, kwargs...) + layer_sizes = vcat(x_dim, [h_dim for _ in 1:nb_layers]) + gcn_layers = [GCNLayer(in_dim => out_dim; kwargs...) + for (in_dim, out_dim) in zip(layer_sizes[1:(end - 1)], layer_sizes[2:end])] + last_layer = GCNLayer(layer_sizes[end] => out_dim; kwargs...) + dropout = Dropout(dropout) + + return @compact(; gcn_layers, dropout, last_layer) do (x, adj, mask) + for layer in gcn_layers + x = relu.(layer((x, adj))) + x = dropout(x) + end + @return last_layer((x, adj))[:, mask] + end +end + +# ## Helper Functions + +function loss_function(model, ps, st, (x, y, adj, mask)) + y_pred, st = model((x, adj, mask), ps, st) + loss = CrossEntropyLoss(; agg=mean, logits=Val(true))(y_pred, y[:, mask]) + return loss, st, (; y_pred) +end + +accuracy(y_pred, y) = mean(onecold(y_pred) .== onecold(y)) * 100 + +# ## Training the Model + +function main(; + hidden_dim::Int=64, dropout::Float64=0.1, nb_layers::Int=2, use_bias::Bool=true, + lr::Float64=0.001, weight_decay::Float64=0.0, patience::Int=20, epochs::Int=200 +) + rng = Random.default_rng() + Random.seed!(rng, 0) + + features, targets, adj, (train_idx, val_idx, test_idx) = loadcora() |> xdev + + gcn = GCN(size(features, 1), hidden_dim, size(targets, 1); nb_layers, dropout, use_bias) + ps, st = Lux.setup(rng, gcn) |> xdev + opt = iszero(weight_decay) ? Adam(lr) : AdamW(; eta=lr, lambda=weight_decay) + + train_state = Training.TrainState(gcn, ps, st, opt) + + @printf "Total Trainable Parameters: %0.4f M\n" (Lux.parameterlength(ps)/1e6) + + val_loss_compiled = @compile loss_function( + gcn, ps, Lux.testmode(st), (features, targets, adj, val_idx)) + + train_model_compiled = @compile gcn((features, adj, train_idx), ps, Lux.testmode(st)) + val_model_compiled = @compile gcn((features, adj, val_idx), ps, Lux.testmode(st)) + + best_loss_val = Inf + cnt = 0 + + for epoch in 1:epochs + (_, loss, _, train_state) = Lux.Training.single_train_step!( + AutoEnzyme(), loss_function, (features, targets, adj, train_idx), train_state; + return_gradients=Val(false) + ) + train_acc = accuracy( + Array(train_model_compiled((features, adj, train_idx), + train_state.parameters, Lux.testmode(train_state.states))[1]), + Array(targets)[:, train_idx] + ) + + val_loss = first(val_loss_compiled( + gcn, train_state.parameters, Lux.testmode(train_state.states), + (features, targets, adj, val_idx))) + val_acc = accuracy( + Array(val_model_compiled((features, adj, val_idx), + train_state.parameters, Lux.testmode(train_state.states))[1]), + Array(targets)[:, val_idx] + ) + + @printf "Epoch %3d\tTrain Loss: %.6f\tTrain Acc: %.4f%%\tVal Loss: %.6f\t\ + Val Acc: %.4f%%\n" epoch loss train_acc val_loss val_acc + + if val_loss < best_loss_val + best_loss_val = val_loss + cnt = 0 + else + cnt += 1 + if cnt == patience + @printf "Early Stopping at Epoch %d\n" epoch + break + end + end + end + + test_loss = @jit(loss_function( + gcn, train_state.parameters, Lux.testmode(train_state.states), + (features, targets, adj, test_idx)))[1] + test_acc = accuracy( + Array(@jit(gcn((features, adj, test_idx), + train_state.parameters, Lux.testmode(train_state.states)))[1]), + Array(targets)[:, test_idx] + ) + + @printf "Test Loss: %.6f\tTest Acc: %.4f%%\n" test_loss test_acc + return +end + +main() +nothing #hide