From babbf6f5fc0b29f77f80fffe37c0d448c0048d9c Mon Sep 17 00:00:00 2001 From: HansBug Date: Mon, 18 Sep 2023 17:00:42 +0800 Subject: [PATCH 1/7] dev(narugo): add supported for vmap --- test/torch/funcs/test_wrapper.py | 100 +++++++++++++++++++++++++++++ treetensor/torch/funcs/__init__.py | 3 + treetensor/torch/funcs/base.py | 14 ++++ treetensor/torch/funcs/wrapper.py | 13 ++++ 4 files changed, 130 insertions(+) create mode 100644 test/torch/funcs/test_wrapper.py create mode 100644 treetensor/torch/funcs/wrapper.py diff --git a/test/torch/funcs/test_wrapper.py b/test/torch/funcs/test_wrapper.py new file mode 100644 index 0000000000..4a02f9114c --- /dev/null +++ b/test/torch/funcs/test_wrapper.py @@ -0,0 +1,100 @@ +from unittest import skipUnless + +import pytest +import torch +from hbutils.testing import vpip + +import treetensor.torch as ttorch +from treetensor.torch import Size + + +@pytest.fixture() +def treetensor_x(): + return ttorch.randn({ + 'a': (2, 5, 7), + 'b': { + 'x': (3, 4, 6), + } + }) + + +@pytest.fixture() +def treetensor_y(): + return ttorch.randn({ + 'a': (2, 5, 7), + 'b': { + 'x': (3, 4, 6), + } + }) + + +@pytest.mark.unittest +class TestTorchTensorWrapper: + @skipUnless(vpip('torch') >= '2', 'Torch 2 required.') + def test_vmap(self, treetensor_x, treetensor_y): + f = lambda x, y: (x.sum() + y.mean() * 2) + n_pow = torch.vmap(f) + batched_pow = ttorch.vmap(f) + r = batched_pow(treetensor_x, treetensor_y) + + assert r.shape == Size({ + 'a': (2,), + 'b': { + 'x': (3,) + }, + }) + assert ttorch.isclose( + r, + ttorch.tensor({ + 'a': n_pow(treetensor_x.a, treetensor_y.a), + 'b': { + 'x': n_pow(treetensor_x.b.x, treetensor_y.b.x), + } + }) + ).all() + + @skipUnless(vpip('torch') >= '2', 'Torch 2 required.') + def test_vmap_in_dims(self, treetensor_x, treetensor_y): + f = lambda x, y: (x.sum() + y.mean() * 2) + n_pow = torch.vmap(f, in_dims=1) + batched_pow = ttorch.vmap(f, in_dims=1) + r = batched_pow(treetensor_x, treetensor_y) + + assert r.shape == Size({ + 'a': (5,), + 'b': { + 'x': (4,) + }, + }) + assert ttorch.isclose( + r, + ttorch.tensor({ + 'a': n_pow(treetensor_x.a, treetensor_y.a), + 'b': { + 'x': n_pow(treetensor_x.b.x, treetensor_y.b.x), + } + }) + ).all() + + @skipUnless(vpip('torch') >= '2', 'Torch 2 required.') + def test_vmap_nested(self, treetensor_x, treetensor_y): + f = lambda x, y: (x.sum() + y.mean() * 2) + n_pow = torch.vmap(torch.vmap(f)) + batched_pow = ttorch.vmap(ttorch.vmap(f)) + r = batched_pow(treetensor_x, treetensor_y) + + assert r.shape == Size({ + 'a': (2, 5), + 'b': { + 'x': (3, 4) + }, + }) + assert ttorch.isclose( + r, + ttorch.tensor({ + 'a': n_pow(treetensor_x.a, treetensor_y.a), + 'b': { + 'x': n_pow(treetensor_x.b.x, treetensor_y.b.x), + } + }) + ).all() diff --git a/treetensor/torch/funcs/__init__.py b/treetensor/torch/funcs/__init__.py index 98b029bf89..51a4c89230 100644 --- a/treetensor/torch/funcs/__init__.py +++ b/treetensor/torch/funcs/__init__.py @@ -14,6 +14,8 @@ from .operation import __all__ as _operation_all from .reduction import * from .reduction import __all__ as _reduction_all +from .wrapper import * +from .wrapper import __all__ as _wrapper_all from ...utils import module_autoremove __all__ = [ @@ -24,6 +26,7 @@ *_matrix_all, *_operation_all, *_reduction_all, + *_wrapper_all, ] _current_module = sys.modules[__name__] diff --git a/treetensor/torch/funcs/base.py b/treetensor/torch/funcs/base.py index d5ee52a939..3b5c78fbaa 100644 --- a/treetensor/torch/funcs/base.py +++ b/treetensor/torch/funcs/base.py @@ -1,3 +1,5 @@ +from functools import wraps + import torch from treevalue import func_treelize as original_func_treelize @@ -11,3 +13,15 @@ auto_tensor = replaceable_partial(auto_tree, cls=[(torch.is_tensor, Tensor)]) get_func_from_torch = module_func_loader(torch, Tensor, [(torch.is_tensor, Tensor)]) + + +def wrap_for_treelize(*args, **kwargs): + def _decorator(func): + @wraps(func) + def _new_func(*args_, **kwargs_): + retval = func(*args_, **kwargs_) + return func_treelize(*args, **kwargs)(retval) + + return _new_func + + return _decorator diff --git a/treetensor/torch/funcs/wrapper.py b/treetensor/torch/funcs/wrapper.py new file mode 100644 index 0000000000..cb33bc5270 --- /dev/null +++ b/treetensor/torch/funcs/wrapper.py @@ -0,0 +1,13 @@ +import torch + +from .base import doc_from_base, wrap_for_treelize + +__all__ = [ + 'vmap', +] + + +@doc_from_base() +@wrap_for_treelize() +def vmap(func, *args, **kwargs): + return torch.vmap(func, *args, **kwargs) From 53b83be850bedd9558975a04d332bd00fa102bf5 Mon Sep 17 00:00:00 2001 From: HansBug Date: Mon, 18 Sep 2023 17:05:24 +0800 Subject: [PATCH 2/7] dev(narugo): add rand --- test/torch/funcs/test_construct.py | 44 ++++++++++++++++++++++ treetensor/torch/funcs/construct.py | 57 +++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+) diff --git a/test/torch/funcs/test_construct.py b/test/torch/funcs/test_construct.py index 6ea120b858..6a7e59e75c 100644 --- a/test/torch/funcs/test_construct.py +++ b/test/torch/funcs/test_construct.py @@ -190,6 +190,50 @@ def test_randn_like(self): } }) + @choose_mark() + def test_rand(self): + _target = ttorch.rand(200, 300) + assert 0.45 <= _target.view(60000).mean().tolist() <= 0.55 + assert _target.shape == torch.Size([200, 300]) + + _target = ttorch.rand({ + 'a': (2, 3), + 'b': (5, 6), + 'x': { + 'c': (2, 3, 4), + } + }) + assert _target.shape == ttorch.Size({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([5, 6]), + 'x': { + 'c': torch.Size([2, 3, 4]), + } + }) + + @choose_mark() + def test_rand_like(self): + _target = ttorch.rand_like(torch.ones(200, 300)) + assert 0.45 <= _target.view(60000).mean().tolist() <= 0.55 + assert _target.shape == torch.Size([200, 300]) + + _target = ttorch.rand_like({ + 'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), + 'b': torch.tensor([1, 2, 3, 4], dtype=torch.float32), + 'x': { + 'c': torch.tensor([5, 6, 7], dtype=torch.float32), + 'd': torch.tensor([[[8, 9]]], dtype=torch.float32), + } + }) + assert _target.shape == ttorch.Size({ + 'a': torch.Size([2, 3]), + 'b': torch.Size([4]), + 'x': { + 'c': torch.Size([3]), + 'd': torch.Size([1, 1, 2]), + } + }) + @choose_mark() def test_randint(self): _target = ttorch.randint(-10, 10, { diff --git a/treetensor/torch/funcs/construct.py b/treetensor/torch/funcs/construct.py index ef166e9787..a0062ccfe3 100644 --- a/treetensor/torch/funcs/construct.py +++ b/treetensor/torch/funcs/construct.py @@ -10,6 +10,7 @@ 'tensor', 'as_tensor', 'clone', 'zeros', 'zeros_like', 'randn', 'randn_like', + 'rand', 'rand_like', 'randint', 'randint_like', 'ones', 'ones_like', 'full', 'full_like', @@ -216,6 +217,62 @@ def randn_like(input, *args, **kwargs): return stream_call(torch.randn_like, input, *args, **kwargs) +@doc_from_base() +@args_treelize +@func_treelize() +def rand(*args, **kwargs): + """ + In ``treetensor``, you can use ``rand`` to create a tree of tensors with numbers + obey standard normal distribution. + + Example:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.rand(2, 3) # the same as torch.rand(2, 3) + tensor([[-0.8534, -0.5754, -0.2507], + [ 0.0826, -1.4110, 0.9748]]) + + >>> ttorch.rand({'a': (2, 3), 'b': {'x': (4, )}}) + + ├── a --> tensor([[ 0.5398, 0.7529, -2.0339], + │ [-0.5722, -1.1900, 0.7945]]) + └── b --> + └── x --> tensor([-0.7181, 0.1670, -1.3587, -1.5129]) + """ + return stream_call(torch.rand, *args, **kwargs) + + +# noinspection PyShadowingBuiltins +@doc_from_base() +@args_treelize +@func_treelize() +def rand_like(input, *args, **kwargs): + """ + In ``treetensor``, you can use ``rand_like`` to create a tree of tensors with numbers + obey standard normal distribution like another tree. + + Example:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.rand_like(torch.ones(2, 3)) # the same as torch.rand_like(torch.ones(2, 3)) + tensor([[ 1.8436, 0.2601, 0.9687], + [ 1.6430, -0.1765, -1.1732]]) + + >>> ttorch.rand_like({ + ... 'a': torch.ones(2, 3), + ... 'b': {'x': torch.ones(4, )}, + ... }) + + ├── a --> tensor([[-0.1532, 1.3965, -1.2956], + │ [-0.0750, 0.6475, 1.1421]]) + └── b --> + └── x --> tensor([ 0.1730, 1.6085, 0.6487, -1.1022]) + """ + return stream_call(torch.rand_like, input, *args, **kwargs) + + @doc_from_base() @args_treelize @func_treelize() From 0d5588a9df444372993dbc60691c59ecc2aa8b0c Mon Sep 17 00:00:00 2001 From: HansBug Date: Mon, 18 Sep 2023 17:31:43 +0800 Subject: [PATCH 3/7] dev(narugo): fix docs issue --- requirements-doc.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-doc.txt b/requirements-doc.txt index c836abd335..6e7c6200de 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -1,7 +1,7 @@ Jinja2~=3.0.0 sphinx~=3.2.0 sphinx_rtd_theme~=0.4.3 -enum_tools +enum_tools~=0.9.0 sphinx-toolbox plantumlcli>=0.0.2 packaging From b298918829518ffab4d1effed94eb20ef3e96660 Mon Sep 17 00:00:00 2001 From: HansBug Date: Mon, 18 Sep 2023 17:35:14 +0800 Subject: [PATCH 4/7] dev(narugo): do not test numpy 1.18 on python 3.9 --- .github/workflows/test.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 198119f41d..fa5b65194b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,6 +36,8 @@ jobs: numpy-version: '1.22.0' - python-version: '3.7' numpy-version: '1.24.0' + - python-version: '3.9' + numpy-version: '1.18.0' - python-version: '3.10' numpy-version: '1.18.0' - python-version: '3.11' From b539dc7996c5e68467eb97b1996af2e7f22d090c Mon Sep 17 00:00:00 2001 From: HansBug Date: Mon, 18 Sep 2023 19:30:59 +0800 Subject: [PATCH 5/7] dev(narugo): add check for torch 2 --- .github/workflows/doc.yml | 2 +- test/torch/funcs/test_wrapper.py | 6 ++++++ treetensor/torch/funcs/wrapper.py | 8 +++++++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 1ecd55c02c..fca8581f31 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ 3.7 ] + python-version: [ 3.8 ] services: plantuml: diff --git a/test/torch/funcs/test_wrapper.py b/test/torch/funcs/test_wrapper.py index 4a02f9114c..9ce2ac9e73 100644 --- a/test/torch/funcs/test_wrapper.py +++ b/test/torch/funcs/test_wrapper.py @@ -98,3 +98,9 @@ def test_vmap_nested(self, treetensor_x, treetensor_y): } }) ).all() + + @skipUnless(vpip('torch') < '2', 'Torch 1.x required.') + def test_vmap_torch_1x(self, treetensor_x, treetensor_y): + f = lambda x, y: (x.sum() + y.mean() * 2) + with pytest.raises(NotImplementedError): + _ = ttorch.vmap(f) diff --git a/treetensor/torch/funcs/wrapper.py b/treetensor/torch/funcs/wrapper.py index cb33bc5270..8fdf5f131b 100644 --- a/treetensor/torch/funcs/wrapper.py +++ b/treetensor/torch/funcs/wrapper.py @@ -1,4 +1,5 @@ import torch +from hbutils.testing import vpip from .base import doc_from_base, wrap_for_treelize @@ -6,8 +7,13 @@ 'vmap', ] +_is_torch_2 = vpip('torch') >= '2' + @doc_from_base() @wrap_for_treelize() def vmap(func, *args, **kwargs): - return torch.vmap(func, *args, **kwargs) + if _is_torch_2: + return torch.vmap(func, *args, **kwargs) + else: + raise NotImplementedError(f'Function vmap is not supported in torch {torch.__version__}.') From f0ffabe127180a56cb3b10c940e42eb89749de94 Mon Sep 17 00:00:00 2001 From: HansBug Date: Mon, 18 Sep 2023 20:48:05 +0800 Subject: [PATCH 6/7] dev(narugo): fix bug on torch 1.x --- treetensor/torch/funcs/base.py | 3 +++ treetensor/torch/funcs/wrapper.py | 22 ++++++++++++---------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/treetensor/torch/funcs/base.py b/treetensor/torch/funcs/base.py index 3b5c78fbaa..1bf21a72dc 100644 --- a/treetensor/torch/funcs/base.py +++ b/treetensor/torch/funcs/base.py @@ -1,6 +1,7 @@ from functools import wraps import torch +from hbutils.testing import vpip from treevalue import func_treelize as original_func_treelize from ..tensor import Tensor @@ -14,6 +15,8 @@ get_func_from_torch = module_func_loader(torch, Tensor, [(torch.is_tensor, Tensor)]) +_is_torch_2 = vpip('torch') >= '2' + def wrap_for_treelize(*args, **kwargs): def _decorator(func): diff --git a/treetensor/torch/funcs/wrapper.py b/treetensor/torch/funcs/wrapper.py index 8fdf5f131b..2f442c6d99 100644 --- a/treetensor/torch/funcs/wrapper.py +++ b/treetensor/torch/funcs/wrapper.py @@ -1,19 +1,21 @@ import torch -from hbutils.testing import vpip -from .base import doc_from_base, wrap_for_treelize +from .base import doc_from_base, wrap_for_treelize, _is_torch_2 __all__ = [ 'vmap', ] -_is_torch_2 = vpip('torch') >= '2' - - -@doc_from_base() -@wrap_for_treelize() -def vmap(func, *args, **kwargs): - if _is_torch_2: +if _is_torch_2: + @doc_from_base() + @wrap_for_treelize() + def vmap(func, *args, **kwargs): return torch.vmap(func, *args, **kwargs) - else: + +else: + def vmap(func, *args, **kwargs): + """ + .. warning: + :method:`treetensor.torch.vmap` is not supported for torch 1.x. + """ raise NotImplementedError(f'Function vmap is not supported in torch {torch.__version__}.') From 114f98bd55109493195ed0120f0f07e2e9b3da1d Mon Sep 17 00:00:00 2001 From: HansBug Date: Tue, 19 Sep 2023 15:36:20 +0800 Subject: [PATCH 7/7] dev(narugo): fix issues --- test/torch/funcs/test_construct.py | 4 ++-- test/torch/funcs/test_wrapper.py | 30 +++++++++++++++--------------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/test/torch/funcs/test_construct.py b/test/torch/funcs/test_construct.py index 6a7e59e75c..ec037ece6d 100644 --- a/test/torch/funcs/test_construct.py +++ b/test/torch/funcs/test_construct.py @@ -219,9 +219,9 @@ def test_rand_like(self): _target = ttorch.rand_like({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), - 'b': torch.tensor([1, 2, 3, 4], dtype=torch.float32), + 'b': torch.tensor([1, 2, 3, 4], dtype=torch.float), 'x': { - 'c': torch.tensor([5, 6, 7], dtype=torch.float32), + 'c': torch.tensor([5, 6, 7], dtype=torch.float64), 'd': torch.tensor([[[8, 9]]], dtype=torch.float32), } }) diff --git a/test/torch/funcs/test_wrapper.py b/test/torch/funcs/test_wrapper.py index 9ce2ac9e73..dc147a4539 100644 --- a/test/torch/funcs/test_wrapper.py +++ b/test/torch/funcs/test_wrapper.py @@ -33,9 +33,9 @@ class TestTorchTensorWrapper: @skipUnless(vpip('torch') >= '2', 'Torch 2 required.') def test_vmap(self, treetensor_x, treetensor_y): f = lambda x, y: (x.sum() + y.mean() * 2) - n_pow = torch.vmap(f) - batched_pow = ttorch.vmap(f) - r = batched_pow(treetensor_x, treetensor_y) + native_vf = torch.vmap(f) + tv_vf = ttorch.vmap(f) + r = tv_vf(treetensor_x, treetensor_y) assert r.shape == Size({ 'a': (2,), @@ -46,9 +46,9 @@ def test_vmap(self, treetensor_x, treetensor_y): assert ttorch.isclose( r, ttorch.tensor({ - 'a': n_pow(treetensor_x.a, treetensor_y.a), + 'a': native_vf(treetensor_x.a, treetensor_y.a), 'b': { - 'x': n_pow(treetensor_x.b.x, treetensor_y.b.x), + 'x': native_vf(treetensor_x.b.x, treetensor_y.b.x), } }) ).all() @@ -56,9 +56,9 @@ def test_vmap(self, treetensor_x, treetensor_y): @skipUnless(vpip('torch') >= '2', 'Torch 2 required.') def test_vmap_in_dims(self, treetensor_x, treetensor_y): f = lambda x, y: (x.sum() + y.mean() * 2) - n_pow = torch.vmap(f, in_dims=1) - batched_pow = ttorch.vmap(f, in_dims=1) - r = batched_pow(treetensor_x, treetensor_y) + native_vf = torch.vmap(f, in_dims=1) + tv_vf = ttorch.vmap(f, in_dims=1) + r = tv_vf(treetensor_x, treetensor_y) assert r.shape == Size({ 'a': (5,), @@ -69,9 +69,9 @@ def test_vmap_in_dims(self, treetensor_x, treetensor_y): assert ttorch.isclose( r, ttorch.tensor({ - 'a': n_pow(treetensor_x.a, treetensor_y.a), + 'a': native_vf(treetensor_x.a, treetensor_y.a), 'b': { - 'x': n_pow(treetensor_x.b.x, treetensor_y.b.x), + 'x': native_vf(treetensor_x.b.x, treetensor_y.b.x), } }) ).all() @@ -79,9 +79,9 @@ def test_vmap_in_dims(self, treetensor_x, treetensor_y): @skipUnless(vpip('torch') >= '2', 'Torch 2 required.') def test_vmap_nested(self, treetensor_x, treetensor_y): f = lambda x, y: (x.sum() + y.mean() * 2) - n_pow = torch.vmap(torch.vmap(f)) - batched_pow = ttorch.vmap(ttorch.vmap(f)) - r = batched_pow(treetensor_x, treetensor_y) + native_vf = torch.vmap(torch.vmap(f)) + tv_vf = ttorch.vmap(ttorch.vmap(f)) + r = tv_vf(treetensor_x, treetensor_y) assert r.shape == Size({ 'a': (2, 5), @@ -92,9 +92,9 @@ def test_vmap_nested(self, treetensor_x, treetensor_y): assert ttorch.isclose( r, ttorch.tensor({ - 'a': n_pow(treetensor_x.a, treetensor_y.a), + 'a': native_vf(treetensor_x.a, treetensor_y.a), 'b': { - 'x': n_pow(treetensor_x.b.x, treetensor_y.b.x), + 'x': native_vf(treetensor_x.b.x, treetensor_y.b.x), } }) ).all()