From 75f07b131eb64ad259339bd95b9627e8c9d44caa Mon Sep 17 00:00:00 2001 From: Mike Taves Date: Sun, 26 Jan 2025 14:50:08 +1300 Subject: [PATCH] Add matrix multiplier methods --- CHANGES.txt | 1 + README.rst | 7 ++-- src/affine.py | 65 +++++++++++++++++++++++++++++++++++-- tests/test_numpy.py | 29 ++++++++++++++++- tests/test_rotation.py | 8 ++--- tests/test_transform.py | 71 +++++++++++++++++++++++++++++++---------- 6 files changed, 155 insertions(+), 26 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index 61e6d74..4157ae0 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -13,6 +13,7 @@ CHANGES (#111). - Source was moved to a single-module affine.py in the src directory (#112). - Add numpy __array__ interface (#108). +- Add support for ``@`` matrix multiplier methods (#122). 2.4.0 (2023-01-19) ------------------ diff --git a/README.rst b/README.rst index fcd5b1b..cc428dd 100644 --- a/README.rst +++ b/README.rst @@ -51,8 +51,9 @@ Matrices can be created by passing the values ``a, b, c, d, e, f`` to the Affine(0.7071067811865476, -0.7071067811865475, 0.0, 0.7071067811865475, 0.7071067811865476, 0.0) -These matrices can be applied to ``(x, y)`` tuples to obtain transformed -coordinates ``(x', y')``. +These matrices can be applied to ``(x, y)`` tuples using the +``*`` operator (or the ``@`` matrix multiplier operator for +future releases) to obtain transformed coordinates ``(x', y')``. .. code-block:: pycon @@ -91,7 +92,7 @@ origin can be easily computed. >>> fwd * (col, row) (-237481.5, 195036.4) -The reverse transformation is obtained using the ``~`` operator. +The reverse transformation is obtained using the ``~`` inverse operator. .. code-block:: pycon diff --git a/src/affine.py b/src/affine.py index c6a9a3b..224b60c 100644 --- a/src/affine.py +++ b/src/affine.py @@ -549,6 +549,61 @@ def __add__(self, other): __iadd__ = __add__ + def __matmul__(self, other): + """Matrix multiplication. + + Apply the transform using matrix multiplication, creating + a resulting object of the same type. A transform may be applied + to another transform or vector array. + + Parameters + ---------- + other : Affine or iterable of (vx, vy, [vw]) + + Returns + ------- + Affine or a tuple of two three floats + """ + sa, sb, sc, sd, se, sf = self[:6] + if isinstance(other, Affine): + oa, ob, oc, od, oe, of = other[:6] + return self.__class__( + sa * oa + sb * od, + sa * ob + sb * oe, + sa * oc + sb * of + sc, + sd * oa + se * od, + sd * ob + se * oe, + sd * oc + se * of + sf, + ) + # vector of 2 or 3 values + try: + other = tuple(map(float, other)) + except (TypeError, ValueError): + return NotImplemented + num_values = len(other) + if num_values == 2: + vx, vy = other + elif num_values == 3: + vx, vy, vw = other + if vw != 1.0: + raise ValueError("third value must be 1.0") + else: + raise TypeError("expected vector of 2 or 3 values") + px = vx * sa + vy * sb + sc + py = vx * sd + vy * se + sf + if num_values == 2: + return (px, py) + return (px, py, vw) + + def __rmatmul__(self, other): + return NotImplemented + + def __imatmul__(self, other): + """Provide wrapper for `__matmul__`, however `other` is not modified inplace.""" + if isinstance(other, (Affine, tuple)): + return self.__matmul__(other) + return NotImplemented + def __mul__(self, other): """Multiplication. @@ -564,6 +619,12 @@ def __mul__(self, other): ------- Affine or a tuple of two floats """ + # TODO: consider enable this for 3.1 + # warnings.warn( + # "Use `@` matmul instead of `*` mul operator for matrix multiplication", + # PendingDeprecationWarning, + # stacklevel=2, + # ) sa, sb, sc, sd, se, sf = self[:6] if isinstance(other, Affine): oa, ob, oc, od, oe, of = other[:6] @@ -667,7 +728,7 @@ def loadsw(s: str) -> Affine: raise ValueError(f"Expected 6 coefficients, found {len(coeffs)}") a, d, b, e, c, f = (float(x) for x in coeffs) center = Affine(a, b, c, d, e, f) - return center * Affine.translation(-0.5, -0.5) + return center @ Affine.translation(-0.5, -0.5) def dumpsw(obj: Affine) -> str: @@ -680,7 +741,7 @@ def dumpsw(obj: Affine) -> str: ------- str """ - center = obj * Affine.translation(0.5, 0.5) + center = obj @ Affine.translation(0.5, 0.5) return "\n".join(repr(getattr(center, x)) for x in list("adbecf")) + "\n" diff --git a/tests/test_numpy.py b/tests/test_numpy.py index 200045d..74bc49f 100644 --- a/tests/test_numpy.py +++ b/tests/test_numpy.py @@ -2,7 +2,7 @@ import pytest -from affine import Affine +from affine import Affine, identity try: import numpy as np @@ -66,3 +66,30 @@ def test_linalg(): ) testing.assert_allclose(~tfm, expected_inv) testing.assert_allclose(np.linalg.inv(ar), expected_inv) + + +def test_matmul(): + A = Affine(2, 0, 3, 0, 3, 2) + Ar = np.array(A) + + # matrix @ matrix = matrix + res = A @ identity + assert isinstance(res, Affine) + testing.assert_equal(res, Ar) + res = Ar @ np.eye(3) + assert isinstance(res, np.ndarray) + testing.assert_equal(res, Ar) + + # matrix @ vector = vector + v = (2, 3, 1) + vr = np.array(v) + expected_p = (7, 11, 1) + res = A @ v + assert isinstance(res, tuple) + testing.assert_equal(res, expected_p) + res = A @ vr + assert isinstance(res, tuple) + testing.assert_equal(res, expected_p) + res = Ar @ vr + assert isinstance(res, np.ndarray) + testing.assert_equal(res, expected_p) diff --git a/tests/test_rotation.py b/tests/test_rotation.py index fce3676..edb315e 100644 --- a/tests/test_rotation.py +++ b/tests/test_rotation.py @@ -16,7 +16,7 @@ def test_rotation_angle(): | 0---------* - Affine.rotation(45.0) * (1.0, 0.0) == (0.707..., 0.707...) + Affine.rotation(45.0) @ (1.0, 0.0) == (0.707..., 0.707...) | | * @@ -24,7 +24,7 @@ def test_rotation_angle(): | 0---------- """ - x, y = Affine.rotation(45.0) * (1.0, 0.0) + x, y = Affine.rotation(45.0) @ (1.0, 0.0) sqrt2div2 = math.sqrt(2.0) / 2.0 assert x == pytest.approx(sqrt2div2) assert y == pytest.approx(sqrt2div2) @@ -55,8 +55,8 @@ def test_rotation_matrix_pivot(): rot = Affine.rotation(90.0, pivot=(1.0, 1.0)) exp = ( Affine.translation(1.0, 1.0) - * Affine.rotation(90.0) - * Affine.translation(-1.0, -1.0) + @ Affine.rotation(90.0) + @ Affine.translation(-1.0, -1.0) ) for r, e in zip(rot, exp): assert r == pytest.approx(e) diff --git a/tests/test_transform.py b/tests/test_transform.py index 031f09f..04a910c 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -150,7 +150,7 @@ def test_permutation_constructor(): perm = Affine.permutation() assert isinstance(perm, Affine) assert tuple(perm) == (0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0) - assert (perm * perm).is_identity + assert (perm @ perm).is_identity def test_translation_constructor(): @@ -348,16 +348,16 @@ def test_sub(): Affine(1, 2, 3, 4, 5, 6) - Affine(6, 5, 4, 3, 2, 1) -def test_mul_by_identity(): +def test_matmul_by_identity(): t = Affine(1, 2, 3, 4, 5, 6) - assert tuple(t * Affine.identity()) == tuple(t) + assert tuple(t @ Affine.identity()) == tuple(t) -def test_mul_transform(): - t = Affine.rotation(5) * Affine.rotation(29) +def test_matmul_transform(): + t = Affine.rotation(5) @ Affine.rotation(29) assert isinstance(t, Affine) seq_almost_equal(t, Affine.rotation(34)) - t = Affine.scale(3, 5) * Affine.scale(2) + t = Affine.scale(3, 5) @ Affine.scale(2) seq_almost_equal(t, Affine.scale(6, 10)) @@ -369,7 +369,7 @@ def test_itransform(): A = Affine.rotation(33) pts = [(4, 1), (-1, 0), (3, 2)] - pts_expect = [A * pt for pt in pts] + pts_expect = [A @ pt for pt in pts] r = A.itransform(pts) assert r is None assert pts == pts_expect @@ -380,7 +380,12 @@ def test_mul_wrong_type(): Affine(1, 2, 3, 4, 5, 6) * None -def test_mul_sequence_wrong_member_types(): +def test_matmul_wrong_type(): + with pytest.raises(TypeError): + Affine(1, 2, 3, 4, 5, 6) @ None + + +def test_matmul_sequence_wrong_member_types(): class NotPtSeq: @classmethod def from_points(cls, points): @@ -392,6 +397,9 @@ def __iter__(): with pytest.raises(TypeError): Affine(1, 2, 3, 4, 5, 6) * NotPtSeq() + with pytest.raises(TypeError): + Affine(1, 2, 3, 4, 5, 6) @ NotPtSeq() + def test_imul_transform(): t = Affine.translation(3, 5) @@ -400,12 +408,19 @@ def test_imul_transform(): seq_almost_equal(t, Affine.translation(1, 8.5)) +def test_imatmul_transform(): + t = Affine.translation(3, 5) + t @= Affine.translation(-2, 3.5) + assert isinstance(t, Affine) + seq_almost_equal(t, Affine.translation(1, 8.5)) + + def test_inverse(): seq_almost_equal(~Affine.identity(), Affine.identity()) seq_almost_equal(~Affine.translation(2, -3), Affine.translation(-2, 3)) seq_almost_equal(~Affine.rotation(-33.3), Affine.rotation(33.3)) t = Affine(1, 2, 3, 4, 5, 6) - seq_almost_equal(~t * t, Affine.identity()) + seq_almost_equal(~t @ t, Affine.identity()) def test_cant_invert_degenerate(): @@ -489,17 +504,31 @@ def test_rmul_notimplemented(): (1.0, 1.0) * t +def test_imatmul_not_implemented(): + t = Affine.identity() + with pytest.raises(TypeError): + t @= 2.0 + + def test_mul_tuple(): t = Affine(1, 2, 3, 4, 5, 6) assert t * (2, 2) == (9, 24) + with pytest.raises(TypeError): + t * (2, 2, 1) + + +def test_rmatmul_notimplemented(): + t = Affine.identity() + with pytest.raises(TypeError): + (1.0, 1.0) @ t def test_associative(): point = (12, 5) trans = Affine.translation(-10.0, -5.0) rot90 = Affine.rotation(90.0) - result1 = rot90 * (trans * point) - result2 = (rot90 * trans) * point + result1 = rot90 @ (trans @ point) + result2 = (rot90 @ trans) @ point seq_almost_equal(result1, (0.0, 2.0)) seq_almost_equal(result1, result2) @@ -508,8 +537,8 @@ def test_roundtrip(): point = (12, 5) trans = Affine.translation(3, 4) rot37 = Affine.rotation(37.0) - point_prime = (trans * rot37) * point - roundtrip_point = ~(trans * rot37) * point_prime + point_prime = (trans @ rot37) @ point + roundtrip_point = ~(trans @ rot37) @ point_prime seq_almost_equal(point, roundtrip_point) @@ -526,14 +555,14 @@ def test_eccentricity(): def test_eccentricity_complex(): - assert (Affine.scale(2, 3) * Affine.rotation(77)).eccentricity == pytest.approx( + assert (Affine.scale(2, 3) @ Affine.rotation(77)).eccentricity == pytest.approx( math.sqrt(5) / 3 ) - assert (Affine.rotation(77) * Affine.scale(2, 3)).eccentricity == pytest.approx( + assert (Affine.rotation(77) @ Affine.scale(2, 3)).eccentricity == pytest.approx( math.sqrt(5) / 3 ) assert ( - Affine.translation(32, -47) * Affine.rotation(77) * Affine.scale(2, 3) + Affine.translation(32, -47) @ Affine.rotation(77) @ Affine.scale(2, 3) ).eccentricity == pytest.approx(math.sqrt(5) / 3) @@ -561,8 +590,13 @@ class TextPoint: def __rmul__(self, other): return other * (1, 2) + def __rmatmul__(self, other): + return other @ (1, 2) + assert Affine.identity() * TextPoint() == (1, 2) + assert Affine.identity() @ TextPoint() == (1, 2) + # See gh-71 for bug report motivating this test. def test_mul_fallback_type_error(): @@ -577,8 +611,13 @@ def __iter__(self): def __rmul__(self, other): return other * (1, 2) + def __rmatmul__(self, other): + return other @ (1, 2) + assert Affine.identity() * TextPoint() == (1, 2) + assert Affine.identity() @ TextPoint() == (1, 2) + def test_init_invalid_g(): with pytest.raises(ValueError, match="g must"):