Skip to content

[WIP][Blackwell Kernels] Blackwell group gemm and dense gemms with Python Cutlass #1256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
98df179
start pycutlass kernels
lessw2020 Jun 1, 2025
c6e19d3
dense gemm working
lessw2020 Jun 1, 2025
f042735
start group gemm
lessw2020 Jun 1, 2025
cebe5f8
working group gemm!
lessw2020 Jun 2, 2025
8d11678
just keep the working benchmarks
lessw2020 Jun 3, 2025
1d852d4
add triton_do_bench benchmarking
lessw2020 Jun 3, 2025
47eb705
add persistent dense gemm kernel
lessw2020 Jun 3, 2025
95886d5
make equal comparison via tf32
lessw2020 Jun 4, 2025
c9beca9
add full bench with persistent kernel
lessw2020 Jun 4, 2025
45e22e1
group gemm working with benchmarking
lessw2020 Jun 4, 2025
84b3894
start group gemm integration with ds
lessw2020 Jun 4, 2025
46508c6
bgg integrated, failing on some group sizes
lessw2020 Jun 4, 2025
dd1b229
try padding, but still failing
lessw2020 Jun 4, 2025
eeb440a
add sm100 group scheduler
lessw2020 Jun 4, 2025
5eaea74
add manual looping as default to get running on blackwell
lessw2020 Jun 5, 2025
55c2175
symm memory + manual looping = blackwell inference working
lessw2020 Jun 5, 2025
ddc40ad
add tensor converter
lessw2020 Jun 5, 2025
3235025
cute dense looping group gemm start
lessw2020 Jun 5, 2025
3a5e52f
cute dense looping group gemm start
lessw2020 Jun 5, 2025
8028659
start dense cute gemm with looping - gate only
lessw2020 Jun 6, 2025
6a883e2
gate cute kernel prepped
lessw2020 Jun 6, 2025
1b90dd0
first cute blackwell python gemm running in deepseek inference!
lessw2020 Jun 6, 2025
ab9701d
use tma_store
lessw2020 Jun 6, 2025
171e526
gate and up now running as dense blackwell cute gemms
lessw2020 Jun 6, 2025
a6e033c
full MoE running on looping dense blackwell gemms
lessw2020 Jun 6, 2025
f2e9159
refine implementation of dense cute gemms
lessw2020 Jun 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ set -ex
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
NGPU=${NGPU:-"8"}
export LOG_RANK=${LOG_RANK:-0}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/experiments/llama4/train_configs/debug_model.toml"}

overrides=""
if [ $# -ne 0 ]; then
Expand Down
24 changes: 20 additions & 4 deletions torchtitan/experiments/deepseek_v3/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

# use inference.sh "Your Question Here?" to run inference with a single prompt.

import logging
import sys
from dataclasses import dataclass

Expand All @@ -19,15 +20,26 @@
from model_config import deepseek_config_registry
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
from transformers import AutoTokenizer
from torchtitan.tools.logging import init_logger, logger

from torchtitan.tools.utils import Color
from transformers import AutoTokenizer

# Uncomment the model you want to run.
model_id, mesh_shape = "deepseek-ai/DeepSeek-V2-Lite-Chat", (1, 4)
# model_id, mesh_shape = "deepseek-ai/deepseek-v3", (8, 4)


def remove_notset_root_handlers():
"""
Remove handlers with level NOTSET from root logger.
Titan's logger is set, and thus we can differentiate between these.
"""
for handler in logger.handlers[:]:
if handler.level == logging.NOTSET:
logger.removeHandler(handler)


def colorize_chat(text, user_color=None, assistant_color=None, output_color=None):
"""Parse and colorize chat output with optional colors for each role."""
lines = text.split("\n")
Expand Down Expand Up @@ -127,7 +139,7 @@ def create_model(dist_config: DistConfig):
model_args.ep_size = dist_config.ep_size
model_args.num_stages = dist_config.pp_size
model_args.stage_idx = dist_config.pp_rank
model_args.max_seq_len = 4096 # 16384
model_args.max_seq_len = 256 # 16384

with dist_config.device, dist_config.mesh:
model = DeepseekForCausalLM(model_args)
Expand Down Expand Up @@ -224,7 +236,7 @@ def generate(
tokenizer,
dist_config,
messages: list[dict],
n_tokens: int = 200,
n_tokens: int = 50,
):
rank = dist.get_rank()
device = dist_config.device
Expand Down Expand Up @@ -353,6 +365,10 @@ def generate_with_cuda_graph(


if __name__ == "__main__":
# init_logger()
# get rid of HF duplicate logs
# remove_notset_root_handlers()

# Get user prompt from command line arguments
user_prompt = "What is 2+2?" # Default prompt
if len(sys.argv) > 1:
Expand All @@ -375,7 +391,7 @@ def generate_with_cuda_graph(
]

generate(model, pp_schedule, tokenizer, dist_config, messages)
generate_with_cuda_graph(model, tokenizer, dist_config, messages)
# generate_with_cuda_graph(model, tokenizer, dist_config, messages)

if rank == 0:
print(f"\n{color.yellow}Closing inference mesh...{color.reset}")
Expand Down
Loading