diff --git a/src/backend.rs b/src/backend.rs index 0dbf50d..fb3339f 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,5 +1,10 @@ use itertools::izip; +pub trait ModInit { + type Element; + fn new(q: Self::Element) -> Self; +} + pub trait VectorOps { type Element; @@ -40,8 +45,9 @@ pub struct ModularOpsU64 { barrett_alpha: usize, } -impl ModularOpsU64 { - pub fn new(q: u64) -> ModularOpsU64 { +impl ModInit for ModularOpsU64 { + type Element = u64; + fn new(q: u64) -> ModularOpsU64 { let logq = 64 - q.leading_zeros(); // barrett calculation @@ -55,7 +61,9 @@ impl ModularOpsU64 { barrett_mu: mu, } } +} +impl ModularOpsU64 { fn add_mod_fast(&self, a: u64, b: u64) -> u64 { debug_assert!(a < self.q); debug_assert!(b < self.q); diff --git a/src/bool.rs b/src/bool.rs index 310bbc3..c6c88e5 100644 --- a/src/bool.rs +++ b/src/bool.rs @@ -1,20 +1,19 @@ -use std::collections::HashMap; +use std::{collections::HashMap, fmt::Debug, marker::PhantomData}; -use num_traits::{FromPrimitive, One, PrimInt, ToPrimitive, Zero}; +use itertools::Itertools; +use num_traits::{FromPrimitive, Num, One, PrimInt, ToPrimitive, Zero}; use crate::{ - backend::{ArithmeticOps, VectorOps}, - decomposer::Decomposer, - lwe::lwe_key_switch, - ntt::Ntt, - rgsw::{galois_auto, rlwe_by_rgsw, IsTrivial}, - Matrix, MatrixEntity, MatrixMut, Row, RowMut, + backend::{ArithmeticOps, ModInit, VectorOps}, + decomposer::{gadget_vector, Decomposer, DefaultDecomposer, NumInfo}, + lwe::{decrypt_lwe, encrypt_lwe, lwe_key_switch, lwe_ksk_keygen, LweSecret}, + ntt::{Ntt, NttInit}, + random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist}, + rgsw::{encrypt_rgsw, galois_auto, galois_key_gen, rlwe_by_rgsw, IsTrivial, RlweSecret}, + utils::{generate_prime, mod_exponent, TryConvertFrom, WithLocal}, + Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, }; -struct BoolEvaluator {} - -impl BoolEvaluator {} - trait PbsKey { type M: Matrix; @@ -23,6 +22,278 @@ trait PbsKey { fn auto_map_index(&self, k: isize) -> &[usize]; fn auto_map_sign(&self, k: isize) -> &[bool]; } +trait Parameters { + type Element; + type D: Decomposer; + fn rlwe_q(&self) -> Self::Element; + fn lwe_q(&self) -> Self::Element; + fn br_q(&self) -> usize; + fn d_rgsw(&self) -> usize; + fn d_lwe(&self) -> usize; + fn rlwe_n(&self) -> usize; + fn lwe_n(&self) -> usize; + /// Embedding fator for ring X^{q}+1 inside + fn embedding_factor(&self) -> usize; + /// generator g + fn g(&self) -> isize; + fn decomoposer_lwe(&self) -> &Self::D; + fn decomoposer_rlwe(&self) -> &Self::D; + + /// Maps a \in Z^*_{q} to discrete log k, with generator g (i.e. g^k = + /// a). Returned vector is of size q that stores dlog of a at `vec[a]`. + /// For any a, if k is s.t. a = g^{k}, then k is expressed as k. If k is s.t + /// a = -g^{k}, then k is expressed as k=k+q/2 + fn g_k_dlog_map(&self) -> &[usize]; +} +struct ClientKey { + sk_rlwe: RlweSecret, + sk_lwe: LweSecret, +} + +struct ServerKey { + /// Rgsw cts of LWE secret elements + rgsw_cts: Vec, + /// Galois keys + galois_keys: HashMap, + /// LWE ksk to key switching LWE ciphertext from RLWE secret to LWE secret + lwe_ksk: M, +} + +struct BoolParameters { + rlwe_q: El, + rlwe_logq: usize, + lwe_q: El, + lwe_logq: usize, + br_q: usize, + rlwe_n: usize, + lwe_n: usize, + d_rgsw: usize, + logb_rgsw: usize, + d_lwe: usize, + logb_lwe: usize, + g: usize, + w: usize, +} + +struct BoolEvaluator { + parameters: BoolParameters, + decomposer_rlwe: DefaultDecomposer, + decomposer_lwe: DefaultDecomposer, + g_k_dlog_map: Vec, + rlwe_nttop: Ntt, + rlwe_modop: ModOp, + lwe_modop: ModOp, + embedding_factor: usize, + + _phantom: PhantomData, +} + +impl BoolEvaluator +where + NttOp: NttInit + Ntt, + ModOp: ModInit + + ArithmeticOps + + VectorOps, + M::MatElement: PrimInt + Debug + NumInfo + FromPrimitive, + M: MatrixEntity + MatrixMut, + M::R: TryConvertFrom<[i32], Parameters = M::MatElement> + RowEntity, + M: TryConvertFrom<[i32], Parameters = M::MatElement>, + ::R: RowMut, + DefaultSecureRng: RandomGaussianDist<[M::MatElement], Parameters = M::MatElement> + + RandomGaussianDist + + RandomUniformDist<[M::MatElement], Parameters = M::MatElement>, +{ + fn new(parameters: BoolParameters) -> Self { + //TODO(Jay): Run sanity checks for modulus values in parameters + + let decomposer_rlwe = + DefaultDecomposer::new(parameters.rlwe_q, parameters.logb_rgsw, parameters.d_rgsw); + let decomposer_lwe = + DefaultDecomposer::new(parameters.lwe_q, parameters.logb_lwe, parameters.d_lwe); + + // generatr dlog map s.t. g^{k} % q = a, for all a \in Z*_{q} + let g = parameters.g; + let q = parameters.br_q; + let mut g_k_dlog_map = vec![0usize; q]; + for i in 0..q / 2 { + let v = mod_exponent(g as u64, i as u64, q as u64) as usize; + // g^i + g_k_dlog_map[v] = i; + // -(g^i) + g_k_dlog_map[q - v] = i + (q / 2); + } + + let embedding_factor = (2 * parameters.rlwe_n) / q; + + let rlwe_nttop = NttOp::new(parameters.rlwe_q, parameters.rlwe_n); + let rlwe_modop = ModInit::new(parameters.rlwe_q); + let lwe_modop = ModInit::new(parameters.lwe_q); + + BoolEvaluator { + parameters: parameters, + decomposer_lwe, + decomposer_rlwe, + g_k_dlog_map, + embedding_factor, + lwe_modop, + rlwe_modop, + rlwe_nttop, + + _phantom: PhantomData, + } + } + + fn client_key(&self) -> ClientKey { + let sk_lwe = LweSecret::random(self.parameters.lwe_n >> 1, self.parameters.lwe_n); + let sk_rlwe = RlweSecret::random(self.parameters.rlwe_n >> 1, self.parameters.rlwe_n); + ClientKey { sk_rlwe, sk_lwe } + } + + fn server_key(&self, client_key: &ClientKey) -> ServerKey { + let sk_rlwe = &client_key.sk_rlwe; + let sk_lwe = &client_key.sk_lwe; + + let d_rgsw_gadget_vec = gadget_vector( + self.parameters.rlwe_logq, + self.parameters.logb_rgsw, + self.parameters.d_rgsw, + ); + + // generate galois key -g, g + let mut galois_keys = HashMap::new(); + let g = self.parameters.g as isize; + for i in [g, -g] { + let gk = DefaultSecureRng::with_local_mut(|rng| { + let mut ksk_out = M::zeros(self.parameters.d_rgsw * 2, self.parameters.rlwe_n); + galois_key_gen( + &mut ksk_out, + sk_rlwe, + i, + &d_rgsw_gadget_vec, + &self.rlwe_modop, + &self.rlwe_nttop, + rng, + ); + ksk_out + }); + + galois_keys.insert(i, gk); + } + + // generate rgsw ciphertexts RGSW(si) where si is i^th LWE secret element + let ring_size = self.parameters.rlwe_n; + let rlwe_q = self.parameters.rlwe_q; + let rgsw_cts = sk_lwe + .values() + .iter() + .map(|si| { + // X^{si}; assume |emebedding_factor * si| < N + let mut m = M::zeros(1, ring_size); + let si = (self.embedding_factor as i32) * si; + if si < 0 { + // X^{-i} = X^{2N - i} = -X^{N-i} + m.set( + 0, + ring_size - (si.abs() as usize), + rlwe_q - M::MatElement::one(), + ); + } else { + // X^{i} + m.set(0, (si.abs() as usize), M::MatElement::one()); + } + self.rlwe_nttop.forward(m.get_row_mut(0)); + + let rgsw_si = DefaultSecureRng::with_local_mut(|rng| { + let mut rgsw_si = M::zeros(self.parameters.d_rgsw * 4, ring_size); + encrypt_rgsw( + &mut rgsw_si, + &m, + &d_rgsw_gadget_vec, + sk_rlwe, + &self.rlwe_modop, + &self.rlwe_nttop, + rng, + ); + rgsw_si + }); + rgsw_si + }) + .collect_vec(); + + // LWE KSK from RLWE secret s -> LWE secret z + let d_lwe_gadget = gadget_vector( + self.parameters.lwe_logq, + self.parameters.logb_lwe, + self.parameters.d_lwe, + ); + let mut lwe_ksk = DefaultSecureRng::with_local_mut(|rng| { + let mut out = M::zeros(self.parameters.d_lwe * ring_size, self.parameters.lwe_n + 1); + lwe_ksk_keygen( + &sk_rlwe.values(), + &sk_lwe.values(), + &mut out, + &d_lwe_gadget, + &self.lwe_modop, + rng, + ); + out + }); + + ServerKey { + rgsw_cts, + galois_keys, + lwe_ksk, + } + } + + pub fn encrypt(&self, m: bool, client_key: &ClientKey) -> M::R { + let rlwe_q_by8 = + M::MatElement::from_f64((self.parameters.rlwe_q.to_f64().unwrap() / 8.0).round()) + .unwrap(); + let m = if m { + // Q/8 + rlwe_q_by8 + } else { + // -Q/8 + self.parameters.rlwe_q - rlwe_q_by8 + }; + + DefaultSecureRng::with_local_mut(|rng| { + let mut lwe_out = M::R::zeros(self.parameters.rlwe_n + 1); + encrypt_lwe( + &mut lwe_out, + &m, + client_key.sk_rlwe.values(), + &self.rlwe_modop, + rng, + ); + lwe_out + }) + } + + pub fn decrypt(&self, lwe_ct: &M::R, client_key: &ClientKey) -> bool { + let m = decrypt_lwe(lwe_ct, client_key.sk_rlwe.values(), &self.rlwe_modop); + let m = { + // m + q/8 => {0,q/4 1} + let rlwe_q_by8 = + M::MatElement::from_f64((self.parameters.rlwe_q.to_f64().unwrap() / 8.0).round()) + .unwrap(); + (((m + rlwe_q_by8).to_f64().unwrap() * 4.0) / self.parameters.rlwe_q.to_f64().unwrap()) + .round() + .to_usize() + .unwrap() + % 4 + }; + + if m == 0 { + false + } else if m == 1 { + true + } else { + panic!("Incorrect bool decryption. Got m={m} expected m to be 0 or 1") + } + } +} /// LMKCY+ Blind rotation /// @@ -137,29 +408,6 @@ fn blind_rotation< }); } -trait Parameters { - type Element; - type D: Decomposer; - fn rlwe_q(&self) -> Self::Element; - fn lwe_q(&self) -> Self::Element; - fn br_q(&self) -> usize; - fn d_rgsw(&self) -> usize; - fn d_lwe(&self) -> usize; - fn rlwe_n(&self) -> usize; - fn lwe_n(&self) -> usize; - // Embedding fator for ring X^{q}+1 inside - fn embedding_factor(&self) -> usize; - // generator g - fn g(&self) -> isize; - fn decomoposer_lwe(&self) -> &Self::D; - fn decomoposer_rlwe(&self) -> &Self::D; - /// Maps a \in Z^*_{2q} to discrete log k, with generator g (i.e. g^k = - /// a). Returned vector is of size q that stores dlog of a at `vec[a]`. - /// For any a, k is s.t. a = g^{k}, then k is expressed as k. If k is s.t a - /// = -g^{k/2}, then k is expressed as k=k+q/2 - fn g_k_dlog_map(&self) -> &[usize]; -} - /// - Mod down /// - key switching /// - mod down @@ -274,7 +522,6 @@ fn pbs< partb_trivial_rlwe[2 * index] = *v; }); } - // TODO Rotate test input // blind rotate blind_rotation( @@ -358,3 +605,44 @@ fn monomial_mul>( } }); } + +#[cfg(test)] +mod tests { + use crate::{backend::ModularOpsU64, ntt::NttBackendU64}; + + use super::*; + + const SP_BOOL_PARAMS: BoolParameters = BoolParameters:: { + rlwe_q: 4294957057u64, + rlwe_logq: 32, + lwe_q: 1 << 16, + lwe_logq: 16, + br_q: 1 << 9, + rlwe_n: 1 << 10, + lwe_n: 490, + d_rgsw: 4, + logb_rgsw: 7, + d_lwe: 4, + logb_lwe: 4, + g: 5, + w: 1, + }; + + #[test] + fn encrypt_decrypt_works() { + // let prime = generate_prime(32, 2 * 1024, 1 << 32); + // dbg!(prime); + let bool_evaluator = + BoolEvaluator::>, u64, NttBackendU64, ModularOpsU64>::new(SP_BOOL_PARAMS); + let client_key = bool_evaluator.client_key(); + // let sever_key = bool_evaluator.server_key(&client_key); + + let mut m = true; + for _ in 0..1000 { + let lwe_ct = bool_evaluator.encrypt(m, &client_key); + let m_back = bool_evaluator.decrypt(&lwe_ct, &client_key); + assert_eq!(m, m_back); + m = !m; + } + } +} diff --git a/src/decomposer.rs b/src/decomposer.rs index f148705..5b88334 100644 --- a/src/decomposer.rs +++ b/src/decomposer.rs @@ -143,7 +143,11 @@ fn round_value(value: T, ignore_bits: usize) -> T { mod tests { use rand::{thread_rng, Rng}; - use crate::{backend::ModularOpsU64, decomposer::round_value, utils::generate_prime}; + use crate::{ + backend::{ModInit, ModularOpsU64}, + decomposer::round_value, + utils::generate_prime, + }; use super::{Decomposer, DefaultDecomposer}; diff --git a/src/lib.rs b/src/lib.rs index 23a4a11..eca001b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,6 +77,10 @@ pub trait Row: AsRef<[Self::Element]> { pub trait RowMut: Row + AsMut<[::Element]> {} +pub trait RowEntity: Row { + fn zeros(col: usize) -> Self; +} + trait Secret { type Element; fn values(&self) -> &[Self::Element]; @@ -123,3 +127,9 @@ impl Row for Vec { } impl RowMut for Vec {} + +impl RowEntity for Vec { + fn zeros(col: usize) -> Self { + vec![T::zero(); col] + } +} diff --git a/src/lwe.rs b/src/lwe.rs index b66e7cd..849450f 100644 --- a/src/lwe.rs +++ b/src/lwe.rs @@ -21,7 +21,7 @@ trait LweKeySwitchParameters { trait LweCiphertext {} -struct LweSecret { +pub struct LweSecret { values: Vec, } @@ -33,7 +33,7 @@ impl Secret for LweSecret { } impl LweSecret { - fn random(hw: usize, n: usize) -> LweSecret { + pub(crate) fn random(hw: usize, n: usize) -> LweSecret { DefaultSecureRng::with_local_mut(|rng| { let mut out = vec![0i32; n]; fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng); @@ -71,38 +71,32 @@ pub(crate) fn lwe_key_switch< lwe_out.as_mut()[0] = out_b; } -fn lwe_ksk_keygen< +pub fn lwe_ksk_keygen< Mmut: MatrixMut, - S: Secret, + S, Op: VectorOps + ArithmeticOps, R: RandomGaussianDist + RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>, >( - from_lwe_sk: &S, - to_lwe_sk: &S, + from_lwe_sk: &[S], + to_lwe_sk: &[S], ksk_out: &mut Mmut, gadget: &[Mmut::MatElement], operator: &Op, rng: &mut R, ) where ::R: RowMut, - Mmut::R: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, + Mmut::R: TryConvertFrom<[S], Parameters = Mmut::MatElement>, Mmut::MatElement: Zero + Debug, { - assert!( - ksk_out.dimension() - == ( - from_lwe_sk.values().len() * gadget.len(), - to_lwe_sk.values().len() + 1, - ) - ); + assert!(ksk_out.dimension() == (from_lwe_sk.len() * gadget.len(), to_lwe_sk.len() + 1,)); let d = gadget.len(); let modulus = VectorOps::modulus(operator); - let mut neg_sk_in_m = Mmut::R::try_convert_from(from_lwe_sk.values(), &modulus); + let mut neg_sk_in_m = Mmut::R::try_convert_from(from_lwe_sk, &modulus); operator.elwise_neg_mut(neg_sk_in_m.as_mut()); - let sk_out_m = Mmut::R::try_convert_from(to_lwe_sk.values(), &modulus); + let sk_out_m = Mmut::R::try_convert_from(to_lwe_sk, &modulus); izip!( neg_sk_in_m.as_ref(), @@ -134,23 +128,23 @@ fn lwe_ksk_keygen< } /// Encrypts encoded message m as LWE ciphertext -fn encrypt_lwe< +pub fn encrypt_lwe< Ro: Row + RowMut, R: RandomGaussianDist + RandomUniformDist<[Ro::Element], Parameters = Ro::Element>, - S: Secret, + S, Op: ArithmeticOps, >( lwe_out: &mut Ro, m: &Ro::Element, - s: &S, + s: &[S], operator: &Op, rng: &mut R, ) where - Ro: TryConvertFrom<[S::Element], Parameters = Ro::Element>, + Ro: TryConvertFrom<[S], Parameters = Ro::Element>, Ro::Element: Zero, { - let s = Ro::try_convert_from(s.values(), &operator.modulus()); + let s = Ro::try_convert_from(s, &operator.modulus()); assert!(s.as_ref().len() == (lwe_out.as_ref().len() - 1)); // a*s @@ -168,16 +162,16 @@ fn encrypt_lwe< lwe_out.as_mut()[0] = b; } -fn decrypt_lwe, S: Secret>( +pub fn decrypt_lwe, S>( lwe_ct: &Ro, - s: &S, + s: &[S], operator: &Op, ) -> Ro::Element where - Ro: TryConvertFrom<[S::Element], Parameters = Ro::Element>, + Ro: TryConvertFrom<[S], Parameters = Ro::Element>, Ro::Element: Zero, { - let s = Ro::try_convert_from(s.values(), &operator.modulus()); + let s = Ro::try_convert_from(s, &operator.modulus()); let mut sa = Ro::Element::zero(); izip!(lwe_ct.as_ref().iter().skip(1), s.as_ref()).for_each(|(ai, si)| { @@ -193,10 +187,11 @@ where mod tests { use crate::{ - backend::ModularOpsU64, + backend::{ModInit, ModularOpsU64}, decomposer::{gadget_vector, DefaultDecomposer}, lwe::lwe_key_switch, random::DefaultSecureRng, + Secret, }; use super::{decrypt_lwe, encrypt_lwe, lwe_ksk_keygen, LweSecret}; @@ -217,8 +212,14 @@ mod tests { for m in 0..1u64 << logp { let encoded_m = m << (logq - logp); let mut lwe_ct = vec![0u64; lwe_n + 1]; - encrypt_lwe(&mut lwe_ct, &encoded_m, &lwe_sk, &modq_op, &mut rng); - let encoded_m_back = decrypt_lwe(&lwe_ct, &lwe_sk, &modq_op); + encrypt_lwe( + &mut lwe_ct, + &encoded_m, + &lwe_sk.values(), + &modq_op, + &mut rng, + ); + let encoded_m_back = decrypt_lwe(&lwe_ct, &lwe_sk.values(), &modq_op); let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round() as u64) % (1u64 << logp); @@ -247,8 +248,8 @@ mod tests { let mut ksk = vec![vec![0u64; lwe_out_n + 1]; d_ks * lwe_in_n]; let gadget = gadget_vector(logq, logb, d_ks); lwe_ksk_keygen( - &lwe_sk_in, - &lwe_sk_out, + &lwe_sk_in.values(), + &lwe_sk_out.values(), &mut ksk, &gadget, &modq_op, @@ -260,7 +261,13 @@ mod tests { // encrypt using lwe_sk_in let encoded_m = m << (logq - logp); let mut lwe_in_ct = vec![0u64; lwe_in_n + 1]; - encrypt_lwe(&mut lwe_in_ct, &encoded_m, &lwe_sk_in, &modq_op, &mut rng); + encrypt_lwe( + &mut lwe_in_ct, + &encoded_m, + lwe_sk_in.values(), + &modq_op, + &mut rng, + ); // key switch from lwe_sk_in to lwe_sk_out let decomposer = DefaultDecomposer::new(1u64 << logq, logb, d_ks); @@ -268,7 +275,7 @@ mod tests { lwe_key_switch(&mut lwe_out_ct, &lwe_in_ct, &ksk, &modq_op, &decomposer); // decrypt lwe_out_ct using lwe_sk_out - let encoded_m_back = decrypt_lwe(&lwe_out_ct, &lwe_sk_out, &modq_op); + let encoded_m_back = decrypt_lwe(&lwe_out_ct, &lwe_sk_out.values(), &modq_op); let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round() as u64) % (1u64 << logp); diff --git a/src/ntt.rs b/src/ntt.rs index 320a28c..9c57402 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -2,10 +2,15 @@ use itertools::Itertools; use rand::{thread_rng, Rng, RngCore}; use crate::{ - backend::{ArithmeticOps, ModularOpsU64}, + backend::{ArithmeticOps, ModInit, ModularOpsU64}, utils::{mod_exponent, mod_inverse, shoup_representation_fq}, }; +pub trait NttInit { + type Element; + fn new(q: Self::Element, n: usize) -> Self; +} + pub trait Ntt { type Element; fn forward_lazy(&self, v: &mut [Self::Element]); @@ -195,8 +200,9 @@ pub struct NttBackendU64 { psi_inv_powers_bo_shoup: Box<[u64]>, } -impl NttBackendU64 { - pub fn new(q: u64, n: usize) -> Self { +impl NttInit for NttBackendU64 { + type Element = u64; + fn new(q: u64, n: usize) -> Self { // \psi = 2n^{th} primitive root of unity in F_q let mut rng = thread_rng(); let psi = find_primitive_root(q, (n * 2) as u64, &mut rng) @@ -325,9 +331,9 @@ mod tests { use rand::{thread_rng, Rng}; use rand_distr::Uniform; - use super::NttBackendU64; + use super::{NttBackendU64, NttInit}; use crate::{ - backend::{ModularOpsU64, VectorOps}, + backend::{ArithmeticOps, ModInit, ModularOpsU64, VectorOps}, ntt::Ntt, utils::{generate_prime, negacyclic_mul}, }; diff --git a/src/rgsw.rs b/src/rgsw.rs index dac3361..b3b400a 100644 --- a/src/rgsw.rs +++ b/src/rgsw.rs @@ -16,7 +16,7 @@ use crate::{ Matrix, MatrixEntity, MatrixMut, RowMut, Secret, }; -struct RlweCiphertext(M, bool); +pub struct RlweCiphertext(M, bool); impl Matrix for RlweCiphertext { type MatElement = M::MatElement; @@ -58,7 +58,7 @@ pub trait IsTrivial { fn set_not_trivial(&mut self); } -struct RlweSecret { +pub struct RlweSecret { values: Vec, } @@ -70,7 +70,7 @@ impl Secret for RlweSecret { } impl RlweSecret { - fn random(hw: usize, n: usize) -> RlweSecret { + pub fn random(hw: usize, n: usize) -> RlweSecret { DefaultSecureRng::with_local_mut(|rng| { let mut out = vec![0i32; n]; fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng); @@ -80,8 +80,15 @@ impl RlweSecret { } } -fn generate_auto_map(ring_size: usize, k: usize) -> (Vec, Vec) { +fn generate_auto_map(ring_size: usize, k: isize) -> (Vec, Vec) { assert!(k & 1 == 1, "Auto {k} must be odd"); + + // k = k % 2*N + let k = if k < 0 { + (2 * ring_size) - (k.abs() as usize) + } else { + k as usize + }; let (auto_map_index, auto_sign_index): (Vec, Vec) = (0..ring_size) .into_iter() .map(|i| { @@ -183,13 +190,14 @@ pub(crate) fn galois_key_gen< >( ksk_out: &mut Mmut, s: &S, - auto_k: usize, + auto_k: isize, gadget_vector: &[Mmut::MatElement], mod_op: &ModOp, ntt_op: &NttOp, rng: &mut R, ) where ::R: RowMut, + //FIXME(Jay): Why isn't this bound Mmut::R: given that secret is a vector (Row) not a matrix Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, Mmut::MatElement: Copy + Sub, { @@ -327,7 +335,7 @@ pub(crate) fn galois_auto< /// RLWE'_B(-sm) || RLWE'_A(m) || RLWE'_B(m)]^T pub(crate) fn encrypt_rgsw< Mmut: MatrixMut + MatrixEntity, - M: Matrix + Clone, + M: Matrix, S: Secret, R: RandomGaussianDist<[Mmut::MatElement], Parameters = Mmut::MatElement> + RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>, @@ -805,9 +813,9 @@ mod tests { use rand::{thread_rng, Rng}; use crate::{ - backend::ModularOpsU64, + backend::{ModInit, ModularOpsU64}, decomposer::{gadget_vector, DefaultDecomposer}, - ntt::{self, Ntt, NttBackendU64}, + ntt::{self, Ntt, NttBackendU64, NttInit}, random::{DefaultSecureRng, RandomUniformDist}, rgsw::{measure_noise, RlweCiphertext}, utils::{generate_prime, negacyclic_mul}, @@ -933,7 +941,7 @@ mod tests { &mut rng, ); - let auto_k = 25; + let auto_k = -25; // Generate galois key to key switch from s^k to s let mut ksk_out = vec![vec![0u64; ring_size as usize]; d_rgsw * 2];