Skip to content

Commit

Permalink
readme
Browse files Browse the repository at this point in the history
  • Loading branch information
femto committed Oct 17, 2024
1 parent 62814e3 commit b804b34
Show file tree
Hide file tree
Showing 8 changed files with 1,292 additions and 13 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

Minion is designed to execute and analyze complex queries, offering a variety of features that demonstrate its flexibility and intelligence.

![Minion](assets/minion1.webp)

## Minion Design

The core logic of Minion is implemented in `examples/smart_minion/brain.py`. You can experiment with different examples by modifying the code, as various scenarios are commented out for easy testing.
Expand Down
Binary file added assets/minion1.webp
Binary file not shown.
19 changes: 19 additions & 0 deletions examples/smart_minion/basic/base_eval_script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Evaluation args
model_path=""
inference_file=""
output_file=""
task_name=""
seed="42"

# environment parameters
max_round=""
env_server_base=""

python -u base_eval_template.py \
--model_path "${model_path}" \
--inference_file "${inference_file}" \
--output_file "${output_file}" \
--task_name "${task_name}" \
--seed "${seed}" \
--max_round "${max_round}" \
--env_server_base "${env_server_base}"
147 changes: 147 additions & 0 deletions examples/smart_minion/basic/base_eval_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Sequential Evaluation template

import json
import time
from dataclasses import dataclass, field

import jsonlines
import transformers
from agentenv.controller import Agent, Evaluator
from agentenv.envs import (
AcademiaTask,
AlfWorldTask,
BabyAITask,
MazeTask,
MovieTask,
SciworldTask,
SheetTask,
SqlGymTask,
TextCraftTask,
TodoTask,
WeatherTask,
WebarenaTask,
WebshopTask,
WordleTask,
)
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig


@dataclass
class EvalArguments:
model_path: str
inference_file: str = field(metadata={"help": "Test dataset."})
output_file: str
task_name: str = field(default="webshop", metadata={"help": "Task name for evaluation"})
seed: int = field(default=42)

# conversation rounds
max_round: int = field(
default=6,
metadata={"help": "Interaction rounds between agents and environment"},
)

# environment parameters
env_server_base: str = field(default=None)
data_len: int = field(default=200)
timeout: int = field(default=2400)


def main():
parser = transformers.HfArgumentParser(EvalArguments)
(args,) = parser.parse_args_into_dataclasses()
args = vars(args)
print(args)
print(json.dumps(args, indent=2, ensure_ascii=False))

MODEL_PATH = args["model_path"]
DATA_PATH = args["inference_file"]

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", trust_remote_code=True).eval()

# task_name - task dict
task_classes = {
"webshop": WebshopTask,
"alfworld": AlfWorldTask,
"babyai": BabyAITask,
"sciworld": SciworldTask,
"textcraft": TextCraftTask,
"webarena": WebarenaTask,
"sqlgym": SqlGymTask,
"maze": MazeTask,
"wordle": WordleTask,
"weather": WeatherTask,
"todo": TodoTask,
"movie": MovieTask,
"sheet": SheetTask,
"academia": AcademiaTask,
}

# select task according to the name
task_class = task_classes.get(args["task_name"].lower(), None)
if task_class is None:
raise ValueError(f"Unsupported task name: {args.task_name}")

# set environment parameters
env_args = {
"env_server_base": args["env_server_base"],
"data_len": args["data_len"],
"timeout": args["timeout"],
}

# set env client
evaluator = Evaluator(
Agent(model, tokenizer),
[task_class(client_args=env_args, n_clients=1)],
)

with open(DATA_PATH, "r") as file:
test_data = json.load(file)

data_idxs = [[int(item["item_id"].split("_")[-1])] for item in test_data]

total_score = 0.0
total_success = 0.0
start_time = time.time()
for data_idx in tqdm(data_idxs, total=len(data_idxs), desc="[Evaluation Loop]"):
exps = evaluator.eval(
generation_config=GenerationConfig(
max_length=4096,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.unk_token_id,
),
max_rounds=args["max_round"],
idxs=data_idx,
)
total_score += exps.score
total_success += exps.success

cur_experiences = exps.experiences
# write inference results to file
with jsonlines.open(args["output_file"], mode="a") as f:
for exp in cur_experiences:
conversation = exp.conversation
cur_reward = exp.reward
cur_success = 1 if exp.reward == 1 else 0
item_id = f"{args['task_name']}_{data_idx}"
f.write(
{
"conversations": conversation,
"item_id": item_id,
"reward": cur_reward,
"success": cur_success,
}
)
process_time = time.time() - start_time

Score = total_score / len(data_idxs)
Success = total_success / len(data_idxs)
print("\n\n==== EVALUATION ====\n")
print(f"Score: {Score}")
print(f"Success: {Success}")
print(f"Time: {process_time} seconds")


if __name__ == "__main__":
main()
59 changes: 59 additions & 0 deletions examples/smart_minion/basic/eval_agentlm_webshop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Evaluate AgentLM-7B on WebShop task.
# 200 data pieces, max_length=4096, max_rounds=20


import os

from agentenv.controller import Agent, Evaluator

# from agentenv.envs import WebshopTask
from agentenv.envs import AlfWorldTask
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

PYTORCH_MPS_HIGH_WATERMARK_RATIO = 0.0


os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"

MODEL_PATH = "THUDM/agentlm-7b"

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", trust_remote_code=True).eval()

evaluator = Evaluator(
Agent(model, tokenizer),
[
AlfWorldTask(
client_args={
"env_server_base": "http://127.0.0.1:36001", # If you have modified the port, modify it here.
"data_len": 200, # Currently, the data_len argument is of no use. It will be removed in future versions.
"timeout": 300,
},
# The n_clients argument is reserved for subsequent implementations of batch generation. Please leave it at 1.
n_clients=1,
)
],
)

exps = evaluator.eval(
generation_config=GenerationConfig(
max_length=4096,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
),
max_rounds=7,
idxs=list(range(200)),
)

print("\n\n==== EVALUATION ====\n")
print(f"Score: {exps.score}")
print(f"Success: {exps.success}")

print("\n\n==== EXPERIENCES ====\n")
for idx, exp in enumerate(exps.experiences[:3]):
print(f"\n\n==== EXP {idx} ====\n")
for message in exp.conversation:
if message["from"] == "gpt":
print(f"\n### Agent\n{message['value']}")
else:
print(f"\n### Env\n{message['value']}")
23 changes: 10 additions & 13 deletions examples/smart_minion/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,20 @@ async def smart_brain():
# )
# print(obs)

# llm.model = "z3-" + llm.model

# llm.model = "re2-" + llm.model
obs, score, *_ = await brain.step(
query="一架飞机降落在水面发现漏水时,已经进了一些水,水匀速进入飞机内.如果10人淘水,3小时淘完;如5人淘水8小时淘完.如果要求2小时淘完,要安排多少人淘水?",
route="dot",
)
print(obs)

cache_plan = os.path.join(current_file_dir, "aime", "plan_gpt4o.7.json")
# cache_plan = os.path.join(current_file_dir, "aime", "plan_gpt4o.7.json")
#
# obs, score, *_ = await brain.step(
# query="Find the largest possible real part of\[(75+117i)z+\frac{96+144i}{z}\]where $z$ is a complex number with $|z|=4$.",
# route="cot",
# dataset="aime 2024",
# cache_plan=cache_plan,
# )
# print(obs)

llm.model = "re2-" + llm.model
obs, score, *_ = await brain.step(
query="Find the largest possible real part of\[(75+117i)z+\frac{96+144i}{z}\]where $z$ is a complex number with $|z|=4$.",
query="Real numbers $x$ and $y$ with $x,y>1$ satisfy $\log_x(y^x)=\log_y(x^{4y})=10.$ What is the value of $xy$?",
route="cot",
dataset="aime 2024",
cache_plan=cache_plan,
)
print(obs)

Expand Down
File renamed without changes.
Loading

0 comments on commit b804b34

Please sign in to comment.