|
| 1 | +NAMESPACE_BEGIN(nanobind) |
| 2 | +NAMESPACE_BEGIN(detail) |
| 3 | + |
| 4 | +template <typename Func, typename Return, typename... Args, typename... Extra> |
| 5 | +object func_create(Func &&f, Return (*)(Args...), const Extra &...extra) { |
| 6 | + struct capture { |
| 7 | + std::remove_reference_t<Func> f; |
| 8 | + }; |
| 9 | + |
| 10 | + // Store the capture object in the function record if there is space |
| 11 | + constexpr bool IsSmall = sizeof(capture) <= sizeof(void *) * 3; |
| 12 | + constexpr bool IsTrivial = std::is_trivially_destructible_v<capture>; |
| 13 | + |
| 14 | + void *func_rec = func_alloc(); |
| 15 | + void (*free_capture)(void *ptr) = nullptr; |
| 16 | + |
| 17 | + if constexpr (IsSmall) { |
| 18 | + capture *cap = std::launder((capture *) func_rec); |
| 19 | + new (cap) capture{ std::forward<Func>(f) }; |
| 20 | + |
| 21 | + if constexpr (!IsTrivial) { |
| 22 | + free_capture = [](void *func_rec_2) { |
| 23 | + capture *cap_2 = std::launder((capture *) func_rec_2); |
| 24 | + cap_2->~capture(); |
| 25 | + }; |
| 26 | + } |
| 27 | + } else { |
| 28 | + void **cap = std::launder((void **) func_rec); |
| 29 | + cap[0] = new capture{ std::forward<Func>(f) }; |
| 30 | + |
| 31 | + free_capture = [](void *func_rec_2) { |
| 32 | + void **cap_2 = std::launder((void **) func_rec_2); |
| 33 | + delete (capture *) cap_2[0]; |
| 34 | + }; |
| 35 | + } |
| 36 | + |
| 37 | + auto impl = [](void *func_rec_2) -> PyObject * { |
| 38 | + capture *cap; |
| 39 | + if constexpr (IsSmall) |
| 40 | + cap = std::launder((capture *) func_rec_2); |
| 41 | + else |
| 42 | + cap = std::launder((void **) func_rec_2)[0]; |
| 43 | + |
| 44 | + cap->f(); |
| 45 | + |
| 46 | + return nullptr; |
| 47 | + }; |
| 48 | + |
| 49 | + (detail::func_apply(func_rec, extra), ...); |
| 50 | + |
| 51 | + return reinterpret_steal<object>(func_init(func_rec, free_capture, impl)); |
| 52 | +} |
| 53 | + |
| 54 | +template <typename T> |
| 55 | +constexpr bool is_lambda_v = !std::is_function_v<T> && !std::is_pointer_v<T> && |
| 56 | + !std::is_member_pointer_v<T>; |
| 57 | + |
| 58 | + |
| 59 | +/// Strip the class from a method type |
| 60 | +template <typename T> struct remove_class { }; |
| 61 | +template <typename C, typename R, typename... A> struct remove_class<R (C::*)(A...)> { using type = R (A...); }; |
| 62 | +template <typename C, typename R, typename... A> struct remove_class<R (C::*)(A...) const> { using type = R (A...); }; |
| 63 | + |
| 64 | +NAMESPACE_END(detail) |
| 65 | + |
| 66 | +template <bool V> using enable_if_t = std::enable_if_t<V, int>; |
| 67 | + |
| 68 | +template <typename Return, typename... Args, typename... Extra> |
| 69 | +object cpp_function(Return (*f)(Args...), const Extra&... extra) { |
| 70 | + return detail::func_create(f, f, extra...); |
| 71 | +} |
| 72 | + |
| 73 | +/// Construct a cpp_function from a lambda function (possibly with internal state) |
| 74 | +template <typename Func, typename... Extra, |
| 75 | + enable_if_t<detail::is_lambda_v<std::remove_reference_t<Func>>> = 0> |
| 76 | +object cpp_function(Func &&f, const Extra &...extra) { |
| 77 | + using RawFunc = |
| 78 | + typename detail::remove_class<decltype(&Func::operator())>::type; |
| 79 | + return detail::func_create(std::forward<Func>(f), (RawFunc *) nullptr, |
| 80 | + extra...); |
| 81 | +} |
| 82 | + |
| 83 | +/// Construct a cpp_function from a class method (non-const, no ref-qualifier) |
| 84 | +template <typename Return, typename Class, typename... Args, typename... Extra> |
| 85 | +object cpp_function(Return (Class::*f)(Args...), const Extra&... extra) { |
| 86 | + return detail::func_create( |
| 87 | + [f](Class *c, Args... args) -> Return { |
| 88 | + return (c->*f)(std::forward<Args>(args)...); |
| 89 | + }, |
| 90 | + (Return(*)(Class *, Args...)) nullptr, extra...); |
| 91 | +} |
| 92 | + |
| 93 | +/// Construct a cpp_function from a class method (const, no ref-qualifier) |
| 94 | +template <typename Return, typename Class, typename... Args, typename... Extra> |
| 95 | +object cpp_function(Return (Class::*f)(Args...) const, const Extra &...extra) { |
| 96 | + return detail::func_create( |
| 97 | + [f](const Class *c, Args... args) -> Return { |
| 98 | + return (c->*f)(std::forward<Args>(args)...); |
| 99 | + }, |
| 100 | + (Return(*)(const Class *, Args...)) nullptr, extra...); |
| 101 | +} |
| 102 | + |
| 103 | +template <typename Func, typename... Extra> |
| 104 | +module_ &module_::def(const char *name_, Func &&f, const Extra &...extra) { |
| 105 | + object func = cpp_function(std::forward<Func>(f), name(name_), scope(*this), |
| 106 | + pred(getattr(*this, name_, none())), extra...); |
| 107 | + if (PyModule_AddObject(m_ptr, name_, func.release().ptr())) |
| 108 | + detail::fail("module::def(): could not add object!"); |
| 109 | + return *this; |
| 110 | +} |
| 111 | + |
| 112 | +NAMESPACE_END(nanobind) |
0 commit comments