Skip to content

Commit cd629ae

Browse files
committed
Add XorShiftRng serde
1 parent 1fb9b81 commit cd629ae

File tree

1 file changed

+183
-0
lines changed

1 file changed

+183
-0
lines changed

src/lib.rs

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,13 @@ use std::io;
258258
use std::rc::Rc;
259259
use std::num::Wrapping as w;
260260
use std::time;
261+
#[cfg(feature="serde-1")]
262+
use std::fmt;
263+
264+
#[cfg(feature = "serde-1")]
265+
use serde::{Deserialize, Deserializer, Serialize, Serializer};
266+
#[cfg(feature="serde-1")]
267+
use serde::de::Visitor;
261268

262269
pub use os::OsRng;
263270

@@ -805,6 +812,154 @@ impl Rand for XorShiftRng {
805812
}
806813
}
807814

815+
#[cfg(feature = "serde-1")]
816+
impl Serialize for XorShiftRng {
817+
fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
818+
where
819+
S: Serializer,
820+
{
821+
use serde::ser::SerializeStruct;
822+
823+
let mut state = ser.serialize_struct("XorShiftRng",6)?;
824+
825+
let w(x) = self.x;
826+
state.serialize_field("x", &x)?;
827+
828+
let w(y) = self.y;
829+
state.serialize_field("y", &y)?;
830+
831+
let w(z) = self.z;
832+
state.serialize_field("z", &z)?;
833+
834+
let w(w_field) = self.w;
835+
state.serialize_field("w", &w_field)?;
836+
837+
state.end()
838+
}
839+
}
840+
841+
#[cfg(feature="serde-1")]
842+
impl<'de> Deserialize<'de> for XorShiftRng {
843+
fn deserialize<D>(de: D) -> Result<XorShiftRng, D::Error>
844+
where D: Deserializer<'de> {
845+
use serde::de::{SeqAccess,MapAccess};
846+
use serde::de;
847+
848+
enum Field { X, Y, Z, W };
849+
850+
impl<'de> Deserialize<'de> for Field {
851+
fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
852+
where D: Deserializer<'de> {
853+
struct XorFieldVisitor;
854+
impl<'de> Visitor<'de> for XorFieldVisitor {
855+
type Value = Field;
856+
857+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
858+
formatter.write_str("`x`, `y`, `z`, or `w`")
859+
}
860+
861+
fn visit_str<E>(self, value: &str) -> Result<Field,E>
862+
where E: de::Error {
863+
match value {
864+
"x" => Ok(Field::X),
865+
"y" => Ok(Field::Y),
866+
"z" => Ok(Field::Z),
867+
"w" => Ok(Field::W),
868+
_ => Err(de::Error::unknown_field(value, FIELDS))
869+
}
870+
}
871+
}
872+
deserializer.deserialize_identifier(XorFieldVisitor)
873+
}
874+
}
875+
876+
struct XorVisitor;
877+
878+
const FIELDS: &[&'static str] = &["x", "y", "z", "w"];
879+
880+
impl<'de> Visitor<'de> for XorVisitor {
881+
type Value = XorShiftRng;
882+
883+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
884+
formatter.write_str("struct XorShiftRng")
885+
}
886+
887+
fn visit_seq<V>(self, mut seq: V) -> Result<XorShiftRng, V::Error>
888+
where V: SeqAccess<'de> {
889+
let x: u32 = seq.next_element()?
890+
.ok_or_else(|| de::Error::invalid_length(0,&self))?;
891+
892+
let y: u32 = seq.next_element()?
893+
.ok_or_else(|| de::Error::invalid_length(1, &self))?;
894+
895+
let z: u32 = seq.next_element()?
896+
.ok_or_else(|| de::Error::invalid_length(2, &self))?;
897+
898+
let w_field: u32 = seq.next_element()?
899+
.ok_or_else(|| de::Error::invalid_length(3, &self))?;
900+
901+
902+
let (x,y,z,w_field) = (w(x), w(y), w(z), w(w_field));
903+
904+
Ok(XorShiftRng {
905+
x: x,y: y,z: z,w: w_field
906+
})
907+
}
908+
909+
fn visit_map<V>(self, mut map: V) -> Result<XorShiftRng, V::Error>
910+
where V: MapAccess<'de>
911+
{
912+
let mut x = None;
913+
let mut y = None;
914+
let mut z = None;
915+
let mut w_field = None;
916+
917+
while let Some(key) = map.next_key()? {
918+
match key {
919+
Field::X => {
920+
if x.is_some() {
921+
return Err(de::Error::duplicate_field("x"));
922+
}
923+
x = Some(map.next_value()?);
924+
}
925+
Field::Y => {
926+
if y.is_some() {
927+
return Err(de::Error::duplicate_field("y"));
928+
}
929+
y = Some(map.next_value()?);
930+
}
931+
Field::Z => {
932+
if z.is_some() {
933+
return Err(de::Error::duplicate_field("z"));
934+
}
935+
z = Some(map.next_value()?);
936+
}
937+
Field::W => {
938+
if w_field.is_some() {
939+
return Err(de::Error::duplicate_field("w"));
940+
}
941+
w_field = Some(map.next_value()?);
942+
}
943+
}
944+
}
945+
946+
let x = x.ok_or_else(|| de::Error::missing_field("x"))?;
947+
let y = y.ok_or_else(|| de::Error::missing_field("y"))?;
948+
let z = z.ok_or_else(|| de::Error::missing_field("z"))?;
949+
let w_field = w_field.ok_or_else(|| de::Error::missing_field("w"))?;
950+
951+
let (x,y,z,w_field) = (w(x), w(y), w(z), w(w_field));
952+
953+
Ok(XorShiftRng {
954+
x: x,y: y,z: z,w: w_field
955+
})
956+
}
957+
}
958+
959+
de.deserialize_struct("IsaacRng", FIELDS, XorVisitor)
960+
}
961+
}
962+
808963
/// A wrapper for generating floating point numbers uniformly in the
809964
/// open interval `(0,1)` (not including either endpoint).
810965
///
@@ -1328,4 +1483,32 @@ mod test {
13281483
assert_eq!(rng.next_u64(), deserialized.next_u64());
13291484
}
13301485
}
1486+
1487+
#[cfg(feature="serde-1")]
1488+
#[test]
1489+
fn test_xor_serde() {
1490+
use super::XorShiftRng;
1491+
use bincode;
1492+
use std::io::{BufWriter, BufReader};
1493+
1494+
let seed: [u32; 4] = thread_rng().gen();
1495+
let mut rng: XorShiftRng = SeedableRng::from_seed(seed);
1496+
1497+
let buf: Vec<u8> = Vec::new();
1498+
let mut buf = BufWriter::new(buf);
1499+
bincode::serialize_into(&mut buf, &rng, bincode::Infinite).expect("Could not serialize");
1500+
1501+
let buf = buf.into_inner().unwrap();
1502+
let mut read = BufReader::new(&buf[..]);
1503+
let mut deserialized: XorShiftRng = bincode::deserialize_from(&mut read, bincode::Infinite).expect("Could not deserialize");
1504+
1505+
assert_eq!(rng.x, deserialized.x);
1506+
assert_eq!(rng.y, deserialized.y);
1507+
assert_eq!(rng.z, deserialized.z);
1508+
assert_eq!(rng.w, deserialized.w);
1509+
1510+
for _ in 0..16 {
1511+
assert_eq!(rng.next_u64(), deserialized.next_u64());
1512+
}
1513+
}
13311514
}

0 commit comments

Comments
 (0)