@@ -234,6 +234,8 @@ This mostly aims to re-work the given expression into `some(steps(A))[i,j]`,
234
234
but also pushes `A = f(x)` into `store.top`.
235
235
"""
236
236
function standardise (ex, store:: NamedTuple , call:: CallInfo ; LHS= false )
237
+ @nospecialize ex
238
+
237
239
# This acts only on single indexing expressions:
238
240
if @capture (ex, A_{ijk__})
239
241
static= true
@@ -378,6 +380,7 @@ target dims not correctly handled yet -- what do I want? TODO
378
380
Simple glue / stand. does not permutedims, but broadcasting may have to... avoid twice?
379
381
"""
380
382
function standardglue (ex, target, store:: NamedTuple , call:: CallInfo )
383
+ @nospecialize ex
381
384
382
385
# The sole target here is indexing expressions:
383
386
if @capture (ex, A_[inner__])
@@ -469,6 +472,7 @@ This beings the expression to have target indices,
469
472
by permutedims and if necessary broadcasting, always using `readycast()`.
470
473
"""
471
474
function targetcast (ex, target, store:: NamedTuple , call:: CallInfo )
475
+ @nospecialize ex
472
476
473
477
# If just one naked expression, then we won't broadcast:
474
478
if @capture (ex, A_[ijk__])
503
507
This is walked over the expression to prepare for `@__dot__` etc, by `targetcast()`.
504
508
"""
505
509
function readycast (ex, target, store:: NamedTuple , call:: CallInfo )
510
+ @nospecialize ex
506
511
507
512
# Scalar functions can be protected entirely from broadcasting:
508
513
# TODO this means A[i,j] + rand()/10 doesn't work, /(...,10) is a function!
@@ -578,6 +583,7 @@ If there are more than two factors, it recurses, and you get `(A*B) * C`,
578
583
or perhaps tuple `(A*B, C)`.
579
584
"""
580
585
function matmultarget (ex, target, parsed, store:: NamedTuple , call:: CallInfo )
586
+ @nospecialize ex
581
587
582
588
@capture (ex, A_ * B_ * C__ | * (A_, B_, C__) ) || throw (MacroError (" can't @matmul that!" , call))
583
589
@@ -631,6 +637,7 @@ pushing calculation steps into store.
631
637
Also a convenient place to tidy all indices, including e.g. `fun(M[:,j],N[j]).same[i']`.
632
638
"""
633
639
function recursemacro (ex, store:: NamedTuple , call:: CallInfo )
640
+ @nospecialize ex
634
641
635
642
# Actually look for recursion
636
643
if @capture (ex, @reduce (subex__) )
@@ -675,6 +682,8 @@ This saves to `store` the sizes of all input tensors, and their sub-slices if an
675
682
however it should not destroy this so that `sz_j` can be got later.
676
683
"""
677
684
function rightsizes (ex, store:: NamedTuple , call:: CallInfo )
685
+ @nospecialize ex
686
+
678
687
:recurse in call. flags && return nothing # outer version took care of this
679
688
680
689
if @capture (ex, A_[outer__][inner__] | A_[outer__]{inner__} )
@@ -1115,8 +1124,7 @@ end
1115
1124
1116
1125
tensorprimetidy (v:: Vector ) = Any[ tensorprimetidy (x) for x in v ]
1117
1126
function tensorprimetidy (ex)
1118
- MacroTools. postwalk (ex) do x
1119
-
1127
+ MacroTools. postwalk (ex) do @nospecialize x
1120
1128
@capture (x, ((ij__,) \ k_) ) && return :( ($ (ij... ),$ k) )
1121
1129
@capture (x, i_ \ j_ ) && return :( ($ i,$ j) )
1122
1130
@@ -1172,7 +1180,7 @@ containsindexing(s) = false
1172
1180
function containsindexing (ex:: Expr )
1173
1181
flag = false
1174
1182
# MacroTools.postwalk(x -> @capture(x, A_[ijk__]) && (flag=true), ex)
1175
- MacroTools. postwalk (ex) do x
1183
+ MacroTools. postwalk (ex) do @nospecialize x
1176
1184
# @capture(x, A_[ijk__]) && !(all(isconstant, ijk)) && (flag=true)
1177
1185
if @capture (x, A_[ijk__])
1178
1186
# @show x ijk # TODO this is a bit broken? @pretty @cast Z[i,j] := W[i] * exp(X[1][i] - X[2][j])
@@ -1185,7 +1193,7 @@ end
1185
1193
listindices (s:: Symbol ) = []
1186
1194
function listindices (ex:: Expr )
1187
1195
list = []
1188
- MacroTools. postwalk (ex) do x
1196
+ MacroTools. postwalk (ex) do @nospecialize x
1189
1197
if @capture (x, A_[ijk__])
1190
1198
flat, _ = indexparse (nothing , ijk)
1191
1199
push! (list, flat)
0 commit comments