Skip to content

Commit 03885c5

Browse files
Refactor: class to fn && Fix: warning
- Added DocStrs - Fixed small typo in test file name Co-authored-by: Afonso Antunes <[email protected]>
1 parent d2066a1 commit 03885c5

File tree

3 files changed

+88
-49
lines changed

3 files changed

+88
-49
lines changed

pandas/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@
347347
"wide_to_long",
348348
]
349349

350-
from .core.accessor import AccessorEntryPointLoader
350+
from .core.accessor import accessor_entry_point_loader
351351

352-
AccessorEntryPointLoader.load()
353-
del AccessorEntryPointLoader
352+
accessor_entry_point_loader()
353+
del accessor_entry_point_loader

pandas/core/accessor.py

Lines changed: 73 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
from pandas import Index
2727
from pandas.core.generic import NDFrame
2828

29-
from importlib.metadata import entry_points
29+
from importlib.metadata import (
30+
EntryPoints,
31+
entry_points,
32+
)
3033

3134

3235
class DirNamesMixin:
@@ -398,44 +401,78 @@ def register_index_accessor(name: str) -> Callable[[TypeT], TypeT]:
398401
return _register_accessor(name, Index)
399402

400403

401-
class AccessorEntryPointLoader: # is this a good name for the class?
402-
"""Loader class for registering accessors via entry points."""
404+
def accessor_entry_point_loader() -> None:
405+
"""
406+
Load and register pandas accessors declared via entry points.
403407
404-
ENTRY_POINT_GROUP: str = "pandas_accessor"
408+
This function scans the 'pandas.accessor' entry point group for accessors
409+
registered by third-party packages. Each entry point is expected to follow
410+
the format:
405411
406-
@classmethod
407-
def load(cls) -> None:
408-
"""loads and registers accessors defined by 'pandas_accessor'."""
409-
accessors = entry_points(group=cls.ENTRY_POINT_GROUP)
410-
unique_accessors_names: set[str] = set()
411-
412-
for accessor in accessors:
413-
# Verifies duplicated accessor names
414-
if accessor.name in unique_accessors_names:
415-
try:
416-
pkg_name: str = accessor.dist.name
417-
except Exception:
418-
pkg_name = "unknown"
419-
warnings.warn(
420-
"Warning: you have two accessors with the same name:"
421-
f" '{accessor.name}' has already been registered"
422-
f" by the package '{pkg_name}'. So the '{accessor.name}' "
423-
f"provided by the package '{pkg_name}' is not "
424-
f"being used. Uninstall the package you don't want"
425-
"to use if you want to get rid of this warning.\n",
426-
UserWarning,
427-
stacklevel=2,
428-
)
412+
TODO
429413
430-
else:
431-
unique_accessors_names.add(accessor.name)
414+
For example:
415+
416+
TODO
417+
TODO
418+
TODO
419+
420+
421+
For each valid entry point:
422+
- The accessor class is dynamically imported and registered using
423+
the appropriate registration decorator function
424+
(e.g. register_dataframe_accessor).
425+
- If two packages declare the same accessor name, a warning is issued,
426+
and only the first one is used.
427+
428+
Notes
429+
-----
430+
- This function is only intended to be called at pandas startup.
431+
432+
Raises
433+
------
434+
UserWarning
435+
If two accessors share the same name, the second one is ignored.
436+
437+
Examples
438+
--------
439+
>>> df.myplugin.do_something() # Assuming such accessor was registered
440+
"""
441+
442+
ENTRY_POINT_GROUP: str = "pandas.accessor"
443+
444+
accessors: EntryPoints = entry_points(group=ENTRY_POINT_GROUP)
445+
accessor_package_dict: dict[str, str] = {}
446+
447+
for new_accessor in accessors:
448+
try:
449+
new_pkg_name: str = new_accessor.dist.name
450+
except AttributeError:
451+
new_pkg_name: str = "Unknown"
452+
453+
# Verifies duplicated accessor names
454+
if new_accessor.name in accessor_package_dict:
455+
loaded_pkg_name: str = accessor_package_dict.get(new_accessor.name)
456+
457+
warnings.warn(
458+
"Warning: you have two accessors with the same name:"
459+
f" '{new_accessor.name}' has already been registered"
460+
f" by the package '{new_pkg_name}'. So the "
461+
f"'{new_accessor.name}' provided by the package "
462+
f"'{loaded_pkg_name}' is not being used. "
463+
"Uninstall the package you don't want"
464+
"to use if you want to get rid of this warning.\n",
465+
UserWarning,
466+
stacklevel=2,
467+
)
468+
469+
accessor_package_dict.update({new_accessor.name: new_pkg_name})
432470

433-
def make_property(ep):
434-
def accessor(self) -> Any:
435-
cls_ = ep.load()
436-
return cls_(self)
471+
def make_accessor(ep):
472+
def accessor(self) -> Any:
473+
cls_ = ep.load()
474+
return cls_(self)
437475

438-
return accessor
476+
return accessor
439477

440-
# _register_accessor()
441-
register_dataframe_accessor(accessor.name)(make_property(accessor))
478+
register_dataframe_accessor(new_accessor.name)(make_accessor(new_accessor))

pandas/tests/test_plugis_entrypoint_loader.py renamed to pandas/tests/test_plugins_entrypoint_loader.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import pandas as pd
22
import pandas._testing as tm
3-
from pandas.core.accessor import AccessorEntryPointLoader
3+
from pandas.core.accessor import accessor_entry_point_loader
44

55
# TODO: test for pkg names
66

7+
PANDAS_ENTRY_POINT_GROUP: str = "pandas.accessor"
8+
79

810
def test_no_accessors(monkeypatch):
911
# GH29076
@@ -15,7 +17,7 @@ def mock_entry_points(*, group):
1517
# Patch entry_points in the correct module
1618
monkeypatch.setattr("pandas.core.accessor.entry_points", mock_entry_points)
1719

18-
AccessorEntryPointLoader.load()
20+
accessor_entry_point_loader()
1921

2022

2123
def test_load_dataframe_accessors(monkeypatch):
@@ -36,14 +38,14 @@ def test_method(self):
3638

3739
# Mock entry_points
3840
def mock_entry_points(*, group):
39-
if group == AccessorEntryPointLoader.ENTRY_POINT_GROUP:
41+
if group == PANDAS_ENTRY_POINT_GROUP:
4042
return [MockEntryPoint()]
4143
return []
4244

4345
# Patch entry_points in the correct module
4446
monkeypatch.setattr("pandas.core.accessor.entry_points", mock_entry_points)
4547

46-
AccessorEntryPointLoader.load()
48+
accessor_entry_point_loader()
4749

4850
# Create DataFrame and verify that the accessor was registered
4951
df = pd.DataFrame({"a": [1, 2, 3]})
@@ -82,15 +84,15 @@ def which(self):
8284
return Accessor2
8385

8486
def mock_entry_points(*, group):
85-
if group == AccessorEntryPointLoader.ENTRY_POINT_GROUP:
87+
if group == PANDAS_ENTRY_POINT_GROUP:
8688
return [MockEntryPoint1(), MockEntryPoint2()]
8789
return []
8890

8991
monkeypatch.setattr("pandas.core.accessor.entry_points", mock_entry_points)
9092

9193
# Check that the UserWarning is raised
9294
with tm.assert_produces_warning(UserWarning, match="duplicate_accessor") as record:
93-
AccessorEntryPointLoader.load()
95+
accessor_entry_point_loader()
9496

9597
messages = [str(w.message) for w in record]
9698
assert any("you have two accessors with the same name:" in msg for msg in messages)
@@ -131,15 +133,15 @@ def which(self):
131133
return Accessor2
132134

133135
def mock_entry_points(*, group):
134-
if group == AccessorEntryPointLoader.ENTRY_POINT_GROUP:
136+
if group == PANDAS_ENTRY_POINT_GROUP:
135137
return [MockEntryPoint1(), MockEntryPoint2()]
136138
return []
137139

138140
monkeypatch.setattr("pandas.core.accessor.entry_points", mock_entry_points)
139141

140142
# Check that no UserWarning is raised
141143
with tm.assert_produces_warning(None, check_stacklevel=False):
142-
AccessorEntryPointLoader.load()
144+
accessor_entry_point_loader()
143145

144146
df = pd.DataFrame({"x": [1, 2, 3]})
145147
assert hasattr(df, "accessor1"), "Accessor1 not registered"
@@ -193,15 +195,15 @@ def which(self):
193195
return Accessor3
194196

195197
def mock_entry_points(*, group):
196-
if group == AccessorEntryPointLoader.ENTRY_POINT_GROUP:
198+
if group == PANDAS_ENTRY_POINT_GROUP:
197199
return [MockEntryPoint1(), MockEntryPoint2(), MockEntryPoint3()]
198200
return []
199201

200202
monkeypatch.setattr("pandas.core.accessor.entry_points", mock_entry_points)
201203

202204
# Capture warnings
203205
with tm.assert_produces_warning(UserWarning, match="duplicate_accessor") as record:
204-
AccessorEntryPointLoader.load()
206+
accessor_entry_point_loader()
205207

206208
messages = [str(w.message) for w in record]
207209

0 commit comments

Comments
 (0)