|
| 1 | +// g++ -std=c++11 -Llib -I../include -I../src mxm_gen.cpp -ltaco;./a.out |
| 2 | + |
| 3 | +#include "taco.h" |
| 4 | +#include "taco/index_notation/transformations.h" |
| 5 | +#include "codegen/codegen_c.h" |
| 6 | +#include "taco/lower/lower.h" |
| 7 | + |
| 8 | +#include <iostream> |
| 9 | +#include <vector> |
| 10 | +#include <string> |
| 11 | +#include <memory> |
| 12 | +#include <fstream> |
| 13 | + |
| 14 | +using namespace taco; |
| 15 | + |
| 16 | +ir::Expr orImpl(const std::vector<ir::Expr>& v) { |
| 17 | + return ir::Or::make(v[0], v[1]); |
| 18 | +} |
| 19 | +Func OrOp("or", orImpl, {Annihilator(true), Identity(false), Commutative(), Associative()}); |
| 20 | + |
| 21 | +ir::Expr andImpl(const std::vector<ir::Expr>& v) { |
| 22 | + return ir::And::make(v[0], v[1]); |
| 23 | +} |
| 24 | +Func AndOp("and", andImpl, {Annihilator(false), Identity(true), Commutative(), Associative()}); |
| 25 | + |
| 26 | +ir::Expr addImpl(const std::vector<ir::Expr>& v) { |
| 27 | + return ir::Add::make(v[0], v[1]); |
| 28 | +} |
| 29 | +Func AddOp("add", addImpl, {Annihilator(std::numeric_limits<double>::infinity()), Identity(0), Commutative(), Associative()}); |
| 30 | + |
| 31 | +ir::Expr minImpl(const std::vector<ir::Expr>& v) { |
| 32 | + return ir::Min::make(v[0], v[1]); |
| 33 | +} |
| 34 | +Func MinOp("min", minImpl, {Identity(std::numeric_limits<double>::infinity()), Commutative(), Associative()}); |
| 35 | + |
| 36 | +ir::Expr maskImpl(const std::vector<ir::Expr>& v) { |
| 37 | + return v[0]; |
| 38 | +} |
| 39 | +struct MaskAlgebra { |
| 40 | + IterationAlgebra operator()(const std::vector<IndexExpr>& r) { |
| 41 | + return Intersect(r[0], Complement(r[1])); |
| 42 | + } |
| 43 | +}; |
| 44 | +Func MaskOp("mask", maskImpl, MaskAlgebra()); |
| 45 | + |
| 46 | +static bool compare(std::vector<IndexVar> vars1, std::vector<IndexVar> vars2) { |
| 47 | + return vars1 == vars2; |
| 48 | +} |
| 49 | + |
| 50 | +static IndexStmt optimizeSpGEMM(IndexStmt stmt) { |
| 51 | + if (!isa<Forall>(stmt)) { |
| 52 | + return stmt; |
| 53 | + } |
| 54 | + Forall foralli = to<Forall>(stmt); |
| 55 | + IndexVar i = foralli.getIndexVar(); |
| 56 | + |
| 57 | + if (!isa<Forall>(foralli.getStmt())) { |
| 58 | + return stmt; |
| 59 | + } |
| 60 | + Forall forallk = to<Forall>(foralli.getStmt()); |
| 61 | + IndexVar k = forallk.getIndexVar(); |
| 62 | + |
| 63 | + if (!isa<Forall>(forallk.getStmt())) { |
| 64 | + return stmt; |
| 65 | + } |
| 66 | + Forall forallj = to<Forall>(forallk.getStmt()); |
| 67 | + IndexVar j = forallj.getIndexVar(); |
| 68 | + |
| 69 | + if (!isa<Assignment>(forallj.getStmt())) { |
| 70 | + return stmt; |
| 71 | + } |
| 72 | + Assignment assignment = to<Assignment>(forallj.getStmt()); |
| 73 | + IndexExpr reduceOp = assignment.getOperator(); |
| 74 | + |
| 75 | + if (!isa<Call>(assignment.getRhs())) { |
| 76 | + return stmt; |
| 77 | + } |
| 78 | + Call mul = to<Call>(assignment.getRhs()); |
| 79 | + |
| 80 | + taco_iassert(isa<Access>(assignment.getLhs())); |
| 81 | + if (!isa<Access>(mul.getArgs()[0])) { |
| 82 | + return stmt; |
| 83 | + } |
| 84 | + if (!isa<Access>(mul.getArgs()[1])) { |
| 85 | + return stmt; |
| 86 | + } |
| 87 | + |
| 88 | + Access Aaccess = to<Access>(assignment.getLhs()); |
| 89 | + Access Baccess = to<Access>(mul.getArgs()[0]); |
| 90 | + Access Caccess = to<Access>(mul.getArgs()[1]); |
| 91 | + |
| 92 | + if (Aaccess.getIndexVars().size() != 2 || |
| 93 | + Baccess.getIndexVars().size() != 2 || |
| 94 | + Caccess.getIndexVars().size() != 2) { |
| 95 | + return stmt; |
| 96 | + } |
| 97 | + |
| 98 | + if (!compare(Aaccess.getIndexVars(), {i,j}) || |
| 99 | + !compare(Baccess.getIndexVars(), {i,k}) || |
| 100 | + !compare(Caccess.getIndexVars(), {k,j})) { |
| 101 | + return stmt; |
| 102 | + } |
| 103 | + |
| 104 | + TensorVar A = Aaccess.getTensorVar(); |
| 105 | + if (A.getFormat().getModeFormats()[0].getName() != "dense" || |
| 106 | + A.getFormat().getModeFormats()[1].getName() != "compressed" || |
| 107 | + A.getFormat().getModeOrdering()[0] != 0 || |
| 108 | + A.getFormat().getModeOrdering()[1] != 1) { |
| 109 | + return stmt; |
| 110 | + } |
| 111 | + |
| 112 | + // I think we can to linear combination of rows as long as there are no permutations in the format and the |
| 113 | + // level formats are ordered. The i -> k -> j loops should iterate over the data structures without issue. |
| 114 | + TensorVar B = Baccess.getTensorVar(); |
| 115 | + if (!B.getFormat().getModeFormats()[0].isOrdered() && A.getFormat().getModeFormats()[0].isOrdered() || |
| 116 | + !B.getFormat().getModeFormats()[1].isOrdered() && A.getFormat().getModeFormats()[1].isOrdered() || |
| 117 | + B.getFormat().getModeOrdering()[0] != 0 || |
| 118 | + B.getFormat().getModeOrdering()[1] != 1) { |
| 119 | + return stmt; |
| 120 | + } |
| 121 | + |
| 122 | + TensorVar C = Caccess.getTensorVar(); |
| 123 | + if (!C.getFormat().getModeFormats()[0].isOrdered() && A.getFormat().getModeFormats()[0].isOrdered() || |
| 124 | + !C.getFormat().getModeFormats()[1].isOrdered() && A.getFormat().getModeFormats()[1].isOrdered() || |
| 125 | + C.getFormat().getModeOrdering()[0] != 0 || |
| 126 | + C.getFormat().getModeOrdering()[1] != 1) { |
| 127 | + return stmt; |
| 128 | + } |
| 129 | + |
| 130 | + // It's an SpMM statement so return an optimized SpMM statement |
| 131 | + TensorVar w("w", |
| 132 | + Type(A.getType().getDataType(), |
| 133 | + {A.getType().getShape().getDimension(1)}), |
| 134 | + taco::dense, |
| 135 | + A.getFill()); |
| 136 | + return forall(i, |
| 137 | + where(forall(j, |
| 138 | + A(i,j) = w(j)), |
| 139 | + forall(k, |
| 140 | + forall(j, |
| 141 | + Assignment(w(j), mul, reduceOp))))); |
| 142 | +} |
| 143 | + |
| 144 | +void printToFile(std::string filename, IndexStmt stmt) { |
| 145 | + std::stringstream source; |
| 146 | + |
| 147 | + std::shared_ptr<ir::CodeGen> codegen = ir::CodeGen::init_default(source, ir::CodeGen::ImplementationGen); |
| 148 | + ir::Stmt compute = lower(stmt, "evaluate", true, true); |
| 149 | + codegen->compile(compute, true); |
| 150 | + |
| 151 | + std::ofstream source_file; |
| 152 | + source_file.open(filename + ".c"); |
| 153 | + source_file << source.str(); |
| 154 | + source_file.close(); |
| 155 | +} |
| 156 | + |
| 157 | +Format UCSR({Dense, Compressed(ModeFormat::NOT_ORDERED)}); |
| 158 | +Format UZCSR({Dense, Compressed({ModeFormat::NOT_ORDERED, ModeFormat::ZEROLESS})}); |
| 159 | + |
| 160 | +int main() { |
| 161 | + IndexVar i("i"), j("j"), k("k"); |
| 162 | +#if 1 |
| 163 | + Tensor<double> A("A", {200, 200}, UCSR, std::numeric_limits<double>::infinity()); |
| 164 | + Tensor<double> B("B", {200, 200}, UCSR, std::numeric_limits<double>::infinity()); |
| 165 | + Tensor<double> C("C", {200, 200}, UCSR, std::numeric_limits<double>::infinity()); |
| 166 | + A(i,j) = Reduction(MinOp(), k, AddOp(B(i,k), C(k,j))); |
| 167 | +#else |
| 168 | + Tensor<bool> A("A", {200, 200}, UCSR); |
| 169 | + Tensor<bool> B("B", {200, 200}, UZCSR); |
| 170 | + Tensor<bool> C("C", {200, 200}, UZCSR); |
| 171 | + A(i,j) = Reduction(OrOp(), k, AndOp(B(i,k), C(k,j))); |
| 172 | +#endif |
| 173 | + IndexStmt stmt = A.getAssignment().concretize(); |
| 174 | + stmt = reorderLoopsTopologically(stmt); |
| 175 | + stmt = optimizeSpGEMM(stmt); |
| 176 | + stmt = stmt.assemble(A.getTensorVar(), AssembleStrategy::Insert); |
| 177 | + IndexVar qi = to<Forall>(to<Assemble>(stmt).getQueries()).getIndexVar(); |
| 178 | + stmt = stmt.parallelize(i, ParallelUnit::CPUThread, |
| 179 | + OutputRaceStrategy::NoRaces) |
| 180 | + .parallelize(qi, ParallelUnit::CPUThread, |
| 181 | + OutputRaceStrategy::NoRaces); |
| 182 | + stmt = scalarPromote(stmt); |
| 183 | + printToFile("mxm", stmt); |
| 184 | + return 0; |
| 185 | +} |
0 commit comments