Skip to content

Commit

Permalink
feat: print out throughput info
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 20, 2025
1 parent d2c9cc7 commit 740af9d
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion examples/RealNVP/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,11 @@ function main(;
train_state = Training.TrainState(model, ps, st, opt)
@printf "Total Trainable Parameters: %d\n" Lux.parameterlength(ps)

total_samples = 0
start_time = time()

for (iter, x) in enumerate(dataloader)
total_samples += size(x, ndims(x))
(_, loss, _, train_state) = Training.single_train_step!(
AutoEnzyme(), loss_function, x, train_state;
return_gradients=Val(false)
Expand All @@ -217,7 +221,9 @@ function main(;
isnan(loss) && error("NaN loss encountered in iter $(iter)!")

if iter == 1 || iter == maxiters || iter % 1000 == 0
@printf "Iter: [%6d/%6d]\tTraining Loss: %.6f\n" iter maxiters loss
throughput = total_samples / (time() - start_time)
@printf "Iter: [%6d/%6d]\tTraining Loss: %.6f\t\
Throughput: %.6f samples/s\n" iter maxiters loss throughput
end

iter maxiters && break
Expand All @@ -229,6 +235,7 @@ function main(;
end

trained_model = main()
nothing #hide

# ## Visualizing the Results
z_stages = Matrix{Float32}[]
Expand Down

0 comments on commit 740af9d

Please sign in to comment.