Skip to content
This repository was archived by the owner on Jan 8, 2024. It is now read-only.

Commit 8fb651d

Browse files
committed
Rest of the equality and comparison operators.
1 parent f5559a3 commit 8fb651d

File tree

6 files changed

+186
-13
lines changed

6 files changed

+186
-13
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "imrc"
3-
version = "0.1.0"
3+
version = "0.2.0"
44
edition = "2021"
55

66
[lib]

src/lib.rs

Lines changed: 121 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,25 @@ impl HashMapPy {
142142
format!("HashMap({{{}}})", contents.collect::<Vec<_>>().join(", "))
143143
}
144144

145-
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
145+
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyResult<PyObject> {
146146
match op {
147-
CompareOp::Eq => (self.inner.len() == other.inner.len()).into_py(py),
148-
CompareOp::Ne => (self.inner.len() != other.inner.len()).into_py(py),
149-
_ => py.NotImplemented(),
147+
CompareOp::Eq => Ok((self.inner.len() == other.inner.len()
148+
&& self
149+
.inner
150+
.iter()
151+
.map(|(k1, v1)| (v1, other.inner.get(&k1)))
152+
.map(|(v1, v2)| PyAny::eq(v1.extract(py)?, v2))
153+
.all(|r| r.unwrap_or(false)))
154+
.into_py(py)),
155+
CompareOp::Ne => Ok((self.inner.len() != other.inner.len()
156+
|| self
157+
.inner
158+
.iter()
159+
.map(|(k1, v1)| (v1, other.inner.get(&k1)))
160+
.map(|(v1, v2)| PyAny::ne(v1.extract(py)?, v2))
161+
.all(|r| r.unwrap_or(true)))
162+
.into_py(py)),
163+
_ => Ok(py.NotImplemented()),
150164
}
151165
}
152166

@@ -255,6 +269,10 @@ impl<'source> FromPyObject<'source> for HashSetPy {
255269
}
256270
}
257271

272+
fn is_subset(one: &HashSet<Key>, two: &HashSet<Key>) -> bool {
273+
one.iter().all(|v| two.contains(v))
274+
}
275+
258276
#[pymethods]
259277
impl HashSetPy {
260278
#[new]
@@ -268,6 +286,22 @@ impl HashSetPy {
268286
}
269287
}
270288

289+
fn __and__(&self, other: &Self) -> Self {
290+
self.intersection(&other)
291+
}
292+
293+
fn __or__(&self, other: &Self) -> Self {
294+
self.union(&other)
295+
}
296+
297+
fn __sub__(&self, other: &Self) -> Self {
298+
self.difference(&other)
299+
}
300+
301+
fn __xor__(&self, other: &Self) -> Self {
302+
self.symmetric_difference(&other)
303+
}
304+
271305
fn __iter__(slf: PyRef<'_, Self>) -> PyResult<Py<KeyIterator>> {
272306
let iter = slf
273307
.inner
@@ -292,11 +326,19 @@ impl HashSetPy {
292326
format!("HashSet({{{}}})", contents.collect::<Vec<_>>().join(", "))
293327
}
294328

295-
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyObject {
329+
fn __richcmp__(&self, other: &Self, op: CompareOp, py: Python<'_>) -> PyResult<PyObject> {
296330
match op {
297-
CompareOp::Eq => (self.inner.len() == other.inner.len()).into_py(py),
298-
CompareOp::Ne => (self.inner.len() != other.inner.len()).into_py(py),
299-
_ => py.NotImplemented(),
331+
CompareOp::Eq => Ok((self.inner.len() == other.inner.len()
332+
&& is_subset(&self.inner, &other.inner))
333+
.into_py(py)),
334+
CompareOp::Ne => Ok((self.inner.len() != other.inner.len()
335+
|| self.inner.iter().any(|k| !other.inner.contains(k)))
336+
.into_py(py)),
337+
CompareOp::Lt => Ok((self.inner.len() < other.inner.len()
338+
&& is_subset(&self.inner, &other.inner))
339+
.into_py(py)),
340+
CompareOp::Le => Ok(is_subset(&self.inner, &other.inner).into_py(py)),
341+
_ => Ok(py.NotImplemented()),
300342
}
301343
}
302344

@@ -326,6 +368,69 @@ impl HashSetPy {
326368
}
327369
}
328370

371+
fn difference(&self, other: &Self) -> Self {
372+
let mut inner = self.inner.clone();
373+
for value in other.inner.iter() {
374+
inner.remove(value);
375+
}
376+
HashSetPy { inner }
377+
}
378+
379+
fn intersection(&self, other: &Self) -> Self {
380+
let mut inner: HashSet<Key> = HashSet::new();
381+
let larger: &HashSet<Key>;
382+
let iter;
383+
if self.inner.len() > other.inner.len() {
384+
larger = &self.inner;
385+
iter = other.inner.iter();
386+
} else {
387+
larger = &other.inner;
388+
iter = self.inner.iter();
389+
}
390+
for value in iter {
391+
if larger.contains(value) {
392+
inner.insert(value.to_owned());
393+
}
394+
}
395+
HashSetPy { inner }
396+
}
397+
398+
fn symmetric_difference(&self, other: &Self) -> Self {
399+
let mut inner: HashSet<Key>;
400+
let iter;
401+
if self.inner.len() > other.inner.len() {
402+
inner = self.inner.clone();
403+
iter = other.inner.iter();
404+
} else {
405+
inner = other.inner.clone();
406+
iter = self.inner.iter();
407+
}
408+
for value in iter {
409+
if inner.contains(value) {
410+
inner.remove(value);
411+
} else {
412+
inner.insert(value.to_owned());
413+
}
414+
}
415+
HashSetPy { inner }
416+
}
417+
418+
fn union(&self, other: &Self) -> Self {
419+
let mut inner: HashSet<Key>;
420+
let iter;
421+
if self.inner.len() > other.inner.len() {
422+
inner = self.inner.clone();
423+
iter = other.inner.iter();
424+
} else {
425+
inner = other.inner.clone();
426+
iter = self.inner.iter();
427+
}
428+
for value in iter {
429+
inner.insert(value.to_owned());
430+
}
431+
HashSetPy { inner }
432+
}
433+
329434
#[pyo3(signature = (*iterables))]
330435
fn update(&self, iterables: &PyTuple) -> PyResult<HashSetPy> {
331436
let mut inner = self.inner.clone();
@@ -414,6 +519,14 @@ impl VectorPy {
414519
.map(|(e1, e2)| PyAny::eq(e1.extract(py)?, e2))
415520
.all(|r| r.unwrap_or(false)))
416521
.into_py(py)),
522+
CompareOp::Ne => Ok((self.inner.len() != other.inner.len()
523+
|| self
524+
.inner
525+
.iter()
526+
.zip(other.inner.iter())
527+
.map(|(e1, e2)| PyAny::ne(e1.extract(py)?, e2))
528+
.any(|r| r.unwrap_or(true)))
529+
.into_py(py)),
417530
_ => Ok(py.NotImplemented()),
418531
}
419532
}

tests/test_hash_trie_map.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def test_update_with_multiple_arguments():
244244
def test_update_one_argument():
245245
x = HashMap(a=1)
246246

247-
assert x.update({"b": "2"}) == HashMap(a=1, b=2)
247+
assert x.update({"b": 2}) == HashMap(a=1, b=2)
248248

249249

250250
def test_update_no_arguments():
@@ -309,3 +309,16 @@ def test_convert_hashtriemap():
309309
def test_fast_convert_hashtriemap():
310310
m = HashMap({i: i * 2 for i in range(3)})
311311
assert HashMap.convert(m) is m
312+
313+
314+
def test_more_eq():
315+
# Non-pyrsistent-test-suite test
316+
o = object()
317+
318+
assert HashMap([(o, o), (1, o)]) == HashMap([(o, o), (1, o)])
319+
assert HashMap([(o, "foo")]) == HashMap([(o, "foo")])
320+
assert HashMap() == HashMap([])
321+
322+
assert HashMap({1: 2}) != HashMap({1: 3})
323+
assert HashMap({o: 1}) != HashMap({o: o})
324+
assert HashMap([]) != HashMap([(o, 1)])

tests/test_hash_trie_set.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def test_contains():
9696
assert 4 not in s
9797

9898

99-
@pytest.mark.xfail(reason="Can't figure out inheriting collections.abc yet")
10099
def test_supports_set_operations():
101100
s1 = HashSet([1, 2, 3])
102101
s2 = HashSet([3, 4, 5])
@@ -114,7 +113,6 @@ def test_supports_set_operations():
114113
assert s1.symmetric_difference(s2) == s1 ^ s2
115114

116115

117-
@pytest.mark.xfail(reason="Can't figure out inheriting collections.abc yet")
118116
def test_supports_set_comparisons():
119117
s1 = HashSet([1, 2, 3])
120118
s3 = HashSet([1, 2])
@@ -151,3 +149,33 @@ def test_update_no_elements():
151149

152150
def test_iterable():
153151
assert HashSet(iter("a")) == HashSet(iter("a"))
152+
153+
154+
def test_more_eq():
155+
# Non-pyrsistent-test-suite test
156+
o = object()
157+
158+
assert HashSet([o]) == HashSet([o])
159+
assert HashSet([o, o]) == HashSet([o, o])
160+
assert HashSet([o]) == HashSet([o, o])
161+
assert HashSet() == HashSet([])
162+
assert not (HashSet([1, 2]) == HashSet([1, 3]))
163+
assert not (HashSet([o, 1]) == HashSet([o, o]))
164+
assert not (HashSet([]) == HashSet([o]))
165+
166+
assert HashSet([1, 2]) != HashSet([1, 3])
167+
assert HashSet([]) != HashSet([o])
168+
assert not (HashSet([o]) != HashSet([o]))
169+
assert not (HashSet([o, o]) != HashSet([o, o]))
170+
assert not (HashSet([o]) != HashSet([o, o]))
171+
assert not (HashSet() != HashSet([]))
172+
173+
174+
def test_more_set_comparisons():
175+
s = HashSet([1, 2, 3])
176+
177+
assert s == s
178+
assert not (s < s)
179+
assert s <= s
180+
assert not (s > s)
181+
assert s >= s

tests/test_list.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,22 @@ def test_hashing():
106106
def test_sequence():
107107
m = Vector("asdf")
108108
assert m == Vector(["a", "s", "d", "f"])
109+
110+
111+
def test_more_eq():
112+
# Non-pyrsistent-test-suite test
113+
o = object()
114+
115+
assert Vector([o, o]) == Vector([o, o])
116+
assert Vector([o]) == Vector([o])
117+
assert Vector() == Vector([])
118+
assert not (Vector([1, 2]) == Vector([1, 3]))
119+
assert not (Vector([o]) == Vector([o, o]))
120+
assert not (Vector([]) == Vector([o]))
121+
122+
assert Vector([1, 2]) != Vector([1, 3])
123+
assert Vector([o]) != Vector([o, o])
124+
assert Vector([]) != Vector([o])
125+
assert not (Vector([o, o]) != Vector([o, o]))
126+
assert not (Vector([o]) != Vector([o]))
127+
assert not (Vector() != Vector([]))

0 commit comments

Comments
 (0)