Skip to content

Commit 4c777a0

Browse files
committed
feat: migrate DDIM to Reactant
1 parent 3c331d7 commit 4c777a0

File tree

2 files changed

+26
-23
lines changed

2 files changed

+26
-23
lines changed

Diff for: examples/DDIM/Project.toml

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
[deps]
22
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
3-
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
43
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
54
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
65
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
76
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
7+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
88
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
99
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
1010
ImageIO = "82e4d734-157c-48bb-816b-45c225c6df19"
@@ -14,8 +14,9 @@ LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
1414
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1515
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1616
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
17-
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
17+
ProgressTables = "e0b4b9f6-8cc7-451e-9c86-94c5316e9f73"
1818
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
19+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
1920
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2021
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2122
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -24,7 +25,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2425

2526
[compat]
2627
ArgCheck = "2.3.0"
27-
CairoMakie = "0.12"
2828
Comonicon = "1"
2929
ConcreteStructs = "0.2.3"
3030
DataAugmentation = "0.3"
@@ -38,7 +38,6 @@ LuxCUDA = "0.3"
3838
MLUtils = "0.4"
3939
Optimisers = "0.4.1"
4040
ParameterSchedulers = "0.4.1"
41-
ProgressBars = "1"
4241
Random = "1.10"
4342
Setfield = "1"
4443
StableRNGs = "1.0.2"

Diff for: examples/DDIM/main.jl

+23-19
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
# ## Package Imports
88

9-
using ArgCheck, CairoMakie, ConcreteStructs, Comonicon, DataAugmentation, DataDeps, FileIO,
10-
ImageCore, JLD2, Lux, LuxCUDA, MLUtils, Optimisers, ParameterSchedulers, ProgressBars,
11-
Random, Setfield, StableRNGs, Statistics, Zygote
9+
using ArgCheck, ConcreteStructs, Comonicon, DataAugmentation, DataDeps, Enzyme, FileIO,
10+
ImageCore, JLD2, Lux, LuxCUDA, MLUtils, Optimisers, ParameterSchedulers,
11+
ProgressTables, Random, Reactant, Setfield, StableRNGs, Statistics, Zygote
1212
using TensorBoardLogger: TBLogger, log_value, log_images
1313

1414
CUDA.allowscalar(false)
@@ -19,16 +19,15 @@ CUDA.allowscalar(false)
1919
# [the Keras example](https://keras.io/examples/generative/ddim/). Embed noise variances to
2020
# embedding.
2121

22-
function sinusoidal_embedding(x::AbstractArray{T, 4}, min_freq::T, max_freq::T,
23-
embedding_dims::Int) where {T <: AbstractFloat}
24-
size(x)[1:3] != (1, 1, 1) &&
22+
function sinusoidal_embedding(
23+
x::AbstractArray{T, 4}, min_freq, max_freq, embedding_dims::Int) where {T}
24+
if size(x)[1:3] != (1, 1, 1)
2525
throw(DimensionMismatch("Input shape must be (1, 1, 1, batch)"))
26+
end
2627

27-
lower, upper = log(min_freq), log(max_freq)
28+
lower, upper = T(log(min_freq)), T(log(max_freq))
2829
n = embedding_dims ÷ 2
29-
d = (upper - lower) / (n - 1)
30-
freqs = reshape(exp.(lower:d:upper) |> get_device(x), 1, 1, n, 1)
31-
x_ = 2 .* x .* freqs
30+
x_ = 2 .* x .* exp.(reshape(range(lower, upper; length=n), 1, 1, n, 1))
3231
return cat(sinpi.(x_), cospi.(x_); dims=Val(3))
3332
end
3433

@@ -97,7 +96,7 @@ function unet_model(image_size::Tuple{Int, Int}; channels=[32, 64, 96, 128],
9796
return @compact(;
9897
upsample, conv_in, conv_out, down_blocks, residual_blocks, up_blocks,
9998
min_freq, max_freq, embedding_dims,
100-
num_blocks=(length(channels) - 1)) do x::Tuple{AbstractArray{<:Real, 4}, AbstractArray{<:Real, 4}}
99+
num_blocks=(length(channels) - 1)) do x::Tuple{<:AbstractArray, <:AbstractArray}
101100
#! format: on
102101
noisy_images, noise_variances = x
103102

@@ -125,8 +124,7 @@ function ddim(rng::AbstractRNG, args...; min_signal_rate=0.02f0,
125124
unet = unet_model(args...; kwargs...)
126125
bn = BatchNorm(3; affine=false, track_stats=true)
127126

128-
return @compact(; unet, bn, rng, min_signal_rate,
129-
max_signal_rate, dispatch=:DDIM) do x::AbstractArray{<:Real, 4}
127+
return @compact(; unet, bn, rng, min_signal_rate, max_signal_rate, dispatch=:DDIM) do x
130128
images = bn(x)
131129
rng = Lux.replicate(rng)
132130

@@ -144,10 +142,10 @@ function ddim(rng::AbstractRNG, args...; min_signal_rate=0.02f0,
144142
end
145143
end
146144

147-
function diffusion_schedules(diffusion_times::AbstractArray{T, 4}, min_signal_rate::T,
148-
max_signal_rate::T) where {T <: Real}
149-
start_angle = acos(max_signal_rate)
150-
end_angle = acos(min_signal_rate)
145+
function diffusion_schedules(
146+
diffusion_times::AbstractArray{T, 4}, min_signal_rate, max_signal_rate) where {T}
147+
start_angle = T(acos(max_signal_rate))
148+
end_angle = T(acos(min_signal_rate))
151149

152150
diffusion_angles = @. start_angle + (end_angle - start_angle) * diffusion_times
153151

@@ -157,8 +155,14 @@ function diffusion_schedules(diffusion_times::AbstractArray{T, 4}, min_signal_ra
157155
return noise_rates, signal_rates
158156
end
159157

160-
function denoise(unet, noisy_images::AbstractArray{T, 4}, noise_rates::AbstractArray{T, 4},
161-
signal_rates::AbstractArray{T, 4}) where {T <: Real}
158+
function denoise(
159+
unet, noisy_images::AbstractArray{T1, 4}, noise_rates::AbstractArray{T2, 4},
160+
signal_rates::AbstractArray{T3, 4}) where {T1, T2, T3}
161+
T = promote_type(T1, T2, T3)
162+
noisy_images = T.(noisy_images)
163+
noise_rates = T.(noise_rates)
164+
signal_rates = T.(signal_rates)
165+
162166
pred_noises = unet((noisy_images, noise_rates .^ 2))
163167
pred_images = @. (noisy_images - pred_noises * noise_rates) / signal_rates
164168
return pred_noises, pred_images

0 commit comments

Comments
 (0)