1+ from typing import Any , Dict , List , Set , Union
2+ from z3 import Solver , Int , Bool , And , Or , Not , IntVal , BoolVal , simplify , Implies
3+
4+ from ..schemas .ast_nodes import *
5+ # Assuming imports from your specific file structure
6+ class FixedVerificationResult :
7+ def __init__ (self ):
8+ self .success : bool = False
9+ self .message : str = ""
10+ self .undefined_symbols : List [str ] = []
11+ self .proof_steps : List [str ] = []
12+
13+ def to_dict (self ) -> Dict :
14+ return {
15+ "success" : self .success ,
16+ "message" : self .message ,
17+ "undefined_symbols" : self .undefined_symbols ,
18+ "proof_steps" : self .proof_steps
19+ }
20+ from typing import Any , Dict , List , Set , Optional , Tuple
21+ from z3 import Solver , Int , Bool , And , Or , Not , IntVal , BoolVal , Implies , is_bool
22+ from evaluation_function .schemas .ast_nodes import (
23+ ProgramNode , BlockNode , AssignmentNode , VariableNode , LiteralNode ,
24+ BinaryOpNode , UnaryOpNode , ConditionalNode , LoopNode ,
25+ OperatorType , ExpressionNode , NodeType , LoopType
26+ )
27+
28+ class FixedALevelVerifier :
29+ def __init__ (self ):
30+ self .defined_variables : Set [str ] = set ()
31+ self .proof_steps : List [str ] = []
32+
33+ def collect_symbols (self , precondition : ExpressionNode , program : ProgramNode ) -> None :
34+ self .defined_variables .clear ()
35+ self ._collect_from_expr (precondition )
36+ if program .global_statements :
37+ self ._collect_from_block (program .global_statements )
38+
39+ def _collect_from_expr (self , expr : Optional [ExpressionNode ]):
40+ if not expr : return
41+ if isinstance (expr , VariableNode ):
42+ self .defined_variables .add (expr .name )
43+ elif isinstance (expr , BinaryOpNode ):
44+ self ._collect_from_expr (expr .left )
45+ self ._collect_from_expr (expr .right )
46+ elif isinstance (expr , UnaryOpNode ):
47+ self ._collect_from_expr (expr .operand )
48+
49+ def _collect_from_block (self , block : BlockNode ):
50+ for stmt in block .statements :
51+ if isinstance (stmt , AssignmentNode ) and isinstance (stmt .target , VariableNode ):
52+ self .defined_variables .add (stmt .target .name )
53+ elif isinstance (stmt , ConditionalNode ):
54+ if stmt .then_branch : self ._collect_from_block (stmt .then_branch )
55+ for elif_b in stmt .elif_branches :
56+ if elif_b .then_branch : self ._collect_from_block (elif_b .then_branch )
57+ if stmt .else_branch : self ._collect_from_block (stmt .else_branch )
58+ elif isinstance (stmt , LoopNode ) and stmt .body :
59+ self ._collect_from_block (stmt .body )
60+
61+ # ------------------------
62+ # WP Logic
63+ # ------------------------
64+ def wp (self , statements : List [Any ], post : ExpressionNode ) -> ExpressionNode :
65+ wp_expr = post
66+ for stmt in reversed (statements ):
67+ wp_expr = self ._wp_stmt (stmt , wp_expr )
68+ self .proof_steps .append (f"WP after { self ._stmt_to_str (stmt )} : { self ._expr_to_str (wp_expr )} " )
69+ return wp_expr
70+
71+ def _wp_stmt (self , stmt : Any , post : ExpressionNode ) -> ExpressionNode :
72+ if isinstance (stmt , AssignmentNode ) and isinstance (stmt .target , VariableNode ):
73+ return self ._substitute (post , stmt .target .name , stmt .value )
74+
75+ elif isinstance (stmt , ConditionalNode ):
76+ current_wp = self .wp (stmt .else_branch .statements if stmt .else_branch else [], post )
77+ for elif_b in reversed (stmt .elif_branches ):
78+ wp_elif = self .wp (elif_b .then_branch .statements if elif_b .then_branch else [], post )
79+ current_wp = self ._combine_conditional (elif_b .condition , wp_elif , current_wp )
80+ wp_then = self .wp (stmt .then_branch .statements if stmt .then_branch else [], post )
81+ return self ._combine_conditional (stmt .condition , wp_then , current_wp )
82+
83+ elif isinstance (stmt , LoopNode ):
84+ # A-Level Simplification: Use provided invariant
85+ if "invariant" in stmt .metadata :
86+ return stmt .metadata ["invariant" ]
87+ return post
88+
89+ def _combine_conditional (self , cond : ExpressionNode , wp_true : ExpressionNode , wp_false : ExpressionNode ) -> ExpressionNode :
90+ left = BinaryOpNode (operator = OperatorType .AND , left = cond , right = wp_true )
91+ not_c = UnaryOpNode (operator = OperatorType .NOT , operand = cond )
92+ right = BinaryOpNode (operator = OperatorType .AND , left = not_c , right = wp_false )
93+ return BinaryOpNode (operator = OperatorType .OR , left = left , right = right )
94+
95+ def _substitute (self , expr : ExpressionNode , var_name : str , replacement : ExpressionNode ) -> ExpressionNode :
96+ if isinstance (expr , VariableNode ) and expr .name == var_name :
97+ return replacement
98+ if isinstance (expr , BinaryOpNode ):
99+ return BinaryOpNode (
100+ operator = expr .operator ,
101+ left = self ._substitute (expr .left , var_name , replacement ),
102+ right = self ._substitute (expr .right , var_name , replacement )
103+ )
104+ if isinstance (expr , UnaryOpNode ):
105+ return UnaryOpNode (
106+ operator = expr .operator ,
107+ operand = self ._substitute (expr .operand , var_name , replacement )
108+ )
109+ return expr
110+
111+ # ------------------------
112+ # Z3 Integration
113+ # ------------------------
114+ def verify (self , precondition : ExpressionNode , program : ProgramNode , postcondition : ExpressionNode ):
115+ from evaluation_function .equivalence .equivalence import FixedVerificationResult
116+ result = FixedVerificationResult ()
117+ try :
118+ self .collect_symbols (precondition , program )
119+ wp_final = self .wp (program .global_statements .statements , postcondition ) if program .global_statements else postcondition
120+
121+ success , msg = self ._implies (precondition , wp_final )
122+ result .success , result .message = success , msg
123+ result .proof_steps , result .undefined_symbols = self .proof_steps , []
124+ except Exception as e :
125+ result .success , result .message = False , f"Internal Error: { str (e )} "
126+ return result
127+
128+ def _implies (self , pre : ExpressionNode , post : ExpressionNode ) -> tuple [bool , str ]:
129+ s = Solver ()
130+ env = {}
131+ try :
132+ pre_z3 = self ._expr_to_z3 (pre , env )
133+ post_z3 = self ._expr_to_z3 (post , env )
134+
135+ # Ensure both are treated as Booleans in Z3 context
136+ # A-Level pseudocode usually treats 1 as True, 0 as False
137+ def ensure_bool (z3_expr ):
138+ if not is_bool (z3_expr ):
139+ return z3_expr != 0
140+ return z3_expr
141+
142+ s .add (And (ensure_bool (pre_z3 ), Not (ensure_bool (post_z3 ))))
143+
144+ if s .check ().r == - 1 : # UNSAT
145+ return True , "✓ Success"
146+ return False , f"✗ Counter-example: { s .model ()} "
147+ except Exception as e :
148+ raise ValueError (f"Z3 Error: { e } " )
149+
150+ def _expr_to_z3 (self , expr : ExpressionNode , env : Dict [str , Any ]):
151+ if isinstance (expr , LiteralNode ):
152+ if isinstance (expr .value , bool ): return BoolVal (expr .value )
153+ # Check if it's a numeric string or actual int
154+ val = int (expr .value ) if not isinstance (expr .value , bool ) else expr .value
155+ return IntVal (val ) if not isinstance (val , bool ) else BoolVal (val )
156+
157+ if isinstance (expr , VariableNode ):
158+ if expr .name not in env :
159+ # Default to Int for A-Level variables unless specified
160+ env [expr .name ] = Int (expr .name )
161+ return env [expr .name ]
162+
163+ if isinstance (expr , UnaryOpNode ):
164+ operand = self ._expr_to_z3 (expr .operand , env )
165+ if expr .operator == OperatorType .NOT :
166+ if not is_bool (operand ): operand = (operand != 0 )
167+ return Not (operand )
168+ if expr .operator == OperatorType .SUBTRACT : return - operand
169+
170+ if isinstance (expr , BinaryOpNode ):
171+ l = self ._expr_to_z3 (expr .left , env )
172+ r = self ._expr_to_z3 (expr .right , env )
173+ op = expr .operator
174+
175+ # Arithmetic
176+ if op == OperatorType .ADD : return l + r
177+ if op == OperatorType .SUBTRACT : return l - r
178+ if op == OperatorType .MULTIPLY : return l * r
179+
180+ # Comparisons (return Bool)
181+ if op == OperatorType .EQUAL : return l == r
182+ if op == OperatorType .NOT_EQUAL : return l != r
183+ if op == OperatorType .GREATER_THAN : return l > r
184+ if op == OperatorType .GREATER_EQUAL : return l >= r
185+ if op == OperatorType .LESS_THAN : return l < r
186+ if op == OperatorType .LESS_EQUAL : return l <= r
187+
188+ # Logical (ensure operands are Bool)
189+ if op in [OperatorType .AND , OperatorType .OR ]:
190+ if not is_bool (l ): l = (l != 0 )
191+ if not is_bool (r ): r = (r != 0 )
192+ return And (l , r ) if op == OperatorType .AND else Or (l , r )
193+
194+ raise ValueError (f"Unsupported node or operator: { type (expr )} { getattr (expr , 'operator' , '' )} " )
195+
196+ def _expr_to_str (self , expr : ExpressionNode ) -> str :
197+ if isinstance (expr , VariableNode ): return expr .name
198+ if isinstance (expr , LiteralNode ): return str (expr .value )
199+ if isinstance (expr , BinaryOpNode ):
200+ return f"({ self ._expr_to_str (expr .left )} { expr .operator .value } { self ._expr_to_str (expr .right )} )"
201+ if isinstance (expr , UnaryOpNode ):
202+ return f"{ expr .operator .value } ({ self ._expr_to_str (expr .operand )} )"
203+ return "expr"
204+
205+ def _stmt_to_str (self , stmt ) -> str :
206+ if isinstance (stmt , AssignmentNode ): return f"{ stmt .target .name } = ..."
207+ return str (type (stmt ).__name__ )
0 commit comments