Skip to content

Commit 759585b

Browse files
author
Olivia W Hsu
committed
Merge branch 'main' of github.com:tensor-compiler/array-programming-benchmarks into main
2 parents f2c6f7a + 3202b3a commit 759585b

File tree

1 file changed

+185
-0
lines changed

1 file changed

+185
-0
lines changed

gen/mxm_gen.cpp

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
// g++ -std=c++11 -Llib -I../include -I../src mxm_gen.cpp -ltaco;./a.out
2+
3+
#include "taco.h"
4+
#include "taco/index_notation/transformations.h"
5+
#include "codegen/codegen_c.h"
6+
#include "taco/lower/lower.h"
7+
8+
#include <iostream>
9+
#include <vector>
10+
#include <string>
11+
#include <memory>
12+
#include <fstream>
13+
14+
using namespace taco;
15+
16+
ir::Expr orImpl(const std::vector<ir::Expr>& v) {
17+
return ir::Or::make(v[0], v[1]);
18+
}
19+
Func OrOp("or", orImpl, {Annihilator(true), Identity(false), Commutative(), Associative()});
20+
21+
ir::Expr andImpl(const std::vector<ir::Expr>& v) {
22+
return ir::And::make(v[0], v[1]);
23+
}
24+
Func AndOp("and", andImpl, {Annihilator(false), Identity(true), Commutative(), Associative()});
25+
26+
ir::Expr addImpl(const std::vector<ir::Expr>& v) {
27+
return ir::Add::make(v[0], v[1]);
28+
}
29+
Func AddOp("add", addImpl, {Annihilator(std::numeric_limits<double>::infinity()), Identity(0), Commutative(), Associative()});
30+
31+
ir::Expr minImpl(const std::vector<ir::Expr>& v) {
32+
return ir::Min::make(v[0], v[1]);
33+
}
34+
Func MinOp("min", minImpl, {Identity(std::numeric_limits<double>::infinity()), Commutative(), Associative()});
35+
36+
ir::Expr maskImpl(const std::vector<ir::Expr>& v) {
37+
return v[0];
38+
}
39+
struct MaskAlgebra {
40+
IterationAlgebra operator()(const std::vector<IndexExpr>& r) {
41+
return Intersect(r[0], Complement(r[1]));
42+
}
43+
};
44+
Func MaskOp("mask", maskImpl, MaskAlgebra());
45+
46+
static bool compare(std::vector<IndexVar> vars1, std::vector<IndexVar> vars2) {
47+
return vars1 == vars2;
48+
}
49+
50+
static IndexStmt optimizeSpGEMM(IndexStmt stmt) {
51+
if (!isa<Forall>(stmt)) {
52+
return stmt;
53+
}
54+
Forall foralli = to<Forall>(stmt);
55+
IndexVar i = foralli.getIndexVar();
56+
57+
if (!isa<Forall>(foralli.getStmt())) {
58+
return stmt;
59+
}
60+
Forall forallk = to<Forall>(foralli.getStmt());
61+
IndexVar k = forallk.getIndexVar();
62+
63+
if (!isa<Forall>(forallk.getStmt())) {
64+
return stmt;
65+
}
66+
Forall forallj = to<Forall>(forallk.getStmt());
67+
IndexVar j = forallj.getIndexVar();
68+
69+
if (!isa<Assignment>(forallj.getStmt())) {
70+
return stmt;
71+
}
72+
Assignment assignment = to<Assignment>(forallj.getStmt());
73+
IndexExpr reduceOp = assignment.getOperator();
74+
75+
if (!isa<Call>(assignment.getRhs())) {
76+
return stmt;
77+
}
78+
Call mul = to<Call>(assignment.getRhs());
79+
80+
taco_iassert(isa<Access>(assignment.getLhs()));
81+
if (!isa<Access>(mul.getArgs()[0])) {
82+
return stmt;
83+
}
84+
if (!isa<Access>(mul.getArgs()[1])) {
85+
return stmt;
86+
}
87+
88+
Access Aaccess = to<Access>(assignment.getLhs());
89+
Access Baccess = to<Access>(mul.getArgs()[0]);
90+
Access Caccess = to<Access>(mul.getArgs()[1]);
91+
92+
if (Aaccess.getIndexVars().size() != 2 ||
93+
Baccess.getIndexVars().size() != 2 ||
94+
Caccess.getIndexVars().size() != 2) {
95+
return stmt;
96+
}
97+
98+
if (!compare(Aaccess.getIndexVars(), {i,j}) ||
99+
!compare(Baccess.getIndexVars(), {i,k}) ||
100+
!compare(Caccess.getIndexVars(), {k,j})) {
101+
return stmt;
102+
}
103+
104+
TensorVar A = Aaccess.getTensorVar();
105+
if (A.getFormat().getModeFormats()[0].getName() != "dense" ||
106+
A.getFormat().getModeFormats()[1].getName() != "compressed" ||
107+
A.getFormat().getModeOrdering()[0] != 0 ||
108+
A.getFormat().getModeOrdering()[1] != 1) {
109+
return stmt;
110+
}
111+
112+
// I think we can to linear combination of rows as long as there are no permutations in the format and the
113+
// level formats are ordered. The i -> k -> j loops should iterate over the data structures without issue.
114+
TensorVar B = Baccess.getTensorVar();
115+
if (!B.getFormat().getModeFormats()[0].isOrdered() && A.getFormat().getModeFormats()[0].isOrdered() ||
116+
!B.getFormat().getModeFormats()[1].isOrdered() && A.getFormat().getModeFormats()[1].isOrdered() ||
117+
B.getFormat().getModeOrdering()[0] != 0 ||
118+
B.getFormat().getModeOrdering()[1] != 1) {
119+
return stmt;
120+
}
121+
122+
TensorVar C = Caccess.getTensorVar();
123+
if (!C.getFormat().getModeFormats()[0].isOrdered() && A.getFormat().getModeFormats()[0].isOrdered() ||
124+
!C.getFormat().getModeFormats()[1].isOrdered() && A.getFormat().getModeFormats()[1].isOrdered() ||
125+
C.getFormat().getModeOrdering()[0] != 0 ||
126+
C.getFormat().getModeOrdering()[1] != 1) {
127+
return stmt;
128+
}
129+
130+
// It's an SpMM statement so return an optimized SpMM statement
131+
TensorVar w("w",
132+
Type(A.getType().getDataType(),
133+
{A.getType().getShape().getDimension(1)}),
134+
taco::dense,
135+
A.getFill());
136+
return forall(i,
137+
where(forall(j,
138+
A(i,j) = w(j)),
139+
forall(k,
140+
forall(j,
141+
Assignment(w(j), mul, reduceOp)))));
142+
}
143+
144+
void printToFile(std::string filename, IndexStmt stmt) {
145+
std::stringstream source;
146+
147+
std::shared_ptr<ir::CodeGen> codegen = ir::CodeGen::init_default(source, ir::CodeGen::ImplementationGen);
148+
ir::Stmt compute = lower(stmt, "evaluate", true, true);
149+
codegen->compile(compute, true);
150+
151+
std::ofstream source_file;
152+
source_file.open(filename + ".c");
153+
source_file << source.str();
154+
source_file.close();
155+
}
156+
157+
Format UCSR({Dense, Compressed(ModeFormat::NOT_ORDERED)});
158+
Format UZCSR({Dense, Compressed({ModeFormat::NOT_ORDERED, ModeFormat::ZEROLESS})});
159+
160+
int main() {
161+
IndexVar i("i"), j("j"), k("k");
162+
#if 1
163+
Tensor<double> A("A", {200, 200}, UCSR, std::numeric_limits<double>::infinity());
164+
Tensor<double> B("B", {200, 200}, UCSR, std::numeric_limits<double>::infinity());
165+
Tensor<double> C("C", {200, 200}, UCSR, std::numeric_limits<double>::infinity());
166+
A(i,j) = Reduction(MinOp(), k, AddOp(B(i,k), C(k,j)));
167+
#else
168+
Tensor<bool> A("A", {200, 200}, UCSR);
169+
Tensor<bool> B("B", {200, 200}, UZCSR);
170+
Tensor<bool> C("C", {200, 200}, UZCSR);
171+
A(i,j) = Reduction(OrOp(), k, AndOp(B(i,k), C(k,j)));
172+
#endif
173+
IndexStmt stmt = A.getAssignment().concretize();
174+
stmt = reorderLoopsTopologically(stmt);
175+
stmt = optimizeSpGEMM(stmt);
176+
stmt = stmt.assemble(A.getTensorVar(), AssembleStrategy::Insert);
177+
IndexVar qi = to<Forall>(to<Assemble>(stmt).getQueries()).getIndexVar();
178+
stmt = stmt.parallelize(i, ParallelUnit::CPUThread,
179+
OutputRaceStrategy::NoRaces)
180+
.parallelize(qi, ParallelUnit::CPUThread,
181+
OutputRaceStrategy::NoRaces);
182+
stmt = scalarPromote(stmt);
183+
printToFile("mxm", stmt);
184+
return 0;
185+
}

0 commit comments

Comments
 (0)