diff --git a/src/Curves/Weierstrass/Jacobian/ScalarMult.v b/src/Curves/Weierstrass/Jacobian/ScalarMult.v index a6fc14de3c..2fe2dba939 100644 --- a/src/Curves/Weierstrass/Jacobian/ScalarMult.v +++ b/src/Curves/Weierstrass/Jacobian/ScalarMult.v @@ -190,9 +190,8 @@ Module ScalarMult. Next Obligation. Proof. generalize (proj2_sig AB); rewrite <- Heq_anonymous. auto. Qed. Next Obligation. Proof. destruct AB as ((A & B) & HAB). simpl. t. Qed. - Program Definition tplu_co_z_points (P : Wpoint) (HPnz : P <> ∞ :> W.point) : co_z_points := - tplu (of_affine P) _. - Next Obligation. Proof. t. Qed. + Program Definition tplu_co_z_points (P : point) (HPaff : z_of P = 1) : co_z_points := + tplu P _. Next Obligation. Proof. t. Qed. Lemma opp_of_affine (P : Wpoint) (HPnz : P <> ∞ :> Wpoint) : @@ -202,13 +201,13 @@ Module ScalarMult. (* Scalar Multiplication on Weierstraß Elliptic Curves from Co-Z Arithmetic *) (* Goundar, Joye, Miyaji, Rivain, Vanelli *) (* Algorithm 14 Joye’s double-add algorithm with Co-Z addition formulæ *) - (* Adapted *) - Definition joye_ladder (scalarbitsz : Z) (testbit : Z -> bool) (P : Wpoint) - (HPnz : P <> ∞ :> Wpoint) : Wpoint := - to_affine ( + (* This is an adapted version that consumes and returns points in jacobian + coordinates, correctness assumes the scalar is odd (i.e., testbit 0 = true). *) + Definition joye_ladder_inner (scalarbitsz : Z) (testbit : Z -> bool) (P : point) + (HPaff : z_of P = 1) : point := (* Initialization *) let b := testbit 1%Z in - let R1R0 := cswap_co_z_points b (tplu_co_z_points P HPnz) in + let R1R0 := cswap_co_z_points b (tplu_co_z_points P HPaff) in (* loop *) let '(R1R0, _) := (@while (co_z_points*Z) (fun '(_, i) => (Z.ltb i scalarbitsz)) @@ -220,17 +219,24 @@ Module ScalarMult. (R1R0, i)) (Z.to_nat scalarbitsz) (* bound on loop iterations *) (R1R0, 2%Z)) in - (* R0 is now (k | 1) P *) - (* Substract P from R0 if lsb was actually 0 *) - let R0 := snd (proj1_sig R1R0) in + snd (proj1_sig R1R0). + + (* Wrapper around joye_ladder_inner for points in affine coordinates, + it also readjusts the result if the scalar input is even. *) + Program Definition joye_ladder (scalarbitsz : Z) (testbit : Z -> bool) (P : Wpoint) + (HPnz : P <> ∞ :> Wpoint) : Wpoint := + to_affine ( + let P := of_affine P in + let R0 := joye_ladder_inner scalarbitsz testbit P _ in let b := testbit 0%Z in - let mP := opp (of_affine P) in + let mP := opp P in (* Make sure R0 and -P are co-z *) - let R0R1 := make_co_z_points R0 mP (opp_of_affine P HPnz) in + let R0R1 := make_co_z_points R0 mP (opp_of_affine _ HPnz) in (* if b = 0 then R0 <- R0 - P and R1 <- R0 *) (* if b = 1 then R0 <- R0 and R1 <- R0 - P *) let R0 := fst (proj1_sig (cswap_co_z_points b (zaddu_co_z_points R0R1))) in R0). + Next Obligation. Proof. t. Qed. Section Proofs. @@ -261,13 +267,414 @@ Module ScalarMult. eapply (@homomorphism_scalarmult Wpoint Weq Wadd Wzero Wopp Wgroup.(Hierarchy.commutative_group_group) point eq add zero opp Pgroup scalarmult (@scalarmult_ref_is_scalarmult Wpoint Weq Wadd Wzero Wopp Wgroup.(Hierarchy.commutative_group_group)) scalarmult' (@scalarmult_ref_is_scalarmult point eq add zero opp Pgroup) of_affine ltac:(econstructor; [eapply Jacobian.of_affine_add|eapply Jacobian.Proper_of_affine])). Qed. - (* We compute [n]P where P ≠ ∞ and n < ord(P) *) - Context {n scalarbitsz : Z} - {Hn : (2 <= n < 2^scalarbitsz)%Z} + Context {scalarbitsz : Z} {Hscalarbitsz : (2 <= scalarbitsz)%Z} - {P : Wpoint} {HPnz : P <> ∞ :> Wpoint} {ordP : Z} {HordPpos : (2 < ordP)%Z} - {HordPodd : Z.odd ordP = true :> bool} + {HordPodd : Z.odd ordP = true :> bool}. + + Section Inner. + (* Proofs about joye_ladder_inner *) + + (* Bit 0 of the scalar input is irrelevant *) + Lemma joye_ladder_inner_bit0_irr (bitsz : Z) (testbit0 testbit1 : Z -> bool) + (P : point) (HPaff : z_of P = 1) + (bit0_irr : forall i, (i >= 1)%Z -> testbit0 i = testbit1 i :> bool) : + eq (joye_ladder_inner bitsz testbit0 P HPaff) + (joye_ladder_inner bitsz testbit1 P HPaff). + Proof. + unfold joye_ladder_inner. + rewrite (surjective_pairing (while _ _ _ (cswap_co_z_points (testbit0 _) _, _))). + rewrite (surjective_pairing (while _ _ _ (cswap_co_z_points (testbit1 _) _, _))). + rewrite bit0_irr by lia. + match goal with + | |- eq (snd (proj1_sig (fst (while ?T0 ?B0 ?F0 ?I0)))) + (snd (proj1_sig (fst (while ?T1 ?B1 ?F1 ?I1)))) => + set (test0 := T0); + set (body0 := B0); + set (fuel := F0); + set (init0 := I0); + set (body1 := B1) + end. + apply (while.preservation test0 body0 test0 body1 (fun s1 s2 => eq (snd (proj1_sig (fst s1))) (snd (proj1_sig (fst s2))) /\ (s1 = s2 :> (co_z_points * Z)) /\ (2 <= snd s1)%Z)). + - intros s1 s2 (_ & <- & _). reflexivity. + - unfold test0. intros (PQ1 & i1) (PQ2 & i2) _. + cbn [fst snd]. intros (_ & Heq & Hi). inversion Heq; clear Heq. + subst PQ1; subst i1. + unfold body0, body1. rewrite bit0_irr by lia. + split; [reflexivity|split; [reflexivity|simpl; lia] ]. + - repeat (split; try reflexivity). + Qed. + + Context {n : Z} {Hnodd : n = Z.setbit n 0 :> Z} + {Hn : (2 <= n < 2^scalarbitsz)%Z} + {P : point} {HPaff : z_of P = 1} + {HordP : forall l, (eq (scalarmult' l P) zero) <-> exists k, (l = k * ordP :> Z)%Z}. + Local Notation testbitn := (Z.testbit n). + Context {HSS : forall i, (2 <= i <= scalarbitsz)%Z -> not (eq (scalarmult' (SS n (Z.to_nat i)) P) zero)} + {HTT : forall i, (2 <= i <= scalarbitsz)%Z -> not (eq (scalarmult' (TT n (Z.to_nat i)) P) zero)}. + + Lemma mult_two_power (k : Z) : + (0 <= k)%Z -> + not (eq (scalarmult' (2^k)%Z P) zero). + Proof. + eapply (Zlt_0_ind (fun x => ~ eq (scalarmult' (2 ^ x) P) zero)). + intros x Hind Hxpos Heqz. + destruct (proj1 (HordP (2^x)%Z) Heqz) as [l Hl]. + destruct (Z.eq_dec x 0); [subst x|]. + - simpl in Hl. clear -Hl HordPpos. + generalize (Z.divide_1_r_nonneg ordP ltac:(lia) ltac:(exists l; lia)). + lia. + - assert (x = Z.succ (Z.pred x) :> Z) by lia. + rewrite H in Hl. rewrite Z.pow_succ_r in Hl; [|lia]. + generalize (Znumtheory.prime_mult 2%Z Znumtheory.prime_2 l ordP ltac:(exists (2 ^ Z.pred x)%Z; lia)). + intros [A|A]; destruct A as [m Hm]; [|replace ordP with (0 + 2 * m)%Z in HordPodd by lia; rewrite Z.odd_add_mul_2 in HordPodd; simpl in HordPodd; congruence]. + rewrite Hm in Hl. + assert ((2 ^ Z.pred x)%Z = (m * ordP)%Z :> Z) by lia. + apply (Hind (Z.pred x) ltac:(lia)). + eapply HordP. exists m; assumption. + Qed. + + Lemma mult_two (k : Z) : + eq (scalarmult' (2 * k)%Z P) zero -> + eq (scalarmult' k P) zero. + Proof. + intros Heqz. destruct (proj1 (HordP (2 * k)%Z) Heqz) as [l Hl]. + generalize (Znumtheory.prime_mult 2%Z Znumtheory.prime_2 l ordP ltac:(exists k; lia)). + intros [A|A]; destruct A as [m Hm]; [|replace ordP with (0 + 2 * m)%Z in HordPodd by lia; rewrite Z.odd_add_mul_2 in HordPodd; simpl in HordPodd; congruence]. + rewrite Hm in Hl. assert (k = m * ordP :> Z)%Z as -> by lia. + apply HordP; eauto. + Qed. + + Lemma HSS_plus_TT (m : Z) (k : nat) : + not (eq (scalarmult' (SS m k + TT m k)%Z P) zero). + Proof. rewrite SS_plus_TT. apply mult_two_power. lia. Qed. + + Lemma SS1 : SS n 1 = 1%Z :> Z. + Proof. cbv [SS]. rewrite Z.pow_1_r, <-Z.bit0_mod, Hnodd, Z.setbit_eqb; trivial; lia. Qed. + + Lemma TT1 : TT n 1 = 1%Z :> Z. + Proof. cbv [TT]. rewrite SS1; trivial. Qed. + + Lemma SS2 : SS n 2 = (if Z.testbit n 1 then 3%Z else 1%Z) :> Z. + Proof. + cbv [SS]; rewrite Hnodd, <-Z.land_ones, Z.land_comm by lia. + destruct n; cbn; trivial; repeat (destruct p; cbn; trivial). + Qed. + + Lemma TT2 : TT n 2 = (if Z.testbit n 1 then 1%Z else 3%Z) :> Z. + Proof. cbv [TT]. rewrite SS2. case Z.testbit; lia. Qed. + + Lemma SS_TT2 : (if testbitn 1 then SS n 2 else TT n 2) = 3%Z :> Z. + Proof. rewrite SS2, TT2; try lia; case testbitn; trivial. Qed. + + Lemma HordP3 : + not (eq (scalarmult' 3%Z P) zero). + Proof. + rewrite <- SS_TT2, <- (Nat2Z.id 2). + case testbitn; [eapply HSS|eapply HTT]; lia. + Qed. + + Hint Unfold fst snd proj1_sig : points_as_coordinates. + Hint Unfold fieldwise fieldwise' : points_as_coordinates. + + Lemma eq_fieldwise (P1 P2 : point) : + fieldwise (n:=3) Feq + (proj1_sig P1) (proj1_sig P2) -> + eq P1 P2. + Proof. clear -field; intros; t. Qed. + + Lemma Pynz : + y_of P <> 0. + Proof. + intro Hy. assert (HA : eq P (opp P)). + - apply eq_fieldwise. destruct P as (((X & Y) & Z) & HP). + simpl; cbv in HPaff, Hy; repeat split; fsatz. + - apply (mult_two_power 1%Z ltac:(lia)). + replace (2 ^ 1)%Z with (1 - -1)%Z by lia. + rewrite (scalarmult_sub_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). + rewrite <- (scalarmult_1_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))) in HA. + rewrite HA. + rewrite <- (scalarmult_opp_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). + rewrite <- (scalarmult_sub_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). + replace (- (1) - -1)%Z with 0%Z by lia. reflexivity. + Qed. + + Lemma add_opp_zero (A : point) : + eq (add A (opp A)) zero. + Proof. + generalize (Jacobian.add_opp A). + destruct (add A (opp A)) as (((X & Y) & Z) & H). + cbn. intros HP; destruct (dec (Z = 0)); fsatz. + Qed. + + Lemma scalarmult_difference (A : point) (n1 n2 : Z): + eq (scalarmult' n1 A) (scalarmult' n2 A) -> + eq (scalarmult' (n1 - n2)%Z A) zero. + Proof. + intros. rewrite (scalarmult_sub_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))), H, <- (scalarmult_sub_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))), Z.sub_diag. + apply (scalarmult_0_l (is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). + Qed. + + Lemma dblu_scalarmult' (Q : point) (Hz1 : z_of Q = 1) (Hynz : y_of Q <> 0) : + let '(M, N) := dblu Q Hz1 in + eq M (scalarmult' 2 Q) + /\ eq N (scalarmult' 1 Q). + Proof. + generalize (@Jacobian.dblu_correct F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv a b field char_ge_3 Feq_dec char_ge_12 Q Hz1 Hynz). + rewrite (surjective_pairing (dblu _ _)) at 1. + intros (A & B & C). split. + - rewrite <- A. replace 2%Z with (1 + 1)%Z by lia. + rewrite (scalarmult_add_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). + rewrite (scalarmult_1_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). + rewrite <- Jacobian.add_double; reflexivity. + - rewrite (scalarmult_1_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). + symmetry. assumption. + Qed. + + Lemma co_xz_implies (P1 P2 : point) (Hxeq : x_of P1 = x_of P2) + (Hzeq : z_of P1 = z_of P2) : + eq P1 P2 \/ eq P1 (opp P2). + Proof. + clear -Hxeq Hzeq. prept; [tauto|fsatz|fsatz|]. + assert (f4 = f1 \/ f4 = Fopp f1) by (destruct (dec (f4 = f1)); [left; assumption|right; fsatz]). + destruct H; [left; repeat split; fsatz|right; repeat split; fsatz]. + Qed. + + Lemma tplu_scalarmult' {p q} (H : tplu P HPaff = (p, q) :> _) : + eq p (scalarmult' 3 P) /\ eq q (scalarmult' 1 P) /\ co_z p q. + Proof. + intros; unfold tplu. generalize (dblu_scalarmult' P HPaff Pynz). + inversion_prod; subst p q. + rewrite (surjective_pairing (dblu _ _)) at 1. + set (P1 := (snd (dblu P HPaff))). + set (P2 := (fst (dblu P HPaff))). intros [A1 A2]. + destruct (dec (x_of P1 = x_of P2)) as [Hxe|Hxne]. + { destruct (co_xz_implies P1 P2 Hxe (CoZ.Jacobian.tplu_obligation_1 P HPaff)) as [A|A]. + - rewrite A1, A2 in A. elim (mult_two_power 1%Z ltac:(lia)). + rewrite <- A. replace 1%Z with (2 - 1)%Z by lia. + apply scalarmult_difference; symmetry; assumption. + - rewrite A1, A2, <- (scalarmult_opp_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))) in A. + apply scalarmult_difference in A. + elim HordP3. exact A. } + generalize (Jacobian.zaddu_correct _ _ (CoZ.Jacobian.tplu_obligation_1 P HPaff) Hxne). + rewrite (surjective_pairing (zaddu _ _ _)) at 1. + intros (A & B & C); subst P1 P2. + repeat try split; trivial. + { rewrite <-A, A1, A2, (@scalarmult_add_l point eq add zero opp Pgroup scalarmult' (@scalarmult_ref_is_scalarmult _ _ _ _ _ Pgroup) 1 2); reflexivity. } + { rewrite <-B. rewrite A2. reflexivity. } + Qed. + + (* Since co-Z formulas are incomplete, we need to show that we won't hit the neutral point for ZDAU in the loop *) + Lemma zaddu_SS_TT (i : Z) (B1 B2 Y1 Y2 R0 R1 : point) (HB12 : co_z B1 B2) + (Hi : (2 <= i < scalarbitsz)%Z) + (HBx : x_of B1 <> x_of B2) + (HY : zaddu B1 B2 HB12 = (Y1, Y2) :> point * point) + (HR0 : eq R0 (scalarmult' (SS n (Z.to_nat i)) P)) + (HR1 : eq R1 (scalarmult' (TT n (Z.to_nat i)) P)) + (HB1 : B1 = (if testbitn i then R0 else R1) :> point) + (HB2 : B2 = (if testbitn i then R1 else R0) :> point) : + x_of Y1 <> x_of Y2. + Proof. + intro XX. generalize (Jacobian.zaddu_correct B1 B2 HB12 HBx). + rewrite HY. intros (W1 & W2 & W3). + destruct (co_xz_implies _ _ XX W3) as [W|W]; rewrite W, <- W2, HB1, HB2 in W1. + - destruct (testbitn i); + rewrite HR1, HR0, <- (scalarmult_add_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))) in W1; + apply scalarmult_difference in W1; + rewrite Z.add_simpl_l in W1; + [apply (HTT i ltac:(lia))|apply (HSS i ltac:(lia))]; auto. + - destruct (testbitn i) eqn:Hti; + rewrite HR1, HR0, <- (scalarmult_add_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))), <- (scalarmult_opp_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))) in W1; + apply scalarmult_difference in W1. + all: match goal with + | H : eq (scalarmult' ?X _) zero |- _ => + match X with + | (SS _ _ + _ - _)%Z => + replace X with (SS n (S (Z.to_nat i))) in H by (rewrite SS_succ, Z2Nat.id, Hti; lia) + | (TT _ _ + _ - _)%Z => + replace X with (TT n (S (Z.to_nat i))) in H by (rewrite TT_succ, Z2Nat.id, Hti; lia) + end + end. + all: rewrite <- Z2Nat.inj_succ in W1; try lia. + all: match goal with + | H : eq (scalarmult' (SS _ _) _) zero |- _ => + apply (HSS (Z.succ i) ltac:(lia)) + | _ => + apply (HTT (Z.succ i) ltac:(lia)) + end; auto. + Qed. + + Lemma SS_TT_xne (i : Z) (R0 R1 : point) (HCOZ : co_z R0 R1) + (Hi : (2 <= i < scalarbitsz)%Z) + (HR0 : eq R0 (scalarmult' (SS n (Z.to_nat i)) P)) + (HR1 : eq R1 (scalarmult' (TT n (Z.to_nat i)) P)) : + x_of R0 <> x_of R1. + Proof. + assert (HH : forall A, eq (opp A) zero -> eq A zero) by (clear -field; unfold zero, Wzero; intros; t). + intros Hxe. generalize (co_xz_implies _ _ Hxe HCOZ). + destruct (Z.eq_dec 2 i). + { subst i. replace (Z.to_nat 2) with 2%nat in HR1 by lia. + replace (Z.to_nat 2) with 2%nat in HR0 by lia. + rewrite TT2 in HR1. rewrite SS2 in HR0. + rewrite HR0, HR1, <- (scalarmult_opp_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). + destruct (testbitn 1) eqn:Htj; intros [Q|Q]; apply scalarmult_difference in Q. + all: match goal with + | H : eq (scalarmult' ?X _) zero |- _ => + try replace X with 2%Z in H by lia; + try replace X with 4%Z in H by lia; + try replace X with (- (2))%Z in H by lia + end. + all: match goal with + | H : eq (scalarmult' ?X _) zero |- _ => + match X with + | 2%Z => apply (mult_two_power 1%Z ltac:(lia) H) + | 4%Z => apply (mult_two_power 2%Z ltac:(lia) H) + | (- (2))%Z => + rewrite (scalarmult_opp_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))) in Q; + apply HH in H; + apply (mult_two_power 1%Z ltac:(lia) H) + end + end. } + { set (j := Z.pred i). assert (His : i = Z.succ j :> Z) by lia. + assert (Hj : (2 <= j)%Z) by lia. + rewrite His, Z2Nat.inj_succ in HR0, HR1 by lia. + rewrite HR0, HR1, <- (scalarmult_opp_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). + rewrite TT_succ, SS_succ, Z2Nat.id by lia. + destruct (testbitn j) eqn:Htj; intros [Q|Q]; apply scalarmult_difference in Q. + all: try match goal with + | H : eq (scalarmult' ?X _) zero |- _ => + match X with + | Z.sub _ ?Y => + match Y with + | (- _)%Z => + replace X with (2 * (SS n (Z.to_nat j) + TT n (Z.to_nat j)))%Z in H by lia; + apply mult_two in H; + apply (HSS_plus_TT n (Z.to_nat j) H) + | (TT _ _) => + replace X with (2 * SS n (Z.to_nat j))%Z in H by lia; + apply mult_two in H; + apply (HSS j ltac:(lia) H) + | (SS _ _) => + replace X with (2 * TT n (Z.to_nat j))%Z in H by lia; + apply mult_two in H; + apply (HTT j ltac:(lia) H) + | (Z.add _ _) => + replace X with (2 * - TT n (Z.to_nat j))%Z in H by lia; + apply mult_two in H; + rewrite (scalarmult_opp_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))) in H; + apply HH in H; + apply (HTT j ltac:(lia) H) + end + end + end. } + Qed. + + (* When n is odd, joye_ladder_inner computes [n]P *) + Lemma joye_ladder_inner_correct : + eq (joye_ladder_inner scalarbitsz testbitn P HPaff) + (scalarmult' n P). + Proof. + unfold joye_ladder_inner. set (WW := while _ _ _ _). + destruct (tplu_co_z_points P HPaff) as ((A1 & A2) & HA12) eqn:Htplu. + rewrite (sig_eta (tplu_co_z_points _ _)) in Htplu. + apply proj1_sig_eq in Htplu; simpl in Htplu. + (* Initialize the ladder state with ([3]P, [1]P) or its symmetric *) + destruct (cswap_co_z_points (testbitn 1) _) as ((A3 & A4) & HA34) eqn:HA1. + rewrite (sig_eta (cswap_co_z_points _ _)) in HA1. + apply proj1_sig_eq in HA1. cbn [proj1_sig cswap_co_z_points] in HA1. + assert (A3 = (if testbitn 1 then A2 else A1) :> point) as HA3 by (destruct (testbitn 1); inversion HA1; auto). + assert (A4 = (if testbitn 1 then A1 else A2) :> point) as HA4 by (destruct (testbitn 1); inversion HA1; auto). + clear HA1. destruct (tplu_scalarmult' Htplu) as (HeqA1 & HeqA2 & _). + set (inv := + fun '(R1R0, i) => + let '(R1, R0) := proj1_sig (R1R0:co_z_points) in + (2 <= i <= scalarbitsz)%Z /\ + (eq R1 (scalarmult' (TT n (Z.to_nat i)) P) + /\ eq R0 (scalarmult' (SS n (Z.to_nat i)) P)) + /\ ((i < scalarbitsz)%Z -> x_of R1 <> x_of R0)). + assert (HH : forall (A B : Prop), A -> (A -> B) -> A /\ B) by tauto. + assert (WWinv : inv WW /\ (snd WW = scalarbitsz :> Z)). + { set (measure := fun (state : (co_z_points*Z)) => ((Z.to_nat scalarbitsz) + 2 - (Z.to_nat (snd state)))%nat). + unfold WW. replace (Z.to_nat scalarbitsz) with (measure (exist _ (A3, A4) HA34, 2%Z)) by (unfold measure; simpl; lia). + eapply (while.by_invariant inv measure (fun s => inv s /\ (snd s = scalarbitsz :> Z))). + - (* Invariant holds at beginning *) + unfold inv. cbn [proj1_sig]. + split; [lia|]. apply HH. + + change (Z.to_nat 2) with 2%nat. + rewrite SS2, TT2, HA3, HA4. + case Z.testbit; auto. + + intros [He1 He2] Hxe. symmetry. + apply (SS_TT_xne 2%Z A4 A3 ltac:(apply co_z_comm; exact HA34) ltac:(lia)); eauto. + - (* Invariant is preserved by the loop, + measure decreases, + and post-condition i = scalarbitsz. *) + intros s Hs. destruct s as (R1R0 & i). + destruct R1R0 as ((R1 & R0) & HCOZ). + destruct Hs as (Hi & (HR1 & HR0) & Hx). + destruct (Z.ltb i scalarbitsz) eqn:Hltb. + + apply Z.ltb_lt in Hltb. + split. + * (* Invariant preservation. + The loop body is basically : + (R1, R0) <- cswap (testbitn i) (R1, R0); + (R1, R0) <- ZDAU (R1, R0); + (R1, R0) <- cswap (testbitn i) (R1, R0); + *) + (* Start by giving names to all intermediate values *) + unfold inv. destruct (cswap_co_z_points (testbitn i) (exist _ _ _)) as ((B1 & B2) & HB12) eqn:Hswap1. + rewrite (sig_eta (cswap_co_z_points _ _)) in Hswap1. + apply proj1_sig_eq in Hswap1. simpl in Hswap1. + assert (HB1: B1 = (if testbitn i then R0 else R1) :> point) by (destruct (testbitn i); congruence). + assert (HB2: B2 = (if testbitn i then R1 else R0) :> point) by (destruct (testbitn i); congruence). + clear Hswap1. + destruct (zdau_co_z_points _) as ((C1 & C2) & HC12) eqn:HZDAU. + rewrite (sig_eta (zdau_co_z_points _)) in HZDAU. + apply proj1_sig_eq in HZDAU. simpl in HZDAU. + assert (HBx : x_of B1 <> x_of B2) by (rewrite HB1, HB2; destruct (testbitn i); [symmetry|]; auto). + destruct (zaddu B1 B2 (zdau_co_z_points_obligation_1 (exist (fun '(A, B) => co_z A B) (B1, B2) HB12) B1 B2 eq_refl)) as (Y1 & Y2) eqn:HY. + assert (HYx : x_of Y1 <> x_of Y2) by (eapply zaddu_SS_TT; eauto; lia). + generalize (@Jacobian.zdau_correct_alt F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv a b field char_ge_3 Feq_dec char_ge_12 ltac:(unfold id in *; fsatz) B1 B2 (zdau_co_z_points_obligation_1 (exist (fun '(A, B) => co_z A B) (B1, B2) HB12) B1 B2 eq_refl) HBx ltac:(rewrite HY; simpl; apply HYx)). + rewrite HZDAU. intros (HC1 & HC2 & _). + destruct (cswap_co_z_points (testbitn i) _) as ((D1 & D2) & HD12) eqn:HD. + rewrite (sig_eta (cswap_co_z_points _ _)) in HD. + apply proj1_sig_eq in HD. cbn [proj1_sig cswap_co_z_points] in HD. + assert (HD1 : D1 = (if testbitn i then C2 else C1) :> point) by (destruct (testbitn i); congruence). + assert (HD2 : D2 = (if testbitn i then C1 else C2) :> point) by (destruct (testbitn i); congruence). + clear HD. simpl. + (* invariant preservation *) + (* counter still within bounds *) + split; [lia|]. rewrite HD1, HD2. apply HH. + { (* New values are indeed [SS (i+1)]P and [TT (i+1)]P *) + destruct (testbitn i) eqn:Hti; + rewrite <- HC1, <- HC2, HB1, HB2; + replace (Z.to_nat (Z.succ i)) with (S (Z.to_nat i)) by lia; + rewrite SS_succ, TT_succ, Z2Nat.id by lia; + rewrite Hti; split; try assumption; + rewrite <- Jacobian.add_double; try reflexivity; + rewrite HR0, HR1; + repeat rewrite <- (scalarmult_add_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))); + rewrite <- Z.add_diag; reflexivity. } + { (* Make sure we don't hit bad values *) + intros [He1 He2] Hxe. symmetry. eapply (SS_TT_xne (Z.succ i)); eauto. + - destruct (testbitn i); [|apply co_z_comm]; auto. + - lia. } + * (* measure decreases *) + apply Z.ltb_lt in Hltb. + unfold measure; simpl; lia. + + (* Post-condition *) + simpl; split; auto. + rewrite Z.ltb_nlt in Hltb. lia. } + destruct WWinv as (Hinv & Hj). + destruct WW as (R1R0 & i). destruct R1R0 as ((R1 & R0) & HCOZ). + simpl in Hj; subst i. destruct Hinv as (_ & (_ & HR0) & _). + rewrite (SSn n scalarbitsz ltac:(lia) ltac:(lia)) in HR0. + exact HR0. + Qed. + End Inner. + + (* We compute [n]P where P ≠ ∞ and n < ord(P) *) + Context {n : Z} {Hn : (2 <= n < 2^scalarbitsz)%Z} + {P : Wpoint} {HPnz : P <> ∞ :> Wpoint} {HordP : forall l, (Weq (scalarmult l P) ∞) <-> exists k, (l = k * ordP :> Z)%Z} {HordPn : (n + 2 < ordP)%Z}. @@ -292,79 +699,6 @@ Module ScalarMult. - apply (LandLorShiftBounds.Z.lor_range n 1 scalarbitsz); lia. Qed. - Lemma Htestbitn'0 : testbitn' 0 = true :> bool. - Proof. rewrite Z.setbit_eqb; lia. Qed. - - Lemma Htestbitn' j (Hj : (1 <= j)%Z) : testbitn j = testbitn' j :> bool. - Proof. rewrite Z.setbit_neq; trivial; lia. Qed. - - Lemma SS1 : (SS n' 1 = 1%Z :> Z). - Proof. cbv [SS]. rewrite Z.pow_1_r, <-Z.bit0_mod, Z.setbit_eq; trivial; lia. Qed. - - Lemma TT1 : (TT n' 1 = 1%Z :> Z). - Proof. cbv [TT]. rewrite SS1; trivial. Qed. - - Lemma SS_2 : SS n' 2 = (if Z.testbit n' 1 then 3%Z else 1%Z) :> Z. - Proof. - cbv [SS n']; rewrite <-Z.land_ones, Z.land_comm by lia. - destruct n; cbn; trivial; repeat (destruct p; cbn; trivial). - Qed. - - Lemma TT_2 : TT n' 2 = (if Z.testbit n' 1 then 1%Z else 3%Z) :> Z. - Proof. cbv [TT]. rewrite SS_2. case Z.testbit; lia. Qed. - - Lemma SS_TT2 : ((if testbitn 1 then SS n' 2 else TT n' 2) = 3 :> Z)%Z. - Proof. rewrite SS_2, TT_2, Z.setbit_neq; try lia; case testbitn; trivial. Qed. - - Lemma HordP3 : - (3 < ordP)%Z. - Proof. - destruct (Z.eq_dec 3 ordP); [|lia]. - generalize SS_TT2; intros HSSTT. - destruct (testbitn 1); [elim (HSS 2 ltac:(lia))|elim (HTT 2 ltac:(lia))]; replace (Z.to_nat 2) with 2%nat by lia; rewrite HSSTT; eapply HordP; try lia. - Qed. - - Lemma n'_alt : - n' = (if testbitn 0 then n else (n + 1)%Z) :> Z. - Proof. - apply Z.bits_inj'; intros. - destruct (Z.eq_dec n0 0) as [->|?]; [rewrite Z.setbit_eq|rewrite Z.setbit_neq]; try lia. - - destruct (testbitn 0) eqn:Hbit0; auto. - rewrite Z.bit0_odd, <- Z.even_pred. - replace (Z.pred (n + 1))%Z with n by lia. - rewrite <- Z.negb_odd, <- Z.bit0_odd, Hbit0; reflexivity. - - destruct (testbitn 0) eqn:Hbit0; auto. - replace n0 with (Z.succ (n0 - 1))%Z by lia. - rewrite Z.bit0_odd in Hbit0. - rewrite (Zeven_div2 n); [|apply Zeven_bool_iff; rewrite <- Z.negb_odd, Hbit0; reflexivity]. - rewrite Z.testbit_even_succ, Z.testbit_odd_succ; auto; lia. - Qed. - - Lemma HordPn' : - (0 < n' < ordP)%Z. - Proof. rewrite n'_alt; destruct (testbitn 0); lia. Qed. - - Lemma mult_two_power (k : Z) : - (0 <= k)%Z -> - not (Weq (scalarmult (2^k)%Z P) ∞). - Proof. - eapply (Zlt_0_ind (fun x => ~ Weq (scalarmult (2 ^ x) P) Wzero)). - intros x Hind Hxpos Heqz. - destruct (proj1 (HordP (2^x)%Z) Heqz) as [l Hl]. - destruct (Z.eq_dec x 0); [subst x|]. - - simpl in Hl. clear -Hl HordPpos. - generalize (Z.divide_1_r_nonneg ordP ltac:(lia) ltac:(exists l; lia)). - lia. - - assert (x = Z.succ (Z.pred x) :> Z) by lia. - rewrite H in Hl. rewrite Z.pow_succ_r in Hl; [|lia]. - generalize (Znumtheory.prime_mult 2%Z Znumtheory.prime_2 l ordP ltac:(exists (2 ^ Z.pred x)%Z; lia)). - intros [A|A]; destruct A as [m Hm]; [|replace ordP with (0 + 2 * m)%Z in HordPodd by lia; rewrite Z.odd_add_mul_2 in HordPodd; simpl in HordPodd; congruence]. - rewrite Hm in Hl. - assert ((2 ^ Z.pred x)%Z = (m * ordP)%Z :> Z) by lia. - apply (Hind (Z.pred x) ltac:(lia)). - eapply HordP. exists m; assumption. - Qed. - Lemma scalarmult_eq_weq_conversion (k1 k2 : Z) : Weq (scalarmult k1 P) (scalarmult k2 P) <-> eq (scalarmult' k1 (of_affine P)) (scalarmult' k2 (of_affine P)). Proof. @@ -377,117 +711,29 @@ Module ScalarMult. symmetry; repeat rewrite scalarmult_scalarmult'; auto. Qed. - Hint Unfold fst snd proj1_sig : points_as_coordinates. - Hint Unfold fieldwise fieldwise' : points_as_coordinates. - - Lemma eq_fieldwise (P1 P2 : point) : - fieldwise (n:=3) Feq - (proj1_sig P1) (proj1_sig P2) -> - eq P1 P2. - Proof. clear -field; intros. t. Qed. - - Lemma Pynz : - y_of (of_affine P) <> 0. - Proof. - intro Hy. assert (HA : eq (of_affine P) (opp (of_affine P))). - - apply eq_fieldwise. destruct P as [ [ [X Y] | u] HP]; simpl; cbv in Hy; repeat split; fsatz. - - apply (mult_two_power 1%Z ltac:(lia)). - replace Wzero with (scalarmult 0%Z P) by reflexivity. - eapply scalarmult_eq_weq_conversion. - replace (2 ^ 1)%Z with (1 - -1)%Z by lia. - rewrite (scalarmult_sub_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). - rewrite <- (scalarmult_1_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))) in HA. - rewrite HA. - rewrite <- (scalarmult_opp_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). - rewrite <- (scalarmult_sub_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). - replace (- (1) - -1)%Z with 0%Z by lia. reflexivity. - Qed. - Lemma HordP_alt (k : Z) : (0 < k < ordP)%Z -> not (Weq (scalarmult k P) ∞). Proof. - intros Hbds Heq. - destruct (proj1 (HordP k) Heq) as [l Hl]. + intros Hbds Heq. destruct (proj1 (HordP k) Heq) as [l Hl]. clear -Hbds Hl. generalize (Zmult_gt_0_lt_0_reg_r ordP l ltac:(lia) ltac:(lia)). intros. generalize (proj1 (Z.mul_le_mono_pos_r 1%Z l ordP ltac:(lia)) ltac:(lia)). lia. Qed. - Lemma dblu_scalarmult' (Q : point) (Hz1 : z_of Q = 1) (Hynz : y_of Q <> 0) : - let '(M, N) := dblu Q Hz1 in - eq M (scalarmult' 2 Q) - /\ eq N (scalarmult' 1 Q). - Proof. - generalize (@Jacobian.dblu_correct F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv a b field char_ge_3 Feq_dec char_ge_12 Q Hz1 Hynz). - rewrite (surjective_pairing (dblu _ _)) at 1. - intros (A & B & C). split. - - rewrite <- A. replace 2%Z with (1 + 1)%Z by lia. - rewrite (scalarmult_add_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). - rewrite (scalarmult_1_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). - rewrite <- Jacobian.add_double; reflexivity. - - rewrite (scalarmult_1_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). - symmetry. assumption. - Qed. - - Lemma co_xz_implies (P1 P2 : point) (Hxeq : x_of P1 = x_of P2) - (Hzeq : z_of P1 = z_of P2) : - eq P1 P2 \/ eq P1 (opp P2). - Proof. - clear -Hxeq Hzeq. prept; [tauto|fsatz|fsatz|]. - assert (f4 = f1 \/ f4 = Fopp f1) by (destruct (dec (f4 = f1)); [left; assumption|right; fsatz]). - destruct H; [left; repeat split; fsatz|right; repeat split; fsatz]. - Qed. - - Lemma tplu_scalarmult' (Hz1 : z_of (of_affine P) = 1) {p q} (H : tplu (of_affine P) Hz1 = (p, q) :> _) : - eq p (scalarmult' 3 (of_affine P)) /\ eq q (scalarmult' 1 (of_affine P)) /\ co_z p q. - Proof. - intros; unfold tplu. generalize (dblu_scalarmult' (of_affine P) Hz1 Pynz). - inversion_prod; subst p q. - rewrite (surjective_pairing (dblu _ _)) at 1. - set (P1 := (snd (dblu (of_affine P) Hz1))). - set (P2 := (fst (dblu (of_affine P) Hz1))). intros [A1 A2]. - destruct (dec (x_of P1 = x_of P2)) as [Hxe|Hxne]. - { destruct (co_xz_implies P1 P2 Hxe (CoZ.Jacobian.tplu_obligation_1 (of_affine P) Hz1)) as [A|A]. - - rewrite A1, A2 in A. elim (HordP_alt 1%Z ltac:(lia)). - replace Wzero with (scalarmult 0%Z P) by reflexivity. - apply scalarmult_eq_weq_conversion. - replace 1%Z with (2 - 1)%Z by lia. - rewrite (@scalarmult_sub_l _ _ _ _ _ Pgroup _ (@scalarmult_ref_is_scalarmult _ _ _ _ _ Pgroup) 2 1). - rewrite A. - rewrite <- (@scalarmult_sub_l _ _ _ _ _ Pgroup _ (@scalarmult_ref_is_scalarmult _ _ _ _ _ Pgroup)). - replace (2 - 2)%Z with 0%Z by lia. reflexivity. - - rewrite A1, A2 in A. - elim (HordP_alt 3%Z ltac:(generalize HordP3; lia)). - replace Wzero with (scalarmult 0%Z P) by reflexivity. - apply scalarmult_eq_weq_conversion. - replace 3%Z with (1 - -2)%Z by lia. - rewrite (@scalarmult_sub_l _ _ _ _ _ Pgroup _ (@scalarmult_ref_is_scalarmult _ _ _ _ _ Pgroup)). - rewrite A. - rewrite <- (@scalarmult_opp_l _ _ _ _ _ Pgroup _ (@scalarmult_ref_is_scalarmult _ _ _ _ _ Pgroup)). - rewrite <- (@scalarmult_sub_l _ _ _ _ _ Pgroup _ (@scalarmult_ref_is_scalarmult _ _ _ _ _ Pgroup)). - replace (- (2) - -2)%Z with 0%Z by lia. reflexivity. } - generalize (@Jacobian.zaddu_correct _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ (CoZ.Jacobian.tplu_obligation_1 (of_affine P) Hz1) Hxne). - rewrite (surjective_pairing (zaddu _ _ _)) at 1. - intros (A & B & C); subst P1 P2. - repeat try split; trivial. - { rewrite <-A, A1, A2, (@scalarmult_add_l point eq add zero opp Pgroup scalarmult' (@scalarmult_ref_is_scalarmult _ _ _ _ _ Pgroup) 1 2); reflexivity. } - { rewrite <-B. rewrite A2. reflexivity. } - Qed. - - Lemma add_opp_zero (A : point) : - eq (add A (opp A)) zero. - Proof. - generalize (Jacobian.add_opp A). - destruct (add A (opp A)) as (((X & Y) & Z) & H). - cbn. intros HP; destruct (dec (Z = 0)); fsatz. - Qed. - - Lemma scalarmult_difference (A : point) (n1 n2 : Z): - eq (scalarmult' n1 A) (scalarmult' n2 A) -> - eq (scalarmult' (n1 - n2)%Z A) zero. + Lemma n'_alt : + n' = (if testbitn 0 then n else (n + 1)%Z) :> Z. Proof. - intros. rewrite (scalarmult_sub_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))), H, <- (scalarmult_sub_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))), Z.sub_diag. - apply (scalarmult_0_l (is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). + apply Z.bits_inj'; intros. + destruct (Z.eq_dec n0 0) as [->|?]; [rewrite Z.setbit_eq|rewrite Z.setbit_neq]; try lia. + - destruct (testbitn 0) eqn:Hbit0; auto. + rewrite Z.bit0_odd, <- Z.even_pred. + replace (Z.pred (n + 1))%Z with n by lia. + rewrite <- Z.negb_odd, <- Z.bit0_odd, Hbit0; reflexivity. + - destruct (testbitn 0) eqn:Hbit0; auto. + replace n0 with (Z.succ (n0 - 1))%Z by lia. + rewrite Z.bit0_odd in Hbit0. + rewrite (Zeven_div2 n); [|apply Zeven_bool_iff; rewrite <- Z.negb_odd, Hbit0; reflexivity]. + rewrite Z.testbit_even_succ, Z.testbit_odd_succ; auto; lia. Qed. Lemma joye_ladder_correct : @@ -496,238 +742,44 @@ Module ScalarMult. (* Unfold the ladder *) rewrite <- (Jacobian.to_affine_of_affine (scalarmult n P)). apply Jacobian.Proper_to_affine. rewrite scalarmult_scalarmult'. - destruct (tplu_co_z_points P HPnz) as ((A1 & A2) & HA12) eqn:Htplu. - rewrite (sig_eta (tplu_co_z_points _ _)) in Htplu. - apply proj1_sig_eq in Htplu; simpl in Htplu. - (* Initialize the ladder state with ([3]P, [1]P) or its symmetric *) - cbv zeta. destruct (cswap_co_z_points (testbitn 1) _) as ((A3 & A4) & HA34) eqn:HA1. - rewrite (sig_eta (cswap_co_z_points _ _)) in HA1. - apply proj1_sig_eq in HA1. cbn [proj1_sig cswap_co_z_points] in HA1. - assert (A3 = (if testbitn 1 then A2 else A1) :> point) as HA3 by (destruct (testbitn 1); inversion HA1; auto). - assert (A4 = (if testbitn 1 then A1 else A2) :> point) as HA4 by (destruct (testbitn 1); inversion HA1; auto). - clear HA1. destruct (tplu_scalarmult' (tplu_co_z_points_obligation_1 P HPnz) Htplu) as (HeqA1 & HeqA2 & _). - (* While loop *) - set (WW := while _ _ _ _). - (* Invariant is: - - loop counter i is such that 2 ≤ i ≤ scalarbitsz - - R1 = [TT i]P - - R0 = [SS i]P - - additional condition to ensure that the ladder state does not encounter the neutral point - *) - set (inv := - fun '(R1R0, i) => - let '(R1, R0) := proj1_sig (R1R0:co_z_points) in - (2 <= i <= scalarbitsz)%Z /\ - (eq R1 (scalarmult' (TT n' (Z.to_nat i)) (of_affine P)) - /\ eq R0 (scalarmult' (SS n' (Z.to_nat i)) (of_affine P))) - /\ ((i < scalarbitsz)%Z -> x_of R1 <> x_of R0)). - assert (WWinv : inv WW /\ (snd WW = scalarbitsz :> Z)). - { set (measure := fun (state : (co_z_points*Z)) => ((Z.to_nat scalarbitsz) + 2 - (Z.to_nat (snd state)))%nat). - unfold WW. replace (Z.to_nat scalarbitsz) with (measure (exist _ (A3, A4) HA34, 2%Z)) by (unfold measure; simpl; lia). - eapply (while.by_invariant inv measure (fun s => inv s /\ (snd s = scalarbitsz :> Z))). - - (* Invariant holds at beginning *) - unfold inv. cbn [proj1_sig]. - split; [lia|]. - split. - + change (Z.to_nat 2) with 2%nat. - rewrite SS_2, TT_2, <-Htestbitn', HA3, HA4 by lia. - case Z.testbit; auto. - + intros AB Hxe. destruct (co_xz_implies A3 A4 Hxe HA34) as [Heq|Hopp]; [rewrite HA3, HA4 in Heq|rewrite HA3, HA4 in Hopp]. - * eapply (HordP_alt 2 ltac:(lia)). - replace Wzero with (scalarmult 0 P) by reflexivity. - apply scalarmult_eq_weq_conversion. - replace 2%Z with (3 - 1)%Z by lia. - rewrite (scalarmult_sub_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). - rewrite <- HeqA1, <- HeqA2. - destruct (testbitn 1); rewrite Heq; simpl; apply add_opp_zero. - * eapply (mult_two_power 2%Z ltac:(lia)). - replace (Z.pow 2 2) with 4%Z by lia. - replace Wzero with (scalarmult 0%Z P) by reflexivity. - replace 4%Z with (3 + 1)%Z by lia. - apply scalarmult_eq_weq_conversion. - rewrite (scalarmult_add_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))). - rewrite <- HeqA1, <- HeqA2. - destruct (testbitn 1); rewrite Hopp; [|rewrite Jacobian.add_comm]; apply add_opp_zero. - - (* Invariant is preserved by the loop, - measure decreases, - and post-condition i = scalarbitsz. - *) - intros s Hs. destruct s as (R1R0 & i). - destruct R1R0 as ((R1 & R0) & HCOZ). - destruct Hs as (Hi & (HR1 & HR0) & Hx). - destruct (Z.ltb i scalarbitsz) eqn:Hltb. - + apply Z.ltb_lt in Hltb. - split. - * (* Invariant preservation. - The loop body is basically : - (R1, R0) <- cswap (testbitn i) (R1, R0); - (R1, R0) <- ZDAU (R1, R0); - (R1, R0) <- cswap (testbitn i) (R1, R0); - *) - (* Start by giving names to all intermediate values *) - unfold inv. destruct (cswap_co_z_points (testbitn i) (exist _ _ _)) as ((B1 & B2) & HB12) eqn:Hswap1. - rewrite (sig_eta (cswap_co_z_points _ _)) in Hswap1. - apply proj1_sig_eq in Hswap1. simpl in Hswap1. - assert (HB1: B1 = (if testbitn i then R0 else R1) :> point) by (destruct (testbitn i); congruence). - assert (HB2: B2 = (if testbitn i then R1 else R0) :> point) by (destruct (testbitn i); congruence). - clear Hswap1. - destruct (zdau_co_z_points _) as ((C1 & C2) & HC12) eqn:HZDAU. - rewrite (sig_eta (zdau_co_z_points _)) in HZDAU. - apply proj1_sig_eq in HZDAU. simpl in HZDAU. - assert (HBx : x_of B1 <> x_of B2) by (rewrite HB1, HB2; destruct (testbitn i); [symmetry|]; auto). - destruct (zaddu B1 B2 (zdau_co_z_points_obligation_1 (exist (fun '(A, B) => co_z A B) (B1, B2) HB12) B1 B2 eq_refl)) as (Y1 & Y2) eqn:HY. - (* Since co-Z formulas are incomplete, we need to show that we won't hit the neutral point for ZDAU *) - assert (HYx : x_of Y1 <> x_of Y2). - { (* We need to prove that [SS i + TT i]P and - [SS i]P or [TT i]P - (depending on testbitn i) have different X coordinates, i.e., - [SS i + TT i ± SS i]P ≠ ∞ - or [SS i + TT i ± TT i]P ≠ ∞ - *) - intro XX. generalize (Jacobian.zaddu_correct B1 B2 (zdau_co_z_points_obligation_1 (exist (fun '(A, B) => co_z A B) (B1, B2) HB12) B1 B2 eq_refl) HBx). - rewrite HY. intros (W1 & W2 & W3). - destruct (co_xz_implies _ _ XX W3) as [W|W]; rewrite W, <- W2, HB1, HB2 in W1. - - destruct (testbitn i); - rewrite HR1, HR0, <- (scalarmult_add_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))) in W1; - apply scalarmult_difference in W1; - rewrite Z.add_simpl_l in W1; - [apply (HTT i ltac:(lia))|apply (HSS i ltac:(lia))]; - replace Wzero with (scalarmult 0 P) by reflexivity; - apply scalarmult_eq_weq_conversion; - rewrite W1; reflexivity. - - destruct (testbitn i) eqn:Hti; - rewrite (Htestbitn' i ltac:(lia)) in Hti; - rewrite HR1, HR0, <- (scalarmult_add_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))), <- (scalarmult_opp_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))) in W1; - apply scalarmult_difference in W1. - all: match goal with - | H : eq (scalarmult' ?X _) zero |- _ => - match X with - | (SS _ _ + _ - _)%Z => - replace X with (SS n' (S (Z.to_nat i))) in H by (rewrite SS_succ, Z2Nat.id, Hti; lia) - | (TT _ _ + _ - _)%Z => - replace X with (TT n' (S (Z.to_nat i))) in H by (rewrite TT_succ, Z2Nat.id, Hti; lia) - end - end. - all: rewrite <- Z2Nat.inj_succ in W1; try lia. - all: match goal with - | H : eq (scalarmult' (SS _ _) _) zero |- _ => - apply (HSS (Z.succ i) ltac:(lia)) - | _ => - apply (HTT (Z.succ i) ltac:(lia)) - end. - all: replace Wzero with (scalarmult 0 P) by reflexivity. - all: apply scalarmult_eq_weq_conversion. - all: rewrite W1; reflexivity. } - generalize (@Jacobian.zdau_correct_alt F Feq Fzero Fone Fopp Fadd Fsub Fmul Finv Fdiv a b field char_ge_3 Feq_dec char_ge_12 ltac:(unfold id in *; fsatz) B1 B2 (zdau_co_z_points_obligation_1 (exist (fun '(A, B) => co_z A B) (B1, B2) HB12) B1 B2 eq_refl) HBx ltac:(rewrite HY; simpl; apply HYx)). - rewrite HZDAU. intros (HC1 & HC2 & _). - destruct (cswap_co_z_points (testbitn i) _) as ((D1 & D2) & HD12) eqn:HD. - rewrite (sig_eta (cswap_co_z_points _ _)) in HD. - apply proj1_sig_eq in HD. cbn [proj1_sig cswap_co_z_points] in HD. - assert (HD1 : D1 = (if testbitn i then C2 else C1) :> point) by (destruct (testbitn i); congruence). - assert (HD2 : D2 = (if testbitn i then C1 else C2) :> point) by (destruct (testbitn i); congruence). - clear HD. simpl. - (* invariant preservation *) - (* counter still within bounds *) - split; [lia|]. rewrite HD1, HD2. split. - { (* New values are indeed [SS (i+1)]P and [TT (i+1)]P *) - destruct (testbitn i) eqn:Hti; - rewrite (Htestbitn' i ltac:(lia)) in Hti; - rewrite <- HC1, <- HC2, HB1, HB2; - replace (Z.to_nat (Z.succ i)) with (S (Z.to_nat i)) by lia; - rewrite SS_succ, TT_succ, Z2Nat.id by lia; - rewrite Hti; split; try assumption; - rewrite <- Jacobian.add_double; try reflexivity; - rewrite HR0, HR1; - repeat rewrite <- (scalarmult_add_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))); - rewrite <- Z.add_diag; reflexivity. } - { (* Make sure we don't hit bad values *) - intros Hsi Hxe'. - assert (Hxe : x_of C1 = x_of C2) by (destruct (testbitn i); fsatz); clear Hxe'. - generalize (co_xz_implies _ _ Hxe HC12). - rewrite <- HC1, <- HC2, <- Jacobian.add_double; [|reflexivity]. - rewrite HB1, HB2. destruct (testbitn i) eqn:Hti; - rewrite (Htestbitn' i ltac:(lia)) in Hti; - rewrite HR0, HR1; - repeat rewrite <- (scalarmult_add_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))); - rewrite Z.add_diag, <- (scalarmult_opp_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))); - intros [Q|Q]; apply scalarmult_difference in Q; - (* 3 cases *) - match goal with - | H : eq (scalarmult' ?X _) zero |- _ => - match X with - | Z.sub _ ?Y => - match Y with - | (- _)%Z => (* Case [2^(i+1)]P ≠ ∞ *) - replace X with (2 * (SS n' (Z.to_nat i) + TT n' (Z.to_nat i)))%Z in H by lia - | (TT _ _) => (* Case [2 * (SS i)]P ≠ ∞ *) - replace X with (2 * SS n' (Z.to_nat i))%Z in H by lia; - shelve - | (SS _ _) => (* Case [2 * (TT i)]P ≠ ∞ *) - replace X with (2 * TT n' (Z.to_nat i))%Z in H by lia; - shelve - end - end - end. - (* Solve case [2^(i+1)]P ≠ ∞ first *) - all: rewrite SS_plus_TT, Z2Nat.id, <- Z.pow_succ_r in Q; try lia; eauto. - all: eapply (mult_two_power (Z.succ i) ltac:(lia)). - all: replace Wzero with (scalarmult 0 P) by reflexivity. - all: apply scalarmult_eq_weq_conversion. - all: rewrite Q; reflexivity. - Unshelve. (* Solve the other cases *) - all: replace zero with (scalarmult' 0 (of_affine P)) in Q by reflexivity. - all: apply scalarmult_eq_weq_conversion in Q. - all: generalize (SS_monotone1 n' (Z.to_nat i)); rewrite SS1; intro QS. - all: generalize (TT_monotone1 n' (Z.to_nat i)); rewrite TT1; intro QT. - all: match goal with - | H : Weq (scalarmult (Z.mul 2%Z ?X) P) (_ 0%Z _) |- _ => - destruct (proj1 (HordP (Z.mul 2%Z X)) Q) as [l Hl]; - generalize (Znumtheory.prime_mult 2%Z Znumtheory.prime_2 l ordP ltac:(exists X; lia)) - end. - all: intros [A|A]; destruct A as [m Hm]; - [|replace ordP with (0 + 2 * m)%Z in HordPodd by lia; rewrite Z.odd_add_mul_2 in HordPodd; simpl in HordPodd; congruence]. - all: subst l; rewrite <- Z.mul_assoc, <- Z.mul_shuffle3 in Hl. - all: apply (Z.mul_reg_l _ _ 2%Z ltac:(lia)) in Hl. - all: match goal with - | H : SS _ (Z.to_nat ?X) = (_ * ordP)%Z :> Z |- _ => - apply (HSS X ltac:(lia)); - apply (proj2 (HordP (SS n' (Z.to_nat X)))) - | H : TT _ (Z.to_nat ?X) = (_ * ordP)%Z :> Z |- _ => - apply (HTT X ltac:(lia)); - apply (proj2 (HordP (TT n' (Z.to_nat X)))) - end; eauto. } - * (* measure decreases *) - apply Z.ltb_lt in Hltb. - unfold measure; simpl; lia. - + (* Post-condition *) - simpl; split; auto. - rewrite Z.ltb_nlt in Hltb. lia. } - (* Finalization, compute [n' - 1]P and [n']P *) - destruct WWinv as (Hinv & Hj). - destruct WW as (R1R0 & i). destruct R1R0 as ((R1 & R0) & HCOZ). - simpl in Hj; subst i. destruct Hinv as (_ & (_ & HR0) & _). - rewrite SSn in HR0; [|generalize Hn'; lia|lia]. cbn [snd proj1_sig]. + set (P' := of_affine P). set (HPaff := joye_ladder_obligation_1 P HPnz). + assert (Hnodd : n' = Z.setbit n' 0 :> Z) by (repeat rewrite Z.setbit_spec'; rewrite <- Z.lor_assoc, Z.lor_diag; reflexivity). + assert (HordP' : forall l, (eq (scalarmult' l P') zero) <-> exists k, (l = k * ordP :> Z)%Z). + { intros l; split; replace zero with (scalarmult' 0%Z P') by reflexivity. + - intros H; apply scalarmult_eq_weq_conversion in H; apply HordP; auto. + - intros H; apply scalarmult_eq_weq_conversion; apply HordP; auto. } + assert (HSS' : forall i, (2 <= i <= scalarbitsz)%Z -> not (eq (scalarmult' (SS n' (Z.to_nat i)) P') zero)). + { replace zero with (scalarmult' 0%Z P') by reflexivity. + intros i Hi He; apply scalarmult_eq_weq_conversion in He; apply (HSS i Hi); auto. } + assert (HTT' : forall i, (2 <= i <= scalarbitsz)%Z -> not (eq (scalarmult' (TT n' (Z.to_nat i)) P') zero)). + { replace zero with (scalarmult' 0%Z P') by reflexivity. + intros i Hi He; apply scalarmult_eq_weq_conversion in He; apply (HTT i Hi); auto. } + generalize (joye_ladder_inner_correct (n:=n') (P:=of_affine P) (HPaff:=HPaff) (Hnodd:=Hnodd) (Hn:=Hn') (HordP:=HordP')(HSS:=HSS') (HTT:=HTT')). + intros Hinner. rewrite (joye_ladder_inner_bit0_irr scalarbitsz testbitn' testbitn (of_affine P) HPaff ltac:(intros; rewrite Z.setbit_neq; trivial; lia)) in Hinner. + set (R0 := joye_ladder_inner scalarbitsz testbitn P' HPaff). + cbv zeta. fold R0. fold P' in Hinner. fold R0 in Hinner. + (* We now have R0 = [n']P *) destruct (make_co_z_points R0 _ _) as ((B1 & B2) & HB12) eqn:HMCZ. rewrite (sig_eta (make_co_z_points _ _ _)) in HMCZ. apply proj1_sig_eq in HMCZ; simpl in HMCZ. (* Prove [n']P ≠ ∞ *) assert (HR0znz : z_of R0 <> 0). - { intro. apply (HordP_alt n'). - - apply HordPn'. - - replace Wzero with (scalarmult 0 P) by reflexivity. - apply scalarmult_eq_weq_conversion. - rewrite <- HR0. destruct R0 as (((? & ?) & ?) & ?). - cbn in H; cbn. clear -field H. - destruct (dec (f1 = 0)); fsatz. } + { intro. apply (HSS scalarbitsz ltac:(lia)). + replace Wzero with (scalarmult 0%Z P) by reflexivity. + apply scalarmult_eq_weq_conversion. + rewrite SSn by (generalize Hn'; lia). + fold P'. rewrite <- Hinner. destruct R0 as (((? & ?) & ?) & ?). + cbn in H; cbn. clear -field H. + destruct (dec (f1 = 0)); fsatz. } (* Have co-Z representations of [n']P and [-1]P *) - generalize (Jacobian.make_co_z_correct R0 (opp (of_affine P)) (opp_of_affine P HPnz) HR0znz). + generalize (Jacobian.make_co_z_correct R0 (opp P') (opp_of_affine P HPnz) HR0znz). rewrite HMCZ. intros (HB1 & HB2 & _). destruct (zaddu_co_z_points _) as ((C1 & C2) & HC12) eqn:HZADDU. rewrite (sig_eta (zaddu_co_z_points _)) in HZADDU. apply proj1_sig_eq in HZADDU. simpl in HZADDU. (* Ensure ZADDU doesn't hit the neutral point *) assert (Hxne: x_of B1 <> x_of B2). - { intro Hxe. destruct (co_xz_implies _ _ Hxe HB12) as [HEq|HNeq]; [rewrite <- HB1, HR0, <- HB2 in HEq|rewrite <- HB1, HR0, <- HB2 in HNeq]. + { intro Hxe. destruct (co_xz_implies _ _ Hxe HB12) as [HEq|HNeq]; [rewrite <- HB1, Hinner, <- HB2 in HEq|rewrite <- HB1, Hinner, <- HB2 in HNeq]. - rewrite <- (scalarmult_opp1_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup))) in HEq. apply scalarmult_difference in HEq. apply (HordP_alt (n' - -1)%Z). @@ -735,7 +787,7 @@ Module ScalarMult. + replace Wzero with (scalarmult 0 P) by reflexivity. apply scalarmult_eq_weq_conversion. auto. - rewrite (Group.inv_inv (H:=Pgroup)) in HNeq. - rewrite <- (scalarmult_1_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup)) (of_affine P)) in HNeq at 2. + rewrite <- (scalarmult_1_l (groupG:=Pgroup) (mul_is_scalarmult:=scalarmult_ref_is_scalarmult (groupG:=Pgroup)) P') in HNeq at 2. apply scalarmult_difference in HNeq. apply (HordP_alt (n' - 1)%Z). + rewrite n'_alt; destruct (testbitn 0); lia. @@ -743,7 +795,7 @@ Module ScalarMult. apply scalarmult_eq_weq_conversion. auto. } generalize (Jacobian.zaddu_correct B1 B2 (zaddu_co_z_points_obligation_1 (exist (fun '(A, B) => co_z A B) (B1, B2) HB12) B1 B2 eq_refl) Hxne). rewrite HZADDU. intros (HC1 & HC2 & _). - rewrite <- HB1, <- HB2, HR0 in HC1. rewrite <- HB1, HR0 in HC2. + rewrite <- HB1, <- HB2, Hinner in HC1. rewrite <- HB1, Hinner in HC2. destruct (cswap_co_z_points (testbitn 0) _) as ((D1 & D2) & HD12) eqn:Hswap. rewrite (sig_eta (cswap_co_z_points _ _)) in Hswap. apply proj1_sig_eq in Hswap. cbn [proj1_sig cswap_co_z_points] in Hswap.