Skip to content

Commit

Permalink
Kernel pre post hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
MarioSieg committed Feb 13, 2025
1 parent 624fbed commit 816d2c1
Show file tree
Hide file tree
Showing 3 changed files with 292 additions and 48 deletions.
44 changes: 28 additions & 16 deletions magnetron/magnetron_cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ typedef struct mag_threadpool_t {
volatile mag_atomic_t num_workers_online; /* Number of workers that are online */
mag_worker_t* workers; /* Array of workers */
const mag_kernel_registry_t* kernels; /* Specialized compute kernel registry */
mag_kernel_context_t* kernel_ctx; /* Kernel context */
mag_thread_sched_prio_t sched_prio; /* Scheduling priority */
} mag_threadpool_t;

Expand All @@ -189,6 +190,7 @@ typedef struct mag_cpu_device_t {
mag_threadpool_t* pool; /* Thread pool. NULL if num_allocated_workers <= 1 */
uint32_t num_allocated_workers; /* Amount of worker thread used. if == 1 then single threaded mode and thread pool is not created */
mag_kernel_registry_t kernels; /* Compute kernels. Specialized by arch optimized version at boot (e.g. AVX, AVX512 etc..) */
mag_kernel_context_t kernel_ctx; /* Kernel context */
} mag_cpu_device_t;

/* Await signal to start work */
Expand All @@ -206,19 +208,19 @@ static bool mag_worker_await_work(mag_worker_t* worker, mag_threadpool_t* pool)
}

/* Execute the operation on the current thread */
static void mag_worker_exec_thread_local(const mag_kernel_registry_t* kernels, mag_compute_payload_t* payload) {
static void mag_worker_exec_thread_local(const mag_kernel_registry_t* kernels, mag_compute_payload_t* payload, mag_kernel_context_t* ctx) {
if (mag_likely(payload->node)) { /* Do the work 🦾 */
mag_op_t op = payload->node->op;
void (*kernel)(const mag_compute_payload_t*) = payload->is_fwd ? kernels->fwd[op] : kernels->bwd[op];
(*kernel)(payload);
void (*kernel)(const mag_compute_payload_t*, mag_kernel_context_t* ctx) = payload->is_fwd ? kernels->fwd[op] : kernels->bwd[op];
(*kernel)(payload, ctx);
payload->node = NULL;
}
}

/* Execute the operation and broadcast completion if last chunk was done */
static void mag_worker_exec_and_broadcast(mag_threadpool_t* pool, const mag_kernel_registry_t* kernels, mag_compute_payload_t* payload) {
static void mag_worker_exec_and_broadcast(mag_threadpool_t* pool, const mag_kernel_registry_t* kernels, mag_compute_payload_t* payload, mag_kernel_context_t* ctx) {
if (mag_likely(payload->thread_idx < pool->num_active_workers)) /* Execute the operation if we are an active thread. */
mag_worker_exec_thread_local(kernels, payload);
mag_worker_exec_thread_local(kernels, payload, ctx);
mag_mutex_lock(&pool->mtx);
if (++pool->num_completed == pool->num_allocated_workers) /* If we are the last to finish, wake the main thread */
mag_cv_broadcast(&pool->cv);
Expand All @@ -231,19 +233,20 @@ static MAG_HOTPROC void* mag_worker_thread_exec_op(void* arg) {
mag_threadpool_t* pool = worker->pool;
mag_compute_payload_t* payload = &worker->payload;
const mag_kernel_registry_t* kernels = pool->kernels;
mag_kernel_context_t* ctx = pool->kernel_ctx;
char name[32];
snprintf(name, sizeof(name), "mag_worker_%" PRIx64, payload->thread_idx);
mag_thread_set_name(name);
/*mag_thread_set_prio(pool->sched_prio);*/
mag_atomic_fetch_add(&pool->num_workers_online, 1, MAG_MO_SEQ_CST);
while (mag_likely(mag_worker_await_work(worker, pool))) /* Main work loop: wait, work, signal status */
mag_worker_exec_and_broadcast(pool, kernels, payload);
mag_worker_exec_and_broadcast(pool, kernels, payload, ctx);
mag_atomic_fetch_sub(&pool->num_workers_online, 1, MAG_MO_SEQ_CST);
return MAG_THREAD_RET_NONE;
}

/* Create thread pool and allocate threads */
static mag_threadpool_t* mag_threadpool_create(uint32_t num_workers, const mag_kernel_registry_t* kernels, mag_thread_sched_prio_t prio) { /* Create a thread pool */
static mag_threadpool_t* mag_threadpool_create(uint32_t num_workers, const mag_kernel_registry_t* kernels, mag_kernel_context_t* ctx, mag_thread_sched_prio_t prio) { /* Create a thread pool */
mag_threadpool_t* pool = mag_alloc_aligned(sizeof(*pool), __alignof(mag_threadpool_t));
memset(pool, 0, sizeof(*pool));
mag_worker_t* workers = mag_alloc_aligned(num_workers*sizeof(*workers), __alignof(mag_worker_t));
Expand All @@ -257,6 +260,7 @@ static mag_threadpool_t* mag_threadpool_create(uint32_t num_workers, const mag_k
.num_workers_online = 0, /* Main thread as worker 0 */
.workers = workers,
.kernels = kernels,
.kernel_ctx = ctx,
.sched_prio = prio
};
mag_cv_create(&pool->cv);
Expand Down Expand Up @@ -324,27 +328,34 @@ static void mag_threadpool_barrier(mag_threadpool_t* pool) {
/* Execute an operator tensor on the CPU */
static MAG_HOTPROC void mag_threadpool_parallel_compute(mag_threadpool_t* pool, mag_tensor_t* node, bool is_fwd, uint32_t num_active_workers) {
mag_assert2(pool != NULL);
mag_threadpool_kickoff(pool, node, is_fwd, num_active_workers); /* Kick off workers */
mag_cv_broadcast(&pool->cv); /* Wake up all workers */
mag_worker_exec_and_broadcast(pool, pool->kernels, &pool->workers->payload); /* Main thread does work too */
mag_threadpool_barrier(pool); /* Wait for all workers to finish */
mag_threadpool_kickoff(pool, node, is_fwd, num_active_workers); /* Kick off workers */
mag_cv_broadcast(&pool->cv); /* Wake up all workers */
mag_worker_exec_and_broadcast(pool, pool->kernels, &pool->workers->payload, pool->kernel_ctx); /* Main thread does work too */
mag_threadpool_barrier(pool); /* Wait for all workers to finish */
}

static uint32_t mag_cpu_dynamic_work_scaling(mag_cpu_device_t* dvc, mag_op_t op, int64_t numel);

static MAG_HOTPROC void mag_cpu_exec(mag_compute_device_t* dvc, bool is_fwd, mag_tensor_t* node) {
mag_cpu_device_t* cpu_dvc = dvc->impl;
uint32_t intraop_workers = mag_cpu_dynamic_work_scaling(cpu_dvc, node->op, node->numel);
const mag_kernel_registry_t* kernels = &cpu_dvc->kernels;
mag_kernel_context_t* kctx = &cpu_dvc->kernel_ctx; /* Setup pre/post kernel context */
kctx->node = node;
kctx->alloced_threads = cpu_dvc->num_allocated_workers;
uint32_t (*pre)(mag_kernel_context_t*) = is_fwd ? kernels->fwd_pre[node->op] : kernels->bwd_pre[node->op]; /* Fetch pre exec kernel */
void (*post)(mag_kernel_context_t*) = is_fwd ? kernels->fwd_post[node->op] : kernels->bwd_post[node->op]; /* Fetch post exec kernel */
uint32_t intraop_workers = pre ? (*pre)(kctx) : mag_cpu_dynamic_work_scaling(cpu_dvc, node->op, node->numel); /* Use thread count recommended by pre-kernel or compute general thread count heuristic. */
if (intraop_workers <= 1) { /* Main thread does the work (single threaded mode). */
mag_compute_payload_t payload = {
.node = node,
.thread_idx = 0,
.thread_num = 1
};
mag_worker_exec_thread_local(&cpu_dvc->kernels, &payload);
return; /* Done */
mag_worker_exec_thread_local(&cpu_dvc->kernels, &payload, kctx);
goto epilogue;
}
mag_threadpool_parallel_compute(cpu_dvc->pool, node, is_fwd, intraop_workers); /* Multithreaded mode. */
mag_threadpool_parallel_compute(cpu_dvc->pool, node, is_fwd, intraop_workers); /* Multithreaded exec + barrier */
epilogue: if (post) (*post)(kctx); /* Post-exec */
}

static void mag_cpu_exec_fwd(mag_compute_device_t* dvc, mag_tensor_t* node) {
Expand Down Expand Up @@ -429,10 +440,11 @@ static mag_cpu_device_t* mag_cpu_init_device(mag_ctx_t* ctx, uint32_t num_thread
.pool = NULL,
.num_allocated_workers = 0,
.kernels = {},
.kernel_ctx = {}
};
mag_blas_detect_optimal_specialization(ctx, &dvc->kernels);
if (num_threads > 1) {
dvc->pool = mag_threadpool_create(num_threads, &dvc->kernels, sched_prio);
dvc->pool = mag_threadpool_create(num_threads, &dvc->kernels, &dvc->kernel_ctx, sched_prio);
dvc->num_allocated_workers = num_threads;
}
return dvc;
Expand Down
Loading

0 comments on commit 816d2c1

Please sign in to comment.