6
6
7
7
# ## Package Imports
8
8
9
- using ArgCheck, CairoMakie, ConcreteStructs, Comonicon, DataAugmentation, DataDeps, FileIO,
10
- ImageCore, JLD2, Lux, LuxCUDA, MLUtils, Optimisers, ParameterSchedulers, ProgressBars,
11
- Random, Setfield, StableRNGs, Statistics, Zygote
9
+ using ArgCheck, ConcreteStructs, Comonicon, DataAugmentation, DataDeps, Enzyme , FileIO,
10
+ ImageCore, JLD2, Lux, LuxCUDA, MLUtils, Optimisers, ParameterSchedulers,
11
+ ProgressTables, Random, Reactant , Setfield, StableRNGs, Statistics, Zygote
12
12
using TensorBoardLogger: TBLogger, log_value, log_images
13
13
14
14
CUDA. allowscalar (false )
@@ -19,16 +19,15 @@ CUDA.allowscalar(false)
19
19
# [the Keras example](https://keras.io/examples/generative/ddim/). Embed noise variances to
20
20
# embedding.
21
21
22
- function sinusoidal_embedding (x :: AbstractArray{T, 4} , min_freq :: T , max_freq :: T ,
23
- embedding_dims:: Int ) where {T <: AbstractFloat }
24
- size (x)[1 : 3 ] != (1 , 1 , 1 ) &&
22
+ function sinusoidal_embedding (
23
+ x :: AbstractArray{T, 4} , min_freq, max_freq, embedding_dims:: Int ) where {T}
24
+ if size (x)[1 : 3 ] != (1 , 1 , 1 )
25
25
throw (DimensionMismatch (" Input shape must be (1, 1, 1, batch)" ))
26
+ end
26
27
27
- lower, upper = log (min_freq), log (max_freq)
28
+ lower, upper = T ( log (min_freq)), T ( log (max_freq) )
28
29
n = embedding_dims ÷ 2
29
- d = (upper - lower) / (n - 1 )
30
- freqs = reshape (exp .(lower: d: upper) |> get_device (x), 1 , 1 , n, 1 )
31
- x_ = 2 .* x .* freqs
30
+ x_ = 2 .* x .* exp .(reshape (range (lower, upper; length= n), 1 , 1 , n, 1 ))
32
31
return cat (sinpi .(x_), cospi .(x_); dims= Val (3 ))
33
32
end
34
33
@@ -97,7 +96,7 @@ function unet_model(image_size::Tuple{Int, Int}; channels=[32, 64, 96, 128],
97
96
return @compact (;
98
97
upsample, conv_in, conv_out, down_blocks, residual_blocks, up_blocks,
99
98
min_freq, max_freq, embedding_dims,
100
- num_blocks= (length (channels) - 1 )) do x:: Tuple{AbstractArray{<:Real, 4}, AbstractArray{<:Real, 4} }
99
+ num_blocks= (length (channels) - 1 )) do x:: Tuple{<:AbstractArray, <:AbstractArray }
101
100
# ! format: on
102
101
noisy_images, noise_variances = x
103
102
@@ -125,8 +124,7 @@ function ddim(rng::AbstractRNG, args...; min_signal_rate=0.02f0,
125
124
unet = unet_model (args... ; kwargs... )
126
125
bn = BatchNorm (3 ; affine= false , track_stats= true )
127
126
128
- return @compact (; unet, bn, rng, min_signal_rate,
129
- max_signal_rate, dispatch= :DDIM ) do x:: AbstractArray{<:Real, 4}
127
+ return @compact (; unet, bn, rng, min_signal_rate, max_signal_rate, dispatch= :DDIM ) do x
130
128
images = bn (x)
131
129
rng = Lux. replicate (rng)
132
130
@@ -144,10 +142,10 @@ function ddim(rng::AbstractRNG, args...; min_signal_rate=0.02f0,
144
142
end
145
143
end
146
144
147
- function diffusion_schedules (diffusion_times :: AbstractArray{T, 4} , min_signal_rate :: T ,
148
- max_signal_rate :: T ) where {T <: Real }
149
- start_angle = acos (max_signal_rate)
150
- end_angle = acos (min_signal_rate)
145
+ function diffusion_schedules (
146
+ diffusion_times :: AbstractArray{T, 4} , min_signal_rate, max_signal_rate ) where {T}
147
+ start_angle = T ( acos (max_signal_rate) )
148
+ end_angle = T ( acos (min_signal_rate) )
151
149
152
150
diffusion_angles = @. start_angle + (end_angle - start_angle) * diffusion_times
153
151
@@ -157,8 +155,14 @@ function diffusion_schedules(diffusion_times::AbstractArray{T, 4}, min_signal_ra
157
155
return noise_rates, signal_rates
158
156
end
159
157
160
- function denoise (unet, noisy_images:: AbstractArray{T, 4} , noise_rates:: AbstractArray{T, 4} ,
161
- signal_rates:: AbstractArray{T, 4} ) where {T <: Real }
158
+ function denoise (
159
+ unet, noisy_images:: AbstractArray{T1, 4} , noise_rates:: AbstractArray{T2, 4} ,
160
+ signal_rates:: AbstractArray{T3, 4} ) where {T1, T2, T3}
161
+ T = promote_type (T1, T2, T3)
162
+ noisy_images = T .(noisy_images)
163
+ noise_rates = T .(noise_rates)
164
+ signal_rates = T .(signal_rates)
165
+
162
166
pred_noises = unet ((noisy_images, noise_rates .^ 2 ))
163
167
pred_images = @. (noisy_images - pred_noises * noise_rates) / signal_rates
164
168
return pred_noises, pred_images
0 commit comments