@@ -6040,7 +6040,7 @@ static __device__ void rope_yarn(
6040
6040
// rope == RoPE == rotary positional embedding
6041
6041
template <typename T, bool has_pos>
6042
6042
static __global__ void rope (
6043
- const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
6043
+ const T * x, T * dst, int ncols, const float * pos, float freq_scale, int p_delta_rows, float freq_base,
6044
6044
float ext_factor, float attn_factor, rope_corr_dims corr_dims
6045
6045
) {
6046
6046
const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
@@ -6053,7 +6053,7 @@ static __global__ void rope(
6053
6053
const int i = row*ncols + col;
6054
6054
const int i2 = row/p_delta_rows;
6055
6055
6056
- const int p = has_pos ? pos[i2] : 0 ;
6056
+ const float p = has_pos ? pos[i2] : 0 . 0f ;
6057
6057
const float theta_base = p*powf (freq_base, -float (col)/ncols);
6058
6058
6059
6059
float cos_theta, sin_theta;
@@ -6068,7 +6068,7 @@ static __global__ void rope(
6068
6068
6069
6069
template <typename T, bool has_pos>
6070
6070
static __global__ void rope_neox (
6071
- const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
6071
+ const T * x, T * dst, int ncols, int n_dims, const float * pos, float freq_scale, int p_delta_rows,
6072
6072
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
6073
6073
) {
6074
6074
const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
@@ -6095,7 +6095,7 @@ static __global__ void rope_neox(
6095
6095
6096
6096
float cur_rot = inv_ndims * ic - ib;
6097
6097
6098
- const int p = has_pos ? pos[i2] : 0 ;
6098
+ const float p = has_pos ? pos[i2] : 0 . 0f ;
6099
6099
const float theta_base = p*freq_scale*powf (theta_scale, col/2 .0f );
6100
6100
6101
6101
float cos_theta, sin_theta;
@@ -6109,7 +6109,7 @@ static __global__ void rope_neox(
6109
6109
}
6110
6110
6111
6111
static __global__ void rope_glm_f32 (
6112
- const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
6112
+ const float * x, float * dst, int ncols, const float * pos, float freq_scale, int p_delta_rows, float freq_base,
6113
6113
int n_ctx
6114
6114
) {
6115
6115
const int col = blockDim .x *blockIdx .x + threadIdx .x ;
@@ -6124,10 +6124,10 @@ static __global__ void rope_glm_f32(
6124
6124
const int i2 = row/p_delta_rows;
6125
6125
6126
6126
const float col_theta_scale = powf (freq_base, -2 .0f *col/ncols);
6127
- // FIXME: this is likely wrong
6128
- const int p = pos != nullptr ? pos[i2] : 0 ;
6129
6127
6130
- const float theta = min (p, n_ctx - 2 )*freq_scale*col_theta_scale;
6128
+ const float p = pos != nullptr ? pos[i2] : 0 .0f ;
6129
+
6130
+ const float theta = min (p, (float ) n_ctx - 2 )*freq_scale*col_theta_scale;
6131
6131
const float sin_theta = sinf (theta);
6132
6132
const float cos_theta = cosf (theta);
6133
6133
@@ -6137,7 +6137,7 @@ static __global__ void rope_glm_f32(
6137
6137
dst[i + 0 ] = x0*cos_theta - x1*sin_theta;
6138
6138
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
6139
6139
6140
- const float block_theta = (( float ) max (p - n_ctx - 2 , 0 ) )*col_theta_scale;
6140
+ const float block_theta = max (p - n_ctx - 2 , 0 . 0f )*col_theta_scale;
6141
6141
const float sin_block_theta = sinf (block_theta);
6142
6142
const float cos_block_theta = cosf (block_theta);
6143
6143
@@ -7688,7 +7688,7 @@ static void clamp_f32_cuda(const float * x, float * dst, const float min, const
7688
7688
7689
7689
template <typename T>
7690
7690
static void rope_cuda (
7691
- const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
7691
+ const T * x, T * dst, int ncols, int nrows, const float * pos, float freq_scale, int p_delta_rows,
7692
7692
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
7693
7693
) {
7694
7694
GGML_ASSERT (ncols % 2 == 0 );
@@ -7708,7 +7708,7 @@ static void rope_cuda(
7708
7708
7709
7709
template <typename T>
7710
7710
static void rope_neox_cuda (
7711
- const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
7711
+ const T * x, T * dst, int ncols, int n_dims, int nrows, const float * pos, float freq_scale, int p_delta_rows,
7712
7712
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
7713
7713
) {
7714
7714
GGML_ASSERT (ncols % 2 == 0 );
@@ -7733,7 +7733,7 @@ static void rope_neox_cuda(
7733
7733
}
7734
7734
7735
7735
static void rope_glm_f32_cuda (
7736
- const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
7736
+ const float * x, float * dst, int ncols, int nrows, const float * pos, float freq_scale, int p_delta_rows,
7737
7737
float freq_base, int n_ctx, cudaStream_t stream
7738
7738
) {
7739
7739
GGML_ASSERT (ncols % 4 == 0 );
@@ -9035,11 +9035,11 @@ static void ggml_cuda_op_rope(
9035
9035
memcpy (&beta_fast, (int32_t *) dst->op_params + 9 , sizeof (float ));
9036
9036
memcpy (&beta_slow, (int32_t *) dst->op_params + 10 , sizeof (float ));
9037
9037
9038
- const int32_t * pos = nullptr ;
9038
+ const float * pos = nullptr ;
9039
9039
if ((mode & 1 ) == 0 ) {
9040
- GGML_ASSERT (src1->type == GGML_TYPE_I32 );
9040
+ GGML_ASSERT (src1->type == GGML_TYPE_F32 );
9041
9041
GGML_ASSERT (src1->ne [0 ] == ne2);
9042
- pos = (const int32_t *) src1_dd;
9042
+ pos = (const float *) src1_dd;
9043
9043
}
9044
9044
9045
9045
const bool is_neox = mode & 2 ;
0 commit comments