@@ -615,20 +615,19 @@ end
615
615
return TracedRNumber {T} ((), res)
616
616
end
617
617
618
- # function bitcast_convert(
619
- # ::Type{TracedRArray{U,N}},
620
- # x::TracedRArray{T,N};
621
- # location=mlir_stacktrace(
622
- # "bitcast_convert", @__FILE__, @__LINE__
623
- # ),
624
- # ) where {T,N}
625
- # res = MLIR.IR.result(
626
- # stablehlo.bitcast_convert(
627
- # x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), location
628
- # ),
629
- # )
630
- # return TracedRArray{T,N}((), res, size(x))
631
- # end
618
+ function bitcast_convert (
619
+ :: Type{TracedRArray{U,N}} ,
620
+ x:: TracedRArray{T,N} ;
621
+ location= mlir_stacktrace (" bitcast_convert" , @__FILE__ , @__LINE__ ),
622
+ ) where {T,U,N}
623
+ res = MLIR. IR. result (
624
+ stablehlo. bitcast_convert (
625
+ x. mlir_data; result_0= mlir_type (TracedRArray{U,N}, size (x)), location
626
+ ),
627
+ )
628
+ return TracedRArray {U,N} ((), res, size (x))
629
+ end
630
+
632
631
@noinline function bitcast_convert (
633
632
:: Type{U} ,
634
633
x:: TracedRNumber{T} ;
@@ -1244,7 +1243,8 @@ end
1244
1243
)
1245
1244
1246
1245
Generate a random array of type `T` with the given shape and seed from a uniform random
1247
- distribution between 0 and 1. Returns a NamedTuple with the following fields:
1246
+ distribution between 0 and 1 (for floating point types). Returns a NamedTuple with the
1247
+ following fields:
1248
1248
1249
1249
- `output_state`: The state of the random number generator after the operation.
1250
1250
- `output`: The generated array.
@@ -1283,6 +1283,7 @@ distribution between 0 and 1. Returns a NamedTuple with the following fields:
1283
1283
)
1284
1284
end
1285
1285
1286
+ # https://github.com/jax-ml/jax/blob/474dcd409d6fa4c048014851922460f9d4fc199e/jax/_src/random.py#L444-L464
1286
1287
@noinline function rng_bit_generator (
1287
1288
:: Type{T} ,
1288
1289
seed:: TracedRArray{UInt64,1} ,
@@ -1291,11 +1292,20 @@ end
1291
1292
location= mlir_stacktrace (" rng_bit_generator" , @__FILE__ , @__LINE__ ),
1292
1293
) where {T<: AbstractFloat }
1293
1294
nbits = sizeof (T) * 8
1294
- uT = nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64)
1295
+ @assert nbits ∈ (8 , 16 , 32 , 64 ) " Unsupported type: $(T) "
1296
+ uT = nbits == 8 ? UInt8 : (nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64))
1295
1297
(; output_state, output) = rng_bit_generator (uT, seed, shape; algorithm, location)
1296
- output = divide (
1297
- convert (TracedRArray{T,ndims (output)}, output),
1298
- fill (T (typemax (uT)), Tuple (shape); location),
1298
+ float_bits = or (
1299
+ shift_right_logical (
1300
+ output,
1301
+ fill (uT (nbits - Reactant. nmantissa (T)), size (output); location);
1302
+ location,
1303
+ ),
1304
+ fill (reinterpret (uT, T (1 )), size (output); location),
1305
+ )
1306
+ output = subtract (
1307
+ bitcast_convert (TracedRArray{T,length (shape)}, float_bits; location),
1308
+ fill (T (1 ), size (output); location),
1299
1309
)
1300
1310
return (; output_state, output)
1301
1311
end
0 commit comments