Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 57 additions & 70 deletions src/arbiter_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dataclasses import dataclass
from google.genai.types import Tool, GenerateContentConfig, GoogleSearch
from shared_resources import SpacyModelSingleton # Import the correct singleton class
from bs4 import BeautifulSoup


# Third-party imports
Expand Down Expand Up @@ -420,23 +421,64 @@ def ground_assertions(
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
gemini_filename = f"arbiter_output_{timestamp}.html"

# Create the HTML content by inserting the raw response inside the body tag
html_parts = gemini_template.split("%s")
if len(html_parts) == 2:
full_html = html_parts[0] + raw_response + html_parts[1]

full_html = html_parts[0] + raw_response + html_parts[1] # Use raw_response here
with open(gemini_filename, "w") as f:
f.write(full_html)

logger.debug(f"Raw Gemini output saved as {gemini_filename}")
else:
logger.warning("Template format doesn't contain a single '%s' placeholder")
logger.warning("Template format for gemini_output.html doesn't contain a single '%s' placeholder")
except Exception as e:
logger.error(f"Failed to save raw Gemini output: {e}")

# --- Refactor HTML parsing using BeautifulSoup ---
soup = BeautifulSoup(response_full, 'html.parser')

# --- Inject Model Information ---
ai_model_name = getattr(self, "_ai_model", os.environ.get("AI_MODEL", "N/A"))
human_model_name = getattr(self, "_human_model", os.environ.get("HUMAN_MODEL", "N/A"))

h3_texts_to_update = {
"Conversation 1 (AI-AI Meta-Prompted)": f"Conversation 1 (AI-AI Meta-Prompted) - Models: {human_model_name} & {ai_model_name}",
"Conversation 2 (Human-AI Meta-Prompted)": f"Conversation 2 (Human-AI Meta-Prompted) - Models: {human_model_name} & {ai_model_name}",
"Conversation 3 (Non-Metaprompted)": f"Conversation 3 (Non-Metaprompted) - Models: {human_model_name} & {ai_model_name}",
}

for h3_tag in soup.find_all('h3'):
original_text = h3_tag.string
if original_text and original_text.strip() in h3_texts_to_update:
h3_tag.string = h3_texts_to_update[original_text.strip()]
logger.debug(f"Updated h3 tag: '{original_text}' to '{h3_tag.string}'")

# --- Extract Winner Statement using BeautifulSoup ---
extracted_winner = "No clear winner determined" # Default
try:
comparative_analysis_h2 = soup.find('h2', string=lambda text: text and "Comparative Analysis" in text)
if comparative_analysis_h2:
section_div = comparative_analysis_h2.find_parent('div', class_='section')
if section_div:
first_p_tag = section_div.find('p')
if first_p_tag and first_p_tag.string:
winner_text = first_p_tag.string.strip()
if winner_text:
extracted_winner = winner_text[:250] + ('...' if len(winner_text) > 250 else '')
logger.debug(f"Extracted winner statement with BeautifulSoup: {extracted_winner}")
else:
logger.warning("Found <p> tag in Comparative Analysis section, but it was empty.")
else:
logger.warning("Could not find a <p> tag with content in the Comparative Analysis section.")
else:
logger.warning("Could not find parent 'div.section' for 'Comparative Analysis' h2.")
else:
logger.warning("Could not find '<h2>Comparative Analysis</h2>' section using BeautifulSoup.")
except Exception as parse_err:
logger.error(f"BeautifulSoup parsing error for winner: {parse_err}", exc_info=True)

modified_response_html = str(soup) # Get the modified HTML string

# Save the formatted arbiter report with proper styling
try:
# Try to use the new template first
template_path = "templates/new_arbiter_report.html"
if not os.path.exists(template_path):
template_path = "templates/simple_arbiter_report.html"
Expand All @@ -448,86 +490,31 @@ def ground_assertions(
formatted_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
formatted_filename = f"arbiter_report_{timestamp}.html"

# Inject model information into the report response
# Look for the first h3 tag for each conversation
modified_response = response_full

# Get model information from environment variables or function params
ai_model = os.environ.get("AI_MODEL", "")
human_model = os.environ.get("HUMAN_MODEL", "")

# If environment variables aren't set but we have function parameters, use those directly
if (not ai_model or not human_model) and hasattr(self, "_ai_model") and hasattr(self, "_human_model"):
ai_model = self._ai_model
human_model = self._human_model

# Replace conversation headers to include model information
model_tags = [
("<h3>Conversation 1 (AI-AI Meta-Prompted)</h3>",
f"<h3>Conversation 1 (AI-AI Meta-Prompted) - Models: {human_model} & {ai_model}</h3>"),
("<h3>Conversation 2 (Human-AI Meta-Prompted)</h3>",
f"<h3>Conversation 2 (Human-AI Meta-Prompted) - Models: {human_model} & {ai_model}</h3>"),
("<h3>Conversation 3 (Non-Metaprompted)</h3>",
f"<h3>Conversation 3 (Non-Metaprompted) - Models: {human_model} & {ai_model}</h3>")
]

for old_tag, new_tag in model_tags:
modified_response = modified_response.replace(old_tag, new_tag)

# --- Add parsing logic to extract winner ---
extracted_winner = "No clear winner determined" # Default
try:
start_tag = '<h2>Comparative Analysis</h2>'
start_index = modified_response.find(start_tag)
if start_index != -1:
# Find the first <p> tag after the Comparative Analysis header
p_start_tag = '<p>'
p_start_index = modified_response.find(p_start_tag, start_index + len(start_tag))
if p_start_index != -1:
p_end_tag = '</p>'
p_end_index = modified_response.find(p_end_tag, p_start_index + len(p_start_tag))
if p_end_index != -1:
# Extract text between <p> tags, strip whitespace, limit length
winner_text = modified_response[p_start_index + len(p_start_tag):p_end_index].strip()
# Basic check if it looks like a winner statement and not just tags
if winner_text and not winner_text.startswith("<"):
extracted_winner = winner_text[:250] + ('...' if len(winner_text) > 250 else '') # Limit length slightly more
logger.debug(f"Extracted winner statement: {extracted_winner}")
else:
logger.warning("Found <p> tag after Comparative Analysis, but content seems invalid or empty.")
else:
logger.warning("Could not find <p> tag following Comparative Analysis header.")
else:
logger.warning("Could not find '<h2>Comparative Analysis</h2>' section in the response.")

except Exception as parse_err:
logger.warning(f"Could not parse winner from HTML response due to error: {parse_err}")
# --- End parsing logic ---

# For simple template, insert content directly without string formatting
if template_path == "templates/simple_arbiter_report.html":
html_parts = report_template.split("%s")
if len(html_parts) == 2:
full_html = html_parts[0] + modified_response + html_parts[1]

full_html = html_parts[0] + modified_response_html + html_parts[1]
with open(formatted_filename, "w") as f:
f.write(full_html)
else:
logger.warning("Simple template format doesn't contain a single '%s' placeholder")
else:
# For new template, use safer string.Template
import string
template = string.Template(report_template)
with open(formatted_filename, "w") as f:
f.write(template.safe_substitute(
gemini_content=modified_response,
winner=extracted_winner, # Use extracted winner here
timestamp=formatted_timestamp
gemini_content=modified_response_html, # Use BS modified HTML
winner=extracted_winner,
timestamp=formatted_timestamp,
# Ensure model_names_summary is available or add it to template_vars if needed
model_names_summary=f"Models: AI Model - {ai_model_name}, Human Model - {human_model_name}"
))

logger.info(f"Formatted arbiter report saved as {formatted_filename}")
except Exception as e:
logger.error(f"Failed to save formatted arbiter report: {e}")

return response_full
return modified_response_html # Return the BS modified HTML

except Exception as e:
logger.error(f"Error grounding assertion with Gemini: {e}")
Expand Down
11 changes: 11 additions & 0 deletions src/configdataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,11 @@ class ModelConfig:
type: str
role: str
persona: Optional[str] = field(default=None)
max_tokens: int = 3192 # Default max tokens, moved from global MAX_TOKENS
temperature: float = 0.8
stop_sequences: List[str] = field(default_factory=list)
seed: Optional[int] = None # Allow None, specific clients might generate a random seed if None
timeout: Optional[int] = 90 # Default timeout in seconds

def __post_init__(self):
# Validate model type
Expand All @@ -431,6 +436,12 @@ def __post_init__(self):
# Validate persona if provided
if self.persona and not isinstance(self.persona, str):
raise ValueError("Persona must be a string")
if not (0 <= self.temperature <= 2.0): # Typical range for temperature
raise ValueError("Temperature must be between 0.0 and 2.0")
if self.max_tokens <= 0:
raise ValueError("Max tokens must be positive")
if self.timeout is not None and self.timeout <= 0:
raise ValueError("Timeout must be positive if specified")


@dataclass
Expand Down
Loading
Loading