Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@ async def run_sub_agent(

# Start new sub-agent session
self.task_log.start_sub_agent_session(sub_agent_name, task_description)

# Reset sub-agent usage stats for independent tracking
self.sub_agent_llm_client.reset_usage_stats()

# Simplified initial user content (no file attachments)
initial_user_content = [{"type": "text", "text": task_description}]
Expand Down Expand Up @@ -661,6 +664,14 @@ async def run_sub_agent(
] = {"system_prompt": system_prompt, "message_history": message_history} # type: ignore
self.task_log.save()

# Record sub-agent cumulative usage
usage_log = self.sub_agent_llm_client.get_usage_log()
self.task_log.log_step(
"usage_calculation",
usage_log,
metadata={"session_id": self.task_log.current_sub_agent_session_id},
)

self.task_log.end_sub_agent_session(sub_agent_name)
self.task_log.log_step(
"sub_agent_completed", f"Sub agent {sub_agent_name} completed", "info"
Expand All @@ -682,6 +693,9 @@ async def run_main_agent(
if task_file_name:
logger.debug(f"Associated File: {task_file_name}")

# Reset main agent usage stats for independent tracking
self.llm_client.reset_usage_stats()

# 1. Process input
initial_user_content, task_description = process_input(
task_description, task_file_name
Expand Down Expand Up @@ -1089,6 +1103,14 @@ async def run_main_agent(
"task_completed", f"Main agent task {task_id} completed successfully"
)

# Record main agent cumulative usage
usage_log = self.llm_client.get_usage_log()
self.task_log.log_step(
"usage_calculation",
usage_log,
metadata={"session_id": "main_agent"},
)

if "browsecomp-zh" in self.cfg.benchmark.name:
return final_summary, final_summary
else:
Expand Down
65 changes: 65 additions & 0 deletions src/llm/provider_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class LLMProviderClientBase(ABC):

# post_init
client: Any = dataclasses.field(init=False)
# Usage tracking - cumulative for each agent session
total_input_tokens: int = dataclasses.field(init=False, default=0)
total_input_cached_tokens: int = dataclasses.field(init=False, default=0)
total_output_tokens: int = dataclasses.field(init=False, default=0)
total_output_reasoning_tokens: int = dataclasses.field(init=False, default=0)

def __post_init__(self):
# Explicitly assign from cfg object
Expand Down Expand Up @@ -195,6 +200,19 @@ async def create_message(
tool_definitions,
keep_tool_result=keep_tool_result,
)

# Accumulate usage for agent session
if response:
try:
usage = self._extract_usage_from_response(response)
if usage:
self.total_input_tokens += usage.get("input_tokens", 0)
self.total_input_cached_tokens += usage.get("cached_tokens", 0)
self.total_output_tokens += usage.get("output_tokens", 0)
self.total_output_reasoning_tokens += usage.get("reasoning_tokens", 0)
except Exception as e:
logger.warning(f"Failed to accumulate usage: {e}")

return response

@staticmethod
Expand Down Expand Up @@ -314,3 +332,50 @@ def handle_max_turns_reached_summary_prompt(
self, message_history: list[dict[str, Any]], summary_prompt: str
):
raise NotImplementedError("must implement in subclass")

def _extract_usage_from_response(self, response):
"""Default Extract usage - OpenAI Chat Completions format"""
if not hasattr(response, 'usage'):
return {
"input_tokens": 0,
"cached_tokens": 0,
"output_tokens": 0,
"reasoning_tokens": 0
}

usage = response.usage
prompt_tokens_details = getattr(usage, 'prompt_tokens_details', {}) or {}
if hasattr(prompt_tokens_details, "to_dict"):
prompt_tokens_details = prompt_tokens_details.to_dict()
completion_tokens_details = getattr(usage, 'completion_tokens_details', {}) or {}
if hasattr(completion_tokens_details, "to_dict"):
completion_tokens_details = completion_tokens_details.to_dict()

usage_dict = {
"input_tokens": getattr(usage, 'prompt_tokens', 0),
"cached_tokens": prompt_tokens_details.get('cached_tokens', 0),
"output_tokens": getattr(usage, 'completion_tokens', 0),
"reasoning_tokens": completion_tokens_details.get('reasoning_tokens', 0)
}

return usage_dict

def get_usage_log(self) -> str:
"""Get cumulative usage for current agent session as formatted string"""
# Format: [Provider | Model] Total Input: X, Cache Input: Y, Output: Z, ...
provider_model = f"[{self.provider_class} | {self.model_name}]"
input_uncached = self.total_input_tokens - self.total_input_cached_tokens
output_response = self.total_output_tokens - self.total_output_reasoning_tokens
total_tokens = self.total_input_tokens + self.total_output_tokens

return (f"Usage log: {provider_model}, "
f"Total Input: {self.total_input_tokens} (Cached: {self.total_input_cached_tokens}, Uncached: {input_uncached}), "
f"Total Output: {self.total_output_tokens} (Reasoning: {self.total_output_reasoning_tokens}, Response: {output_response}), "
f"Total Tokens: {total_tokens}")

def reset_usage_stats(self):
"""Reset usage stats for new agent session"""
self.total_input_tokens = 0
self.total_input_cached_tokens = 0
self.total_output_tokens = 0
self.total_output_reasoning_tokens = 0
27 changes: 26 additions & 1 deletion src/llm/providers/claude_anthropic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __post_init__(self):

def _create_client(self, config: DictConfig):
"""Create Anthropic client"""
api_key = config.env.anthropic_api_key
api_key = self.cfg.llm.anthropic_api_key

if self.async_client:
return AsyncAnthropic(
Expand Down Expand Up @@ -183,6 +183,31 @@ def handle_max_turns_reached_summary_prompt(self, message_history, summary_promp
else:
return summary_prompt

def _extract_usage_from_response(self, response):
"""Extract usage - Anthropic format"""
if not hasattr(response, 'usage'):
return {
"input_tokens": 0,
"cached_tokens": 0,
"output_tokens": 0,
"reasoning_tokens": 0
}

usage = response.usage
cache_creation_input_tokens = getattr(usage, 'cache_creation_input_tokens', 0)
cache_read_input_tokens = getattr(usage, 'cache_read_input_tokens', 0)
input_tokens = getattr(usage, 'input_tokens', 0)
output_tokens = getattr(usage, 'output_tokens', 0)

usage_dict = {
"input_tokens": cache_creation_input_tokens + cache_read_input_tokens + input_tokens,
"cached_tokens": cache_read_input_tokens,
"output_tokens": output_tokens,
"reasoning_tokens": 0
}

return usage_dict

def _apply_cache_control(self, messages):
"""Apply cache control to the last user message and system message (if applicable)"""
cached_messages = []
Expand Down
8 changes: 4 additions & 4 deletions src/llm/providers/claude_newapi_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ class ClaudeNewAPIClient(LLMProviderClientBase):
def _create_client(self, config: DictConfig):
if self.async_client:
return AsyncOpenAI(
api_key=config.env.newapi_api_key,
base_url=config.env.newapi_base_url,
api_key=self.cfg.llm.newapi_api_key,
base_url=self.cfg.llm.newapi_base_url,
)
else:
return OpenAI(
api_key=config.env.newapi_api_key,
base_url=config.env.newapi_base_url,
api_key=self.cfg.llm.newapi_api_key,
base_url=self.cfg.llm.newapi_base_url,
)

# @retry(wait=wait_fixed(10), stop=stop_after_attempt(5))
Expand Down
4 changes: 4 additions & 0 deletions src/llm/providers/claude_openrouter_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ async def _create_message(
if self.repetition_penalty != 1.0:
extra_body["repetition_penalty"] = self.repetition_penalty

extra_body["usage"] = {
"include": True
}

params = {
"model": self.model_name,
"temperature": temperature,
Expand Down
8 changes: 4 additions & 4 deletions src/llm/providers/deepseek_newapi_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def _create_client(self, config: DictConfig):
"""Create configured OpenAI client"""
if self.async_client:
return AsyncOpenAI(
api_key=config.env.newapi_api_key,
base_url=config.env.newapi_base_url,
api_key=self.cfg.llm.newapi_api_key,
base_url=self.cfg.llm.newapi_base_url,
)
else:
return OpenAI(
api_key=config.env.newapi_api_key,
base_url=config.env.newapi_base_url,
api_key=self.cfg.llm.newapi_api_key,
base_url=self.cfg.llm.newapi_base_url,
)

# @retry(wait=wait_fixed(10), stop=stop_after_attempt(5))
Expand Down
8 changes: 4 additions & 4 deletions src/llm/providers/gpt_openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def _create_client(self, config: DictConfig):
"""Create configured OpenAI client"""
if self.async_client:
return AsyncOpenAI(
api_key=config.env.openai_api_key,
base_url=config.env.openai_base_url,
api_key=self.cfg.llm.openai_api_key,
base_url=self.cfg.llm.openai_base_url,
)
else:
return OpenAI(
api_key=config.env.openai_api_key,
base_url=config.env.openai_base_url,
api_key=self.cfg.llm.openai_api_key,
base_url=self.cfg.llm.openai_base_url,
)

@retry(wait=wait_fixed(10), stop=stop_after_attempt(5))
Expand Down
33 changes: 30 additions & 3 deletions src/llm/providers/gpt_openai_response_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def _create_client(self, config: DictConfig):
"""Create configured OpenAI client"""
if self.async_client:
return AsyncOpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
api_key=self.cfg.llm.openai_api_key,
base_url=self.cfg.llm.openai_base_url,
)
else:
return OpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
api_key=self.cfg.llm.openai_api_key,
base_url=self.cfg.llm.openai_base_url,
)

Expand Down Expand Up @@ -93,7 +93,7 @@ async def _create_message(
response = self._convert_response_to_serializable(response)

# Update token count
self._update_token_usage(response.get("usage", None))
# self._update_token_usage(response.get("usage", None))
Copy link

Copilot AI Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commented-out code should be removed rather than left as a comment. The functionality has been replaced by the new usage tracking system in the base class.

Suggested change
# self._update_token_usage(response.get("usage", None))

Copilot uses AI. Check for mistakes.

logger.debug(
f"LLM Response API call status: {response.get('error', 'N/A')}"
)
Expand Down Expand Up @@ -269,3 +269,30 @@ def _convert_response_to_serializable(self, response):
}

return serializable_response

def _extract_usage_from_response(self, response):
"""Extract usage - OpenAI Responses API format"""
if not response or not response.get('usage'):
return {
"input_tokens": 0,
"cached_tokens": 0,
"output_tokens": 0,
"reasoning_tokens": 0
}

usage = response.get('usage', {}) or {}
input_tokens_details = usage.get('input_tokens_details', {}) or {}
if hasattr(input_tokens_details, "to_dict"):
input_tokens_details = input_tokens_details.to_dict()
output_tokens_details = usage.get('output_tokens_details', {}) or {}
if hasattr(output_tokens_details, "to_dict"):
output_tokens_details = output_tokens_details.to_dict()

usage_dict = {
"input_tokens": usage.get('input_tokens', 0),
"cached_tokens": input_tokens_details.get('cached_tokens', 0),
"output_tokens": usage.get('output_tokens', 0),
"reasoning_tokens": output_tokens_details.get('reasoning_tokens', 0)
}

return usage_dict
8 changes: 4 additions & 4 deletions src/llm/providers/qwen_sglang_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ def _create_client(self, config: DictConfig):
"""Create configured OpenAI client"""
if self.async_client:
return AsyncOpenAI(
api_key=config.env.qwen_api_key,
base_url=config.env.qwen_base_url,
api_key=self.cfg.llm.qwen_api_key,
base_url=self.cfg.llm.qwen_base_url,
)
else:
return OpenAI(
api_key=config.env.qwen_api_key,
base_url=config.env.qwen_base_url,
api_key=self.cfg.llm.qwen_api_key,
base_url=self.cfg.llm.qwen_base_url,
)

@retry(wait=wait_fixed(10), stop=stop_after_attempt(5))
Expand Down