From 2e2434a6bff433f6dc99dcd07ef9613f9c6e0bab Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Sat, 1 Mar 2025 01:38:51 -0800 Subject: [PATCH] Use Listable to prove equality of registers and opcodes Also add rip, eip, ip registers to ensure that we no longer depend on the exact number of registers. This is much faster, especially as we add more registers and opcodes --- src/Assembly/Equality.v | 20 +- src/Assembly/Equivalence.v | 4 +- src/Assembly/EquivalenceProofs.v | 4 +- src/Assembly/Parse.v | 20 -- src/Assembly/Symbolic.v | 24 +- src/Assembly/Syntax.v | 372 ++++++++++------------ src/Assembly/SyntaxTests.v | 155 +++++++++ src/Assembly/WithBedrock/Proofs.v | 36 ++- src/Assembly/WithBedrock/Semantics.v | 8 +- src/Assembly/WithBedrock/SymbolicProofs.v | 4 +- src/Util/Listable.v | 1 + 11 files changed, 376 insertions(+), 272 deletions(-) create mode 100644 src/Assembly/SyntaxTests.v diff --git a/src/Assembly/Equality.v b/src/Assembly/Equality.v index 4575629772..336da900e8 100644 --- a/src/Assembly/Equality.v +++ b/src/Assembly/Equality.v @@ -35,8 +35,8 @@ Bind Scope REG_scope with REG. Infix "=?" := REG_beq : REG_scope. Global Instance REG_beq_spec : reflect_rel (@eq REG) REG_beq | 10 - := reflect_of_beq internal_REG_dec_bl internal_REG_dec_lb. -Definition REG_beq_eq x y : (x =? y)%REG = true <-> x = y := conj (@internal_REG_dec_bl _ _) (@internal_REG_dec_lb _ _). + := reflect_of_beq REG_dec_bl REG_dec_lb. +Definition REG_beq_eq x y : (x =? y)%REG = true <-> x = y := conj (@REG_dec_bl _ _) (@REG_dec_lb _ _). Lemma REG_beq_neq x y : (x =? y)%REG = false <-> x <> y. Proof. rewrite <- REG_beq_eq; destruct (x =? y)%REG; intuition congruence. Qed. Global Instance REG_beq_compat : Proper (eq ==> eq ==> eq) REG_beq | 10. @@ -95,8 +95,8 @@ Bind Scope AccessSize_scope with AccessSize. Infix "=?" := AccessSize_beq : AccessSize_scope. Global Instance AccessSize_beq_spec : reflect_rel (@eq AccessSize) AccessSize_beq | 10 - := reflect_of_beq internal_AccessSize_dec_bl internal_AccessSize_dec_lb. -Definition AccessSize_beq_eq x y : (x =? y)%AccessSize = true <-> x = y := conj (@internal_AccessSize_dec_bl _ _) (@internal_AccessSize_dec_lb _ _). + := reflect_of_beq AccessSize_dec_bl AccessSize_dec_lb. +Definition AccessSize_beq_eq x y : (x =? y)%AccessSize = true <-> x = y := conj (@AccessSize_dec_bl _ _) (@AccessSize_dec_lb _ _). Lemma AccessSize_beq_neq x y : (x =? y)%AccessSize = false <-> x <> y. Proof. rewrite <- AccessSize_beq_eq; destruct (x =? y)%AccessSize; intuition congruence. Qed. Global Instance AccessSize_beq_compat : Proper (eq ==> eq ==> eq) AccessSize_beq | 10. @@ -141,8 +141,8 @@ Bind Scope FLAG_scope with FLAG. Infix "=?" := FLAG_beq : FLAG_scope. Global Instance FLAG_beq_spec : reflect_rel (@eq FLAG) FLAG_beq | 10 - := reflect_of_beq internal_FLAG_dec_bl internal_FLAG_dec_lb. -Definition FLAG_beq_eq x y : (x =? y)%FLAG = true <-> x = y := conj (@internal_FLAG_dec_bl _ _) (@internal_FLAG_dec_lb _ _). + := reflect_of_beq FLAG_dec_bl FLAG_dec_lb. +Definition FLAG_beq_eq x y : (x =? y)%FLAG = true <-> x = y := conj (@FLAG_dec_bl _ _) (@FLAG_dec_lb _ _). Lemma FLAG_beq_neq x y : (x =? y)%FLAG = false <-> x <> y. Proof. rewrite <- FLAG_beq_eq; destruct (x =? y)%FLAG; intuition congruence. Qed. Global Instance FLAG_beq_compat : Proper (eq ==> eq ==> eq) FLAG_beq | 10. @@ -155,8 +155,8 @@ Bind Scope OpCode_scope with OpCode. Infix "=?" := OpCode_beq : OpCode_scope. Global Instance OpCode_beq_spec : reflect_rel (@eq OpCode) OpCode_beq | 10 - := reflect_of_beq internal_OpCode_dec_bl internal_OpCode_dec_lb. -Definition OpCode_beq_eq x y : (x =? y)%OpCode = true <-> x = y := conj (@internal_OpCode_dec_bl _ _) (@internal_OpCode_dec_lb _ _). + := reflect_of_beq OpCode_dec_bl OpCode_dec_lb. +Definition OpCode_beq_eq x y : (x =? y)%OpCode = true <-> x = y := conj (@OpCode_dec_bl _ _) (@OpCode_dec_lb _ _). Lemma OpCode_beq_neq x y : (x =? y)%OpCode = false <-> x <> y. Proof. rewrite <- OpCode_beq_eq; destruct (x =? y)%OpCode; intuition congruence. Qed. Global Instance OpCode_beq_compat : Proper (eq ==> eq ==> eq) OpCode_beq | 10. @@ -169,8 +169,8 @@ Bind Scope OpPrefix_scope with OpPrefix. Infix "=?" := OpPrefix_beq : OpPrefix_scope. Global Instance OpPrefix_beq_spec : reflect_rel (@eq OpPrefix) OpPrefix_beq | 10 - := reflect_of_beq internal_OpPrefix_dec_bl internal_OpPrefix_dec_lb. -Definition OpPrefix_beq_eq x y : (x =? y)%OpPrefix = true <-> x = y := conj (@internal_OpPrefix_dec_bl _ _) (@internal_OpPrefix_dec_lb _ _). + := reflect_of_beq OpPrefix_dec_bl OpPrefix_dec_lb. +Definition OpPrefix_beq_eq x y : (x =? y)%OpPrefix = true <-> x = y := conj (@OpPrefix_dec_bl _ _) (@OpPrefix_dec_lb _ _). Lemma OpPrefix_beq_neq x y : (x =? y)%OpPrefix = false <-> x <> y. Proof. rewrite <- OpPrefix_beq_eq; destruct (x =? y)%OpPrefix; intuition congruence. Qed. Global Instance OpPrefix_beq_compat : Proper (eq ==> eq ==> eq) OpPrefix_beq | 10. diff --git a/src/Assembly/Equivalence.v b/src/Assembly/Equivalence.v index aff15463cd..bf5dc1f3fc 100644 --- a/src/Assembly/Equivalence.v +++ b/src/Assembly/Equivalence.v @@ -1275,10 +1275,10 @@ Definition init_symbolic_state_descr : description := Build_description "init_sy Definition init_symbolic_state (d : dag) : symbolic_state := let _ := init_symbolic_state_descr in - let '(initial_reg_idxs, d) := dag_gensym_n 16 d in + let '(initial_reg_idxs, d) := dag_gensym_n (List.length widest_registers) d in {| dag_state := d; - symbolic_reg_state := Tuple.from_list_default None 16 (List.map Some initial_reg_idxs); + symbolic_reg_state := Tuple.from_list_default None _ (List.map Some initial_reg_idxs); symbolic_mem_state := []; symbolic_flag_state := Tuple.repeat None 6; |}. diff --git a/src/Assembly/EquivalenceProofs.v b/src/Assembly/EquivalenceProofs.v index a4340e146f..f086330c74 100644 --- a/src/Assembly/EquivalenceProofs.v +++ b/src/Assembly/EquivalenceProofs.v @@ -1847,7 +1847,7 @@ Qed. (* TODO: this is Symbolic.get_reg; move to SymbolicProofs? *) Lemma get_reg_set_reg_full s rn rn' v : get_reg (set_reg s rn v) rn' - = if ((rn n) _ s)) && (rn =? rn'))%nat%bool + = if ((rn N.of_nat n) _ s)) && (rn =? rn'))%N%bool then Some v else get_reg s rn'. Proof. @@ -1863,7 +1863,7 @@ Qed. (* TODO: this is Symbolic.get_reg; move to SymbolicProofs? *) Local Lemma get_reg_set_reg_same s rn v - (H : (rn < (fun n (_ : Tuple.tuple _ n) => n) _ s)%nat) + (H : (rn < (fun n (_ : Tuple.tuple _ n) => N.of_nat n) _ s)%N) : get_reg (set_reg s rn v) rn = Some v. Proof. rewrite get_reg_set_reg_full; break_innermost_match; reflect_hyps; cbv beta in *; try reflexivity; lia. diff --git a/src/Assembly/Parse.v b/src/Assembly/Parse.v index 7b0f313d2a..486c886602 100644 --- a/src/Assembly/Parse.v +++ b/src/Assembly/Parse.v @@ -22,34 +22,18 @@ Local Open Scope list_scope. Local Open Scope string_scope. Local Open Scope parse_scope. -Derive REG_Listable SuchThat (@FinitelyListable REG REG_Listable) As REG_FinitelyListable. -Proof. prove_ListableDerive. Qed. -Global Existing Instances REG_Listable REG_FinitelyListable. - Global Instance show_REG : Show REG. Proof. prove_Show_enum (). Defined. Global Instance show_lvl_REG : ShowLevel REG := show_REG. -Derive FLAG_Listable SuchThat (@FinitelyListable FLAG FLAG_Listable) As FLAG_FinitelyListable. -Proof. prove_ListableDerive. Qed. -Global Existing Instances FLAG_Listable FLAG_FinitelyListable. - Global Instance show_FLAG : Show FLAG. Proof. prove_Show_enum (). Defined. Global Instance show_lvl_FLAG : ShowLevel FLAG := show_FLAG. -Derive OpCode_Listable SuchThat (@FinitelyListable OpCode OpCode_Listable) As OpCode_FinitelyListable. -Proof. prove_ListableDerive. Qed. -Global Existing Instances OpCode_Listable OpCode_FinitelyListable. - Global Instance show_OpCode : Show OpCode. Proof. prove_Show_enum (). Defined. Global Instance show_lvl_OpCode : ShowLevel OpCode := show_OpCode. -Derive OpPrefix_Listable SuchThat (@FinitelyListable OpPrefix OpPrefix_Listable) As OpPrefix_FinitelyListable. -Proof. prove_ListableDerive. Qed. -Global Existing Instances OpPrefix_Listable OpPrefix_FinitelyListable. - Global Instance show_OpPrefix : Show OpPrefix. Proof. prove_Show_enum (). Defined. Global Instance show_lvl_OpPrefix : ShowLevel OpPrefix := show_OpPrefix. @@ -72,10 +56,6 @@ Definition parse_FLAG_list : list (string * FLAG) Definition parse_FLAG : ParserAction FLAG := parse_strs parse_FLAG_list. -Derive AccessSize_Listable SuchThat (@FinitelyListable AccessSize AccessSize_Listable) As AccessSize_FinitelyListable. -Proof. prove_ListableDerive. Qed. -Global Existing Instances AccessSize_Listable AccessSize_FinitelyListable. - Global Instance show_AccessSize : Show AccessSize. Proof. prove_Show_enum (). Defined. Global Instance show_lvl_AccessSize : ShowLevel AccessSize := show_AccessSize. diff --git a/src/Assembly/Symbolic.v b/src/Assembly/Symbolic.v index 593a30c720..99eb5546a2 100644 --- a/src/Assembly/Symbolic.v +++ b/src/Assembly/Symbolic.v @@ -3829,7 +3829,7 @@ Definition simplify {opts : symbolic_options_computed_opt} (dag : dag) (e : node Lemma eval_simplify {opts : symbolic_options_computed_opt} G d n v : gensym_dag_ok G d -> eval_node G d n v -> eval G d (simplify d n) v. Proof using Type. eauto using Rewrite.eval_expr, eval_node_reveal_node_at_least. Qed. -Definition reg_state := Tuple.tuple (option idx) 16. +Definition reg_state := Tuple.tuple (option idx) (compute! (List.length widest_registers)). Definition flag_state := Tuple.tuple (option idx) 6. Definition mem_state := list (idx * idx). @@ -3863,16 +3863,20 @@ Definition reverse_lookup_flag (st : flag_state) (i : idx) : option FLAG (List.find (fun v => option_beq N.eqb (Some i) (fst v)) (Tuple.to_list _ (Tuple.map2 (@pair _ _) st (CF, PF, AF, ZF, SF, OF)))). -Definition get_reg (st : reg_state) (ri : nat) : option idx - := Tuple.nth_default None ri st. -Definition set_reg (st : reg_state) ri (i : idx) : reg_state +Definition is_ip_register_index (ri : N) : bool := + REG_beq (widest_register_of_index ri) rip. +Definition get_reg (st : reg_state) (ri : N) : option idx + := if is_ip_register_index ri + then None + else Tuple.nth_default None (N.to_nat ri) st. +Definition set_reg (st : reg_state) (ri : N) (i : idx) : reg_state := Tuple.from_list_default None _ (ListUtil.set_nth - ri + (N.to_nat ri) (Some i) (Tuple.to_list _ st)). Definition reverse_lookup_widest_reg (st : reg_state) (i : idx) : option REG := option_map - (fun v => widest_register_of_index (fst v)) + (fun v => widest_register_of_index (N.of_nat (fst v))) (List.find (fun v => option_beq N.eqb (Some i) (snd v)) (List.enumerate (Tuple.to_list _ st))). @@ -3906,7 +3910,7 @@ Definition update_mem_with (st : symbolic_state) (f : mem_state -> mem_state) : := {| dag_state := st.(dag_state); symbolic_reg_state := st.(symbolic_reg_state) ; symbolic_flag_state := st.(symbolic_flag_state) ; symbolic_mem_state := f st.(symbolic_mem_state) |}. Global Instance show_reg_state : Show reg_state := fun st => - show (List.map (fun '(n, v) => (widest_register_of_index n, v)) (ListUtil.List.enumerate (Option.List.map id (Tuple.to_list _ st)))). + show (List.combine widest_registers (Option.List.map id (Tuple.to_list _ st))). Global Instance show_flag_state : Show flag_state := fun '(cfv, pfv, afv, zfv, sfv, ofv) => ( @@ -3953,7 +3957,7 @@ Module error. Local Unset Decidable Equality Schemes. Variant error := | get_flag (f : FLAG) (s : flag_state) - | get_reg (r : nat + REG) (s : reg_state) + | get_reg (r : N + REG) (s : reg_state) | load (a : idx) (s : symbolic_state) | remove (a : idx) (s : symbolic_state) | remove_has_duplicates (a : idx) (vs : list idx) (s : symbolic_state) @@ -3977,7 +3981,7 @@ Module error. => ["In flag state " ++ show_flag_state s; "Flag " ++ show f ++ " was read without being set."] | get_reg (inl i) s - => ["Invalid reg index " ++ show_nat i] + => ["Invalid reg index " ++ show i] | get_reg (inr r) s => ["In reg state " ++ show_reg_state s; "Register " ++ show (r : REG) ++ " read without being set."] @@ -4042,7 +4046,7 @@ Definition mapM_ {A B} (f: A -> M B) l : M unit := _ <- mapM f l; ret tt. Definition error_get_reg_of_reg_index ri : symbolic_state -> error := error.get_reg (let r := widest_register_of_index ri in - if (reg_index r =? ri)%nat + if (reg_index r =? ri)%N then inr r else inl ri). diff --git a/src/Assembly/Syntax.v b/src/Assembly/Syntax.v index 50dd2e5046..d2a9cc5399 100644 --- a/src/Assembly/Syntax.v +++ b/src/Assembly/Syntax.v @@ -3,7 +3,11 @@ From Coq Require Import NArith. From Coq Require Import String. From Coq Require Import List. From Coq Require Import Derive. +Require Import Crypto.Util.Prod. Require Import Crypto.Util.Option. +Require Import Crypto.Util.Bool.Reflect. +Require Import Crypto.Util.Listable. +Require Import Crypto.Util.ListUtil. Require Crypto.Util.Tuple. Require Crypto.Util.OptionList. Import ListNotations. @@ -11,21 +15,36 @@ Import ListNotations. Local Open Scope list_scope. Local Set Implicit Arguments. -Local Set Boolean Equality Schemes. -Local Set Decidable Equality Schemes. Local Set Primitive Projections. Inductive REG := -| rax | rcx | rdx | rbx | rsp | rbp | rsi | rdi | r8 | r9 | r10 | r11 | r12 | r13 | r14 | r15 -| eax | ecx | edx | ebx | esp | ebp | esi | edi | r8d | r9d | r10d | r11d | r12d | r13d | r14d | r15d -| ax | cx | dx | bx | sp | bp | si | di | r8w | r9w | r10w | r11w | r12w | r13w | r14w | r15w +| rax | rcx | rdx | rbx | rsp | rbp | rsi | rdi | r8 | r9 | r10 | r11 | r12 | r13 | r14 | r15 | rip +| eax | ecx | edx | ebx | esp | ebp | esi | edi | r8d | r9d | r10d | r11d | r12d | r13d | r14d | r15d | eip +| ax | cx | dx | bx | sp | bp | si | di | r8w | r9w | r10w | r11w | r12w | r13w | r14w | r15w | ip | ah | al | ch | cl | dh | dl | bh | bl | spl | bpl | sil | dil | r8b | r9b | r10b | r11b | r12b | r13b | r14b | r15b . +Derive REG_Listable SuchThat (@FinitelyListable REG REG_Listable) As REG_FinitelyListable. +Proof. prove_ListableDerive. Qed. +Global Existing Instances REG_Listable REG_FinitelyListable. +Definition REG_beq : REG -> REG -> bool := eqb_of_listable. +Definition REG_dec_bl : forall x y, REG_beq x y = true -> x = y := eqb_of_listable_bl. +Definition REG_dec_lb : forall x y, x = y -> REG_beq x y = true := eqb_of_listable_lb. +Definition REG_eq_dec : forall x y : REG, {x = y} + {x <> y} := eq_dec_of_listable. + Definition CONST := Z. Coercion CONST_of_Z (x : Z) : CONST := x. Inductive AccessSize := byte | word | dword | qword. + +Derive AccessSize_Listable SuchThat (@FinitelyListable AccessSize AccessSize_Listable) As AccessSize_FinitelyListable. +Proof. prove_ListableDerive. Qed. +Global Existing Instances AccessSize_Listable AccessSize_FinitelyListable. +Definition AccessSize_beq : AccessSize -> AccessSize -> bool := eqb_of_listable. +Definition AccessSize_dec_bl : forall x y, AccessSize_beq x y = true -> x = y := eqb_of_listable_bl. +Definition AccessSize_dec_lb : forall x y, x = y -> AccessSize_beq x y = true := eqb_of_listable_lb. +Definition AccessSize_eq_dec : forall x y : AccessSize, {x = y} + {x <> y} := eq_dec_of_listable. + Coercion bits_of_AccessSize (x : AccessSize) : N := match x with | byte => 8 @@ -41,12 +60,28 @@ Definition mem_of_reg (r : REG) : MEM := Inductive FLAG := CF | PF | AF | ZF | SF | OF. +Derive FLAG_Listable SuchThat (@FinitelyListable FLAG FLAG_Listable) As FLAG_FinitelyListable. +Proof. prove_ListableDerive. Qed. +Global Existing Instances FLAG_Listable FLAG_FinitelyListable. +Definition FLAG_beq : FLAG -> FLAG -> bool := eqb_of_listable. +Definition FLAG_dec_bl : forall x y, FLAG_beq x y = true -> x = y := eqb_of_listable_bl. +Definition FLAG_dec_lb : forall x y, x = y -> FLAG_beq x y = true := eqb_of_listable_lb. +Definition FLAG_eq_dec : forall x y : FLAG, {x = y} + {x <> y} := eq_dec_of_listable. + Inductive OpPrefix := | rep | repz | repnz . +Derive OpPrefix_Listable SuchThat (@FinitelyListable OpPrefix OpPrefix_Listable) As OpPrefix_FinitelyListable. +Proof. prove_ListableDerive. Qed. +Global Existing Instances OpPrefix_Listable OpPrefix_FinitelyListable. +Definition OpPrefix_beq : OpPrefix -> OpPrefix -> bool := eqb_of_listable. +Definition OpPrefix_dec_bl : forall x y, OpPrefix_beq x y = true -> x = y := eqb_of_listable_bl. +Definition OpPrefix_dec_lb : forall x y, x = y -> OpPrefix_beq x y = true := eqb_of_listable_lb. +Definition OpPrefix_eq_dec : forall x y : OpPrefix, {x = y} + {x <> y} := eq_dec_of_listable. + Inductive OpCode := | adc | adcx @@ -62,10 +97,10 @@ Inductive OpCode := | cmovo | cmp | db +| dw | dd -| dec | dq -| dw +| dec | imul | inc | je @@ -95,6 +130,64 @@ Inductive OpCode := | xor . +Derive OpCode_Listable SuchThat (@FinitelyListable OpCode OpCode_Listable) As OpCode_FinitelyListable. +Proof. prove_ListableDerive. Qed. +Global Existing Instances OpCode_Listable OpCode_FinitelyListable. +Definition OpCode_beq : OpCode -> OpCode -> bool := eqb_of_listable. +Definition OpCode_dec_bl : forall x y, OpCode_beq x y = true -> x = y := eqb_of_listable_bl. +Definition OpCode_dec_lb : forall x y, x = y -> OpCode_beq x y = true := eqb_of_listable_lb. +Definition OpCode_eq_dec : forall x y : OpCode, {x = y} + {x <> y} := eq_dec_of_listable. + +Definition accesssize_of_declaration (opc : OpCode) : option AccessSize := + match opc with + | db => Some byte + | dd => Some dword + | dq => Some qword + | dw => Some word + | adc + | adcx + | add + | adox + | and + | bzhi + | call + | clc + | cmovb + | cmovc + | cmovnz + | cmovo + | cmp + | dec + | imul + | inc + | je + | jmp + | lea + | mov + | movzx + | mul + | mulx + | or + | pop + | push + | rcr + | ret + | sar + | sbb + | setc + | seto + | shl + | shlx + | shr + | shrx + | shrd + | sub + | test + | xchg + | xor + => None + end. + Record JUMP_LABEL := { jump_near : bool ; label_name : string }. Inductive ARG := reg (r : REG) | mem (m : MEM) | const (c : CONST) | label (l : JUMP_LABEL). @@ -119,11 +212,11 @@ Definition Lines := list Line. Definition reg_size (r : REG) : N := match r with - |( rax | rcx | rdx | rbx | rsp | rbp | rsi | rdi | r8 | r9 | r10 | r11 | r12 | r13 | r14 | r15 ) + |( rax | rcx | rdx | rbx | rsp | rbp | rsi | rdi | r8 | r9 | r10 | r11 | r12 | r13 | r14 | r15 | rip) => 64 - |( eax | ecx | edx | ebx | esp | ebp | esi | edi | r8d | r9d | r10d | r11d | r12d | r13d | r14d | r15d) + |( eax | ecx | edx | ebx | esp | ebp | esi | edi | r8d | r9d | r10d | r11d | r12d | r13d | r14d | r15d | eip) => 32 - |( ax | cx | dx | bx | sp | bp | si | di | r8w | r9w | r10w | r11w | r12w | r13w | r14w | r15w) + |( ax | cx | dx | bx | sp | bp | si | di | r8w | r9w | r10w | r11w | r12w | r13w | r14w | r15w | ip) => 16 |(ah | al | ch | cl | dh | dl | bh | bl | spl | bpl | sil | dil | r8b | r9b | r10b | r11b | r12b | r13b | r14b | r15b) => 8 @@ -172,154 +265,74 @@ Definition operand_size (x : ARG) (operation_size : N) : N := | None => operation_size end. - -Definition reg_index (r : REG) : nat - := match r with - | rax - | eax - | ax - |(ah | al) - => 0 - | rcx - | ecx - | cx - |(ch | cl) - => 1 - | rdx - | edx - | dx - |(dh | dl) - => 2 - | rbx - | ebx - | bx - |(bh | bl) - => 3 - | rsp - | esp - | sp - |( spl) - => 4 - | rbp - | ebp - | bp - |( bpl) - => 5 - | rsi - | esi - | si - |( sil) - => 6 - | rdi - | edi - | di - |( dil) - => 7 - | r8 - | r8d - | r8w - | r8b - => 8 - | r9 - | r9d - | r9w - | r9b - => 9 - | r10 - | r10d - | r10w - | r10b - => 10 - | r11 - | r11d - | r11w - | r11b - => 11 - | r12 - | r12d - | r12w - | r12b - => 12 - | r13 - | r13d - | r13w - | r13b - => 13 - | r14 - | r14d - | r14w - | r14b - => 14 - | r15 - | r15d - | r15w - | r15b - => 15 - end. Definition reg_offset (r : REG) : N := - match r with - |( rax | rcx | rdx | rbx | rsp | rbp | rsi | rdi | r8 | r9 | r10 | r11 | r12 | r13 | r14 | r15 ) - |( eax | ecx | edx | ebx | esp | ebp | esi | edi | r8d | r9d | r10d | r11d | r12d | r13d | r14d | r15d) - |( ax | cx | dx | bx | sp | bp | si | di | r8w | r9w | r10w | r11w | r12w | r13w | r14w | r15w) - |( al | cl | dl | bl | spl | bpl | sil | dil | r8b | r9b | r10b | r11b | r12b | r13b | r14b | r15b) - => 0 - |(ah | ch | dh | bh ) - => 8 - end. -Definition index_and_shift_and_bitcount_of_reg (r : REG) := - (reg_index r, reg_offset r, reg_size r). + match r with + |(ah | ch | dh | bh ) + => 8 + | _ => 0 + end. -Definition regs_of_index (index : nat) : list (list REG) := - match index with - | 0 => [ [ al ; ah] ; [ ax] ; [ eax] ; [rax] ] - | 1 => [ [ cl ; ch] ; [ cx] ; [ ecx] ; [rcx] ] - | 2 => [ [ dl ; dh] ; [ dx] ; [ edx] ; [rdx] ] - | 3 => [ [ bl ; bh] ; [ bx] ; [ ebx] ; [rbx] ] - | 4 => [ [ spl ] ; [ sp] ; [ esp] ; [rsp] ] - | 5 => [ [ bpl ] ; [ bp] ; [ ebp] ; [rbp] ] - | 6 => [ [ sil ] ; [ si] ; [ esi] ; [rsi] ] - | 7 => [ [ dil ] ; [ di] ; [ edi] ; [rdi] ] - | 8 => [ [ r8b ] ; [ r8w] ; [ r8d] ; [r8 ] ] - | 9 => [ [ r9b ] ; [ r9w] ; [ r9d] ; [r9 ] ] - | 10 => [ [r10b ] ; [r10w] ; [r10d] ; [r10] ] - | 11 => [ [r11b ] ; [r11w] ; [r11d] ; [r11] ] - | 12 => [ [r12b ] ; [r12w] ; [r12d] ; [r12] ] - | 13 => [ [r13b ] ; [r13w] ; [r13d] ; [r13] ] - | 14 => [ [r14b ] ; [r14w] ; [r14d] ; [r14] ] - | 15 => [ [r15b ] ; [r15w] ; [r15d] ; [r15] ] - | _ => [] +Definition widest_register_of (r : REG) : REG := + match r with + | ((al | ah) | ax | eax | rax) => rax + | ((cl | ch) | cx | ecx | rcx) => rcx + | ((dl | dh) | dx | edx | rdx) => rdx + | ((bl | bh) | bx | ebx | rbx) => rbx + | (spl | sp | esp | rsp) => rsp + | (bpl | bp | ebp | rbp) => rbp + | (sil | si | esi | rsi) => rsi + | (dil | di | edi | rdi) => rdi + | (r8b | r8w | r8d | r8) => r8 + | (r9b | r9w | r9d | r9) => r9 + | (r10b | r10w | r10d | r10) => r10 + | (r11b | r11w | r11d | r11) => r11 + | (r12b | r12w | r12d | r12) => r12 + | (r13b | r13w | r13d | r13) => r13 + | (r14b | r14w | r14d | r14) => r14 + | (r15b | r15w | r15d | r15) => r15 + | (ip | eip | rip) => rip end. +Definition widest_registers := Eval lazy in List.filter (fun x => REG_beq x (widest_register_of x)) (list_all REG). + +Definition wide_reg_index_pairs := Eval lazy in List.map (fun '(n, r) => (N.of_nat n, r)) (List.enumerate widest_registers). + +Definition eta_reg {A} : (REG -> A) -> (REG -> A). +Proof. + intros f r; pose (f r) as fr; destruct r. + all: let v := eval cbv in fr in exact v. +Defined. + +Definition reg_index (r : REG) : N := Eval lazy in + eta_reg (fun r => + Option.value + (option_map (@fst _ _) (find (fun '(n, r') => REG_beq (widest_register_of r) r') wide_reg_index_pairs)) + 0%N) + r. + +Definition widest_register_of_index_opt (n : N) : option REG + := List.nth_error (List.map (@snd _ _) wide_reg_index_pairs) (N.to_nat n). + (** convenience printing function *) -Definition widest_register_of_index (n : nat) : REG - := match n with - | 0 => rax - | 1 => rcx - | 2 => rdx - | 3 => rbx - | 4 => rsp - | 5 => rbp - | 6 => rsi - | 7 => rdi - | 8 => r8 - | 9 => r9 - | 10 => r10 - | 11 => r11 - | 12 => r12 - | 13 => r13 - | 14 => r14 - | 15 => r15 - | _ => rax - end%nat. +Definition widest_register_of_index (n : N) : REG + := Option.value (widest_register_of_index_opt n) rax. + +Definition widest_reg_size_of (r : REG) : N := + reg_size (widest_register_of_index (reg_index r)). + +Definition index_and_shift_and_bitcount_of_reg (r : REG) := + (reg_index r, reg_offset r, reg_size r). + +Definition overlapping_registers (r : REG) : list REG := Eval lazy in eta_reg + (fun r => List.filter (fun r' => REG_beq (widest_register_of r) (widest_register_of r')) (list_all REG)) + r. Definition reg_of_index_and_shift_and_bitcount_opt := fun '(index, offset, size) => - let sz := N.log2 (size / 8) in - let offset_n := (offset / 8)%N in - if ((8 * 2^sz =? size) && (offset =? offset_n * 8))%N%bool - then (rs <- nth_error (regs_of_index index) (N.to_nat sz); - nth_error rs (N.to_nat offset_n))%option - else None. + (wr <- widest_register_of_index_opt index; + let rs := overlapping_registers wr in + List.find (fun r => ((reg_size r =? size) && (reg_offset r =? offset))%N%bool) rs)%option. + Definition reg_of_index_and_shift_and_bitcount := fun '(index, offset, size) => match reg_of_index_and_shift_and_bitcount_opt (index, offset, size) with @@ -327,59 +340,6 @@ Definition reg_of_index_and_shift_and_bitcount := | None => widest_register_of_index index end. -Lemma widest_register_of_index_correct - : forall n, - (~exists r, reg_index r = n) - \/ (let r := widest_register_of_index n in reg_index r = n - /\ forall r', reg_index r' = n -> r = r' \/ (reg_size r' < reg_size r)%N). -Proof. - intro n; set (r := widest_register_of_index n). - cbv in r. - repeat match goal with r := context[match ?n with _ => _ end] |- _ => destruct n; [ right | ] end; - [ .. | left; intros [ [] H]; cbv in H; congruence ]. - all: subst r; split; [ reflexivity | ]. - all: intros [] H; cbv in H; try (exfalso; congruence). - all: try (left; reflexivity). - all: try (right; vm_compute; reflexivity). -Qed. - -Lemma reg_of_index_and_shift_and_bitcount_opt_correct v r - : reg_of_index_and_shift_and_bitcount_opt v = Some r <-> index_and_shift_and_bitcount_of_reg r = v. -Proof. - split; [ | intro; subst; destruct r; vm_compute; reflexivity ]. - cbv [index_and_shift_and_bitcount_of_reg]; destruct v as [ [index shift] bitcount ]. - cbv [reg_of_index_and_shift_and_bitcount_opt]. - generalize (shift / 8)%N (N.log2 (bitcount / 8)); intros *. - repeat first [ congruence - | progress subst - | match goal with - | [ H : _ /\ _ |- _ ] => destruct H - | [ H : N.to_nat _ = _ |- _ ] => apply (f_equal N.of_nat) in H; rewrite N2Nat.id in H; subst - | [ |- Some _ = Some _ -> _ ] => inversion 1; subst - | [ |- context[match ?x with _ => _ end] ] => destruct x eqn:?; subst - end - | progress cbv [regs_of_index] - | match goal with - | [ |- context[nth_error _ ?n] ] => destruct n eqn:?; cbn [nth_error Option.bind] - end - | rewrite Bool.andb_true_iff, ?N.eqb_eq in * |- ]. - all: vm_compute; reflexivity. -Qed. - -Lemma reg_of_index_and_shift_and_bitcount_of_reg r - : reg_of_index_and_shift_and_bitcount (index_and_shift_and_bitcount_of_reg r) = r. -Proof. destruct r; vm_compute; reflexivity. Qed. - -Lemma reg_of_index_and_shift_and_bitcount_eq v r - : reg_of_index_and_shift_and_bitcount v = r - -> (index_and_shift_and_bitcount_of_reg r = v - \/ ((~exists r, index_and_shift_and_bitcount_of_reg r = v) - /\ r = widest_register_of_index (fst (fst v)))). -Proof. - cbv [reg_of_index_and_shift_and_bitcount]. - destruct v as [ [index offset] size ]. - destruct reg_of_index_and_shift_and_bitcount_opt eqn:H; - [ left | right; split; [ intros [r' H'] | ] ]; subst; try reflexivity. - { rewrite reg_of_index_and_shift_and_bitcount_opt_correct in H; assumption. } - { rewrite <- reg_of_index_and_shift_and_bitcount_opt_correct in H'; congruence. } -Qed. +Class assembly_program_options := { + default_rel : bool ; +}. diff --git a/src/Assembly/SyntaxTests.v b/src/Assembly/SyntaxTests.v new file mode 100644 index 0000000000..4659c239df --- /dev/null +++ b/src/Assembly/SyntaxTests.v @@ -0,0 +1,155 @@ +From Coq Require Import ZArith. +From Coq Require Import NArith. +From Coq Require Import String. +From Coq Require Import List. +From Coq Require Import Derive. +Require Import Crypto.Util.Prod. +Require Import Crypto.Util.Option. +Require Import Crypto.Util.Bool.Reflect. +Require Import Crypto.Util.Listable. +Require Import Crypto.Util.ListUtil. +Require Import Crypto.Assembly.Syntax. +Require Crypto.Util.Tuple. +Require Crypto.Util.OptionList. +Import ListNotations. + +Local Open Scope list_scope. + +Local Set Implicit Arguments. +Local Set Primitive Projections. + +Local Coercion N.of_nat : nat >-> N. + +Lemma reg_of_index_and_shift_and_bitcount_opt_of_index_and_shift_and_bitcount_of_reg : forall r : REG, reg_of_index_and_shift_and_bitcount_opt (index_and_shift_and_bitcount_of_reg r) = Some r. +Proof. destruct r; vm_compute; try reflexivity. Defined. + +Lemma reg_of_index_and_shift_and_bitcount_of_index_and_shift_and_bitcount_of_reg : forall r : REG, reg_of_index_and_shift_and_bitcount (index_and_shift_and_bitcount_of_reg r) = r. +Proof. destruct r; vm_compute; reflexivity. Defined. + +Lemma reg_index_widest_register_of : forall r : REG, reg_index (widest_register_of r) = reg_index r. +Proof. destruct r; reflexivity. Defined. + +Lemma reg_index_widest_register_of_index_opt : forall index : N, (N.to_nat index option_map reg_index (widest_register_of_index_opt index) = Some index. +Proof. + intros index; cbv [widest_register_of_index_opt]. + rewrite <- (N2Nat.id index), Nat2N.id; generalize (N.to_nat index); clear index; intro index. + vm_compute List.map. + vm_compute List.length. + cbv [Nat.ltb]; cbn [Nat.leb]. + repeat lazymatch goal with + | [ |- false = true -> _ ] => discriminate + | [ |- (?index <=? _) = true -> _ ] => + is_var index; destruct index; [ reflexivity | cbn [Nat.leb] ] + end. +Qed. + +Lemma widest_register_of_index_opt_Some_length_iff : forall index : N, (exists r, widest_register_of_index_opt index = Some r) <-> (N.to_nat index _ ] => is_var index; destruct index; cbn [nth_error] + | [ |- _ <-> false = true ] => split; [ | discriminate ] + | [ |- _ <-> true = true ] => repeat esplit + | [ |- _ <-> (?index <=? _) = true ] => + is_var index; destruct index; cbn [Nat.leb nth_error] + end. + all: intros [? H]; discriminate. +Qed. + +Lemma reg_index_widest_register_of_index : forall index : N, (N.to_nat index reg_index (widest_register_of_index index) = index. +Proof. + intros index H; cbv [widest_register_of_index]. + apply reg_index_widest_register_of_index_opt in H. + destruct widest_register_of_index_opt; cbv [option_map] in *; inversion H; subst. + reflexivity. +Qed. + +Lemma reg_index_overlapping_registers : forall r r' n, nth_error (overlapping_registers r) n = Some r' -> reg_index r' = reg_index r. +Proof. + intros r r' n; destruct r. + all: vm_compute overlapping_registers. + all: repeat lazymatch goal with + | [ |- nth_error (_ :: _) ?v = Some _ -> _ ] => is_var v; destruct v; cbn [nth_error] + | [ |- nth_error [] ?v = Some _ -> _ ] => is_var v; destruct v; cbn [nth_error] + | [ |- Some _ = Some _ -> _ ] => let H := fresh in intro H; inversion H + | [ |- None = Some _ -> _ ] => let H := fresh in intro H; inversion H + end. + all: subst; reflexivity. +Qed. + +Lemma reg_of_index_and_shift_and_bitcount_of_reg r + : reg_of_index_and_shift_and_bitcount (index_and_shift_and_bitcount_of_reg r) = r. +Proof. destruct r; vm_compute; reflexivity. Qed. + +Lemma widest_register_of_index_opt_correct + : forall n r, widest_register_of_index_opt n = Some r -> + reg_index r = n + /\ forall r', reg_index r' = n -> r = r' \/ (reg_size r' < reg_size r)%N. +Proof. + intros n r H. + epose proof (proj1 (widest_register_of_index_opt_Some_length_iff _) (ex_intro _ _ H)) as H'. + pose proof H' as H''. + apply reg_index_widest_register_of_index_opt in H''. + rewrite H in H''; cbn in H''; inversion H''; subst. + split; [ reflexivity | ]. + destruct r, r'. + all: vm_compute; try (constructor; reflexivity); try discriminate. +Qed. + +Lemma widest_register_of_index_correct + : forall n, + (~exists r, reg_index r = n) + \/ (let r := widest_register_of_index n in reg_index r = n + /\ forall r', reg_index r' = n -> r = r' \/ (reg_size r' < reg_size r)%N). +Proof. + intro n; pose proof (widest_register_of_index_opt_correct n) as H. + cbv [widest_register_of_index]. + destruct (widest_register_of_index_opt n) as [r |] eqn:H'; [ right; apply H; reflexivity | left ]. + intros [ [] H'' ]; subst; cbv in H'. + all: inversion H'. +Qed. + +Lemma reg_of_index_and_shift_and_bitcount_opt_correct v r + : reg_of_index_and_shift_and_bitcount_opt v = Some r <-> index_and_shift_and_bitcount_of_reg r = v. +Proof. + split; [ | intro; subst; destruct r; vm_compute; reflexivity ]. + cbv [index_and_shift_and_bitcount_of_reg]; destruct v as [ [index shift] bitcount ]. + cbv [reg_of_index_and_shift_and_bitcount_opt]. + pose proof (reg_index_widest_register_of_index index) as H''. + cbv [widest_register_of_index] in H''. + rewrite <- widest_register_of_index_opt_Some_length_iff in H''. + destruct widest_register_of_index_opt eqn:H; [ | intro H'; cbv in H'; now inversion H' ]. + cbv [Option.bind Option.sequence_return] in *. + specialize (H'' (ex_intro _ _ eq_refl)). + subst. + rewrite find_some_iff. + repeat first + [ progress intros + | progress destruct_head'_ex + | progress destruct_head'_and + | progress reflect_hyps + | progress subst + | match goal with + | [ H : nth_error (overlapping_registers _) _ = Some _ |- _ ] => + apply reg_index_overlapping_registers in H; try rewrite H + end + | reflexivity ]. +Qed. + +Lemma reg_of_index_and_shift_and_bitcount_eq v r + : reg_of_index_and_shift_and_bitcount v = r + -> (index_and_shift_and_bitcount_of_reg r = v + \/ ((~exists r, index_and_shift_and_bitcount_of_reg r = v) + /\ r = widest_register_of_index (fst (fst v)))). +Proof. + cbv [reg_of_index_and_shift_and_bitcount]. + destruct v as [ [index offset] size ]. + destruct reg_of_index_and_shift_and_bitcount_opt eqn:H; + [ left | right; split; [ intros [r' H'] | ] ]; subst; try reflexivity. + { rewrite reg_of_index_and_shift_and_bitcount_opt_correct in H; assumption. } + { rewrite <- reg_of_index_and_shift_and_bitcount_opt_correct in H'; congruence. } +Qed. diff --git a/src/Assembly/WithBedrock/Proofs.v b/src/Assembly/WithBedrock/Proofs.v index aa0cf55824..08ec58e652 100644 --- a/src/Assembly/WithBedrock/Proofs.v +++ b/src/Assembly/WithBedrock/Proofs.v @@ -245,7 +245,7 @@ Definition init_symbolic_state_G (m : machine_state) let '(initial_reg_idxs, (G, d)) := dag_gensym_n_G vals st in (G, {| dag_state := d - ; symbolic_reg_state := Tuple.from_list_default None 16 (List.map Some initial_reg_idxs) + ; symbolic_reg_state := Tuple.from_list_default None _ (List.map Some initial_reg_idxs) ; symbolic_flag_state := Tuple.repeat None 6 ; symbolic_mem_state := [] |}). @@ -254,7 +254,7 @@ Lemma init_symbolic_state_eq_G G d m : init_symbolic_state d = snd (init_symbolic_state_G m (G, d)). Proof. cbv [init_symbolic_state init_symbolic_state_G]. - epose proof (dag_gensym_n_eq_G G d (Tuple.to_list 16 m.(machine_reg_state))) as H. + epose proof (dag_gensym_n_eq_G G d (Tuple.to_list _ m.(machine_reg_state))) as H. rewrite Tuple.length_to_list in H; rewrite H; clear H. eta_expand; cbn [fst snd]. reflexivity. @@ -266,7 +266,7 @@ Lemma init_symbolic_state_G_ok m G d G' ss (H : init_symbolic_state_G m (G, d) = (G', ss)) (d' := ss.(dag_state)) (Hframe : frame m) - (Hbounds : Forall (fun v => (0 <= v < 2^64)%Z) (Tuple.to_list 16 m.(machine_reg_state))) + (Hbounds : Forall (fun v => (0 <= v < 2^64)%Z) (Tuple.to_list _ m.(machine_reg_state))) : R frame G' ss m /\ (forall reg, Option.is_Some (Symbolic.get_reg ss.(symbolic_reg_state) (reg_index reg)) = true) /\ gensym_dag_ok G' d' @@ -291,11 +291,11 @@ Proof. by (eapply dag_gensym_n_G_ok; [ | eta_expand; reflexivity | ]; assumption). clear; cbv [reg_index]; break_innermost_match; lia. } set (v := dag_gensym_n_G _ _) in *; clearbody v; destruct_head'_prod; cbn [fst snd] in *. - eassert (H' : Tuple.to_list 16 m.(machine_reg_state) = _). + eassert (H' : Tuple.to_list _ m.(machine_reg_state) = _). { repeat match goal with H : _ |- _ => clear H end. cbv [Tuple.to_list Tuple.to_list']. set_evars; eta_expand; subst_evars; reflexivity. } - generalize dependent (Tuple.to_list 16 m.(machine_reg_state)); intros; subst. + generalize dependent (Tuple.to_list _ m.(machine_reg_state)); intros; subst. repeat match goal with H : context[?x :: _] |- _ => assert_fails is_var x; set x in * end. repeat match goal with H : Forall2 _ ?v (_ :: _) |- _ => is_var v; inversion H; clear H; subst end. repeat match goal with H : Forall2 _ ?v nil |- _ => is_var v; inversion H; clear H; subst end. @@ -313,7 +313,7 @@ Lemma init_symbolic_state_ok m G d (Hd : gensym_dag_ok G d) (ss := init_symbolic_state d) (d' := ss.(dag_state)) - (Hbounds : Forall (fun v => (0 <= v < 2^64)%Z) (Tuple.to_list 16 m.(machine_reg_state))) + (Hbounds : Forall (fun v => (0 <= v < 2^64)%Z) (Tuple.to_list _ m.(machine_reg_state))) (Hframe : frame m) : exists G', R frame G' ss m @@ -332,12 +332,12 @@ Lemma get_reg_of_R_regs G d r mr reg : forall idx', Symbolic.get_reg r (reg_index reg) = Some idx' -> eval_idx_Z G d idx' (Semantics.get_reg mr reg). Proof. assert (reg_offset reg = 0%N) by now destruct reg. - assert (reg_index reg < length (Tuple.to_list _ r)) + assert (N.to_nat (reg_index reg) < List.length (Tuple.to_list _ r)) by now rewrite Tuple.length_to_list; destruct reg; cbv [reg_index]; lia. cbv [Symbolic.get_reg Semantics.get_reg R_regs] in *. rewrite Tuple.fieldwise_to_list_iff in Hreg. erewrite @Forall2_forall_iff in Hreg by now rewrite !Tuple.length_to_list. - specialize (Hreg (reg_index reg) ltac:(assumption)); rewrite !@Tuple.nth_default_to_list in *. + specialize (Hreg (N.to_nat (reg_index reg)) ltac:(assumption)); rewrite !@Tuple.nth_default_to_list in *. cbv [index_and_shift_and_bitcount_of_reg] in *. generalize dependent (reg_size reg); intros; subst. generalize dependent (reg_offset reg); intros; subst. @@ -464,7 +464,7 @@ Definition R_regs_preserved_v rn (m : Semantics.reg_state) := Z.land (Tuple.nth_default 0%Z rn m) (Z.ones 64). Definition R_regs_preserved G d G' d' (m : Semantics.reg_state) rs rs' - := forall rn idx, Symbolic.get_reg rs' rn = Some idx -> exists idx', Symbolic.get_reg rs rn = Some idx' /\ let v := R_regs_preserved_v rn m in eval_idx_Z G d idx' v -> eval_idx_Z G' d' idx v. + := forall rn idx, Symbolic.get_reg rs' rn = Some idx -> exists idx', Symbolic.get_reg rs rn = Some idx' /\ let v := R_regs_preserved_v (N.to_nat rn) m in eval_idx_Z G d idx' v -> eval_idx_Z G' d' idx v. Global Instance R_regs_preserved_Reflexive G d m : Reflexive (R_regs_preserved G d G d m) | 100. Proof. intro; cbv [R_regs_preserved]; eauto. Qed. @@ -476,19 +476,22 @@ Lemma R_regs_subsumed_get_reg_same_eval G d G' d' rs rs' rm Proof. cbv [R_regs Symbolic.get_reg R_regs_preserved R_regs_preserved_v] in *. rewrite @Tuple.fieldwise_to_list_iff, @Forall2_forall_iff_nth_error in *. - intro i; specialize (H i); specialize (H_impl i). + intro i; specialize (H i); specialize (H_impl (N.of_nat i)). rewrite <- !@Tuple.nth_default_to_list in *. cbv [nth_default option_eq] in *. repeat first [ progress destruct_head'_and | progress destruct_head'_ex + | progress inversion_option | rewrite @Tuple.length_to_list in * | progress cbv [R_reg eval_idx_Z] in * | progress break_innermost_match | progress break_innermost_match_hyps + | rewrite !Nat2N.id in * | now auto | progress intros | progress subst | match goal with + | [ H : ?x = Some ?a, H' : ?x = Some ?b |- _ ] => rewrite H in H' | [ H : nth_error _ _ = None |- _ ] => apply nth_error_error_length in H | [ H : ?i >= ?n, H' : context[nth_error (Tuple.to_list ?n _) ?i] |- _ ] => rewrite nth_error_length_error in H' by now rewrite Tuple.length_to_list; lia @@ -501,12 +504,13 @@ Qed. Lemma R_regs_preserved_set_reg G d G' d' rs rs' ri rm v (H : R_regs_preserved G d G' d' rm rs rs') - (H_same : (ri < 16)%nat -> exists idx, Symbolic.get_reg rs ri = Some idx /\ let v' := R_regs_preserved_v ri rm in eval_idx_Z G d idx v' -> eval_idx_Z G' d' v v') + (H_same : (ri < N.of_nat (List.length widest_registers))%N -> exists idx, Symbolic.get_reg rs ri = Some idx /\ let v' := R_regs_preserved_v (N.to_nat ri) rm in eval_idx_Z G d idx v' -> eval_idx_Z G' d' v v') : R_regs_preserved G d G' d' rm rs (Symbolic.set_reg rs' ri v). Proof. cbv [R_regs_preserved] in *. intros rn idx; specialize (H rn). rewrite get_reg_set_reg_full; intro. + vm_compute (length widest_registers) in *. repeat first [ progress break_innermost_match_hyps | progress inversion_option | progress subst @@ -527,7 +531,7 @@ Qed. Lemma R_regs_preserved_fold_left_set_reg_index {T1 T2} G d G' d' rs rs' rm (r_idxs : list (_ * (_ * T1 + _ * T2))) (H : R_regs_preserved G d G' d' rm rs rs') - (H_same : Forall (fun '(r, v) => let v := match v with inl (v, _) => v | inr (v, _) => v end in exists idx, Symbolic.get_reg rs (reg_index r) = Some idx /\ let v' := R_regs_preserved_v (reg_index r) rm in eval_idx_Z G d idx v' -> eval_idx_Z G' d' v v') r_idxs) + (H_same : Forall (fun '(r, v) => let v := match v with inl (v, _) => v | inr (v, _) => v end in exists idx, Symbolic.get_reg rs (reg_index r) = Some idx /\ let v' := R_regs_preserved_v (N.to_nat (reg_index r)) rm in eval_idx_Z G d idx v' -> eval_idx_Z G' d' v v') r_idxs) : R_regs_preserved G d G' d' rm rs @@ -548,12 +552,12 @@ Qed. Lemma Semantics_get_reg_eq_nth_default_of_R_regs G d ss ms r (Hsz : reg_size r = 64%N) (HR : R_regs G d ss ms) - : Semantics.get_reg ms r = Tuple.nth_default 0%Z (reg_index r) (ms : Semantics.reg_state). + : Semantics.get_reg ms r = Tuple.nth_default 0%Z (N.to_nat (reg_index r)) (ms : Semantics.reg_state). Proof. assert (Hro : reg_offset r = 0%N) by now revert Hsz; clear; cbv; destruct r; lia. cbv [R_regs R_reg] in HR. rewrite Tuple.fieldwise_to_list_iff, Forall2_forall_iff_nth_error in HR. - specialize (HR (reg_index r)). + specialize (HR (N.to_nat (reg_index r))). cbv [Semantics.get_reg index_and_shift_and_bitcount_of_reg]. rewrite Hro, Hsz; change (Z.of_N 0) with 0%Z; change (Z.of_N 64) with 64%Z. rewrite Z.shiftr_0_r, <- Tuple.nth_default_to_list; cbv [nth_default option_eq] in *. @@ -563,7 +567,7 @@ Qed. Lemma Semantics_get_reg_eq_nth_default_of_R frame G ss ms r (Hsz : reg_size r = 64%N) (HR : R frame G ss ms) - : Semantics.get_reg ms r = Tuple.nth_default 0%Z (reg_index r) (ms : Semantics.reg_state). + : Semantics.get_reg ms r = Tuple.nth_default 0%Z (N.to_nat (reg_index r)) (ms : Semantics.reg_state). Proof. destruct ss, ms; eapply Semantics_get_reg_eq_nth_default_of_R_regs; try eassumption; apply HR. Qed. @@ -586,7 +590,7 @@ Lemma Forall2_R_regs_preserved_same_helper G d reg_available idxs m rs : Forall (fun '(r, v) => let v := match v with inl (v, _) => v | inr (v, _) => v end in exists idx, Symbolic.get_reg rs (reg_index r) = Some idx - /\ let v' := R_regs_preserved_v (reg_index r) m in eval_idx_Z G d idx v' -> eval_idx_Z G d v v') + /\ let v' := R_regs_preserved_v (N.to_nat (reg_index r)) m in eval_idx_Z G d idx v' -> eval_idx_Z G d v v') (List.combine reg_available idxs). Proof. cbv [get_asm_reg] in *. diff --git a/src/Assembly/WithBedrock/Semantics.v b/src/Assembly/WithBedrock/Semantics.v index f436d155d7..695e4aa6c2 100644 --- a/src/Assembly/WithBedrock/Semantics.v +++ b/src/Assembly/WithBedrock/Semantics.v @@ -53,23 +53,23 @@ Definition havoc_flag (st : flag_state) (f : FLAG) : flag_state Definition havoc_flags : flag_state := (None, None, None, None, None, None). -Definition reg_state := Tuple.tuple Z 16. +Definition reg_state := Tuple.tuple Z (compute! (List.length widest_registers)). Definition bitmask_of_reg (r : REG) : Z := let '(idx, shift, bitcount) := index_and_shift_and_bitcount_of_reg r in Z.shiftl (Z.ones (Z.of_N bitcount)) (Z.of_N shift). Definition get_reg (st : reg_state) (r : REG) : Z := let '(idx, shift, bitcount) := index_and_shift_and_bitcount_of_reg r in - let rv := Tuple.nth_default 0%Z idx st in + let rv := Tuple.nth_default 0%Z (N.to_nat idx) st in Z.land (Z.shiftr rv (Z.of_N shift)) (Z.ones (Z.of_N bitcount)). Definition set_reg (st : reg_state) (r : REG) (v : Z) : reg_state := let '(idx, shift, bitcount) := index_and_shift_and_bitcount_of_reg r in Tuple.from_list_default 0%Z _ (ListUtil.update_nth - idx + (N.to_nat idx) (fun curv => Z.lor (Z.shiftl (Z.land v (Z.ones (Z.of_N bitcount))) (Z.of_N shift)) (Z.ldiff curv (Z.shiftl (Z.ones (Z.of_N bitcount)) (Z.of_N shift)))) (Tuple.to_list _ st)). Definition annotate_reg_state (st : reg_state) : list (REG * Z) - := List.map (fun '(n, v) => (widest_register_of_index n, v)) (enumerate (Tuple.to_list _ st)). + := List.combine widest_registers (Tuple.to_list _ st). Ltac print_reg_state st := let st' := (eval cbv in (annotate_reg_state st)) in idtac st'. (* Kludge since [byte] isn't present in Coq 8.9 *) diff --git a/src/Assembly/WithBedrock/SymbolicProofs.v b/src/Assembly/WithBedrock/SymbolicProofs.v index 38e9c0ccaf..a5ffb70215 100644 --- a/src/Assembly/WithBedrock/SymbolicProofs.v +++ b/src/Assembly/WithBedrock/SymbolicProofs.v @@ -314,7 +314,7 @@ Qed. Lemma get_reg_R_regs d s m (HR : R_regs d s m) ri : forall i, Symbolic.get_reg s ri = Some i -> - exists v, eval d i v /\ Tuple.nth_default 0 ri m = v. + exists v, eval d i v /\ Tuple.nth_default 0 (N.to_nat ri) m = v. Proof using Type. cbv [Symbolic.get_reg]; intros. rewrite <-Tuple.nth_default_to_list in H. @@ -336,7 +336,7 @@ Qed. Lemma get_reg_R s m (HR : R s m) ri : forall i, Symbolic.get_reg s ri = Some i -> - exists v, eval s i v /\ Tuple.nth_default 0 ri (m : reg_state) = v. + exists v, eval s i v /\ Tuple.nth_default 0 (N.to_nat ri) (m : reg_state) = v. Proof using Type. destruct s, m; apply get_reg_R_regs, HR. Qed. diff --git a/src/Util/Listable.v b/src/Util/Listable.v index feaad2c888..f898c909e7 100644 --- a/src/Util/Listable.v +++ b/src/Util/Listable.v @@ -7,6 +7,7 @@ Require Import Crypto.Util.Decidable. Require Import Crypto.Util.Bool.Reflect. Import ListNotations. +(* TODO: move to using [N] for performance instead of [nat] *) Class Listable T := { list_all : list T ; find_index : T -> nat }. Arguments find_index {T} {_} _.