Skip to content

Commit a6e4c40

Browse files
authored
fix: make random number generation consistent with Jax (#1191)
1 parent 54b24cd commit a6e4c40

File tree

2 files changed

+36
-19
lines changed

2 files changed

+36
-19
lines changed

src/Ops.jl

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -615,20 +615,19 @@ end
615615
return TracedRNumber{T}((), res)
616616
end
617617

618-
# function bitcast_convert(
619-
# ::Type{TracedRArray{U,N}},
620-
# x::TracedRArray{T,N};
621-
# location=mlir_stacktrace(
622-
# "bitcast_convert", @__FILE__, @__LINE__
623-
# ),
624-
# ) where {T,N}
625-
# res = MLIR.IR.result(
626-
# stablehlo.bitcast_convert(
627-
# x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), location
628-
# ),
629-
# )
630-
# return TracedRArray{T,N}((), res, size(x))
631-
# end
618+
function bitcast_convert(
619+
::Type{TracedRArray{U,N}},
620+
x::TracedRArray{T,N};
621+
location=mlir_stacktrace("bitcast_convert", @__FILE__, @__LINE__),
622+
) where {T,U,N}
623+
res = MLIR.IR.result(
624+
stablehlo.bitcast_convert(
625+
x.mlir_data; result_0=mlir_type(TracedRArray{U,N}, size(x)), location
626+
),
627+
)
628+
return TracedRArray{U,N}((), res, size(x))
629+
end
630+
632631
@noinline function bitcast_convert(
633632
::Type{U},
634633
x::TracedRNumber{T};
@@ -1244,7 +1243,8 @@ end
12441243
)
12451244
12461245
Generate a random array of type `T` with the given shape and seed from a uniform random
1247-
distribution between 0 and 1. Returns a NamedTuple with the following fields:
1246+
distribution between 0 and 1 (for floating point types). Returns a NamedTuple with the
1247+
following fields:
12481248
12491249
- `output_state`: The state of the random number generator after the operation.
12501250
- `output`: The generated array.
@@ -1283,6 +1283,7 @@ distribution between 0 and 1. Returns a NamedTuple with the following fields:
12831283
)
12841284
end
12851285

1286+
# https://github.com/jax-ml/jax/blob/474dcd409d6fa4c048014851922460f9d4fc199e/jax/_src/random.py#L444-L464
12861287
@noinline function rng_bit_generator(
12871288
::Type{T},
12881289
seed::TracedRArray{UInt64,1},
@@ -1291,11 +1292,20 @@ end
12911292
location=mlir_stacktrace("rng_bit_generator", @__FILE__, @__LINE__),
12921293
) where {T<:AbstractFloat}
12931294
nbits = sizeof(T) * 8
1294-
uT = nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64)
1295+
@assert nbits (8, 16, 32, 64) "Unsupported type: $(T)"
1296+
uT = nbits == 8 ? UInt8 : (nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64))
12951297
(; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location)
1296-
output = divide(
1297-
convert(TracedRArray{T,ndims(output)}, output),
1298-
fill(T(typemax(uT)), Tuple(shape); location),
1298+
float_bits = or(
1299+
shift_right_logical(
1300+
output,
1301+
fill(uT(nbits - Reactant.nmantissa(T)), size(output); location);
1302+
location,
1303+
),
1304+
fill(reinterpret(uT, T(1)), size(output); location),
1305+
)
1306+
output = subtract(
1307+
bitcast_convert(TracedRArray{T,length(shape)}, float_bits; location),
1308+
fill(T(1), size(output); location),
12991309
)
13001310
return (; output_state, output)
13011311
end

src/utils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,3 +791,10 @@ end
791791
$(Expr(:meta, :generated_only))
792792
return $(Expr(:meta, :generated, call_with_reactant_generator))
793793
end
794+
795+
@static if isdefined(Core, :BFloat16)
796+
nmantissa(::Type{Core.BFloat16}) = 7
797+
end
798+
nmantissa(::Type{Float16}) = 10
799+
nmantissa(::Type{Float32}) = 23
800+
nmantissa(::Type{Float64}) = 52

0 commit comments

Comments
 (0)