Skip to content

Commit

Permalink
Add support for more assembly
Browse files Browse the repository at this point in the history
I want to be able to handle the output of gcc/clang on our C code.

Right now, I have added support for parsing the assembly output.
Equivalence checking is still to come.
  • Loading branch information
JasonGross committed Mar 1, 2025
1 parent 847d6a7 commit e6c475f
Show file tree
Hide file tree
Showing 30 changed files with 24,328 additions and 121 deletions.
33 changes: 31 additions & 2 deletions src/Assembly/Equality.v
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,39 @@ Declare Scope REG_scope.
Delimit Scope REG_scope with REG.
Bind Scope REG_scope with REG.


Definition REG_beq (r1 r2 : REG) : bool :=
(prod_beq _ _ (prod_beq _ _ N.eqb N.eqb) N.eqb)
(index_and_shift_and_bitcount_of_reg r1) (index_and_shift_and_bitcount_of_reg r2).

Lemma REG_dec_lb : forall r1 r2 : REG, r1 = r2 -> REG_beq r1 r2 = true.
Proof.
intros r1 r2 H.
subst r2; destruct r1.
all: reflexivity.
Defined.

Lemma REG_dec_bl : forall r1 r2 : REG, REG_beq r1 r2 = true -> r1 = r2.
Proof.
cbv [REG_beq].
intros r1 r2 H.
rewrite <- (reg_of_index_and_shift_and_bitcount_of_index_and_shift_and_bitcount_of_reg r1), <- (reg_of_index_and_shift_and_bitcount_of_index_and_shift_and_bitcount_of_reg r2).
reflect_hyps.
rewrite H.
reflexivity.
Defined.

Definition REG_eq_dec (x y : REG) : {x = y} + {x <> y} :=
(if REG_beq x y as b return (REG_beq x y = b -> _)
then fun pf => left (REG_dec_bl x y pf)
else fun pf => right (fun pf' => diff_false_true (eq_trans (eq_sym pf) (REG_dec_lb x y pf'))))
eq_refl.

Infix "=?" := REG_beq : REG_scope.

Global Instance REG_beq_spec : reflect_rel (@eq REG) REG_beq | 10
:= reflect_of_beq internal_REG_dec_bl internal_REG_dec_lb.
Definition REG_beq_eq x y : (x =? y)%REG = true <-> x = y := conj (@internal_REG_dec_bl _ _) (@internal_REG_dec_lb _ _).
:= reflect_of_beq REG_dec_bl REG_dec_lb.
Definition REG_beq_eq x y : (x =? y)%REG = true <-> x = y := conj (@REG_dec_bl _ _) (@REG_dec_lb _ _).
Lemma REG_beq_neq x y : (x =? y)%REG = false <-> x <> y.
Proof. rewrite <- REG_beq_eq; destruct (x =? y)%REG; intuition congruence. Qed.
Global Instance REG_beq_compat : Proper (eq ==> eq ==> eq) REG_beq | 10.
Expand Down Expand Up @@ -108,6 +136,7 @@ Bind Scope MEM_scope with MEM.

Definition MEM_beq (x y : MEM) : bool
:= ((option_beq AccessSize_beq x.(mem_bits_access_size) y.(mem_bits_access_size))
&& option_beq String.eqb x.(mem_constant_location_label) y.(mem_constant_location_label)
&& option_beq String.eqb x.(mem_base_label) y.(mem_base_label)
&& (option_beq REG_beq x.(mem_base_reg) y.(mem_base_reg))
&& (option_beq (prod_beq _ _ Z.eqb REG_beq) x.(mem_scale_reg) y.(mem_scale_reg))
Expand Down
5 changes: 3 additions & 2 deletions src/Assembly/Equivalence.v
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ From Coq Require Import ZArith.
From Coq Require Import NArith.
Require Import Crypto.Assembly.Syntax.
Require Import Crypto.Assembly.Parse.
Require Import Crypto.Assembly.Equality.
Require Import Crypto.Assembly.Symbolic.
Require Import Crypto.Util.Strings.Parse.Common.
Require Import Crypto.Util.ErrorT.
Expand Down Expand Up @@ -277,8 +278,8 @@ Definition show_annotated_Line : Show AnnotatedLine
end)%string.

Global Instance show_lines_AnnotatedLines : ShowLines AnnotatedLines
:= fun '(ls, ss)
=> let d := dag.eager.force ss.(dag_state) in
:= fun '(ls, sst)
=> let d := dag.eager.force sst.(dag_state) in
List.map (fun l => show_annotated_Line (l, d)) ls.

Fixpoint remove_common_indices {T} (eqb : T -> T -> bool) (xs ys : list T) (start_idx : nat) : list (nat * T) * list T
Expand Down
33 changes: 27 additions & 6 deletions src/Assembly/Parse.v
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,20 @@ Definition parse_label : ParserAction string
(fun '(char, ls) => string_of_list_ascii (char :: ls))
(([a-zA-Z] || parse_any_ascii "._?$") ;;
(([a-zA-Z] || parse_any_ascii "0123456789_$#@~.?")* )).
Definition parse_non_access_size_label : ParserAction string
:= parse_lookahead_not parse_AccessSize ;;R parse_label.

Definition parse_MEM : ParserAction MEM
:= parse_map
(fun '(access_size, (br (*base reg*), sr (*scale reg, including z *), offset, base_label))
(fun '(access_size, (constant_location_label, (br (*base reg*), sr (*scale reg, including z *), offset, base_label)))
=> {| mem_bits_access_size := access_size:option AccessSize
; mem_base_reg := br:option REG
; mem_constant_location_label := constant_location_label:option string
; mem_base_label := base_label
; mem_scale_reg := sr:option (Z * REG)
; mem_offset := offset:option Z |})
(((strip_whitespace_after parse_AccessSize)?) ;;
(parse_non_access_size_label?) ;;
(parse_option_list_map
(fun '(offset, vars)
=> (vars <-- List.map (fun '(c, (v, e), vs) => match vs, e with [], 1%Z => Some (c, v) | _, _ => None end) vars;
Expand Down Expand Up @@ -160,7 +164,8 @@ Definition parse_OpCode_list : list (string * OpCode)
:= Eval vm_compute in
List.map
(fun r => (show r, r))
(list_all OpCode).
(list_all OpCode)
++ [(".quad", dq); (".word", dw); (".byte", db)].

Definition parse_OpCode : ParserAction OpCode
:= parse_strs_case_insensitive parse_OpCode_list.
Expand Down Expand Up @@ -254,11 +259,23 @@ Global Instance show_lvl_MEM : ShowLevel MEM
:= fun m
=> (match m.(mem_bits_access_size) with
| Some n
=> show_lvl_app (fun 'tt => if n =? 8 then "byte" else if n =? 64 then "QWORD PTR" else "BAD SIZE")%N (* TODO: Fix casing and stuff *)
=> show_lvl_app (fun 'tt => if n =? 8 then "byte"
else if n =? 16 then "word"
else if n =? 32 then "dword"
else if n =? 64 then "QWORD PTR"
else if n =? 128 then "XMMWORD PTR"
else if n =? 256 then "YMMWORD PTR"
else if n =? 512 then "ZMMWORD PTR"
else "BAD SIZE")%N (* TODO: Fix casing and stuff *)
| None => show_lvl
end)
(fun 'tt
=> let reg_part
=> let label_part :=
match m.(mem_constant_location_label) with
| None => ""
| Some l => l
end in
let reg_part
:= (match m.(mem_base_reg), m.(mem_scale_reg) with
| (*"[Reg]" *) Some br, None => show_REG br
| (*"[Reg + Z * Reg]"*) Some br, Some (z, sr) => show_REG br ++ " + " ++ Decimal.show_Z z ++ " * " ++ show_REG sr (*only matching '+' here, because there cannot be a negative scale. *)
Expand All @@ -275,9 +292,13 @@ Global Instance show_lvl_MEM : ShowLevel MEM
then "0x08 * " ++ Decimal.show_Z (offset / 8)
else Hex.show_Z offset)
end%Z) in
"[" ++ match m.(mem_base_label) with
label_part ++ "[" ++ match m.(mem_base_label) with
| None => reg_part ++ offset_part
| Some l => "((" ++ l ++ offset_part ++ "))"
| Some l =>
let l_offset := l ++ offset_part in
if reg_part =? ""
then "((" ++ l_offset ++ "))"
else reg_part ++ " + " ++ l_offset
end
++ "]").
Global Instance show_MEM : Show MEM := show_lvl_MEM.
Expand Down
Loading

0 comments on commit e6c475f

Please sign in to comment.