Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 7bcd42c

Browse files
committedJan 12, 2023
Keep outer graph visible to the scan user function, including sequences
Sequences are now demoted to being just another constant in the Scan Op. The user facing function creates the right indexing graph for iterating over sequences automatically. Some extra logic is added in the `scan_to_loop` rewrite to avoid creating duplicated indexes, while being on guard for Scans created elsewhere.
1 parent c46cd53 commit 7bcd42c

File tree

4 files changed

+197
-147
lines changed

4 files changed

+197
-147
lines changed
 

‎pytensor/loop/basic.py‎

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import functools
12
from typing import List, Tuple
23

34
import numpy as np
45

5-
from pytensor import Variable, as_symbolic
6+
from pytensor import Variable, as_symbolic, clone_replace
67
from pytensor.graph import FunctionGraph
8+
from pytensor.graph.basic import Constant, truncated_graph_inputs
79
from pytensor.loop.op import Scan
810
from pytensor.scan.utils import until
9-
from pytensor.tensor import as_tensor, empty_like
11+
from pytensor.tensor import as_tensor, constant, empty_like, minimum
1012

1113

1214
def scan(
@@ -20,6 +22,8 @@ def scan(
2022
if sequences is None and n_steps is None:
2123
raise ValueError("Must provide n_steps when scanning without sequences")
2224

25+
# TODO: init_states should be made opaque to the inner function,
26+
# since any relationship to the outer graph no longer holds
2327
if init_states is None:
2428
init_states = []
2529
else:
@@ -34,20 +38,31 @@ def scan(
3438
sequences = [sequences]
3539
sequences = [as_tensor(s) for s in sequences]
3640

41+
if sequences:
42+
leading_dims = [seq.shape[0] for seq in sequences]
43+
shortest_dim = functools.reduce(minimum, leading_dims)
44+
if n_steps is None:
45+
n_steps = shortest_dim
46+
else:
47+
n_steps = minimum(n_steps, shortest_dim)
48+
3749
if non_sequences is None:
3850
non_sequences = []
3951
else:
4052
if not isinstance(non_sequences, (tuple, list)):
4153
non_sequences = [non_sequences]
4254
non_sequences = [as_symbolic(n) for n in non_sequences]
4355

56+
# Create subsequence inputs for the inner function
57+
idx = constant(0, dtype="int64", name="idx")
58+
symbolic_idx = idx.type(name="idx")
59+
subsequences = [s[symbolic_idx] for s in sequences]
4460
# Note: Old scan order is sequences + init + non_sequences
45-
inner_sequences = [s[0] for s in sequences]
46-
inner_inputs = [i.type() for i in init_states + inner_sequences + non_sequences]
47-
inner_outputs = fn(*inner_inputs)
48-
if not isinstance(inner_outputs, (tuple, list)):
49-
inner_outputs = [inner_outputs]
50-
next_states = [out for out in inner_outputs if not isinstance(out, until)]
61+
fn_inputs = init_states + subsequences + non_sequences
62+
fn_outputs = fn(*fn_inputs)
63+
if not isinstance(fn_outputs, (tuple, list)):
64+
fn_outputs = [fn_outputs]
65+
next_states = [out for out in fn_outputs if not isinstance(out, until)]
5166

5267
if len(next_states) > len(init_states):
5368
if not init_states:
@@ -61,27 +76,43 @@ def scan(
6176
prev_states = []
6277
for i, (init_state, next_state) in enumerate(zip(init_states, next_states)):
6378
if init_state is None:
79+
# next_state may reference idx, let's replace that by the initial value
80+
[next_state] = clone_replace(
81+
output=[next_state], replace={symbolic_idx: idx}
82+
)
6483
init_state = empty_like(next_state)
6584
init_state.name = "empty_init_state"
66-
inner_inputs.insert(i, init_state.type())
6785
prev_states.append(init_state)
6886

69-
until_condition = [out.condition for out in inner_outputs if isinstance(out, until)]
87+
until_condition = [out.condition for out in fn_outputs if isinstance(out, until)]
7088
if not until_condition:
7189
until_condition = [as_tensor(np.array(True))]
7290
if len(until_condition) > 1:
7391
raise ValueError("Only one until condition can be returned")
7492

75-
update_fg = FunctionGraph(
76-
inputs=inner_inputs, outputs=until_condition + next_states
93+
fgraph_inputs = [symbolic_idx] + prev_states + sequences + non_sequences
94+
fgraph_outputs = until_condition + [symbolic_idx + 1] + next_states
95+
96+
all_fgraph_inputs = truncated_graph_inputs(
97+
fgraph_outputs, ancestors_to_include=fgraph_inputs
98+
)
99+
extra_fgraph_inputs = [
100+
inp
101+
for inp in all_fgraph_inputs
102+
if (not isinstance(inp, Constant) and inp not in fgraph_inputs)
103+
]
104+
fgraph_inputs = fgraph_inputs + extra_fgraph_inputs
105+
update_fg = FunctionGraph(inputs=fgraph_inputs, outputs=fgraph_outputs)
106+
107+
scan_op = Scan(update_fg=update_fg)
108+
scan_outs = scan_op(
109+
n_steps, idx, *prev_states, *sequences, *non_sequences, *extra_fgraph_inputs
77110
)
78-
scan_op = Scan(update_fg=update_fg, n_sequences=len(sequences))
79-
scan_outs = scan_op(n_steps, *prev_states, *sequences, *non_sequences)
80111
assert isinstance(scan_outs, list)
81112
last_states = scan_outs[: scan_op.n_states]
82113
traces = scan_outs[scan_op.n_states :]
83-
84-
return last_states, traces
114+
# Don't return the inner index state
115+
return last_states[1:], traces[1:]
85116

86117

87118
def map(

‎pytensor/loop/op.py‎

Lines changed: 106 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1-
import functools
21
from typing import Optional
32

43
import numpy as np
54

6-
from pytensor import In, Out, get_scalar_constant_value
5+
from pytensor import In, Out
76
from pytensor.compile import optdb, pfunc
87
from pytensor.graph import Apply, FunctionGraph, Op, Type, node_rewriter
98
from pytensor.graph.rewriting.basic import in2out
109
from pytensor.scalar import constant
11-
from pytensor.tensor import NoneConst, and_, empty, minimum, set_subtensor
10+
from pytensor.tensor import (
11+
NoneConst,
12+
add,
13+
and_,
14+
empty,
15+
get_scalar_constant_value,
16+
set_subtensor,
17+
)
1218
from pytensor.tensor.exceptions import NotScalarConstantError
1319
from pytensor.tensor.shape import Shape_i
1420
from pytensor.tensor.type import DenseTensorType, TensorType
@@ -17,8 +23,14 @@
1723

1824
def validate_loop_update_types(update):
1925
assert update.outputs[0].type.dtype == "bool"
20-
for input_state, output_state in zip(update.inputs, update.outputs[1:]):
21-
assert input_state.type == output_state.type
26+
for i, (input_state, output_state) in enumerate(
27+
zip(update.inputs, update.outputs[1:])
28+
):
29+
if input_state.type != output_state.type:
30+
raise TypeError(
31+
f"The {i}-th input and output states of the inner loop function have different types: "
32+
f"{input_state.type} vs {output_state.type}."
33+
)
2234

2335

2436
class Loop(Op):
@@ -128,11 +140,11 @@ class Scan(Op):
128140
129141
Roughly equivalent to
130142
```
131-
def scan(fn, initial_states, sequences, constants, max_iters):
143+
def scan(fn, initial_states, constants, max_iters):
132144
traces = [[]*len(initial_states)]
133145
states = initial_states
134-
for (idx, *subsequences) in zip(*(range(max_iters), *sequences)):
135-
resume, states = fn(*states, *subsequences, *constants)
146+
for i in range(max_iters):
147+
resume, states = fn(*states, *constants)
136148
for trace, state in zip(traces, states):
137149
trace.append(state)
138150
if not resume:
@@ -142,15 +154,12 @@ def scan(fn, initial_states, sequences, constants, max_iters):
142154
Not all types of states can be collected, for instance RandomGenerator. For these
143155
`None` is returned in place of the respective traces
144156
145-
The number of iterations is bounded by max_iters or the shortest of sequences.
146-
147157
This Op must always be converted to a Loop during compilation.
148158
"""
149159

150160
def __init__(
151161
self,
152162
update_fg: FunctionGraph, # (*state, *consts) -> (bool, *state)
153-
n_sequences: int,
154163
reverse_fg: Optional[FunctionGraph] = None,
155164
):
156165
validate_loop_update_types(update_fg)
@@ -170,61 +179,29 @@ def __init__(
170179
# We can't concatenate all types of states, such as RandomTypes
171180
self.trace_types.append(NoneConst.type)
172181

173-
self.n_sequences = n_sequences
174-
self.sequence_types = []
175-
for inner_seq in update_fg.inputs[
176-
self.n_states : self.n_states + self.n_sequences
177-
]:
178-
# TODO: Accomodate other sequence types
179-
assert isinstance(inner_seq.type, DenseTensorType)
180-
self.sequence_types.append(
181-
DenseTensorType(
182-
shape=(None, *inner_seq.type.shape), dtype=inner_seq.type.dtype
183-
)
184-
)
185-
186-
self.non_sequence_types = [
187-
inp.type for inp in update_fg.inputs[self.n_states + self.n_sequences :]
188-
]
189-
self.n_non_sequences = len(self.non_sequence_types)
182+
self.constant_types = [inp.type for inp in update_fg.inputs[self.n_states :]]
183+
self.n_constants = len(self.constant_types)
190184

191185
self.update_fg = update_fg.clone(check_integrity=False)
192186
self.reverse_fg = (
193187
reverse_fg.clone(check_integrity=False) if reverse_fg is not None else None
194188
)
195189

196190
def make_node(self, max_iters, *inputs):
197-
assert len(inputs) == self.n_states + self.n_sequences + self.n_non_sequences
198-
199-
if self.n_sequences == 0 and max_iters is None:
200-
raise ValueError("Must provide max_iters in Scans without sequences")
191+
assert len(inputs) == self.n_states + self.n_constants
201192

202-
if max_iters is not None:
203-
max_iters = TensorType(dtype="int64", shape=()).filter_variable(max_iters)
193+
max_iters = TensorType(dtype="int64", shape=()).filter_variable(max_iters)
204194

205195
states = inputs[: self.n_states]
206196
states = [
207197
inp_type.filter_variable(inp)
208198
for inp_type, inp in zip(self.state_types, states)
209199
]
210200

211-
sequences = inputs[self.n_states : self.n_states + self.n_sequences]
212-
sequences = [
201+
constants = inputs[self.n_states :]
202+
constants = [
213203
inp_type.filter_variable(inp)
214-
for inp_type, inp in zip(self.sequence_types, sequences)
215-
]
216-
if sequences:
217-
leading_dims = [seq.shape[0] for seq in sequences]
218-
shortest_dim = functools.reduce(minimum, leading_dims)
219-
if max_iters is None:
220-
max_iters = shortest_dim
221-
else:
222-
max_iters = minimum(max_iters, shortest_dim)
223-
224-
non_sequences = inputs[self.n_states + self.n_sequences :]
225-
non_sequences = [
226-
inp_type.filter_variable(inp)
227-
for inp_type, inp in zip(self.non_sequence_types, non_sequences)
204+
for inp_type, inp in zip(self.constant_types, constants)
228205
]
229206

230207
# If there is no loop condition, `max_iters` exclusively defines the number of iterations
@@ -249,7 +226,7 @@ def make_node(self, max_iters, *inputs):
249226

250227
return Apply(
251228
self,
252-
[max_iters, *states, *sequences, *non_sequences],
229+
[max_iters, *states, *constants],
253230
[output_type() for output_type in self.state_types + trace_types],
254231
)
255232

@@ -299,20 +276,16 @@ def scan_to_loop(fgraph, node):
299276
It roughly creates the following computational graph
300277
```
301278
302-
def scan(fn, initial_states, sequences, constants, max_iters):
303-
304-
def update_fn(idx, states, traces, sequences, constants, max_iters)
305-
subsequences = [seq[idx] for seq in subsequences]
306-
resume, states = inner_fn(states, subsequences, constants)
307-
for trace, state in zip(traces, states):
308-
trace[idx] = state
309-
return (resume and (idx < max_iters)), idx + 1, states, traces
310-
279+
def scan(fn, idx, initial_states, constants, max_iters):
311280
idx = 0
281+
states = initial_states
312282
traces = [empty(max_iters, *initial_state.shape) for initial_state in initial_states]
313283
while True:
314-
resume, idx, states, traces = update_fn(idx, *states, *traces, *sequences, *constants, max_iters)
315-
if not resume:
284+
resume, states, fn(*states, *traces, *constants)
285+
for trace, state in zip(traces, states):
286+
trace[idx] = state
287+
idx += 1
288+
if not resume or idx >= max_iters:
316289
break
317290
traces = [trace[: idx] for trace in traces]
318291
return states, traces
@@ -339,7 +312,6 @@ def update_fn(idx, states, traces, sequences, constants, max_iters)
339312

340313
# Inputs to the new Loop
341314
max_iters = node.inputs[0]
342-
init_idx = constant(np.array(0, dtype="int64"), name="idx")
343315
init_states = node.inputs[1 : 1 + op.n_states]
344316
init_traces = [
345317
empty(
@@ -348,79 +320,103 @@ def update_fn(idx, states, traces, sequences, constants, max_iters)
348320
)
349321
for trace_idx in used_traces_idxs
350322
]
351-
sequences = node.inputs[1 + op.n_states : 1 + op.n_states + op.n_sequences]
352-
non_sequences = node.inputs[1 + op.n_states + op.n_sequences :]
323+
constants = node.inputs[1 + op.n_states :]
353324

354-
new_fg = op.update_fg.clone(check_integrity=False)
325+
update_fg = op.update_fg.clone(check_integrity=False)
355326

356-
# Inner index
357-
inner_prev_idx = init_idx.type()
358-
inner_prev_idx.name = "prev_idx"
327+
# Check if inner_fg computes and index already, otherwise create a new one
328+
has_idx = False
329+
if len(node.inputs) > 1:
330+
try:
331+
outer_inp = node.inputs[1]
332+
outer_is_zero = get_scalar_constant_value(outer_inp) == 0
333+
except NotScalarConstantError:
334+
pass
335+
else:
336+
if (
337+
outer_is_zero
338+
and len(update_fg.inputs) > 0
339+
and len(update_fg.outputs) > 1
340+
):
341+
inner_out = update_fg.outputs[1]
342+
if (
343+
inner_out.owner is not None
344+
and inner_out.owner.op == add
345+
and len(inner_out.owner.inputs) == 2
346+
):
347+
left, right = inner_out.owner.inputs
348+
if left is update_fg.inputs[0]:
349+
try:
350+
has_idx = (
351+
get_scalar_constant_value(
352+
right, only_process_constants=True
353+
)
354+
== 1
355+
)
356+
except NotScalarConstantError:
357+
pass
358+
359+
if has_idx:
360+
init_idx = outer_inp
361+
inner_idx = inner_out.owner.inputs[0]
362+
inner_next_idx = inner_out
363+
if not has_idx:
364+
init_idx = constant(np.array(0, dtype="int64"), name="idx")
365+
inner_idx = init_idx.type()
366+
inner_idx.name = "idx"
367+
inner_next_idx = inner_idx + 1
368+
inner_next_idx.name = "next_idx"
359369

360370
# Inner traces
361-
inner_prev_states = new_fg.inputs[: op.n_states]
362-
inner_prev_traces = [init_trace.type() for init_trace in init_traces]
363-
for s, t in zip(inner_prev_states, inner_prev_traces):
364-
t.name = "prev_trace"
371+
inner_states = update_fg.inputs[: op.n_states]
372+
inner_traces = [init_trace.type() for init_trace in init_traces]
373+
for s, t in zip(inner_states, inner_traces):
374+
t.name = "trace"
365375
if s.name:
366376
t.name = "_".join((t.name, s.name))
367377

368-
inner_non_sequences = new_fg.inputs[op.n_states + op.n_sequences :]
369-
370-
# Replace inner sub-sequences by sequence[idx]
371-
inner_seqs_news = []
372-
if op.n_sequences:
373-
inner_subseqs_old = new_fg.inputs[op.n_states : op.n_states + op.n_sequences]
374-
inner_subseqs_new = []
375-
for sequence in sequences:
376-
inner_seq_new = sequence.type()
377-
inner_seq_new.name = sequence.name or "sequence"
378-
inner_seqs_news.append(inner_seq_new)
379-
inner_subseq_new = inner_seq_new[inner_prev_idx]
380-
inner_subseq_new.name = inner_seq_new.name + "[prev_idx]"
381-
inner_subseqs_new.append(inner_subseq_new)
382-
383-
# Replace inner_sequence input by sequence[idx]
384-
replacements = tuple(zip(inner_subseqs_old, inner_subseqs_new))
385-
new_fg.replace_all(replacements, import_missing=True)
386-
387-
# Inner continue condition and index
388-
inner_continue_cond, *inner_next_states = new_fg.outputs
389-
inner_next_idx = inner_prev_idx + 1
390-
inner_next_idx.name = "next_idx"
378+
inner_constants = update_fg.inputs[op.n_states :]
379+
380+
# Inner continue condition
381+
inner_continue_cond, *inner_next_states = update_fg.outputs
391382
inner_next_traces = [
392-
set_subtensor(prev_trace[inner_prev_idx], inner_next_states[trace_idx])
393-
for trace_idx, prev_trace in zip(used_traces_idxs, inner_prev_traces)
383+
set_subtensor(prev_trace[inner_idx], inner_next_states[trace_idx])
384+
for trace_idx, prev_trace in zip(used_traces_idxs, inner_traces)
394385
]
395386
for t in inner_next_traces:
396387
t.name = "next_trace"
397388
inner_max_iters = max_iters.type()
398389
inner_continue_cond = and_(inner_continue_cond, inner_next_idx < inner_max_iters)
399390
inner_continue_cond.name = "continue(?)"
400391

401-
new_fg = FunctionGraph(
392+
if not has_idx:
393+
init_states = [init_idx] + init_states
394+
inner_states = [inner_idx] + inner_states
395+
inner_next_states = [inner_next_idx] + inner_next_states
396+
397+
new_update_fg = FunctionGraph(
402398
inputs=[
403-
inner_prev_idx,
404-
*inner_prev_states,
405-
*inner_prev_traces,
406-
*inner_seqs_news,
407-
*inner_non_sequences,
399+
*inner_states,
400+
*inner_traces,
401+
*inner_constants,
408402
inner_max_iters,
409403
],
410404
outputs=[
411405
inner_continue_cond,
412-
inner_next_idx,
413406
*inner_next_states,
414407
*inner_next_traces,
415408
],
416409
)
417410

418411
# TODO: Implement Reverse?
419-
loop_op = Loop(update_fg=new_fg)
420-
421-
final_idx, *new_outs = loop_op(
422-
init_idx, *init_states, *init_traces, *sequences, *non_sequences, max_iters
423-
)
412+
loop_op = Loop(update_fg=new_update_fg)
413+
414+
new_outs = loop_op(*init_states, *init_traces, *constants, max_iters)
415+
if has_idx:
416+
# idx was part of the original scan, and therefore has a corresponding trace
417+
final_idx = new_outs[0]
418+
else:
419+
final_idx, *new_outs = new_outs
424420
new_states = new_outs[: op.n_states]
425421
new_traces = new_outs[op.n_states :]
426422

‎tests/loop/basic.py‎

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import numpy as np
22

33
import pytensor
4+
from pytensor import grad
45
from pytensor.loop.basic import filter, map, reduce, scan
5-
from pytensor.tensor import eq, vector, zeros
6+
from pytensor.tensor import arange, eq, vector, zeros
67

78

89
def test_scan():
@@ -19,6 +20,35 @@ def test_scan():
1920
)
2021

2122

23+
def test_scan_taking_grads_non_sequence():
24+
xs = vector("xs")
25+
ys = xs**2
26+
27+
_, [J] = scan(
28+
lambda i, y, xs: grad(y[i], wrt=xs),
29+
sequences=arange(ys.shape[0]),
30+
non_sequences=[ys, xs],
31+
)
32+
33+
f = pytensor.function([xs], J)
34+
np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]])
35+
36+
37+
def test_scan_taking_grads_sequence():
38+
# This is not possible with the old Scan
39+
xs = vector("xs")
40+
ys = xs**2
41+
42+
_, [J] = scan(
43+
lambda y, xs: grad(y, wrt=xs),
44+
sequences=[ys],
45+
non_sequences=[xs],
46+
)
47+
48+
f = pytensor.function([xs], J)
49+
np.testing.assert_array_equal(f([4, 4]), np.c_[[8, 0], [0, 8]])
50+
51+
2252
def test_map():
2353
xs = vector("xs")
2454
ys = map(

‎tests/loop/test_op.py‎

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_fori_scan():
4141
update_fg = FunctionGraph([x], [constant(np.array(True)), x + 2])
4242

4343
n_iters = 10
44-
y, ys = Scan(n_sequences=0, update_fg=update_fg)(n_iters, x)
44+
y, ys = Scan(update_fg=update_fg)(n_iters, x)
4545

4646
fn = function([x], [y, ys])
4747

@@ -69,7 +69,7 @@ def test_fori_scan_shape():
6969
update_fg = FunctionGraph([x], [constant(np.array(True)), x + 2])
7070

7171
n_iters = 10
72-
_, ys = Scan(n_sequences=0, update_fg=update_fg)(n_iters, x)
72+
_, ys = Scan(update_fg=update_fg)(n_iters, x)
7373

7474
fn = function([x], ys.shape, on_unused_input="ignore")
7575
nodes = tuple(fn.maker.fgraph.apply_nodes)
@@ -84,9 +84,7 @@ def test_while_scan():
8484
update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2])
8585

8686
max_iters = 1000
87-
_, y, _, ys = Scan(n_sequences=0, update_fg=update_fg)(
88-
max_iters, np.array(0, dtype="int64"), x
89-
)
87+
_, y, _, ys = Scan(update_fg=update_fg)(max_iters, np.array(0, dtype="int64"), x)
9088

9189
fn = function([x], [y, ys])
9290

@@ -99,11 +97,10 @@ def test_while_scan():
9997
)
10098
assert len(loop_nodes) == 1
10199
(loop_node,) = loop_nodes
102-
assert len(loop_node.outputs) == 4
100+
assert len(loop_node.outputs) == 3
103101
assert loop_node.outputs[0].type.shape == ()
104102
assert loop_node.outputs[1].type.shape == ()
105-
assert loop_node.outputs[2].type.shape == ()
106-
assert loop_node.outputs[3].type.shape == (1000,)
103+
assert loop_node.outputs[2].type.shape == (1000,)
107104

108105
y_eval, ys_eval = fn(0)
109106
np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2))
@@ -116,9 +113,7 @@ def test_while_scan_shape():
116113
update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2])
117114

118115
max_iters = 1000
119-
_, _, _, ys = Scan(n_sequences=0, update_fg=update_fg)(
120-
max_iters, np.array(0, dtype="int64"), x
121-
)
116+
_, _, _, ys = Scan(update_fg=update_fg)(max_iters, np.array(0, dtype="int64"), x)
122117

123118
fn = function([x], ys.shape)
124119
loop_nodes = tuple(
@@ -129,18 +124,18 @@ def test_while_scan_shape():
129124

130125

131126
def test_foreach_scan():
132-
dummy_init = empty(())
133-
x = scalar("x")
127+
idx = scalar("idx", dtype="int64")
128+
dummy_x0 = empty(())
129+
xs = vector("xs")
134130
const = scalar("const")
135131
update_fg = FunctionGraph(
136-
[dummy_init, x, const], [constant(np.array(True)), x * const]
132+
[idx, dummy_x0, xs, const], [constant(np.array(True)), idx + 1, xs[idx] * const]
137133
)
138134

139-
xs = vector("xs")
140-
_, ys = Scan(n_sequences=1, update_fg=update_fg)(None, dummy_init, xs, const)
135+
n_steps = xs.shape[0]
136+
_, _, _, ys = Scan(update_fg=update_fg)(n_steps, 0, dummy_x0, xs, const)
141137

142138
fn = pytensor.function([xs, const], ys)
143-
pytensor.dprint(fn, print_type=True)
144139

145140
np.testing.assert_almost_equal(fn(np.arange(10), 100), np.arange(10) * 100)
146141

@@ -157,9 +152,7 @@ def test_fori_random_scan():
157152
[constant(np.array(True)), *normal(rng=rng).owner.outputs[::-1]],
158153
)
159154

160-
_, new_rng, ys, rngs = Scan(n_sequences=0, update_fg=update_fg)(
161-
n_iters, dummy_init, rng_shared
162-
)
155+
_, new_rng, ys, rngs = Scan(update_fg=update_fg)(n_iters, dummy_init, rng_shared)
163156
assert isinstance(rngs.type, NoneTypeT)
164157

165158
fn = function([], ys, updates={rng_shared: new_rng})

0 commit comments

Comments
 (0)
Please sign in to comment.