Skip to content

Commit d40e910

Browse files
malfetpobin6
authored andcommitted
[MPS] Move MetalShaderLibrary to its own header (pytorch#141475)
In preparation to be used from libtorch_python Pull Request resolved: pytorch#141475 Approved by: https://github.com/Skylion007 ghstack dependencies: pytorch#141474
1 parent 01757af commit d40e910

File tree

2 files changed

+62
-40
lines changed

2 files changed

+62
-40
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#pragma once
2+
#include <Metal/Metal.h>
3+
#include <unordered_map>
4+
#include <vector>
5+
6+
namespace at::native::mps {
7+
class MetalShaderLibrary {
8+
public:
9+
MetalShaderLibrary(const std::string& src)
10+
: shaderSource(src), nparams(0), compile_options(nullptr) {}
11+
MetalShaderLibrary(const std::string& src, unsigned nparams_)
12+
: shaderSource(src), nparams(nparams_), compile_options(nullptr) {}
13+
MetalShaderLibrary(
14+
const std::string& src,
15+
unsigned nparams_,
16+
MTLCompileOptions* compile_options_)
17+
: shaderSource(src),
18+
nparams(nparams_),
19+
compile_options(compile_options_) {}
20+
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
21+
inline id<MTLComputePipelineState> getPipelineStateForFunc(
22+
const std::string& fname) {
23+
return getLibraryPipelineState(getLibrary(), fname).first;
24+
}
25+
id<MTLComputePipelineState> getPipelineStateForFunc(
26+
const std::string& fname,
27+
const std::initializer_list<std::string>& params) {
28+
return getLibraryPipelineState(getLibrary(params), fname).first;
29+
}
30+
inline id<MTLFunction> getMTLFunction(const std::string& fname) {
31+
return getLibraryPipelineState(getLibrary(), fname).second;
32+
}
33+
id<MTLFunction> getMTLFunction(
34+
const std::string& fname,
35+
const std::initializer_list<std::string>& params) {
36+
return getLibraryPipelineState(getLibrary(params), fname).second;
37+
}
38+
static MetalShaderLibrary& getBundledLibrary();
39+
40+
protected:
41+
virtual id<MTLLibrary> getLibrary();
42+
virtual id<MTLLibrary> getLibrary(
43+
const std::initializer_list<std::string>& params);
44+
id<MTLLibrary> library = nil;
45+
46+
private:
47+
std::pair<id<MTLComputePipelineState>, id<MTLFunction>>
48+
getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
49+
50+
id<MTLLibrary> compileLibrary(const std::string& src);
51+
std::string shaderSource;
52+
unsigned nparams;
53+
MTLCompileOptions* compile_options;
54+
std::unordered_map<std::string, id<MTLLibrary>> libMap;
55+
std::unordered_map<
56+
std::string,
57+
std::pair<id<MTLComputePipelineState>, id<MTLFunction>>>
58+
cplMap;
59+
};
60+
61+
} // namespace at::native::mps

aten/src/ATen/native/mps/OperationUtils.h

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <ATen/Tensor.h>
88
#include <ATen/Utils.h>
99
#include <ATen/mps/MPSStream.h>
10+
#include <ATen/native/mps/MetalShaderLibrary.h>
1011
#include <ATen/native/mps/TensorFactory.h>
1112
#include <c10/core/ScalarType.h>
1213
#include <torch/library.h>
@@ -342,46 +343,6 @@ inline bool is_dense_in_storage(const TensorBase& t) {
342343
return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
343344
}
344345

345-
class MetalShaderLibrary {
346-
public:
347-
MetalShaderLibrary(const std::string& src) : shaderSource(src), nparams(0), compile_options(nullptr) {}
348-
MetalShaderLibrary(const std::string& src, unsigned nparams_)
349-
: shaderSource(src), nparams(nparams_), compile_options(nullptr) {}
350-
MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_)
351-
: shaderSource(src), nparams(nparams_), compile_options(compile_options_) {}
352-
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
353-
inline id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname) {
354-
return getLibraryPipelineState(getLibrary(), fname).first;
355-
}
356-
id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname,
357-
const std::initializer_list<std::string>& params) {
358-
return getLibraryPipelineState(getLibrary(params), fname).first;
359-
}
360-
inline id<MTLFunction> getMTLFunction(const std::string& fname) {
361-
return getLibraryPipelineState(getLibrary(), fname).second;
362-
}
363-
id<MTLFunction> getMTLFunction(const std::string& fname, const std::initializer_list<std::string>& params) {
364-
return getLibraryPipelineState(getLibrary(params), fname).second;
365-
}
366-
static MetalShaderLibrary& getBundledLibrary();
367-
368-
protected:
369-
virtual id<MTLLibrary> getLibrary();
370-
virtual id<MTLLibrary> getLibrary(const std::initializer_list<std::string>& params);
371-
id<MTLLibrary> library = nil;
372-
373-
private:
374-
std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getLibraryPipelineState(id<MTLLibrary> lib,
375-
const std::string& fname);
376-
377-
id<MTLLibrary> compileLibrary(const std::string& src);
378-
std::string shaderSource;
379-
unsigned nparams;
380-
MTLCompileOptions* compile_options;
381-
std::unordered_map<std::string, id<MTLLibrary>> libMap;
382-
std::unordered_map<std::string, std::pair<id<MTLComputePipelineState>, id<MTLFunction>>> cplMap;
383-
};
384-
385346
namespace detail {
386347
template <typename T>
387348
class has_size_type {

0 commit comments

Comments
 (0)