-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
1,292 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Evaluation args | ||
model_path="" | ||
inference_file="" | ||
output_file="" | ||
task_name="" | ||
seed="42" | ||
|
||
# environment parameters | ||
max_round="" | ||
env_server_base="" | ||
|
||
python -u base_eval_template.py \ | ||
--model_path "${model_path}" \ | ||
--inference_file "${inference_file}" \ | ||
--output_file "${output_file}" \ | ||
--task_name "${task_name}" \ | ||
--seed "${seed}" \ | ||
--max_round "${max_round}" \ | ||
--env_server_base "${env_server_base}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# Sequential Evaluation template | ||
|
||
import json | ||
import time | ||
from dataclasses import dataclass, field | ||
|
||
import jsonlines | ||
import transformers | ||
from agentenv.controller import Agent, Evaluator | ||
from agentenv.envs import ( | ||
AcademiaTask, | ||
AlfWorldTask, | ||
BabyAITask, | ||
MazeTask, | ||
MovieTask, | ||
SciworldTask, | ||
SheetTask, | ||
SqlGymTask, | ||
TextCraftTask, | ||
TodoTask, | ||
WeatherTask, | ||
WebarenaTask, | ||
WebshopTask, | ||
WordleTask, | ||
) | ||
from tqdm import tqdm | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | ||
|
||
|
||
@dataclass | ||
class EvalArguments: | ||
model_path: str | ||
inference_file: str = field(metadata={"help": "Test dataset."}) | ||
output_file: str | ||
task_name: str = field(default="webshop", metadata={"help": "Task name for evaluation"}) | ||
seed: int = field(default=42) | ||
|
||
# conversation rounds | ||
max_round: int = field( | ||
default=6, | ||
metadata={"help": "Interaction rounds between agents and environment"}, | ||
) | ||
|
||
# environment parameters | ||
env_server_base: str = field(default=None) | ||
data_len: int = field(default=200) | ||
timeout: int = field(default=2400) | ||
|
||
|
||
def main(): | ||
parser = transformers.HfArgumentParser(EvalArguments) | ||
(args,) = parser.parse_args_into_dataclasses() | ||
args = vars(args) | ||
print(args) | ||
print(json.dumps(args, indent=2, ensure_ascii=False)) | ||
|
||
MODEL_PATH = args["model_path"] | ||
DATA_PATH = args["inference_file"] | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | ||
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", trust_remote_code=True).eval() | ||
|
||
# task_name - task dict | ||
task_classes = { | ||
"webshop": WebshopTask, | ||
"alfworld": AlfWorldTask, | ||
"babyai": BabyAITask, | ||
"sciworld": SciworldTask, | ||
"textcraft": TextCraftTask, | ||
"webarena": WebarenaTask, | ||
"sqlgym": SqlGymTask, | ||
"maze": MazeTask, | ||
"wordle": WordleTask, | ||
"weather": WeatherTask, | ||
"todo": TodoTask, | ||
"movie": MovieTask, | ||
"sheet": SheetTask, | ||
"academia": AcademiaTask, | ||
} | ||
|
||
# select task according to the name | ||
task_class = task_classes.get(args["task_name"].lower(), None) | ||
if task_class is None: | ||
raise ValueError(f"Unsupported task name: {args.task_name}") | ||
|
||
# set environment parameters | ||
env_args = { | ||
"env_server_base": args["env_server_base"], | ||
"data_len": args["data_len"], | ||
"timeout": args["timeout"], | ||
} | ||
|
||
# set env client | ||
evaluator = Evaluator( | ||
Agent(model, tokenizer), | ||
[task_class(client_args=env_args, n_clients=1)], | ||
) | ||
|
||
with open(DATA_PATH, "r") as file: | ||
test_data = json.load(file) | ||
|
||
data_idxs = [[int(item["item_id"].split("_")[-1])] for item in test_data] | ||
|
||
total_score = 0.0 | ||
total_success = 0.0 | ||
start_time = time.time() | ||
for data_idx in tqdm(data_idxs, total=len(data_idxs), desc="[Evaluation Loop]"): | ||
exps = evaluator.eval( | ||
generation_config=GenerationConfig( | ||
max_length=4096, | ||
eos_token_id=tokenizer.eos_token_id, | ||
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.unk_token_id, | ||
), | ||
max_rounds=args["max_round"], | ||
idxs=data_idx, | ||
) | ||
total_score += exps.score | ||
total_success += exps.success | ||
|
||
cur_experiences = exps.experiences | ||
# write inference results to file | ||
with jsonlines.open(args["output_file"], mode="a") as f: | ||
for exp in cur_experiences: | ||
conversation = exp.conversation | ||
cur_reward = exp.reward | ||
cur_success = 1 if exp.reward == 1 else 0 | ||
item_id = f"{args['task_name']}_{data_idx}" | ||
f.write( | ||
{ | ||
"conversations": conversation, | ||
"item_id": item_id, | ||
"reward": cur_reward, | ||
"success": cur_success, | ||
} | ||
) | ||
process_time = time.time() - start_time | ||
|
||
Score = total_score / len(data_idxs) | ||
Success = total_success / len(data_idxs) | ||
print("\n\n==== EVALUATION ====\n") | ||
print(f"Score: {Score}") | ||
print(f"Success: {Success}") | ||
print(f"Time: {process_time} seconds") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Evaluate AgentLM-7B on WebShop task. | ||
# 200 data pieces, max_length=4096, max_rounds=20 | ||
|
||
|
||
import os | ||
|
||
from agentenv.controller import Agent, Evaluator | ||
|
||
# from agentenv.envs import WebshopTask | ||
from agentenv.envs import AlfWorldTask | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | ||
|
||
PYTORCH_MPS_HIGH_WATERMARK_RATIO = 0.0 | ||
|
||
|
||
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0" | ||
|
||
MODEL_PATH = "THUDM/agentlm-7b" | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | ||
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", trust_remote_code=True).eval() | ||
|
||
evaluator = Evaluator( | ||
Agent(model, tokenizer), | ||
[ | ||
AlfWorldTask( | ||
client_args={ | ||
"env_server_base": "http://127.0.0.1:36001", # If you have modified the port, modify it here. | ||
"data_len": 200, # Currently, the data_len argument is of no use. It will be removed in future versions. | ||
"timeout": 300, | ||
}, | ||
# The n_clients argument is reserved for subsequent implementations of batch generation. Please leave it at 1. | ||
n_clients=1, | ||
) | ||
], | ||
) | ||
|
||
exps = evaluator.eval( | ||
generation_config=GenerationConfig( | ||
max_length=4096, | ||
eos_token_id=tokenizer.eos_token_id, | ||
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, | ||
), | ||
max_rounds=7, | ||
idxs=list(range(200)), | ||
) | ||
|
||
print("\n\n==== EVALUATION ====\n") | ||
print(f"Score: {exps.score}") | ||
print(f"Success: {exps.success}") | ||
|
||
print("\n\n==== EXPERIENCES ====\n") | ||
for idx, exp in enumerate(exps.experiences[:3]): | ||
print(f"\n\n==== EXP {idx} ====\n") | ||
for message in exp.conversation: | ||
if message["from"] == "gpt": | ||
print(f"\n### Agent\n{message['value']}") | ||
else: | ||
print(f"\n### Env\n{message['value']}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Oops, something went wrong.