-
Notifications
You must be signed in to change notification settings - Fork 85
Error Analysis API #208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
recursix
wants to merge
27
commits into
main
Choose a base branch
from
error-analysis
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Error Analysis API #208
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
048a622
Add initial implementation of ChangeSummarizer and EpisodeAnalysis cl…
recursix fd8fd95
Added chain summarizer prompt
Megh-Thakkar b8c85b1
Added error classification prompt
Megh-Thakkar 5cb6cc2
Fix typo
jardinetsouffleton 9f531cc
Update error_analysis.py
Megh-Thakkar 31e5bf5
added pipeline and tests
TLSDC 000893d
quick parsing to run from cligit push
TLSDC 4727a9e
even more parsing and making imports absolute
TLSDC 42f0362
.
TLSDC 394999b
chat_models can take str as input
TLSDC e0e786c
typing
TLSDC 46d2c8c
keep this here bc it's going to pop back up
TLSDC 8a882ad
pipeline mvp
TLSDC 3fab5b4
added a specific tab and viz for it in xray
TLSDC 2be23e5
added formatting options
TLSDC 41f8f69
Update summarizer_prompts.py
Megh-Thakkar 6163b47
xml parsing
TLSDC a1f3416
fix
TLSDC a455d0d
add error analysis prediction validation script
jardinetsouffleton 82dbaba
black version update
TLSDC 5fbbe57
phony command, joblib stuff, took think out of prompt
TLSDC 3a3d602
task_info
TLSDC 5bf1bac
added flag to oracle success or no
TLSDC c7b1c5a
Merge branch 'main' into error-analysis
TLSDC 0972130
darglint
TLSDC ec1395b
Merge branch 'error-analysis' of github.com:ServiceNow/AgentLab into …
TLSDC 46d1075
tests
TLSDC File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
black[jupyter]>=24.2.0 | ||
black[jupyter]>=24.2.0,<25 | ||
blacken-docs | ||
pre-commit | ||
pytest==7.3.2 | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import json | ||
import re | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Generator | ||
|
||
from bgym import ExpResult | ||
|
||
from agentlab.analyze.error_analysis.summarizer import ( | ||
ChangeSummarizer, | ||
EpisodeErrorSummarizer, | ||
EpisodeSummarizer, | ||
) | ||
from agentlab.analyze.inspect_results import yield_all_exp_results | ||
|
||
|
||
@dataclass | ||
class Analyzer: | ||
prompt: str | ||
llm = None | ||
|
||
def __call__(self, *args, **kwds): | ||
return "analysis" | ||
|
||
|
||
def analyze(exp_result, episode_summarizer, save_analysis_func): | ||
error_analysis = episode_summarizer(exp_result) | ||
save_analysis_func(exp_result, error_analysis) | ||
|
||
|
||
@dataclass | ||
class ErrorAnalysisPipeline: | ||
exp_dir: Path | ||
filter: str = None | ||
episode_summarizer: EpisodeSummarizer = None | ||
|
||
def filter_exp_results(self) -> Generator[ExpResult, None, None]: | ||
# TODO:(thibault) improve filtering | ||
exp_results = yield_all_exp_results(self.exp_dir) | ||
for exp_result in exp_results: | ||
if self.filter is None or self.filter in str(exp_result.exp_dir): | ||
yield exp_result | ||
|
||
def run_analysis(self, parallel=False, jobs=-1): | ||
filtered_results = self.filter_exp_results() | ||
|
||
if parallel: | ||
import joblib | ||
|
||
joblib.Parallel(n_jobs=jobs, backend="threading")( | ||
joblib.delayed(analyze)(exp_result, self.episode_summarizer, self.save_analysis) | ||
for exp_result in filtered_results | ||
) | ||
|
||
else: | ||
for exp_result in filtered_results: | ||
error_analysis = self.episode_summarizer(exp_result) | ||
self.save_analysis(exp_result, error_analysis) | ||
|
||
def save_analysis(self, exp_result: ExpResult, error_analysis: dict, exists_ok=True): | ||
"""Save the analysis to json""" | ||
analysis_path = exp_result.exp_dir / "error_analysis.json" | ||
if not exists_ok and analysis_path.exists(): | ||
raise FileExistsError(f"{analysis_path} already exists") | ||
with analysis_path.open("w") as f: | ||
json.dump(error_analysis, f, indent=4) | ||
|
||
|
||
AXTREE_FORMATTER = lambda x: x.get("axtree_txt", "No AXTREE available") | ||
HTML_FORMATTER = lambda x: x.get("pruned_html", "No HTML available") | ||
|
||
|
||
def main(): | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-e", "--exp_dir", type=str) | ||
parser.add_argument("-f", "--filter", type=str, default=None) | ||
parser.add_argument("-p", "--parallel", action="store_true") | ||
parser.add_argument("-j", "--jobs", type=int, default=-1) | ||
parser.add_argument("-g", "--guess_success", action="store_true") | ||
|
||
args = parser.parse_args() | ||
|
||
assert args.exp_dir is not None, "Please provide an exp_dir, e.g., -e /path/to/exp_dir" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. required=True? |
||
|
||
exp_dir = Path(args.exp_dir) | ||
filter = args.filter | ||
parallel = args.parallel | ||
jobs = args.jobs | ||
guess_success = args.guess_success | ||
|
||
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT | ||
|
||
llm = CHAT_MODEL_ARGS_DICT["azure/gpt-4o-2024-08-06"].make_model() | ||
|
||
pipeline = ErrorAnalysisPipeline( | ||
exp_dir=exp_dir, | ||
filter=filter, | ||
episode_summarizer=EpisodeErrorSummarizer( | ||
ChangeSummarizer(llm, AXTREE_FORMATTER), llm, guess_success=guess_success | ||
), | ||
) | ||
|
||
pipeline.run_analysis(parallel=parallel, jobs=jobs) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
from dataclasses import dataclass | ||
|
||
from bgym import ExpResult, StepInfo | ||
|
||
from agentlab.analyze.error_analysis.summarizer_prompts import ( | ||
CHANGE_SUMMARIZER_PROMPT, | ||
ERROR_CLASSIFICATION_PROMPT, | ||
ERROR_CLASSIFICATION_PROMPT_SUCCESS_OR_NOT, | ||
) | ||
from agentlab.llm.llm_utils import json_parser, parse_html_tags | ||
from agentlab.llm.tracking import set_tracker | ||
|
||
|
||
def _diff(past_obs, current_obs): | ||
"""TODO: Implement the diff function. | ||
|
||
Returns a diff version of current_obs compares to past_obs, unless there is too many changes. | ||
|
||
Args: | ||
past_obs: The past observation. | ||
current_obs: The current observation. | ||
|
||
Raises: | ||
ValueError: Not implemented yet. | ||
""" | ||
raise ValueError("Not implemented yet.") | ||
|
||
|
||
@dataclass | ||
class ChangeSummarizer: | ||
|
||
llm: callable # language model | ||
obs_formatter: callable = lambda x: x.get("dom_txt", "No AXTREE available") | ||
use_diff: bool = False | ||
|
||
def summarize(self, obs: StepInfo, next_obs: StepInfo, past_summaries: list[str]) -> str: | ||
"""Produces, a summary of the effect of an action.""" | ||
obs_message = self.obs_formatter(obs.obs) | ||
next_obs_message = self.obs_formatter(next_obs.obs) | ||
|
||
action = obs.action | ||
|
||
goal = obs.obs["goal"] # Use goal object from agentlab | ||
# TODO(thibault): switch to 'goal_object' | ||
# Outsource everything to formatter | ||
|
||
if self.use_diff: | ||
next_obs_message = _diff(obs_message, next_obs_message) | ||
|
||
return self.parse( | ||
self.llm( | ||
self.make_prompt( | ||
obs_message, | ||
action, | ||
next_obs_message, | ||
past_summaries, | ||
goal, | ||
obs.obs.get("plan", "No plan available"), | ||
) | ||
)["content"] | ||
) | ||
|
||
def make_prompt( | ||
self, past_obs_message, action, current_obs_message, past_summaries, goal, plan | ||
): | ||
"""TODO: Implement the prompt.""" | ||
return CHANGE_SUMMARIZER_PROMPT.format( | ||
goal=goal, | ||
plan=plan, | ||
past_observation=past_obs_message, | ||
current_observation=current_obs_message, | ||
past_summaries=past_summaries, | ||
action=action, | ||
) | ||
|
||
def parse(self, raw_output: str) -> dict: | ||
parsed_result = parse_html_tags( | ||
raw_output, keys=["changeSummary", "actionAssessment", "explanation", "suggestion"] | ||
)[0] | ||
return parsed_result | ||
|
||
|
||
@dataclass | ||
class EpisodeAnalysis: | ||
analysis: str # complete analysis of the episode | ||
summary: str # short summary of the analysis | ||
categories: dict[str, float] # score for each category e.g. type of error or difficulty levels | ||
|
||
|
||
@dataclass | ||
class EpisodeSummarizer: | ||
|
||
change_summarizer: ChangeSummarizer = None | ||
llm: callable = None | ||
parser: callable = lambda x: json_parser(x)[0] | ||
guess_success: bool = False | ||
|
||
def make_prompt(self, exp_results: ExpResult, summaries: list[str]): ... | ||
|
||
def __call__(self, exp_results: ExpResult) -> EpisodeAnalysis: | ||
"""Run Change Summarizer for every step in the episode or extract a pre-computed one.""" | ||
|
||
if not self.guess_success: | ||
if exp_results.steps_info[-1].reward == 1: | ||
return {"analysis": "Success", "summaries": {}} | ||
|
||
with set_tracker("summary") as summaries_tracker: | ||
summaries = self.make_change_summaries(exp_results) | ||
prompt = self.make_prompt(exp_results, summaries) | ||
|
||
with set_tracker("analysis") as analysis_tracker: | ||
raw_analysis = self.llm(prompt)["content"] | ||
analysis = self.parse(raw_analysis) | ||
res = { | ||
"analysis": analysis, | ||
"summaries": {i: a for i, a in enumerate(summaries)}, | ||
} | ||
res.update(analysis_tracker.stats) | ||
res.update(summaries_tracker.stats) | ||
return res | ||
|
||
def make_change_summaries(self, exp_result: ExpResult) -> list[str]: | ||
summaries = [] # type: list[str] | ||
# this assumes that there is always an extra step at the end of the episode | ||
# it is generally the case, but exps can sometimes fail in a weird way and not save the last step_info | ||
# TODO:(thibault) make some checks or w/e | ||
for step, next_step in zip(exp_result.steps_info[:-1], exp_result.steps_info[1:]): | ||
summaries.append(self.change_summarizer.summarize(step, next_step, summaries)) | ||
return summaries | ||
|
||
def parse(self, raw_output: str) -> dict: | ||
parsed_result = parse_html_tags(raw_output, keys=["explanation", "errorCategory"])[0] | ||
return parsed_result | ||
|
||
|
||
@dataclass | ||
class EpisodeErrorSummarizer(EpisodeSummarizer): | ||
|
||
change_summarizer: ChangeSummarizer = None | ||
|
||
def make_prompt(self, exp_results: ExpResult, summaries: list[str]): | ||
"""TODO: Implement the prompt.""" | ||
goal = exp_results.steps_info[0].obs["goal"] | ||
|
||
def format_summary(summary): | ||
res = "" | ||
for key, value in summary.items(): | ||
res += f"{key}: {value}\n" | ||
return res | ||
|
||
txt_summaries = "\n".join([format_summary(summary) for summary in summaries]) | ||
|
||
actions = [step.action for step in exp_results.steps_info[:-1]] | ||
action_errors = "\n".join( | ||
[step.obs["last_action_error"] for step in exp_results.steps_info[1:]] | ||
) | ||
|
||
txt_actions = "\n".join( | ||
[ | ||
f"Action: {action}\nAction Error: {action_error}" | ||
for action, action_error in zip(actions, action_errors) | ||
] | ||
) | ||
|
||
extra_info = exp_results.steps_info[-1].task_info | ||
|
||
prompt = ( | ||
ERROR_CLASSIFICATION_PROMPT_SUCCESS_OR_NOT | ||
if self.guess_success | ||
else ERROR_CLASSIFICATION_PROMPT | ||
) | ||
|
||
return prompt.format( | ||
goal=goal, | ||
historical_summaries=txt_summaries, | ||
action_history=txt_actions, | ||
extra_info=extra_info, | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nicee