1
- import functools
2
1
from typing import Optional
3
2
4
3
import numpy as np
5
4
6
- from pytensor import In , Out , get_scalar_constant_value
5
+ from pytensor import In , Out
7
6
from pytensor .compile import optdb , pfunc
8
7
from pytensor .graph import Apply , FunctionGraph , Op , Type , node_rewriter
9
8
from pytensor .graph .rewriting .basic import in2out
10
9
from pytensor .scalar import constant
11
- from pytensor .tensor import NoneConst , and_ , empty , minimum , set_subtensor
10
+ from pytensor .tensor import (
11
+ NoneConst ,
12
+ add ,
13
+ and_ ,
14
+ empty ,
15
+ get_scalar_constant_value ,
16
+ set_subtensor ,
17
+ )
12
18
from pytensor .tensor .exceptions import NotScalarConstantError
13
19
from pytensor .tensor .shape import Shape_i
14
20
from pytensor .tensor .type import DenseTensorType , TensorType
17
23
18
24
def validate_loop_update_types (update ):
19
25
assert update .outputs [0 ].type .dtype == "bool"
20
- for input_state , output_state in zip (update .inputs , update .outputs [1 :]):
21
- assert input_state .type == output_state .type
26
+ for i , (input_state , output_state ) in enumerate (
27
+ zip (update .inputs , update .outputs [1 :])
28
+ ):
29
+ if input_state .type != output_state .type :
30
+ raise TypeError (
31
+ f"The { i } -th input and output states of the inner loop function have different types: "
32
+ f"{ input_state .type } vs { output_state .type } ."
33
+ )
22
34
23
35
24
36
class Loop (Op ):
@@ -128,11 +140,11 @@ class Scan(Op):
128
140
129
141
Roughly equivalent to
130
142
```
131
- def scan(fn, initial_states, sequences, constants, max_iters):
143
+ def scan(fn, initial_states, constants, max_iters):
132
144
traces = [[]*len(initial_states)]
133
145
states = initial_states
134
- for (idx, *subsequences) in zip(*( range(max_iters), *sequences) ):
135
- resume, states = fn(*states, *subsequences, * constants)
146
+ for i in range(max_iters):
147
+ resume, states = fn(*states, *constants)
136
148
for trace, state in zip(traces, states):
137
149
trace.append(state)
138
150
if not resume:
@@ -142,15 +154,12 @@ def scan(fn, initial_states, sequences, constants, max_iters):
142
154
Not all types of states can be collected, for instance RandomGenerator. For these
143
155
`None` is returned in place of the respective traces
144
156
145
- The number of iterations is bounded by max_iters or the shortest of sequences.
146
-
147
157
This Op must always be converted to a Loop during compilation.
148
158
"""
149
159
150
160
def __init__ (
151
161
self ,
152
162
update_fg : FunctionGraph , # (*state, *consts) -> (bool, *state)
153
- n_sequences : int ,
154
163
reverse_fg : Optional [FunctionGraph ] = None ,
155
164
):
156
165
validate_loop_update_types (update_fg )
@@ -170,61 +179,29 @@ def __init__(
170
179
# We can't concatenate all types of states, such as RandomTypes
171
180
self .trace_types .append (NoneConst .type )
172
181
173
- self .n_sequences = n_sequences
174
- self .sequence_types = []
175
- for inner_seq in update_fg .inputs [
176
- self .n_states : self .n_states + self .n_sequences
177
- ]:
178
- # TODO: Accomodate other sequence types
179
- assert isinstance (inner_seq .type , DenseTensorType )
180
- self .sequence_types .append (
181
- DenseTensorType (
182
- shape = (None , * inner_seq .type .shape ), dtype = inner_seq .type .dtype
183
- )
184
- )
185
-
186
- self .non_sequence_types = [
187
- inp .type for inp in update_fg .inputs [self .n_states + self .n_sequences :]
188
- ]
189
- self .n_non_sequences = len (self .non_sequence_types )
182
+ self .constant_types = [inp .type for inp in update_fg .inputs [self .n_states :]]
183
+ self .n_constants = len (self .constant_types )
190
184
191
185
self .update_fg = update_fg .clone (check_integrity = False )
192
186
self .reverse_fg = (
193
187
reverse_fg .clone (check_integrity = False ) if reverse_fg is not None else None
194
188
)
195
189
196
190
def make_node (self , max_iters , * inputs ):
197
- assert len (inputs ) == self .n_states + self .n_sequences + self .n_non_sequences
198
-
199
- if self .n_sequences == 0 and max_iters is None :
200
- raise ValueError ("Must provide max_iters in Scans without sequences" )
191
+ assert len (inputs ) == self .n_states + self .n_constants
201
192
202
- if max_iters is not None :
203
- max_iters = TensorType (dtype = "int64" , shape = ()).filter_variable (max_iters )
193
+ max_iters = TensorType (dtype = "int64" , shape = ()).filter_variable (max_iters )
204
194
205
195
states = inputs [: self .n_states ]
206
196
states = [
207
197
inp_type .filter_variable (inp )
208
198
for inp_type , inp in zip (self .state_types , states )
209
199
]
210
200
211
- sequences = inputs [self .n_states : self . n_states + self . n_sequences ]
212
- sequences = [
201
+ constants = inputs [self .n_states :]
202
+ constants = [
213
203
inp_type .filter_variable (inp )
214
- for inp_type , inp in zip (self .sequence_types , sequences )
215
- ]
216
- if sequences :
217
- leading_dims = [seq .shape [0 ] for seq in sequences ]
218
- shortest_dim = functools .reduce (minimum , leading_dims )
219
- if max_iters is None :
220
- max_iters = shortest_dim
221
- else :
222
- max_iters = minimum (max_iters , shortest_dim )
223
-
224
- non_sequences = inputs [self .n_states + self .n_sequences :]
225
- non_sequences = [
226
- inp_type .filter_variable (inp )
227
- for inp_type , inp in zip (self .non_sequence_types , non_sequences )
204
+ for inp_type , inp in zip (self .constant_types , constants )
228
205
]
229
206
230
207
# If there is no loop condition, `max_iters` exclusively defines the number of iterations
@@ -249,7 +226,7 @@ def make_node(self, max_iters, *inputs):
249
226
250
227
return Apply (
251
228
self ,
252
- [max_iters , * states , * sequences , * non_sequences ],
229
+ [max_iters , * states , * constants ],
253
230
[output_type () for output_type in self .state_types + trace_types ],
254
231
)
255
232
@@ -299,20 +276,16 @@ def scan_to_loop(fgraph, node):
299
276
It roughly creates the following computational graph
300
277
```
301
278
302
- def scan(fn, initial_states, sequences, constants, max_iters):
303
-
304
- def update_fn(idx, states, traces, sequences, constants, max_iters)
305
- subsequences = [seq[idx] for seq in subsequences]
306
- resume, states = inner_fn(states, subsequences, constants)
307
- for trace, state in zip(traces, states):
308
- trace[idx] = state
309
- return (resume and (idx < max_iters)), idx + 1, states, traces
310
-
279
+ def scan(fn, idx, initial_states, constants, max_iters):
311
280
idx = 0
281
+ states = initial_states
312
282
traces = [empty(max_iters, *initial_state.shape) for initial_state in initial_states]
313
283
while True:
314
- resume, idx, states, traces = update_fn(idx, *states, *traces, *sequences, *constants, max_iters)
315
- if not resume:
284
+ resume, states, fn(*states, *traces, *constants)
285
+ for trace, state in zip(traces, states):
286
+ trace[idx] = state
287
+ idx += 1
288
+ if not resume or idx >= max_iters:
316
289
break
317
290
traces = [trace[: idx] for trace in traces]
318
291
return states, traces
@@ -339,7 +312,6 @@ def update_fn(idx, states, traces, sequences, constants, max_iters)
339
312
340
313
# Inputs to the new Loop
341
314
max_iters = node .inputs [0 ]
342
- init_idx = constant (np .array (0 , dtype = "int64" ), name = "idx" )
343
315
init_states = node .inputs [1 : 1 + op .n_states ]
344
316
init_traces = [
345
317
empty (
@@ -348,79 +320,103 @@ def update_fn(idx, states, traces, sequences, constants, max_iters)
348
320
)
349
321
for trace_idx in used_traces_idxs
350
322
]
351
- sequences = node .inputs [1 + op .n_states : 1 + op .n_states + op .n_sequences ]
352
- non_sequences = node .inputs [1 + op .n_states + op .n_sequences :]
323
+ constants = node .inputs [1 + op .n_states :]
353
324
354
- new_fg = op .update_fg .clone (check_integrity = False )
325
+ update_fg = op .update_fg .clone (check_integrity = False )
355
326
356
- # Inner index
357
- inner_prev_idx = init_idx .type ()
358
- inner_prev_idx .name = "prev_idx"
327
+ # Check if inner_fg computes and index already, otherwise create a new one
328
+ has_idx = False
329
+ if len (node .inputs ) > 1 :
330
+ try :
331
+ outer_inp = node .inputs [1 ]
332
+ outer_is_zero = get_scalar_constant_value (outer_inp ) == 0
333
+ except NotScalarConstantError :
334
+ pass
335
+ else :
336
+ if (
337
+ outer_is_zero
338
+ and len (update_fg .inputs ) > 0
339
+ and len (update_fg .outputs ) > 1
340
+ ):
341
+ inner_out = update_fg .outputs [1 ]
342
+ if (
343
+ inner_out .owner is not None
344
+ and inner_out .owner .op == add
345
+ and len (inner_out .owner .inputs ) == 2
346
+ ):
347
+ left , right = inner_out .owner .inputs
348
+ if left is update_fg .inputs [0 ]:
349
+ try :
350
+ has_idx = (
351
+ get_scalar_constant_value (
352
+ right , only_process_constants = True
353
+ )
354
+ == 1
355
+ )
356
+ except NotScalarConstantError :
357
+ pass
358
+
359
+ if has_idx :
360
+ init_idx = outer_inp
361
+ inner_idx = inner_out .owner .inputs [0 ]
362
+ inner_next_idx = inner_out
363
+ if not has_idx :
364
+ init_idx = constant (np .array (0 , dtype = "int64" ), name = "idx" )
365
+ inner_idx = init_idx .type ()
366
+ inner_idx .name = "idx"
367
+ inner_next_idx = inner_idx + 1
368
+ inner_next_idx .name = "next_idx"
359
369
360
370
# Inner traces
361
- inner_prev_states = new_fg .inputs [: op .n_states ]
362
- inner_prev_traces = [init_trace .type () for init_trace in init_traces ]
363
- for s , t in zip (inner_prev_states , inner_prev_traces ):
364
- t .name = "prev_trace "
371
+ inner_states = update_fg .inputs [: op .n_states ]
372
+ inner_traces = [init_trace .type () for init_trace in init_traces ]
373
+ for s , t in zip (inner_states , inner_traces ):
374
+ t .name = "trace "
365
375
if s .name :
366
376
t .name = "_" .join ((t .name , s .name ))
367
377
368
- inner_non_sequences = new_fg .inputs [op .n_states + op .n_sequences :]
369
-
370
- # Replace inner sub-sequences by sequence[idx]
371
- inner_seqs_news = []
372
- if op .n_sequences :
373
- inner_subseqs_old = new_fg .inputs [op .n_states : op .n_states + op .n_sequences ]
374
- inner_subseqs_new = []
375
- for sequence in sequences :
376
- inner_seq_new = sequence .type ()
377
- inner_seq_new .name = sequence .name or "sequence"
378
- inner_seqs_news .append (inner_seq_new )
379
- inner_subseq_new = inner_seq_new [inner_prev_idx ]
380
- inner_subseq_new .name = inner_seq_new .name + "[prev_idx]"
381
- inner_subseqs_new .append (inner_subseq_new )
382
-
383
- # Replace inner_sequence input by sequence[idx]
384
- replacements = tuple (zip (inner_subseqs_old , inner_subseqs_new ))
385
- new_fg .replace_all (replacements , import_missing = True )
386
-
387
- # Inner continue condition and index
388
- inner_continue_cond , * inner_next_states = new_fg .outputs
389
- inner_next_idx = inner_prev_idx + 1
390
- inner_next_idx .name = "next_idx"
378
+ inner_constants = update_fg .inputs [op .n_states :]
379
+
380
+ # Inner continue condition
381
+ inner_continue_cond , * inner_next_states = update_fg .outputs
391
382
inner_next_traces = [
392
- set_subtensor (prev_trace [inner_prev_idx ], inner_next_states [trace_idx ])
393
- for trace_idx , prev_trace in zip (used_traces_idxs , inner_prev_traces )
383
+ set_subtensor (prev_trace [inner_idx ], inner_next_states [trace_idx ])
384
+ for trace_idx , prev_trace in zip (used_traces_idxs , inner_traces )
394
385
]
395
386
for t in inner_next_traces :
396
387
t .name = "next_trace"
397
388
inner_max_iters = max_iters .type ()
398
389
inner_continue_cond = and_ (inner_continue_cond , inner_next_idx < inner_max_iters )
399
390
inner_continue_cond .name = "continue(?)"
400
391
401
- new_fg = FunctionGraph (
392
+ if not has_idx :
393
+ init_states = [init_idx ] + init_states
394
+ inner_states = [inner_idx ] + inner_states
395
+ inner_next_states = [inner_next_idx ] + inner_next_states
396
+
397
+ new_update_fg = FunctionGraph (
402
398
inputs = [
403
- inner_prev_idx ,
404
- * inner_prev_states ,
405
- * inner_prev_traces ,
406
- * inner_seqs_news ,
407
- * inner_non_sequences ,
399
+ * inner_states ,
400
+ * inner_traces ,
401
+ * inner_constants ,
408
402
inner_max_iters ,
409
403
],
410
404
outputs = [
411
405
inner_continue_cond ,
412
- inner_next_idx ,
413
406
* inner_next_states ,
414
407
* inner_next_traces ,
415
408
],
416
409
)
417
410
418
411
# TODO: Implement Reverse?
419
- loop_op = Loop (update_fg = new_fg )
420
-
421
- final_idx , * new_outs = loop_op (
422
- init_idx , * init_states , * init_traces , * sequences , * non_sequences , max_iters
423
- )
412
+ loop_op = Loop (update_fg = new_update_fg )
413
+
414
+ new_outs = loop_op (* init_states , * init_traces , * constants , max_iters )
415
+ if has_idx :
416
+ # idx was part of the original scan, and therefore has a corresponding trace
417
+ final_idx = new_outs [0 ]
418
+ else :
419
+ final_idx , * new_outs = new_outs
424
420
new_states = new_outs [: op .n_states ]
425
421
new_traces = new_outs [op .n_states :]
426
422
0 commit comments