Skip to content

Commit d456aea

Browse files
dtransposeddtransposed
and
dtransposed
authored
[Misc] Add Next Edit Prediction (NEP) datasets support in benchmark_serving.py (vllm-project#16839)
Signed-off-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal> Signed-off-by: dtransposed <> Co-authored-by: dtransposed <damian@damian-ml-machine.europe-west3-b.c.jetbrains-grazie.internal>
1 parent 621ca2c commit d456aea

File tree

3 files changed

+182
-2
lines changed

3 files changed

+182
-2
lines changed

benchmarks/benchmark_dataset.py

+88
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,94 @@ def sample(self,
887887
return sampled_requests
888888

889889

890+
# -----------------------------------------------------------------------------
891+
# Next Edit Prediction Dataset Implementation
892+
# -----------------------------------------------------------------------------
893+
894+
895+
zeta_prompt = """### Instruction:
896+
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
897+
898+
### User Edits:
899+
900+
{}
901+
902+
### User Excerpt:
903+
904+
{}
905+
906+
### Response:
907+
908+
""" # noqa: E501
909+
910+
911+
def _format_zeta_prompt(
912+
sample: dict,
913+
original_start_marker: str = "<|editable_region_start|>") -> dict:
914+
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
915+
916+
This function formats examples from the NEP dataset
917+
into prompts and expected outputs. It could be
918+
further extended to support more NEP datasets.
919+
920+
Args:
921+
sample: The dataset sample containing events,
922+
inputs, and outputs.
923+
original_start_marker: The marker indicating the
924+
start of the editable region. Defaults to
925+
"<|editable_region_start|>".
926+
927+
Returns:
928+
A dictionary with the formatted prompts and expected outputs.
929+
"""
930+
events = sample["events"]
931+
input = sample["input"]
932+
output = sample["output"]
933+
prompt = zeta_prompt.format(events, input)
934+
935+
# following the original implementation, extract the focused region
936+
# from the raw output
937+
output_start_index = output.find(original_start_marker)
938+
output_focused_region = output[output_start_index:]
939+
expected_output = output_focused_region
940+
941+
return {"prompt": prompt, "expected_output": expected_output}
942+
943+
944+
class NextEditPredictionDataset(HuggingFaceDataset):
945+
"""
946+
Dataset class for processing a Next Edit Prediction dataset.
947+
"""
948+
949+
SUPPORTED_DATASET_PATHS = {
950+
"zed-industries/zeta",
951+
}
952+
MAPPING_PROMPT_FUNCS = {
953+
"zed-industries/zeta": _format_zeta_prompt,
954+
}
955+
956+
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
957+
**kwargs):
958+
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
959+
self.dataset_path)
960+
if formatting_prompt_func is None:
961+
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
962+
samples = []
963+
for sample in self.data:
964+
sample = formatting_prompt_func(sample)
965+
samples.append(
966+
SampleRequest(
967+
prompt=sample["prompt"],
968+
prompt_len=len(tokenizer(sample["prompt"]).input_ids),
969+
expected_output_len=len(
970+
tokenizer(sample["expected_output"]).input_ids),
971+
))
972+
if len(samples) >= num_requests:
973+
break
974+
self.maybe_oversample_requests(samples, num_requests)
975+
return samples
976+
977+
890978
# -----------------------------------------------------------------------------
891979
# ASR Dataset Implementation
892980
# -----------------------------------------------------------------------------

benchmarks/benchmark_serving.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@
5353
from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset,
5454
ConversationDataset, HuggingFaceDataset,
5555
InstructCoderDataset, MTBenchDataset,
56-
RandomDataset, SampleRequest, ShareGPTDataset,
57-
SonnetDataset, VisionArenaDataset)
56+
NextEditPredictionDataset, RandomDataset,
57+
SampleRequest, ShareGPTDataset, SonnetDataset,
58+
VisionArenaDataset)
5859
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
5960

6061
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@@ -603,6 +604,9 @@ def main(args: argparse.Namespace):
603604
elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS:
604605
dataset_class = AIMODataset
605606
args.hf_split = "train"
607+
elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501
608+
dataset_class = NextEditPredictionDataset
609+
args.hf_split = "train"
606610
elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS:
607611
dataset_class = ASRDataset
608612
args.hf_split = "train"

vllm/benchmarks/datasets.py

+88
Original file line numberDiff line numberDiff line change
@@ -829,3 +829,91 @@ def sample(self,
829829
))
830830
self.maybe_oversample_requests(sampled_requests, num_requests)
831831
return sampled_requests
832+
833+
834+
# -----------------------------------------------------------------------------
835+
# Next Edit Prediction Dataset Implementation
836+
# -----------------------------------------------------------------------------
837+
838+
839+
zeta_prompt = """### Instruction:
840+
You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
841+
842+
### User Edits:
843+
844+
{}
845+
846+
### User Excerpt:
847+
848+
{}
849+
850+
### Response:
851+
852+
""" # noqa: E501
853+
854+
855+
def _format_zeta_prompt(
856+
sample: dict,
857+
original_start_marker: str = "<|editable_region_start|>") -> dict:
858+
"""Format the zeta prompt for the Next Edit Prediction (NEP) dataset.
859+
860+
This function formats examples from the NEP dataset
861+
into prompts and expected outputs. It could be
862+
further extended to support more NEP datasets.
863+
864+
Args:
865+
sample: The dataset sample containing events,
866+
inputs, and outputs.
867+
original_start_marker: The marker indicating the
868+
start of the editable region. Defaults to
869+
"<|editable_region_start|>".
870+
871+
Returns:
872+
A dictionary with the formatted prompts and expected outputs.
873+
"""
874+
events = sample["events"]
875+
input = sample["input"]
876+
output = sample["output"]
877+
prompt = zeta_prompt.format(events, input)
878+
879+
# following the original implementation, extract the focused region
880+
# from the raw output
881+
output_start_index = output.find(original_start_marker)
882+
output_focused_region = output[output_start_index:]
883+
expected_output = output_focused_region
884+
885+
return {"prompt": prompt, "expected_output": expected_output}
886+
887+
888+
class NextEditPredictionDataset(HuggingFaceDataset):
889+
"""
890+
Dataset class for processing a Next Edit Prediction dataset.
891+
"""
892+
893+
SUPPORTED_DATASET_PATHS = {
894+
"zed-industries/zeta",
895+
}
896+
MAPPING_PROMPT_FUNCS = {
897+
"zed-industries/zeta": _format_zeta_prompt,
898+
}
899+
900+
def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int,
901+
**kwargs):
902+
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(
903+
self.dataset_path)
904+
if formatting_prompt_func is None:
905+
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
906+
samples = []
907+
for sample in self.data:
908+
sample = formatting_prompt_func(sample)
909+
samples.append(
910+
SampleRequest(
911+
prompt=sample["prompt"],
912+
prompt_len=len(tokenizer(sample["prompt"]).input_ids),
913+
expected_output_len=len(
914+
tokenizer(sample["expected_output"]).input_ids),
915+
))
916+
if len(samples) >= num_requests:
917+
break
918+
self.maybe_oversample_requests(samples, num_requests)
919+
return samples

0 commit comments

Comments
 (0)