Skip to content

Commit 539a314

Browse files
committed
feat: basic verifier
1 parent aab523c commit 539a314

File tree

3 files changed

+571
-1
lines changed

3 files changed

+571
-1
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
from typing import Any, Dict, List, Set, Union
2+
from z3 import Solver, Int, Bool, And, Or, Not, IntVal, BoolVal, simplify, Implies
3+
4+
from ..schemas.ast_nodes import *
5+
# Assuming imports from your specific file structure
6+
class FixedVerificationResult:
7+
def __init__(self):
8+
self.success: bool = False
9+
self.message: str = ""
10+
self.undefined_symbols: List[str] = []
11+
self.proof_steps: List[str] = []
12+
13+
def to_dict(self) -> Dict:
14+
return {
15+
"success": self.success,
16+
"message": self.message,
17+
"undefined_symbols": self.undefined_symbols,
18+
"proof_steps": self.proof_steps
19+
}
20+
from typing import Any, Dict, List, Set, Optional, Tuple
21+
from z3 import Solver, Int, Bool, And, Or, Not, IntVal, BoolVal, Implies, is_bool
22+
from evaluation_function.schemas.ast_nodes import (
23+
ProgramNode, BlockNode, AssignmentNode, VariableNode, LiteralNode,
24+
BinaryOpNode, UnaryOpNode, ConditionalNode, LoopNode,
25+
OperatorType, ExpressionNode, NodeType, LoopType
26+
)
27+
28+
class FixedALevelVerifier:
29+
def __init__(self):
30+
self.defined_variables: Set[str] = set()
31+
self.proof_steps: List[str] = []
32+
33+
def collect_symbols(self, precondition: ExpressionNode, program: ProgramNode) -> None:
34+
self.defined_variables.clear()
35+
self._collect_from_expr(precondition)
36+
if program.global_statements:
37+
self._collect_from_block(program.global_statements)
38+
39+
def _collect_from_expr(self, expr: Optional[ExpressionNode]):
40+
if not expr: return
41+
if isinstance(expr, VariableNode):
42+
self.defined_variables.add(expr.name)
43+
elif isinstance(expr, BinaryOpNode):
44+
self._collect_from_expr(expr.left)
45+
self._collect_from_expr(expr.right)
46+
elif isinstance(expr, UnaryOpNode):
47+
self._collect_from_expr(expr.operand)
48+
49+
def _collect_from_block(self, block: BlockNode):
50+
for stmt in block.statements:
51+
if isinstance(stmt, AssignmentNode) and isinstance(stmt.target, VariableNode):
52+
self.defined_variables.add(stmt.target.name)
53+
elif isinstance(stmt, ConditionalNode):
54+
if stmt.then_branch: self._collect_from_block(stmt.then_branch)
55+
for elif_b in stmt.elif_branches:
56+
if elif_b.then_branch: self._collect_from_block(elif_b.then_branch)
57+
if stmt.else_branch: self._collect_from_block(stmt.else_branch)
58+
elif isinstance(stmt, LoopNode) and stmt.body:
59+
self._collect_from_block(stmt.body)
60+
61+
# ------------------------
62+
# WP Logic
63+
# ------------------------
64+
def wp(self, statements: List[Any], post: ExpressionNode) -> ExpressionNode:
65+
wp_expr = post
66+
for stmt in reversed(statements):
67+
wp_expr = self._wp_stmt(stmt, wp_expr)
68+
self.proof_steps.append(f"WP after {self._stmt_to_str(stmt)}: {self._expr_to_str(wp_expr)}")
69+
return wp_expr
70+
71+
def _wp_stmt(self, stmt: Any, post: ExpressionNode) -> ExpressionNode:
72+
if isinstance(stmt, AssignmentNode) and isinstance(stmt.target, VariableNode):
73+
return self._substitute(post, stmt.target.name, stmt.value)
74+
75+
elif isinstance(stmt, ConditionalNode):
76+
current_wp = self.wp(stmt.else_branch.statements if stmt.else_branch else [], post)
77+
for elif_b in reversed(stmt.elif_branches):
78+
wp_elif = self.wp(elif_b.then_branch.statements if elif_b.then_branch else [], post)
79+
current_wp = self._combine_conditional(elif_b.condition, wp_elif, current_wp)
80+
wp_then = self.wp(stmt.then_branch.statements if stmt.then_branch else [], post)
81+
return self._combine_conditional(stmt.condition, wp_then, current_wp)
82+
83+
elif isinstance(stmt, LoopNode):
84+
# A-Level Simplification: Use provided invariant
85+
if "invariant" in stmt.metadata:
86+
return stmt.metadata["invariant"]
87+
return post
88+
89+
def _combine_conditional(self, cond: ExpressionNode, wp_true: ExpressionNode, wp_false: ExpressionNode) -> ExpressionNode:
90+
left = BinaryOpNode(operator=OperatorType.AND, left=cond, right=wp_true)
91+
not_c = UnaryOpNode(operator=OperatorType.NOT, operand=cond)
92+
right = BinaryOpNode(operator=OperatorType.AND, left=not_c, right=wp_false)
93+
return BinaryOpNode(operator=OperatorType.OR, left=left, right=right)
94+
95+
def _substitute(self, expr: ExpressionNode, var_name: str, replacement: ExpressionNode) -> ExpressionNode:
96+
if isinstance(expr, VariableNode) and expr.name == var_name:
97+
return replacement
98+
if isinstance(expr, BinaryOpNode):
99+
return BinaryOpNode(
100+
operator=expr.operator,
101+
left=self._substitute(expr.left, var_name, replacement),
102+
right=self._substitute(expr.right, var_name, replacement)
103+
)
104+
if isinstance(expr, UnaryOpNode):
105+
return UnaryOpNode(
106+
operator=expr.operator,
107+
operand=self._substitute(expr.operand, var_name, replacement)
108+
)
109+
return expr
110+
111+
# ------------------------
112+
# Z3 Integration
113+
# ------------------------
114+
def verify(self, precondition: ExpressionNode, program: ProgramNode, postcondition: ExpressionNode):
115+
from evaluation_function.equivalence.equivalence import FixedVerificationResult
116+
result = FixedVerificationResult()
117+
try:
118+
self.collect_symbols(precondition, program)
119+
wp_final = self.wp(program.global_statements.statements, postcondition) if program.global_statements else postcondition
120+
121+
success, msg = self._implies(precondition, wp_final)
122+
result.success, result.message = success, msg
123+
result.proof_steps, result.undefined_symbols = self.proof_steps, []
124+
except Exception as e:
125+
result.success, result.message = False, f"Internal Error: {str(e)}"
126+
return result
127+
128+
def _implies(self, pre: ExpressionNode, post: ExpressionNode) -> tuple[bool, str]:
129+
s = Solver()
130+
env = {}
131+
try:
132+
pre_z3 = self._expr_to_z3(pre, env)
133+
post_z3 = self._expr_to_z3(post, env)
134+
135+
# Ensure both are treated as Booleans in Z3 context
136+
# A-Level pseudocode usually treats 1 as True, 0 as False
137+
def ensure_bool(z3_expr):
138+
if not is_bool(z3_expr):
139+
return z3_expr != 0
140+
return z3_expr
141+
142+
s.add(And(ensure_bool(pre_z3), Not(ensure_bool(post_z3))))
143+
144+
if s.check().r == -1: # UNSAT
145+
return True, "✓ Success"
146+
return False, f"✗ Counter-example: {s.model()}"
147+
except Exception as e:
148+
raise ValueError(f"Z3 Error: {e}")
149+
150+
def _expr_to_z3(self, expr: ExpressionNode, env: Dict[str, Any]):
151+
if isinstance(expr, LiteralNode):
152+
if isinstance(expr.value, bool): return BoolVal(expr.value)
153+
# Check if it's a numeric string or actual int
154+
val = int(expr.value) if not isinstance(expr.value, bool) else expr.value
155+
return IntVal(val) if not isinstance(val, bool) else BoolVal(val)
156+
157+
if isinstance(expr, VariableNode):
158+
if expr.name not in env:
159+
# Default to Int for A-Level variables unless specified
160+
env[expr.name] = Int(expr.name)
161+
return env[expr.name]
162+
163+
if isinstance(expr, UnaryOpNode):
164+
operand = self._expr_to_z3(expr.operand, env)
165+
if expr.operator == OperatorType.NOT:
166+
if not is_bool(operand): operand = (operand != 0)
167+
return Not(operand)
168+
if expr.operator == OperatorType.SUBTRACT: return -operand
169+
170+
if isinstance(expr, BinaryOpNode):
171+
l = self._expr_to_z3(expr.left, env)
172+
r = self._expr_to_z3(expr.right, env)
173+
op = expr.operator
174+
175+
# Arithmetic
176+
if op == OperatorType.ADD: return l + r
177+
if op == OperatorType.SUBTRACT: return l - r
178+
if op == OperatorType.MULTIPLY: return l * r
179+
180+
# Comparisons (return Bool)
181+
if op == OperatorType.EQUAL: return l == r
182+
if op == OperatorType.NOT_EQUAL: return l != r
183+
if op == OperatorType.GREATER_THAN: return l > r
184+
if op == OperatorType.GREATER_EQUAL: return l >= r
185+
if op == OperatorType.LESS_THAN: return l < r
186+
if op == OperatorType.LESS_EQUAL: return l <= r
187+
188+
# Logical (ensure operands are Bool)
189+
if op in [OperatorType.AND, OperatorType.OR]:
190+
if not is_bool(l): l = (l != 0)
191+
if not is_bool(r): r = (r != 0)
192+
return And(l, r) if op == OperatorType.AND else Or(l, r)
193+
194+
raise ValueError(f"Unsupported node or operator: {type(expr)} {getattr(expr, 'operator', '')}")
195+
196+
def _expr_to_str(self, expr: ExpressionNode) -> str:
197+
if isinstance(expr, VariableNode): return expr.name
198+
if isinstance(expr, LiteralNode): return str(expr.value)
199+
if isinstance(expr, BinaryOpNode):
200+
return f"({self._expr_to_str(expr.left)} {expr.operator.value} {self._expr_to_str(expr.right)})"
201+
if isinstance(expr, UnaryOpNode):
202+
return f"{expr.operator.value}({self._expr_to_str(expr.operand)})"
203+
return "expr"
204+
205+
def _stmt_to_str(self, stmt) -> str:
206+
if isinstance(stmt, AssignmentNode): return f"{stmt.target.name} = ..."
207+
return str(type(stmt).__name__)

evaluation_function/schemas/ast_nodes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,6 @@ class ProgramNode(ASTNode):
353353
functions: List[FunctionNode] = Field(default_factory=list)
354354
global_statements: Optional[BlockNode] = None
355355

356-
357356
# Update forward references for Pydantic
358357
BinaryOpNode.model_rebuild()
359358
UnaryOpNode.model_rebuild()

0 commit comments

Comments
 (0)