Skip to content

Commit 2249a29

Browse files
v0drofacebook-github-bot
authored andcommitted
Fix segfault with torch.orgqr. (pytorch#46700)
Summary: Fixes pytorch#41768 The fault was that a NULL `tau` would get passed to LAPACK function. This PR fixes that by checking whether the `tau` contains 0 elements at the beginning of the function. Pull Request resolved: pytorch#46700 Reviewed By: albanD Differential Revision: D24616427 Pulled By: mruberry fbshipit-source-id: 92e8f1489b113c0ceeca6e54dea8b810a51a63c3
1 parent f629fbe commit 2249a29

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

aten/src/TH/generic/THLapack.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#define TH_GENERIC_FILE "TH/generic/THLapack.cpp"
33
#else
44

5-
65
TH_EXTERNC void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info);
76
TH_EXTERNC void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info);
87
TH_EXTERNC void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info);

aten/src/TH/generic/THTensorLapack.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,8 @@ void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau)
410410
{
411411
if (a == NULL) a = ra_;
412412
THArgCheck(THTensor_nDimension(a) == 2, 1, "'input' should be 2 dimensional");
413-
THArgCheck(!a->is_empty(), 1, "'input' should not be empty");
413+
THArgCheck(!a->is_empty(), 2, "'input' should not be empty");
414+
THArgCheck(!tau->is_empty(), 3, "'tau' should not be empty");
414415

415416
THTensor *ra__ = NULL;
416417
ra__ = THTensor_(cloneColumnMajor)(ra_, a);

test/test_torch.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -15539,7 +15539,8 @@ def test_orgqr_errors(self, device):
1553915539
((10,), (2,), r"'input' should be 2 dimensional"),
1554015540
((10, 6), (20,), r"input.size\(1\) must be greater than or equal to input2.size\(0\)"),
1554115541
((6, 10), (5,), r"input.size\(0\) must be greater than or equal to input.size\(1\)"),
15542-
((0, 0), (0,), r"'input' should not be empty")
15542+
((0, 0), (0,), r"'input' should not be empty"),
15543+
((2, 2), (2, 0,), r"'tau' should not be empty")
1554315544
]
1554415545
for a_size, tau_size, error_regex in test_cases:
1554515546
a = torch.rand(*a_size, device=device)

0 commit comments

Comments
 (0)