@@ -442,23 +442,26 @@ def rv_op(cls, rhos, sigma, init_dist, steps, ar_order, constant_term, size=None
442442 rhos_bcast = pt .broadcast_to (rhos , rhos_bcast_shape )
443443
444444 def step (* args ):
445- * prev_xs , reversed_rhos , sigma , rng = args
445+ * prev_xs , rng , reversed_rhos , sigma = args
446446 if constant_term :
447447 mu = reversed_rhos [- 1 ] + pt .sum (prev_xs * reversed_rhos [:- 1 ], axis = 0 )
448448 else :
449449 mu = pt .sum (prev_xs * reversed_rhos , axis = 0 )
450450 next_rng , new_x = Normal .dist (mu = mu , sigma = sigma , rng = rng ).owner .outputs
451- return new_x , { rng : next_rng }
451+ return new_x , next_rng
452452
453453 # We transpose inputs as scan iterates over first dimension
454- innov , innov_updates = pytensor .scan (
454+ innov , noise_next_rng = pytensor .scan (
455455 fn = step ,
456- outputs_info = [{"initial" : init_dist .T , "taps" : range (- ar_order , 0 )}],
457- non_sequences = [rhos_bcast .T [::- 1 ], sigma .T , noise_rng ],
456+ outputs_info = [
457+ {"initial" : init_dist .T , "taps" : range (- ar_order , 0 )},
458+ noise_rng ,
459+ ],
460+ non_sequences = [rhos_bcast .T [::- 1 ], sigma .T ],
458461 n_steps = steps ,
459462 strict = True ,
463+ return_updates = False ,
460464 )
461- (noise_next_rng ,) = tuple (innov_updates .values ())
462465 ar = pt .concatenate ([init_dist , innov .T ], axis = - 1 )
463466
464467 return AutoRegressiveRV (
@@ -710,24 +713,25 @@ def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None)
710713
711714 # Create OpFromGraph representing random draws from GARCH11 process
712715
713- def step (prev_y , prev_sigma , omega , alpha_1 , beta_1 , rng ):
716+ def step (prev_y , prev_sigma , rng , omega , alpha_1 , beta_1 ):
714717 new_sigma = pt .sqrt (
715718 omega + alpha_1 * pt .square (prev_y ) + beta_1 * pt .square (prev_sigma )
716719 )
717720 next_rng , new_y = Normal .dist (mu = 0 , sigma = new_sigma , rng = rng ).owner .outputs
718- return ( new_y , new_sigma ), { rng : next_rng }
721+ return new_y , new_sigma , next_rng
719722
720- ( y_t , _ ), innov_updates = pytensor .scan (
723+ y_t , _ , noise_next_rng = pytensor .scan (
721724 fn = step ,
722725 outputs_info = [
723726 init_dist ,
724727 pt .broadcast_to (initial_vol .astype ("floatX" ), init_dist .shape ),
728+ noise_rng ,
725729 ],
726- non_sequences = [omega , alpha_1 , beta_1 , noise_rng ],
730+ non_sequences = [omega , alpha_1 , beta_1 ],
727731 n_steps = steps ,
728732 strict = True ,
733+ return_updates = False ,
729734 )
730- (noise_next_rng ,) = tuple (innov_updates .values ())
731735
732736 garch11 = pt .concatenate ([init_dist [None , ...], y_t ], axis = 0 ).dimshuffle (
733737 (* range (1 , y_t .ndim ), 0 )
@@ -816,12 +820,13 @@ def garch11_logp(
816820 def volatility_update (x , vol , w , a , b ):
817821 return pt .sqrt (w + a * pt .square (x ) + b * pt .square (vol ))
818822
819- vol , _ = pytensor .scan (
823+ vol = pytensor .scan (
820824 fn = volatility_update ,
821825 sequences = [value_dimswapped [:- 1 ]],
822826 outputs_info = [initial_vol ],
823827 non_sequences = [omega , alpha_1 , beta_1 ],
824828 strict = True ,
829+ return_updates = False ,
825830 )
826831 sigma_t = pt .concatenate ([[initial_vol ], vol ])
827832 # Compute and collapse logp across time dimension
@@ -861,21 +866,21 @@ def rv_op(cls, init_dist, steps, sde_pars, dt, sde_fn, size=None):
861866
862867 # Create OpFromGraph representing random draws from SDE process
863868 def step (* prev_args ):
864- prev_y , * prev_sde_pars , rng = prev_args
869+ prev_y , rng , * prev_sde_pars = prev_args
865870 f , g = sde_fn (prev_y , * prev_sde_pars )
866871 mu = prev_y + dt * f
867872 sigma = pt .sqrt (dt ) * g
868873 next_rng , next_y = Normal .dist (mu = mu , sigma = sigma , rng = rng ).owner .outputs
869- return next_y , { rng : next_rng }
874+ return next_y , next_rng
870875
871- y_t , innov_updates = pytensor .scan (
876+ y_t , noise_next_rng = pytensor .scan (
872877 fn = step ,
873- outputs_info = [init_dist ],
874- non_sequences = [* sde_pars , noise_rng ],
878+ outputs_info = [init_dist , noise_rng ],
879+ non_sequences = [* sde_pars ],
875880 n_steps = steps ,
876881 strict = True ,
882+ return_updates = False ,
877883 )
878- (noise_next_rng ,) = tuple (innov_updates .values ())
879884
880885 sde_out = pt .concatenate ([init_dist [None , ...], y_t ], axis = 0 ).dimshuffle (
881886 (* range (1 , y_t .ndim ), 0 )
0 commit comments