Skip to content

Adds continuous batching #850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
8 changes: 7 additions & 1 deletion examples/model_configs/transformers_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@ model_parameters:
compile: false
model_parallel: false
batch_size: 1
multichoice_continuations_start_space: null # If true/false, will force multiple choice continuations to start/not start with a space. If none, will do nothing
use_chat_template: true
continuous_batching: false
model_loading_kwargs:
attn_implementation: "eager"
#tp_plan: "auto"
generation_parameters:
#num_blocks: 4096
#block_size: 64
#max_new_tokens: 256
temperature: 0.0
top_p: 0.9
5 changes: 5 additions & 0 deletions src/lighteval/models/model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@


class GenerationParameters(BaseModel, extra="forbid"):
num_blocks: NonNegativeInt | None = None # transformers
block_size: NonNegativeInt | None = None # transformers

early_stopping: bool | None = None # transformers
repetition_penalty: NonNegativeFloat | None = None # vllm, transformers, tgi, sglang
frequency_penalty: NonNegativeFloat | None = None # vllm, tgi, sglang
Expand Down Expand Up @@ -186,6 +189,8 @@ def to_transformers_dict(self) -> dict:
"repetition_penalty": self.repetition_penalty,
"length_penalty": self.length_penalty,
"output_scores": True,
"num_blocks": self.num_blocks,
"block_size": self.block_size,
"return_dict_in_generate": True,
}
return {k: v for k, v in args.items() if v is not None}
Expand Down
172 changes: 163 additions & 9 deletions src/lighteval/models/transformers/transformers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import logging
import os
from datetime import timedelta
from typing import Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import torch
import torch.nn.functional as F
Expand All @@ -41,6 +41,7 @@
BitsAndBytesConfig,
PretrainedConfig,
)
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.utils import GenerateOutput
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

Expand Down Expand Up @@ -110,6 +111,8 @@ class TransformersModelConfig(ModelConfig):
True forces adding space, False removes leading space if present.
pairwise_tokenization (bool):
Whether to tokenize context and continuation separately or together. Defaults to False.
continuous_batching (bool):
Whether to use continuous batching for generation. Defaults to False.

Example:
```python
Expand Down Expand Up @@ -147,6 +150,7 @@ class TransformersModelConfig(ModelConfig):
compile: bool = False
multichoice_continuations_start_space: bool | None = None
pairwise_tokenization: bool = False
continuous_batching: bool = False

def model_post_init(self, __context):
if self.multichoice_continuations_start_space is True:
Expand Down Expand Up @@ -190,7 +194,9 @@ def __init__(
self._add_special_tokens = config.add_special_tokens or False
self.pairwise_tokenization = config.pairwise_tokenization
self.batch_size = config.batch_size
self.continuous_batching = config.continuous_batching
self.transformers_config = config.get_transformers_config()
self.generation_config_dict = config.generation_parameters.to_transformers_dict()

self.model_sha = config.get_model_sha()
self._max_length = self._init_max_length()
Expand All @@ -210,8 +216,6 @@ def __init__(

self.model_name = _simplify_name(config.model_name)

self.generation_config_dict = config.generation_parameters.to_transformers_dict()

if is_accelerate_available():
model_size, _ = calculate_maximum_sizes(self.model)
model_size = convert_bytes(model_size)
Expand Down Expand Up @@ -256,14 +260,15 @@ def from_model(

# Instanciate the object without using __init__
self = cls.__new__(cls)
self.config = config
self.transformers_config = model.config
self.generation_config_dict = config.generation_parameters.to_transformers_dict()
self.config = config if config is not None else TransformersModelConfig(model_name=model.config.name_or_path)
Copy link
Preview

Copilot AI Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The from_model constructor does not set self.continuous_batching, so models loaded via from_model will always default to False. Add self.continuous_batching = self.config.continuous_batching after setting self.config.

Suggested change
self.config = config if config is not None else TransformersModelConfig(model_name=model.config.name_or_path)
self.config = config if config is not None else TransformersModelConfig(model_name=model.config.name_or_path)
self.continuous_batching = self.config.continuous_batching if config and hasattr(self.config, 'continuous_batching') else False

Copilot uses AI. Check for mistakes.

if config is not None:
self.generation_config_dict = config.generation_parameters.to_transformers_dict()
self._max_length = self._init_max_length()
self._tokenizer = self._create_auto_tokenizer()
self.batch_size = config.batch_size
self.batch_size = getattr(config, "batch_size", None)
self.model_name = _simplify_name(model.name_or_path)
self.model_sha = config.get_model_sha()
self.model_sha = self.config.get_model_sha()

# If model_parallel is not set we compare the number of processes with the number of GPUs
self.model = model
Expand Down Expand Up @@ -402,6 +407,11 @@ def _create_auto_model(self) -> transformers.PreTrainedModel:
# model.to(self.device)
model.eval()
torch.set_grad_enabled(False)
if self.continuous_batching:
generation_config = GenerationConfig(
**self.generation_config_dict,
)
model.generation_config = generation_config

if self.config.compile:
try:
Expand Down Expand Up @@ -504,7 +514,110 @@ def forward_batch(batch_size):
logger.info(f"Determined largest batch size: {batch_size}")
return batch_size

def greedy_until(
def _continuous_greedy_until(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there anyway to factorize more between continuous and padded greedy until? (other wise, there's a risk we end up having different input management for example, like we had in the past across generation models)

self,
docs: list[Doc],
) -> list[ModelResponse]:
"""
Generates responses using a greedy decoding strategy until certain ending conditions are met.

Args:
requests (list[Request]): list of requests containing the context and ending conditions.
override_bs (int, optional): Override the batch size for generation. Defaults to None.

Returns:
list[GenerateReturn]: list of generated responses.
"""
dataset = GenerativeTaskDataset(requests=docs, num_dataset_splits=self.DATASET_SPLITS)
results = []

for split in tqdm(
dataset.splits_iterator(),
total=dataset.num_dataset_splits,
desc="Splits",
position=0,
disable=False, # self.disable_tqdm,
):
# For chat models, generation stops with EOS token, so we don't need to specify stop tokens
if self.use_chat_template:
stop_tokens = []
else:
# NOTE: we are assuming all items in a batch behave similarly (same
# stop_tokens and max_tokens genrated) which is not necessarily
# the case! Because of that we only use batch size of 1
stop_tokens = split[0].stop_sequence

max_new_tokens = self.config.generation_parameters.max_new_tokens or split[0].generation_size
returns_logits = split[0].use_logits
num_samples = split[0].num_samples
contexts = [self.prompt_manager.prepare_prompt(doc) for doc in split]
tokenized = self.tokenizer(contexts, add_special_tokens=self.add_special_tokens)

# The main question for this step is the following:
# Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
# of losing some meaning, or have some generations that are exceedingly short?
# The choice we go for here is to avoid truncating the prompt if we can, since it
# should have been managed by the prompt creator/few shot manager if requested by the user.
inputs = tokenized["input_ids"]
context_size = len(inputs[0])

# left truncate the inputs to the maximum length
if max_new_tokens is not None:
if context_size + max_new_tokens > self.max_length:
logger.warning(
f"{context_size + max_new_tokens=} which is greater than {self.max_length=}. Truncating context to {self.max_length - max_new_tokens} tokens."
)
context_size = self.max_length - max_new_tokens
if context_size < 0:
logger.critical(
f"{context_size=} is less than 0, either reduce the max_new_tokens or increase model max length."
)
raise ValueError("Context size is less than 0.")
inputs = [input[-context_size:] for input in inputs]
else:
if context_size > self.max_length:
logger.warning(
f"{context_size=} which is greater than {self.max_length=}. Truncating context to {self.max_length} tokens."
)
context_size = self.max_length
inputs = [input[-context_size:] for input in inputs]

_outputs = self._generate(
inputs=inputs,
max_new_tokens=max_new_tokens,
stop_tokens=stop_tokens,
returns_logits=returns_logits,
num_samples=num_samples,
continuous_batching=True,
)

for req_id, _output in _outputs.items():
output_token_ids = []
logprobs_raw = []
result = []

# for output in _output.outputs:
output_token_ids.append(_output.generated_tokens)
# logprobs_raw.append(output.logprobs)
result.append(self.tokenizer.decode(_output.generated_tokens))

if logprobs_raw and output_token_ids and False:
logprobs = [logprobs_raw[0][token_id].logprob for token_id in output_token_ids[0]]
else:
logprobs = []

input_token_ids = _output.prompt_ids
cur_response = ModelResponse(
text=result,
logprobs=logprobs,
output_tokens=output_token_ids,
input_tokens=input_token_ids,
)
results.append(cur_response)

return dataset.get_original_order(results)

def _padded_greedy_until(
self,
docs: list[Doc],
) -> list[ModelResponse]:
Expand Down Expand Up @@ -617,12 +730,43 @@ def greedy_until(
stop_tokens=stop_tokens,
returns_logits=False,
num_samples=num_samples,
continuous_batching=False,
)
results.extend(cur_reponses)

return dataset.get_original_order(results)

def _generate(
def greedy_until(
self,
docs: list[Doc],
) -> list[ModelResponse]:
if self.continuous_batching:
return self._continuous_greedy_until(docs)
else:
return self._padded_greedy_until(docs)

def _generate_fast(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is generate fast for continuous batching only? if yes -> call it generate_continuous then, since the other is generate_padded and not generate_slow (for homogeneity)

self,
inputs: list[list[int]],
max_new_tokens: Optional[int] = None,
stop_tokens: Optional[list[str]] = None,
returns_logits: Optional[bool] = False,
num_samples: int = 1,
generate: bool = True,
) -> Dict[str, ModelResponse]:
# Compute model generation
self.model.generation_config.use_cuda_graph = False # Disable CUDA graph for batch generation
self.model.generation_config.max_batch_tokens = 256 # Disable CUDA graph for batch generation
# self.model.generation_config.do_sample = False # Disable CUDA graph for batch generation
batch_outputs = self.model.generate_batch(
inputs=inputs,
generation_config=self.model.generation_config,
# You can pass request-specific overrides here, e.g., max_new_tokens=100
)

return batch_outputs

def _generate_padded(
self,
batch: Batch,
max_new_tokens: int,
Expand Down Expand Up @@ -708,6 +852,16 @@ def _generate(

return all_responses

def _generate(
self,
continuous_batching: bool,
**kwargs,
) -> list[ModelResponse]:
if continuous_batching:
return self._generate_fast(**kwargs)
else:
return self._generate_padded(**kwargs)

def loglikelihood(
self,
docs: list[Doc],
Expand Down
2 changes: 2 additions & 0 deletions tests/models/endpoints/test_endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class TestInferenceEndpointModelConfig:
"add_special_tokens": True,
"system_prompt": None,
"generation_parameters": {
"num_blocks": None,
"block_size": None,
Comment on lines +55 to +56
Copy link
Preview

Copilot AI Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Here the parameters are added as num_blocks then block_size, which is the reverse of the other test. Align ordering to maintain consistency.

Suggested change
"num_blocks": None,
"block_size": None,
"block_size": None,
"num_blocks": None,

Copilot uses AI. Check for mistakes.

"early_stopping": None,
"frequency_penalty": None,
"length_penalty": None,
Expand Down
2 changes: 2 additions & 0 deletions tests/models/endpoints/test_tgi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class TestTGIModelConfig:
"model_name": None,
"system_prompt": None,
"generation_parameters": {
"block_size": None,
"num_blocks": None,
Comment on lines +41 to +42
Copy link
Preview

Copilot AI Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Test insertion order of block_size then num_blocks differs from the other endpoint test. Consider keeping parameter order consistent across tests to avoid confusion.

Suggested change
"block_size": None,
"num_blocks": None,
"num_blocks": None,
"block_size": None,

Copilot uses AI. Check for mistakes.

"early_stopping": None,
"frequency_penalty": None,
"length_penalty": None,
Expand Down
Loading