Skip to content

Commit 8804e8a

Browse files
Add zelReloadDrivers(flags) API
Provides a means to re-initialize all of the drivers' library handles and DDI tables. The value of flags must match what was provided to zeInit(flags). Signed-off-by: Lisanna Dettwyler <[email protected]>
1 parent 519eed2 commit 8804e8a

File tree

3 files changed

+274
-0
lines changed

3 files changed

+274
-0
lines changed

Diff for: source/lib/ze_lib.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,22 @@ zelLoaderGetVersions(
171171
#endif
172172
}
173173

174+
ze_result_t ZE_APICALL
175+
zelReloadDrivers(
176+
ze_init_flags_t flags)
177+
{
178+
#ifdef DYNAMIC_LOAD_LOADER
179+
if(nullptr == ze_lib::context->loader)
180+
return ZE_RESULT_ERROR;
181+
typedef ze_result_t (ZE_APICALL *zelReloadDriver_t)(ze_driver_handle_t hDriver);
182+
auto reloadDrivers = reinterpret_cast<zelReloadDriver_t>(
183+
GET_FUNCTION_PTR(ze_lib::context->loader, "zelReloadDriversInternal") );
184+
return reloadDrivers(flags);
185+
#else
186+
return zelReloadDriversInternal(flags);
187+
#endif
188+
}
189+
174190

175191
ze_result_t ZE_APICALL
176192
zelLoaderTranslateHandle(

Diff for: source/loader/ze_loader_api.cpp

+253
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,259 @@ zelLoaderGetVersionsInternal(
7373
return ZE_RESULT_SUCCESS;
7474
}
7575

76+
ZE_DLLEXPORT ze_result_t ZE_APICALL
77+
zelReloadDriversInternal(
78+
ze_init_flags_t flags)
79+
{
80+
for( auto& drv : loader::context->zeDrivers ) {
81+
if(drv.initStatus != ZE_RESULT_SUCCESS)
82+
continue;
83+
84+
if (drv.handle) {
85+
std::string freeLibraryErrorValue;
86+
auto free_result = FREE_DRIVER_LIBRARY( drv.handle );
87+
auto failure = FREE_DRIVER_LIBRARY_FAILURE_CHECK(free_result);
88+
if (failure)
89+
return ZE_RESULT_ERROR_UNINITIALIZED;
90+
}
91+
92+
drv.handle = LOAD_DRIVER_LIBRARY( drv.name.c_str() );
93+
if (NULL == drv.handle)
94+
return ZE_RESULT_ERROR_UNINITIALIZED;
95+
96+
auto zeGetGlobalProcAddrTable = reinterpret_cast<ze_pfnGetGlobalProcAddrTable_t>(
97+
GET_FUNCTION_PTR( drv.handle, "zeGetGlobalProcAddrTable") );
98+
if (!zeGetGlobalProcAddrTable)
99+
return ZE_RESULT_ERROR_UNINITIALIZED;
100+
auto zeGetGlobalProcAddrTableResult = zeGetGlobalProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Global);
101+
if (zeGetGlobalProcAddrTableResult != ZE_RESULT_SUCCESS)
102+
return zeGetGlobalProcAddrTableResult;
103+
104+
auto zeGetRTASBuilderExpProcAddrTable = reinterpret_cast<ze_pfnGetRTASBuilderExpProcAddrTable_t>(
105+
GET_FUNCTION_PTR( drv.handle, "zeGetRTASBuilderExpProcAddrTable") );
106+
if (!zeGetRTASBuilderExpProcAddrTable)
107+
return ZE_RESULT_ERROR_UNINITIALIZED;
108+
auto zeGetRTASBuilderExpProcAddrTableResult = zeGetRTASBuilderExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.RTASBuilderExp);
109+
if (zeGetRTASBuilderExpProcAddrTableResult != ZE_RESULT_SUCCESS)
110+
return zeGetRTASBuilderExpProcAddrTableResult;
111+
112+
auto zeGetRTASParallelOperationExpProcAddrTable = reinterpret_cast<ze_pfnGetRTASParallelOperationExpProcAddrTable_t>(
113+
GET_FUNCTION_PTR( drv.handle, "zeGetRTASParallelOperationExpProcAddrTable") );
114+
if (!zeGetRTASParallelOperationExpProcAddrTable)
115+
return ZE_RESULT_ERROR_UNINITIALIZED;
116+
auto zeGetRTASParallelOperationExpProcAddrTableResult = zeGetRTASParallelOperationExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.RTASParallelOperationExp);
117+
if (zeGetRTASParallelOperationExpProcAddrTableResult != ZE_RESULT_SUCCESS)
118+
return zeGetRTASParallelOperationExpProcAddrTableResult;
119+
120+
auto zeGetDriverProcAddrTable = reinterpret_cast<ze_pfnGetDriverProcAddrTable_t>(
121+
GET_FUNCTION_PTR( drv.handle, "zeGetDriverProcAddrTable") );
122+
if (!zeGetDriverProcAddrTable)
123+
return ZE_RESULT_ERROR_UNINITIALIZED;
124+
auto zeGetDriverProcAddrTableResult = zeGetDriverProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Driver);
125+
if (zeGetDriverProcAddrTableResult != ZE_RESULT_SUCCESS)
126+
return zeGetDriverProcAddrTableResult;
127+
128+
auto zeGetDriverExpProcAddrTable = reinterpret_cast<ze_pfnGetDriverExpProcAddrTable_t>(
129+
GET_FUNCTION_PTR( drv.handle, "zeGetDriverExpProcAddrTable") );
130+
if (!zeGetDriverExpProcAddrTable)
131+
return ZE_RESULT_ERROR_UNINITIALIZED;
132+
auto zeGetDriverExpProcAddrTableResult = zeGetDriverExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.DriverExp);
133+
if (zeGetDriverExpProcAddrTableResult != ZE_RESULT_SUCCESS)
134+
return zeGetDriverExpProcAddrTableResult;
135+
136+
auto zeGetDeviceProcAddrTable = reinterpret_cast<ze_pfnGetDeviceProcAddrTable_t>(
137+
GET_FUNCTION_PTR( drv.handle, "zeGetDeviceProcAddrTable") );
138+
if (!zeGetDeviceProcAddrTable)
139+
return ZE_RESULT_ERROR_UNINITIALIZED;
140+
auto zeGetDeviceProcAddrTableResult = zeGetDeviceProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Device);
141+
if (zeGetDeviceProcAddrTableResult != ZE_RESULT_SUCCESS)
142+
return zeGetDeviceProcAddrTableResult;
143+
144+
auto zeGetDeviceExpProcAddrTable = reinterpret_cast<ze_pfnGetDeviceExpProcAddrTable_t>(
145+
GET_FUNCTION_PTR( drv.handle, "zeGetDeviceExpProcAddrTable") );
146+
if (!zeGetDeviceExpProcAddrTable)
147+
return ZE_RESULT_ERROR_UNINITIALIZED;
148+
auto zeGetDeviceExpProcAddrTableResult = zeGetDeviceExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.DeviceExp);
149+
if (zeGetDeviceExpProcAddrTableResult != ZE_RESULT_SUCCESS)
150+
return zeGetDeviceExpProcAddrTableResult;
151+
152+
auto zeGetContextProcAddrTable = reinterpret_cast<ze_pfnGetContextProcAddrTable_t>(
153+
GET_FUNCTION_PTR( drv.handle, "zeGetContextProcAddrTable") );
154+
if (!zeGetContextProcAddrTable)
155+
return ZE_RESULT_ERROR_UNINITIALIZED;
156+
auto zeGetContextProcAddrTableResult = zeGetContextProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Context);
157+
if (zeGetContextProcAddrTableResult != ZE_RESULT_SUCCESS)
158+
return zeGetContextProcAddrTableResult;
159+
160+
auto zeGetCommandQueueProcAddrTable = reinterpret_cast<ze_pfnGetCommandQueueProcAddrTable_t>(
161+
GET_FUNCTION_PTR( drv.handle, "zeGetCommandQueueProcAddrTable") );
162+
if (!zeGetCommandQueueProcAddrTable)
163+
return ZE_RESULT_ERROR_UNINITIALIZED;
164+
auto zeGetCommandQueueProcAddrTableResult = zeGetCommandQueueProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.CommandQueue);
165+
if (zeGetCommandQueueProcAddrTableResult != ZE_RESULT_SUCCESS)
166+
return zeGetCommandQueueProcAddrTableResult;
167+
168+
auto zeGetCommandListProcAddrTable = reinterpret_cast<ze_pfnGetCommandListProcAddrTable_t>(
169+
GET_FUNCTION_PTR( drv.handle, "zeGetCommandListProcAddrTable") );
170+
if (!zeGetCommandListProcAddrTable)
171+
return ZE_RESULT_ERROR_UNINITIALIZED;
172+
auto zeGetCommandListProcAddrTableResult = zeGetCommandListProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.CommandList);
173+
if (zeGetCommandListProcAddrTableResult != ZE_RESULT_SUCCESS)
174+
return zeGetCommandListProcAddrTableResult;
175+
176+
auto zeGetCommandListExpProcAddrTable = reinterpret_cast<ze_pfnGetCommandListExpProcAddrTable_t>(
177+
GET_FUNCTION_PTR( drv.handle, "zeGetCommandListExpProcAddrTable") );
178+
if (!zeGetCommandListExpProcAddrTable)
179+
return ZE_RESULT_ERROR_UNINITIALIZED;
180+
auto zeGetCommandListExpProcAddrTableResult = zeGetCommandListExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.CommandListExp);
181+
if (zeGetCommandListExpProcAddrTableResult != ZE_RESULT_SUCCESS)
182+
return zeGetCommandListExpProcAddrTableResult;
183+
184+
auto zeGetEventProcAddrTable = reinterpret_cast<ze_pfnGetEventProcAddrTable_t>(
185+
GET_FUNCTION_PTR( drv.handle, "zeGetEventProcAddrTable") );
186+
if (!zeGetEventProcAddrTable)
187+
return ZE_RESULT_ERROR_UNINITIALIZED;
188+
auto zeGetEventProcAddrTableResult = zeGetEventProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Event);
189+
if (zeGetEventProcAddrTableResult != ZE_RESULT_SUCCESS)
190+
return zeGetEventProcAddrTableResult;
191+
192+
auto zeGetEventExpProcAddrTable = reinterpret_cast<ze_pfnGetEventExpProcAddrTable_t>(
193+
GET_FUNCTION_PTR( drv.handle, "zeGetEventExpProcAddrTable") );
194+
if (!zeGetEventExpProcAddrTable)
195+
return ZE_RESULT_ERROR_UNINITIALIZED;
196+
auto zeGetEventExpProcAddrTableResult = zeGetEventExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.EventExp);
197+
if (zeGetEventExpProcAddrTableResult != ZE_RESULT_SUCCESS)
198+
return zeGetEventExpProcAddrTableResult;
199+
200+
auto zeGetEventPoolProcAddrTable = reinterpret_cast<ze_pfnGetEventPoolProcAddrTable_t>(
201+
GET_FUNCTION_PTR( drv.handle, "zeGetEventPoolProcAddrTable") );
202+
if (!zeGetEventPoolProcAddrTable)
203+
return ZE_RESULT_ERROR_UNINITIALIZED;
204+
auto zeGetEventPoolProcAddrTableResult = zeGetEventPoolProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.EventPool);
205+
if (zeGetEventPoolProcAddrTableResult != ZE_RESULT_SUCCESS)
206+
return zeGetEventPoolProcAddrTableResult;
207+
208+
auto zeGetFenceProcAddrTable = reinterpret_cast<ze_pfnGetFenceProcAddrTable_t>(
209+
GET_FUNCTION_PTR( drv.handle, "zeGetFenceProcAddrTable") );
210+
if (!zeGetFenceProcAddrTable)
211+
return ZE_RESULT_ERROR_UNINITIALIZED;
212+
auto zeGetFenceProcAddrTableResult = zeGetFenceProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Fence);
213+
if (zeGetFenceProcAddrTableResult != ZE_RESULT_SUCCESS)
214+
return zeGetFenceProcAddrTableResult;
215+
216+
auto zeGetImageProcAddrTable = reinterpret_cast<ze_pfnGetImageProcAddrTable_t>(
217+
GET_FUNCTION_PTR( drv.handle, "zeGetImageProcAddrTable") );
218+
if (!zeGetImageProcAddrTable)
219+
return ZE_RESULT_ERROR_UNINITIALIZED;
220+
auto zeGetImageProcAddrTableResult = zeGetImageProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Image);
221+
if (zeGetImageProcAddrTableResult != ZE_RESULT_SUCCESS)
222+
return zeGetImageProcAddrTableResult;
223+
224+
auto zeGetImageExpProcAddrTable = reinterpret_cast<ze_pfnGetImageExpProcAddrTable_t>(
225+
GET_FUNCTION_PTR( drv.handle, "zeGetImageExpProcAddrTable") );
226+
if (!zeGetImageExpProcAddrTable)
227+
return ZE_RESULT_ERROR_UNINITIALIZED;
228+
auto zeGetImageExpProcAddrTableResult = zeGetImageExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.ImageExp);
229+
if (zeGetImageExpProcAddrTableResult != ZE_RESULT_SUCCESS)
230+
return zeGetImageExpProcAddrTableResult;
231+
232+
auto zeGetKernelProcAddrTable = reinterpret_cast<ze_pfnGetKernelProcAddrTable_t>(
233+
GET_FUNCTION_PTR( drv.handle, "zeGetKernelProcAddrTable") );
234+
if (!zeGetKernelProcAddrTable)
235+
return ZE_RESULT_ERROR_UNINITIALIZED;
236+
auto zeGetKernelProcAddrTableResult = zeGetKernelProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Kernel);
237+
if (zeGetKernelProcAddrTableResult != ZE_RESULT_SUCCESS)
238+
return zeGetKernelProcAddrTableResult;
239+
240+
auto zeGetKernelExpProcAddrTable = reinterpret_cast<ze_pfnGetKernelExpProcAddrTable_t>(
241+
GET_FUNCTION_PTR( drv.handle, "zeGetKernelExpProcAddrTable") );
242+
if (!zeGetKernelExpProcAddrTable)
243+
return ZE_RESULT_ERROR_UNINITIALIZED;
244+
auto zeGetKernelExpProcAddrTableResult = zeGetKernelExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.KernelExp);
245+
if (zeGetKernelExpProcAddrTableResult != ZE_RESULT_SUCCESS)
246+
return zeGetKernelExpProcAddrTableResult;
247+
248+
auto zeGetMemProcAddrTable = reinterpret_cast<ze_pfnGetMemProcAddrTable_t>(
249+
GET_FUNCTION_PTR( drv.handle, "zeGetMemProcAddrTable") );
250+
if (!zeGetMemProcAddrTable)
251+
return ZE_RESULT_ERROR_UNINITIALIZED;
252+
auto zeGetMemProcAddrTableResult = zeGetMemProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Mem);
253+
if (zeGetMemProcAddrTableResult != ZE_RESULT_SUCCESS)
254+
return zeGetMemProcAddrTableResult;
255+
256+
auto zeGetMemExpProcAddrTable = reinterpret_cast<ze_pfnGetMemExpProcAddrTable_t>(
257+
GET_FUNCTION_PTR( drv.handle, "zeGetMemExpProcAddrTable") );
258+
if (!zeGetMemExpProcAddrTable)
259+
return ZE_RESULT_ERROR_UNINITIALIZED;
260+
auto zeGetMemExpProcAddrTableResult = zeGetMemExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.MemExp);
261+
if (zeGetMemExpProcAddrTableResult != ZE_RESULT_SUCCESS)
262+
return zeGetMemExpProcAddrTableResult;
263+
264+
auto zeGetModuleProcAddrTable = reinterpret_cast<ze_pfnGetModuleProcAddrTable_t>(
265+
GET_FUNCTION_PTR( drv.handle, "zeGetModuleProcAddrTable") );
266+
if (!zeGetModuleProcAddrTable)
267+
return ZE_RESULT_ERROR_UNINITIALIZED;
268+
auto zeGetModuleProcAddrTableResult = zeGetModuleProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Module);
269+
if (zeGetModuleProcAddrTableResult != ZE_RESULT_SUCCESS)
270+
return zeGetModuleProcAddrTableResult;
271+
272+
auto zeGetModuleBuildLogProcAddrTable = reinterpret_cast<ze_pfnGetModuleBuildLogProcAddrTable_t>(
273+
GET_FUNCTION_PTR( drv.handle, "zeGetModuleBuildLogProcAddrTable") );
274+
if (!zeGetModuleBuildLogProcAddrTable)
275+
return ZE_RESULT_ERROR_UNINITIALIZED;
276+
auto zeGetModuleBuildLogProcAddrTableResult = zeGetModuleBuildLogProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.ModuleBuildLog);
277+
if (zeGetModuleBuildLogProcAddrTableResult != ZE_RESULT_SUCCESS)
278+
return zeGetModuleBuildLogProcAddrTableResult;
279+
280+
auto zeGetPhysicalMemProcAddrTable = reinterpret_cast<ze_pfnGetPhysicalMemProcAddrTable_t>(
281+
GET_FUNCTION_PTR( drv.handle, "zeGetPhysicalMemProcAddrTable") );
282+
if (!zeGetPhysicalMemProcAddrTable)
283+
return ZE_RESULT_ERROR_UNINITIALIZED;
284+
auto zeGetPhysicalMemProcAddrTableResult = zeGetPhysicalMemProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.PhysicalMem);
285+
if (zeGetPhysicalMemProcAddrTableResult != ZE_RESULT_SUCCESS)
286+
return zeGetPhysicalMemProcAddrTableResult;
287+
288+
auto zeGetSamplerProcAddrTable = reinterpret_cast<ze_pfnGetSamplerProcAddrTable_t>(
289+
GET_FUNCTION_PTR( drv.handle, "zeGetSamplerProcAddrTable") );
290+
if (!zeGetSamplerProcAddrTable)
291+
return ZE_RESULT_ERROR_UNINITIALIZED;
292+
auto zeGetSamplerProcAddrTableResult = zeGetSamplerProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.Sampler);
293+
if (zeGetSamplerProcAddrTableResult != ZE_RESULT_SUCCESS)
294+
return zeGetSamplerProcAddrTableResult;
295+
296+
auto zeGetVirtualMemProcAddrTable = reinterpret_cast<ze_pfnGetVirtualMemProcAddrTable_t>(
297+
GET_FUNCTION_PTR( drv.handle, "zeGetVirtualMemProcAddrTable") );
298+
if (!zeGetVirtualMemProcAddrTable)
299+
return ZE_RESULT_ERROR_UNINITIALIZED;
300+
auto zeGetVirtualMemProcAddrTableResult = zeGetVirtualMemProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.VirtualMem);
301+
if (zeGetVirtualMemProcAddrTableResult != ZE_RESULT_SUCCESS)
302+
return zeGetVirtualMemProcAddrTableResult;
303+
304+
auto zeGetFabricEdgeExpProcAddrTable = reinterpret_cast<ze_pfnGetFabricEdgeExpProcAddrTable_t>(
305+
GET_FUNCTION_PTR( drv.handle, "zeGetFabricEdgeExpProcAddrTable") );
306+
if (!zeGetFabricEdgeExpProcAddrTable)
307+
return ZE_RESULT_ERROR_UNINITIALIZED;
308+
auto zeGetFabricEdgeExpProcAddrTableResult = zeGetFabricEdgeExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.FabricEdgeExp);
309+
if (zeGetFabricEdgeExpProcAddrTableResult != ZE_RESULT_SUCCESS)
310+
return zeGetFabricEdgeExpProcAddrTableResult;
311+
312+
auto zeGetFabricVertexExpProcAddrTable = reinterpret_cast<ze_pfnGetFabricVertexExpProcAddrTable_t>(
313+
GET_FUNCTION_PTR( drv.handle, "zeGetFabricVertexExpProcAddrTable") );
314+
if (!zeGetFabricVertexExpProcAddrTable)
315+
return ZE_RESULT_ERROR_UNINITIALIZED;
316+
auto zeGetFabricVertexExpProcAddrTableResult = zeGetFabricVertexExpProcAddrTable(ZE_API_VERSION_CURRENT, &drv.dditable.ze.FabricVertexExp);
317+
if (zeGetFabricVertexExpProcAddrTableResult != ZE_RESULT_SUCCESS)
318+
return zeGetFabricVertexExpProcAddrTableResult;
319+
320+
auto initResult = drv.dditable.ze.Global.pfnInit(flags);
321+
// Bail out if any drivers that previously succeeded fail
322+
if (initResult != ZE_RESULT_SUCCESS)
323+
return initResult;
324+
}
325+
326+
return ZE_RESULT_SUCCESS;
327+
}
328+
76329

77330
ZE_DLLEXPORT ze_result_t ZE_APICALL
78331
zelLoaderTranslateHandleInternal(

Diff for: source/loader/ze_loader_api.h

+5
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ zelLoaderGetVersionsInternal(
6868
zel_component_version_t *versions); //Pointer to array of versions. If set to NULL, num_elems is returned
6969

7070

71+
ZE_DLLEXPORT ze_result_t ZE_APICALL
72+
zelReloadDriversInternal(
73+
ze_init_flags_t flags);
74+
75+
7176
ZE_DLLEXPORT ze_result_t ZE_APICALL
7277
zelLoaderTranslateHandleInternal(
7378
zel_handle_type_t handleType, //Handle type

0 commit comments

Comments
 (0)