Skip to content

Commit

Permalink
Add matrix multiplier methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mwtoews committed Jan 26, 2025
1 parent 756471b commit 75f07b1
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 26 deletions.
1 change: 1 addition & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ CHANGES
(#111).
- Source was moved to a single-module affine.py in the src directory (#112).
- Add numpy __array__ interface (#108).
- Add support for ``@`` matrix multiplier methods (#122).

2.4.0 (2023-01-19)
------------------
Expand Down
7 changes: 4 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ Matrices can be created by passing the values ``a, b, c, d, e, f`` to the
Affine(0.7071067811865476, -0.7071067811865475, 0.0,
0.7071067811865475, 0.7071067811865476, 0.0)
These matrices can be applied to ``(x, y)`` tuples to obtain transformed
coordinates ``(x', y')``.
These matrices can be applied to ``(x, y)`` tuples using the
``*`` operator (or the ``@`` matrix multiplier operator for
future releases) to obtain transformed coordinates ``(x', y')``.

.. code-block:: pycon
Expand Down Expand Up @@ -91,7 +92,7 @@ origin can be easily computed.
>>> fwd * (col, row)
(-237481.5, 195036.4)
The reverse transformation is obtained using the ``~`` operator.
The reverse transformation is obtained using the ``~`` inverse operator.

.. code-block:: pycon
Expand Down
65 changes: 63 additions & 2 deletions src/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,61 @@ def __add__(self, other):

__iadd__ = __add__

def __matmul__(self, other):
"""Matrix multiplication.
Apply the transform using matrix multiplication, creating
a resulting object of the same type. A transform may be applied
to another transform or vector array.
Parameters
----------
other : Affine or iterable of (vx, vy, [vw])
Returns
-------
Affine or a tuple of two three floats
"""
sa, sb, sc, sd, se, sf = self[:6]
if isinstance(other, Affine):
oa, ob, oc, od, oe, of = other[:6]
return self.__class__(
sa * oa + sb * od,
sa * ob + sb * oe,
sa * oc + sb * of + sc,
sd * oa + se * od,
sd * ob + se * oe,
sd * oc + se * of + sf,
)
# vector of 2 or 3 values
try:
other = tuple(map(float, other))
except (TypeError, ValueError):
return NotImplemented
num_values = len(other)
if num_values == 2:
vx, vy = other
elif num_values == 3:
vx, vy, vw = other
if vw != 1.0:
raise ValueError("third value must be 1.0")
else:
raise TypeError("expected vector of 2 or 3 values")
px = vx * sa + vy * sb + sc
py = vx * sd + vy * se + sf
if num_values == 2:
return (px, py)
return (px, py, vw)

def __rmatmul__(self, other):
return NotImplemented

def __imatmul__(self, other):
"""Provide wrapper for `__matmul__`, however `other` is not modified inplace."""
if isinstance(other, (Affine, tuple)):
return self.__matmul__(other)
return NotImplemented

def __mul__(self, other):
"""Multiplication.
Expand All @@ -564,6 +619,12 @@ def __mul__(self, other):
-------
Affine or a tuple of two floats
"""
# TODO: consider enable this for 3.1
# warnings.warn(
# "Use `@` matmul instead of `*` mul operator for matrix multiplication",
# PendingDeprecationWarning,
# stacklevel=2,
# )
sa, sb, sc, sd, se, sf = self[:6]
if isinstance(other, Affine):
oa, ob, oc, od, oe, of = other[:6]
Expand Down Expand Up @@ -667,7 +728,7 @@ def loadsw(s: str) -> Affine:
raise ValueError(f"Expected 6 coefficients, found {len(coeffs)}")
a, d, b, e, c, f = (float(x) for x in coeffs)
center = Affine(a, b, c, d, e, f)
return center * Affine.translation(-0.5, -0.5)
return center @ Affine.translation(-0.5, -0.5)


def dumpsw(obj: Affine) -> str:
Expand All @@ -680,7 +741,7 @@ def dumpsw(obj: Affine) -> str:
-------
str
"""
center = obj * Affine.translation(0.5, 0.5)
center = obj @ Affine.translation(0.5, 0.5)
return "\n".join(repr(getattr(center, x)) for x in list("adbecf")) + "\n"


Expand Down
29 changes: 28 additions & 1 deletion tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from affine import Affine
from affine import Affine, identity

try:
import numpy as np
Expand Down Expand Up @@ -66,3 +66,30 @@ def test_linalg():
)
testing.assert_allclose(~tfm, expected_inv)
testing.assert_allclose(np.linalg.inv(ar), expected_inv)


def test_matmul():
A = Affine(2, 0, 3, 0, 3, 2)
Ar = np.array(A)

# matrix @ matrix = matrix
res = A @ identity
assert isinstance(res, Affine)
testing.assert_equal(res, Ar)
res = Ar @ np.eye(3)
assert isinstance(res, np.ndarray)
testing.assert_equal(res, Ar)

# matrix @ vector = vector
v = (2, 3, 1)
vr = np.array(v)
expected_p = (7, 11, 1)
res = A @ v
assert isinstance(res, tuple)
testing.assert_equal(res, expected_p)
res = A @ vr
assert isinstance(res, tuple)
testing.assert_equal(res, expected_p)
res = Ar @ vr
assert isinstance(res, np.ndarray)
testing.assert_equal(res, expected_p)
8 changes: 4 additions & 4 deletions tests/test_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ def test_rotation_angle():
|
0---------*
Affine.rotation(45.0) * (1.0, 0.0) == (0.707..., 0.707...)
Affine.rotation(45.0) @ (1.0, 0.0) == (0.707..., 0.707...)
|
| *
|
|
0----------
"""
x, y = Affine.rotation(45.0) * (1.0, 0.0)
x, y = Affine.rotation(45.0) @ (1.0, 0.0)
sqrt2div2 = math.sqrt(2.0) / 2.0
assert x == pytest.approx(sqrt2div2)
assert y == pytest.approx(sqrt2div2)
Expand Down Expand Up @@ -55,8 +55,8 @@ def test_rotation_matrix_pivot():
rot = Affine.rotation(90.0, pivot=(1.0, 1.0))
exp = (
Affine.translation(1.0, 1.0)
* Affine.rotation(90.0)
* Affine.translation(-1.0, -1.0)
@ Affine.rotation(90.0)
@ Affine.translation(-1.0, -1.0)
)
for r, e in zip(rot, exp):
assert r == pytest.approx(e)
71 changes: 55 additions & 16 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_permutation_constructor():
perm = Affine.permutation()
assert isinstance(perm, Affine)
assert tuple(perm) == (0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0)
assert (perm * perm).is_identity
assert (perm @ perm).is_identity


def test_translation_constructor():
Expand Down Expand Up @@ -348,16 +348,16 @@ def test_sub():
Affine(1, 2, 3, 4, 5, 6) - Affine(6, 5, 4, 3, 2, 1)


def test_mul_by_identity():
def test_matmul_by_identity():
t = Affine(1, 2, 3, 4, 5, 6)
assert tuple(t * Affine.identity()) == tuple(t)
assert tuple(t @ Affine.identity()) == tuple(t)


def test_mul_transform():
t = Affine.rotation(5) * Affine.rotation(29)
def test_matmul_transform():
t = Affine.rotation(5) @ Affine.rotation(29)
assert isinstance(t, Affine)
seq_almost_equal(t, Affine.rotation(34))
t = Affine.scale(3, 5) * Affine.scale(2)
t = Affine.scale(3, 5) @ Affine.scale(2)
seq_almost_equal(t, Affine.scale(6, 10))


Expand All @@ -369,7 +369,7 @@ def test_itransform():

A = Affine.rotation(33)
pts = [(4, 1), (-1, 0), (3, 2)]
pts_expect = [A * pt for pt in pts]
pts_expect = [A @ pt for pt in pts]
r = A.itransform(pts)
assert r is None
assert pts == pts_expect
Expand All @@ -380,7 +380,12 @@ def test_mul_wrong_type():
Affine(1, 2, 3, 4, 5, 6) * None


def test_mul_sequence_wrong_member_types():
def test_matmul_wrong_type():
with pytest.raises(TypeError):
Affine(1, 2, 3, 4, 5, 6) @ None


def test_matmul_sequence_wrong_member_types():
class NotPtSeq:
@classmethod
def from_points(cls, points):
Expand All @@ -392,6 +397,9 @@ def __iter__():
with pytest.raises(TypeError):
Affine(1, 2, 3, 4, 5, 6) * NotPtSeq()

with pytest.raises(TypeError):
Affine(1, 2, 3, 4, 5, 6) @ NotPtSeq()


def test_imul_transform():
t = Affine.translation(3, 5)
Expand All @@ -400,12 +408,19 @@ def test_imul_transform():
seq_almost_equal(t, Affine.translation(1, 8.5))


def test_imatmul_transform():
t = Affine.translation(3, 5)
t @= Affine.translation(-2, 3.5)
assert isinstance(t, Affine)
seq_almost_equal(t, Affine.translation(1, 8.5))


def test_inverse():
seq_almost_equal(~Affine.identity(), Affine.identity())
seq_almost_equal(~Affine.translation(2, -3), Affine.translation(-2, 3))
seq_almost_equal(~Affine.rotation(-33.3), Affine.rotation(33.3))
t = Affine(1, 2, 3, 4, 5, 6)
seq_almost_equal(~t * t, Affine.identity())
seq_almost_equal(~t @ t, Affine.identity())


def test_cant_invert_degenerate():
Expand Down Expand Up @@ -489,17 +504,31 @@ def test_rmul_notimplemented():
(1.0, 1.0) * t


def test_imatmul_not_implemented():
t = Affine.identity()
with pytest.raises(TypeError):
t @= 2.0


def test_mul_tuple():
t = Affine(1, 2, 3, 4, 5, 6)
assert t * (2, 2) == (9, 24)
with pytest.raises(TypeError):
t * (2, 2, 1)


def test_rmatmul_notimplemented():
t = Affine.identity()
with pytest.raises(TypeError):
(1.0, 1.0) @ t


def test_associative():
point = (12, 5)
trans = Affine.translation(-10.0, -5.0)
rot90 = Affine.rotation(90.0)
result1 = rot90 * (trans * point)
result2 = (rot90 * trans) * point
result1 = rot90 @ (trans @ point)
result2 = (rot90 @ trans) @ point
seq_almost_equal(result1, (0.0, 2.0))
seq_almost_equal(result1, result2)

Expand All @@ -508,8 +537,8 @@ def test_roundtrip():
point = (12, 5)
trans = Affine.translation(3, 4)
rot37 = Affine.rotation(37.0)
point_prime = (trans * rot37) * point
roundtrip_point = ~(trans * rot37) * point_prime
point_prime = (trans @ rot37) @ point
roundtrip_point = ~(trans @ rot37) @ point_prime
seq_almost_equal(point, roundtrip_point)


Expand All @@ -526,14 +555,14 @@ def test_eccentricity():


def test_eccentricity_complex():
assert (Affine.scale(2, 3) * Affine.rotation(77)).eccentricity == pytest.approx(
assert (Affine.scale(2, 3) @ Affine.rotation(77)).eccentricity == pytest.approx(
math.sqrt(5) / 3
)
assert (Affine.rotation(77) * Affine.scale(2, 3)).eccentricity == pytest.approx(
assert (Affine.rotation(77) @ Affine.scale(2, 3)).eccentricity == pytest.approx(
math.sqrt(5) / 3
)
assert (
Affine.translation(32, -47) * Affine.rotation(77) * Affine.scale(2, 3)
Affine.translation(32, -47) @ Affine.rotation(77) @ Affine.scale(2, 3)
).eccentricity == pytest.approx(math.sqrt(5) / 3)


Expand Down Expand Up @@ -561,8 +590,13 @@ class TextPoint:
def __rmul__(self, other):
return other * (1, 2)

def __rmatmul__(self, other):
return other @ (1, 2)

assert Affine.identity() * TextPoint() == (1, 2)

assert Affine.identity() @ TextPoint() == (1, 2)


# See gh-71 for bug report motivating this test.
def test_mul_fallback_type_error():
Expand All @@ -577,8 +611,13 @@ def __iter__(self):
def __rmul__(self, other):
return other * (1, 2)

def __rmatmul__(self, other):
return other @ (1, 2)

assert Affine.identity() * TextPoint() == (1, 2)

assert Affine.identity() @ TextPoint() == (1, 2)


def test_init_invalid_g():
with pytest.raises(ValueError, match="g must"):
Expand Down

0 comments on commit 75f07b1

Please sign in to comment.