Skip to content

Commit a2edf67

Browse files
committed
[UR][Offload] Global query/read/write support
The following entry points are implemented here: * `DEVICE_INFO_GLOBAL_VARIABLE_SUPPORT` (is now true) * `urProgramGetGlobalVariablePointer` * `urEnqueueDeviceGlobalVariableRead` * `urEnqueueDeviceGlobalVariableWrite` In addition, the enqueue interface has been moved out into a helper function and the program handler now handles the `@global_id_mapping` property.
1 parent 8dea47b commit a2edf67

File tree

5 files changed

+137
-58
lines changed

5 files changed

+137
-58
lines changed

unified-runtime/source/adapters/offload/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
7777
case UR_DEVICE_INFO_MAX_WORK_ITEM_DIMENSIONS:
7878
return ReturnValue(uint32_t{3});
7979
case UR_DEVICE_INFO_COMPILER_AVAILABLE:
80+
case UR_DEVICE_INFO_GLOBAL_VARIABLE_SUPPORT:
8081
return ReturnValue(true);
8182
// Unimplemented features
8283
case UR_DEVICE_INFO_PROGRAM_SET_SPECIALIZATION_CONSTANTS:
83-
case UR_DEVICE_INFO_GLOBAL_VARIABLE_SUPPORT:
8484
case UR_DEVICE_INFO_USM_POOL_SUPPORT:
8585
case UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP:
8686
case UR_DEVICE_INFO_IMAGE_SUPPORT:

unified-runtime/source/adapters/offload/enqueue.cpp

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -93,26 +93,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
9393
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
9494
}
9595

96-
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
97-
ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingRead,
98-
size_t offset, size_t size, void *pDst, uint32_t numEventsInWaitList,
99-
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
100-
96+
namespace {
97+
ur_result_t doMemcpy(ur_queue_handle_t hQueue, void *DestPtr,
98+
ol_device_handle_t DestDevice, const void *SrcPtr,
99+
ol_device_handle_t SrcDevice, size_t size, bool blocking,
100+
uint32_t numEventsInWaitList,
101+
const ur_event_handle_t *phEventWaitList,
102+
ur_event_handle_t *phEvent) {
101103
// Ignore wait list for now
102104
(void)numEventsInWaitList;
103105
(void)phEventWaitList;
104106
//
105107

106108
ol_event_handle_t EventOut = nullptr;
107109

108-
char *DevPtr =
109-
reinterpret_cast<char *>(std::get<BufferMem>(hBuffer->Mem).Ptr);
110-
111-
OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, pDst, Adapter->HostDevice,
112-
DevPtr + offset, hQueue->OffloadDevice, size,
113-
phEvent ? &EventOut : nullptr));
110+
OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, DestPtr, DestDevice, SrcPtr,
111+
SrcDevice, size, phEvent ? &EventOut : nullptr));
114112

115-
if (blockingRead) {
113+
if (blocking) {
116114
OL_RETURN_ON_ERR(olWaitQueue(hQueue->OffloadQueue));
117115
}
118116

@@ -124,37 +122,63 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
124122

125123
return UR_RESULT_SUCCESS;
126124
}
125+
} // namespace
126+
127+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
128+
ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingRead,
129+
size_t offset, size_t size, void *pDst, uint32_t numEventsInWaitList,
130+
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
131+
char *DevPtr =
132+
reinterpret_cast<char *>(std::get<BufferMem>(hBuffer->Mem).Ptr);
133+
134+
return doMemcpy(hQueue, pDst, Adapter->HostDevice, DevPtr + offset,
135+
hQueue->OffloadDevice, size, blockingRead,
136+
numEventsInWaitList, phEventWaitList, phEvent);
137+
}
127138

128139
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
129140
ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingWrite,
130141
size_t offset, size_t size, const void *pSrc, uint32_t numEventsInWaitList,
131142
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
132-
133-
// Ignore wait list for now
134-
(void)numEventsInWaitList;
135-
(void)phEventWaitList;
136-
//
137-
138-
ol_event_handle_t EventOut = nullptr;
139-
140143
char *DevPtr =
141144
reinterpret_cast<char *>(std::get<BufferMem>(hBuffer->Mem).Ptr);
142145

143-
OL_RETURN_ON_ERR(olMemcpy(hQueue->OffloadQueue, DevPtr + offset,
144-
hQueue->OffloadDevice, pSrc, Adapter->HostDevice,
145-
size, phEvent ? &EventOut : nullptr));
146+
return doMemcpy(hQueue, DevPtr + offset, hQueue->OffloadDevice, pSrc,
147+
Adapter->HostDevice, size, blockingWrite, numEventsInWaitList,
148+
phEventWaitList, phEvent);
149+
}
146150

147-
if (blockingWrite) {
148-
OL_RETURN_ON_ERR(olWaitQueue(hQueue->OffloadQueue));
151+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead(
152+
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
153+
bool blockingRead, size_t count, size_t offset, void *pDst,
154+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
155+
ur_event_handle_t *phEvent) {
156+
void *Ptr;
157+
if (auto Err = urProgramGetGlobalVariablePointer(nullptr, hProgram, name,
158+
nullptr, &Ptr)) {
159+
return Err;
149160
}
150161

151-
if (phEvent) {
152-
auto *Event = new ur_event_handle_t_();
153-
Event->OffloadEvent = EventOut;
154-
*phEvent = Event;
162+
return doMemcpy(hQueue, pDst, Adapter->HostDevice,
163+
reinterpret_cast<const char *>(Ptr) + offset,
164+
hQueue->OffloadDevice, count, blockingRead,
165+
numEventsInWaitList, phEventWaitList, phEvent);
166+
}
167+
168+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
169+
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
170+
bool blockingWrite, size_t count, size_t offset, const void *pSrc,
171+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
172+
ur_event_handle_t *phEvent) {
173+
void *Ptr;
174+
if (auto Err = urProgramGetGlobalVariablePointer(nullptr, hProgram, name,
175+
nullptr, &Ptr)) {
176+
return Err;
155177
}
156178

157-
return UR_RESULT_SUCCESS;
179+
return doMemcpy(hQueue, reinterpret_cast<char *>(Ptr) + offset,
180+
hQueue->OffloadDevice, pSrc, Adapter->HostDevice, count,
181+
blockingWrite, numEventsInWaitList, phEventWaitList, phEvent);
158182
}
159183

160184
ur_result_t enqueueNoOp(ur_queue_handle_t hQueue, ur_event_handle_t *phEvent) {

unified-runtime/source/adapters/offload/program.cpp

Lines changed: 77 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace {
2929
#ifdef UR_CUDA_ENABLED
3030
ur_result_t ProgramCreateCudaWorkaround(ur_context_handle_t hContext,
3131
const uint8_t *Binary, size_t Length,
32-
ur_program_handle_t *phProgram) {
32+
ur_program_handle_t hProgram) {
3333
uint8_t *RealBinary;
3434
size_t RealLength;
3535
CUlinkState State;
@@ -48,25 +48,17 @@ ur_result_t ProgramCreateCudaWorkaround(ur_context_handle_t hContext,
4848
fprintf(stderr, "Performed CUDA bin workaround (size = %lu)\n", RealLength);
4949
#endif
5050

51-
ur_program_handle_t Program = new ur_program_handle_t_();
5251
auto Res = olCreateProgram(hContext->Device->OffloadDevice, RealBinary,
53-
RealLength, &Program->OffloadProgram);
52+
RealLength, &hProgram->OffloadProgram);
5453

5554
// Program owns the linked module now
5655
cuLinkDestroy(State);
5756

58-
if (Res != OL_SUCCESS) {
59-
delete Program;
60-
return offloadResultToUR(Res);
61-
}
62-
63-
*phProgram = Program;
64-
65-
return UR_RESULT_SUCCESS;
57+
return offloadResultToUR(Res);
6658
}
6759
#else
6860
ur_result_t ProgramCreateCudaWorkaround(ur_context_handle_t, const uint8_t *,
69-
size_t, ur_program_handle_t *) {
61+
size_t, ur_program_handle_t) {
7062
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
7163
}
7264
#endif
@@ -76,7 +68,8 @@ ur_result_t ProgramCreateCudaWorkaround(ur_context_handle_t, const uint8_t *,
7668
UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
7769
ur_context_handle_t hContext, uint32_t numDevices,
7870
ur_device_handle_t *phDevices, size_t *pLengths, const uint8_t **ppBinaries,
79-
const ur_program_properties_t *, ur_program_handle_t *phProgram) {
71+
const ur_program_properties_t *pProperties,
72+
ur_program_handle_t *phProgram) {
8073
if (numDevices > 1) {
8174
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
8275
}
@@ -100,24 +93,55 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
10093
}
10194
}
10295

96+
ur_program_handle_t Program = new ur_program_handle_t_{};
97+
Program->URContext = hContext;
98+
Program->Binary = RealBinary;
99+
Program->BinarySizeInBytes = RealLength;
100+
101+
// Parse properties
102+
if (pProperties) {
103+
if (pProperties->count > 0 && pProperties->pMetadatas == nullptr) {
104+
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
105+
} else if (pProperties->count == 0 && pProperties->pMetadatas != nullptr) {
106+
return UR_RESULT_ERROR_INVALID_SIZE;
107+
}
108+
109+
auto Length = pProperties->count;
110+
auto Metadata = pProperties->pMetadatas;
111+
for (size_t i = 0; i < Length; ++i) {
112+
const ur_program_metadata_t MetadataElement = Metadata[i];
113+
std::string MetadataElementName{MetadataElement.pName};
114+
115+
auto [Prefix, Tag] = splitMetadataName(MetadataElementName);
116+
117+
if (Tag == __SYCL_UR_PROGRAM_METADATA_GLOBAL_ID_MAPPING) {
118+
const char *MetadataValPtr =
119+
reinterpret_cast<const char *>(MetadataElement.value.pData) +
120+
sizeof(std::uint64_t);
121+
const char *MetadataValPtrEnd =
122+
MetadataValPtr + MetadataElement.size - sizeof(std::uint64_t);
123+
Program->GlobalIDMD[Prefix] =
124+
std::string{MetadataValPtr, MetadataValPtrEnd};
125+
}
126+
}
127+
}
128+
129+
ur_result_t Res;
103130
ol_platform_backend_t Backend;
104131
olGetPlatformInfo(phDevices[0]->Platform->OffloadPlatform,
105132
OL_PLATFORM_INFO_BACKEND, sizeof(Backend), &Backend);
106133
if (Backend == OL_PLATFORM_BACKEND_CUDA) {
107-
return ProgramCreateCudaWorkaround(hContext, RealBinary, RealLength,
108-
phProgram);
134+
Res =
135+
ProgramCreateCudaWorkaround(hContext, RealBinary, RealLength, Program);
136+
} else {
137+
Res = offloadResultToUR(olCreateProgram(hContext->Device->OffloadDevice,
138+
RealBinary, RealLength,
139+
&Program->OffloadProgram));
109140
}
110141

111-
ur_program_handle_t Program = new ur_program_handle_t_{};
112-
Program->URContext = hContext;
113-
Program->Binary = RealBinary;
114-
Program->BinarySizeInBytes = RealLength;
115-
auto Res = olCreateProgram(hContext->Device->OffloadDevice, RealBinary,
116-
RealLength, &Program->OffloadProgram);
117-
118-
if (Res != OL_SUCCESS) {
142+
if (Res != UR_RESULT_SUCCESS) {
119143
delete Program;
120-
return offloadResultToUR(Res);
144+
return Res;
121145
}
122146

123147
*phProgram = Program;
@@ -240,3 +264,32 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramSetSpecializationConstants(
240264
ur_program_handle_t, uint32_t, const ur_specialization_constant_info_t *) {
241265
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
242266
}
267+
268+
UR_APIEXPORT ur_result_t UR_APICALL urProgramGetGlobalVariablePointer(
269+
ur_device_handle_t, ur_program_handle_t hProgram,
270+
const char *pGlobalVariableName, size_t *pGlobalVariableSizeRet,
271+
void **ppGlobalVariablePointerRet) {
272+
auto DeviceGlobalNameIt = hProgram->GlobalIDMD.find(pGlobalVariableName);
273+
if (DeviceGlobalNameIt == hProgram->GlobalIDMD.end())
274+
return UR_RESULT_ERROR_INVALID_VALUE;
275+
std::string DeviceGlobalName = DeviceGlobalNameIt->second;
276+
277+
ol_symbol_handle_t Symbol;
278+
auto Err = olGetSymbol(hProgram->OffloadProgram, DeviceGlobalName.c_str(),
279+
OL_SYMBOL_KIND_GLOBAL_VARIABLE, &Symbol);
280+
if (Err && Err->Code == OL_ERRC_NOT_FOUND) {
281+
return UR_RESULT_ERROR_INVALID_VALUE;
282+
}
283+
OL_RETURN_ON_ERR(Err);
284+
285+
if (pGlobalVariableSizeRet) {
286+
OL_RETURN_ON_ERR(olGetSymbolInfo(Symbol,
287+
OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
288+
sizeof(size_t), pGlobalVariableSizeRet));
289+
}
290+
OL_RETURN_ON_ERR(olGetSymbolInfo(Symbol,
291+
OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
292+
sizeof(void *), ppGlobalVariablePointerRet));
293+
294+
return UR_RESULT_SUCCESS;
295+
}

unified-runtime/source/adapters/offload/program.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,6 @@ struct ur_program_handle_t_ : RefCounted {
2020
ur_context_handle_t URContext;
2121
const uint8_t *Binary;
2222
size_t BinarySizeInBytes;
23+
// A mapping from mangled global names -> names in the binary
24+
std::unordered_map<std::string, std::string> GlobalIDMD;
2325
};

unified-runtime/source/adapters/offload/ur_interface_loader.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramProcAddrTable(
9292
pDdiTable->pfnCreateWithNativeHandle = urProgramCreateWithNativeHandle;
9393
pDdiTable->pfnGetBuildInfo = nullptr;
9494
pDdiTable->pfnGetFunctionPointer = nullptr;
95-
pDdiTable->pfnGetGlobalVariablePointer = nullptr;
95+
pDdiTable->pfnGetGlobalVariablePointer = urProgramGetGlobalVariablePointer;
9696
pDdiTable->pfnGetInfo = urProgramGetInfo;
9797
pDdiTable->pfnGetNativeHandle = urProgramGetNativeHandle;
9898
pDdiTable->pfnLink = nullptr;
@@ -168,8 +168,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable(
168168
if (UR_RESULT_SUCCESS != result) {
169169
return result;
170170
}
171-
pDdiTable->pfnDeviceGlobalVariableRead = nullptr;
172-
pDdiTable->pfnDeviceGlobalVariableWrite = nullptr;
171+
pDdiTable->pfnDeviceGlobalVariableRead = urEnqueueDeviceGlobalVariableRead;
172+
pDdiTable->pfnDeviceGlobalVariableWrite = urEnqueueDeviceGlobalVariableWrite;
173173
pDdiTable->pfnEventsWait = nullptr;
174174
pDdiTable->pfnEventsWaitWithBarrier = nullptr;
175175
pDdiTable->pfnKernelLaunch = urEnqueueKernelLaunch;

0 commit comments

Comments
 (0)