Skip to content

Commit 4512e87

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent b8e0b80 commit 4512e87

File tree

3 files changed

+73
-42
lines changed

3 files changed

+73
-42
lines changed

transformer_engine/jax/cpp_extensions/router.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,30 @@
1919
"map_score_function",
2020
]
2121

22+
2223
def map_score_function(score_function: str) -> int:
2324
score_function_map = {"sigmoid": 0, "softmax": 1}
24-
assert score_function in score_function_map, \
25-
f"score_function must be 'sigmoid' or 'softmax', got {score_function}"
25+
assert (
26+
score_function in score_function_map
27+
), f"score_function must be 'sigmoid' or 'softmax', got {score_function}"
2628
return score_function_map[score_function]
2729

30+
2831
class FusedTopkWithScoreFunctionFwdPrimitive(BasePrimitive):
2932
"""
3033
Fused TopK with Score Function Forward Primitive
3134
"""
3235

3336
name = "te_fused_topk_with_score_function_forward_ffi"
3437
multiple_results = True # Returns (probs, routing_map, intermediate_output)
35-
impl_static_args = (2, 3, 4, 5, 6, 7,) # topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function,
38+
impl_static_args = (
39+
2,
40+
3,
41+
4,
42+
5,
43+
6,
44+
7,
45+
) # topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function,
3646
inner_primitive = None
3747
outer_primitive = None
3848

@@ -52,7 +62,7 @@ def abstract(
5262
te_fused_topk_with_score_function_forward abstract
5363
"""
5464
dtype = dtypes.canonicalize_dtype(logits_aval.dtype)
55-
assert len(logits_aval.shape) == 3 # (batch, seqlen, num_experts)
65+
assert len(logits_aval.shape) == 3 # (batch, seqlen, num_experts)
5666

5767
probs_aval = logits_aval.update(shape=logits_aval.shape, dtype=dtype)
5868
routing_map_aval = logits_aval.update(shape=logits_aval.shape, dtype=jnp.bool_)
@@ -78,14 +88,14 @@ def lowering(
7888
"""
7989
logits_type = ir.RankedTensorType(logits.type)
8090
logits_shape = logits_type.shape
81-
assert len(logits_shape) == 3 # (batch, seqlen, num_experts)
91+
assert len(logits_shape) == 3 # (batch, seqlen, num_experts)
8292
(batch, seqlen, num_experts) = logits_shape
8393

8494
return ffi.ffi_lowering(FusedTopkWithScoreFunctionFwdPrimitive.name)(
8595
ctx,
8696
logits,
8797
expert_bias,
88-
num_tokens=batch*seqlen,
98+
num_tokens=batch * seqlen,
8999
num_experts=num_experts,
90100
topk=topk,
91101
use_pre_softmax=use_pre_softmax,
@@ -107,17 +117,20 @@ def impl(
107117
score_function,
108118
):
109119
assert FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive is not None
110-
(probs, routing_map, intermediate_output) = FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive.bind(
111-
logits,
112-
expert_bias,
113-
topk=topk,
114-
use_pre_softmax=use_pre_softmax,
115-
num_groups=num_groups,
116-
group_topk=group_topk,
117-
scaling_factor=scaling_factor,
118-
score_function=score_function,
120+
(probs, routing_map, intermediate_output) = (
121+
FusedTopkWithScoreFunctionFwdPrimitive.inner_primitive.bind(
122+
logits,
123+
expert_bias,
124+
topk=topk,
125+
use_pre_softmax=use_pre_softmax,
126+
num_groups=num_groups,
127+
group_topk=group_topk,
128+
scaling_factor=scaling_factor,
129+
score_function=score_function,
130+
)
119131
)
120132
return probs, routing_map, intermediate_output
133+
121134
@staticmethod
122135
def batcher(
123136
batched_args,
@@ -129,7 +142,9 @@ def batcher(
129142
scaling_factor,
130143
score_function,
131144
):
132-
raise NotImplementedError("Batcher not implemented for FusedTopkWithScoreFunctionFwdPrimitive")
145+
raise NotImplementedError(
146+
"Batcher not implemented for FusedTopkWithScoreFunctionFwdPrimitive"
147+
)
133148

134149
@staticmethod
135150
def infer_sharding_from_operands(
@@ -169,7 +184,8 @@ def partition(
169184
del result_infos
170185
out_shardings = (arg_infos[0].sharding, arg_infos[0].sharding, arg_infos[0].sharding)
171186
arg_shardings = (arg_infos[0].sharding, arg_infos[1].sharding)
172-
impl = partial(FusedTopkWithScoreFunctionFwdPrimitive.impl,
187+
impl = partial(
188+
FusedTopkWithScoreFunctionFwdPrimitive.impl,
173189
topk=topk,
174190
use_pre_softmax=use_pre_softmax,
175191
num_groups=num_groups,
@@ -261,15 +277,15 @@ def lowering(
261277
"""
262278
intermediate_output_type = ir.RankedTensorType(intermediate_output.type)
263279
intermediate_output_shape = intermediate_output_type.shape
264-
assert len(intermediate_output_shape) == 3 # (batch, seqlen, num_experts)
280+
assert len(intermediate_output_shape) == 3 # (batch, seqlen, num_experts)
265281
(batch, seqlen, num_experts) = intermediate_output_shape
266282

267283
return ffi.ffi_lowering(FusedTopkWithScoreFunctionBwdPrimitive.name)(
268284
ctx,
269285
routing_map,
270286
intermediate_output,
271287
grad_probs,
272-
num_tokens=batch*seqlen,
288+
num_tokens=batch * seqlen,
273289
num_experts=num_experts,
274290
topk=topk,
275291
use_pre_softmax=use_pre_softmax,
@@ -307,7 +323,9 @@ def batcher(
307323
scaling_factor,
308324
score_function,
309325
):
310-
raise NotImplementedError("Batcher not implemented for FusedTopkWithScoreFunctionBwdPrimitive")
326+
raise NotImplementedError(
327+
"Batcher not implemented for FusedTopkWithScoreFunctionBwdPrimitive"
328+
)
311329

312330
@staticmethod
313331
def infer_sharding_from_operands(
@@ -411,6 +429,7 @@ def fused_topk_with_score_function_fwd(
411429
score_function=score_function,
412430
)
413431

432+
414433
def fused_topk_with_score_function_bwd(
415434
routing_map: jnp.ndarray,
416435
intermediate_output: jnp.ndarray,

transformer_engine/jax/csrc/extensions/router.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@ constexpr int kScoreFunctionSigmoid = 0;
1919
constexpr int kScoreFunctionSoftmax = 1;
2020

2121
Error_Type FusedTopkWithScoreFunctionForwardFFI(
22-
cudaStream_t stream, Buffer_Type logits_buf, Buffer_Type expert_bias_buf,
23-
Result_Type probs_buf, Result_Type routing_map_buf, Result_Type intermediate_output_buf,
24-
int64_t num_tokens, int64_t num_experts, int64_t topk, bool use_pre_softmax,
25-
int64_t num_groups, int64_t group_topk, double scaling_factor, int64_t score_function) {
26-
22+
cudaStream_t stream, Buffer_Type logits_buf, Buffer_Type expert_bias_buf, Result_Type probs_buf,
23+
Result_Type routing_map_buf, Result_Type intermediate_output_buf, int64_t num_tokens,
24+
int64_t num_experts, int64_t topk, bool use_pre_softmax, int64_t num_groups, int64_t group_topk,
25+
double scaling_factor, int64_t score_function) {
2726
auto logits_dtype = convert_ffi_datatype_to_te_dtype(logits_buf.element_type());
28-
auto logits_shape = std::vector<size_t>{static_cast<size_t>(num_tokens),
29-
static_cast<size_t>(num_experts)};
27+
auto logits_shape =
28+
std::vector<size_t>{static_cast<size_t>(num_tokens), static_cast<size_t>(num_experts)};
3029

3130
auto *logits = logits_buf.untyped_data();
3231
auto logits_tensor = TensorWrapper(logits, logits_shape, logits_dtype);
@@ -47,28 +46,28 @@ Error_Type FusedTopkWithScoreFunctionForwardFFI(
4746
auto routing_map_tensor = TensorWrapper(routing_map, logits_shape, DType::kByte);
4847

4948
auto *intermediate_output = intermediate_output_buf->untyped_data();
50-
auto intermediate_output_tensor =
51-
TensorWrapper(intermediate_output, logits_shape, logits_dtype);
49+
auto intermediate_output_tensor = TensorWrapper(intermediate_output, logits_shape, logits_dtype);
5250

5351
nvte_fused_topk_with_score_function_forward(
5452
logits_tensor.data(), static_cast<int>(num_tokens), static_cast<int>(num_experts),
5553
static_cast<int>(topk), static_cast<int>(use_pre_softmax), static_cast<int>(num_groups),
5654
static_cast<int>(group_topk), static_cast<float>(scaling_factor),
57-
static_cast<int>(score_function),
58-
expert_bias_tensor.data(), probs_tensor.data(),
55+
static_cast<int>(score_function), expert_bias_tensor.data(), probs_tensor.data(),
5956
routing_map_tensor.data(), intermediate_output_tensor.data(), stream);
6057

6158
return ffi_with_cuda_error_check();
6259
}
6360

64-
Error_Type FusedTopkWithScoreFunctionBackwardFFI(
65-
cudaStream_t stream, Buffer_Type routing_map_buf, Buffer_Type intermediate_output_buf,
66-
Buffer_Type grad_probs_buf, Result_Type grad_logits_buf, int64_t num_tokens,
67-
int64_t num_experts, int64_t topk, bool use_pre_softmax, double scaling_factor,
68-
int64_t score_function) {
61+
Error_Type FusedTopkWithScoreFunctionBackwardFFI(cudaStream_t stream, Buffer_Type routing_map_buf,
62+
Buffer_Type intermediate_output_buf,
63+
Buffer_Type grad_probs_buf,
64+
Result_Type grad_logits_buf, int64_t num_tokens,
65+
int64_t num_experts, int64_t topk,
66+
bool use_pre_softmax, double scaling_factor,
67+
int64_t score_function) {
6968
auto grad_probs_dtype = convert_ffi_datatype_to_te_dtype(grad_probs_buf.element_type());
70-
auto tensor_shape = std::vector<size_t>{static_cast<size_t>(num_tokens),
71-
static_cast<size_t>(num_experts)};
69+
auto tensor_shape =
70+
std::vector<size_t>{static_cast<size_t>(num_tokens), static_cast<size_t>(num_experts)};
7271

7372
auto *routing_map = routing_map_buf.untyped_data();
7473
auto routing_map_tensor = TensorWrapper(routing_map, tensor_shape, DType::kByte);
@@ -129,4 +128,3 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedTopkWithScoreFunctionBackwardHandler,
129128

130129
} // namespace jax
131130
} // namespace transformer_engine
132-

transformer_engine/jax/router.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,27 @@ def _fused_topk_with_score_function(
5757
score_function,
5858
):
5959
outputs, _ = _fused_topk_fwd_rule(
60-
logits, expert_bias, topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function
60+
logits,
61+
expert_bias,
62+
topk,
63+
use_pre_softmax,
64+
num_groups,
65+
group_topk,
66+
scaling_factor,
67+
score_function,
6168
)
6269
return outputs
6370

6471

6572
def _fused_topk_fwd_rule(
66-
logits, expert_bias, topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function
73+
logits,
74+
expert_bias,
75+
topk,
76+
use_pre_softmax,
77+
num_groups,
78+
group_topk,
79+
scaling_factor,
80+
score_function,
6781
):
6882
probs, routing_map, intermediate_output = tex.fused_topk_with_score_function_fwd(
6983
logits,
@@ -84,7 +98,7 @@ def _fused_topk_bwd_rule(
8498
del num_groups, group_topk
8599
routing_map, intermediate_output = ctx
86100
grad_probs, _, _ = grads
87-
101+
88102
grad_logits = tex.fused_topk_with_score_function_bwd(
89103
routing_map,
90104
intermediate_output,

0 commit comments

Comments
 (0)