Skip to content

Commit

Permalink
Add support for more assembly
Browse files Browse the repository at this point in the history
I want to be able to handle the output of gcc/clang on our C code.

Right now, I have added support for parsing the assembly output.
Equivalence checking is still to come.
  • Loading branch information
JasonGross committed Mar 2, 2025
1 parent 4c8add8 commit 8460918
Show file tree
Hide file tree
Showing 33 changed files with 24,468 additions and 345 deletions.
20 changes: 10 additions & 10 deletions src/Assembly/Equality.v
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ Bind Scope REG_scope with REG.
Infix "=?" := REG_beq : REG_scope.

Global Instance REG_beq_spec : reflect_rel (@eq REG) REG_beq | 10
:= reflect_of_beq internal_REG_dec_bl internal_REG_dec_lb.
Definition REG_beq_eq x y : (x =? y)%REG = true <-> x = y := conj (@internal_REG_dec_bl _ _) (@internal_REG_dec_lb _ _).
:= reflect_of_beq REG_dec_bl REG_dec_lb.
Definition REG_beq_eq x y : (x =? y)%REG = true <-> x = y := conj (@REG_dec_bl _ _) (@REG_dec_lb _ _).
Lemma REG_beq_neq x y : (x =? y)%REG = false <-> x <> y.
Proof. rewrite <- REG_beq_eq; destruct (x =? y)%REG; intuition congruence. Qed.
Global Instance REG_beq_compat : Proper (eq ==> eq ==> eq) REG_beq | 10.
Expand Down Expand Up @@ -95,8 +95,8 @@ Bind Scope AccessSize_scope with AccessSize.
Infix "=?" := AccessSize_beq : AccessSize_scope.

Global Instance AccessSize_beq_spec : reflect_rel (@eq AccessSize) AccessSize_beq | 10
:= reflect_of_beq internal_AccessSize_dec_bl internal_AccessSize_dec_lb.
Definition AccessSize_beq_eq x y : (x =? y)%AccessSize = true <-> x = y := conj (@internal_AccessSize_dec_bl _ _) (@internal_AccessSize_dec_lb _ _).
:= reflect_of_beq AccessSize_dec_bl AccessSize_dec_lb.
Definition AccessSize_beq_eq x y : (x =? y)%AccessSize = true <-> x = y := conj (@AccessSize_dec_bl _ _) (@AccessSize_dec_lb _ _).
Lemma AccessSize_beq_neq x y : (x =? y)%AccessSize = false <-> x <> y.
Proof. rewrite <- AccessSize_beq_eq; destruct (x =? y)%AccessSize; intuition congruence. Qed.
Global Instance AccessSize_beq_compat : Proper (eq ==> eq ==> eq) AccessSize_beq | 10.
Expand Down Expand Up @@ -141,8 +141,8 @@ Bind Scope FLAG_scope with FLAG.
Infix "=?" := FLAG_beq : FLAG_scope.

Global Instance FLAG_beq_spec : reflect_rel (@eq FLAG) FLAG_beq | 10
:= reflect_of_beq internal_FLAG_dec_bl internal_FLAG_dec_lb.
Definition FLAG_beq_eq x y : (x =? y)%FLAG = true <-> x = y := conj (@internal_FLAG_dec_bl _ _) (@internal_FLAG_dec_lb _ _).
:= reflect_of_beq FLAG_dec_bl FLAG_dec_lb.
Definition FLAG_beq_eq x y : (x =? y)%FLAG = true <-> x = y := conj (@FLAG_dec_bl _ _) (@FLAG_dec_lb _ _).
Lemma FLAG_beq_neq x y : (x =? y)%FLAG = false <-> x <> y.
Proof. rewrite <- FLAG_beq_eq; destruct (x =? y)%FLAG; intuition congruence. Qed.
Global Instance FLAG_beq_compat : Proper (eq ==> eq ==> eq) FLAG_beq | 10.
Expand All @@ -155,8 +155,8 @@ Bind Scope OpCode_scope with OpCode.
Infix "=?" := OpCode_beq : OpCode_scope.

Global Instance OpCode_beq_spec : reflect_rel (@eq OpCode) OpCode_beq | 10
:= reflect_of_beq internal_OpCode_dec_bl internal_OpCode_dec_lb.
Definition OpCode_beq_eq x y : (x =? y)%OpCode = true <-> x = y := conj (@internal_OpCode_dec_bl _ _) (@internal_OpCode_dec_lb _ _).
:= reflect_of_beq OpCode_dec_bl OpCode_dec_lb.
Definition OpCode_beq_eq x y : (x =? y)%OpCode = true <-> x = y := conj (@OpCode_dec_bl _ _) (@OpCode_dec_lb _ _).
Lemma OpCode_beq_neq x y : (x =? y)%OpCode = false <-> x <> y.
Proof. rewrite <- OpCode_beq_eq; destruct (x =? y)%OpCode; intuition congruence. Qed.
Global Instance OpCode_beq_compat : Proper (eq ==> eq ==> eq) OpCode_beq | 10.
Expand All @@ -169,8 +169,8 @@ Bind Scope OpPrefix_scope with OpPrefix.
Infix "=?" := OpPrefix_beq : OpPrefix_scope.

Global Instance OpPrefix_beq_spec : reflect_rel (@eq OpPrefix) OpPrefix_beq | 10
:= reflect_of_beq internal_OpPrefix_dec_bl internal_OpPrefix_dec_lb.
Definition OpPrefix_beq_eq x y : (x =? y)%OpPrefix = true <-> x = y := conj (@internal_OpPrefix_dec_bl _ _) (@internal_OpPrefix_dec_lb _ _).
:= reflect_of_beq OpPrefix_dec_bl OpPrefix_dec_lb.
Definition OpPrefix_beq_eq x y : (x =? y)%OpPrefix = true <-> x = y := conj (@OpPrefix_dec_bl _ _) (@OpPrefix_dec_lb _ _).
Lemma OpPrefix_beq_neq x y : (x =? y)%OpPrefix = false <-> x <> y.
Proof. rewrite <- OpPrefix_beq_eq; destruct (x =? y)%OpPrefix; intuition congruence. Qed.
Global Instance OpPrefix_beq_compat : Proper (eq ==> eq ==> eq) OpPrefix_beq | 10.
Expand Down
9 changes: 5 additions & 4 deletions src/Assembly/Equivalence.v
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ From Coq Require Import ZArith.
From Coq Require Import NArith.
Require Import Crypto.Assembly.Syntax.
Require Import Crypto.Assembly.Parse.
Require Import Crypto.Assembly.Equality.
Require Import Crypto.Assembly.Symbolic.
Require Import Crypto.Util.Strings.Parse.Common.
Require Import Crypto.Util.ErrorT.
Expand Down Expand Up @@ -277,8 +278,8 @@ Definition show_annotated_Line : Show AnnotatedLine
end)%string.

Global Instance show_lines_AnnotatedLines : ShowLines AnnotatedLines
:= fun '(ls, ss)
=> let d := dag.eager.force ss.(dag_state) in
:= fun '(ls, sst)
=> let d := dag.eager.force sst.(dag_state) in
List.map (fun l => show_annotated_Line (l, d)) ls.

Fixpoint remove_common_indices {T} (eqb : T -> T -> bool) (xs ys : list T) (start_idx : nat) : list (nat * T) * list T
Expand Down Expand Up @@ -1275,10 +1276,10 @@ Definition init_symbolic_state_descr : description := Build_description "init_sy

Definition init_symbolic_state (d : dag) : symbolic_state
:= let _ := init_symbolic_state_descr in
let '(initial_reg_idxs, d) := dag_gensym_n 16 d in
let '(initial_reg_idxs, d) := dag_gensym_n (List.length widest_registers) d in
{|
dag_state := d;
symbolic_reg_state := Tuple.from_list_default None 16 (List.map Some initial_reg_idxs);
symbolic_reg_state := Tuple.from_list_default None _ (List.map Some initial_reg_idxs);
symbolic_mem_state := [];
symbolic_flag_state := Tuple.repeat None 6;
|}.
Expand Down
4 changes: 2 additions & 2 deletions src/Assembly/EquivalenceProofs.v
Original file line number Diff line number Diff line change
Expand Up @@ -1847,7 +1847,7 @@ Qed.
(* TODO: this is Symbolic.get_reg; move to SymbolicProofs? *)
Lemma get_reg_set_reg_full s rn rn' v
: get_reg (set_reg s rn v) rn'
= if ((rn <? ((fun n (_ : Tuple.tuple _ n) => n) _ s)) && (rn =? rn'))%nat%bool
= if ((rn <? ((fun n (_ : Tuple.tuple _ n) => N.of_nat n) _ s)) && (rn =? rn'))%N%bool
then Some v
else get_reg s rn'.
Proof.
Expand All @@ -1863,7 +1863,7 @@ Qed.

(* TODO: this is Symbolic.get_reg; move to SymbolicProofs? *)
Local Lemma get_reg_set_reg_same s rn v
(H : (rn < (fun n (_ : Tuple.tuple _ n) => n) _ s)%nat)
(H : (rn < (fun n (_ : Tuple.tuple _ n) => N.of_nat n) _ s)%N)
: get_reg (set_reg s rn v) rn = Some v.
Proof.
rewrite get_reg_set_reg_full; break_innermost_match; reflect_hyps; cbv beta in *; try reflexivity; lia.
Expand Down
157 changes: 118 additions & 39 deletions src/Assembly/Parse.v
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,18 @@ Local Open Scope list_scope.
Local Open Scope string_scope.
Local Open Scope parse_scope.

Derive REG_Listable SuchThat (@FinitelyListable REG REG_Listable) As REG_FinitelyListable.
Proof. prove_ListableDerive. Qed.
Global Existing Instances REG_Listable REG_FinitelyListable.

Global Instance show_REG : Show REG.
Proof. prove_Show_enum (). Defined.
Global Instance show_lvl_REG : ShowLevel REG := show_REG.

Derive FLAG_Listable SuchThat (@FinitelyListable FLAG FLAG_Listable) As FLAG_FinitelyListable.
Proof. prove_ListableDerive. Qed.
Global Existing Instances FLAG_Listable FLAG_FinitelyListable.

Global Instance show_FLAG : Show FLAG.
Proof. prove_Show_enum (). Defined.
Global Instance show_lvl_FLAG : ShowLevel FLAG := show_FLAG.

Derive OpCode_Listable SuchThat (@FinitelyListable OpCode OpCode_Listable) As OpCode_FinitelyListable.
Proof. prove_ListableDerive. Qed.
Global Existing Instances OpCode_Listable OpCode_FinitelyListable.

Global Instance show_OpCode : Show OpCode.
Proof. prove_Show_enum (). Defined.
Global Instance show_lvl_OpCode : ShowLevel OpCode := show_OpCode.

Derive OpPrefix_Listable SuchThat (@FinitelyListable OpPrefix OpPrefix_Listable) As OpPrefix_FinitelyListable.
Proof. prove_ListableDerive. Qed.
Global Existing Instances OpPrefix_Listable OpPrefix_FinitelyListable.

Global Instance show_OpPrefix : Show OpPrefix.
Proof. prove_Show_enum (). Defined.
Global Instance show_lvl_OpPrefix : ShowLevel OpPrefix := show_OpPrefix.
Expand All @@ -72,10 +56,6 @@ Definition parse_FLAG_list : list (string * FLAG)
Definition parse_FLAG : ParserAction FLAG
:= parse_strs parse_FLAG_list.

Derive AccessSize_Listable SuchThat (@FinitelyListable AccessSize AccessSize_Listable) As AccessSize_FinitelyListable.
Proof. prove_ListableDerive. Qed.
Global Existing Instances AccessSize_Listable AccessSize_FinitelyListable.

Global Instance show_AccessSize : Show AccessSize.
Proof. prove_Show_enum (). Defined.
Global Instance show_lvl_AccessSize : ShowLevel AccessSize := show_AccessSize.
Expand All @@ -100,16 +80,26 @@ Definition parse_label : ParserAction string
(fun '(char, ls) => string_of_list_ascii (char :: ls))
(([a-zA-Z] || parse_any_ascii "._?$") ;;
(([a-zA-Z] || parse_any_ascii "0123456789_$#@~.?")* )).
Definition parse_non_access_size_label : ParserAction string
:= parse_lookahead_not parse_AccessSize ;;R parse_label.

Definition parse_MEM : ParserAction MEM
:= parse_map
(fun '(access_size, (br (*base reg*), sr (*scale reg, including z *), offset, base_label))
=> {| mem_bits_access_size := access_size:option AccessSize
; mem_base_reg := br:option REG
; mem_base_label := base_label
; mem_scale_reg := sr:option (Z * REG)
; mem_offset := offset:option Z |})
:= parse_option_list_map
(fun '(access_size, (constant_location_label, (br (*base reg*), sr (*scale reg, including z *), offset, base_label)))
=> match base_label, constant_location_label with
| Some _, Some _ => (* invalid? *) None
| Some _ as lbl, None
| None, Some _ as lbl
| None, None as lbl =>
Some
{| mem_bits_access_size := access_size:option AccessSize
; mem_base_reg := br:option REG
; mem_base_label := lbl
; mem_scale_reg := sr:option (Z * REG)
; mem_offset := offset:option Z |}
end)
(((strip_whitespace_after parse_AccessSize)?) ;;
(parse_non_access_size_label?) ;;
(parse_option_list_map
(fun '(offset, vars)
=> (vars <-- List.map (fun '(c, (v, e), vs) => match vs, e with [], 1%Z => Some (c, v) | _, _ => None end) vars;
Expand Down Expand Up @@ -160,7 +150,13 @@ Definition parse_OpCode_list : list (string * OpCode)
:= Eval vm_compute in
List.map
(fun r => (show r, r))
(list_all OpCode).
(list_all OpCode)
++ [(".byte", db)
; (".word", dw)
; (".long", dd)
; (".int", dd)
; (".quad", dq)
; (".octa", do)].

Definition parse_OpCode : ParserAction OpCode
:= parse_strs_case_insensitive parse_OpCode_list.
Expand Down Expand Up @@ -254,7 +250,14 @@ Global Instance show_lvl_MEM : ShowLevel MEM
:= fun m
=> (match m.(mem_bits_access_size) with
| Some n
=> show_lvl_app (fun 'tt => if n =? 8 then "byte" else if n =? 64 then "QWORD PTR" else "BAD SIZE")%N (* TODO: Fix casing and stuff *)
=> show_lvl_app (fun 'tt => if n =? 8 then "byte"
else if n =? 16 then "word"
else if n =? 32 then "dword"
else if n =? 64 then "QWORD PTR"
else if n =? 128 then "XMMWORD PTR"
else if n =? 256 then "YMMWORD PTR"
else if n =? 512 then "ZMMWORD PTR"
else "BAD SIZE")%N (* TODO: Fix casing and stuff *)
| None => show_lvl
end)
(fun 'tt
Expand All @@ -275,11 +278,21 @@ Global Instance show_lvl_MEM : ShowLevel MEM
then "0x08 * " ++ Decimal.show_Z (offset / 8)
else Hex.show_Z offset)
end%Z) in
"[" ++ match m.(mem_base_label) with
| None => reg_part ++ offset_part
| Some l => "((" ++ l ++ offset_part ++ "))"
end
++ "]").
match m.(mem_base_label), m.(mem_base_reg), m.(mem_offset), m.(mem_scale_reg) with
| Some lbl, Some rip, None, None => lbl ++ "[" ++ reg_part ++ offset_part ++ "]"
| Some lbl, _, _, _ => let l_offset := lbl ++ offset_part in
"[" ++
(if reg_part =? ""
then "((" ++ l_offset ++ "))"
else reg_part ++ " + " ++ l_offset)
++ "]"
| None, _, _, _ =>
"[" ++
(if reg_part =? ""
then "((" ++ offset_part ++ "))"
else reg_part ++ offset_part)
++ "]"
end).
Global Instance show_MEM : Show MEM := show_lvl_MEM.

Global Instance show_lvl_JUMP_LABEL : ShowLevel JUMP_LABEL
Expand Down Expand Up @@ -498,20 +511,86 @@ Definition find_globals (ls : Lines) : list string
end)
ls.

Fixpoint split_code_to_functions' (globals : list string) (ls : Lines) : Lines (* prefix *) * list (string (* global name *) * Lines)
Definition find_labels (ls : Lines) : list string
:= Option.List.map
(fun l => match l.(rawline) with
| LABEL name => Some name
| _ => None
end)
ls.

Fixpoint split_code_to_functions' (label_is_function : string -> bool) (ls : Lines) : Lines (* prefix *) * list (string (* global name *) * Lines)
:= match ls with
| [] => ([], [])
| l :: ls
=> let '(prefix, rest) := split_code_to_functions' globals ls in
=> let '(prefix, rest) := split_code_to_functions' label_is_function ls in
let default := (l :: prefix, rest) in
match l.(rawline) with
| LABEL name => if List.existsb (fun n => name =? n)%string globals
| LABEL name => if label_is_function name
then ([], (name, l::prefix) :: rest)
else default
| _ => default
end
end.

Definition split_code_to_functions (ls : Lines) : Lines (* prefix *) * list (string (* global name *) * Lines)
Definition string_matches_loose (allow_prefix : bool) (allow_suffix : bool) (longer_string shorter_string : string) : bool
:= match allow_prefix, allow_suffix with
| false, false => shorter_string =? longer_string
| true, false => String.endswith shorter_string longer_string
| false, true => String.startswith shorter_string longer_string
| true, true => String.is_substring shorter_string longer_string
end.
Definition split_code_to_listed_functions {allow_prefix allow_suffix : bool} (functions : list string) (ls : Lines) : Lines (* prefix *) * list (string (* global name *) * Lines)
:= split_code_to_functions' (fun name => List.existsb (fun f => string_matches_loose allow_prefix allow_suffix f name)%string functions) ls.
Definition split_code_to_global_functions (ls : Lines) : Lines (* prefix *) * list (string (* global name *) * Lines)
:= let globals := find_globals ls in
split_code_to_functions' globals ls.
split_code_to_listed_functions (allow_prefix:=false) (allow_suffix:=false) globals ls.
Definition split_code_at_labels (ls : Lines) : Lines (* prefix *) * list (string (* label name *) * Lines)
:= let labels := find_labels ls in
split_code_to_listed_functions (allow_prefix:=false) (allow_suffix:=false) labels ls.

Fixpoint get_initial_data (ls : Lines) : list (AccessSize * list Z)
:= let get_arg_consts args :=
Option.List.lift
(List.map (fun arg => match arg with
| const c => Some c
| _ => None
end)
args) in
match ls with
| [] => []
| l :: ls
=> match l.(rawline) with
| INSTR instr =>
match accesssize_of_declaration instr.(op) with
| None => []
| Some size =>
let csts := get_arg_consts instr.(args) in
match csts with
| Some csts => (size, csts) :: get_initial_data ls
| None => []
end
end
| LABEL _
| EMPTY
| GLOBAL _
| DEFAULT_REL
=> get_initial_data ls
| SECTION _
| ALIGN _
=> []
end
end.

Definition get_labeled_data (ls : Lines) : list (string * list (AccessSize * list Z)) :=
let '(_, labeled_data) := split_code_at_labels ls in
let labeled_data := List.map (fun '(lbl, lines) => (lbl, get_initial_data lines)) labeled_data in
let labeled_data := List.filter (fun '(_, data) => match data with nil => false | _ => true end) labeled_data in
labeled_data.

Definition parse_assembly_options (ls : Lines) : assembly_program_options
:= {| default_rel := Option.is_Some (List.find (fun l => match l.(rawline) with
| DEFAULT_REL => true
| _ => false
end) ls)
|}.
Loading

0 comments on commit 8460918

Please sign in to comment.