@@ -168,6 +168,7 @@ cdef class RandomState:
168
168
__MAXSIZE = < uint64_t > sys .maxsize
169
169
cdef object __seed
170
170
cdef object __stream
171
+ cdef object __version
171
172
172
173
IF RS_RNG_SEED == 1 :
173
174
def __init__ (self , seed = None ):
@@ -176,6 +177,8 @@ cdef class RandomState:
176
177
IF RS_RNG_MOD_NAME == 'dsfmt' :
177
178
self .rng_state .buffered_uniforms = < double * > PyArray_malloc_aligned (2 * DSFMT_N * sizeof (double ))
178
179
self .lock = Lock ()
180
+ self .__version = 0
181
+
179
182
self .__seed = seed
180
183
self .__stream = None
181
184
@@ -186,6 +189,8 @@ cdef class RandomState:
186
189
self .rng_state .rng = < rng_t * > PyArray_malloc_aligned (sizeof (rng_t ))
187
190
self .rng_state .binomial = & self .binomial_info
188
191
self .lock = Lock ()
192
+ self .__version = 0
193
+
189
194
self .__seed = seed
190
195
self .__stream = stream
191
196
@@ -444,7 +449,8 @@ cdef class RandomState:
444
449
'state' : _get_state (self .rng_state ),
445
450
'gauss' : {'has_gauss' : self .rng_state .has_gauss , 'gauss' : self .rng_state .gauss },
446
451
'uint32' : {'has_uint32' : self .rng_state .has_uint32 , 'uint32' : self .rng_state .uinteger },
447
- 'seed' : self .__seed }
452
+ 'seed' : self .__seed ,
453
+ 'version' : self .__version }
448
454
if self .__stream is not None :
449
455
state ['stream' ] = self .__stream
450
456
return state
@@ -485,7 +491,8 @@ cdef class RandomState:
485
491
'state' : _get_state (self .rng_state ),
486
492
'gauss' : {'has_gauss' : self .rng_state .has_gauss , 'gauss' : self .rng_state .gauss },
487
493
'uint32' : {'has_uint32' : self .rng_state .has_uint32 , 'uint32' : self .rng_state .uinteger },
488
- 'seed' : self .__seed }
494
+ 'seed' : self .__seed ,
495
+ 'version' : self .__version }
489
496
if self .__stream is not None :
490
497
state ['stream' ] = self .__stream
491
498
return state
@@ -545,6 +552,9 @@ cdef class RandomState:
545
552
546
553
if state ['name' ] != rng_name :
547
554
raise ValueError ('Not a ' + rng_name + ' RNG state' )
555
+ if 'version' in state :
556
+ if state ['version' ] != 0 :
557
+ raise NotImplementedError ('Support for multiple version has not been implemented.' )
548
558
549
559
_set_state (& self .rng_state , state ['state' ])
550
560
self .rng_state .has_gauss = state ['gauss' ]['has_gauss' ]
0 commit comments