Skip to content

Commit bb4de49

Browse files
committed
stubgen.py: handle method aliases similarly to function aliases
1 parent 8ce0dee commit bb4de49

File tree

3 files changed

+13
-2
lines changed

3 files changed

+13
-2
lines changed

src/stubgen.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,14 @@ def put_nb_func(self, fn: NbFunction, name: Optional[str] = None) -> None:
376376
self.write_ln(f"@{overload}")
377377
self.put_nb_overload(fn, s, name)
378378

379+
def put_nb_method(self, fn: NbFunction, name: Optional[str] = None) -> None:
380+
fn_name = getattr(fn, "__name__", None)
381+
# Check if this function is an alias
382+
if name and fn_name and name != fn_name:
383+
self.write_ln(f"{name} = {fn_name}\n")
384+
return
385+
self.put_nb_func(fn, name)
386+
379387
def put_function(self, fn: Callable[..., Any], name: Optional[str] = None, parent: Optional[object] = None):
380388
"""Append a function of an arbitrary type to the stub"""
381389
# Don't generate a constructor for nanobind classes that aren't constructible
@@ -848,7 +856,7 @@ def put(self, value: object, name: Optional[str] = None, parent: Optional[object
848856
elif tp_mod == "nanobind":
849857
if tp_name == "nb_method":
850858
value = cast(NbFunction, value)
851-
self.put_nb_func(value, name)
859+
self.put_nb_method(value, name)
852860
elif tp_name == "nb_static_property":
853861
value = cast(NbStaticProperty, value)
854862
self.put_nb_static_property(name, value)

tests/test_typing.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,10 @@ NB_MODULE(test_typing_ext, m) {
5757

5858
m.def("makeNestedClass", [] { return NestedClass(); });
5959

60-
// Aliases to local functoins and types
60+
// Aliases to functions and types
6161
m.attr("FooAlias") = m.attr("Foo");
6262
m.attr("f_alias") = m.attr("f");
63+
nb::type<Foo>().attr("lt_alias") = nb::type<Foo>().attr("__lt__");
6364

6465
// Custom signature generation for classes and methods
6566
struct CustomSignature { int value; };

tests/test_typing_ext.pyi.ref

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class Foo:
3232

3333
def __ge__(self, arg: Foo, /) -> bool: ...
3434

35+
lt_alias = __lt__
36+
3537
FooAlias: TypeAlias = Foo
3638

3739
T = TypeVar("T", contravariant=True)

0 commit comments

Comments
 (0)