Skip to content

Commit 28eb3c8

Browse files
salilsdesaipytorchmergebot
authored andcommitted
[Vulkan] Generate ShaderInfos Directly via Codegen in gen_vulkan_spv (pytorch#91911)
@bypass-github-export-checks Before this change, we have the data members which make up a ```ShaderInfo``` sitting in ```spv.h/.cpp``` in an unorganized manner. This diff makes the change such that the ```ShaderInfo```s are initialized directly in spv.h/.cpp Now spv.h looks like ``` #pragma once #include <stdint.h> #include <vector> #include <string> #include <ATen/native/vulkan/api/Types.h> #include <ATen/native/vulkan/api/vk_api.h> namespace at { namespace native { namespace vulkan { namespace api { struct ShaderInfo; } // namespace api extern const api::ShaderInfo adaptive_avg_pool2d_spv; ... extern const api::ShaderInfo conv2d_pw_2x2_spv; } // namespace vulkan } // namespace native } // namespace at ``` (Full File: P557399150) and spv.cpp looks like ``` #include <ATen/native/vulkan/spv.h> #include <ATen/native/vulkan/api/Shader.h> namespace at { namespace native { namespace vulkan { namespace { const uint32_t adaptive_avg_pool2d_spv_bin[] = { 119734787, ... }; ... const uint32_t conv2d_pw_2x2_spv_bin[] = { 119734787, ... }; } // namespace const api::ShaderInfo adaptive_avg_pool2d_spv( "vulkan.adaptive_avg_pool2d", adaptive_avg_pool2d_spv_bin, 3204, {VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER}, std::vector<uint32_t>(), api::StorageType::UNKNOWN, api::StorageType::UNKNOWN ); ... const api::ShaderInfo conv2d_pw_2x2_spv( "vulkan.conv2d_pw_2x2", conv2d_pw_2x2_spv_bin, 7736, {VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER}, {2, 2, 1}, api::StorageType::TEXTURE_2D, api::StorageType::TEXTURE_2D ); } // namespace vulkan } // namespace native } // namespace at ``` (Full File: P584237146) Differential Revision: [D41354313](https://our.internmc.facebook.com/intern/diff/D41354313/) Pull Request resolved: pytorch#91911 Approved by: https://github.com/mcr229
1 parent 776fef9 commit 28eb3c8

File tree

3 files changed

+61
-67
lines changed

3 files changed

+61
-67
lines changed

aten/src/ATen/native/vulkan/ops/Common.h

+1-11
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,7 @@
1616
}
1717
#else
1818
#include <ATen/native/vulkan/spv.h>
19-
#define VK_KERNEL(name) \
20-
::at::native::vulkan::api::ShaderInfo { \
21-
CONCAT_LITERALS(vulkan., name), name##_spv, name##_spv_len, \
22-
name##_spv_layout \
23-
}
24-
#define VK_SHADER(name) \
25-
::at::native::vulkan::api::ShaderInfo { \
26-
CONCAT_LITERALS(vulkan., name), name##_spv, name##_spv_len, \
27-
name##_spv_layout, name##_spv_tile_size, name##_spv_bias_storage_type, \
28-
name##_spv_weight_storage_type, \
29-
}
19+
#define VK_KERNEL(name) ::at::native::vulkan::name##_spv
3020
#endif /* USE_VULKAN_SHADERC_RUNTIME */
3121

3222
namespace at {

aten/src/ATen/native/vulkan/ops/Convolution.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -296,43 +296,43 @@ static api::ShaderInfo get_shader(
296296

297297
switch (method) {
298298
case Conv2dSlidingWindow:
299-
shader = VK_SHADER(quantized_conv2d);
299+
shader = VK_KERNEL(quantized_conv2d);
300300
break;
301301
case Conv2dDepthwise:
302-
shader = VK_SHADER(quantized_conv2d_dw);
302+
shader = VK_KERNEL(quantized_conv2d_dw);
303303
break;
304304
case Conv2dPointwise:
305-
shader = VK_SHADER(quantized_conv2d_pw_2x2);
305+
shader = VK_KERNEL(quantized_conv2d_pw_2x2);
306306
break;
307307
// todo fail for quantized transposed conv
308308
}
309309
return shader;
310310
}
311311

312312
if (transposed) {
313-
shader = VK_SHADER(conv_transpose2d);
313+
shader = VK_KERNEL(conv_transpose2d);
314314
return shader;
315315
}
316316

317317
switch (method) {
318318
case Conv2dSlidingWindow:
319-
shader = VK_SHADER(conv2d);
319+
shader = VK_KERNEL(conv2d);
320320
break;
321321
case Conv2dDepthwise:
322-
shader = VK_SHADER(conv2d_dw);
322+
shader = VK_KERNEL(conv2d_dw);
323323
if (kernel_size.size() == 4 && kernel_size[2] == 3 &&
324324
kernel_size[3] == 3) {
325325
// 1x1 refers to the output tile size
326-
shader = VK_SHADER(conv2d_dw_3x3);
326+
shader = VK_KERNEL(conv2d_dw_3x3);
327327
}
328328
if (kernel_size.size() == 4 && kernel_size[2] == 5 &&
329329
kernel_size[3] == 5) {
330330
// 1x1 refers to the output tile size
331-
shader = VK_SHADER(conv2d_dw_5x5);
331+
shader = VK_KERNEL(conv2d_dw_5x5);
332332
}
333333
break;
334334
case Conv2dPointwise:
335-
shader = VK_SHADER(conv2d_pw_2x2);
335+
shader = VK_KERNEL(conv2d_pw_2x2);
336336
break;
337337
}
338338
return shader;

tools/gen_vulkan_spv.py

+51-47
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def getBiasStorageType(lineStr):
7070
"TEXTURE_2D" : "api::StorageType::TEXTURE_2D",
7171
"TEXTURE_3D" : "api::StorageType::TEXTURE_3D",
7272
"BUFFER" : "api::StorageType::BUFFER",
73+
"": "api::StorageType::UNKNOWN",
7374
}
7475

7576
def determineDescriptorType(lineStr):
@@ -165,64 +166,67 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
165166
h += "#include <vector>\n"
166167
h += "#include <string>\n"
167168
h += "#include <ATen/native/vulkan/api/Types.h>\n"
168-
h += "#include <ATen/native/vulkan/api/vk_api.h>"
169+
h += "#include <ATen/native/vulkan/api/vk_api.h>\n"
169170

170-
nsbegin = "\nnamespace at {\nnamespace native {\nnamespace vulkan {\n"
171-
nsend = "\n}\n}\n} //namespace at::native::vulkan\n"
171+
nsbegin = "namespace at {\nnamespace native {\nnamespace vulkan {\n"
172+
nsend = "} // namespace vulkan\n} // namespace native\n} // namespace at\n"
172173

173174
h += nsbegin
174175

175-
cpp = "#include <ATen/native/vulkan/{}>".format(H_NAME)
176+
# Forward declaration of ShaderInfo
177+
h += "namespace api {\nstruct ShaderInfo;\n} // namespace api\n"
178+
179+
cpp = "#include <ATen/native/vulkan/{}>\n".format(H_NAME)
180+
cpp += "#include <ATen/native/vulkan/api/Shader.h>\n"
176181
cpp += nsbegin
177182

183+
shader_info_bin_code = []
184+
shader_info_cpp_code = []
185+
shader_info_h_code = []
186+
178187
for spvPath, srcPath in spvPaths.items():
179188
name = getName(spvPath)
180-
name_len = name + "_len"
181-
h += "extern const uint32_t {}[];\n".format(name)
182-
h += "extern const uint32_t {};\n".format(name_len)
183189

184-
shader_info = getShaderInfo(srcPath)
185-
name_layout = name + "_layout"
186-
h += "extern const std::vector<VkDescriptorType> {};\n".format(name_layout)
187-
188-
cpp += "const uint32_t " + name + "[] = {\n"
189-
sizeBytes = 0
190190
print("spvPath:{}".format(spvPath))
191191
with open(spvPath, 'rb') as f:
192-
for word in array.array('I', f.read()):
193-
cpp += "{},\n".format(word)
194-
sizeBytes += 4
195-
cpp += "};\n"
196-
cpp += "const uint32_t {} = {};\n".format(name_len, sizeBytes)
197-
198-
# Add layout
199-
cpp += "const std::vector<VkDescriptorType> {} = {{\n".format(name_layout)
200-
for descriptor in shader_info.layouts:
201-
cpp += " {},\n".format(descriptor)
202-
cpp += "};\n"
203-
204-
# Add tile size
205-
if (len(shader_info.tile_size) > 0):
206-
name_tile_size = name + "_tile_size"
207-
h += "extern const std::vector<uint32_t> {};\n".format(name_tile_size)
208-
cpp += "const std::vector<uint32_t> {} = {{\n".format(name_tile_size)
209-
for s in shader_info.tile_size:
210-
cpp += " {},\n".format(s)
211-
cpp += "};\n"
212-
213-
# Add weight type
214-
if (shader_info.weight_storage_type != ""):
215-
name_weight_storage_type = name + "_weight_storage_type"
216-
h += "extern const api::StorageType {};\n".format(name_weight_storage_type)
217-
cpp += "const api::StorageType {} = \n".format(name_weight_storage_type)
218-
cpp += " {};\n".format(storageTypeToEnum[shader_info.weight_storage_type])
219-
220-
# Add bias type
221-
if (shader_info.bias_storage_type != ""):
222-
name_bias_storage_type = name + "_bias_storage_type"
223-
h += "extern const api::StorageType {};\n".format(name_bias_storage_type)
224-
cpp += "const api::StorageType {} = \n".format(name_bias_storage_type)
225-
cpp += " {};\n".format(storageTypeToEnum[shader_info.bias_storage_type])
192+
next_bin = array.array('I', f.read())
193+
sizeBytes = 4 * len(next_bin)
194+
shader_info_bin_code.append(
195+
"const uint32_t {}_bin[] = {{\n {}\n}};".format(
196+
name,
197+
",\n ".join(str(x) for x in next_bin),
198+
)
199+
)
200+
201+
shader_info = getShaderInfo(srcPath)
202+
203+
tile_size = (
204+
"{{{}}}".format(", ".join(str(x) for x in shader_info.tile_size))
205+
if (len(shader_info.tile_size) > 0)
206+
else "std::vector<uint32_t>()"
207+
)
208+
209+
shader_info_args = [
210+
"\"vulkan.{}\"".format(name.replace("_spv", "")),
211+
"{}_bin".format(name),
212+
str(sizeBytes),
213+
"{{{}}}".format(", ".join(shader_info.layouts)),
214+
tile_size,
215+
storageTypeToEnum[shader_info.weight_storage_type],
216+
storageTypeToEnum[shader_info.bias_storage_type],
217+
]
218+
219+
shader_info_h_code.append("extern const api::ShaderInfo {};".format(name))
220+
shader_info_cpp_code.append(
221+
"const api::ShaderInfo {}(\n {}\n);".format(
222+
name,
223+
",\n ".join(shader_info_args),
224+
),
225+
)
226+
227+
cpp += "namespace {{\n{}\n}} // namespace\n".format("\n".join(shader_info_bin_code))
228+
cpp += "{}\n".format("\n".join(shader_info_cpp_code))
229+
h += "{}\n".format("\n".join(shader_info_h_code))
226230

227231
cpp += nsend
228232
h += nsend

0 commit comments

Comments
 (0)