From a061024555ef8561ab3979e8d1bb23844bb75037 Mon Sep 17 00:00:00 2001 From: Dustin Jamner Date: Mon, 7 Nov 2022 15:17:41 -0500 Subject: [PATCH] Prove remaining theorems in Broadcast --- .../End2End/RupicolaCrypto/Broadcast.v | 209 +++++++++++++----- 1 file changed, 155 insertions(+), 54 deletions(-) diff --git a/src/Bedrock/End2End/RupicolaCrypto/Broadcast.v b/src/Bedrock/End2End/RupicolaCrypto/Broadcast.v index 6687c58858..15c15af3d5 100644 --- a/src/Bedrock/End2End/RupicolaCrypto/Broadcast.v +++ b/src/Bedrock/End2End/RupicolaCrypto/Broadcast.v @@ -382,9 +382,9 @@ Section with_parameters. broadcast_expr l idx_var scratch a_ptr R lst_expr v -> (let v := v in - forall m, + forall tr idx to m, (v$@a_ptr* R)%sep m -> - <{ Trace := t; Memory := m; Locals := l; Functions := e }> + <{ Trace := tr; Memory := m; Locals := map.put (map.put l idx_var idx) to_var to; Functions := e }> k_impl <{ pred (k v eq_refl) }>) -> <{ Trace := t; Memory := m; Locals := l; Functions := e }> @@ -544,42 +544,47 @@ Section with_parameters. intros. destruct H10. eapply H7 in H10. - Admitted. - - Lemma compile_broadcast_expr {t m l e} (len : nat) (lst scratch : list T) : - let v := lst in - forall P (pred: P v -> predicate) (k: nlet_eq_k P v) k_impl - (a_ptr : word) a_var lst_expr idx_var to_var R, - - DEXPR m l (expr.var a_var) a_ptr -> - - (scratch$@a_ptr * R)%sep m -> - - len = length scratch -> - len = length lst -> - len < 2^width -> - - ~idx_var = to_var -> - map.get l idx_var = None -> - map.get l to_var = None -> - - broadcast_expr l idx_var scratch a_ptr R lst_expr v -> - (let v := v in - forall m, - (v$@a_ptr* R)%sep m -> - <{ Trace := t; Memory := m; Locals := l; Functions := e }> - k_impl - <{ pred (k v eq_refl) }>) -> - <{ Trace := t; Memory := m; Locals := l; Functions := e }> - cmd_loop_fresh false idx_var (expr.literal 0) to_var len - (cmd.store szT (expr.op bopname.add a_var - (expr.op bopname.mul idx_var sz_word)) - lst_expr) - k_impl - <{ pred (nlet_eq [a_var] v k) }>. - Proof using T_Fits_ok env_ok ext_spec_ok locals_ok mem_ok word_ok. - eauto using compile_broadcast_expr'. - Qed. + change (-1) with (0-1). + unfold v in H10. + subst locals0. + exact H10. + } + Qed. + + Lemma compile_broadcast_expr {t m l e} (len : nat) (lst scratch : list T) : + let v := lst in + forall P (pred: P v -> predicate) (k: nlet_eq_k P v) k_impl + (a_ptr : word) a_var lst_expr idx_var to_var R, + + DEXPR m l (expr.var a_var) a_ptr -> + + (scratch$@a_ptr * R)%sep m -> + + len = length scratch -> + len = length lst -> + len < 2^width -> + + ~idx_var = to_var -> + map.get l idx_var = None -> + map.get l to_var = None -> + + broadcast_expr l idx_var scratch a_ptr R lst_expr v -> + (let v := v in + forall tr idx to m, + (v$@a_ptr* R)%sep m -> + <{ Trace := tr; Memory := m; Locals := map.put (map.put l idx_var idx) to_var to; Functions := e }> + k_impl + <{ pred (k v eq_refl) }>) -> + <{ Trace := t; Memory := m; Locals := l; Functions := e }> + cmd_loop_fresh false idx_var (expr.literal 0) to_var len + (cmd.store szT (expr.op bopname.add a_var + (expr.op bopname.mul idx_var sz_word)) + lst_expr) + k_impl + <{ pred (nlet_eq [a_var] v k) }>. + Proof using T_Fits_ok env_ok ext_spec_ok locals_ok mem_ok word_ok. + eauto using compile_broadcast_expr'. + Qed. Section Binops. @@ -669,7 +674,8 @@ Section with_parameters. | [H : context [predT ?ptr1] |- context [ truncated_word _ ?ptr2]] => replace ptr2 with ptr1 end. - ecancel_assumption_impl. + seprewrite_in predT_to_truncated_word H3. + ecancel_assumption. f_equal. rewrite Z.mul_comm. rewrite word.ring_morph_mul. @@ -677,9 +683,17 @@ Section with_parameters. Qed. + Lemma split_hd_tl {A} (a:A) (l:list A) + : 0 < length l -> + l = hd a l :: tl l. + Proof. + destruct l; simpl in *; [lia | auto]. + Qed. + Lemma broadcast_var l idx_var scratch a_ptr b_ptr R a_var a_data : map.get l a_var = Some a_ptr -> ~a_var = idx_var -> + length scratch <= length a_data -> let R' := (a_data$@a_ptr ⋆ R) in broadcast_expr l idx_var scratch b_ptr R' @@ -688,8 +702,48 @@ Section with_parameters. (expr.op bopname.mul idx_var sz_word))) a_data. Proof using T_Fits_ok locals_ok mem_ok word_ok. - Admitted. - + unfold broadcast_expr; intuition idtac. + repeat straightline. + exists a_ptr; intuition idtac. + { + rewrite map.get_put_diff by assumption. + assumption. + } + exists (word.of_Z (Z.of_nat (length lstl))). + intuition idtac. + { + rewrite map.get_put_same; eauto. + } + simpl. + unfold WeakestPrecondition.literal. + cbv [dlet]. + exists (word_of_T (nth (length lstl) a_data default)). + split; auto. + erewrite load_of_sep. + { + erewrite truncate_of_T. + reflexivity. + } + seprewrite_in (array_append (T:=T) predT sz_word) H4. + replace (nth (length lstl) scratch) + with (nth ((length lstl)+0) scratch) by (f_equal;lia). + seprewrite_in map_predT_to_truncated_word H4. + seprewrite_in map_predT_to_truncated_word H4. + seprewrite_in map_predT_to_truncated_word H4. + rewrite <- (firstn_skipn (length lstl) a_data) in H4. + rewrite map_app in H4. + seprewrite_in (array_append (T:=word)) H4. + rewrite map_length in H4. + rewrite firstn_length in H4. + replace ((Init.Nat.min (length lstl) (length a_data))) with (length lstl) in H4 by lia. + rewrite (split_hd_tl default (skipn (length lstl) a_data)) in H4 by (rewrite skipn_length; lia). + simpl in H4. + rewrite Z.mul_comm in H4. + rewrite word.ring_morph_mul in H4. + rewrite <- hd_skipn_nth_default in H4. + rewrite nth_default_eq in H4. + ecancel_assumption. + Qed. End WithAccessSize. @@ -737,8 +791,10 @@ Section with_parameters. rewrite byte_and_xff. reflexivity. } + intros ptr t m. + split. { - intros ptr t m H. + intro H. unfold truncated_word, truncated_scalar. cbn. rewrite word.unsigned_of_Z_nowrap. @@ -746,6 +802,15 @@ Section with_parameters. ecancel_assumption. apply byte_in_word_bounds. } + { + intro H. + unfold truncated_word, truncated_scalar in H. + cbn in H. + rewrite word.unsigned_of_Z_nowrap in H. + rewrite word.byte_of_Z_unsigned in H. + ecancel_assumption. + apply byte_in_word_bounds. + } Qed. @@ -756,6 +821,9 @@ Section with_parameters. word_of_T b := b; |}. + (*TODO: where to get this fact from?*) + Axiom width_mul_8 : exists x, width = x * 8. + Instance word_ac_ok : FitsInLocal_ok word word_ac. Proof. constructor; unfold word_of_T, szT, predT, word_ac. @@ -769,23 +837,38 @@ Section with_parameters. rewrite !word.of_Z_unsigned. rewrite Z2Nat.id. unfold Memory.bytes_per_word. - replace ((width + 7) / 8 * 8) with width by admit. - admit (*TODO: a similar proof lives in CSwap.v*). - admit (*easy*). + replace ((width + 7) / 8 * 8) with width. + { + rewrite <- (word.of_Z_unsigned t). + rewrite <- word.morph_and. + rewrite word.of_Z_land_ones. + auto. + } + { + pose proof width_mul_8 as Hw; destruct Hw as [x Hw]; subst. + lia. + } + { + unfold Memory.bytes_per_word. + pose proof (word.width_pos). + pose proof width_mul_8 as Hw; destruct Hw as [x Hw]; subst. + lia. + } } { - intros ptr t m H. - unfold scalar in *. - assumption. + intros ptr t m; + split; + unfold scalar in *; + auto. } - Admitted. - - + Qed. (*TODO: define in general*) Lemma broadcast_byte_const l (idx_var : string) scratch a_ptr R (const_list : list byte) - : broadcast_expr byte l idx_var scratch a_ptr R + : length scratch <= length const_list -> + length scratch <= 2^width -> + broadcast_expr byte l idx_var scratch a_ptr R (expr.inlinetable access_size.one const_list @@ -803,9 +886,27 @@ Section with_parameters. exists (word_of_byte (nth (length lstl) const_list x00)). split; auto. eapply load_one_of_sep. - (*TODO: preds for const list*) - Admitted. + instantiate (1:= fun _ => True). + exists (map.put map.empty (word.of_Z (Z.of_nat (length lstl))) (nth (length lstl) const_list x00)). + exists (map.remove (OfListWord.map.of_list_word const_list) (word.of_Z (Z.of_nat (length lstl)))). + intuition idtac. + { + eapply map.split_comm. + eapply map.split_remove_put. + rewrite map.split_empty_r; auto. + rewrite OfListWord.map.get_of_list_word. + rewrite word.unsigned_of_Z. + rewrite word.wrap_small. + rewrite Nat2Z.id. + eapply nth_error_nth'. + lia. + split; try lia. + } + { + reflexivity. + } + Qed. Lemma broadcast_byte_and l idx_var scratch a_ptr R l1_expr l2_expr (l1 l2 : list byte)