Skip to content

Add CLI arg generate_until_token to support reasoning and CoT models #617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/lighteval/main_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def accelerate( # noqa C901
load_responses_from_details_date_id: Annotated[
Optional[str], Option(help="Load responses from details directory.", rich_help_panel=HELP_PANEL_NAME_1)
] = None,
generate_until_token: Annotated[
Optional[str], Option(help="Continue generating reasoning or chain-of-thought after system prompt and until this stop token.", rich_help_panel=HELP_PANEL_NAME_4)
] = None,
# === saving ===
output_dir: Annotated[
str, Option(help="Output directory for evaluation results.", rich_help_panel=HELP_PANEL_NAME_2)
Expand Down Expand Up @@ -121,6 +124,9 @@ def accelerate( # noqa C901

env_config = EnvConfig(token=TOKEN, cache_dir=cache_dir)

if (not use_chat_template) and (generate_until_token is not None):
raise Exception("`generate_until_token` must be used with `use_chat_template` flag.")

evaluation_tracker = EvaluationTracker(
output_dir=output_dir,
save_details=save_details,
Expand All @@ -141,6 +147,7 @@ def accelerate( # noqa C901
use_chat_template=use_chat_template,
system_prompt=system_prompt,
load_responses_from_details_date_id=load_responses_from_details_date_id,
generate_until_token=generate_until_token,
)

# TODO (nathan): better handling of model_args
Expand Down Expand Up @@ -172,6 +179,7 @@ def accelerate( # noqa C901
"multichoice_continuations_start_space"
]
args_dict["use_chat_template"] = use_chat_template
args_dict["generate_until_token"] = generate_until_token

# Keeping only non null params
args_dict = {k: v for k, v in args_dict.items() if v is not None}
Expand Down
2 changes: 2 additions & 0 deletions src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class PipelineParameters:
use_chat_template: bool = False
system_prompt: str | None = None
load_responses_from_details_date_id: str | None = None
generate_until_token: str | None = None

def __post_init__(self): # noqa C901
if self.launcher_type == ParallelismManager.ACCELERATE:
Expand Down Expand Up @@ -234,6 +235,7 @@ def _init_tasks_and_requests(self, tasks: str):
evaluation_tracker=self.evaluation_tracker,
use_chat_template=self.pipeline_parameters.use_chat_template,
system_prompt=self.pipeline_parameters.system_prompt,
generate_until_token=self.pipeline_parameters.generate_until_token,
)

self.task_names_list = task_names_list
Expand Down
2 changes: 2 additions & 0 deletions src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ def create_requests_from_tasks( # noqa: C901
evaluation_tracker: "EvaluationTracker",
use_chat_template: bool,
system_prompt: str | None,
generate_until_token: str | None,
) -> Tuple[dict[RequestType, list[Request]], dict[SampleUid, Doc]]:
"""
Takes a task dict and a fewshot dict and returns a dict of requests, a dict
Expand Down Expand Up @@ -646,6 +647,7 @@ def create_requests_from_tasks( # noqa: C901
truncate_few_shots=truncate_few_shots,
use_chat_template=use_chat_template,
system_prompt=system_prompt,
generate_until_token=generate_until_token,
)

# Constructing the requests
Expand Down
35 changes: 31 additions & 4 deletions src/lighteval/tasks/prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

from lighteval.models.abstract_model import LightevalModel
from lighteval.models.litellm_model import LiteLLMClient
from lighteval.models.model_output import Batch
from lighteval.models.transformers.transformers_model import TransformersModel
from lighteval.tasks.requests import Doc
from lighteval.utils.utils import as_list

Expand Down Expand Up @@ -105,7 +107,8 @@ def add_context_to_doc(
sampler: Optional[random.Random] = None,
truncate_few_shots: bool = False,
use_chat_template=False,
system_prompt: str = None,
system_prompt: Optional[str] = None,
generate_until_token: Optional[str] = None,
) -> Doc:
is_multi_turn = doc.specific is not None and len(doc.specific.get("multi_turn_queries", [])) > 0
if is_multi_turn:
Expand All @@ -120,6 +123,7 @@ def add_context_to_doc(
sampler=sampler,
use_chat_template=use_chat_template,
system_prompt=system_prompt,
generate_until_token=generate_until_token,
)
doc.num_effective_few_shots = num_effective_few_shots
doc.num_asked_few_shots = num_fewshot
Expand Down Expand Up @@ -173,7 +177,8 @@ def _single_turn_context(
sampler: Optional[random.Random] = None,
truncate_few_shots: bool = False,
use_chat_template=False,
system_prompt: str = None,
system_prompt: Optional[str] = None,
generate_until_token: Optional[str] = None,
):
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
Expand Down Expand Up @@ -238,9 +243,31 @@ def _single_turn_context(
return output, num_effective_fewshots

elif use_chat_template:
return self.model.tokenizer.apply_chat_template(
chat_preview = self.model.tokenizer.apply_chat_template(
output, tokenize=False, add_generation_prompt=True
), num_effective_fewshots
)
if generate_until_token is not None:
if not isinstance(self.model, TransformersModel):
raise Exception("`generate_until_token` only implemented for `TransformerModel` class")
tokenized = self.model.tokenizer(chat_preview, return_tensors="pt").to(self.model.device)
prepared_batch = Batch(
input_ids=tokenized["input_ids"],
input_mask=tokenized["attention_mask"],
input_lengths=[len(tokenized["input_ids"][0])],
truncated=[False],
padded=[False],
)
response = self.model._generate(
batch=prepared_batch,
do_sample=True,
max_new_tokens=2048,
stop_tokens=[generate_until_token],
)
logger.debug(response[0].result[0])
full_start = chat_preview + response[0].result[0] + generate_until_token
return full_start, num_effective_fewshots
else:
return chat_preview, num_effective_fewshots

return output, num_effective_fewshots

Expand Down