Skip to content

Commit 2f8a9a2

Browse files
committed
Add XorShiftRng serde
1 parent fd99e8b commit 2f8a9a2

File tree

1 file changed

+184
-0
lines changed

1 file changed

+184
-0
lines changed

src/lib.rs

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,14 @@ use std::mem;
257257
use std::io;
258258
use std::rc::Rc;
259259
use std::num::Wrapping as w;
260+
#[cfg(feature="serde-1")]
261+
use std::fmt;
262+
263+
#[cfg(feature = "serde-1")]
264+
use serde::{Deserialize, Deserializer, Serialize, Serializer};
265+
#[cfg(feature="serde-1")]
266+
use serde::de::Visitor;
267+
260268

261269
pub use os::OsRng;
262270

@@ -804,6 +812,154 @@ impl Rand for XorShiftRng {
804812
}
805813
}
806814

815+
#[cfg(feature = "serde-1")]
816+
impl Serialize for XorShiftRng {
817+
fn serialize<S>(&self, ser: S) -> Result<S::Ok, S::Error>
818+
where
819+
S: Serializer,
820+
{
821+
use serde::ser::SerializeStruct;
822+
823+
let mut state = ser.serialize_struct("XorShiftRng",6)?;
824+
825+
let w(x) = self.x;
826+
state.serialize_field("x", &x)?;
827+
828+
let w(y) = self.y;
829+
state.serialize_field("y", &y)?;
830+
831+
let w(z) = self.z;
832+
state.serialize_field("z", &z)?;
833+
834+
let w(w_field) = self.w;
835+
state.serialize_field("w", &w_field)?;
836+
837+
state.end()
838+
}
839+
}
840+
841+
#[cfg(feature="serde-1")]
842+
impl<'de> Deserialize<'de> for XorShiftRng {
843+
fn deserialize<D>(de: D) -> Result<XorShiftRng, D::Error>
844+
where D: Deserializer<'de> {
845+
use serde::de::{SeqAccess,MapAccess};
846+
use serde::de;
847+
848+
enum Field { X, Y, Z, W };
849+
850+
impl<'de> Deserialize<'de> for Field {
851+
fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
852+
where D: Deserializer<'de> {
853+
struct XorFieldVisitor;
854+
impl<'de> Visitor<'de> for XorFieldVisitor {
855+
type Value = Field;
856+
857+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
858+
formatter.write_str("`x`, `y`, `z`, or `w`")
859+
}
860+
861+
fn visit_str<E>(self, value: &str) -> Result<Field,E>
862+
where E: de::Error {
863+
match value {
864+
"x" => Ok(Field::X),
865+
"y" => Ok(Field::Y),
866+
"z" => Ok(Field::Z),
867+
"w" => Ok(Field::W),
868+
_ => Err(de::Error::unknown_field(value, FIELDS))
869+
}
870+
}
871+
}
872+
deserializer.deserialize_identifier(XorFieldVisitor)
873+
}
874+
}
875+
876+
struct XorVisitor;
877+
878+
const FIELDS: &[&'static str] = &["x", "y", "z", "w"];
879+
880+
impl<'de> Visitor<'de> for XorVisitor {
881+
type Value = XorShiftRng;
882+
883+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
884+
formatter.write_str("struct XorShiftRng")
885+
}
886+
887+
fn visit_seq<V>(self, mut seq: V) -> Result<XorShiftRng, V::Error>
888+
where V: SeqAccess<'de> {
889+
let x: u32 = seq.next_element()?
890+
.ok_or_else(|| de::Error::invalid_length(0,&self))?;
891+
892+
let y: u32 = seq.next_element()?
893+
.ok_or_else(|| de::Error::invalid_length(1, &self))?;
894+
895+
let z: u32 = seq.next_element()?
896+
.ok_or_else(|| de::Error::invalid_length(2, &self))?;
897+
898+
let w_field: u32 = seq.next_element()?
899+
.ok_or_else(|| de::Error::invalid_length(3, &self))?;
900+
901+
902+
let (x,y,z,w_field) = (w(x), w(y), w(z), w(w_field));
903+
904+
Ok(XorShiftRng {
905+
x: x,y: y,z: z,w: w_field
906+
})
907+
}
908+
909+
fn visit_map<V>(self, mut map: V) -> Result<XorShiftRng, V::Error>
910+
where V: MapAccess<'de>
911+
{
912+
let mut x = None;
913+
let mut y = None;
914+
let mut z = None;
915+
let mut w_field = None;
916+
917+
while let Some(key) = map.next_key()? {
918+
match key {
919+
Field::X => {
920+
if x.is_some() {
921+
return Err(de::Error::duplicate_field("x"));
922+
}
923+
x = Some(map.next_value()?);
924+
}
925+
Field::Y => {
926+
if y.is_some() {
927+
return Err(de::Error::duplicate_field("y"));
928+
}
929+
y = Some(map.next_value()?);
930+
}
931+
Field::Z => {
932+
if z.is_some() {
933+
return Err(de::Error::duplicate_field("z"));
934+
}
935+
z = Some(map.next_value()?);
936+
}
937+
Field::W => {
938+
if w_field.is_some() {
939+
return Err(de::Error::duplicate_field("w"));
940+
}
941+
w_field = Some(map.next_value()?);
942+
}
943+
}
944+
}
945+
946+
let x = x.ok_or_else(|| de::Error::missing_field("x"))?;
947+
let y = y.ok_or_else(|| de::Error::missing_field("y"))?;
948+
let z = z.ok_or_else(|| de::Error::missing_field("z"))?;
949+
let w_field = w_field.ok_or_else(|| de::Error::missing_field("w"))?;
950+
951+
let (x,y,z,w_field) = (w(x), w(y), w(z), w(w_field));
952+
953+
Ok(XorShiftRng {
954+
x: x,y: y,z: z,w: w_field
955+
})
956+
}
957+
}
958+
959+
de.deserialize_struct("IsaacRng", FIELDS, XorVisitor)
960+
}
961+
}
962+
807963
/// A wrapper for generating floating point numbers uniformly in the
808964
/// open interval `(0,1)` (not including either endpoint).
809965
///
@@ -1312,4 +1468,32 @@ mod test {
13121468
assert_eq!(rng.next_u64(), deserialized.next_u64());
13131469
}
13141470
}
1471+
1472+
#[test]
1473+
#[cfg(feature="serde-1")]
1474+
fn test_xor_serde() {
1475+
use super::XorShiftRng;
1476+
use bincode;
1477+
use std::io::{BufWriter, BufReader};
1478+
1479+
let seed: [u32; 4] = thread_rng().gen();
1480+
let mut rng: XorShiftRng = SeedableRng::from_seed(seed);
1481+
1482+
let buf: Vec<u8> = Vec::new();
1483+
let mut buf = BufWriter::new(buf);
1484+
bincode::serialize_into(&mut buf, &rng, bincode::Infinite).expect("Could not serialize");
1485+
1486+
let buf = buf.into_inner().unwrap();
1487+
let mut read = BufReader::new(&buf[..]);
1488+
let mut deserialized: XorShiftRng = bincode::deserialize_from(&mut read, bincode::Infinite).expect("Could not deserialize");
1489+
1490+
assert_eq!(rng.x, deserialized.x);
1491+
assert_eq!(rng.y, deserialized.y);
1492+
assert_eq!(rng.z, deserialized.z);
1493+
assert_eq!(rng.w, deserialized.w);
1494+
1495+
for _ in 0..16 {
1496+
assert_eq!(rng.next_u64(), deserialized.next_u64());
1497+
}
1498+
}
13151499
}

0 commit comments

Comments
 (0)