Skip to content

Commit

Permalink
dev(hansbug): try fix
Browse files Browse the repository at this point in the history
  • Loading branch information
HansBug committed Oct 20, 2024
1 parent 356a7f6 commit 68bc5d8
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 32 deletions.
52 changes: 27 additions & 25 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,56 +16,38 @@ jobs:
- 'windows-latest' # need to be fixed, see: https://github.com/opendilab/treevalue/issues/41
- 'macos-latest'
python-version:
- '3.7'
- '3.8'
- '3.9'
- '3.10'
- '3.11'
- '3.12'
numpy-version:
- '1.18.0'
- '1.22.0'
- '1.24.0'
- '2.0.2'
torch-version:
- '1.2.0'
- '1.6.0'
- '1.10.0'
- '1.13.0'
- '2.0.1'
- '2.4.0'
exclude:
- python-version: '3.7'
numpy-version: '1.22.0'
- python-version: '3.7'
numpy-version: '1.24.0'
- python-version: '3.9'
numpy-version: '1.18.0'
- python-version: '3.10'
numpy-version: '1.18.0'
- python-version: '3.11'
numpy-version: '1.18.0'
- python-version: '3.12'
numpy-version: '1.18.0'
- python-version: '3.7'
torch-version: '2.0.1'
- python-version: '3.8'
torch-version: '1.2.0'
- python-version: '3.9'
torch-version: '1.2.0'
- python-version: '3.9'
torch-version: '1.6.0'
- python-version: '3.10'
torch-version: '1.2.0'
- python-version: '3.10'
torch-version: '1.6.0'
- python-version: '3.10'
torch-version: '1.10.0'
- python-version: '3.11'
torch-version: '1.2.0'
- python-version: '3.11'
torch-version: '1.6.0'
- python-version: '3.11'
torch-version: '1.10.0'
- os: 'windows-latest'
torch-version: '1.2.0'
- os: 'windows-latest'
torch-version: '1.6.0'
- python-version: '3.12'
torch-version: '1.10.0'
- os: 'windows-latest'
python-version: '3.11'
torch-version: '1.13.0'
Expand All @@ -75,6 +57,15 @@ jobs:
- os: 'ubuntu-latest'
python-version: '3.11'
numpy-version: '1.22.0'
- os: 'windows-latest'
python-version: '3.12'
torch-version: '1.13.0'
- os: 'macos-latest'
python-version: '3.12'
torch-version: '1.13.0'
- os: 'ubuntu-latest'
python-version: '3.12'
numpy-version: '1.22.0'
- os: 'windows-latest'
python-version: '3.9'
numpy-version: '1.18.0'
Expand All @@ -90,6 +81,17 @@ jobs:
- os: 'macos-latest'
python-version: '3.11'
numpy-version: '1.22.0'
- os: 'windows-latest'
python-version: '3.12'
numpy-version: '1.18.0'
- os: 'macos-latest'
python-version: '3.12'
numpy-version: '1.18.0'
- os: 'macos-latest'
python-version: '3.12'
numpy-version: '1.22.0'
- python-version: '3.8'
numpy-version: '2.0.2'

steps:
- name: Get system version for Linux
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
treevalue>=1.4.11
torch>=1.1.0
treevalue>=1.5.0
torch>=1.10.0
hbutils>=0.6.13
numpy
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _load_req(file: str):
url='https://github.com/opendilab/DI-treetensor',

# environment
python_requires=">=3.7",
python_requires=">=3.8",
install_requires=requirements,
tests_require=group_requirements['test'],
extras_require=group_requirements,
Expand All @@ -56,10 +56,11 @@ def _load_req(file: str):
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
'Programming Language :: Python :: Implementation :: PyPy'
],
)
14 changes: 13 additions & 1 deletion treetensor/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from .funcs import get_func_from_numpy
from ..config.meta import __VERSION__

try:
from numpy.core._multiarray_umath import _ArrayFunctionDispatcher
except (ImportError, ModuleNotFoundError):
_ArrayFunctionDispatcher = None

__all__ = [
*_funcs_all,
*_array_all,
Expand All @@ -23,6 +28,13 @@
)
_np_all = set(np.__all__)

_l_func_types = [BuiltinFunctionType, FunctionType]
if _ArrayFunctionDispatcher:
_l_func_types.append(_ArrayFunctionDispatcher)
if getattr(np, 'ufunc'):
_l_func_types.append(np.ufunc)
_func_types = tuple(_l_func_types)


class _Module(ModuleType):
def __init__(self, module):
Expand All @@ -40,7 +52,7 @@ def __getattr__(self, name):
return getattr(self.__origin__, name)
else:
item = getattr(np, name)
if isinstance(item, (FunctionType, BuiltinFunctionType)) and not name.startswith('_'):
if isinstance(item, _func_types) and not name.startswith('_'):
return get_func_from_numpy(name)
elif isinstance(item, _basic_types) and name in _np_all:
return item
Expand Down
3 changes: 1 addition & 2 deletions treetensor/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def __getattr__(self, name):
item = getattr(torch, name)
if isinstance(item, (FunctionType, BuiltinFunctionType)) and not name.startswith('_'):
return get_func_from_torch(name)
elif (isinstance(item, torch.dtype)) or \
isinstance(item, _basic_types) and name in _torch_all:
elif isinstance(item, torch.dtype) or isinstance(item, _basic_types):
return item
else:
raise AttributeError(f'Attribute {repr(name)} not found in {repr(__name__)}.')
Expand Down

0 comments on commit 68bc5d8

Please sign in to comment.