Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/mp/proxy-io.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,11 @@ class Connection
//! ThreadMap.makeThread) used to service requests to clients.
::capnp::CapabilityServerSet<Thread> m_threads;

//! Thread pool populated by ThreadMap.makePool(). When a request arrives
//! with no context.thread set, PassField round-robins across these threads.
std::vector<Thread::Client> m_thread_pool;
size_t m_thread_pool_index{0};

//! Canceler for canceling promises that we want to discard when the
//! connection is destroyed. This is used to interrupt method calls that are
//! still executing at time of disconnection.
Expand Down
4 changes: 4 additions & 0 deletions include/mp/proxy.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ interface ThreadMap $count(0) {
# execute on. Clients create and name threads and pass the thread handle as
# a call parameter.
makeThread @0 (name :Text) -> (result :Thread);
# Pre-allocate a pool of server threads for implicit dispatch. When a
# request arrives with no context.thread set, the server dispatches it
# through this pool via a shared work queue.
makePool @1 (name :Text, count :UInt32) -> ();
}

interface Thread {
Expand Down
34 changes: 26 additions & 8 deletions include/mp/type-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,20 +201,38 @@ auto PassField(Priority<1>, TypeList<>, ServerContext& server_context, const Fn&
const auto& params = server_context.call_context.getParams();
Context::Reader context_arg = Accessor::get(params);
auto thread_client = context_arg.getThread();
auto result = server.m_context.connection->m_threads.getLocalServer(thread_client)
.then([&loop, invoke = kj::mv(invoke), req](const kj::Maybe<Thread::Server&>& perhaps) mutable {
// Assuming the thread object is found, pass it a pointer to the
// `invoke` lambda above which will invoke the function on that
// thread.
auto* connection = server.m_context.connection;
auto result = connection->m_threads.getLocalServer(thread_client)
.then([&loop, invoke = kj::mv(invoke), req, connection](const kj::Maybe<Thread::Server&>& perhaps) mutable {
// If the client specified a thread, dispatch to it directly.
KJ_IF_MAYBE (thread_server, perhaps) {
auto& thread = static_cast<ProxyServer<Thread>&>(*thread_server);
MP_LOG(loop, Log::Debug)
<< "IPC server post request #" << req << " {" << thread.m_thread_context.thread_name << "}";
return thread.template post<typename ServerContext::CallContext>(std::move(invoke));
} else {
MP_LOG(loop, Log::Error)
<< "IPC server error request #" << req << ", missing thread to execute request";
throw std::runtime_error("invalid thread handle");
// No thread specified — fall back to the connection's thread
// pool (populated by ThreadMap.makePool). Error if no pool.
auto& pool = connection->m_thread_pool;
if (pool.empty()) {
MP_LOG(loop, Log::Error)
<< "IPC server error request #" << req << ", no thread specified and no pool configured";
throw std::runtime_error("no thread specified and no pool configured");
}
size_t idx = connection->m_thread_pool_index++ % pool.size();
return connection->m_threads.getLocalServer(pool[idx])
.then([&loop, invoke = kj::mv(invoke), req](const kj::Maybe<Thread::Server&>& pool_perhaps) mutable {
KJ_IF_MAYBE (pt, pool_perhaps) {
auto& pool_thread = static_cast<ProxyServer<Thread>&>(*pt);
MP_LOG(loop, Log::Debug)
<< "IPC server post request #" << req << " {" << pool_thread.m_thread_context.thread_name << "}";
return pool_thread.template post<typename ServerContext::CallContext>(std::move(invoke));
} else {
MP_LOG(loop, Log::Error)
<< "IPC server error request #" << req << ", pool thread not found";
throw std::runtime_error("pool thread not found");
}
});
}
});
// Use connection m_canceler object to cancel the result promise if the
Expand Down
1 change: 1 addition & 0 deletions include/mp/type-threadmap.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ struct ProxyServer<ThreadMap> final : public virtual ThreadMap::Server
public:
ProxyServer(Connection& connection);
kj::Promise<void> makeThread(MakeThreadContext context) override;
kj::Promise<void> makePool(MakePoolContext context) override;
Connection& m_connection;
};

Expand Down
27 changes: 27 additions & 0 deletions src/mp/proxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <kj/function.h>
#include <kj/memory.h>
#include <kj/string.h>
#include <cstdint>
#include <map>
#include <memory>
#include <optional>
Expand All @@ -36,6 +37,7 @@
#include <tuple>
#include <unistd.h>
#include <utility>
#include <vector>

namespace mp {

Expand Down Expand Up @@ -415,6 +417,31 @@ kj::Promise<void> ProxyServer<Thread>::getName(GetNameContext context)

ProxyServer<ThreadMap>::ProxyServer(Connection& connection) : m_connection(connection) {}

kj::Promise<void> ProxyServer<ThreadMap>::makePool(MakePoolContext context)
{
if (!m_connection.m_thread_pool.empty()) {
throw std::runtime_error("makePool called on connection with existing pool");
}
EventLoop& loop{*m_connection.m_loop};
const auto& params = context.getParams();
const std::string pool_name = params.getName();
const uint32_t count = params.getCount();
for (uint32_t i = 0; i < count; ++i) {
const std::string thread_name = pool_name + "/pool/" + std::to_string(i);
std::promise<ThreadContext*> thread_context;
std::thread thread([&loop, &thread_context, thread_name]() {
g_thread_context.thread_name = ThreadName(loop.m_exe_name) + " (from " + thread_name + ")";
g_thread_context.waiter = std::make_unique<Waiter>();
Lock lock(g_thread_context.waiter->m_mutex);
thread_context.set_value(&g_thread_context);
g_thread_context.waiter->wait(lock, [] { return !g_thread_context.waiter; });
});
auto thread_server = kj::heap<ProxyServer<Thread>>(m_connection, *thread_context.get_future().get(), std::move(thread));
m_connection.m_thread_pool.push_back(m_connection.m_threads.add(kj::mv(thread_server)));
}
return kj::READY_NOW;
}
Comment on lines +420 to +443
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeated calls silently grow the pool beyond the intended size, and since naming is based on the loop index starting from 0 each time, you'd end up with duplicate names like two threads both called pool/0 making logs ambiguous.

I'm not sure what would be the correct approach but I thought about:

  • Return an error if m_connection.m_thread_pool is not empty
  • Replace the old pool by cleaning the m_connection.m_thread_pool and create a fresh one with the new count. This makes sense if the client needs to resize its dedicated pool.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a guard statement:

    if (!m_connection.m_thread_pool.empty()) {
        throw std::runtime_error("makePool called on connection with existing pool");
    }

I think throwing is appropriate here as I would like to get user feedback first. If they anticipate pool resizing during runtime will be valuable we can add resizing APIs


kj::Promise<void> ProxyServer<ThreadMap>::makeThread(MakeThreadContext context)
{
EventLoop& loop{*m_connection.m_loop};
Expand Down
73 changes: 73 additions & 0 deletions test/mp/test/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
#include <kj/async.h>
#include <kj/async-io.h>
#include <kj/common.h>
#include <kj/exception.h>
#include <kj/debug.h>
#include <kj/memory.h>
#include <kj/string.h>
#include <kj/test.h>
#include <memory>
#include <mp/proxy.h>
Expand Down Expand Up @@ -481,5 +483,76 @@ KJ_TEST("Make simultaneous IPC calls on single remote thread")
KJ_EXPECT(expected == 400);
}

KJ_TEST("Call async IPC method dispatched to pool thread")
{
TestSetup setup;
ProxyClient<messages::FooInterface>* foo = setup.client.get();

// Set up the thread map exchange so the client has the server's ThreadMap,
// then call makePool to pre-allocate two server threads.
foo->initThreadMap();
setup.server->m_impl->m_int_fn = [](int n) { return n * 2; };

ThreadContext& tc{g_thread_context};
std::atomic<size_t> running{3};
std::promise<void> pool_ready;
foo->m_context.loop->sync([&] {
auto pool_req = foo->m_context.connection->m_thread_map.makePoolRequest();
pool_req.setName("test");
pool_req.setCount(2);
foo->m_context.loop->m_task_set->add(
pool_req.send().then([&](auto&&) { pool_ready.set_value(); }));
});
pool_ready.get_future().get();

// Send three callIntFnAsync requests with no context.thread set.
// The server should dispatch each to a pool thread.
auto client{foo->m_client};
foo->m_context.loop->sync([&] {
for (size_t i = 0; i < running; ++i) {
auto request{client.callIntFnAsyncRequest()};
request.initContext(); // context present but thread unset
request.setArg(static_cast<int32_t>(i + 1));
foo->m_context.loop->m_task_set->add(request.send().then(
[&running, &tc, i](auto&& results) {
assert(results.getResult() == static_cast<int32_t>((i + 1) * 2));
running -= 1;
tc.waiter->m_cv.notify_all();
}));
}
});
{
Lock lock(tc.waiter->m_mutex);
tc.waiter->wait(lock, [&running] { return running == 0; });
}
}

KJ_TEST("Call async IPC method without thread or pool errors correctly")
{
TestSetup setup;
ProxyClient<messages::FooInterface>* foo = setup.client.get();
setup.server->m_impl->m_fn = [] {};

// Send a callFnAsync request with no context.thread and no pool configured.
// The server should throw the "no thread specified and no pool configured" error.
std::promise<void> done;
bool error_thrown{false};
foo->m_context.loop->sync([&] {
auto request{foo->m_client.callFnAsyncRequest()};
request.initContext();
foo->m_context.loop->m_task_set->add(
request.send().then(
[&](auto&&) { done.set_value(); },
[&](kj::Exception&& e) {
error_thrown = true;
KJ_EXPECT(std::string_view{e.getDescription().cStr()}.find(
"no thread specified and no pool configured") != std::string_view::npos);
done.set_value();
}));
});
done.get_future().get();
KJ_EXPECT(error_thrown);
}

} // namespace test
} // namespace mp
Loading