Skip to content

Commit fe3ecb8

Browse files
committed
bind_vector/map: implemented __repr__ function, fixed api<T> in-place ops
1 parent d80e994 commit fe3ecb8

File tree

7 files changed

+63
-3
lines changed

7 files changed

+63
-3
lines changed

include/nanobind/nb_lib.h

+7
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,16 @@ NB_CORE void set_implicit_cast_warnings(bool value) noexcept;
419419

420420
NB_CORE bool iterable_check(PyObject *o) noexcept;
421421

422+
// ========================================================================
423+
422424
NB_CORE void slice_compute(PyObject *slice, Py_ssize_t size,
423425
Py_ssize_t &start, Py_ssize_t &stop,
424426
Py_ssize_t &step, size_t &slice_length);
425427

428+
// ========================================================================
429+
430+
NB_CORE PyObject *repr_list(PyObject *o);
431+
NB_CORE PyObject *repr_map(PyObject *o);
432+
426433
NAMESPACE_END(detail)
427434
NAMESPACE_END(NB_NAMESPACE)

include/nanobind/nb_types.h

+15
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ public: \
5151
detail::obj_op_2(derived().ptr(), o.derived().ptr(), op)); \
5252
}
5353

54+
#define NB_API_OP_2_IO(name) \
55+
template <typename T> NB_INLINE decltype(auto) name(const api<T> &o) { \
56+
return operator=(handle::name(o)); \
57+
}
5458

5559
// A few forward declarations
5660
class object;
@@ -220,6 +224,16 @@ class object : public handle {
220224
temp.dec_ref();
221225
return *this;
222226
}
227+
228+
NB_API_OP_2_IO(operator+=)
229+
NB_API_OP_2_IO(operator-=)
230+
NB_API_OP_2_IO(operator*=)
231+
NB_API_OP_2_IO(operator/=)
232+
NB_API_OP_2_IO(operator|=)
233+
NB_API_OP_2_IO(operator&=)
234+
NB_API_OP_2_IO(operator^=)
235+
NB_API_OP_2_IO(operator<<=)
236+
NB_API_OP_2_IO(operator>>=)
223237
};
224238

225239
template <typename T> NB_INLINE T borrow(handle h) {
@@ -670,3 +684,4 @@ NAMESPACE_END(NB_NAMESPACE)
670684
#undef NB_API_OP_1
671685
#undef NB_API_OP_2
672686
#undef NB_API_OP_2_I
687+
#undef NB_API_OP_2_IO

include/nanobind/stl/bind_map.h

+5
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ class_<Map> bind_map(handle scope, const char *name, Args &&...args) {
5252
[](const Map &m) { return !m.empty(); },
5353
"Check whether the map is nonempty")
5454

55+
.def("__repr__",
56+
[](handle_t<Map> h) {
57+
return steal<str>(detail::repr_map(h.ptr()));
58+
})
59+
5560
.def("__contains__",
5661
[](const Map &m, const Key &k) { return m.find(k) != m.end(); })
5762

include/nanobind/stl/bind_vector.h

+5
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ class_<Vector> bind_vector(handle scope, const char *name, Args &&...args) {
5757
[](const Vector &v) { return !v.empty(); },
5858
"Check whether the vector is nonempty")
5959

60+
.def("__repr__",
61+
[](handle_t<Vector> h) {
62+
return steal<str>(detail::repr_list(h.ptr()));
63+
})
64+
6065
.def("__iter__",
6166
[](Vector &v) {
6267
return make_iterator(type<Vector>(), "Iterator",

src/common.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -962,5 +962,34 @@ bool iterable_check(PyObject *o) noexcept {
962962
#endif
963963
}
964964

965+
// ========================================================================
966+
967+
NB_CORE PyObject *repr_list(PyObject *o) {
968+
object s = steal(nb_inst_name(o));
969+
s += str("([");
970+
size_t len = obj_len(o);
971+
for (size_t i = 0; i < len; ++i) {
972+
s += repr(handle(o)[i]);
973+
if (i + 1 < len)
974+
s += str(", ");
975+
}
976+
s += str("])");
977+
return s.release().ptr();
978+
}
979+
980+
NB_CORE PyObject *repr_map(PyObject *o) {
981+
object s = steal(nb_inst_name(o));
982+
s += str("({");
983+
bool first = true;
984+
for (handle kv : handle(o).attr("items")()) {
985+
if (!first)
986+
s += str(", ");
987+
s += repr(kv[0]) + str(": ") + repr(kv[1]);
988+
first = false;
989+
}
990+
s += str("})");
991+
return s.release().ptr();
992+
}
993+
965994
NAMESPACE_END(detail)
966995
NAMESPACE_END(NB_NAMESPACE)

tests/test_stl_bind_map.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def test_map_string_double():
3434
assert len(mm2) == 1
3535
mm2.clear()
3636
assert len(mm2) == 0
37+
assert repr(mm) == "test_bind_map_ext.MapStringDouble({'a': 1.0, 'b': 2.5})"
3738

3839
with pytest.warns(RuntimeWarning, match="implicit conversion from type 'dict' to type 'test_bind_map_ext.MapStringDouble' failed"):
3940
with pytest.raises(TypeError):
@@ -80,7 +81,6 @@ def test_map_string_double():
8081
assert list(um.keys()) == list(um)
8182
assert sorted(list(um.items())) == [("ua", 1.1), ("ub", 2.6)]
8283
assert list(zip(um.keys(), um.values())) == list(um.items())
83-
assert "UnorderedMapStringDouble" in str(um)
8484

8585
assert type(keys).__qualname__ == 'MapStringDouble.KeyView'
8686
assert type(values).__qualname__ == 'MapStringDouble.ValueView'

tests/test_stl_bind_vector.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def test01_vector_int():
1010
# test construction from a generator
1111
v_int1 = t.VectorInt(x for x in range(5))
1212
assert t.VectorInt(v_int1) == t.VectorInt([0, 1, 2, 3, 4])
13+
assert repr(v_int1) == "test_bind_vector_ext.VectorInt([0, 1, 2, 3, 4])"
1314

1415
v_int2 = t.VectorInt([0, 0])
1516
assert v_int == v_int2
@@ -121,8 +122,6 @@ def check_del(s):
121122
del l1c[s]
122123
del l2c[s]
123124
l2c = list(l2c)
124-
print(repr(l1c))
125-
print(repr(l2c))
126125
assert l1c == l2c
127126

128127
check_same(slice(1, 13, 4))

0 commit comments

Comments
 (0)