-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
149 lines (125 loc) · 4.54 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import boto3
from dynaconf import Dynaconf
from src.data_utils import write_dataset_local, write_dataset_to_s3, load_latest_dataset
import argparse
from src.completion import batch_get_completions, invoke_sagemaker_endpoint
from src.format import format_prompt_as_xml, format_prompt
from src.prompt_tones import master_sys_prompt
from tqdm import tqdm
import os
from transformers import AutoTokenizer
import pandas as pd
from src.completion import batch_get_completions
from src.format import format_prompt_as_xml
from src.prompt_tones import master_sys_prompt, get_prompt, get_all_tones, Tone
from tqdm import tqdm
from src.action_generate import generate_all_datasets, generate_specific_dataset
from src.action_inference import run_inference
from src.action_llm_judge import judge
from src.aws_utils import get_current_aws_account_id
def setup_argparse() -> argparse.ArgumentParser:
"""Setup command line argument parser"""
parser = argparse.ArgumentParser(
description="WRAVAL - Writing Assistant Evaluation Tool"
)
parser.add_argument(
"action",
choices=[
"generate",
"inference",
"llm_judge",
"human_judge_upload",
"human_judge_parsing"
],
help="Action to perform (generate data or run inference)",
)
parser.add_argument(
"--type",
"-t",
choices=get_all_tones() + ["all"],
default="all",
help="Type of dataset to generate (default: all)",
)
parser.add_argument(
"--model", "-m", default="haiku-3", help="Model to use (default: haiku-3)"
)
parser.add_argument(
"--number-of-samples", "-n", default=100, help="Number of samples to generate (default:100)"
)
parser.add_argument(
"--aws-account", required=False, help="AWS account number for Bedrock ARN"
)
parser.add_argument(
"--upload-s3", action="store_true", help="Upload generated datasets to S3"
)
parser.add_argument(
"--endpoint-type",
choices=["bedrock", "sagemaker"],
default="bedrock",
help="Type of endpoint to use (default: bedrock)",
)
parser.add_argument(
"--data-dir", default="~/data", help="Where the data files are stored"
)
return parser
def main():
parser = setup_argparse()
args = parser.parse_args()
if args.aws_account is None:
aws_account = get_current_aws_account_id()
else:
aws_account = args.aws_account
settings = Dynaconf(
settings_files=["settings.toml"], env=args.model, environments=True
)
settings.model = settings.model.format(aws_account=aws_account)
bedrock_client = boto3.client(
service_name="bedrock-runtime", region_name=settings.region
)
if args.action == "generate":
settings.model = settings.model.format(aws_account=aws_account)
bedrock_client = boto3.client(
service_name="bedrock-runtime", region_name=settings.region
)
if args.type == "all":
generate_all_datasets(settings, bedrock_client, args.model, args.upload_s3)
else:
generate_specific_dataset(
settings, bedrock_client, args.type, args.model, args.upload_s3
)
elif args.action == "inference":
if args.endpoint_type == "bedrock":
inference_model = settings.model.format(aws_account=aws_account)
client = boto3.client(
service_name="bedrock-runtime", region_name=settings.region
)
else: # sagemaker
inference_model = args.model # Use model name directly as endpoint name
client = None # Not needed for SageMaker endpoint
run_inference(
settings,
client,
inference_model,
args.upload_s3,
args.data_dir,
args.endpoint_type
)
elif args.action == "llm_judge":
if args.endpoint_type == "bedrock":
judge_model = settings.model.format(aws_account=aws_account)
client = boto3.client(
service_name="bedrock-runtime", region_name=settings.region
)
else: # sagemaker
judge_model = args.model # Use model name directly as endpoint name
client = None # Not needed for SageMaker endpoint
judge(
settings,
client,
judge_model,
args.upload_s3,
args.data_dir,
args.endpoint_type
)
if __name__ == "__main__":
main()