Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/specdec_bench/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ MTBench is available [here](https://huggingface.co/datasets/HuggingFaceH4/mt_ben
Download `nvidia/gpt-oss-120b-Eagle3` to a local directory `/path/to/eagle`.

```bash
python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b --draft_model_dir /path/to/eagle --mtbench question.jsonl --tp_size 1 --ep_size 1 --draft_length 3 --output_length 4096 --num_requests 80 --engine TRTLLM --concurrency 1
python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b --draft_model_dir /path/to/eagle --mtbench question.jsonl --tp_size 1 --ep_size 1 --draft_length 3 --output_length 4096 --num_requests 80 --engine TRTLLM --concurrency 1 --postprocess gptoss

```

Expand Down
24 changes: 22 additions & 2 deletions examples/specdec_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@

import yaml
from specdec_bench import datasets, metrics, models, runners
from specdec_bench.utils import decode_chat, encode_chat, get_tokenizer, postprocess_base
from specdec_bench.utils import (
decode_chat,
encode_chat,
get_tokenizer,
postprocess_base,
postprocess_gptoss,
)

engines_available = {
"TRTLLM": models.TRTLLMPYTModel,
Expand Down Expand Up @@ -109,7 +115,12 @@ def run_simple(args):
metrics_list.insert(0, metrics.AcceptanceRate())
runner = runners.SimpleRunner(model, metrics=metrics_list)

postprocess = postprocess_base
if args.postprocess == "base":
postprocess = postprocess_base
elif args.postprocess == "gptoss":
postprocess = postprocess_gptoss
else:
raise ValueError(f"Invalid postprocess: {args.postprocess}")

asyncio.run(
run_loop(runner, dataset, tokenizer, args.output_length, postprocess, args.concurrency)
Expand Down Expand Up @@ -183,6 +194,15 @@ def run_simple(args):
help="Maximum number of concurrent requests",
)
parser.add_argument("--aa_timing", action="store_true", help="Enable AA timing metric")
parser.add_argument(
"--postprocess",
type=str,
required=False,
default="base",
choices=["base", "gptoss"],
help="Postprocess to use",
)

args = parser.parse_args()

if args.runtime_params is not None:
Expand Down