@@ -3214,6 +3214,8 @@ static void ggml_compute_forward_reglu_f32(
3214
3214
GGML_ASSERT (dst->ne [0 ] == nc);
3215
3215
GGML_ASSERT (ggml_nrows (dst) == nr);
3216
3216
3217
+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3218
+
3217
3219
// rows per thread
3218
3220
const int dr = (nr + nth - 1 )/nth;
3219
3221
@@ -3224,7 +3226,8 @@ static void ggml_compute_forward_reglu_f32(
3224
3226
for (int i1 = ir0; i1 < ir1; i1++) {
3225
3227
ggml_vec_reglu_f32 (nc,
3226
3228
(float *) ((char *) dst->data + i1*( dst->nb [1 ])),
3227
- (float *) ((char *) src0->data + i1*(src0->nb [1 ])));
3229
+ (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3230
+ (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3228
3231
3229
3232
#ifndef NDEBUG
3230
3233
for (int k = 0 ; k < nc; k++) {
@@ -3255,6 +3258,8 @@ static void ggml_compute_forward_reglu_f16(
3255
3258
GGML_ASSERT (dst->ne [0 ] == nc);
3256
3259
GGML_ASSERT (ggml_nrows (dst) == nr);
3257
3260
3261
+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3262
+
3258
3263
// rows per thread
3259
3264
const int dr = (nr + nth - 1 )/nth;
3260
3265
@@ -3265,7 +3270,8 @@ static void ggml_compute_forward_reglu_f16(
3265
3270
for (int i1 = ir0; i1 < ir1; i1++) {
3266
3271
ggml_vec_reglu_f16 (nc,
3267
3272
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])),
3268
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])));
3273
+ (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3274
+ (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3269
3275
3270
3276
#ifndef NDEBUG
3271
3277
for (int k = 0 ; k < nc; k++) {
@@ -3321,6 +3327,8 @@ static void ggml_compute_forward_geglu_f32(
3321
3327
GGML_ASSERT (dst->ne [0 ] == nc);
3322
3328
GGML_ASSERT (ggml_nrows (dst) == nr);
3323
3329
3330
+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3331
+
3324
3332
// rows per thread
3325
3333
const int dr = (nr + nth - 1 )/nth;
3326
3334
@@ -3331,7 +3339,8 @@ static void ggml_compute_forward_geglu_f32(
3331
3339
for (int i1 = ir0; i1 < ir1; i1++) {
3332
3340
ggml_vec_geglu_f32 (nc,
3333
3341
(float *) ((char *) dst->data + i1*( dst->nb [1 ])),
3334
- (float *) ((char *) src0->data + i1*(src0->nb [1 ])));
3342
+ (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3343
+ (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3335
3344
3336
3345
#ifndef NDEBUG
3337
3346
for (int k = 0 ; k < nc; k++) {
@@ -3362,6 +3371,8 @@ static void ggml_compute_forward_geglu_f16(
3362
3371
GGML_ASSERT (dst->ne [0 ] == nc);
3363
3372
GGML_ASSERT (ggml_nrows (dst) == nr);
3364
3373
3374
+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3375
+
3365
3376
// rows per thread
3366
3377
const int dr = (nr + nth - 1 )/nth;
3367
3378
@@ -3372,7 +3383,8 @@ static void ggml_compute_forward_geglu_f16(
3372
3383
for (int i1 = ir0; i1 < ir1; i1++) {
3373
3384
ggml_vec_geglu_f16 (nc,
3374
3385
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])),
3375
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])));
3386
+ (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3387
+ (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3376
3388
3377
3389
#ifndef NDEBUG
3378
3390
for (int k = 0 ; k < nc; k++) {
@@ -3428,6 +3440,8 @@ static void ggml_compute_forward_swiglu_f32(
3428
3440
GGML_ASSERT (dst->ne [0 ] == nc);
3429
3441
GGML_ASSERT (ggml_nrows (dst) == nr);
3430
3442
3443
+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3444
+
3431
3445
// rows per thread
3432
3446
const int dr = (nr + nth - 1 )/nth;
3433
3447
@@ -3438,7 +3452,8 @@ static void ggml_compute_forward_swiglu_f32(
3438
3452
for (int i1 = ir0; i1 < ir1; i1++) {
3439
3453
ggml_vec_swiglu_f32 (nc,
3440
3454
(float *) ((char *) dst->data + i1*( dst->nb [1 ])),
3441
- (float *) ((char *) src0->data + i1*(src0->nb [1 ])));
3455
+ (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3456
+ (float *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3442
3457
3443
3458
#ifndef NDEBUG
3444
3459
for (int k = 0 ; k < nc; k++) {
@@ -3469,6 +3484,8 @@ static void ggml_compute_forward_swiglu_f16(
3469
3484
GGML_ASSERT (dst->ne [0 ] == nc);
3470
3485
GGML_ASSERT (ggml_nrows (dst) == nr);
3471
3486
3487
+ const int32_t swapped = ggml_get_op_params_i32 (dst, 1 );
3488
+
3472
3489
// rows per thread
3473
3490
const int dr = (nr + nth - 1 )/nth;
3474
3491
@@ -3479,7 +3496,8 @@ static void ggml_compute_forward_swiglu_f16(
3479
3496
for (int i1 = ir0; i1 < ir1; i1++) {
3480
3497
ggml_vec_swiglu_f16 (nc,
3481
3498
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb [1 ])),
3482
- (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])));
3499
+ (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? nc : 0 ),
3500
+ (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb [1 ])) + (swapped ? 0 : nc));
3483
3501
3484
3502
#ifndef NDEBUG
3485
3503
for (int k = 0 ; k < nc; k++) {
0 commit comments