diff --git a/aten/src/ATen/native/mps/MetalShaderLibrary.h b/aten/src/ATen/native/mps/MetalShaderLibrary.h index 690efbb6c76896..1e20961069f720 100644 --- a/aten/src/ATen/native/mps/MetalShaderLibrary.h +++ b/aten/src/ATen/native/mps/MetalShaderLibrary.h @@ -1,5 +1,16 @@ #pragma once +#ifdef __OBJC__ #include +typedef id MTLLibrary_t; +typedef id MTLFunction_t; +typedef id MTLComputePipelineState_t; +#else +typedef void MTLCompileOptions; +typedef void* MTLLibrary_t; +typedef void* MTLFunction_t; +typedef void* MTLComputePipelineState_t; +#endif + #include #include @@ -19,19 +30,19 @@ class MetalShaderLibrary { compile_options(compile_options_) {} MetalShaderLibrary(const MetalShaderLibrary&) = delete; virtual ~MetalShaderLibrary() = default; - inline id getPipelineStateForFunc( + inline MTLComputePipelineState_t getPipelineStateForFunc( const std::string& fname) { return getLibraryPipelineState(getLibrary(), fname).first; } - id getPipelineStateForFunc( + MTLComputePipelineState_t getPipelineStateForFunc( const std::string& fname, const std::initializer_list& params) { return getLibraryPipelineState(getLibrary(params), fname).first; } - inline id getMTLFunction(const std::string& fname) { + inline MTLFunction_t getMTLFunction(const std::string& fname) { return getLibraryPipelineState(getLibrary(), fname).second; } - id getMTLFunction( + MTLFunction_t getMTLFunction( const std::string& fname, const std::initializer_list& params) { return getLibraryPipelineState(getLibrary(params), fname).second; @@ -39,23 +50,23 @@ class MetalShaderLibrary { static MetalShaderLibrary& getBundledLibrary(); protected: - virtual id getLibrary(); - virtual id getLibrary( + virtual MTLLibrary_t getLibrary(); + virtual MTLLibrary_t getLibrary( const std::initializer_list& params); - id library = nil; + MTLLibrary_t library = nullptr; private: - std::pair, id> - getLibraryPipelineState(id lib, const std::string& fname); - - id compileLibrary(const std::string& src); + std::pair getLibraryPipelineState( + MTLLibrary_t lib, + const std::string& fname); + MTLLibrary_t compileLibrary(const std::string& src); std::string shaderSource; unsigned nparams; MTLCompileOptions* compile_options; - std::unordered_map> libMap; + std::unordered_map libMap; std::unordered_map< std::string, - std::pair, id>> + std::pair> cplMap; };