Skip to content

Commit f872636

Browse files
committed
Merge pull request #52 from bashtage/add-version
ENH: Add version to state
2 parents e83c366 + 6697448 commit f872636

File tree

3 files changed

+19
-4
lines changed

3 files changed

+19
-4
lines changed

doc/source/xoroshiro128plus.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
XorShift128+ Randomstate
2-
************************
1+
XoroShiro128+ Randomstate
2+
*************************
33

44
.. currentmodule:: randomstate.prng.xoroshiro128plus
55

randomstate/randomstate.pyx

+12-2
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ cdef class RandomState:
168168
__MAXSIZE = <uint64_t>sys.maxsize
169169
cdef object __seed
170170
cdef object __stream
171+
cdef object __version
171172

172173
IF RS_RNG_SEED==1:
173174
def __init__(self, seed=None):
@@ -176,6 +177,8 @@ cdef class RandomState:
176177
IF RS_RNG_MOD_NAME == 'dsfmt':
177178
self.rng_state.buffered_uniforms = <double *>PyArray_malloc_aligned(2 * DSFMT_N * sizeof(double))
178179
self.lock = Lock()
180+
self.__version = 0
181+
179182
self.__seed = seed
180183
self.__stream = None
181184

@@ -186,6 +189,8 @@ cdef class RandomState:
186189
self.rng_state.rng = <rng_t *>PyArray_malloc_aligned(sizeof(rng_t))
187190
self.rng_state.binomial = &self.binomial_info
188191
self.lock = Lock()
192+
self.__version = 0
193+
189194
self.__seed = seed
190195
self.__stream = stream
191196

@@ -444,7 +449,8 @@ cdef class RandomState:
444449
'state': _get_state(self.rng_state),
445450
'gauss': {'has_gauss': self.rng_state.has_gauss, 'gauss': self.rng_state.gauss},
446451
'uint32': {'has_uint32': self.rng_state.has_uint32, 'uint32': self.rng_state.uinteger},
447-
'seed': self.__seed}
452+
'seed': self.__seed,
453+
'version': self.__version}
448454
if self.__stream is not None:
449455
state['stream'] = self.__stream
450456
return state
@@ -485,7 +491,8 @@ cdef class RandomState:
485491
'state': _get_state(self.rng_state),
486492
'gauss': {'has_gauss': self.rng_state.has_gauss, 'gauss': self.rng_state.gauss},
487493
'uint32': {'has_uint32': self.rng_state.has_uint32, 'uint32': self.rng_state.uinteger},
488-
'seed': self.__seed}
494+
'seed': self.__seed,
495+
'version': self.__version}
489496
if self.__stream is not None:
490497
state['stream'] = self.__stream
491498
return state
@@ -545,6 +552,9 @@ cdef class RandomState:
545552

546553
if state['name'] != rng_name:
547554
raise ValueError('Not a ' + rng_name + ' RNG state')
555+
if 'version' in state:
556+
if state['version'] != 0:
557+
raise NotImplementedError('Support for multiple version has not been implemented.')
548558

549559
_set_state(&self.rng_state, state['state'])
550560
self.rng_state.has_gauss = state['gauss']['has_gauss']

randomstate/tests/test_smoke.py

+5
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,11 @@ def test_pickle(self):
445445
print(unpick.get_state())
446446
assert_(comp_state(self.rs.get_state(), unpick.get_state()))
447447

448+
def test_version(self):
449+
state = self.rs.get_state()
450+
assert_('version' in state)
451+
assert_(state['version'] == 0)
452+
448453
def test_seed_array(self):
449454
if self.seed_vector_bits is None:
450455
raise SkipTest

0 commit comments

Comments
 (0)