Skip to content

Commit

Permalink
Set array sizes from bounds
Browse files Browse the repository at this point in the history
Fix mit-plv#2040 (I hope)

Still TODO: add a test-case for this based on from_bytes assembly
  • Loading branch information
JasonGross committed Mar 8, 2025
1 parent c177f8e commit 00afedb
Showing 1 changed file with 48 additions and 33 deletions.
81 changes: 48 additions & 33 deletions src/Assembly/Equivalence.v
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ Inductive EquivalenceCheckingError :=
| Internal_error_output_load_failed (_ : option Symbolic.error) (_ : list ((REG + idx) + idx)) (_ : symbolic_state)
| Internal_error_extra_input_arguments (t : API.type) (unused_arguments : list (idx + list idx))
| Internal_error_lingering_memory (_ : symbolic_state)
| Internal_error_LoadOutputs_length_mismatch (outputaddrs : list ((REG + idx) + idx)) (output_types : list (option nat))
| Internal_error_LoadOutputs_length_mismatch (outputaddrs : list ((REG + idx) + idx)) (output_types : list (N * (option nat)))
| Not_enough_registers (num_given num_extra_needed : nat)
| Registers_too_narrow (bad_reg : list REG)
| Callee_saved_registers_too_narrow (bad_reg : list REG)
Expand All @@ -233,6 +233,9 @@ Inductive EquivalenceCheckingError :=
| Expected_const_in_reference_code (_ : idx)
| Expected_power_of_two (w : N) (_ : idx)
| Unknown_array_length (t : base.type)
| Unknown_array_bounds {t : base.type} (bs : list (ZRange.type.base.option.interp t))
| Unknown_scalar_size (t : base.type)
| Invalid_zero_size_array (t : base.type)
| Registers_not_saved (regs : list (REG * (idx (* before *) * idx (* after *)))) (_ : symbolic_state)
| Invalid_arrow_type (t : API.type)
| Invalid_argument_type (t : API.type)
Expand Down Expand Up @@ -266,6 +269,9 @@ Definition symbolic_state_of_EquivalenceCheckingError (e : EquivalenceCheckingEr
| Expected_const_in_reference_code _
| Expected_power_of_two _ _
| Unknown_array_length _
| Unknown_scalar_size _
| Unknown_array_bounds _ _
| Invalid_zero_size_array _
| Invalid_arrow_type _
| Invalid_argument_type _
| Invalid_return_type _
Expand Down Expand Up @@ -708,6 +714,9 @@ Global Instance show_lines_EquivalenceCheckingError : ShowLines EquivalenceCheck
| Invalid_return_type t
=> ["Invalid type for return: " ++ show t]%string
| Unknown_array_length t => ["Unknown array length of type " ++ show t ++ "."]%string
| Unknown_array_bounds t bs => ["Unknown array bounds of type " ++ show t ++ ": " ++ show bs]%string
| Unknown_scalar_size t => ["Unknown scalar size of type " ++ show t ++ "."]%string
| Invalid_zero_size_array t => ["Array of type " ++ show t ++ " has zero size."]%string
| Invalid_arrow_type t => ["Invalid higher order function involving the type " ++ show t ++ "."]%string
| Invalid_higher_order_application var s d f x
=> let __ := @Compilers.ToString.PHOAS.expr.partially_show_expr in
Expand Down Expand Up @@ -780,7 +789,7 @@ Definition RevealWidth (i : idx) : symexM N :=
then symex_return w
else symex_error (Expected_power_of_two s i).

Definition type_spec := list (option nat). (* list of array lengths; None means not an array *)
Definition type_spec := list (N * option nat). (* list of element size in bytes * length; None means not an array *)

(** Convert PHOAS info about types and argument bounds into a simplified specification *)
Fixpoint simplify_base_type
Expand All @@ -789,18 +798,24 @@ Fixpoint simplify_base_type
:= match t return ZRange.type.base.option.interp t -> _ with
| base.type.unit
=> fun 'tt => Success []
| base.type.type_base base.type.Z
=> fun _ => Success [None]
| (base.type.type_base base.type.Z) as t
=> fun r
=> match ZRange.type.base.option.lift_Some r with
| Some r => Success [(Z.to_N (ZRange.type.base.bitwidth r), None)]
| None => Error (Unknown_scalar_size t)
end
| base.type.prod A B
=> fun '(bA, bB)
=> (vA <- simplify_base_type A bA;
vB <- simplify_base_type B bB;
Success (vA ++ vB))
| base.type.list (base.type.type_base base.type.Z)
| (base.type.list (base.type.type_base base.type.Z as tZ)) as t
=> fun b
=> match b with
| None => Error (Unknown_array_length t)
| Some bs => Success [Some (List.length bs)]
=> match b, option_map ZRange.type.base.bitwidth (ZRange.type.base.option.lift_Some b) with
| None, _ => Error (Unknown_array_length t)
| Some b, None => Error (@Unknown_array_bounds tZ b)
| Some nil, _ | _, Some nil => Error (Invalid_zero_size_array t)
| Some _, Some bs => Success [(Z.to_N (List.fold_right Z.max 0%Z bs), Some (List.length bs))]
end
| base.type.type_base _
| base.type.option _
Expand Down Expand Up @@ -829,39 +844,39 @@ Fixpoint simplify_input_type
Definition build_inputarray {descr:description} (len : nat) : dag.M (list idx) :=
List.foldmap (fun _ => merge_fresh_symbol) (List.seq 0 len).

Fixpoint build_inputs {descr:description} (types : type_spec) : dag.M (list (idx + list idx))
Fixpoint build_inputs {descr:description} (types : type_spec) : dag.M (list (N * (idx + list idx)))
:= match types with
| [] => dag.ret []
| None :: tys
| (sz, None) :: tys
=> (idx <- merge_fresh_symbol;
rest <- build_inputs tys;
dag.ret (inl idx :: rest))
| Some len :: tys
dag.ret ((sz, inl idx) :: rest))
| (sz, Some len) :: tys
=> (idxs <- build_inputarray len;
rest <- build_inputs tys;
dag.ret (inr idxs :: rest))
dag.ret ((sz, inr idxs) :: rest))
end%dagM.

(* we factor this out so that conversion is not slow when proving things about this *)
Definition compute_array_address {opts : symbolic_options_computed_opt} {descr:description} (base : idx) (i : nat)
:= (offset <- Symbolic.App (zconst 64%N (8 * Z.of_nat i), nil);
Definition compute_array_address {opts : symbolic_options_computed_opt} {descr:description} {bytes_per_element : N} (base : idx) (i : nat)
:= (offset <- Symbolic.App (zconst 64%N (Z.of_N bytes_per_element * Z.of_nat i), nil);
Symbolic.App (add 64%N, [base; offset]))%x86symex.

Definition build_merge_array_addresses {opts : symbolic_options_computed_opt} {descr:description} (base : idx) (items : list idx) : M (list idx)
Definition build_merge_array_addresses {opts : symbolic_options_computed_opt} {descr:description} {bytes_per_element : N} (base : idx) (items : list idx) : M (list idx)
:= mapM (fun '(i, idx) =>
(addr <- compute_array_address base i;
(addr <- compute_array_address (bytes_per_element:=bytes_per_element) base i;
(fun s => Success (addr, update_mem_with s (cons (addr,idx)))))
)%x86symex (List.enumerate items).

Fixpoint build_merge_base_addresses {opts : symbolic_options_computed_opt} {descr:description} {dereference_scalar:bool} (items : list (idx + list idx)) (reg_available : list REG) : M (list ((REG + idx) + idx))
Fixpoint build_merge_base_addresses {opts : symbolic_options_computed_opt} {descr:description} {dereference_scalar:bool} (items : list (N * (idx + list idx))) (reg_available : list REG) : M (list ((REG + idx) + idx))
:= match items, reg_available with
| [], _ | _, [] => Symbolic.ret []
| inr idxs :: xs, r :: reg_available
| (sz, inr idxs) :: xs, r :: reg_available
=> (base <- SetRegFresh r; (* note: overwrites initial value *)
addrs <- build_merge_array_addresses base idxs; (* note: overwrites initial value *)
addrs <- build_merge_array_addresses (bytes_per_element:=sz) base idxs; (* note: overwrites initial value *)
rest <- build_merge_base_addresses (dereference_scalar:=dereference_scalar) xs reg_available;
Symbolic.ret (inr base :: rest))
| inl idx :: xs, r :: reg_available =>
| (_sz, inl idx) :: xs, r :: reg_available =>
(addr <- (if dereference_scalar
then
(addr <- SetRegFresh r;
Expand Down Expand Up @@ -1273,10 +1288,10 @@ Definition symex_PHOAS_PHOAS {opts : symbolic_options_computed_opt} {t} (expr :
Definition symex_PHOAS
{opts : symbolic_options_computed_opt}
{t} (expr : API.Expr t)
(inputs : list (idx + list idx))
(inputs : list (N * (idx + list idx)))
(d : dag)
: ErrorT EquivalenceCheckingError (list (idx + list idx) * dag)
:= (input_var_data <- build_input_var t inputs;
:= (input_var_data <- build_input_var t (List.map snd inputs);
let '(input_var_data, unused_inputs) := input_var_data in
_ <- (if (List.length unused_inputs =? 0)%nat
then Success tt
Expand Down Expand Up @@ -1307,18 +1322,18 @@ Definition build_merge_stack_placeholders {opts : symbolic_options_computed_opt}
: M idx
:= (stack_placeholders <- lift_dag (build_inputarray stack_size);
stack_base <- compute_stack_base stack_size;
_ <- build_merge_array_addresses stack_base stack_placeholders;
_ <- build_merge_array_addresses (bytes_per_element:=64%N) stack_base stack_placeholders;
ret stack_base)%x86symex.

Definition LoadArray {opts : symbolic_options_computed_opt} {descr:description} (base : idx) (len : nat) : M (list idx)
Definition LoadArray {opts : symbolic_options_computed_opt} {descr:description} {bytes_per_element : N} (base : idx) (len : nat) : M (list idx)
:= mapM (fun i =>
(addr <- compute_array_address base i;
(addr <- compute_array_address (bytes_per_element:=bytes_per_element) base i;
Remove64 addr)%x86symex)
(seq 0 len).

Definition LoadOutputs_internal {opts : symbolic_options_computed_opt} {descr:description} {dereference_scalar:bool} (outputaddrs : list ((REG + idx) + idx)) (output_types : type_spec)
: M (list (idx + list idx))
:= (mapM (fun '(ocells, spec) =>
:= (mapM (fun '(ocells, (sz, spec)) =>
match ocells, spec with
| inl _, Some _ | inr _, None => err (error.unsupported_memory_access_size 0)
| inl addr, None
Expand All @@ -1331,7 +1346,7 @@ Definition LoadOutputs_internal {opts : symbolic_options_computed_opt} {descr:de
end;
ret (inl v))
| inr base, Some len
=> (v <- LoadArray base len;
=> (v <- LoadArray (bytes_per_element:=sz) base len;
ret (inr v))
end) (List.combine outputaddrs output_types))%N%x86symex.

Expand All @@ -1358,7 +1373,7 @@ Definition symex_asm_func_M
{dereference_output_scalars:bool}
(callee_saved_registers : list REG)
(output_types : type_spec) (stack_size : nat)
(inputs : list (idx + list idx)) (reg_available : list REG) (asm : Lines)
(inputs : list (N * (idx + list idx))) (reg_available : list REG) (asm : Lines)
: M (ErrorT EquivalenceCheckingError (list (idx + list idx)))
:= (output_placeholders <- lift_dag (build_inputs (descr:=Build_description "output_placeholders" true) output_types);
let n_outputs := List.length output_placeholders in
Expand All @@ -1369,12 +1384,12 @@ Definition symex_asm_func_M
initial_register_values <- mapM (GetReg (descr:=Build_description "initial_register_values" true)) callee_saved_registers;
_ <- SymexLines asm;
final_register_values <- mapM (GetReg (descr:=Build_description "final_register_values" true)) callee_saved_registers;
_ <- LoadArray (descr:=Build_description "load final stack" true) stack_base stack_size;
_ <- LoadArray (descr:=Build_description "load final stack" true) (bytes_per_element:=64%N) stack_base stack_size;
let unsaved_registers : list (REG * (idx * idx)) := List.filter (fun '(r, (init, final)) => negb (init =? final)%N) (List.combine callee_saved_registers (List.combine initial_register_values final_register_values)) in
asm_output <- LoadOutputs (descr:=Build_description "asm_output" true) (dereference_scalar:=dereference_output_scalars) outputaddrs output_types;
(* also load inputs, for the sake of the proof *)
(* reconstruct input types *)
let input_types := List.map (fun v => match v with inl _ => None | inr ls => Some (List.length ls) end) inputs in
let input_types := List.map (fun '(sz, v) => (sz, match v with inl _ => None | inr ls => Some (List.length ls) end)) inputs in
asm_input <- LoadOutputs (descr:=Build_description "asm_input <- LoadOutputs" true) (dereference_scalar:=dereference_input_scalars) inputaddrs input_types;
(fun s => Success
(match asm_output, asm_input, unsaved_registers, s.(symbolic_mem_state) with
Expand All @@ -1394,7 +1409,7 @@ Definition symex_asm_func
{opts : symbolic_options_computed_opt}
{dereference_output_scalars:bool}
(d : dag) (callee_saved_registers : list REG) (output_types : type_spec) (stack_size : nat)
(inputs : list (idx + list idx)) (reg_available : list REG) (asm : Lines)
(inputs : list (N * (idx + list idx))) (reg_available : list REG) (asm : Lines)
: ErrorT EquivalenceCheckingError (list (idx + list idx) * symbolic_state)
:= let num_reg_given := List.length reg_available in
let num_reg_needed := List.length inputs + List.length output_types in
Expand Down Expand Up @@ -1452,7 +1467,7 @@ Section check_equivalence.
Local Notation map_err_None v := (ErrorT.map_error (fun e => (None, e)) v).
Local Notation map_err_Some label v := (ErrorT.map_error (fun e => (Some label, e)) v).

Definition map_symex_asm (inputs : list (idx + list idx)) (output_types : type_spec) (d : dag)
Definition map_symex_asm (inputs : list (N * (idx + list idx))) (output_types : type_spec) (d : dag)
: ErrorT
(option (string (* fname *) * Lines (* asm lines *)) * EquivalenceCheckingError)
(list ((string (* fname *) * Lines (* asm lines *)) * (list (idx + list idx) * symbolic_state))) :=
Expand Down

0 comments on commit 00afedb

Please sign in to comment.