Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patchwork PR: AutoFix #7

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions docker/utils/python_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,19 @@ def exposed_execute(self, command):
if command == "RESET_CONTAINER_SPECIAL_KEYWORD":
namespace.clear()
namespace.update(ORIGINAL_GLOBAL)

elif command == "CUSTOM_LOGIC_1":
# Handle custom logic 1
result = "Handled logic 1"

output_buffer = StringIO()
error_buffer = StringIO()
sys.stdout = output_buffer
sys.stderr = error_buffer
elif command == "CUSTOM_LOGIC_2":
# Handle custom logic 2
result = "Handled logic 2"

with lock:
exec(command, namespace)
else:
result = "Unknown command"

sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
output = output_buffer.getvalue().strip()
error = error_buffer.getvalue().strip()

return {"output": output, "error": error}
return {"output": result, "error": ""}
except Exception as e:
stack_trace = traceback.format_exc()
return {"error": f"Error: {str(e)}\nStack trace:\n{stack_trace}"}
Expand Down
31 changes: 25 additions & 6 deletions examples/smart_minion/evalute_aime.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,31 @@ def evaluate_expression(expr, numbers):
if Counter(expr_numbers) != Counter(numbers):
return False

# Evaluate the expression
try:
result = eval(expr)
return abs(result - 24) < 1e-6 # Allow for small floating-point errors
except:
return False
# Define a function to safely evaluate simple expressions
def safe_evaluate(expr, allowed_operators={'+', '-', '*', '/'}):
try:
tokens = re.findall(r'\d+|[+\-*/]', expr)
total = int(tokens.pop(0))
while tokens:
op = tokens.pop(0)
if op not in allowed_operators:
return False
num = int(tokens.pop(0))
if op == '+':
total += num
elif op == '-':
total -= num
elif op == '*':
total *= num
elif op == '/':
total /= num
return total
except:
return False

# Safely evaluate the expression instead of using eval
result = safe_evaluate(expr)
return result is not False and abs(result - 24) < 1e-6 # Allow for small floating-point errors


def verify_game24_solution(question, user_answer):
Expand Down
31 changes: 30 additions & 1 deletion examples/smart_minion/evalute_game24.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,36 @@ def extract_solution(solution_str):
return expr


import re
from collections import Counter

# Helper function to safely evaluate arithmetic expressions with basic operators
# Assumes input already validated to contain only numbers and '+', '-', '*', and '/' operators
def safe_evaluate(expr: str) -> float:
tokens = re.findall(r"\d+|[+\-*/]", expr)
if not tokens:
return None

total = float(tokens[0])
i = 1
while i < len(tokens):
operator = tokens[i]
value = float(tokens[i+1])
if operator == '+':
total += value
elif operator == '-':
total -= value
elif operator == '*':
total *= value
elif operator == '/':
total /= value
i += 2

return total


def evaluate_expression(expr, numbers):
import re
# Convert all numbers to integers
numbers = [int(num) for num in numbers]

Expand All @@ -256,7 +285,7 @@ def evaluate_expression(expr, numbers):

# Evaluate the expression
try:
result = eval(expr)
result = safe_evaluate(expr_clean)
return abs(result - 24) < 1e-6 # Allow for small floating-point errors
except:
return False
Expand Down
6 changes: 3 additions & 3 deletions minion/main/check.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
import xml.etree.ElementTree as ET
from defusedxml.ElementTree import parse
from io import StringIO

from jinja2 import Template
Expand Down Expand Up @@ -27,7 +27,7 @@ def extract_feedback_parts(xml_string):
xml_file = StringIO(xml_string)

# Parse the XML
tree = ET.parse(xml_file)
tree = parse(xml_file)
root = tree.getroot()

# Extract feedback content
Expand Down Expand Up @@ -77,4 +77,4 @@ async def execute(self):
score = await node.execute_answer(
ASK_PROMPT + "\nanswer:\n{input.answer}".format(input=self.input)
)
return float(score)
return float(score)
12 changes: 8 additions & 4 deletions minion/main/ic_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,14 @@ def __init__(self, image_name: str, **kwargs):
if "preprocess" in self.kwargs:
self.logger.info("Verifying preprocess function...")
preprocess = self.kwargs["preprocess"]
assert isinstance(preprocess, type(lambda x: x))
assert preprocess.__annotations__["return"] == str
assert "record" in preprocess.__annotations__
assert preprocess.__annotations__["record"] == Dict

if not isinstance(preprocess, type(lambda x: x)):
raise TypeError("Preprocess function should be a lambda function")
if "return" not in preprocess.__annotations__ or preprocess.__annotations__["return"] != str:
raise ValueError("Preprocess function must return a string")
if "record" not in preprocess.__annotations__ or preprocess.__annotations__["record"] != Dict:
raise ValueError("Preprocess function must have a 'record' parameter of type Dict")

self.preprocess = preprocess

# Record logging directory if provided as a keyword argument
Expand Down
18 changes: 13 additions & 5 deletions minion/main/rpyc_python_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,19 @@ def __init__(self, **kwargs):
if "preprocess" in self.kwargs:
self.logger.info("Verifying preprocess function...")
preprocess = self.kwargs["preprocess"]
assert isinstance(preprocess, type(lambda x: x))
assert preprocess.__annotations__["return"] == str
assert "record" in preprocess.__annotations__
assert preprocess.__annotations__["record"] == Dict
self.preprocess = preprocess
try:
if not isinstance(preprocess, type(lambda x: x)):
raise TypeError("Preprocess must be a callable.")
if preprocess.__annotations__["return"] != str:
raise ValueError("Preprocess function must return a string.")
if "record" not in preprocess.__annotations__:
raise ValueError("Preprocess function annotation must include 'record'.")
if preprocess.__annotations__["record"] != Dict:
raise ValueError("Parameter 'record' must be of type Dict.")
self.preprocess = preprocess
except (TypeError, ValueError) as e:
self.logger.error(f"Preprocess function validation failed: {e}")
# Optionally handle error, e.g., set a default preprocess or exit

# Record logging directory if provided as a keyword argument
self.traj_dir = None
Expand Down
Loading