Skip to content

Commit ec94cbc

Browse files
salilsdesaipytorchmergebot
authored andcommitted
[Vulkan] Remove GLSL Code Gen (pytorch#91912)
@bypass-github-export-checks GLSL Code Gen is not used, so this diff removes - GLSL parts of ShaderSource - Anything enclosed by USE_VULKAN_SHADERC_RUNTIME, as well as the flag itself - gen_vulkan_glsl script Plus some additional refactoring Differential Revision: [D41358861](https://our.internmc.facebook.com/intern/diff/D41358861/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D41358861/)! Pull Request resolved: pytorch#91912 Approved by: https://github.com/mcr229
1 parent 28eb3c8 commit ec94cbc

13 files changed

+214
-500
lines changed

CMakeLists.txt

-4
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,6 @@ option(USE_SOURCE_DEBUG_ON_MOBILE "Enable " ON)
266266
option(USE_LITE_INTERPRETER_PROFILER "Enable " ON)
267267
option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF)
268268
option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF)
269-
option(USE_VULKAN_SHADERC_RUNTIME "Vulkan - Use runtime shader compilation as opposed to build-time (needs libshaderc)" OFF)
270269
# option USE_XNNPACK: try to enable xnnpack by default.
271270
option(USE_XNNPACK "Use XNNPACK" ON)
272271
option(USE_ZMQ "Use ZMQ" OFF)
@@ -746,9 +745,6 @@ if(USE_VULKAN)
746745
string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_RELAXED_PRECISION")
747746
endif()
748747

749-
if(USE_VULKAN_SHADERC_RUNTIME)
750-
string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_SHADERC_RUNTIME")
751-
endif()
752748
endif()
753749

754750
if(BUILD_LITE_INTERPRETER)

aten/src/ATen/gen_vulkan_glsl.py

-115
This file was deleted.

aten/src/ATen/native/vulkan/api/Shader.cpp

+14-49
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
#include <ATen/native/vulkan/api/Shader.h>
22

3-
#ifdef USE_VULKAN_SHADERC_RUNTIME
4-
#include <shaderc/shaderc.hpp>
5-
#endif /* USE_VULKAN_SHADERC_RUNTIME */
6-
73
namespace at {
84
namespace native {
95
namespace vulkan {
@@ -14,38 +10,19 @@ namespace api {
1410
//
1511

1612
ShaderInfo::ShaderInfo()
17-
: type(ShaderInfo::Type::SPIRV),
18-
src_code{
19-
.spirv =
20-
{
21-
nullptr,
22-
0u,
23-
},
13+
: src_code{
14+
nullptr,
15+
0u,
2416
} {}
2517

26-
ShaderInfo::ShaderInfo(std::string name, const char* const glsl_src)
27-
: type(ShaderInfo::Type::GLSL),
28-
src_code{
29-
.glsl =
30-
{
31-
glsl_src,
32-
0u,
33-
},
34-
},
35-
kernel_name{std::move(name)} {}
36-
3718
ShaderInfo::ShaderInfo(
3819
std::string name,
3920
const uint32_t* const spirv_bin,
4021
const uint32_t size,
4122
const std::vector<VkDescriptorType>& layout)
42-
: type(Type::SPIRV),
43-
src_code{
44-
.spirv =
45-
{
46-
spirv_bin,
47-
size,
48-
},
23+
: src_code{
24+
spirv_bin,
25+
size,
4926
},
5027
kernel_name{std::move(name)},
5128
kernel_layout{layout} {}
@@ -58,13 +35,9 @@ ShaderInfo::ShaderInfo(
5835
const std::vector<uint32_t>& tile_size,
5936
const StorageType bias_storage_type,
6037
const StorageType weight_storage_type)
61-
: type(Type::SPIRV),
62-
src_code{
63-
.spirv =
64-
{
65-
spirv_bin,
66-
size,
67-
},
38+
: src_code{
39+
spirv_bin,
40+
size,
6841
},
6942
kernel_name{std::move(name)},
7043
kernel_layout{layout},
@@ -77,17 +50,9 @@ ShaderInfo::ShaderInfo(
7750
}
7851

7952
bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {
80-
if (_1.type != _2.type) {
81-
return false;
82-
}
83-
84-
if (_1.type == ShaderInfo::Type::SPIRV) {
85-
return (
86-
_1.src_code.spirv.bin == _2.src_code.spirv.bin &&
87-
_1.src_code.spirv.size == _2.src_code.spirv.size);
88-
} else {
89-
return (_1.src_code.glsl.src == _2.src_code.glsl.src);
90-
}
53+
return (
54+
_1.src_code.bin == _2.src_code.bin &&
55+
_1.src_code.size == _2.src_code.size);
9156
}
9257

9358
//
@@ -153,8 +118,8 @@ void swap(ShaderLayout& lhs, ShaderLayout& rhs) noexcept {
153118

154119
ShaderModule::ShaderModule(const VkDevice device, const ShaderInfo& source)
155120
: device_(device), handle_{VK_NULL_HANDLE} {
156-
const uint32_t* code = source.src_code.spirv.bin;
157-
uint32_t size = source.src_code.spirv.size;
121+
const uint32_t* code = source.src_code.bin;
122+
uint32_t size = source.src_code.size;
158123

159124
const VkShaderModuleCreateInfo shader_module_create_info{
160125
VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, // sType

aten/src/ATen/native/vulkan/api/Shader.h

+4-13
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,9 @@ class ShaderLayout final {
4545
};
4646

4747
struct ShaderInfo final {
48-
enum class Type { GLSL, SPIRV } type;
49-
50-
union {
51-
struct {
52-
const char* src; // Null-terminated
53-
uint32_t unused; // padding
54-
} glsl;
55-
struct {
56-
const uint32_t* bin;
57-
uint32_t size;
58-
} spirv;
48+
struct {
49+
const uint32_t* bin;
50+
uint32_t size;
5951
} src_code;
6052

6153
std::string kernel_name{""};
@@ -171,8 +163,7 @@ class ShaderCache final {
171163

172164
struct Hasher {
173165
inline size_t operator()(const ShaderInfo& source) const {
174-
return c10::get_hash(
175-
source.type, source.src_code.spirv.bin, source.src_code.spirv.size);
166+
return c10::get_hash(source.src_code.bin, source.src_code.size);
176167
}
177168
};
178169

aten/src/ATen/native/vulkan/ops/Common.h

+1-10
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,9 @@
66
#include <ATen/core/Tensor.h>
77
#include <ATen/native/vulkan/api/api.h>
88
#include <ATen/native/vulkan/ops/Convert.h>
9-
10-
#define CONCAT_LITERALS(a, b) #a #b
11-
#ifdef USE_VULKAN_SHADERC_RUNTIME
12-
#include <ATen/native/vulkan/glsl.h>
13-
#define VK_KERNEL(name) \
14-
::at::native::vulkan::api::ShaderInfo { \
15-
CONCAT_LITERALS(vulkan., name), name##_glsl, \
16-
}
17-
#else
189
#include <ATen/native/vulkan/spv.h>
10+
1911
#define VK_KERNEL(name) ::at::native::vulkan::name##_spv
20-
#endif /* USE_VULKAN_SHADERC_RUNTIME */
2112

2213
namespace at {
2314
namespace native {

cmake/Summary.cmake

-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ function(caffe2_print_configuration_summary)
175175
if(${USE_VULKAN})
176176
message(STATUS " USE_VULKAN_FP16_INFERENCE : ${USE_VULKAN_FP16_INFERENCE}")
177177
message(STATUS " USE_VULKAN_RELAXED_PRECISION : ${USE_VULKAN_RELAXED_PRECISION}")
178-
message(STATUS " USE_VULKAN_SHADERC_RUNTIME : ${USE_VULKAN_SHADERC_RUNTIME}")
179178
endif()
180179
message(STATUS " USE_PROF : ${USE_PROF}")
181180
message(STATUS " USE_QNNPACK : ${USE_QNNPACK}")

0 commit comments

Comments
 (0)