Skip to content

Commit

Permalink
[MPS] Make MetalShaderLibrary usable from C++ (pytorch#141477)
Browse files Browse the repository at this point in the history
By guarding Metal framework include and defining all ObjC protocols to dummy `void*`
Pull Request resolved: pytorch#141477
Approved by: https://github.com/Skylion007
ghstack dependencies: pytorch#141474, pytorch#141475, pytorch#141476
  • Loading branch information
malfet authored and Ryo-not-rio committed Dec 2, 2024
1 parent 17d0276 commit dea27c3
Showing 1 changed file with 24 additions and 13 deletions.
37 changes: 24 additions & 13 deletions aten/src/ATen/native/mps/MetalShaderLibrary.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
#pragma once
#ifdef __OBJC__
#include <Metal/Metal.h>
typedef id<MTLLibrary> MTLLibrary_t;
typedef id<MTLFunction> MTLFunction_t;
typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
#else
typedef void MTLCompileOptions;
typedef void* MTLLibrary_t;
typedef void* MTLFunction_t;
typedef void* MTLComputePipelineState_t;
#endif

#include <unordered_map>
#include <vector>

Expand All @@ -19,43 +30,43 @@ class MetalShaderLibrary {
compile_options(compile_options_) {}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
virtual ~MetalShaderLibrary() = default;
inline id<MTLComputePipelineState> getPipelineStateForFunc(
inline MTLComputePipelineState_t getPipelineStateForFunc(
const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname).first;
}
id<MTLComputePipelineState> getPipelineStateForFunc(
MTLComputePipelineState_t getPipelineStateForFunc(
const std::string& fname,
const std::initializer_list<std::string>& params) {
return getLibraryPipelineState(getLibrary(params), fname).first;
}
inline id<MTLFunction> getMTLFunction(const std::string& fname) {
inline MTLFunction_t getMTLFunction(const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname).second;
}
id<MTLFunction> getMTLFunction(
MTLFunction_t getMTLFunction(
const std::string& fname,
const std::initializer_list<std::string>& params) {
return getLibraryPipelineState(getLibrary(params), fname).second;
}
static MetalShaderLibrary& getBundledLibrary();

protected:
virtual id<MTLLibrary> getLibrary();
virtual id<MTLLibrary> getLibrary(
virtual MTLLibrary_t getLibrary();
virtual MTLLibrary_t getLibrary(
const std::initializer_list<std::string>& params);
id<MTLLibrary> library = nil;
MTLLibrary_t library = nullptr;

private:
std::pair<id<MTLComputePipelineState>, id<MTLFunction>>
getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);

id<MTLLibrary> compileLibrary(const std::string& src);
std::pair<MTLComputePipelineState_t, MTLFunction_t> getLibraryPipelineState(
MTLLibrary_t lib,
const std::string& fname);
MTLLibrary_t compileLibrary(const std::string& src);
std::string shaderSource;
unsigned nparams;
MTLCompileOptions* compile_options;
std::unordered_map<std::string, id<MTLLibrary>> libMap;
std::unordered_map<std::string, MTLLibrary_t> libMap;
std::unordered_map<
std::string,
std::pair<id<MTLComputePipelineState>, id<MTLFunction>>>
std::pair<MTLComputePipelineState_t, MTLFunction_t>>
cplMap;
};

Expand Down

0 comments on commit dea27c3

Please sign in to comment.