Skip to content

Commit 679ed31

Browse files
authored
Fix to BaseTransport.transform_labels() (#208)
* Fix to BaseTransport.transform_labels() Issue #207 * Fix - forgot to commit
1 parent 23db72c commit 679ed31

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

ot/da.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1111,7 +1111,7 @@ def transform_labels(self, ys=None):
11111111
D1 = np.zeros((n, len(ysTemp)))
11121112

11131113
# perform label propagation
1114-
transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
1114+
transp = self.coupling_ / np.sum(self.coupling_, 0, keepdims=True)
11151115

11161116
# set nans to 0
11171117
transp[~ np.isfinite(transp)] = 0

0 commit comments

Comments
 (0)