diff --git a/magnetron/magnetron_cpu.c b/magnetron/magnetron_cpu.c index 3df6f51..a4049c8 100644 --- a/magnetron/magnetron_cpu.c +++ b/magnetron/magnetron_cpu.c @@ -173,6 +173,7 @@ typedef struct mag_threadpool_t { volatile mag_atomic_t num_workers_online; /* Number of workers that are online */ mag_worker_t* workers; /* Array of workers */ const mag_kernel_registry_t* kernels; /* Specialized compute kernel registry */ + mag_kernel_context_t* kernel_ctx; /* Kernel context */ mag_thread_sched_prio_t sched_prio; /* Scheduling priority */ } mag_threadpool_t; @@ -189,6 +190,7 @@ typedef struct mag_cpu_device_t { mag_threadpool_t* pool; /* Thread pool. NULL if num_allocated_workers <= 1 */ uint32_t num_allocated_workers; /* Amount of worker thread used. if == 1 then single threaded mode and thread pool is not created */ mag_kernel_registry_t kernels; /* Compute kernels. Specialized by arch optimized version at boot (e.g. AVX, AVX512 etc..) */ + mag_kernel_context_t kernel_ctx; /* Kernel context */ } mag_cpu_device_t; /* Await signal to start work */ @@ -206,19 +208,19 @@ static bool mag_worker_await_work(mag_worker_t* worker, mag_threadpool_t* pool) } /* Execute the operation on the current thread */ -static void mag_worker_exec_thread_local(const mag_kernel_registry_t* kernels, mag_compute_payload_t* payload) { +static void mag_worker_exec_thread_local(const mag_kernel_registry_t* kernels, mag_compute_payload_t* payload, mag_kernel_context_t* ctx) { if (mag_likely(payload->node)) { /* Do the work 🦾 */ mag_op_t op = payload->node->op; - void (*kernel)(const mag_compute_payload_t*) = payload->is_fwd ? kernels->fwd[op] : kernels->bwd[op]; - (*kernel)(payload); + void (*kernel)(const mag_compute_payload_t*, mag_kernel_context_t* ctx) = payload->is_fwd ? kernels->fwd[op] : kernels->bwd[op]; + (*kernel)(payload, ctx); payload->node = NULL; } } /* Execute the operation and broadcast completion if last chunk was done */ -static void mag_worker_exec_and_broadcast(mag_threadpool_t* pool, const mag_kernel_registry_t* kernels, mag_compute_payload_t* payload) { +static void mag_worker_exec_and_broadcast(mag_threadpool_t* pool, const mag_kernel_registry_t* kernels, mag_compute_payload_t* payload, mag_kernel_context_t* ctx) { if (mag_likely(payload->thread_idx < pool->num_active_workers)) /* Execute the operation if we are an active thread. */ - mag_worker_exec_thread_local(kernels, payload); + mag_worker_exec_thread_local(kernels, payload, ctx); mag_mutex_lock(&pool->mtx); if (++pool->num_completed == pool->num_allocated_workers) /* If we are the last to finish, wake the main thread */ mag_cv_broadcast(&pool->cv); @@ -231,19 +233,20 @@ static MAG_HOTPROC void* mag_worker_thread_exec_op(void* arg) { mag_threadpool_t* pool = worker->pool; mag_compute_payload_t* payload = &worker->payload; const mag_kernel_registry_t* kernels = pool->kernels; + mag_kernel_context_t* ctx = pool->kernel_ctx; char name[32]; snprintf(name, sizeof(name), "mag_worker_%" PRIx64, payload->thread_idx); mag_thread_set_name(name); /*mag_thread_set_prio(pool->sched_prio);*/ mag_atomic_fetch_add(&pool->num_workers_online, 1, MAG_MO_SEQ_CST); while (mag_likely(mag_worker_await_work(worker, pool))) /* Main work loop: wait, work, signal status */ - mag_worker_exec_and_broadcast(pool, kernels, payload); + mag_worker_exec_and_broadcast(pool, kernels, payload, ctx); mag_atomic_fetch_sub(&pool->num_workers_online, 1, MAG_MO_SEQ_CST); return MAG_THREAD_RET_NONE; } /* Create thread pool and allocate threads */ -static mag_threadpool_t* mag_threadpool_create(uint32_t num_workers, const mag_kernel_registry_t* kernels, mag_thread_sched_prio_t prio) { /* Create a thread pool */ +static mag_threadpool_t* mag_threadpool_create(uint32_t num_workers, const mag_kernel_registry_t* kernels, mag_kernel_context_t* ctx, mag_thread_sched_prio_t prio) { /* Create a thread pool */ mag_threadpool_t* pool = mag_alloc_aligned(sizeof(*pool), __alignof(mag_threadpool_t)); memset(pool, 0, sizeof(*pool)); mag_worker_t* workers = mag_alloc_aligned(num_workers*sizeof(*workers), __alignof(mag_worker_t)); @@ -257,6 +260,7 @@ static mag_threadpool_t* mag_threadpool_create(uint32_t num_workers, const mag_k .num_workers_online = 0, /* Main thread as worker 0 */ .workers = workers, .kernels = kernels, + .kernel_ctx = ctx, .sched_prio = prio }; mag_cv_create(&pool->cv); @@ -324,27 +328,34 @@ static void mag_threadpool_barrier(mag_threadpool_t* pool) { /* Execute an operator tensor on the CPU */ static MAG_HOTPROC void mag_threadpool_parallel_compute(mag_threadpool_t* pool, mag_tensor_t* node, bool is_fwd, uint32_t num_active_workers) { mag_assert2(pool != NULL); - mag_threadpool_kickoff(pool, node, is_fwd, num_active_workers); /* Kick off workers */ - mag_cv_broadcast(&pool->cv); /* Wake up all workers */ - mag_worker_exec_and_broadcast(pool, pool->kernels, &pool->workers->payload); /* Main thread does work too */ - mag_threadpool_barrier(pool); /* Wait for all workers to finish */ + mag_threadpool_kickoff(pool, node, is_fwd, num_active_workers); /* Kick off workers */ + mag_cv_broadcast(&pool->cv); /* Wake up all workers */ + mag_worker_exec_and_broadcast(pool, pool->kernels, &pool->workers->payload, pool->kernel_ctx); /* Main thread does work too */ + mag_threadpool_barrier(pool); /* Wait for all workers to finish */ } static uint32_t mag_cpu_dynamic_work_scaling(mag_cpu_device_t* dvc, mag_op_t op, int64_t numel); static MAG_HOTPROC void mag_cpu_exec(mag_compute_device_t* dvc, bool is_fwd, mag_tensor_t* node) { mag_cpu_device_t* cpu_dvc = dvc->impl; - uint32_t intraop_workers = mag_cpu_dynamic_work_scaling(cpu_dvc, node->op, node->numel); + const mag_kernel_registry_t* kernels = &cpu_dvc->kernels; + mag_kernel_context_t* kctx = &cpu_dvc->kernel_ctx; /* Setup pre/post kernel context */ + kctx->node = node; + kctx->alloced_threads = cpu_dvc->num_allocated_workers; + uint32_t (*pre)(mag_kernel_context_t*) = is_fwd ? kernels->fwd_pre[node->op] : kernels->bwd_pre[node->op]; /* Fetch pre exec kernel */ + void (*post)(mag_kernel_context_t*) = is_fwd ? kernels->fwd_post[node->op] : kernels->bwd_post[node->op]; /* Fetch post exec kernel */ + uint32_t intraop_workers = pre ? (*pre)(kctx) : mag_cpu_dynamic_work_scaling(cpu_dvc, node->op, node->numel); /* Use thread count recommended by pre-kernel or compute general thread count heuristic. */ if (intraop_workers <= 1) { /* Main thread does the work (single threaded mode). */ mag_compute_payload_t payload = { .node = node, .thread_idx = 0, .thread_num = 1 }; - mag_worker_exec_thread_local(&cpu_dvc->kernels, &payload); - return; /* Done */ + mag_worker_exec_thread_local(&cpu_dvc->kernels, &payload, kctx); + goto epilogue; } - mag_threadpool_parallel_compute(cpu_dvc->pool, node, is_fwd, intraop_workers); /* Multithreaded mode. */ + mag_threadpool_parallel_compute(cpu_dvc->pool, node, is_fwd, intraop_workers); /* Multithreaded exec + barrier */ + epilogue: if (post) (*post)(kctx); /* Post-exec */ } static void mag_cpu_exec_fwd(mag_compute_device_t* dvc, mag_tensor_t* node) { @@ -429,10 +440,11 @@ static mag_cpu_device_t* mag_cpu_init_device(mag_ctx_t* ctx, uint32_t num_thread .pool = NULL, .num_allocated_workers = 0, .kernels = {}, + .kernel_ctx = {} }; mag_blas_detect_optimal_specialization(ctx, &dvc->kernels); if (num_threads > 1) { - dvc->pool = mag_threadpool_create(num_threads, &dvc->kernels, sched_prio); + dvc->pool = mag_threadpool_create(num_threads, &dvc->kernels, &dvc->kernel_ctx, sched_prio); dvc->num_allocated_workers = num_threads; } return dvc; diff --git a/magnetron/magnetron_cpu_blas.inl b/magnetron/magnetron_cpu_blas.inl index 530deec..6d9427b 100644 --- a/magnetron/magnetron_cpu_blas.inl +++ b/magnetron/magnetron_cpu_blas.inl @@ -1093,9 +1093,10 @@ static void MAG_HOTPROC mag_vgelu_dv_f32( /* gelu' : ℝ -> ℝ, x |-> TODO */ } } -static void mag_blas_nop(const mag_compute_payload_t* payload) { (void)payload; } +static void mag_blas_nop(const mag_compute_payload_t* payload, mag_kernel_context_t* ctx) { (void)payload; (void)ctx; } -static void mag_blas_clone(const mag_compute_payload_t* payload) { +static void mag_blas_clone(const mag_compute_payload_t* payload, mag_kernel_context_t* ctx) { + (void)ctx; mag_tensor_t* r = payload->node; const mag_tensor_t* x = r->op_inputs[0]; mag_assert2(mag_tensor_is_shape_eq(x, r)); @@ -1104,7 +1105,8 @@ static void mag_blas_clone(const mag_compute_payload_t* payload) { memcpy(b_r, b_x, mag_tensor_data_size(r)); } -static void MAG_HOTPROC mag_blas_mean_f32(const mag_compute_payload_t* payload) { +static void MAG_HOTPROC mag_blas_mean_f32(const mag_compute_payload_t* payload, mag_kernel_context_t* ctx) { + (void)ctx; mag_tensor_t* r = payload->node; const mag_tensor_t* x = r->op_inputs[0]; mag_f32_t* b_r = mag_f32p_mut(r); @@ -1133,7 +1135,8 @@ static void MAG_HOTPROC mag_blas_mean_f32(const mag_compute_payload_t* payload) *b_r = (mag_f32_t)sum; } -static void MAG_HOTPROC mag_blas_min_f32(const mag_compute_payload_t* payload) { +static void MAG_HOTPROC mag_blas_min_f32(const mag_compute_payload_t* payload, mag_kernel_context_t* ctx) { + (void)ctx; mag_tensor_t* r = payload->node; const mag_tensor_t* const x = r->op_inputs[0]; mag_f32_t* b_r = mag_f32p_mut(r); @@ -1158,7 +1161,8 @@ static void MAG_HOTPROC mag_blas_min_f32(const mag_compute_payload_t* payload) { *b_r = min; } -static void MAG_HOTPROC mag_blas_max_f32(const mag_compute_payload_t* payload) { +static void MAG_HOTPROC mag_blas_max_f32(const mag_compute_payload_t* payload, mag_kernel_context_t* ctx) { + (void)ctx; mag_tensor_t* r = payload->node; const mag_tensor_t* const x = r->op_inputs[0]; mag_f32_t* b_r = mag_f32p_mut(r); @@ -1183,7 +1187,8 @@ static void MAG_HOTPROC mag_blas_max_f32(const mag_compute_payload_t* payload) { *b_r = max; } -static void MAG_HOTPROC mag_blas_sum_f32(const mag_compute_payload_t* payload) { +static void MAG_HOTPROC mag_blas_sum_f32(const mag_compute_payload_t* payload, mag_kernel_context_t* ctx) { + (void)ctx; mag_tensor_t* r = payload->node; const mag_tensor_t* const x = r->op_inputs[0]; mag_f32_t* b_r = mag_f32p_mut(r); @@ -1209,7 +1214,8 @@ static void MAG_HOTPROC mag_blas_sum_f32(const mag_compute_payload_t* payload) { } #define mag_cpu_blas_impl_unary(T, name) \ - static void MAG_HOTPROC mag_blas_##name##_##T(const mag_compute_payload_t* payload) { \ + static void MAG_HOTPROC mag_blas_##name##_##T(const mag_compute_payload_t* payload, mag_kernel_context_t* ctx) { \ + (void)ctx; \ mag_tensor_t* r = payload->node; \ const mag_tensor_t* x = r->op_inputs[0]; \ mag_##T##_t* br = mag_##T##p_mut(r); \ @@ -1256,7 +1262,8 @@ mag_cpu_blas_impl_unary(f32, gelu_dv) #undef mag_cpu_blas_impl_unary #define mag_cpu_blas_impl_unary_scalar(T, name) \ - static void MAG_HOTPROC mag_blas_##name##s_##T(const mag_compute_payload_t* payload) { \ + static void MAG_HOTPROC mag_blas_##name##s_##T(const mag_compute_payload_t* payload, mag_kernel_context_t* ctx) { \ + (void)ctx; \ mag_tensor_t* r = payload->node; \ const mag_tensor_t* x = r->op_inputs[0]; \ mag_##T##_t xi = r->op_params->x.T; \ @@ -1287,7 +1294,8 @@ mag_cpu_blas_impl_unary_scalar(f32, pow) #undef mag_cpu_blas_impl_unary_scalar #define mag_cpu_blas_impl_binary(T, name, op) \ - static void MAG_HOTPROC mag_blas_##name##_##T(const mag_compute_payload_t* payload) { \ + static void MAG_HOTPROC mag_blas_##name##_##T(const mag_compute_payload_t* payload, mag_kernel_context_t* ctx) { \ + (void)ctx; \ mag_tensor_t* r = payload->node; \ const mag_tensor_t* x = r->op_inputs[0]; \ const mag_tensor_t* y = r->op_inputs[1]; \ @@ -1378,11 +1386,7 @@ mag_cpu_blas_impl_binary(f32, sub, -) mag_cpu_blas_impl_binary(f32, mul, *) mag_cpu_blas_impl_binary(f32, div, /) -/* -** Matrix multiplication. -** R = A x B -*/ -static void MAG_HOTPROC mag_blas_matmul_f32(const mag_compute_payload_t* payload) { +static void MAG_HOTPROC mag_blas_matmul_f32(const mag_compute_payload_t* payload, mag_kernel_context_t* ctx) { mag_tensor_t* r = payload->node; const mag_tensor_t* x = r->op_inputs[0]; const mag_tensor_t* y = r->op_inputs[1]; @@ -1404,24 +1408,42 @@ static void MAG_HOTPROC mag_blas_matmul_f32(const mag_compute_payload_t* payload int64_t ra = chunk*ti; int64_t rb = mag_xmin(ra+chunk, numel); bool tx = mag_tensor_is_transposed(x); - for (int64_t i = ra; i < rb; ++i) { /* Rows */ + + // For each row index (i) in the result. + for (int64_t i = ra; i < rb; ++i) { + // Initialize row i in R to 0. for (int64_t j = 0; j < yd1; ++j) { - float* xo = br + rd1*i + j; - mag_bnd_chk(xo, br, mag_tensor_data_size(r)); - *xo = 0.0f; + /* Using the computed strides: + * rs[0] == 1 and rs[1] == M. + * Thus, element (i,j) is at offset: i + M*j. + */ + float* pr = br + rs0 * i + rs1 * j; + mag_bnd_chk(pr, br, mag_tensor_data_size(r)); + *pr = 0.0f; } - for (int64_t k = 0; k < xd1; ++k) { /* Inner dim */ - const mag_f32_t* px = bx + (tx ? k*xd0 + i : xd1*i + k); + // Multiply: R(i,j) += X(i,k) * Y(k,j) + for (int64_t k = 0; k < xd1; ++k) { + const mag_f32_t* px; + if (tx) { + // If X is transposed, treat it as X^T, + // so element (i,k) is stored at (k,i). + px = bx + xs0 * k + xs1 * i; + } else { + // Otherwise, access X(i,k) as: i + M*k. + px = bx + xs0 * i + xs1 * k; + } mag_bnd_chk(px, bx, mag_tensor_data_size(x)); - for (int64_t j = 0; j < yd1; ++j) { /* Columns */ - mag_f32_t* pr = br + rd1*i + j; - const mag_f32_t* py = by + yd1*k + j; + for (int64_t j = 0; j < yd1; ++j) { + float* pr = br + rs0 * i + rs1 * j; // R(i,j) = i + M*j. + // Y is [K, N] stored with strides: ys0 == 1 and ys1 == K. + const mag_f32_t* py = by + ys0 * k + ys1 * j; // Y(k,j) = k + K*j. mag_bnd_chk(pr, br, mag_tensor_data_size(r)); mag_bnd_chk(py, by, mag_tensor_data_size(y)); *pr += (*px) * (*py); } } } + } #ifndef MAG_BLAS_SPECIALIZATION @@ -1587,7 +1609,7 @@ uint64_t MAG_BLAS_SPECIALIZATION_FEAT_REQUEST(void) { #endif -static void (*const forward_kernels[MAG_OP__NUM])(const mag_compute_payload_t*) = { +static void (*const forward_kernels[MAG_OP__NUM])(const mag_compute_payload_t*, mag_kernel_context_t* ctx) = { [MAG_OP_NOP] = &mag_blas_nop, [MAG_OP_CLONE] = &mag_blas_clone, [MAG_OP_VIEW] = &mag_blas_nop, @@ -1631,7 +1653,95 @@ static void (*const forward_kernels[MAG_OP__NUM])(const mag_compute_payload_t*) [MAG_OP_MATMUL] = &mag_blas_matmul_f32, }; -static void (*const backward_kernels[MAG_OP__NUM])(const mag_compute_payload_t*) = { +static uint32_t (*const pre_forward_kernels[MAG_OP__NUM])(mag_kernel_context_t*) = { + [MAG_OP_NOP] = NULL, + [MAG_OP_CLONE] = NULL, + [MAG_OP_VIEW] = NULL, + [MAG_OP_TRANSPOSE] = NULL, + [MAG_OP_PERMUTE] = NULL, + [MAG_OP_MEAN] = NULL, + [MAG_OP_MIN] = NULL, + [MAG_OP_MAX] = NULL, + [MAG_OP_SUM] = NULL, + [MAG_OP_ABS] = NULL, + [MAG_OP_NEG] = NULL, + [MAG_OP_LOG] = NULL, + [MAG_OP_SQR] = NULL, + [MAG_OP_SQRT] = NULL, + [MAG_OP_SIN] = NULL, + [MAG_OP_COS] = NULL, + [MAG_OP_STEP] = NULL, + [MAG_OP_EXP] = NULL, + [MAG_OP_SOFTMAX] = NULL, + [MAG_OP_SOFTMAX_DV] = NULL, + [MAG_OP_SIGMOID] = NULL, + [MAG_OP_SIGMOID_DV] = NULL, + [MAG_OP_HARD_SIGMOID] = NULL, + [MAG_OP_SILU] = NULL, + [MAG_OP_SILU_DV] = NULL, + [MAG_OP_TANH] = NULL, + [MAG_OP_TANH_DV] = NULL, + [MAG_OP_RELU] = NULL, + [MAG_OP_RELU_DV] = NULL, + [MAG_OP_GELU] = NULL, + [MAG_OP_GELU_DV] = NULL, + [MAG_OP_ADD] = NULL, + [MAG_OP_SUB] = NULL, + [MAG_OP_MUL] = NULL, + [MAG_OP_DIV] = NULL, + [MAG_OP_ADDS] = NULL, + [MAG_OP_SUBS] = NULL, + [MAG_OP_MULS] = NULL, + [MAG_OP_DIVS] = NULL, + [MAG_OP_POWS] = NULL, + [MAG_OP_MATMUL] = NULL, +}; + +static void (*const post_forward_kernels[MAG_OP__NUM])(mag_kernel_context_t*) = { + [MAG_OP_NOP] = NULL, + [MAG_OP_CLONE] = NULL, + [MAG_OP_VIEW] = NULL, + [MAG_OP_TRANSPOSE] = NULL, + [MAG_OP_PERMUTE] = NULL, + [MAG_OP_MEAN] = NULL, + [MAG_OP_MIN] = NULL, + [MAG_OP_MAX] = NULL, + [MAG_OP_SUM] = NULL, + [MAG_OP_ABS] = NULL, + [MAG_OP_NEG] = NULL, + [MAG_OP_LOG] = NULL, + [MAG_OP_SQR] = NULL, + [MAG_OP_SQRT] = NULL, + [MAG_OP_SIN] = NULL, + [MAG_OP_COS] = NULL, + [MAG_OP_STEP] = NULL, + [MAG_OP_EXP] = NULL, + [MAG_OP_SOFTMAX] = NULL, + [MAG_OP_SOFTMAX_DV] = NULL, + [MAG_OP_SIGMOID] = NULL, + [MAG_OP_SIGMOID_DV] = NULL, + [MAG_OP_HARD_SIGMOID] = NULL, + [MAG_OP_SILU] = NULL, + [MAG_OP_SILU_DV] = NULL, + [MAG_OP_TANH] = NULL, + [MAG_OP_TANH_DV] = NULL, + [MAG_OP_RELU] = NULL, + [MAG_OP_RELU_DV] = NULL, + [MAG_OP_GELU] = NULL, + [MAG_OP_GELU_DV] = NULL, + [MAG_OP_ADD] = NULL, + [MAG_OP_SUB] = NULL, + [MAG_OP_MUL] = NULL, + [MAG_OP_DIV] = NULL, + [MAG_OP_ADDS] = NULL, + [MAG_OP_SUBS] = NULL, + [MAG_OP_MULS] = NULL, + [MAG_OP_DIVS] = NULL, + [MAG_OP_POWS] = NULL, + [MAG_OP_MATMUL] = NULL, +}; + +static void (*const backward_kernels[MAG_OP__NUM])(const mag_compute_payload_t*, mag_kernel_context_t* ctx) = { [MAG_OP_NOP] = &mag_blas_nop, [MAG_OP_CLONE] = &mag_blas_clone, [MAG_OP_VIEW] = &mag_blas_nop, @@ -1675,7 +1785,101 @@ static void (*const backward_kernels[MAG_OP__NUM])(const mag_compute_payload_t*) [MAG_OP_MATMUL] = &mag_blas_matmul_f32, }; -void MAG_BLAS_SPECIALIZATION(mag_kernel_registry_t* kernels) { - memcpy(kernels->fwd, forward_kernels, sizeof(forward_kernels)); - memcpy(kernels->bwd, backward_kernels, sizeof(backward_kernels)); +static uint32_t (*const pre_backward_kernels[MAG_OP__NUM])(mag_kernel_context_t*) = { + [MAG_OP_NOP] = NULL, + [MAG_OP_CLONE] = NULL, + [MAG_OP_VIEW] = NULL, + [MAG_OP_TRANSPOSE] = NULL, + [MAG_OP_PERMUTE] = NULL, + [MAG_OP_MEAN] = NULL, + [MAG_OP_MIN] = NULL, + [MAG_OP_MAX] = NULL, + [MAG_OP_SUM] = NULL, + [MAG_OP_ABS] = NULL, + [MAG_OP_NEG] = NULL, + [MAG_OP_LOG] = NULL, + [MAG_OP_SQR] = NULL, + [MAG_OP_SQRT] = NULL, + [MAG_OP_SIN] = NULL, + [MAG_OP_COS] = NULL, + [MAG_OP_STEP] = NULL, + [MAG_OP_EXP] = NULL, + [MAG_OP_SOFTMAX] = NULL, + [MAG_OP_SOFTMAX_DV] = NULL, + [MAG_OP_SIGMOID] = NULL, + [MAG_OP_SIGMOID_DV] = NULL, + [MAG_OP_HARD_SIGMOID] = NULL, + [MAG_OP_SILU] = NULL, + [MAG_OP_SILU_DV] = NULL, + [MAG_OP_TANH] = NULL, + [MAG_OP_TANH_DV] = NULL, + [MAG_OP_RELU] = NULL, + [MAG_OP_RELU_DV] = NULL, + [MAG_OP_GELU] = NULL, + [MAG_OP_GELU_DV] = NULL, + [MAG_OP_ADD] = NULL, + [MAG_OP_SUB] = NULL, + [MAG_OP_MUL] = NULL, + [MAG_OP_DIV] = NULL, + [MAG_OP_ADDS] = NULL, + [MAG_OP_SUBS] = NULL, + [MAG_OP_MULS] = NULL, + [MAG_OP_DIVS] = NULL, + [MAG_OP_POWS] = NULL, + [MAG_OP_MATMUL] = NULL, +}; + +static void (*const post_backward_kernels[MAG_OP__NUM])(mag_kernel_context_t*) = { + [MAG_OP_NOP] = NULL, + [MAG_OP_CLONE] = NULL, + [MAG_OP_VIEW] = NULL, + [MAG_OP_TRANSPOSE] = NULL, + [MAG_OP_PERMUTE] = NULL, + [MAG_OP_MEAN] = NULL, + [MAG_OP_MIN] = NULL, + [MAG_OP_MAX] = NULL, + [MAG_OP_SUM] = NULL, + [MAG_OP_ABS] = NULL, + [MAG_OP_NEG] = NULL, + [MAG_OP_LOG] = NULL, + [MAG_OP_SQR] = NULL, + [MAG_OP_SQRT] = NULL, + [MAG_OP_SIN] = NULL, + [MAG_OP_COS] = NULL, + [MAG_OP_STEP] = NULL, + [MAG_OP_EXP] = NULL, + [MAG_OP_SOFTMAX] = NULL, + [MAG_OP_SOFTMAX_DV] = NULL, + [MAG_OP_SIGMOID] = NULL, + [MAG_OP_SIGMOID_DV] = NULL, + [MAG_OP_HARD_SIGMOID] = NULL, + [MAG_OP_SILU] = NULL, + [MAG_OP_SILU_DV] = NULL, + [MAG_OP_TANH] = NULL, + [MAG_OP_TANH_DV] = NULL, + [MAG_OP_RELU] = NULL, + [MAG_OP_RELU_DV] = NULL, + [MAG_OP_GELU] = NULL, + [MAG_OP_GELU_DV] = NULL, + [MAG_OP_ADD] = NULL, + [MAG_OP_SUB] = NULL, + [MAG_OP_MUL] = NULL, + [MAG_OP_DIV] = NULL, + [MAG_OP_ADDS] = NULL, + [MAG_OP_SUBS] = NULL, + [MAG_OP_MULS] = NULL, + [MAG_OP_DIVS] = NULL, + [MAG_OP_POWS] = NULL, + [MAG_OP_MATMUL] = NULL, +}; + +void MAG_BLAS_SPECIALIZATION(mag_kernel_registry_t* kernels, mag_kernel_context_t* ctx) { + for (unsigned i=0; i < MAG_OP__NUM; ++i) { + kernels->fwd_pre[i] = pre_forward_kernels[i]; + kernels->fwd[i] = forward_kernels[i]; + kernels->fwd_post[i] = post_forward_kernels[i]; + kernels->bwd_pre[i] = pre_backward_kernels[i]; + kernels->bwd[i] = backward_kernels[i]; + kernels->bwd_post[i] = post_backward_kernels[i]; + } } diff --git a/magnetron/magnetron_internal.h b/magnetron/magnetron_internal.h index c1dce57..75775ab 100644 --- a/magnetron/magnetron_internal.h +++ b/magnetron/magnetron_internal.h @@ -259,6 +259,8 @@ defined(__WIN32__) || defined(__TOS_WIN__) || defined(__WINDOWS__) /* Hardware destructive interference size. */ #define MAG_CACHE_LINE_SIZE 64 #endif +#define PAGE_SIZE_4k 0x1000 +#define PAGE_SIZE_2m 0x200000 static uint32_t MAG_AINLINE mag_bswap32(uint32_t x) { /* Swap bytes for endianess switch. Should be optimized to a (bswap/rev) instruction on modern compilers. */ #ifdef MAG_BE @@ -814,16 +816,42 @@ struct mag_tensor_t { (void)prefix##4; \ (void)prefix##5 -typedef struct mag_compute_payload_t { +typedef struct mag_compute_payload_t { /* Compute payload for kernel execution. */ int64_t thread_num; int64_t thread_idx; mag_tensor_t* node; bool is_fwd; } mag_compute_payload_t; -typedef struct mag_kernel_registry_t { - void (*fwd[MAG_OP__NUM])(const mag_compute_payload_t*); - void (*bwd[MAG_OP__NUM])(const mag_compute_payload_t*); +typedef struct mag_kctx_mm_t { /* Matmul kernel context. */ + float* c_buffers; + float* ws_buffers; + int64_t nthr_m; + int64_t nthr_n; + int64_t nthr_k; + int64_t nthr_mn; + int64_t ws_size_per_thr; + int64_t MB; + int64_t NB; + int64_t KB; + bool do_copy; +} mag_kctx_mm_t; + +typedef struct mag_kernel_context_t { /* General op kernel context. */ + mag_tensor_t* node; + int64_t alloced_threads; + union { + mag_kctx_mm_t mm; + } per_op; +} mag_kernel_context_t; + +typedef struct mag_kernel_registry_t { /* Kernel registry for operators. */ + uint32_t (*fwd_pre[MAG_OP__NUM])(mag_kernel_context_t*); + void (*fwd[MAG_OP__NUM])(const mag_compute_payload_t*, mag_kernel_context_t*); + void (*fwd_post[MAG_OP__NUM])(mag_kernel_context_t*); + uint32_t (*bwd_pre[MAG_OP__NUM])(mag_kernel_context_t*); + void (*bwd[MAG_OP__NUM])(const mag_compute_payload_t*, mag_kernel_context_t*); + void (*bwd_post[MAG_OP__NUM])(mag_kernel_context_t*); } mag_kernel_registry_t; #define mag_load_local_storage_group(xk, prefix, var) mag_load_local_storage_group_arr((xk)->var, prefix)