|
7 | 7 | #include <ATen/Tensor.h>
|
8 | 8 | #include <ATen/Utils.h>
|
9 | 9 | #include <ATen/mps/MPSStream.h>
|
| 10 | +#include <ATen/native/mps/MetalShaderLibrary.h> |
10 | 11 | #include <ATen/native/mps/TensorFactory.h>
|
11 | 12 | #include <c10/core/ScalarType.h>
|
12 | 13 | #include <torch/library.h>
|
@@ -342,46 +343,6 @@ inline bool is_dense_in_storage(const TensorBase& t) {
|
342 | 343 | return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
|
343 | 344 | }
|
344 | 345 |
|
345 |
| -class MetalShaderLibrary { |
346 |
| - public: |
347 |
| - MetalShaderLibrary(const std::string& src) : shaderSource(src), nparams(0), compile_options(nullptr) {} |
348 |
| - MetalShaderLibrary(const std::string& src, unsigned nparams_) |
349 |
| - : shaderSource(src), nparams(nparams_), compile_options(nullptr) {} |
350 |
| - MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_) |
351 |
| - : shaderSource(src), nparams(nparams_), compile_options(compile_options_) {} |
352 |
| - MetalShaderLibrary(const MetalShaderLibrary&) = delete; |
353 |
| - inline id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname) { |
354 |
| - return getLibraryPipelineState(getLibrary(), fname).first; |
355 |
| - } |
356 |
| - id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname, |
357 |
| - const std::initializer_list<std::string>& params) { |
358 |
| - return getLibraryPipelineState(getLibrary(params), fname).first; |
359 |
| - } |
360 |
| - inline id<MTLFunction> getMTLFunction(const std::string& fname) { |
361 |
| - return getLibraryPipelineState(getLibrary(), fname).second; |
362 |
| - } |
363 |
| - id<MTLFunction> getMTLFunction(const std::string& fname, const std::initializer_list<std::string>& params) { |
364 |
| - return getLibraryPipelineState(getLibrary(params), fname).second; |
365 |
| - } |
366 |
| - static MetalShaderLibrary& getBundledLibrary(); |
367 |
| - |
368 |
| - protected: |
369 |
| - virtual id<MTLLibrary> getLibrary(); |
370 |
| - virtual id<MTLLibrary> getLibrary(const std::initializer_list<std::string>& params); |
371 |
| - id<MTLLibrary> library = nil; |
372 |
| - |
373 |
| - private: |
374 |
| - std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getLibraryPipelineState(id<MTLLibrary> lib, |
375 |
| - const std::string& fname); |
376 |
| - |
377 |
| - id<MTLLibrary> compileLibrary(const std::string& src); |
378 |
| - std::string shaderSource; |
379 |
| - unsigned nparams; |
380 |
| - MTLCompileOptions* compile_options; |
381 |
| - std::unordered_map<std::string, id<MTLLibrary>> libMap; |
382 |
| - std::unordered_map<std::string, std::pair<id<MTLComputePipelineState>, id<MTLFunction>>> cplMap; |
383 |
| -}; |
384 |
| - |
385 | 346 | namespace detail {
|
386 | 347 | template <typename T>
|
387 | 348 | class has_size_type {
|
|
0 commit comments