3838 ->ReportAggregatesOnly(true ) \
3939 ->UseRealTime()
4040
41+ template<typename T>
42+ taco::Tensor<T> castToType(std::string name, taco::Tensor<double > tensor) {
43+ taco::Tensor<T> result (name, tensor.getDimensions (), tensor.getFormat ());
44+ std::vector<int > coords (tensor.getOrder ());
45+ for (auto & value : taco::iterate<double >(tensor)) {
46+ for (int i = 0 ; i < tensor.getOrder (); i++) {
47+ coords[i] = value.first [i];
48+ }
49+ result.insert (coords, T (value.second ));
50+ }
51+ result.pack ();
52+ return result;
53+ }
54+
55+ struct TacoTensorFileCache {
56+ template <typename T>
57+ taco::Tensor<double > read (std::string path, T format) {
58+ if (this ->lastPath == path) {
59+ return lastLoaded;
60+ }
61+ // TODO (rohany): Not worrying about whether the format was the same as what was asked for.
62+ this ->lastLoaded = taco::read (path, format);
63+ this ->lastPath = path;
64+ return this ->lastLoaded ;
65+ }
66+
67+ template <typename T, typename U>
68+ taco::Tensor<T> readIntoType (std::string name, std::string path, U format) {
69+ auto tensor = this ->read <U>(path, format);
70+ return castToType<T>(name, tensor);
71+ }
72+
73+ taco::Tensor<double > lastLoaded;
74+ std::string lastPath;
75+ };
76+
4177std::string getTacoTensorPath ();
4278taco::TensorBase loadRandomTensor (std::string name, std::vector<int > dims, float sparsity, taco::Format format);
4379
80+ // TODO (rohany): Cache the tensor shifts too.
4481template <typename T, typename T2>
4582taco::Tensor<T> shiftLastMode (std::string name, taco::Tensor<T2> original) {
4683 taco::Tensor<T> result (name, original.getDimensions (), original.getFormat ());
@@ -63,30 +100,4 @@ taco::Tensor<T> shiftLastMode(std::string name, taco::Tensor<T2> original) {
63100 return result;
64101}
65102
66- template <typename T>
67- taco::Tensor<T> castToType (std::string name, taco::Tensor<double > tensor) {
68- taco::Tensor<T> result (name, tensor.getDimensions (), tensor.getFormat ());
69- std::vector<int > coords (tensor.getOrder ());
70- for (auto & value : taco::iterate<double >(tensor)) {
71- for (int i = 0 ; i < tensor.getOrder (); i++) {
72- coords[i] = value.first [i];
73- }
74- result.insert (coords, T (value.second ));
75- }
76- result.pack ();
77- return result;
78- }
79-
80- template <typename T>
81- taco::Tensor<T> readIntoType (std::string name, std::string path, taco::ModeFormat format) {
82- auto tensor = taco::read (path, format);
83- return castToType<T>(name, tensor);
84- }
85-
86- template <typename T>
87- taco::Tensor<T> readIntoType (std::string name, std::string path, taco::Format format) {
88- auto tensor = taco::read (path, format);
89- return castToType<T>(name, tensor);
90- }
91-
92103#endif // TACO_BENCH_BENCH_H
0 commit comments