Skip to content

Commit 0bc607e

Browse files
committed
add in changes to get taco/image running on lanka without TNS explosion
1 parent d54a4c0 commit 0bc607e

File tree

1 file changed

+70
-17
lines changed

1 file changed

+70
-17
lines changed

taco/image.cpp

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,35 @@ Func andOp1("logical_and", Boolean(), andAlgebra());
6868
Func xorAndOp("fused_xor_and", Boolean(), xorAndAlgebra());
6969
Func testOp("test", Boolean(), testConstructionAlgebra());
7070
static void bench_image_xor(benchmark::State& state, const Format& f) {
71-
int num = state.range(0);
7271
auto t1 = 0.5;
7372
auto t2 = 0.55;
74-
Tensor<int64_t> matrix1 = castToTypeZero<int64_t>("A", loadImageTensor("A", num, f, t1, 1 /* variant */));
75-
Tensor<int64_t> matrix2 = castToTypeZero<int64_t>("B", loadImageTensor("B", num, f, t2, 2 /* variant */));
73+
74+
auto num_str = getEnvVar("IMAGE_NUM");
75+
if (num_str == "") {
76+
state.error_occurred();
77+
return;
78+
}
79+
80+
int num = std::stoi(num_str);
81+
82+
taco::Tensor<int64_t> matrix1, matrix2;
83+
try {
84+
matrix1 = castToTypeZero<int64_t>("A", loadImageTensor("A", num, f, t1, 1 /* variant */));
85+
matrix2 = castToTypeZero<int64_t>("B", loadImageTensor("B", num, f, t2, 2 /* variant */));
86+
} catch (TacoException& e) {
87+
// Counters don't show up in the generated CSV if we used SkipWithError, so
88+
// just add in the label that this run is skipped.
89+
state.SetLabel(num_str+"/SKIPPED-FAILED-READ");
90+
return;
91+
}
92+
7693
auto dims = matrix1.getDimensions();
7794

7895
for (auto _ : state) {
7996
state.PauseTiming();
8097
Tensor<int64_t> result("result", dims, f, 1);
8198
IndexVar i("i"), j("j");
82-
result(i, j) = testOp(matrix1(i, j), matrix2(i, j));
99+
result(i, j) = xorOp1(matrix1(i, j), matrix2(i, j));
83100
result.setAssembleWhileCompute(true);
84101
result.compile();
85102
state.ResumeTiming();
@@ -98,15 +115,33 @@ static void CustomArguments(benchmark::internal::Benchmark* b) {
98115
for (int i = 1; i <= 98; ++i)
99116
b->Args({i});
100117
}
101-
TACO_BENCH_ARGS(bench_image_xor, csr, CSR)->Apply(CustomArguments);
118+
TACO_BENCH_ARGS(bench_image_xor, csr, CSR);
102119

103120
static void bench_image_fused(benchmark::State& state, const Format& f) {
104-
int num = state.range(0);
121+
// int num = state.range(0);
105122
auto t1 = 0.5;
106123
auto t2 = 0.55;
107-
Tensor<int64_t> matrix1 = castToTypeZero<int64_t>("A", loadImageTensor("A", num, f, t1, 1 /* variant */));
108-
Tensor<int64_t> matrix2 = castToTypeZero<int64_t>("B", loadImageTensor("B", num, f, t2, 2 /* variant */));
109-
Tensor<int64_t> matrix3 = castToTypeZero<int64_t>("C", loadImageTensor("C", num, f, 3 /* variant */));
124+
125+
auto num_str = getEnvVar("IMAGE_NUM");
126+
if (num_str == "") {
127+
state.error_occurred();
128+
return;
129+
}
130+
131+
int num = std::stoi(num_str);
132+
133+
taco::Tensor<int64_t> matrix1, matrix2, matrix3;
134+
try {
135+
matrix1 = castToTypeZero<int64_t>("A", loadImageTensor("A", num, f, t1, 1 /* variant */));
136+
matrix2 = castToTypeZero<int64_t>("B", loadImageTensor("B", num, f, t2, 2 /* variant */));
137+
matrix3 = castToTypeZero<int64_t>("C", loadImageTensor("C", num, f, 3 /* variant */));
138+
} catch (TacoException& e) {
139+
// Counters don't show up in the generated CSV if we used SkipWithError, so
140+
// just add in the label that this run is skipped.
141+
state.SetLabel(num_str+"/SKIPPED-FAILED-READ");
142+
return;
143+
}
144+
110145
auto dims = matrix1.getDimensions();
111146

112147
// write("temp/taco-mat1-" + std::to_string(num) + ".tns", matrix1);
@@ -153,14 +188,32 @@ static void bench_image_fused(benchmark::State& state, const Format& f) {
153188
// codegen->compile(compute, true);
154189
}
155190
}
156-
TACO_BENCH_ARGS(bench_image_fused, csr, CSR)->Apply(CustomArguments);
191+
TACO_BENCH_ARGS(bench_image_fused, csr, CSR);
157192

158193
static void bench_image_window(benchmark::State& state, const Format& f, double window_size) {
159-
int num = state.range(0);
194+
// int num = state.range(0);
160195
auto t1 = 0.5;
161196
auto t2 = 0.55;
162-
Tensor<int64_t> matrix1 = castToTypeZero<int64_t>("A", loadImageTensor("A", num, f, t1, 1 /* variant */));
163-
Tensor<int64_t> matrix2 = castToTypeZero<int64_t>("B", loadImageTensor("B", num, f, t2, 2 /* variant */));
197+
198+
auto num_str = getEnvVar("IMAGE_NUM");
199+
if (num_str == "") {
200+
state.error_occurred();
201+
return;
202+
}
203+
204+
int num = std::stoi(num_str);
205+
206+
taco::Tensor<int64_t> matrix1, matrix2, matrix3;
207+
try {
208+
matrix1 = castToTypeZero<int64_t>("A", loadImageTensor("A", num, f, t1, 1 /* variant */));
209+
matrix2 = castToTypeZero<int64_t>("B", loadImageTensor("B", num, f, t2, 2 /* variant */));
210+
} catch (TacoException& e) {
211+
// Counters don't show up in the generated CSV if we used SkipWithError, so
212+
// just add in the label that this run is skipped.
213+
state.SetLabel(num_str+"/SKIPPED-FAILED-READ");
214+
return;
215+
}
216+
164217
auto dims = matrix1.getDimensions();
165218

166219
int mid0 = (dims[0]/2.0);
@@ -195,7 +248,7 @@ static void bench_image_window(benchmark::State& state, const Format& f, double
195248
// codegen->compile(compute, true);
196249
}
197250
}
198-
TACO_BENCH_ARGS(bench_image_window, csr/0.25, CSR, 0.25)->Apply(CustomArguments);
199-
TACO_BENCH_ARGS(bench_image_window, csr/0.2, CSR, 0.2)->Apply(CustomArguments);
200-
TACO_BENCH_ARGS(bench_image_window, csr/0.15, CSR, 0.15)->Apply(CustomArguments);
201-
TACO_BENCH_ARGS(bench_image_window, csr/0.1, CSR, 0.1)->Apply(CustomArguments);
251+
TACO_BENCH_ARGS(bench_image_window, csr/0.25, CSR, 0.25);
252+
TACO_BENCH_ARGS(bench_image_window, csr/0.2, CSR, 0.2);
253+
TACO_BENCH_ARGS(bench_image_window, csr/0.15, CSR, 0.15);
254+
TACO_BENCH_ARGS(bench_image_window, csr/0.1, CSR, 0.1);

0 commit comments

Comments
 (0)