From 97d570b0c184e39c482a0fd0e5545a16e0708257 Mon Sep 17 00:00:00 2001 From: Alix Trieu Date: Thu, 6 Feb 2025 11:10:31 +0100 Subject: [PATCH] Refined/Signed Barrett Reduction --- src/Arithmetic/BarrettReduction/Refined.v | 335 ++++++++++++++++++++++ 1 file changed, 335 insertions(+) create mode 100644 src/Arithmetic/BarrettReduction/Refined.v diff --git a/src/Arithmetic/BarrettReduction/Refined.v b/src/Arithmetic/BarrettReduction/Refined.v new file mode 100644 index 0000000000..b39d2c09c9 --- /dev/null +++ b/src/Arithmetic/BarrettReduction/Refined.v @@ -0,0 +1,335 @@ +From Coq Require Import ZArith QArith Qabs Qround. +From Coq Require Import Lia. +Require Import Util.QUtil. + +(* + This is a formalization of (a part of) section 2.4 of the following paper + Efficient Multiplication of Somewhat Small Integers Using Number-Theoretic Transforms + Hanno Becker, Vincent Hwang, Matthias J. Kannwischer, Lorenz Panny, Bo-Yin Yang + IWSEC 2022 +*) + +Local Open Scope Q_scope. + +Section Qround_half_up. + Local Coercion inject_Z : Z >-> Q. + Local Coercion Zpos : positive >-> Z. + + Definition Qround_half_up (x: Q): Z := Qfloor (x + (1#2))%Q. + + Lemma Qfloor_approx (x: Q): + Qabs (x - Qfloor x) < 1%Z. + Proof. + rewrite Qabs_Qlt_condition. + pose proof (Qfloor_le x) as H1. + pose proof (Qlt_floor x) as H2. + destruct x as (x_num & x_den). + cbv [Qlt Qle] in *. simpl in *. lia. + Qed. + Lemma Qround_half_up_approx (x: Q): + Qabs (x - Qround_half_up x) <= 1#2. + Proof. + rewrite Qabs_Qle_condition. + pose proof (Qfloor_le (x + (1#2))) as H1. + pose proof (Qlt_floor (x + (1#2))) as H2. + destruct x as (x_num & x_den). + cbv [Qlt Qle Qplus Qround_half_up] in *. simpl in *. lia. + Qed. + + Lemma Qfloor_diff (x1 x2: Q) (Hlt: (Qfloor x1 < Qfloor x2)%Z): + exists (n: Z), x1 < n <= x2. + Proof. + exists (Qfloor x2). + pose proof (Qlt_floor x1) as H1. + pose proof (Qfloor_le x2) as H2. + split; [|assumption]. + apply (Qlt_le_trans _ _ _ H1). + cbv [Qle]; simpl. lia. + Qed. + + Lemma Qround_half_up_diff (x1 x2: Q) + (Hlt: (Qround_half_up x1 < Qround_half_up x2)%Z): + exists (n: Z), x1 < n + (1#2) <= x2. + Proof. + unfold Qround_half_up in Hlt. + apply Qfloor_diff in Hlt. + destruct Hlt as (n & Hlt & Hle). + exists (n - 1)%Z. + destruct x1 as (x_num1 & x_den1). + destruct x2 as (x_num2 & x_den2). + cbv [Qlt Qle] in *; simpl in *. + split; try lia. + Qed. + + Lemma Qround_half_up_eq (a: Z) (b: positive): + Qround_half_up (a#b) = if (2 * Z.modulo a b -> Q. + Local Coercion Zpos : positive >-> Z. + + Definition mod_approx (approx: Q -> Z) (a: Z) (b: positive): Z := + (a - b * approx (a#b))%Z. + + Lemma Zmod_is_mod_approx_floor: + forall (a: Z) (b: positive), + Z.modulo a b = mod_approx Qfloor a b. + Proof. + intros. unfold mod_approx. + rewrite Qmake_Qdiv, <- Zdiv_Qdiv. + apply Z.mod_eq. discriminate. + Qed. + + Definition signed_mod: Z -> positive -> Z := mod_approx Qround_half_up. + + Lemma signed_mod_eq_Zmod (a: Z) (q: positive): + signed_mod a q = if (2 * Z.modulo a q -> Q. + Local Coercion Zpos : positive >-> Z. + + (* Barrett reduction is the special case with b = 1 *) + Definition barrett_multiplication_approx + (approx: Q -> Z) (R a b: Z) (q: positive): Z := + (a * b - q * Qround_half_up ((a * (approx ((b * R)#q)))%Z / R))%Z. + + (* Not the same bounds on |a| and |b| as in the paper, as theirs cannot be proved *) + Lemma barrett_multiplication_approx_correct + (approx: Q -> Z) (M R a b: Z) (k q: positive) + (Hk: Qabs (((b * R)#q) - approx ((b * R)#q)) <= 1#(Pos.pow 2 k)) + (HOddq: Z.Odd q) (HR: R = Z.pow 2 (M - 1 + Z.log2 q - Z.log2_up (Z.abs b))%Z) + (HM: (2 <= M)%Z) (Ha: (Z.abs a <= Z.pow 2 (M - 1))%Z) + (Hb: (Z.abs b <= Z.pow 2 (M - 2))%Z) + (Ha': (Z.abs a <= Z.pow 2 ((M - 2) - ((Z.log2_up (Z.abs b)) - (k - 1))))%Z): + barrett_multiplication_approx approx R a b q = signed_mod (a * b)%Z q. + Proof. + assert (Hlog2b: (Z.log2_up (Z.abs b) <= M - 2)%Z). + { refine (proj1 (Log2.Z.log2_up_le_pow2_full _ _ _) Hb). lia. } + pose proof (Z.log2_nonneg q) as Hlog2q_nonneg. + pose proof (Z.pow_pos_nonneg 2 (M - 1 + Z.log2 q - Z.log2_up (Z.abs b))%Z ltac:(lia) ltac:(lia)) as HRpos. + unfold barrett_multiplication_approx, signed_mod, mod_approx. + assert (Qround_half_up ((a * approx (b * R # q))%Z / R) = Qround_half_up (a * b # q)) as ->; [|reflexivity]. + match goal with |- ?x = ?y => destruct (Z.eq_dec x y) as [?|Hne]; auto end. + match goal with |- Qround_half_up ?x = Qround_half_up ?y => set (delta := x - y) end. + assert (Hdelta: Qabs delta == ((Qabs a) / R) * Qabs ((b * R # q) - approx (b * R # q))). + { unfold delta. rewrite Qabs_Qminus. + assert ((a * b)#q == (a / R) * ((b * R) # q)) as ->. + { rewrite Qmake_Qdiv, Qmake_Qdiv. + repeat rewrite inject_Z_mult. field. + split; [discriminate|]. cbv [Qeq]; simpl. + rewrite HR. lia. } + assert ((a * approx _)%Z / R == a / R * approx (b * R # q)) as ->. + { repeat rewrite inject_Z_mult. field. + cbv [Qeq]; simpl. rewrite HR; lia. } + assert (a / R * _ - a / R * _ == a / R * ((b * R # q) - approx (b * R # q))) as ->. + { field. cbv [Qeq]; simpl; rewrite HR; lia. } + rewrite Qabs_Qmult. assert (Qabs (a / R) == Qabs a / R) as ->; [|reflexivity]. + cbv [Qdiv]. rewrite Qabs_Qmult, Qabs_Qinv. + rewrite (Qabs_pos R); [reflexivity|]. + cbv [Qle]; simpl. rewrite HR; lia. } + match goal with |- ?x = ?y => assert (x < y \/ y < x)%Z as Hlt by lia end. + assert (Hdelta': 1#(2 * q)%positive <= Qabs delta). + { destruct Hlt as [Hlt|Hlt]; apply Qround_half_up_diff in Hlt; destruct Hlt as (n & Hlo & Hhi). + - apply (proj1 (Qle_minus_iff _ _)) in Hhi. + apply Qopp_lt_compat in Hlo. + apply (proj2 (Qplus_lt_r _ _ ((a * b # q)))) in Hlo. + setoid_replace ((a * b # q) + - ((a * approx (b * R # q))%Z / R)) with (- delta) in Hlo. + 2:{ unfold delta. unfold Qminus. rewrite Qopp_plus, Qopp_involutive, Qplus_comm. reflexivity. } + rewrite <- Qabs_opp. apply Qabs_ge. + refine (Qle_trans _ ((a * b # q) + - (n + (1 # 2))) _ _ _). + 2: apply Qlt_le_weak; assumption. + clear -HOddq Hhi. + cbv [Qopp Qplus Qle] in *. simpl in *. + assert (0 < (a * b * 2 + - (n * 2 + 1) * q))%Z; [|nia]. + destruct (Z.eq_dec 0 (a * b * 2 + - (n * 2 + 1) * q)%Z) as [He|?]; [|lia]. + exfalso. apply (Z.Even_Odd_False ((n * 2 + 1) * q)%Z). + + exists (a * b)%Z. lia. + + destruct HOddq as (? & ->). exists (2 * n * x + n + x)%Z. lia. + - apply (proj1 (Qlt_minus_iff _ _)) in Hlo. + apply (proj2 (Qplus_le_l _ _ (- (a * b # q)))) in Hhi. + apply Qabs_ge. unfold delta. unfold Qminus. + refine (Qle_trans _ _ _ _ Hhi). + clear -HOddq Hlo. + cbv [Qopp Qplus Qlt Qle] in *. simpl in *. + assert (0 < ((n * 2 + 1) * q + - (a * b) * 2))%Z; [|nia]. + lia. } + exfalso. assert (Qabs delta < 1 # 2 * q) as X; [|apply (Qlt_not_le _ _ X); auto]. + rewrite Hdelta. + pose proof (Qle_lt_or_eq _ _ (Qabs_nonneg a)) as [Hpos|Hz]. + - refine (Qle_lt_trans _ ((Qabs a / R) * (1 # 2 ^ k)) _ _ _). + + apply Qmult_le_l; auto. + apply Qlt_shift_div_l; cbv [Qmult Qlt] in *; simpl in *; auto. + rewrite HR; lia. + + assert ((Qabs a / R) * (1 # 2 ^ k) == Qabs a / (R * 2 ^ k)%Z) as ->. + { rewrite HR, <- Z.pow_add_r; try lia. + rewrite <- (Z2Pos.id (Z.pow 2 _)) by lia. + rewrite <- (Z2Pos.id (Z.pow 2 (_ + k))) by lia. + rewrite <- (Zabs_Qabs a xH), <- (Qmake_Qdiv (Z.abs a) _), <- (Qmake_Qdiv (Z.abs a) _). + cbv [Qmult Qeq Qnum Qden]. + rewrite Z2Pos.id by lia. + rewrite Pos2Z.inj_mul, Z2Pos.id by lia. + rewrite Pos2Z.inj_pow, <- Z.pow_add_r by lia. lia. } + rewrite HR, <- Z.pow_add_r; try lia. + rewrite <- (Z2Pos.id (Z.pow _ _)) by lia. + rewrite <- (Zabs_Qabs a xH), <- (Qmake_Qdiv (Z.abs a) (Z.to_pos _)). + cbv [Qlt Qnum Qden]. rewrite Z2Pos.id by lia. + apply (Z.le_lt_trans _ (2 ^ (M - 2 - (Z.log2_up (Z.abs b) - (k - 1))) * (2 * Z.pos q))%Z); [nia|]. + rewrite Z.mul_1_l, Z.mul_assoc, (Z.mul_comm _ 2), <- Z.pow_succ_r by lia. + rewrite <- Z.add_1_r. + apply (Z.lt_le_trans _ (2 ^ (M - 2 - (Z.log2_up (Z.abs b) - (k - 1)) + 1) * Z.pow 2 (Z.succ (Z.log2 q)))%Z). + * apply Zmult_lt_compat_l; [|apply Z.log2_spec; lia]. + apply Z.pow_pos_nonneg; lia. + * rewrite <- Z.add_1_l, <- Z.pow_add_r; [|lia|lia]. + apply Z.pow_le_mono_r; [lia|]. lia. + - rewrite <- Hz. cbv [Qmult Qlt]; simpl. lia. + Qed. + + (* Assumes R_pow, c and v are precomputed *) + (* M is usually the bitwidth *) + (* R_pow := M - 1 + Z.log2 q *) + (* c := approx ((Z.pow 2 R_pow)#q) *) + (* v := Z.shiftl 1 (R_pow - 1) *) + Definition barrett_reduce_approx (R_pow c v a: Z) (q: positive): Z := + let t := (a * c)%Z in + let t := (t + v)%Z in + let t := Z.shiftr t R_pow in + (a - q * t)%Z. + + Lemma barrett_reduce_approx_eq (approx: Q -> Z) (M R_pow c v a: Z) (q: positive) + (HM: (2 <= M)%Z) + (HR_pow: R_pow = (M - 1 + Z.log2 q)%Z) + (Hc: c = approx ((Z.pow 2 R_pow)#q)) + (Hv: v = Z.shiftl 1 (R_pow - 1)): + barrett_reduce_approx R_pow c v a q = barrett_multiplication_approx approx (Z.pow 2 (M - 1 + Z.log2 q)%Z) a 1%Z q. + Proof. + cbv [barrett_multiplication_approx barrett_reduce_approx Qround_half_up]. + rewrite Z.mul_1_r, Z.mul_1_l. + subst v; rewrite Z.shiftl_1_l. rewrite <- HR_pow, <- Hc. + assert (Hle: (0 <= R_pow - 1)%Z) by (rewrite HR_pow; pose proof (Z.log2_nonneg q); lia). + assert (_ + _ == (a * c + 2 ^ (R_pow - 1))%Z / (Z.pow 2 R_pow)) as ->. + { rewrite <- (Z2Pos.id (Z.pow 2 R_pow)) by (apply Z.pow_pos_nonneg; lia). + do 2 rewrite <- (Qmake_Qdiv). + cbv [Qplus Qnum Qden]. rewrite Z.mul_1_l. + rewrite <- (Qmult_1_r (_ # (Z.to_pos (Z.pow 2 R_pow)))). + assert (1 == 2#2) as -> by reflexivity. + cbv [Qmult Qnum Qden]. rewrite Z.mul_add_distr_r. + rewrite (Z.mul_comm (Z.pow _ _) 2). + rewrite <- Z.pow_succ_r by assumption. + assert (Z.succ _ = R_pow)%Z as -> by lia. + rewrite Z2Pos.id; [reflexivity|]. apply Z.pow_pos_nonneg; lia. } + rewrite <- Zdiv_Qdiv, <- Z.shiftr_div_pow2 by lia. reflexivity. + Qed. + + Lemma barrett_reduce_approx_correct + (approx: Q -> Z) (M R_pow c v a: Z) (k q: positive) + (HR_pow: R_pow = (M - 1 + Z.log2 q)%Z) + (Hc: c = approx ((Z.pow 2 R_pow)#q)) + (Hv: v = Z.shiftl 1 (R_pow - 1)) + (Hk: Qabs (((Z.pow 2 (M - 1 + Z.log2 q)%Z)#q) - approx ((Z.pow 2 (M - 1 + Z.log2 q)%Z)#q)) <= 1#(Pos.pow 2 k)) + (HOddq: Z.Odd q) (HM: (2 <= M)%Z) (Ha: (Z.abs a <= Z.pow 2 (M - 1))%Z) + (Ha': (Z.abs a <= Z.pow 2 ((M - 2) + (k - 1)))%Z): + barrett_reduce_approx R_pow c v a q = signed_mod a q. + Proof. + erewrite barrett_reduce_approx_eq by eassumption. + erewrite barrett_multiplication_approx_correct; eauto. + - rewrite Z.mul_1_r; reflexivity. + - rewrite Z.mul_1_l. exact Hk. + - assert (Z.log2_up _ = 0)%Z as -> by reflexivity. + rewrite Z.sub_0_r. reflexivity. + - assert (Z.abs 1 = Z.pow 2 0)%Z as -> by reflexivity. + apply Z.pow_le_mono_r; lia. + - assert (Z.log2_up _ = 0)%Z as -> by reflexivity. + assert (M - 2 - (0 - (k - 1)) = M - 2 + (k - 1))%Z as -> by lia. + exact Ha'. + Qed. +End RefinedBarrettReduction. + +(* Section MLKEM_Example. *) +(* Local Coercion inject_Z : Z >-> Q. *) +(* Local Coercion Zpos : positive >-> Z. *) +(* Let M := 16%Z. *) +(* Let q := 3329%positive. *) +(* Let k := 2%positive. *) + +(* Let R_pow := Eval compute in (M - 1 + Z.log2 q)%Z. (* 26%Z *) *) +(* Let c := Eval compute in Qround_half_up ((Z.pow 2 R_pow)#q). (* 20159%Z *) *) +(* Let v := Eval compute in Z.shiftl 1 (R_pow - 1)%Z. (* 33554432%Z *) *) + +(* Lemma Hk: Qabs (((Z.pow 2 (M - 1 + Z.log2 q)%Z)#q) - Qround_half_up ((Z.pow 2 (M - 1 + Z.log2 q)%Z)#q)) <= 1#(Pos.pow 2 k). *) +(* Proof. compute. congruence. Qed. *) + +(* Lemma HOddq: Z.Odd q. *) +(* Proof. exists (q/2)%Z. reflexivity. Qed. *) + +(* (* *) +(* For comparison https://github.com/pq-crystals/kyber/blob/main/ref/reduce.c *) +(* int16_t barrett_reduce(int16_t a) { *) +(* int16_t t; *) +(* const int16_t v = ((1<<26) + KYBER_Q/2)/KYBER_Q; *) + +(* t = ((int32_t)v*a + (1<<25)) >> 26; *) +(* t *= KYBER_Q; *) +(* return a - t; *) +(* } *) +(* *) *) +(* Definition mlkem_barrett_reduce (a: Z): Z := *) +(* let t := (a * 20159)%Z in *) +(* let t := (t + 33554432)%Z in *) +(* let t := Z.shiftr t 26 in *) +(* (a - q * t)%Z. *) + +(* Lemma MLKEM_barrett_reduce_correct (a: Z) (Ha: (Z.abs a <= Z.pow 2 15)%Z): *) +(* mlkem_barrett_reduce a = signed_mod a q. *) +(* Proof. *) +(* assert (mlkem_barrett_reduce a = barrett_reduce_approx R_pow c v a q) as -> by reflexivity. *) +(* apply (barrett_reduce_approx_correct Qround_half_up M R_pow c v a k q ltac:(reflexivity) ltac:(reflexivity) ltac:(reflexivity) Hk HOddq ltac:(lia) ltac:(apply Ha) ltac:(apply Ha)). *) +(* Qed. *) +(* End MLKEM_Example. *)