@@ -17,6 +17,13 @@ mutable struct TracedRArray{T,N} <: RArray{T,N}
17
17
end
18
18
end
19
19
20
+ function Base. setproperty! (x:: TracedRArray , f:: Symbol , v)
21
+ if f === :mlir_data && ! isnothing (v)
22
+ @assert size (MLIR. IR. type (v)) == size (x)
23
+ end
24
+ return setfield! (x, f, v)
25
+ end
26
+
20
27
mutable struct TracedRScalar{T} <: RScalar{T}
21
28
paths:: Tuple
22
29
mlir_data:: Union{Nothing,MLIR.IR.Value}
@@ -31,6 +38,15 @@ mutable struct TracedRScalar{T} <: RScalar{T}
31
38
end
32
39
end
33
40
41
+ function Base. setproperty! (x:: TracedRScalar , f:: Symbol , v)
42
+ if f === :mlir_data && ! isnothing (v)
43
+ @assert size (MLIR. IR. type (v)) == ()
44
+ end
45
+ return setfield! (x, f, v)
46
+ end
47
+
48
+ Base. eltype (:: Type{TracedRScalar{T}} ) where {T} = T
49
+
34
50
const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
35
51
const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
36
52
const AnyTracedRVector{T} = AnyTracedRArray{T,1 }
@@ -57,7 +73,7 @@ Base.zero(::TracedRScalar{T}) where {T} = promote_to(TracedRScalar{T}, zero(T))
57
73
Base. one (:: TracedRScalar{T} ) where {T} = promote_to (TracedRScalar{T}, one (T))
58
74
59
75
function Base. convert (:: Type{<:TracedRScalar{T}} , x:: Number ) where {T}
60
- return promote_to (TracedRArray{T, 0 }, T (x))
76
+ return promote_to (TracedRScalar{T }, T (x))
61
77
end
62
78
63
79
function Base. getindex (a:: TracedRArray{T,N} , index:: Vararg{Int,N} ) where {T,N}
@@ -119,7 +135,7 @@ function Base.setindex!(
119
135
a:: TracedRArray{T,N} , v, indices:: Vararg{Union{Base.AbstractUnitRange,Colon},N}
120
136
) where {T,N}
121
137
indices = [
122
- (promote_to (TracedRArray {Int, 0 }, i isa Colon ? 1 : first (i)) - 1 ). mlir_data for
138
+ (promote_to (TracedRScalar {Int}, i isa Colon ? 1 : first (i)) - 1 ). mlir_data for
123
139
i in indices
124
140
]
125
141
v = promote_to (TracedRArray{T,N}, v)
@@ -220,6 +236,14 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N}
220
236
return TracedRArray{Base. promote_type (T, S),N}
221
237
end
222
238
239
+ function Base. promote_rule (:: Type{T} , :: Type{TracedRScalar{S}} ) where {T,S}
240
+ return TracedRScalar{Base. promote_type (T, S)}
241
+ end
242
+
243
+ function Base. convert (:: Type{TracedRScalar{T}} , x:: Number ) where {T}
244
+ return promote_to (TracedRScalar{T}, x)
245
+ end
246
+
223
247
function promote_to (:: Type{TracedRArray{T,N}} , rhs) where {T,N}
224
248
if isa (rhs, TracedRArray)
225
249
rhs isa TracedRArray{T,N} && return rhs
@@ -277,12 +301,8 @@ function promote_to(::Type{TracedRScalar{T}}, rhs) where {T}
277
301
)
278
302
end
279
303
280
- function promote_to (:: TracedRArray{T,N} , rhs) where {T,N}
281
- return promote_to (TracedRArray{T,N}, rhs)
282
- end
283
- function promote_to (:: TracedRScalar{T} , rhs) where {T}
284
- return promote_to (TracedRScalar{T}, rhs)
285
- end
304
+ promote_to (:: TracedRArray{T,N} , rhs) where {T,N} = promote_to (TracedRArray{T,N}, rhs)
305
+ promote_to (:: TracedRScalar{T} , rhs) where {T} = promote_to (TracedRScalar{T}, rhs)
286
306
287
307
for (jlop, hloop) in (
288
308
(:(Base. min), :minimum ),
@@ -293,66 +313,35 @@ for (jlop, hloop) in (
293
313
(:(Base.:/ ), :divide ),
294
314
(:(Base.:^ ), :power ),
295
315
)
296
- @eval begin
297
- function $ (jlop)(
298
- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: TracedRArray{T,0} )
299
- ) where {T}
300
- return TracedRArray {T,0} (
301
- (),
302
- MLIR. IR. result (
303
- MLIR. Dialects. stablehlo.$ (hloop)(lhs. mlir_data, rhs. mlir_data), 1
304
- ),
305
- (),
306
- )
307
- end
308
-
309
- function $ (jlop)(
310
- @nospecialize (lhs:: TracedRArray{T1,0} ), @nospecialize (rhs:: TracedRArray{T2,0} )
311
- ) where {T1,T2}
312
- commonTy = TracedRArray{Base. promote_type (T1, T2),0 }
313
- lhs = promote_to (commonTy, lhs)
314
- rhs = promote_to (commonTy, rhs)
315
- return $ (jlop)(lhs, rhs)
316
- end
317
- end
318
-
319
- for otherType in (Number, Any)
320
- @eval begin
321
- function $ (jlop)(
322
- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: $ (otherType))
323
- ) where {T}
324
- rhs = promote_to (lhs, rhs)
325
- return $ (jlop)(lhs, rhs)
326
- end
327
-
328
- function $ (jlop)(
329
- @nospecialize (lhs:: $ (otherType)), @nospecialize (rhs:: TracedRArray{T,0} )
330
- ) where {T}
331
- lhs = promote_to (rhs, lhs)
332
- return $ (jlop)(lhs, rhs)
333
- end
334
- end
316
+ @eval function $ (jlop)(
317
+ @nospecialize (lhs:: TracedRScalar{T} ), @nospecialize (rhs:: TracedRScalar{T} )
318
+ ) where {T}
319
+ return TracedRArray {T} (
320
+ (),
321
+ MLIR. IR. result (
322
+ MLIR. Dialects. stablehlo.$ (hloop)(lhs. mlir_data, rhs. mlir_data), 1
323
+ ),
324
+ )
335
325
end
336
326
end
337
327
338
328
function Base. ifelse (
339
- @nospecialize (pred:: TracedRArray {Bool,0 } ),
340
- @nospecialize (x:: TracedRArray {T1,0 } ),
341
- @nospecialize (y:: TracedRArray {T2,0 } )
329
+ @nospecialize (pred:: TracedRScalar {Bool} ),
330
+ @nospecialize (x:: TracedRScalar {T1} ),
331
+ @nospecialize (y:: TracedRScalar {T2} )
342
332
) where {T1,T2}
343
- return TracedRArray {promote_type(T1, T2),0 } (
333
+ return TracedRScalar {promote_type(T1, T2)} (
344
334
(),
345
335
MLIR. IR. result (
346
336
MLIR. Dialects. stablehlo. select (pred. mlir_data, x. mlir_data, y. mlir_data), 1
347
337
),
348
- size (pred),
349
338
)
350
339
end
351
340
352
- Base. abs2 (x:: Reactant.TracedRArray{T,0 } ) where {T} = x * conj (x)
341
+ Base. abs2 (x:: Reactant.TracedRScalar{T } ) where {T} = x * conj (x)
353
342
354
343
function Base. literal_pow (
355
- :: Base.RefValue{typeof(^)} , x:: TracedRArray{T,0 } , :: Base.RefValue{Val{P}}
344
+ :: Base.RefValue{typeof(^)} , x:: TracedRScalar{T } , :: Base.RefValue{Val{P}}
356
345
) where {T,P}
357
346
return Base. literal_pow (^ , x, Val (P))
358
347
end
@@ -369,14 +358,10 @@ for (jlop, hloop) in (
369
358
(:(Base. log), :log ),
370
359
(:(Base. sqrt), :sqrt ),
371
360
)
372
- @eval begin
373
- function $jlop (@nospecialize (lhs:: TracedRArray{T,0} )) where {T}
374
- return TracedRArray {T,0} (
375
- (),
376
- MLIR. IR. result (MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data), 1 ),
377
- size (lhs),
378
- )
379
- end
361
+ @eval function $ (jlop)(@nospecialize (lhs:: TracedRScalar{T} )) where {T}
362
+ return TracedRScalar {T} (
363
+ (), MLIR. IR. result (MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data), 1 )
364
+ )
380
365
end
381
366
end
382
367
@@ -443,6 +428,7 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
443
428
residx = 1
444
429
445
430
for a in linear_results
431
+ @show a
446
432
if has_residx (a)
447
433
path = get_residx (a)
448
434
set! (result, path[2 : end ], MLIR. IR. result (res, residx))
@@ -478,37 +464,22 @@ for (jlop, hloop, hlocomp, merge) in (
478
464
(:(Base.:(<= )), :compare , " LE" , nothing ),
479
465
(:(Base.:(< )), :compare , " LT" , nothing ),
480
466
)
481
- @eval begin
482
- function $ (jlop)(
483
- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: TracedRArray{T,0} )
484
- ) where {T}
485
- return TracedRArray {Bool,0} (
486
- (),
487
- MLIR. IR. result (
488
- MLIR. Dialects. stablehlo.$ hloop (
489
- lhs. mlir_data,
490
- rhs. mlir_data;
491
- comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
492
- MLIR. IR. context (), $ hlocomp
493
- ),
467
+ @eval function $ (jlop)(
468
+ @nospecialize (lhs:: TracedRScalar{T} ), @nospecialize (rhs:: TracedRScalar{T} )
469
+ ) where {T}
470
+ return TracedRScalar {Bool} (
471
+ (),
472
+ MLIR. IR. result (
473
+ MLIR. Dialects. stablehlo.$ (hloop)(
474
+ lhs. mlir_data,
475
+ rhs. mlir_data;
476
+ comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
477
+ MLIR. IR. context (), $ hlocomp
494
478
),
495
- 1 ,
496
479
),
497
- size (lhs),
498
- )
499
- end
500
-
501
- function $ (jlop)(
502
- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs)
503
- ) where {T}
504
- return $ (jlop)(lhs, promote_to (lhs, rhs))
505
- end
506
-
507
- function $ (jlop)(
508
- @nospecialize (lhs), @nospecialize (rhs:: TracedRArray{T,0} )
509
- ) where {T}
510
- return $ (jlop)(promote_to (rhs, lhs), rhs)
511
- end
480
+ 1 ,
481
+ ),
482
+ )
512
483
end
513
484
514
485
if merge != = nothing
@@ -598,7 +569,7 @@ function Base.mapreduce(
598
569
fnbody = MLIR. IR. Block (in_tys, [MLIR. IR. Location () for arg in in_tys])
599
570
600
571
args = (
601
- TracedRArray {T,0 } ((), MLIR. IR. argument (fnbody, i), ()) for
572
+ TracedRScalar {T } ((), MLIR. IR. argument (fnbody, i), ()) for
602
573
(i, ty) in enumerate (in_tys)
603
574
)
604
575
0 commit comments