diff --git a/src/Assembly/Equivalence.v b/src/Assembly/Equivalence.v index f3f65f3170..2006770e83 100644 --- a/src/Assembly/Equivalence.v +++ b/src/Assembly/Equivalence.v @@ -1452,30 +1452,61 @@ Section check_equivalence. Local Notation map_err_None v := (ErrorT.map_error (fun e => (None, e)) v). Local Notation map_err_Some label v := (ErrorT.map_error (fun e => (Some label, e)) v). - Definition check_equivalence : ErrorT (option (string (* fname *) * Lines (* asm lines *)) * EquivalenceCheckingError) unit := + Definition map_symex_asm (inputs : list (idx + list idx)) (output_types : type_spec) (d : dag) + : ErrorT + (option (string (* fname *) * Lines (* asm lines *)) * EquivalenceCheckingError) + (list ((string (* fname *) * Lines (* asm lines *)) * (list (idx + list idx) * symbolic_state))) := let reg_available := assembly_calling_registers (* registers available for calling conventions *) in + (ls <-- (List.map + (fun '((fname, asm) as label) + => (asm <- map_err_Some label (strip_ret asm); + let stack_size : nat := N.to_nat (assembly_stack_size asm) in + symevaled_asm <- map_err_Some label (symex_asm_func (dereference_output_scalars:=false) d assembly_callee_saved_registers output_types stack_size inputs reg_available asm); + Success (label, symevaled_asm))) + asm); + Success ls)%error. + + Definition check_equivalence : ErrorT (option (string (* fname *) * Lines (* asm lines *)) * EquivalenceCheckingError) unit := let d := dag.empty in input_types <- map_err_None (simplify_input_type t arg_bounds); output_types <- map_err_None (simplify_base_type (type.final_codomain t) out_bounds); - let '(inputs, d) := build_inputs (descr:=Build_description "build_inputs" true ) input_types d in + let '(inputs, d) := build_inputs (descr:=Build_description "build_inputs" true) input_types d in - PHOAS_output <- map_err_None (symex_PHOAS expr inputs d); - let '(PHOAS_output, d) := PHOAS_output in + ls <- ( + if negb debug_symex_asm_first then ( + PHOAS_output <- map_err_None (symex_PHOAS expr inputs d); + let '(PHOAS_output, d) := PHOAS_output in - let first_new_idx_after_all_old_idxs : option idx := Some (dag.size d) in + let first_new_idx_after_all_old_idxs : option idx := Some (dag.size d) in - _ <-- (List.map - (fun '((fname, asm) as label) - => (asm <- map_err_Some label (strip_ret asm); - let stack_size : nat := N.to_nat (assembly_stack_size asm) in - symevaled_asm <- map_err_Some label (symex_asm_func (dereference_output_scalars:=false) d assembly_callee_saved_registers output_types stack_size inputs reg_available asm); - let '(asm_output, s) := symevaled_asm in + asm_output <- map_symex_asm inputs output_types d; + + let ls := List.map (fun '(lbl, (asm_output, s)) => (lbl, asm_output, PHOAS_output, s, first_new_idx_after_all_old_idxs)) asm_output in + Success ls + ) else ( (* debug version, do asm first *) + asm_output <- map_symex_asm inputs output_types d; + + ls <-- (List.map (fun '(lbl, (asm_output, s)) => + let d := s.(dag_state) in + let first_new_idx_after_all_old_idxs : option idx := Some (dag.size d) in + + PHOAS_output <- map_err_None (symex_PHOAS expr inputs d); + let '(PHOAS_output, d) := PHOAS_output in + + let s := {| dag_state := d; symbolic_reg_state := s.(symbolic_reg_state); symbolic_flag_state := s.(symbolic_flag_state); symbolic_mem_state := s.(symbolic_mem_state) |} in + + Success (lbl, asm_output, PHOAS_output, s, first_new_idx_after_all_old_idxs)) + asm_output); + Success ls + )); + + _ <-- List.map (fun '(lbl, asm_output, PHOAS_output, s, first_new_idx_after_all_old_idxs) => + if list_beq _ (sum_beq _ _ N.eqb (list_beq _ N.eqb)) asm_output PHOAS_output + then Success tt + else Error (Some lbl, Unable_to_unify asm_output PHOAS_output first_new_idx_after_all_old_idxs s)) + ls; + Success tt. - if list_beq _ (sum_beq _ _ N.eqb (list_beq _ N.eqb)) asm_output PHOAS_output - then Success tt - else Error (Some label, Unable_to_unify asm_output PHOAS_output first_new_idx_after_all_old_idxs s))) - asm); - Success tt. (** We don't actually generate assembly, we just check equivalence and pass assembly through unchanged *) Definition generate_assembly_of_hinted_expr : ErrorT (option (string (* fname *) * Lines (* asm lines *)) * EquivalenceCheckingError) (list (string * Lines)) diff --git a/src/Assembly/Symbolic.v b/src/Assembly/Symbolic.v index 5849e9547a..fd6d6cdca3 100644 --- a/src/Assembly/Symbolic.v +++ b/src/Assembly/Symbolic.v @@ -401,6 +401,8 @@ Module Export Options. than every time, because it is (currently) quadratic to compute in the number of passes *) Class rewriting_passes_opt := rewriting_passes : list rewrite_pass. + (** Should we symex the assembly first, even though this may be more inefficient? *) + Class debug_symex_asm_first_opt := debug_symex_asm_first : bool. Definition default_rewriting_passes {rewriting_pipeline : rewriting_pipeline_opt} {rewriting_pass_filter : rewriting_pass_filter_opt} @@ -410,17 +412,20 @@ Module Export Options. Class symbolic_options_opt := { asm_rewriting_pipeline : rewriting_pipeline_opt ; asm_rewriting_pass_filter : rewriting_pass_filter_opt + ; asm_debug_symex_asm_first : debug_symex_asm_first_opt }. (* This holds the list of computed options, which are passed around between methods *) Class symbolic_options_computed_opt := { asm_rewriting_passes : rewriting_passes_opt + ; asm_debug_symex_asm_first_computed : debug_symex_asm_first_opt }. (* N.B. The default rewriting pass filter should not be changed here, but instead changed in CLI.v where it is derived from a default string *) Definition default_symbolic_options : symbolic_options_opt := {| asm_rewriting_pipeline := default_rewrite_pass_order ; asm_rewriting_pass_filter := fun _ => true + ; asm_debug_symex_asm_first := false |}. End Options. Module Export Hints. @@ -431,6 +436,8 @@ Module Export Hints. asm_rewriting_pipeline asm_rewriting_pass_filter asm_rewriting_passes + asm_debug_symex_asm_first + asm_debug_symex_asm_first_computed . #[global] Hint Cut [ @@ -438,6 +445,8 @@ Module Export Hints. (asm_rewriting_pipeline | asm_rewriting_pass_filter | asm_rewriting_passes + | asm_debug_symex_asm_first + | asm_debug_symex_asm_first_computed ) ( _ * ) (Build_symbolic_options_opt | Build_symbolic_options_computed_opt diff --git a/src/Assembly/WithBedrock/Proofs.v b/src/Assembly/WithBedrock/Proofs.v index 08ec58e652..ff7e9fc1f1 100644 --- a/src/Assembly/WithBedrock/Proofs.v +++ b/src/Assembly/WithBedrock/Proofs.v @@ -3346,9 +3346,9 @@ Theorem check_equivalence_correct /\ R_runtime_output (output_scalars_are_pointers:=output_scalars_are_pointers) frame retvals (type_spec_of_runtime args) stack_size stack_base asm_args_out asm_args_in assembly_callee_saved_registers runtime_callee_saved_registers st') asm. Proof. - cbv beta delta [check_equivalence ErrorT.bind] in H. + cbv beta delta [check_equivalence map_symex_asm ErrorT.bind] in H. repeat - first [ rewrite List.ErrorT.List.bind_list_cps_id, <- List.ErrorT.List.eq_bind_list_lift in H; + first [ rewrite List.ErrorT.List.bind_list_cps_id, <- !List.ErrorT.List.eq_bind_list_lift in H; cbv beta delta [ErrorT.bind] in H | match type of H with | (let n := ?v in _) = _ @@ -3373,26 +3373,52 @@ Proof. destruct v eqn:?; [ change (T = rhs) in H | change (F = rhs) in H ]; cbv beta in H end ]; try discriminate; []. - cbv beta delta [map_error ErrorT.map2 id] in *. - break_innermost_match_hyps; inversion_ErrorT; subst. - rewrite @List.ErrorT.List.lift_Success_Forall2_iff in *. - progress rewrite ?@Forall2_map_l_iff, ?@Forall2_map_r_iff in *. - Foralls_to_nth_error. - intros; inversion_ErrorT; subst. - progress reflect_hyps. - subst. - pose proof empty_gensym_dag_ok. - let H := fresh in pose proof Hargs as H; eapply build_input_runtime_ok_nil in H; [ | eassumption .. ]. - pose proof (word_args_to_Z_args_bounded word_args). - repeat first [ assumption + progress cbv beta delta [map_error ErrorT.map2 id] in *. + repeat first [ match goal with + | [ H : context G[let x := ?y in @?P x] |- _ ] => + tryif is_var y then + let G' := context G[P y] in + progress change G' in H + else + let h := fresh x in + set (h := y) in *; + let G' := context G[P h] in + progress change G' in H + | [ H : context[let x := ?y in _] |- _ ] => tryif is_var y then fail else let h := fresh x in set (h := y) in * + | [ H : context[List.ErrorT.List.bind_list] |- _ ] + => rewrite List.ErrorT.List.bind_list_cps_id, <- !List.ErrorT.List.eq_bind_list_lift in H; + cbv beta delta [ErrorT.bind] in H + | [ H := Some _ |- _ ] => subst H + | [ H := None |- _ ] => subst H + | [ H := List.map _ _ |- _ ] => subst H + | [ H : context[match List.ErrorT.List.lift ?x with _ => _ end] |- _ ] + => destruct (List.ErrorT.List.lift x) eqn:? + end + | progress cbv beta in * + | progress inversion_ErrorT + | progress subst + | progress break_innermost_match_hyps + | progress rewrite @List.ErrorT.List.lift_Success_Forall2_iff in * + | progress rewrite ?@Forall2_map_l_iff, ?@Forall2_map_r_iff in * ]. + all: Foralls_to_nth_error. + all: intros; inversion_ErrorT; inversion_pair; subst. + all: progress reflect_hyps. + all: subst. + all: pose proof empty_gensym_dag_ok. + all: let H := fresh in pose proof Hargs as H; eapply build_input_runtime_ok_nil in H; [ | eassumption .. ]. + all: pose proof (word_args_to_Z_args_bounded word_args). + all: repeat + first [ assumption | match goal with | [ H : build_inputs _ _ = _ |- _ ] => move H at bottom; eapply build_inputs_ok in H; [ | eassumption .. ] | [ H : symex_PHOAS ?expr ?inputs ?d = Success _, H' : build_input_runtime _ ?ri = Some _ |- _ ] - => move H at bottom; eapply symex_PHOAS_correct with (runtime_inputs:=ri) in H; [ | eassumption .. ] + => move H at bottom; eapply symex_PHOAS_correct with (runtime_inputs:=ri) in H; + [ | try eassumption .. ]; + [ | first [ assumption | eapply Forall2_weaken; [ apply lift_eval_idx_or_list_idx_impl | eassumption ] ] .. ] | [ H : symex_asm_func _ _ _ _ _ _ _ = Success _ |- _ ] => move H at bottom; eapply symex_asm_func_correct in H; - [ | try (eassumption + apply surjective_pairing + reflexivity) .. ]; - [ | clear H; eapply Forall2_weaken; [ apply lift_eval_idx_or_list_idx_impl | eassumption ] ] + [ | (eassumption + apply surjective_pairing + reflexivity + trivial) .. ]; + [ | clear H; first [ assumption | eapply Forall2_weaken; [ apply lift_eval_idx_or_list_idx_impl | eassumption ] ] ] end | progress destruct_head'_ex | progress destruct_head'_and @@ -3402,6 +3428,7 @@ Proof. | match goal with | [ H : ?x = Some ?a, H' : ?x = Some ?b |- _ ] => rewrite H in H'; inversion_option + | [ H : ?x = Some _, H' : ?x = None |- _ ] => exfalso; clear -H H'; rewrite H in H'; inversion_option | [ H : forall args, Forall2 ?P args ?v -> Forall2 _ _ _, H' : Forall2 ?P _ ?v |- _ ] => specialize (H _ H') | [ Himpl : forall e n, eval ?G1 ?d1 e n -> eval ?G2 ?d2 e n, @@ -3413,7 +3440,7 @@ Proof. subst | [ H := _ |- _ ] => subst H end ]. - do 3 eexists; repeat first [ eassumption | apply conj ]; trivial. + all: do 3 eexists; repeat first [ eassumption | apply conj ]; trivial. Qed. Theorem generate_assembly_of_hinted_expr_correct diff --git a/src/CLI.v b/src/CLI.v index 58fff1c8ce..06f74cf12a 100644 --- a/src/CLI.v +++ b/src/CLI.v @@ -547,6 +547,10 @@ Module ForExtraction. := ([Arg.long_key "asm-rewriting-passes"], Arg.String, ["A comma-separated list of rewriting passes to enable. Prefix with - to disable a pass. This list only impacts passes listed in --asm-rewriting-pipeline. Default : " ++ (if default_asm_rewriting_passes =? "" then "(none)" else default_asm_rewriting_passes)]%string ++ describe_flag_options "rewriting pass" "Enable all rewriting passes" special_asm_rewriting_pass_flags known_asm_rewriting_pass_flags_with_spec)%list. + Definition asm_debug_symex_asm_first_spec : named_argT + := ([Arg.long_key "debug-asm-symex-first"], + Arg.Unit, + ["Debug option: If true, the assembly equivalence checker will symex the assembly first, even though this may be more inefficient. This may be useful for having a more concise description of errors in assembly symbolic execution."]). Definition doc_text_before_function_name_spec : named_argT := ([Arg.long_key "doc-text-before-function-name"], Arg.String, @@ -730,6 +734,7 @@ Module ForExtraction. ; asm_error_on_unique_names_mismatch_spec ; asm_rewriting_pipeline_spec ; asm_rewriting_passes_spec + ; asm_debug_symex_asm_first_spec ; doc_text_before_function_name_spec ; doc_text_before_type_name_spec ; doc_newline_before_package_declaration_spec @@ -788,6 +793,7 @@ Module ForExtraction. , asm_error_on_unique_names_mismatchv , asm_rewriting_pipelinev , asm_rewriting_passesv + , asm_debug_symex_asm_firstv , doc_text_before_function_namev , doc_text_before_type_namev , doc_newline_before_package_declarationv @@ -875,6 +881,7 @@ Module ForExtraction. ; symbolic_options_ := {| asm_rewriting_pipeline := to_rewriting_pipeline_list asm_rewriting_pipelinev ; asm_rewriting_pass_filter := fun p => asm_rewriting_pass_filterv (show_rewrite_pass p) + ; asm_debug_symex_asm_first := to_bool asm_debug_symex_asm_firstv |} |} ; ignore_unique_asm_names := negb (to_bool asm_error_on_unique_names_mismatchv)