Skip to content

upload to spanner and add min input and output len #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 244 additions & 8 deletions benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,201 @@

from google.protobuf.timestamp_pb2 import Timestamp



import os
import sys
import uuid
import traceback
from google.cloud import spanner
import math
from google.api_core import exceptions as gcp_exceptions

def safe_json_value(value, default=0.0):
"""Convert value to JSON-safe format, handling NaN and Infinity."""
if value is None:
return default
if isinstance(value, (int, float)):
if math.isnan(value) or math.isinf(value):
return default
return value
return value

def extract_proto_fields(data, run_type):
"""Extract and structure relevant fields for Spanner insertion, including `run_type`."""

config = {
'model': data.get('config', {}).get('model', ''),
'num_models': safe_json_value(data.get('config', {}).get('num_models', 0), 0),
'model_server': data.get('config', {}).get('model_server', ''),
'backend': data.get('dimensions', {}).get('backend', ''),
'model_id': data.get('dimensions', {}).get('model_id', ''),
'tokenizer_id': data.get('dimensions', {}).get('tokenizer_id', ''),
'request_rate': safe_json_value(data.get('metrics', {}).get('request_rate', 0), 0),
'benchmark_time': safe_json_value(data.get('metrics', {}).get('benchmark_time', 0), 0),
'run_type': run_type
}

infrastructure = {
'model_server': config['model_server'],
'backend': config['backend'],
'gpu_cache_usage_p90': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:gpu_cache_usage_perc', {}).get('P90', 0.0)),
'num_requests_waiting_p90': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:num_requests_waiting', {}).get('P90', 0.0)),
'gpu_cache_usage_mean': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:gpu_cache_usage_perc', {}).get('Mean', 0.0)),
'num_requests_waiting_mean': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:num_requests_waiting', {}).get('Mean', 0.0)),
}

metrics = data.get('metrics', {})
prompt_dataset = {
'num_prompts_attempted': safe_json_value(metrics.get('num_prompts_attempted', 0), 0),
'num_prompts_succeeded': safe_json_value(metrics.get('num_prompts_succeeded', 0), 0),
'avg_input_len': safe_json_value(metrics.get('avg_input_len', 0.0)),
'median_input_len': safe_json_value(metrics.get('median_input_len', 0.0)),
'p90_input_len': safe_json_value(metrics.get('p90_input_len', 0.0)),
'avg_output_len': safe_json_value(metrics.get('avg_output_len', 0.0)),
'median_output_len': safe_json_value(metrics.get('median_output_len', 0.0)),
'p90_output_len': safe_json_value(metrics.get('p90_output_len', 0.0))
}

summary_stats = {
'p90_normalized_time_per_output_token_ms': safe_json_value(metrics.get('p90_normalized_time_per_output_token_ms', 0.0)),
'avg_normalized_time_per_output_token_ms': safe_json_value(metrics.get('avg_normalized_time_per_output_token_ms', 0.0)),
'throughput': safe_json_value(metrics.get('throughput', 0.0)),
'input_tokens_per_sec': safe_json_value(metrics.get('input_tokens_per_sec', 0.0)),
'benchmark_time': safe_json_value(metrics.get('benchmark_time', 0.0)),
'date': data.get('dimensions', {}).get('date', ''),
'avg_latency_ms': safe_json_value(metrics.get('avg_latency_ms', 0.0)),
'median_latency_ms': safe_json_value(metrics.get('median_latency_ms', 0.0)),
'p90_latency_ms': safe_json_value(metrics.get('p90_latency_ms', 0.0)),
'p99_latency_ms': safe_json_value(metrics.get('p99_latency_ms', 0.0)),
'time_per_output_token_seconds_p90': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:time_per_output_token_seconds', {}).get('P90', 0.0)),
'time_to_first_token_seconds_p90': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:time_to_first_token_seconds', {}).get('P90', 0.0)),
'time_per_output_token_seconds_mean': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:time_per_output_token_seconds', {}).get('Mean', 0.0)),
'time_to_first_token_seconds_mean': safe_json_value(data.get('metrics', {}).get('server_metrics', {}).get('vllm:time_to_first_token_seconds', {}).get('Mean', 0.0)),
}

return config, infrastructure, prompt_dataset, summary_stats

def clean_for_json(obj):
"""Recursively clean an object for JSON serialization."""
if isinstance(obj, dict):
return {k: clean_for_json(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [clean_for_json(item) for item in obj]
elif isinstance(obj, float):
if math.isnan(obj) or math.isinf(obj):
return 0.0
return obj
elif obj is None:
return 0.0
else:
return obj

def upload_to_spanner_batch_with_retry(instance_id, database_id, json_files, gcs_base_uri, run_type, max_retries=3):
"""
Upload JSON files to Spanner in batches with retry logic.
More efficient but fails entire batch if any single file has issues.
"""
spanner_client = spanner.Client()
instance = spanner_client.instance(instance_id)
database = instance.database(database_id)

print(f"📊 Uploading {len(json_files)} JSON files to Spanner with run_type='{run_type}'...")

retry_count = 0
success = False
processed_files = []

while retry_count <= max_retries and not success:
try:
processed_files = [] # Reset on each retry

with database.batch() as batch:
for json_file in json_files:
try:
with open(json_file, 'r') as f:
data = json.load(f)

config, infra, prompt, stats = extract_proto_fields(data, run_type)
filename = os.path.basename(json_file)
gcs_uri = f"{gcs_base_uri}/{filename}"
latency_profile_id = str(uuid.uuid4())

# Test JSON serialization before inserting
try:
config_clean = clean_for_json(config)
infra_clean = clean_for_json(infra)
prompt_clean = clean_for_json(prompt)
stats_clean = clean_for_json(stats)

config_json = json.dumps(config_clean)
infra_json = json.dumps(infra_clean)
prompt_json = json.dumps(prompt_clean)
stats_json = json.dumps(stats_clean)
except (TypeError, ValueError) as json_error:
print(f"❌ JSON serialization failed for {json_file}: {json_error}")
continue

batch.insert(
table='LatencyProfiles',
columns=['Id', 'Config', 'Infrastructure', 'PromptDataset', 'SummaryStats', 'GcsUri', 'InsertedAt'],
values=[
(latency_profile_id, config_json, infra_json, prompt_json, stats_json, gcs_uri, spanner.COMMIT_TIMESTAMP)
]
)

if 'core_deployment_artifacts' in data or 'extension_deployment_artifacts' in data:
core_json = json.dumps(data.get('core_deployment_artifacts', {}))
ext_json = json.dumps(data.get('extension_deployment_artifacts', {}))
batch.insert(
table='DeploymentArtifacts',
columns=['Id', 'LatencyProfileId', 'CoreDeploymentArtifacts', 'ExtensionDeploymentArtifacts'],
values=[
(str(uuid.uuid4()), latency_profile_id, core_json, ext_json)
]
)

processed_files.append((json_file, latency_profile_id))

except Exception as e:
print(f"❌ Failed to process {json_file}: {e}")
continue

# If we get here, the batch committed successfully
for json_file, profile_id in processed_files:
print(f"✅ {json_file} uploaded (ID: {profile_id})")
success = True

except (gcp_exceptions.DeadlineExceeded,
gcp_exceptions.ServiceUnavailable,
gcp_exceptions.InternalServerError,
gcp_exceptions.TooManyRequests) as retryable_error:
retry_count += 1
if retry_count <= max_retries:
wait_time = (2 ** retry_count) + random.uniform(0, 1)
print(f"⚠️ Batch upload failed (attempt {retry_count}/{max_retries}): {retryable_error}")
print(f"⏳ Retrying entire batch in {wait_time:.1f} seconds...")
time.sleep(wait_time)
else:
print(f"❌ Batch upload failed after {max_retries} retries: {retryable_error}")

except Exception as batch_error:
retry_count += 1
if retry_count <= max_retries:
wait_time = (2 ** retry_count) + random.uniform(0, 1)
print(f"⚠️ Unexpected batch error (attempt {retry_count}/{max_retries}): {batch_error}")
print(f"⏳ Retrying entire batch in {wait_time:.1f} seconds...")
time.sleep(wait_time)
else:
print(f"❌ Batch upload failed after {max_retries} retries: {batch_error}")
traceback.print_exc()

if success:
print("✅ All files uploaded successfully.")
else:
print("❌ Upload process failed after all retries.")


MIN_SEQ_LEN = 4
NEW_TEXT_KEY = "\nOutput:\n"
PROMETHEUS_PORT = 9090
Expand Down Expand Up @@ -101,6 +296,8 @@ def get_filtered_dataset(
dataset_path: str,
max_input_len: int,
max_output_len: int,
min_input_len: int,
min_output_len: int,
tokenizer: PreTrainedTokenizerBase,
use_dummy_text: bool,
) -> List[Tuple[str, int, int]]:
Expand Down Expand Up @@ -139,7 +336,7 @@ def get_filtered_dataset(
filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset:
prompt_len = len(prompt_token_ids)
if prompt_len < MIN_SEQ_LEN or output_len < MIN_SEQ_LEN:
if prompt_len < min_input_len or output_len < min_output_len:
# Prune too short sequences.
# This is because TGI causes errors when the input or output length
# is too short.
Expand Down Expand Up @@ -471,7 +668,7 @@ async def benchmark(
Also saves results separately for each model.
"""
input_requests = get_filtered_dataset(
args.dataset, args.max_input_length, args.max_output_length, tokenizer, args.use_dummy_text)
args.dataset, args.max_input_length, args.max_output_length, args.min_input_length, args.min_input_length, tokenizer, args.use_dummy_text)

# Combine the models list and traffic split list into a dict

Expand Down Expand Up @@ -534,13 +731,13 @@ async def benchmark(
await print_and_save_result(args, benchmark_duration_sec, prompts_sent, "weighted",
overall_results["latencies"], overall_results["ttfts"],
overall_results["itls"], overall_results["tpots"],
overall_results["errors"])
overall_results["errors"], spanner_upload=True, server_metrics_scrape=True)
for model, data in per_model_results.items():
await print_and_save_result(args, benchmark_duration_sec, len(data["latencies"]), model,
data["latencies"], data["ttfts"], data["itls"],
data["tpots"], data["errors"])

def save_json_results(args: argparse.Namespace, benchmark_result, server_metrics, model, errors):
def save_json_results(args: argparse.Namespace, benchmark_result, server_metrics, model, errors, spanner_upload: bool = False):
# Setup
start_dt_proto = Timestamp()
start_dt_proto.FromDatetime(args.start_datetime)
Expand Down Expand Up @@ -636,6 +833,17 @@ def save_json_results(args: argparse.Namespace, benchmark_result, server_metrics
print(f"File {file_name} uploaded to gs://{args.output_bucket}/{args.output_bucket_filepath}")
except google.cloud.exceptions.NotFound:
print(f"GS Bucket (gs://{args.output_bucket}) does not exist")

if args.spanner_instance_id and args.spanner_database_id and spanner_upload:
# Upload to Spanner
try:
upload_to_spanner_batch_with_retry(
args.spanner_instance_id, args.spanner_database_id, [file_name],
args.output_bucket, args.file_prefix)
print(f"File {file_name} uploaded to Spanner")
except Exception as e:
print(f"Failed to upload {file_name} to Spanner: {e}")


def metrics_to_scrape(backend: str) -> List[str]:
# Each key in the map is a metric, it has a corresponding 'stats' object
Expand Down Expand Up @@ -815,7 +1023,7 @@ def get_stats_for_set(name, description, points):
f'p99_{name}': p99,
}

async def print_and_save_result(args: argparse.Namespace, benchmark_duration_sec, total_requests, model, request_latencies, ttfts, itls, tpots, errors):
async def print_and_save_result(args: argparse.Namespace, benchmark_duration_sec, total_requests, model, request_latencies, ttfts, itls, tpots, errors, spanner_upload=False, server_metrics_scrape=False):
benchmark_result = {}

print(f"====Result for Model: {model}====")
Expand Down Expand Up @@ -882,10 +1090,10 @@ async def print_and_save_result(args: argparse.Namespace, benchmark_duration_sec
}

server_metrics = {}
if args.scrape_server_metrics:
if args.scrape_server_metrics and server_metrics_scrape:
server_metrics = print_metrics(metrics_to_scrape(args.backend), benchmark_duration_sec, args.pm_namespace, args.pm_job)
if args.save_json_results:
save_json_results(args, benchmark_result, server_metrics, model, errors)
save_json_results(args, benchmark_result, server_metrics, model, errors, spanner_upload)

async def main(args: argparse.Namespace):
print(args)
Expand Down Expand Up @@ -1022,7 +1230,23 @@ def parse_traffic_split(arg):
type=int,
default=1024,
help=(
"Maximum number of input tokens for filtering the benchmark dataset."
"Maximum number of output tokens for filtering the benchmark dataset."
),
)
parser.add_argument(
"--min-input-length",
type=int,
default=4,
help=(
"Minimum number of input tokens for filtering the benchmark dataset."
),
)
parser.add_argument(
"--min-output-length",
type=int,
default=4,
help=(
"Minimum number of output tokens for filtering the benchmark dataset."
),
)
parser.add_argument(
Expand Down Expand Up @@ -1118,6 +1342,18 @@ def parse_traffic_split(arg):
action="store_true",
help="Whether to scrape server metrics.",
)
parser.add_argument(
"--spanner-instance-id",
type=str,
default=None,
help="Spanner instance ID to upload results to.",
)
parser.add_argument(
"--spanner-database-id",
type=str,
default=None,
help="Spanner database ID to upload results to.",
)
parser.add_argument("--pm-namespace", type=str, default="default", help="namespace of the pod monitoring object, ignored if scrape-server-metrics is false")
parser.add_argument("--pm-job", type=str, default="vllm-podmonitoring", help="name of the pod monitoring object, ignored if scrape-server-metrics is false")
parser.add_argument("--tcp-conn-limit", type=int, default=100, help="Max number of tcp connections allowed per aiohttp ClientSession")
Expand Down
7 changes: 7 additions & 0 deletions latency_throughput_curve.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ BASE_PYTHON_OPTS=(
"--pm-job=$PM_JOB"
)

[[ "$MIN_INPUT_LENGTH" ]] && BASE_PYTHON_OPTS+=("--min-input-length=$MIN_INPUT_LENGTH")
[[ "$MIN_OUTPUT_LENGTH" ]] && BASE_PYTHON_OPTS+=("--min-output-length=$MIN_OUTPUT_LENGTH")
[[ "$OUTPUT_BUCKET" ]] && BASE_PYTHON_OPTS+=("--output-bucket=$OUTPUT_BUCKET")
[[ "$TRAFFIC_SPLIT" ]] && BASE_PYTHON_OPTS+=("--traffic-split=$TRAFFIC_SPLIT")
[[ "$OUTPUT_BUCKET" ]] && BASE_PYTHON_OPTS+=("--output-bucket=$OUTPUT_BUCKET")
[[ "$SCRAPE_SERVER_METRICS" = "true" ]] && BASE_PYTHON_OPTS+=("--scrape-server-metrics")
Expand All @@ -49,6 +52,10 @@ BASE_PYTHON_OPTS=(
[[ "$IGNORE_EOS" = "true" ]] && BASE_PYTHON_OPTS+=("--ignore-eos")
[[ "$OUTPUT_BUCKET_FILEPATH" ]] && BASE_PYTHON_OPTS+=("--output-bucket-filepath" "$OUTPUT_BUCKET_FILEPATH")
[[ "$TCP_CONN_LIMIT" ]] && BASE_PYTHON_OPTS+=("--tcp-conn-limit" "$TCP_CONN_LIMIT")
[[ "$SPANNER_INSTANCE_ID" ]] && BASE_PYTHON_OPTS+=("--spanner-instance-id" "$SPANNER_INSTANCE_ID")
[[ "$SPANNER_DATABASE_ID" ]] && BASE_PYTHON_OPTS+=("--spanner-database-id" "$SPANNER_DATABASE_ID")



SLEEP_TIME=${SLEEP_TIME:-0}
POST_BENCHMARK_SLEEP_TIME=${POST_BENCHMARK_SLEEP_TIME:-infinity}
Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ aioprometheus[starlette]
pynvml == 11.5.0
accelerate
aiohttp

# For Google Cloud Storage
google-auth
google-cloud-storage >= 2.18.2
prometheus_client >= 0.21.0
google-cloud-spanner
google-api-core