10
10
namespace vllm {
11
11
12
12
// Activation and gating kernel template.
13
- template <typename scalar_t , scalar_t (*ACT_FN)(const scalar_t &)>
13
+ template <typename scalar_t , scalar_t (*ACT_FN)(const scalar_t &)>
14
14
__global__ void act_and_mul_kernel (
15
- scalar_t * __restrict__ out, // [..., d]
16
- const scalar_t * __restrict__ input, // [..., 2, d]
17
- const int d) {
15
+ scalar_t * __restrict__ out, // [..., d]
16
+ const scalar_t * __restrict__ input, // [..., 2, d]
17
+ const int d) {
18
18
const int64_t token_idx = blockIdx .x ;
19
19
for (int64_t idx = threadIdx .x ; idx < d; idx += blockDim .x ) {
20
20
const scalar_t x = VLLM_LDG (&input[token_idx * 2 * d + idx]);
@@ -23,139 +23,128 @@ __global__ void act_and_mul_kernel(
23
23
}
24
24
}
25
25
26
- template <typename T>
26
+ template <typename T>
27
27
__device__ __forceinline__ T silu_kernel (const T& x) {
28
28
// x * sigmoid(x)
29
- return (T) (((float ) x) / (1 .0f + expf ((float ) -x)));
29
+ return (T)(((float )x) / (1 .0f + expf ((float )-x)));
30
30
}
31
31
32
- template <typename T>
32
+ template <typename T>
33
33
__device__ __forceinline__ T gelu_kernel (const T& x) {
34
34
// Equivalent to PyTorch GELU with 'none' approximation.
35
35
// Refer to:
36
36
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
37
- const float f = (float ) x;
37
+ const float f = (float )x;
38
38
constexpr float ALPHA = M_SQRT1_2;
39
- return (T) (f * 0 .5f * (1 .0f + ::erf (f * ALPHA)));
39
+ return (T)(f * 0 .5f * (1 .0f + ::erf (f * ALPHA)));
40
40
}
41
41
42
- template <typename T>
42
+ template <typename T>
43
43
__device__ __forceinline__ T gelu_tanh_kernel (const T& x) {
44
44
// Equivalent to PyTorch GELU with 'tanh' approximation.
45
45
// Refer to:
46
46
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
47
- const float f = (float ) x;
47
+ const float f = (float )x;
48
48
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0 .5f ;
49
49
constexpr float KAPPA = 0.044715 ;
50
50
float x_cube = f * f * f;
51
51
float inner = BETA * (f + KAPPA * x_cube);
52
- return (T) (0 .5f * f * (1 .0f + ::tanhf (inner)));
52
+ return (T)(0 .5f * f * (1 .0f + ::tanhf (inner)));
53
53
}
54
54
55
- } // namespace vllm
55
+ } // namespace vllm
56
56
57
57
// Launch activation and gating kernel.
58
- #define LAUNCH_ACTIVATION_GATE_KERNEL (KERNEL ) \
59
- int d = input.size(-1 ) / 2 ; \
60
- int64_t num_tokens = input.numel() / input.size(-1 ); \
61
- dim3 grid (num_tokens); \
62
- dim3 block (std::min(d, 1024 )); \
63
- const at::cuda::OptionalCUDAGuard device_guard (device_of(input)); \
64
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
65
- VLLM_DISPATCH_FLOATING_TYPES ( \
66
- input.scalar_type(), \
67
- "act_and_mul_kernel", \
68
- [&] { \
69
- vllm::act_and_mul_kernel<scalar_t , KERNEL<scalar_t >><<<grid, block, 0 , stream>>> ( \
70
- out.data_ptr <scalar_t >(), \
71
- input.data_ptr <scalar_t >(), \
72
- d); \
73
- });
74
-
75
- void silu_and_mul (
76
- torch::Tensor& out, // [..., d]
77
- torch::Tensor& input) // [..., 2 * d]
58
+ #define LAUNCH_ACTIVATION_GATE_KERNEL (KERNEL ) \
59
+ int d = input.size(-1 ) / 2 ; \
60
+ int64_t num_tokens = input.numel() / input.size(-1 ); \
61
+ dim3 grid (num_tokens); \
62
+ dim3 block (std::min(d, 1024 )); \
63
+ const at::cuda::OptionalCUDAGuard device_guard (device_of(input)); \
64
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
65
+ VLLM_DISPATCH_FLOATING_TYPES ( \
66
+ input.scalar_type(), "act_and_mul_kernel", [&] { \
67
+ vllm::act_and_mul_kernel<scalar_t , KERNEL<scalar_t >> \
68
+ <<<grid, block, 0 , stream>>> (out.data_ptr <scalar_t >(), \
69
+ input.data_ptr <scalar_t >(), d); \
70
+ });
71
+
72
+ void silu_and_mul (torch::Tensor& out, // [..., d]
73
+ torch::Tensor& input) // [..., 2 * d]
78
74
{
79
75
LAUNCH_ACTIVATION_GATE_KERNEL (vllm::silu_kernel);
80
76
}
81
77
82
- void gelu_and_mul (
83
- torch::Tensor& out, // [..., d]
84
- torch::Tensor& input) // [..., 2 * d]
78
+ void gelu_and_mul (torch::Tensor& out, // [..., d]
79
+ torch::Tensor& input) // [..., 2 * d]
85
80
{
86
81
LAUNCH_ACTIVATION_GATE_KERNEL (vllm::gelu_kernel);
87
82
}
88
83
89
- void gelu_tanh_and_mul (
90
- torch::Tensor& out, // [..., d]
91
- torch::Tensor& input) // [..., 2 * d]
84
+ void gelu_tanh_and_mul (torch::Tensor& out, // [..., d]
85
+ torch::Tensor& input) // [..., 2 * d]
92
86
{
93
87
LAUNCH_ACTIVATION_GATE_KERNEL (vllm::gelu_tanh_kernel);
94
88
}
95
89
96
90
namespace vllm {
97
91
98
92
// Element-wise activation kernel template.
99
- template <typename scalar_t , scalar_t (*ACT_FN)(const scalar_t &)>
93
+ template <typename scalar_t , scalar_t (*ACT_FN)(const scalar_t &)>
100
94
__global__ void activation_kernel (
101
- scalar_t * __restrict__ out, // [..., d]
102
- const scalar_t * __restrict__ input, // [..., d]
103
- const int d) {
95
+ scalar_t * __restrict__ out, // [..., d]
96
+ const scalar_t * __restrict__ input, // [..., d]
97
+ const int d) {
104
98
const int64_t token_idx = blockIdx .x ;
105
99
for (int64_t idx = threadIdx .x ; idx < d; idx += blockDim .x ) {
106
100
const scalar_t x = VLLM_LDG (&input[token_idx * d + idx]);
107
101
out[token_idx * d + idx] = ACT_FN (x);
108
102
}
109
103
}
110
104
111
- } // namespace vllm
105
+ } // namespace vllm
112
106
113
107
// Launch element-wise activation kernel.
114
- #define LAUNCH_ACTIVATION_KERNEL (KERNEL ) \
115
- int d = input.size(-1 ); \
116
- int64_t num_tokens = input.numel() / d; \
117
- dim3 grid (num_tokens); \
118
- dim3 block (std::min(d, 1024 )); \
119
- const at::cuda::OptionalCUDAGuard device_guard (device_of(input)); \
120
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
121
- VLLM_DISPATCH_FLOATING_TYPES ( \
122
- input.scalar_type(), \
123
- "activation_kernel", \
124
- [&] { \
125
- vllm::activation_kernel<scalar_t , KERNEL<scalar_t >><<<grid, block, 0 , stream>>> ( \
126
- out.data_ptr <scalar_t >(), \
127
- input.data_ptr <scalar_t >(), \
128
- d); \
129
- });
108
+ #define LAUNCH_ACTIVATION_KERNEL (KERNEL ) \
109
+ int d = input.size(-1 ); \
110
+ int64_t num_tokens = input.numel() / d; \
111
+ dim3 grid (num_tokens); \
112
+ dim3 block (std::min(d, 1024 )); \
113
+ const at::cuda::OptionalCUDAGuard device_guard (device_of(input)); \
114
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
115
+ VLLM_DISPATCH_FLOATING_TYPES (input.scalar_type(), "activation_kernel", [&] { \
116
+ vllm::activation_kernel<scalar_t , KERNEL<scalar_t >> \
117
+ <<<grid, block, 0 , stream>>> (out.data_ptr <scalar_t >(), \
118
+ input.data_ptr <scalar_t >(), d); \
119
+ });
130
120
131
121
namespace vllm {
132
122
133
- template <typename T>
123
+ template <typename T>
134
124
__device__ __forceinline__ T gelu_new_kernel (const T& x) {
135
- const float x3 = (float ) (x * x * x);
136
- const T t = (T) tanhf ((T) (0 .79788456f * (float ) (x + (T) (0 .044715f * x3))));
137
- return ((T) 0.5 ) * x * (((T) 1.0 ) + t);
125
+ const float x3 = (float )(x * x * x);
126
+ const T t = (T)tanhf ((T)(0 .79788456f * (float )(x + (T)(0 .044715f * x3))));
127
+ return ((T)0.5 ) * x * (((T)1.0 ) + t);
138
128
}
139
129
140
- template <typename T>
130
+ template <typename T>
141
131
__device__ __forceinline__ T gelu_fast_kernel (const T& x) {
142
- const float f = (float ) x;
143
- const T t = (T) tanhf (((T) (f * 0 .79788456f )) * (((T) 1.0 ) + (T) (0 .044715f * f) * x));
144
- return ((T) 0.5 ) * x * (((T) 1.0 ) + t);
132
+ const float f = (float )x;
133
+ const T t =
134
+ (T)tanhf (((T)(f * 0 .79788456f )) * (((T)1.0 ) + (T)(0 .044715f * f) * x));
135
+ return ((T)0.5 ) * x * (((T)1.0 ) + t);
145
136
}
146
137
147
- } // namespace vllm
138
+ } // namespace vllm
148
139
149
- void gelu_new (
150
- torch::Tensor& out, // [..., d]
151
- torch::Tensor& input) // [..., d]
140
+ void gelu_new (torch::Tensor& out, // [..., d]
141
+ torch::Tensor& input) // [..., d]
152
142
{
153
143
LAUNCH_ACTIVATION_KERNEL (vllm::gelu_new_kernel);
154
144
}
155
145
156
- void gelu_fast (
157
- torch::Tensor& out, // [..., d]
158
- torch::Tensor& input) // [..., d]
146
+ void gelu_fast (torch::Tensor& out, // [..., d]
147
+ torch::Tensor& input) // [..., d]
159
148
{
160
149
LAUNCH_ACTIVATION_KERNEL (vllm::gelu_fast_kernel);
161
150
}
0 commit comments