Skip to content

Commit

Permalink
Add support for more assembly
Browse files Browse the repository at this point in the history
I want to be able to handle the output of gcc/clang on our C code.

Right now, I have added support for parsing the assembly output.
Equivalence checking is still to come.
  • Loading branch information
JasonGross committed Mar 1, 2025
1 parent 847d6a7 commit 461c421
Show file tree
Hide file tree
Showing 32 changed files with 24,195 additions and 252 deletions.
32 changes: 30 additions & 2 deletions src/Assembly/Equality.v
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,39 @@ Declare Scope REG_scope.
Delimit Scope REG_scope with REG.
Bind Scope REG_scope with REG.


Definition REG_beq (r1 r2 : REG) : bool :=
(prod_beq _ _ (prod_beq _ _ N.eqb N.eqb) N.eqb)
(index_and_shift_and_bitcount_of_reg r1) (index_and_shift_and_bitcount_of_reg r2).

Lemma REG_dec_lb : forall r1 r2 : REG, r1 = r2 -> REG_beq r1 r2 = true.
Proof.
intros r1 r2 H.
subst r2; destruct r1.
all: reflexivity.
Defined.

Lemma REG_dec_bl : forall r1 r2 : REG, REG_beq r1 r2 = true -> r1 = r2.
Proof.
cbv [REG_beq].
intros r1 r2 H.
rewrite <- (reg_of_index_and_shift_and_bitcount_of_index_and_shift_and_bitcount_of_reg r1), <- (reg_of_index_and_shift_and_bitcount_of_index_and_shift_and_bitcount_of_reg r2).
reflect_hyps.
rewrite H.
reflexivity.
Defined.

Definition REG_eq_dec (x y : REG) : {x = y} + {x <> y} :=
(if REG_beq x y as b return (REG_beq x y = b -> _)
then fun pf => left (REG_dec_bl x y pf)
else fun pf => right (fun pf' => diff_false_true (eq_trans (eq_sym pf) (REG_dec_lb x y pf'))))
eq_refl.

Infix "=?" := REG_beq : REG_scope.

Global Instance REG_beq_spec : reflect_rel (@eq REG) REG_beq | 10
:= reflect_of_beq internal_REG_dec_bl internal_REG_dec_lb.
Definition REG_beq_eq x y : (x =? y)%REG = true <-> x = y := conj (@internal_REG_dec_bl _ _) (@internal_REG_dec_lb _ _).
:= reflect_of_beq REG_dec_bl REG_dec_lb.
Definition REG_beq_eq x y : (x =? y)%REG = true <-> x = y := conj (@REG_dec_bl _ _) (@REG_dec_lb _ _).
Lemma REG_beq_neq x y : (x =? y)%REG = false <-> x <> y.
Proof. rewrite <- REG_beq_eq; destruct (x =? y)%REG; intuition congruence. Qed.
Global Instance REG_beq_compat : Proper (eq ==> eq ==> eq) REG_beq | 10.
Expand Down
5 changes: 3 additions & 2 deletions src/Assembly/Equivalence.v
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ From Coq Require Import ZArith.
From Coq Require Import NArith.
Require Import Crypto.Assembly.Syntax.
Require Import Crypto.Assembly.Parse.
Require Import Crypto.Assembly.Equality.
Require Import Crypto.Assembly.Symbolic.
Require Import Crypto.Util.Strings.Parse.Common.
Require Import Crypto.Util.ErrorT.
Expand Down Expand Up @@ -277,8 +278,8 @@ Definition show_annotated_Line : Show AnnotatedLine
end)%string.

Global Instance show_lines_AnnotatedLines : ShowLines AnnotatedLines
:= fun '(ls, ss)
=> let d := dag.eager.force ss.(dag_state) in
:= fun '(ls, sst)
=> let d := dag.eager.force sst.(dag_state) in
List.map (fun l => show_annotated_Line (l, d)) ls.

Fixpoint remove_common_indices {T} (eqb : T -> T -> bool) (xs ys : list T) (start_idx : nat) : list (nat * T) * list T
Expand Down
2 changes: 1 addition & 1 deletion src/Assembly/EquivalenceProofs.v
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Local Open Scope list_scope.

(* TODO: move to global settings *)
Local Set Keyed Unification.

Locate set_reg.
Definition eval_idx_Z (G : symbol -> option Z) (dag : dag) (idx : idx) (v : Z) : Prop
:= eval G dag (ExprRef idx) v.
Definition eval_idx_or_list_idx (G : symbol -> option Z) (d : dag) (v1 : idx + list idx) (v2 : Z + list Z)
Expand Down
56 changes: 42 additions & 14 deletions src/Assembly/Parse.v
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,26 @@ Definition parse_label : ParserAction string
(fun '(char, ls) => string_of_list_ascii (char :: ls))
(([a-zA-Z] || parse_any_ascii "._?$") ;;
(([a-zA-Z] || parse_any_ascii "0123456789_$#@~.?")* )).
Definition parse_non_access_size_label : ParserAction string
:= parse_lookahead_not parse_AccessSize ;;R parse_label.

Definition parse_MEM : ParserAction MEM
:= parse_map
(fun '(access_size, (br (*base reg*), sr (*scale reg, including z *), offset, base_label))
=> {| mem_bits_access_size := access_size:option AccessSize
; mem_base_reg := br:option REG
; mem_base_label := base_label
; mem_scale_reg := sr:option (Z * REG)
; mem_offset := offset:option Z |})
:= parse_option_list_map
(fun '(access_size, (constant_location_label, (br (*base reg*), sr (*scale reg, including z *), offset, base_label)))
=> match base_label, constant_location_label with
| Some _, Some _ => (* invalid? *) None
| Some _ as lbl, None
| None, Some _ as lbl
| None, None as lbl =>
Some
{| mem_bits_access_size := access_size:option AccessSize
; mem_base_reg := br:option REG
; mem_base_label := lbl
; mem_scale_reg := sr:option (Z * REG)
; mem_offset := offset:option Z |}
end)
(((strip_whitespace_after parse_AccessSize)?) ;;
(parse_non_access_size_label?) ;;
(parse_option_list_map
(fun '(offset, vars)
=> (vars <-- List.map (fun '(c, (v, e), vs) => match vs, e with [], 1%Z => Some (c, v) | _, _ => None end) vars;
Expand Down Expand Up @@ -160,7 +170,8 @@ Definition parse_OpCode_list : list (string * OpCode)
:= Eval vm_compute in
List.map
(fun r => (show r, r))
(list_all OpCode).
(list_all OpCode)
++ [(".quad", dq); (".word", dw); (".byte", db)].

Definition parse_OpCode : ParserAction OpCode
:= parse_strs_case_insensitive parse_OpCode_list.
Expand Down Expand Up @@ -254,7 +265,14 @@ Global Instance show_lvl_MEM : ShowLevel MEM
:= fun m
=> (match m.(mem_bits_access_size) with
| Some n
=> show_lvl_app (fun 'tt => if n =? 8 then "byte" else if n =? 64 then "QWORD PTR" else "BAD SIZE")%N (* TODO: Fix casing and stuff *)
=> show_lvl_app (fun 'tt => if n =? 8 then "byte"
else if n =? 16 then "word"
else if n =? 32 then "dword"
else if n =? 64 then "QWORD PTR"
else if n =? 128 then "XMMWORD PTR"
else if n =? 256 then "YMMWORD PTR"
else if n =? 512 then "ZMMWORD PTR"
else "BAD SIZE")%N (* TODO: Fix casing and stuff *)
| None => show_lvl
end)
(fun 'tt
Expand All @@ -275,11 +293,21 @@ Global Instance show_lvl_MEM : ShowLevel MEM
then "0x08 * " ++ Decimal.show_Z (offset / 8)
else Hex.show_Z offset)
end%Z) in
"[" ++ match m.(mem_base_label) with
| None => reg_part ++ offset_part
| Some l => "((" ++ l ++ offset_part ++ "))"
end
++ "]").
match m.(mem_base_label), m.(mem_base_reg), m.(mem_offset), m.(mem_scale_reg) with
| Some lbl, Some rip, None, None => lbl ++ "[" ++ reg_part ++ offset_part ++ "]"
| Some lbl, _, _, _ => let l_offset := lbl ++ offset_part in
"[" ++
(if reg_part =? ""
then "((" ++ l_offset ++ "))"
else reg_part ++ " + " ++ l_offset)
++ "]"
| None, _, _, _ =>
"[" ++
(if reg_part =? ""
then "((" ++ offset_part ++ "))"
else reg_part ++ " + " ++ offset_part)
++ "]"
end).
Global Instance show_MEM : Show MEM := show_lvl_MEM.

Global Instance show_lvl_JUMP_LABEL : ShowLevel JUMP_LABEL
Expand Down
Loading

0 comments on commit 461c421

Please sign in to comment.