Skip to content

Commit fd1ba7c

Browse files
committed
[MINOR] MM Specializations
This commit adds specializations for matrix multiplication with the following: 1. dense-sparse with sparse output 2. ultra sparse out dense dense in. 3. sparse out on sparse vector right side in. Furthermore, I modified the call stack to branch to the native mm inside LibMatrixMult, to allow easy native support for CLA by calling LibMatrixMult, instead of having to go through a MatrixBlock. Closes #2212
1 parent e022eaf commit fd1ba7c

File tree

4 files changed

+163
-24
lines changed

4 files changed

+163
-24
lines changed

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java

Lines changed: 136 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,13 @@ public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock
159159
* @return ret Matrix Block
160160
*/
161161
public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
162+
if(NativeHelper.isNativeLibraryLoaded())
163+
return LibMatrixNative.matrixMult(m1, m2, ret, k);
164+
else
165+
return matrixMult(m1, m2, ret, false, k);
166+
}
167+
168+
public static MatrixBlock matrixMultNonNative(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int k) {
162169
return matrixMult(m1, m2, ret, false, k);
163170
}
164171

@@ -256,7 +263,7 @@ private static void singleThreadedMatrixMult(MatrixBlock m1, MatrixBlock m2, Mat
256263
// core matrix mult computation
257264
if(ultraSparse && !fixedRet)
258265
matrixMultUltraSparse(m1, m2, ret, m1Perm, 0, ru2);
259-
else if(!m1.sparse && !m2.sparse)
266+
else if(!m1.sparse && !m2.sparse && !ret.sparse)
260267
matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru2, 0, m2.clen);
261268
else if(m1.sparse && m2.sparse)
262269
matrixMultSparseSparse(m1, m2, ret, pm2, sparse, 0, ru2);
@@ -1257,6 +1264,100 @@ public static void matrixMultDenseDenseMM(DenseBlock a, DenseBlock b, DenseBlock
12571264
}
12581265

12591266
private static void matrixMultDenseSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean pm2, int rl, int ru) {
1267+
if(ret.isInSparseFormat()){
1268+
if(!m1.sparse && !m2.sparse)
1269+
matrixMultDenseDenseOutSparse(m1,m2,ret, pm2, rl, ru);
1270+
else
1271+
matrixMultDenseSparseOutSparse(m1, m2, ret, pm2, rl, ru);
1272+
}
1273+
else
1274+
matrixMultDenseSparseOutDense(m1, m2, ret, pm2, rl, ru);
1275+
}
1276+
1277+
1278+
private static void matrixMultDenseDenseOutSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean pm2,
1279+
int rl, int ru) {
1280+
final DenseBlock a = m1.getDenseBlock();
1281+
final DenseBlock b = m2.getDenseBlock();
1282+
final SparseBlock c = ret.getSparseBlock();
1283+
final int m = m1.rlen; // rows left
1284+
final int cd = m1.clen; // common dim
1285+
final int n = m2.clen;
1286+
1287+
final int rl1 = pm2 ? 0 : rl;
1288+
final int ru1 = pm2 ? m : ru;
1289+
final int rl2 = pm2 ? rl : 0;
1290+
final int ru2 = pm2 ? ru : cd;
1291+
1292+
final int blocksizeK = 32;
1293+
final int blocksizeI = 32;
1294+
1295+
for(int bi = rl1; bi < ru1; bi += blocksizeI) {
1296+
for(int bk = rl2, bimin = Math.min(ru1, bi + blocksizeI); bk < ru2; bk += blocksizeK) {
1297+
final int bkmin = Math.min(ru2, bk + blocksizeK);
1298+
// core sub block matrix multiplication
1299+
for(int i = bi; i < bimin; i++) { // rows left
1300+
final double[] avals = a.values(i);
1301+
final int aix = a.pos(i);
1302+
for(int k = bk; k < bkmin; k++) { // common dimension
1303+
final double aval = avals[aix + k];
1304+
if(aval != 0) {
1305+
final double[] bvals = b.values(k);
1306+
final int bpos = b.pos(k);
1307+
for(int j = 0; j < n; j++) {
1308+
final double bv = bvals[bpos + j];
1309+
c.add(i, j, aval * bv);
1310+
}
1311+
}
1312+
}
1313+
}
1314+
}
1315+
}
1316+
}
1317+
1318+
1319+
private static void matrixMultDenseSparseOutSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean pm2,
1320+
int rl, int ru) {
1321+
final DenseBlock a = m1.getDenseBlock();
1322+
final SparseBlock b = m2.getSparseBlock();
1323+
final SparseBlock c = ret.getSparseBlock();
1324+
final int m = m1.rlen; // rows left
1325+
final int cd = m1.clen; // common dim
1326+
1327+
final int rl1 = pm2 ? 0 : rl;
1328+
final int ru1 = pm2 ? m : ru;
1329+
final int rl2 = pm2 ? rl : 0;
1330+
final int ru2 = pm2 ? ru : cd;
1331+
1332+
final int blocksizeK = 32;
1333+
final int blocksizeI = 32;
1334+
1335+
for(int bi = rl1; bi < ru1; bi += blocksizeI) {
1336+
for(int bk = rl2, bimin = Math.min(ru1, bi + blocksizeI); bk < ru2; bk += blocksizeK) {
1337+
final int bkmin = Math.min(ru2, bk + blocksizeK);
1338+
// core sub block matrix multiplication
1339+
for(int i = bi; i < bimin; i++) { // rows left
1340+
final double[] avals = a.values(i);
1341+
final int aix = a.pos(i);
1342+
for(int k = bk; k < bkmin; k++) { // common dimension
1343+
final double aval = avals[aix + k];
1344+
if(aval == 0 || b.isEmpty(k))
1345+
continue;
1346+
final int[] bIdx = b.indexes(k);
1347+
final double[] bVals = b.values(k);
1348+
final int bPos = b.pos(k);
1349+
final int bEnd = bPos + b.size(k);
1350+
for(int j = bPos; j < bEnd ; j++){
1351+
c.add(i, bIdx[j], aval * bVals[j]);
1352+
}
1353+
}
1354+
}
1355+
}
1356+
}
1357+
}
1358+
1359+
private static void matrixMultDenseSparseOutDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, boolean pm2, int rl,
1360+
int ru) {
12601361
DenseBlock a = m1.getDenseBlock();
12611362
DenseBlock c = ret.getDenseBlock();
12621363
int m = m1.rlen;
@@ -1907,8 +2008,10 @@ private static void matrixMultUltraSparseRight(MatrixBlock m1, MatrixBlock m2, M
19072008
if(ret.isInSparseFormat()){
19082009
if(m1.isInSparseFormat())
19092010
matrixMultUltraSparseRightSparseMCSRLeftSparseOut(m1, m2, ret, rl, ru);
1910-
else
2011+
else if (m2.isInSparseFormat())
19112012
matrixMultUltraSparseRightDenseLeftSparseOut(m1, m2, ret, rl, ru);
2013+
else
2014+
matrixMultUltraSparseDenseInput(m1, m2, ret, rl, ru);
19122015
}
19132016
else if(ret.getDenseBlock().isContiguous())
19142017
matrixMultUltraSparseRightDenseOut(m1, m2, ret, rl, ru);
@@ -1990,6 +2093,30 @@ private static void matrixMultUltraSparseRightDenseLeftSparseOut(MatrixBlock m1,
19902093
}
19912094
}
19922095

2096+
private static void matrixMultUltraSparseDenseInput(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru){
2097+
final int cd = m1.clen;
2098+
final int rc = m2.clen;
2099+
final DenseBlock a = m1.denseBlock;
2100+
final DenseBlock b = m2.denseBlock;
2101+
final SparseBlockMCSR c = (SparseBlockMCSR) ret.sparseBlock;
2102+
2103+
for(int i = rl; i < ru; i++) {
2104+
// it is known that the left matrix is most likely containing many zeros.
2105+
final double[] av = a.values(i);
2106+
final int pos = a.pos(i);
2107+
for(int k = 0; k < cd; k++) {
2108+
final double v = av[pos + k];
2109+
if(v != 0) {
2110+
final double[] bv = b.values(k);
2111+
final int posb = b.pos(k);
2112+
for(int j = 0; j < rc; j++) {
2113+
c.add(i,j, bv[posb + j] * v);
2114+
}
2115+
}
2116+
}
2117+
}
2118+
}
2119+
19932120
private static void mmDenseMatrixSparseRow(int bpos, int blen, int[] bixs, double[] bvals, int k, int i,
19942121
DenseBlock a, SparseBlockMCSR c) {
19952122
final double[] aval = a.values(i);
@@ -4419,6 +4546,8 @@ public static boolean isUltraSparseMatrixMult(MatrixBlock m1, MatrixBlock m2, bo
44194546
}
44204547

44214548
public static boolean isSparseOutputMatrixMult(MatrixBlock m1, MatrixBlock m2) {
4549+
if(m2.rlen == 1 && m2.nonZeros < m2.clen / 4) // vector right ... that is sparse.
4550+
return true;
44224551
//output is a matrix (not vector), very likely sparse, and output rows fit into L1 cache
44234552
if( !(m1.sparse && m2.sparse && m1.rlen > 1 && m2.clen > 1) )
44244553
return false;
@@ -4551,7 +4680,7 @@ private static class MatrixMultTask implements Callable<Object>
45514680
private final boolean _pm2r; //par over m2 rows
45524681
private final boolean _pm2c; //par over m2 rows
45534682
private final boolean _m1Perm; //sparse permutation
4554-
private final boolean _sparse; //sparse output
4683+
// private final boolean _sparse; //sparse output
45554684
private final int _rl;
45564685
private final int _ru;
45574686
private final ConcurrentHashMap<double[], double[]> _cache;
@@ -4565,7 +4694,7 @@ protected MatrixMultTask( MatrixBlock m1, MatrixBlock m2, MatrixBlock ret,
45654694
_pm2r = pm2r;
45664695
_pm2c = pm2c;
45674696
_m1Perm = m1Perm;
4568-
_sparse = sparse;
4697+
// _sparse = sparse;
45694698
_rl = rl;
45704699
_ru = ru;
45714700
_cache = cache;
@@ -4594,14 +4723,14 @@ public Object call() {
45944723
//compute block matrix multiplication
45954724
if( _ret.sparse ) //ultra-sparse
45964725
matrixMultUltraSparse(_m1, _m2, _ret, _m1Perm, rl, ru);
4597-
else if(!_m1.sparse && !_m2.sparse)
4726+
else if(!_m1.sparse && !_m2.sparse && !_ret.sparse){
45984727
if(_m1.denseBlock instanceof DenseBlockFP64DEDUP && _m2.denseBlock.isContiguous(0,_m1.clen) && cl == 0 && cu == _m2.clen)
45994728
matrixMultDenseDenseMMDedup((DenseBlockFP64DEDUP) _m1.denseBlock, _m2.denseBlock, (DenseBlockFP64DEDUP) _ret.denseBlock, _m2.clen, _m1.clen, rl, ru, _cache);
46004729
else
46014730
matrixMultDenseDense(_m1, _m2, _ret, _tm2, _pm2r, rl, ru, cl, cu);
4602-
4731+
}
46034732
else if(_m1.sparse && _m2.sparse)
4604-
matrixMultSparseSparse(_m1, _m2, _ret, _pm2r, _sparse, rl, ru);
4733+
matrixMultSparseSparse(_m1, _m2, _ret, _pm2r, _ret.sparse, rl, ru);
46054734
else if(_m1.sparse)
46064735
matrixMultSparseDense(_m1, _m2, _ret, _pm2r, rl, ru);
46074736
else

src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, MatrixBlock
140140
else
141141
LOG.warn("Was valid for native MM but native lib was not loaded");
142142

143-
return LibMatrixMult.matrixMult(m1, m2, ret, k);
143+
return LibMatrixMult.matrixMultNonNative(m1, m2, ret, k);
144144
}
145145

146146
public static void tsmm(MatrixBlock m1, MatrixBlock ret, boolean leftTrans, int k) {

src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4994,10 +4994,7 @@ public final MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m
49944994
public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
49954995
checkAggregateBinaryOperations(m1, m2, op);
49964996
final int k = op.getNumThreads();
4997-
if(NativeHelper.isNativeLibraryLoaded())
4998-
return LibMatrixNative.matrixMult(m1, m2, ret, k);
4999-
else
5000-
return LibMatrixMult.matrixMult(m1, m2, ret, k);
4997+
return LibMatrixMult.matrixMult(m1, m2, ret, k);
50014998
}
50024999

50035000
protected void checkAggregateBinaryOperations(MatrixBlock m1, MatrixBlock m2, AggregateBinaryOperator op) {

src/test/java/org/apache/sysds/test/component/matrix/MatrixMultiplyTest.java

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,13 @@ public class MatrixMultiplyTest {
5252
// parallelization degree
5353
private final int k;
5454

55-
public MatrixMultiplyTest(int i, int j, int k, double s, double s2, int p) {
55+
public MatrixMultiplyTest(int i, int j, int k, double s, double s2, int p, boolean self) {
5656
try {
5757
this.left = TestUtils.ceil(TestUtils.generateTestMatrixBlock(i, j, -10, 10, i == 1 && j == 1 ? 1 : s, 13));
58-
this.right = TestUtils.ceil(TestUtils.generateTestMatrixBlock(j, k, -10, 10, k == 1 && k == 1 ? 1 : s2, 14));
58+
if(self)
59+
this.right = left;
60+
else
61+
this.right = TestUtils.ceil(TestUtils.generateTestMatrixBlock(j, k, -10, 10, k == 1 && k == 1 ? 1 : s2, 14));
5962

6063
this.exp = multiply(left, right, 1);
6164
this.k = p;
@@ -83,23 +86,33 @@ public static Collection<Object[]> data() {
8386
for(int i = 0; i < is.length; i++) {
8487
for(int j = 0; j < js.length; j++) {
8588
for(int k = 0; k < ks.length; k++) {
86-
tests.add(new Object[] {is[i], js[j], ks[k], sparsities[s], sparsities[s2], par[p]});
89+
tests.add(new Object[] {is[i], js[j], ks[k], sparsities[s], sparsities[s2], par[p], false});
8790
}
8891
}
8992
}
9093
}
9194
}
9295
}
9396

94-
tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0001, 6});
95-
tests.add(new Object[]{1000, 100, 1000, 0.01, 0.3, 6});
96-
tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0005, 6});
97-
tests.add(new Object[]{1000, 100, 1000, 0.005, 0.3, 6});
98-
99-
tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0001, 6});
100-
tests.add(new Object[]{1000, 100, 1000, 0.01, 0.6, 6});
101-
tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0005, 6});
102-
tests.add(new Object[]{1000, 100, 1000, 0.005, 0.6, 6});
97+
tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0001, 6, false});
98+
tests.add(new Object[]{1000, 100, 1000, 0.01, 0.3, 6, false});
99+
tests.add(new Object[]{1000, 100, 1000, 0.3, 0.0005, 6, false});
100+
tests.add(new Object[]{1000, 100, 1000, 0.005, 0.3, 6, false});
101+
102+
tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0001, 6, false});
103+
tests.add(new Object[]{1000, 100, 1000, 0.01, 0.6, 6, false});
104+
tests.add(new Object[]{1000, 100, 1000, 0.6, 0.0005, 6, false});
105+
tests.add(new Object[]{1000, 100, 1000, 0.005, 0.6, 6, false});
106+
107+
// 0.00004 ultra sparse turn point
108+
tests.add(new Object[]{100, 100, 10000, 0.5, 0.00003, 6, false});
109+
tests.add(new Object[]{10000, 100, 100, 0.00003, 0.6, 6, false});
110+
111+
112+
tests.add(new Object[]{3, 10, 100000, 1.0, 0.00003, 6, false});
113+
tests.add(new Object[]{100000, 10, 3, 0.00003, 1.0, 6, false});
114+
115+
tests.add(new Object[]{1000, 1000, 1000, 0.005, 0.6, 6, true});
103116

104117
}
105118
catch(Exception e) {

0 commit comments

Comments
 (0)