Skip to content

Commit 5bc044c

Browse files
Update Torch QR dispatch
1 parent 4a7b7a8 commit 5bc044c

File tree

5 files changed

+55
-43
lines changed

5 files changed

+55
-43
lines changed

pytensor/link/pytorch/dispatch/nlinalg.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
KroneckerProduct,
1010
MatrixInverse,
1111
MatrixPinv,
12-
QRFull,
1312
SLogDet,
1413
)
1514

@@ -70,21 +69,6 @@ def matrix_inverse(x):
7069
return matrix_inverse
7170

7271

73-
@pytorch_funcify.register(QRFull)
74-
def pytorch_funcify_QRFull(op, **kwargs):
75-
mode = op.mode
76-
if mode == "raw":
77-
raise NotImplementedError("raw mode not implemented in PyTorch")
78-
79-
def qr_full(x):
80-
Q, R = torch.linalg.qr(x, mode=mode)
81-
if mode == "r":
82-
return R
83-
return Q, R
84-
85-
return qr_full
86-
87-
8872
@pytorch_funcify.register(MatrixPinv)
8973
def pytorch_funcify_Pinv(op, **kwargs):
9074
hermitian = op.hermitian
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch import pytorch_funcify
4+
from pytensor.tensor.slinalg import QR
5+
6+
7+
@pytorch_funcify.register(QR)
8+
def pytorch_funcify_QR(op, **kwargs):
9+
mode = op.mode
10+
if mode == "raw":
11+
raise NotImplementedError("raw mode not implemented in PyTorch")
12+
13+
def qr(x):
14+
Q, R = torch.linalg.qr(x, mode=mode)
15+
if mode == "r":
16+
return R
17+
return Q, R
18+
19+
return qr

tests/link/pytorch/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor import config
5+
from pytensor.tensor.type import matrix
6+
7+
8+
@pytest.fixture
9+
def matrix_test():
10+
rng = np.random.default_rng(213234)
11+
12+
M = rng.normal(size=(3, 3))
13+
test_value = M.dot(M.T).astype(config.floatX)
14+
15+
x = matrix("x")
16+
return x, test_value

tests/link/pytorch/test_nlinalg.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,6 @@
88
from tests.link.pytorch.test_basic import compare_pytorch_and_py
99

1010

11-
@pytest.fixture
12-
def matrix_test():
13-
rng = np.random.default_rng(213234)
14-
15-
M = rng.normal(size=(3, 3))
16-
test_value = M.dot(M.T).astype(config.floatX)
17-
18-
x = matrix("x")
19-
return (x, test_value)
20-
21-
2211
@pytest.mark.parametrize(
2312
"func",
2413
(pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det),
@@ -34,22 +23,6 @@ def assert_fn(x, y):
3423
compare_pytorch_and_py([x], outs, [test_value], assert_fn=assert_fn)
3524

3625

37-
@pytest.mark.parametrize(
38-
"mode",
39-
(
40-
"complete",
41-
"reduced",
42-
"r",
43-
pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)),
44-
),
45-
)
46-
def test_qr(mode, matrix_test):
47-
x, test_value = matrix_test
48-
outs = pt_nla.qr(x, mode=mode)
49-
50-
compare_pytorch_and_py([x], outs, [test_value])
51-
52-
5326
@pytest.mark.parametrize("compute_uv", [True, False])
5427
@pytest.mark.parametrize("full_matrices", [True, False])
5528
def test_svd(compute_uv, full_matrices, matrix_test):

tests/link/pytorch/test_slinalg.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import pytest
2+
3+
import pytensor
4+
from tests.link.pytorch.test_basic import compare_pytorch_and_py
5+
6+
7+
@pytest.mark.parametrize(
8+
"mode",
9+
(
10+
"complete",
11+
"reduced",
12+
"r",
13+
pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)),
14+
),
15+
)
16+
def test_qr(mode, matrix_test):
17+
x, test_value = matrix_test
18+
outs = pytensor.tensor.slinalg.qr(x, mode=mode)
19+
20+
compare_pytorch_and_py([x], outs, [test_value])

0 commit comments

Comments
 (0)