@@ -70,6 +70,7 @@ def getBiasStorageType(lineStr):
70
70
"TEXTURE_2D" : "api::StorageType::TEXTURE_2D" ,
71
71
"TEXTURE_3D" : "api::StorageType::TEXTURE_3D" ,
72
72
"BUFFER" : "api::StorageType::BUFFER" ,
73
+ "" : "api::StorageType::UNKNOWN" ,
73
74
}
74
75
75
76
def determineDescriptorType (lineStr ):
@@ -165,64 +166,67 @@ def genCppH(hFilePath, cppFilePath, srcDirPath, glslcPath, tmpDirPath, env):
165
166
h += "#include <vector>\n "
166
167
h += "#include <string>\n "
167
168
h += "#include <ATen/native/vulkan/api/Types.h>\n "
168
- h += "#include <ATen/native/vulkan/api/vk_api.h>"
169
+ h += "#include <ATen/native/vulkan/api/vk_api.h>\n "
169
170
170
- nsbegin = "\n namespace at {\n namespace native {\n namespace vulkan {\n "
171
- nsend = "\n } \n }\n } //namespace at::native::vulkan \n "
171
+ nsbegin = "namespace at {\n namespace native {\n namespace vulkan {\n "
172
+ nsend = "} // namespace vulkan \n } // namespace native \n } // namespace at\n "
172
173
173
174
h += nsbegin
174
175
175
- cpp = "#include <ATen/native/vulkan/{}>" .format (H_NAME )
176
+ # Forward declaration of ShaderInfo
177
+ h += "namespace api {\n struct ShaderInfo;\n } // namespace api\n "
178
+
179
+ cpp = "#include <ATen/native/vulkan/{}>\n " .format (H_NAME )
180
+ cpp += "#include <ATen/native/vulkan/api/Shader.h>\n "
176
181
cpp += nsbegin
177
182
183
+ shader_info_bin_code = []
184
+ shader_info_cpp_code = []
185
+ shader_info_h_code = []
186
+
178
187
for spvPath , srcPath in spvPaths .items ():
179
188
name = getName (spvPath )
180
- name_len = name + "_len"
181
- h += "extern const uint32_t {}[];\n " .format (name )
182
- h += "extern const uint32_t {};\n " .format (name_len )
183
189
184
- shader_info = getShaderInfo (srcPath )
185
- name_layout = name + "_layout"
186
- h += "extern const std::vector<VkDescriptorType> {};\n " .format (name_layout )
187
-
188
- cpp += "const uint32_t " + name + "[] = {\n "
189
- sizeBytes = 0
190
190
print ("spvPath:{}" .format (spvPath ))
191
191
with open (spvPath , 'rb' ) as f :
192
- for word in array .array ('I' , f .read ()):
193
- cpp += "{},\n " .format (word )
194
- sizeBytes += 4
195
- cpp += "};\n "
196
- cpp += "const uint32_t {} = {};\n " .format (name_len , sizeBytes )
197
-
198
- # Add layout
199
- cpp += "const std::vector<VkDescriptorType> {} = {{\n " .format (name_layout )
200
- for descriptor in shader_info .layouts :
201
- cpp += " {},\n " .format (descriptor )
202
- cpp += "};\n "
203
-
204
- # Add tile size
205
- if (len (shader_info .tile_size ) > 0 ):
206
- name_tile_size = name + "_tile_size"
207
- h += "extern const std::vector<uint32_t> {};\n " .format (name_tile_size )
208
- cpp += "const std::vector<uint32_t> {} = {{\n " .format (name_tile_size )
209
- for s in shader_info .tile_size :
210
- cpp += " {},\n " .format (s )
211
- cpp += "};\n "
212
-
213
- # Add weight type
214
- if (shader_info .weight_storage_type != "" ):
215
- name_weight_storage_type = name + "_weight_storage_type"
216
- h += "extern const api::StorageType {};\n " .format (name_weight_storage_type )
217
- cpp += "const api::StorageType {} = \n " .format (name_weight_storage_type )
218
- cpp += " {};\n " .format (storageTypeToEnum [shader_info .weight_storage_type ])
219
-
220
- # Add bias type
221
- if (shader_info .bias_storage_type != "" ):
222
- name_bias_storage_type = name + "_bias_storage_type"
223
- h += "extern const api::StorageType {};\n " .format (name_bias_storage_type )
224
- cpp += "const api::StorageType {} = \n " .format (name_bias_storage_type )
225
- cpp += " {};\n " .format (storageTypeToEnum [shader_info .bias_storage_type ])
192
+ next_bin = array .array ('I' , f .read ())
193
+ sizeBytes = 4 * len (next_bin )
194
+ shader_info_bin_code .append (
195
+ "const uint32_t {}_bin[] = {{\n {}\n }};" .format (
196
+ name ,
197
+ ",\n " .join (str (x ) for x in next_bin ),
198
+ )
199
+ )
200
+
201
+ shader_info = getShaderInfo (srcPath )
202
+
203
+ tile_size = (
204
+ "{{{}}}" .format (", " .join (str (x ) for x in shader_info .tile_size ))
205
+ if (len (shader_info .tile_size ) > 0 )
206
+ else "std::vector<uint32_t>()"
207
+ )
208
+
209
+ shader_info_args = [
210
+ "\" vulkan.{}\" " .format (name .replace ("_spv" , "" )),
211
+ "{}_bin" .format (name ),
212
+ str (sizeBytes ),
213
+ "{{{}}}" .format (", " .join (shader_info .layouts )),
214
+ tile_size ,
215
+ storageTypeToEnum [shader_info .weight_storage_type ],
216
+ storageTypeToEnum [shader_info .bias_storage_type ],
217
+ ]
218
+
219
+ shader_info_h_code .append ("extern const api::ShaderInfo {};" .format (name ))
220
+ shader_info_cpp_code .append (
221
+ "const api::ShaderInfo {}(\n {}\n );" .format (
222
+ name ,
223
+ ",\n " .join (shader_info_args ),
224
+ ),
225
+ )
226
+
227
+ cpp += "namespace {{\n {}\n }} // namespace\n " .format ("\n " .join (shader_info_bin_code ))
228
+ cpp += "{}\n " .format ("\n " .join (shader_info_cpp_code ))
229
+ h += "{}\n " .format ("\n " .join (shader_info_h_code ))
226
230
227
231
cpp += nsend
228
232
h += nsend
0 commit comments