Skip to content

[Offload] olLaunchHostFunction #152482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion offload/liboffload/API/APIDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ class IsHandleType<string Type> {
!ne(!find(Type, "_handle_t", !sub(!size(Type), 9)), -1));
}

// Does the type end with '_cb_t'?
class IsCallbackType<string Type> {
// size("_cb_t") == 5
bit ret = !if(!lt(!size(Type), 5), 0,
!ne(!find(Type, "_cb_t", !sub(!size(Type), 5)), -1));
}

// Does the type end with '*'?
class IsPointerType<string Type> {
bit ret = !ne(!find(Type, "*", !sub(!size(Type), 1)), -1);
Expand Down Expand Up @@ -58,6 +65,7 @@ class Param<string Type, string Name, string Desc, bits<3> Flags = 0> {
TypeInfo type_info = TypeInfo<"", "">;
bit IsHandle = IsHandleType<type>.ret;
bit IsPointer = IsPointerType<type>.ret;
bit IsCallback = IsCallbackType<type>.ret;
}

// A parameter whose range is described by other parameters in the function.
Expand All @@ -81,7 +89,7 @@ class ShouldCheckHandle<Param P> {
}

class ShouldCheckPointer<Param P> {
bit ret = !and(P.IsPointer, !eq(!and(PARAM_OPTIONAL, P.flags), 0));
bit ret = !and(!or(P.IsPointer, P.IsCallback), !eq(!and(PARAM_OPTIONAL, P.flags), 0));
}

// For a list of returns that contains a specific return code, find and append
Expand Down
26 changes: 26 additions & 0 deletions offload/liboffload/API/Queue.td
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,29 @@ def : Function {
Return<"OL_ERRC_INVALID_QUEUE">
];
}

def : FptrTypedef {
let name = "ol_host_function_cb_t";
let desc = "Host function for use by `olLaunchHostFunction`.";
let params = [
Param<"void *", "UserData", "user specified data passed into `olLaunchHostFunction`.", PARAM_IN>,
];
let return = "void";
}

def : Function {
let name = "olLaunchHostFunction";
let desc = "Enqueue a callback function on the host.";
let details = [
"The provided function will be called from the same process as the one that called `olLaunchHostFunction`.",
"The callback will not run until all previous work submitted to the queue has completed.",
"The callback must return before any work submitted to the queue after it is started.",
"The callback must not call any liboffload API functions or any backend specific functions (such as Cuda or HSA library functions).",
];
let params = [
Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN>,
Param<"ol_host_function_cb_t", "Callback", "the callback function to call on the host", PARAM_IN>,
Param<"void *", "UserData", "a pointer that will be passed verbatim to the callback function", PARAM_IN_OPTIONAL>,
];
let returns = [];
}
7 changes: 7 additions & 0 deletions offload/liboffload/src/OffloadImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -830,5 +830,12 @@ Error olGetSymbolInfoSize_impl(ol_symbol_handle_t Symbol,
return olGetSymbolInfoImplDetail(Symbol, PropName, 0, nullptr, PropSizeRet);
}

Error olLaunchHostFunction_impl(ol_queue_handle_t Queue,
ol_host_function_cb_t Callback,
void *UserData) {
return Queue->Device->Device->enqueueHostCall(Callback, UserData,
Queue->AsyncInfo);
}

} // namespace offload
} // namespace llvm
48 changes: 48 additions & 0 deletions offload/plugins-nextgen/amdgpu/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,20 @@ struct AMDGPUStreamTy {
/// Indicate to spread data transfers across all available SDMAs
bool UseMultipleSdmaEngines;

/// Wrapper function for implementing host callbacks
static void CallbackWrapper(AMDGPUSignalTy *InputSignal,
AMDGPUSignalTy *OutputSignal,
void (*Callback)(void *), void *UserData) {
if (InputSignal)
if (auto Err = InputSignal->wait())
// Wait shouldn't report an error
reportFatalInternalError(std::move(Err));

Callback(UserData);

OutputSignal->signal();
}

/// Return the current number of asynchronous operations on the stream.
uint32_t size() const { return NextSlot; }

Expand Down Expand Up @@ -1495,6 +1509,31 @@ struct AMDGPUStreamTy {
OutputSignal->get());
}

Error pushHostCallback(void (*Callback)(void *), void *UserData) {
// Retrieve an available signal for the operation's output.
AMDGPUSignalTy *OutputSignal = nullptr;
if (auto Err = SignalManager.getResource(OutputSignal))
return Err;
OutputSignal->reset();
OutputSignal->increaseUseCount();

AMDGPUSignalTy *InputSignal;
{
std::lock_guard<std::mutex> Lock(Mutex);

// Consume stream slot and compute dependencies.
InputSignal = consume(OutputSignal).second;
}

// "Leaking" the thread here is consistent with other work added to the
// queue. The input and output signals will remain valid until the output is
// signaled.
std::thread(CallbackWrapper, InputSignal, OutputSignal, Callback, UserData)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth making a pool for these?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we already had callback handling in the HSA signal handler. It's called schedCallback.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

schedCallback enqueues tasks to be completed synchronously once the queue has finished all of its tasks (e.g. as part of synchronise). They don't participate in the input/output signal dependency resolution, which is what we need for olLaunchFunction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... Although, why don't we replace schedCallback with olLaunchHostFunction? And maybe only use one thread per queue rather than per host function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would introduce std::thread overhead, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, unless we had one thread per queue that could be sent work through a pipe?

But I think that's probably future work, are you okay with thread overhead just for olLaunchHostFunction?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd leave the schedCallback as they are, being executed by the thread that synchronizes/queries the stream, so no overhead. Making another thread execute these callbacks adds unnecessary inter-thread synchronization overhead in my opinion.

Efficiently implementing a per-stream thread that handles the host callbacks of that stream can be tricky: (1) the thread will need to sleep as much time as possible to avoid disturbing the other threads; (2) the synchronization mechanism between this thread and the ones enqueuing callbacks; (3) possible on-demand thread creation to avoid creating a thread for each stream when not needed.

.detach();

return Plugin::success();
}

/// Synchronize with the stream. The current thread waits until all operations
/// are finalized and it performs the pending post actions (i.e., releasing
/// intermediate buffers).
Expand Down Expand Up @@ -2554,6 +2593,15 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
return Plugin::success();
}

Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
AsyncInfoWrapperTy &AsyncInfo) override {
AMDGPUStreamTy *Stream = nullptr;
if (auto Err = getStream(AsyncInfo, Stream))
return Err;

return Stream->pushHostCallback(Callback, UserData);
};

/// Create an event.
Error createEventImpl(void **EventPtrStorage) override {
AMDGPUEventTy **Event = reinterpret_cast<AMDGPUEventTy **>(EventPtrStorage);
Expand Down
6 changes: 6 additions & 0 deletions offload/plugins-nextgen/common/include/PluginInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,12 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
Error initDeviceInfo(__tgt_device_info *DeviceInfo);
virtual Error initDeviceInfoImpl(__tgt_device_info *DeviceInfo) = 0;

/// Enqueue a host call to AsyncInfo
Error enqueueHostCall(void (*Callback)(void *), void *UserData,
__tgt_async_info *AsyncInfo);
virtual Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
AsyncInfoWrapperTy &AsyncInfo) = 0;

/// Create an event.
Error createEvent(void **EventPtrStorage);
virtual Error createEventImpl(void **EventPtrStorage) = 0;
Expand Down
9 changes: 9 additions & 0 deletions offload/plugins-nextgen/common/src/PluginInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,15 @@ Error GenericDeviceTy::initAsyncInfo(__tgt_async_info **AsyncInfoPtr) {
return Err;
}

Error GenericDeviceTy::enqueueHostCall(void (*Callback)(void *), void *UserData,
__tgt_async_info *AsyncInfo) {
AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo);

auto Err = enqueueHostCallImpl(Callback, UserData, AsyncInfoWrapper);
AsyncInfoWrapper.finalize(Err);
return Err;
}

Error GenericDeviceTy::initDeviceInfo(__tgt_device_info *DeviceInfo) {
assert(DeviceInfo && "Invalid device info");

Expand Down
13 changes: 13 additions & 0 deletions offload/plugins-nextgen/cuda/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,19 @@ struct CUDADeviceTy : public GenericDeviceTy {
return Plugin::success();
}

Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
AsyncInfoWrapperTy &AsyncInfo) override {
if (auto Err = setContext())
return Err;

CUstream Stream;
if (auto Err = getStream(AsyncInfo, Stream))
return Err;

CUresult Res = cuLaunchHostFunc(Stream, Callback, UserData);
return Plugin::check(Res, "error in cuStreamLaunchHostFunc: %s");
};

/// Create an event.
Error createEventImpl(void **EventPtrStorage) override {
CUevent *Event = reinterpret_cast<CUevent *>(EventPtrStorage);
Expand Down
6 changes: 6 additions & 0 deletions offload/plugins-nextgen/host/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,12 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
"initDeviceInfoImpl not supported");
}

Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
AsyncInfoWrapperTy &AsyncInfo) override {
Callback(UserData);
return Plugin::success();
};

/// This plugin does not support the event API. Do nothing without failing.
Error createEventImpl(void **EventPtrStorage) override {
*EventPtrStorage = nullptr;
Expand Down
3 changes: 2 additions & 1 deletion offload/unittests/OffloadAPI/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ add_offload_unittest("queue"
queue/olDestroyQueue.cpp
queue/olGetQueueInfo.cpp
queue/olGetQueueInfoSize.cpp
queue/olWaitEvents.cpp)
queue/olWaitEvents.cpp
queue/olLaunchHostFunction.cpp)

add_offload_unittest("symbol"
symbol/olGetSymbol.cpp
Expand Down
107 changes: 107 additions & 0 deletions offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
//===------- Offload API tests - olLaunchHostFunction ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "../common/Fixtures.hpp"
#include <OffloadAPI.h>
#include <gtest/gtest.h>
#include <thread>

struct olLaunchHostFunctionTest : OffloadQueueTest {};
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olLaunchHostFunctionTest);

struct olLaunchHostFunctionKernelTest : OffloadKernelTest {};
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olLaunchHostFunctionKernelTest);

TEST_P(olLaunchHostFunctionTest, Success) {
ASSERT_SUCCESS(olLaunchHostFunction(Queue, [](void *) {}, nullptr));
}

TEST_P(olLaunchHostFunctionTest, SuccessSequence) {
uint32_t Buff[16] = {1, 1};

for (auto BuffPtr = &Buff[2]; BuffPtr != &Buff[16]; BuffPtr++) {
ASSERT_SUCCESS(olLaunchHostFunction(
Queue,
[](void *BuffPtr) {
uint32_t *AsU32 = reinterpret_cast<uint32_t *>(BuffPtr);
AsU32[0] = AsU32[-1] + AsU32[-2];
},
BuffPtr));
}

ASSERT_SUCCESS(olSyncQueue(Queue));

for (uint32_t i = 2; i < 16; i++) {
ASSERT_EQ(Buff[i], Buff[i - 1] + Buff[i - 2]);
}
}

TEST_P(olLaunchHostFunctionKernelTest, SuccessBlocking) {
// Verify that a host kernel can block execution - A host task is created that
// only resolves when Block is set to false.
ol_kernel_launch_size_args_t LaunchArgs;
LaunchArgs.Dimensions = 1;
LaunchArgs.GroupSize = {64, 1, 1};
LaunchArgs.NumGroups = {1, 1, 1};
LaunchArgs.DynSharedMemory = 0;

ol_queue_handle_t Queue;
ASSERT_SUCCESS(olCreateQueue(Device, &Queue));

void *Mem;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
LaunchArgs.GroupSize.x * sizeof(uint32_t), &Mem));

uint32_t *Data = (uint32_t *)Mem;
for (uint32_t i = 0; i < 64; i++) {
Data[i] = 0;
}

volatile bool Block = true;
ASSERT_SUCCESS(olLaunchHostFunction(
Queue,
[](void *Ptr) {
volatile bool *Block =
reinterpret_cast<volatile bool *>(reinterpret_cast<bool *>(Ptr));

while (*Block)
std::this_thread::yield();
},
const_cast<bool *>(&Block)));

struct {
void *Mem;
} Args{Mem};
ASSERT_SUCCESS(
olLaunchKernel(Queue, Device, Kernel, &Args, sizeof(Args), &LaunchArgs));

std::this_thread::sleep_for(std::chrono::milliseconds(500));
for (uint32_t i = 0; i < 64; i++) {
ASSERT_EQ(Data[i], 0);
}

Block = false;
ASSERT_SUCCESS(olSyncQueue(Queue));

for (uint32_t i = 0; i < 64; i++) {
ASSERT_EQ(Data[i], i);
}

ASSERT_SUCCESS(olDestroyQueue(Queue));
ASSERT_SUCCESS(olMemFree(Mem));
}

TEST_P(olLaunchHostFunctionTest, InvalidNullCallback) {
ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER,
olLaunchHostFunction(Queue, nullptr, nullptr));
}

TEST_P(olLaunchHostFunctionTest, InvalidNullQueue) {
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
olLaunchHostFunction(nullptr, [](void *) {}, nullptr));
}
Loading