Skip to content

Commit 4b94e4a

Browse files
malfetpobin6
authored and
pobin6
committed
[MPS] Make MetalShaderLibrary usable from C++ (pytorch#141477)
By guarding Metal framework include and defining all ObjC protocols to dummy `void*` Pull Request resolved: pytorch#141477 Approved by: https://github.com/Skylion007 ghstack dependencies: pytorch#141474, pytorch#141475, pytorch#141476
1 parent 1641f9c commit 4b94e4a

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

aten/src/ATen/native/mps/MetalShaderLibrary.h

+24-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
#pragma once
2+
#ifdef __OBJC__
23
#include <Metal/Metal.h>
4+
typedef id<MTLLibrary> MTLLibrary_t;
5+
typedef id<MTLFunction> MTLFunction_t;
6+
typedef id<MTLComputePipelineState> MTLComputePipelineState_t;
7+
#else
8+
typedef void MTLCompileOptions;
9+
typedef void* MTLLibrary_t;
10+
typedef void* MTLFunction_t;
11+
typedef void* MTLComputePipelineState_t;
12+
#endif
13+
314
#include <unordered_map>
415
#include <vector>
516

@@ -19,43 +30,43 @@ class MetalShaderLibrary {
1930
compile_options(compile_options_) {}
2031
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
2132
virtual ~MetalShaderLibrary() = default;
22-
inline id<MTLComputePipelineState> getPipelineStateForFunc(
33+
inline MTLComputePipelineState_t getPipelineStateForFunc(
2334
const std::string& fname) {
2435
return getLibraryPipelineState(getLibrary(), fname).first;
2536
}
26-
id<MTLComputePipelineState> getPipelineStateForFunc(
37+
MTLComputePipelineState_t getPipelineStateForFunc(
2738
const std::string& fname,
2839
const std::initializer_list<std::string>& params) {
2940
return getLibraryPipelineState(getLibrary(params), fname).first;
3041
}
31-
inline id<MTLFunction> getMTLFunction(const std::string& fname) {
42+
inline MTLFunction_t getMTLFunction(const std::string& fname) {
3243
return getLibraryPipelineState(getLibrary(), fname).second;
3344
}
34-
id<MTLFunction> getMTLFunction(
45+
MTLFunction_t getMTLFunction(
3546
const std::string& fname,
3647
const std::initializer_list<std::string>& params) {
3748
return getLibraryPipelineState(getLibrary(params), fname).second;
3849
}
3950
static MetalShaderLibrary& getBundledLibrary();
4051

4152
protected:
42-
virtual id<MTLLibrary> getLibrary();
43-
virtual id<MTLLibrary> getLibrary(
53+
virtual MTLLibrary_t getLibrary();
54+
virtual MTLLibrary_t getLibrary(
4455
const std::initializer_list<std::string>& params);
45-
id<MTLLibrary> library = nil;
56+
MTLLibrary_t library = nullptr;
4657

4758
private:
48-
std::pair<id<MTLComputePipelineState>, id<MTLFunction>>
49-
getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
50-
51-
id<MTLLibrary> compileLibrary(const std::string& src);
59+
std::pair<MTLComputePipelineState_t, MTLFunction_t> getLibraryPipelineState(
60+
MTLLibrary_t lib,
61+
const std::string& fname);
62+
MTLLibrary_t compileLibrary(const std::string& src);
5263
std::string shaderSource;
5364
unsigned nparams;
5465
MTLCompileOptions* compile_options;
55-
std::unordered_map<std::string, id<MTLLibrary>> libMap;
66+
std::unordered_map<std::string, MTLLibrary_t> libMap;
5667
std::unordered_map<
5768
std::string,
58-
std::pair<id<MTLComputePipelineState>, id<MTLFunction>>>
69+
std::pair<MTLComputePipelineState_t, MTLFunction_t>>
5970
cplMap;
6071
};
6172

0 commit comments

Comments
 (0)