10
10
import numpy as np
11
11
from sortedcontainers import SortedSet
12
12
13
- model2plot = {
14
- # "gpt-3.5-finetuned": {
15
- # "label": "GPT-3.5-FT",
16
- # "color": "#a65628",
17
- # "marker": "s"
18
- # },
19
- "gpt-3.5-1106-function" : {
20
- "label" : "GPT-3.5-function" ,
21
- "color" : "#984ea3" ,
22
- "marker" : ">"
23
- },
24
- "gpt-4-turbo-function" : {
25
- "label" : "gpt-4-turbo-function" ,
26
- "color" : "#e41a1c" ,
27
- "marker" : "1"
28
- },
29
- # "gpt-4": {
30
- # "label": "GPT-4",
31
- # "color": "#ffff33",
32
- # "marker": "+"
33
- # },
34
- "codellama-7b-instruct" : {
35
- "label" : "CL-7B-Instruct" ,
36
- "color" : "#ff7f00" ,
37
- "marker" : "<"
38
- },
39
- }
40
-
41
13
42
14
def extract_result (file_path : str , model_name : str ) -> (dict , dict ):
43
15
average = {}
@@ -112,15 +84,15 @@ def extract_result(file_path: str, model_name: str) -> (dict, dict):
112
84
return to_plot_accuracy , to_plot_cost
113
85
114
86
115
- def plot_by_requirements (results_path : str , figures_path : str , requirements : SortedSet ) -> None :
87
+ def plot_by_requirements (results_path : str , figures_path : str , requirements : SortedSet , model2plot ) -> None :
88
+ # print(model2plot)
89
+ # exit()
116
90
model2result = {}
117
91
requirements_str = "_" .join (requirements )
118
92
119
93
for model_name in model2plot .keys ():
120
- # if include not in model_name or 'function' in model_name:
121
- # continue
122
94
123
- results_files_list = glob .glob (os .path .join ("." , results_path , f"result-{ model_name } -ad-hoc -{ requirements_str } -*.csv" ))
95
+ results_files_list = glob .glob (os .path .join ("../ " , results_path , f"result-{ model_name } -{ requirements_str } -*.csv" ))
124
96
if results_files_list :
125
97
results_file = results_files_list .pop ()
126
98
@@ -130,7 +102,7 @@ def plot_by_requirements(results_path: str, figures_path: str, requirements: Sor
130
102
model2result [model_name ]["accuracy" ], model2result [model_name ]["cost" ] = extract_result (results_file ,
131
103
model_name )
132
104
133
- base_figures_path = os .path .join ("." , figures_path )
105
+ base_figures_path = os .path .join ("../plot " , figures_path )
134
106
os .makedirs (base_figures_path , exist_ok = True )
135
107
136
108
# Accuracy
@@ -161,6 +133,9 @@ def plot_by_requirements(results_path: str, figures_path: str, requirements: Sor
161
133
plt .ylim ([- 0.1 , 1.2 ])
162
134
plt .yticks (np .arange (0 , 1.2 , 0.25 ))
163
135
plt .xscale ('log' , base = 10 )
136
+
137
+ # print(model2result)
138
+ # exit()
164
139
x_ticks = list (model2result .values ())[0 ]["accuracy" ]["x" ]
165
140
plt .xticks (x_ticks )
166
141
ax .xaxis .set_major_formatter (matplotlib .ticker .ScalarFormatter ())
@@ -215,7 +190,7 @@ def parse_args() -> argparse.Namespace:
215
190
parser = argparse .ArgumentParser ()
216
191
parser .add_argument ('--results_path' , type = str , required = False , default = "result" )
217
192
parser .add_argument ('--figures_path' , type = str , required = True )
218
- # parser.add_argument('--include' , type=str, choices=[' gpt', ' codellama' ])
193
+ parser .add_argument ("--models" , type = str , choices = [" gpt" , " codellama" ])
219
194
220
195
return parser .parse_args ()
221
196
@@ -226,11 +201,48 @@ def main(args: argparse.Namespace) -> None:
226
201
matplotlib .rcParams ['pdf.fonttype' ] = 42
227
202
matplotlib .rcParams ['ps.fonttype' ] = 42
228
203
229
- plot_by_requirements (args .results_path , args .figures_path , SortedSet ({"reachability" }))
230
- plot_by_requirements (args .results_path , args .figures_path , SortedSet ({"reachability" , "waypoint" }))
204
+ if args .models == "gpt" :
205
+ model2plot = {
206
+ "gpt-4-1106" : {
207
+ "label" : "GPT-4-Turbo" ,
208
+ "color" : "#377eb8" ,
209
+ "marker" : "o"
210
+ },
211
+ "gpt-3.5-finetuned" : {
212
+ "label" : "GPT-3.5-FT" ,
213
+ "color" : "#a65628" ,
214
+ "marker" : "s"
215
+ },
216
+ "gpt-3.5-0613" : {
217
+ "label" : "GPT-3.5-Turbo" ,
218
+ "color" : "#984ea3" ,
219
+ "marker" : ">"
220
+ },
221
+ }
222
+ if args .models == "codellama" :
223
+ model2plot = {
224
+ "codellama-13b-instruct" : {
225
+ "label" : "CL-13B-Instruct" ,
226
+ "color" : "#ff7f00" ,
227
+ "marker" : "<"
228
+ },
229
+ "codellama-7b-instruct-finetuned" : {
230
+ "label" : "CL-7B-Instruct-FT (QLoRA)" ,
231
+ "color" : "#4daf4a" ,
232
+ "marker" : ">"
233
+ },
234
+ "codellama-7b-instruct" : {
235
+ "label" : "CL-7B-Instruct" ,
236
+ "color" : "#f781bf" ,
237
+ "marker" : "^"
238
+ },
239
+ }
240
+
241
+ # plot_by_requirements(args.results_path, args.figures_path, SortedSet({"reachability"}), model2plot)
242
+ # plot_by_requirements(args.results_path, args.figures_path, SortedSet({"reachability", "waypoint"}), model2plot)
231
243
plot_by_requirements (args .results_path , args .figures_path ,
232
- SortedSet ({"loadbalancing" , "reachability" , "waypoint" }))
244
+ SortedSet ({"loadbalancing" , "reachability" , "waypoint" }), model2plot )
233
245
234
246
235
247
if __name__ == "__main__" :
236
- main (parse_args ())
248
+ main (parse_args ())
0 commit comments