From 740af9d0752070658f0fdf2b27eeee4c3cca38a9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 20 Jan 2025 13:27:28 -0500 Subject: [PATCH] feat: print out throughput info --- examples/RealNVP/main.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/RealNVP/main.jl b/examples/RealNVP/main.jl index 06a2fb66b..b1c8b4d7e 100644 --- a/examples/RealNVP/main.jl +++ b/examples/RealNVP/main.jl @@ -208,7 +208,11 @@ function main(; train_state = Training.TrainState(model, ps, st, opt) @printf "Total Trainable Parameters: %d\n" Lux.parameterlength(ps) + total_samples = 0 + start_time = time() + for (iter, x) in enumerate(dataloader) + total_samples += size(x, ndims(x)) (_, loss, _, train_state) = Training.single_train_step!( AutoEnzyme(), loss_function, x, train_state; return_gradients=Val(false) @@ -217,7 +221,9 @@ function main(; isnan(loss) && error("NaN loss encountered in iter $(iter)!") if iter == 1 || iter == maxiters || iter % 1000 == 0 - @printf "Iter: [%6d/%6d]\tTraining Loss: %.6f\n" iter maxiters loss + throughput = total_samples / (time() - start_time) + @printf "Iter: [%6d/%6d]\tTraining Loss: %.6f\t\ + Throughput: %.6f samples/s\n" iter maxiters loss throughput end iter ≥ maxiters && break @@ -229,6 +235,7 @@ function main(; end trained_model = main() +nothing #hide # ## Visualizing the Results z_stages = Matrix{Float32}[]