@@ -416,6 +416,18 @@ void cblas_cgeadd(OPENBLAS_CONST enum CBLAS_ORDER CORDER,OPENBLAS_CONST blasint
416
416
void cblas_zgeadd (OPENBLAS_CONST enum CBLAS_ORDER CORDER ,OPENBLAS_CONST blasint crows , OPENBLAS_CONST blasint ccols , OPENBLAS_CONST double * calpha , double * a , OPENBLAS_CONST blasint clda , OPENBLAS_CONST double * cbeta ,
417
417
double * c , OPENBLAS_CONST blasint cldc );
418
418
419
+ void cblas_sgemm_batch (OPENBLAS_CONST enum CBLAS_ORDER Order , OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array , OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array , OPENBLAS_CONST blasint * M_array , OPENBLAS_CONST blasint * N_array , OPENBLAS_CONST blasint * K_array ,
420
+ OPENBLAS_CONST float * alpha_array , OPENBLAS_CONST float * * A_array , OPENBLAS_CONST blasint * lda_array , OPENBLAS_CONST float * * B_array , OPENBLAS_CONST blasint * ldb_array , OPENBLAS_CONST float * beta_array , float * * C_array , OPENBLAS_CONST blasint * ldc_array , OPENBLAS_CONST blasint group_count , OPENBLAS_CONST blasint * group_size );
421
+
422
+ void cblas_dgemm_batch (OPENBLAS_CONST enum CBLAS_ORDER Order , OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array , OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array , OPENBLAS_CONST blasint * M_array , OPENBLAS_CONST blasint * N_array , OPENBLAS_CONST blasint * K_array ,
423
+ OPENBLAS_CONST double * alpha_array , OPENBLAS_CONST double * * A_array , OPENBLAS_CONST blasint * lda_array , OPENBLAS_CONST double * * B_array , OPENBLAS_CONST blasint * ldb_array , OPENBLAS_CONST double * beta_array , double * * C_array , OPENBLAS_CONST blasint * ldc_array , OPENBLAS_CONST blasint group_count , OPENBLAS_CONST blasint * group_size );
424
+
425
+ void cblas_cgemm_batch (OPENBLAS_CONST enum CBLAS_ORDER Order , OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array , OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array , OPENBLAS_CONST blasint * M_array , OPENBLAS_CONST blasint * N_array , OPENBLAS_CONST blasint * K_array ,
426
+ OPENBLAS_CONST void * alpha_array , OPENBLAS_CONST void * * A_array , OPENBLAS_CONST blasint * lda_array , OPENBLAS_CONST void * * B_array , OPENBLAS_CONST blasint * ldb_array , OPENBLAS_CONST void * beta_array , void * * C_array , OPENBLAS_CONST blasint * ldc_array , OPENBLAS_CONST blasint group_count , OPENBLAS_CONST blasint * group_size );
427
+
428
+ void cblas_zgemm_batch (OPENBLAS_CONST enum CBLAS_ORDER Order , OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array , OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array , OPENBLAS_CONST blasint * M_array , OPENBLAS_CONST blasint * N_array , OPENBLAS_CONST blasint * K_array ,
429
+ OPENBLAS_CONST void * alpha_array , OPENBLAS_CONST void * * A_array , OPENBLAS_CONST blasint * lda_array , OPENBLAS_CONST void * * B_array , OPENBLAS_CONST blasint * ldb_array , OPENBLAS_CONST void * beta_array , void * * C_array , OPENBLAS_CONST blasint * ldc_array , OPENBLAS_CONST blasint group_count , OPENBLAS_CONST blasint * group_size );
430
+
419
431
/*** BFLOAT16 and INT8 extensions ***/
420
432
/* convert float array to BFLOAT16 array by rounding */
421
433
void cblas_sbstobf16 (OPENBLAS_CONST blasint n , OPENBLAS_CONST float * in , OPENBLAS_CONST blasint incin , bfloat16 * out , OPENBLAS_CONST blasint incout );
@@ -431,6 +443,9 @@ void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum
431
443
432
444
void cblas_sbgemm (OPENBLAS_CONST enum CBLAS_ORDER Order , OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA , OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB , OPENBLAS_CONST blasint M , OPENBLAS_CONST blasint N , OPENBLAS_CONST blasint K ,
433
445
OPENBLAS_CONST float alpha , OPENBLAS_CONST bfloat16 * A , OPENBLAS_CONST blasint lda , OPENBLAS_CONST bfloat16 * B , OPENBLAS_CONST blasint ldb , OPENBLAS_CONST float beta , float * C , OPENBLAS_CONST blasint ldc );
446
+ void cblas_sbgemm_batch (OPENBLAS_CONST enum CBLAS_ORDER Order , OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransA_array , OPENBLAS_CONST enum CBLAS_TRANSPOSE * TransB_array , OPENBLAS_CONST blasint * M_array , OPENBLAS_CONST blasint * N_array , OPENBLAS_CONST blasint * K_array ,
447
+ OPENBLAS_CONST float * alpha_array , OPENBLAS_CONST bfloat16 * * A_array , OPENBLAS_CONST blasint * lda_array , OPENBLAS_CONST bfloat16 * * B_array , OPENBLAS_CONST blasint * ldb_array , OPENBLAS_CONST float * beta_array , float * * C_array , OPENBLAS_CONST blasint * ldc_array , OPENBLAS_CONST blasint group_count , OPENBLAS_CONST blasint * group_size );
448
+
434
449
#ifdef __cplusplus
435
450
}
436
451
#endif /* __cplusplus */
0 commit comments