rlwe x rgsw not working

This commit is contained in:
Janmajaya Mall
2024-05-01 16:04:51 +05:30
parent 4b835461dd
commit 56752a7559
6 changed files with 871 additions and 692 deletions

View File

@@ -17,16 +17,16 @@ use crate::{
ntt::{Ntt, NttBackendU64, NttInit}, ntt::{Ntt, NttBackendU64, NttInit},
random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist}, random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist},
rgsw::{ rgsw::{
decrypt_rlwe, encrypt_rgsw, galois_auto, galois_key_gen, generate_auto_map, rlwe_by_rgsw, decrypt_rlwe, galois_auto, galois_key_gen, generate_auto_map, rlwe_by_rgsw,
IsTrivial, RlweCiphertext, RlweSecret, secret_key_encrypt_rgsw, IsTrivial, RlweCiphertext, RlweSecret,
}, },
utils::{generate_prime, mod_exponent, TryConvertFrom, WithLocal}, utils::{generate_prime, mod_exponent, TryConvertFrom, WithLocal},
Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret, Matrix, MatrixEntity, MatrixMut, Row, RowEntity, RowMut, Secret,
}; };
thread_local! { // thread_local! {
pub(crate) static CLIENT_KEY: RefCell<ClientKey> = RefCell::new(ClientKey::random()); // pub(crate) static CLIENT_KEY: RefCell<ClientKey> =
} // RefCell::new(ClientKey::random()); }
trait PbsKey { trait PbsKey {
type M: Matrix; type M: Matrix;
@@ -77,25 +77,25 @@ impl ClientKey {
} }
} }
impl WithLocal for ClientKey { // impl WithLocal for ClientKey {
fn with_local<F, R>(func: F) -> R // fn with_local<F, R>(func: F) -> R
where // where
F: Fn(&Self) -> R, // F: Fn(&Self) -> R,
{ // {
CLIENT_KEY.with_borrow(|client_key| func(client_key)) // CLIENT_KEY.with_borrow(|client_key| func(client_key))
} // }
fn with_local_mut<F, R>(func: F) -> R // fn with_local_mut<F, R>(func: F) -> R
where // where
F: Fn(&mut Self) -> R, // F: Fn(&mut Self) -> R,
{ // {
CLIENT_KEY.with_borrow_mut(|client_key| func(client_key)) // CLIENT_KEY.with_borrow_mut(|client_key| func(client_key))
} // }
} // }
fn set_client_key(key: &ClientKey) { // fn set_client_key(key: &ClientKey) {
ClientKey::with_local_mut(|k| *k = key.clone()) // ClientKey::with_local_mut(|k| *k = key.clone())
} // }
struct ServerKey<M> { struct ServerKey<M> {
/// Rgsw cts of LWE secret elements /// Rgsw cts of LWE secret elements
@@ -154,8 +154,6 @@ where
nand_test_vec: M::R, nand_test_vec: M::R,
rlweq_by8: M::MatElement, rlweq_by8: M::MatElement,
rlwe_auto_maps: Vec<(Vec<usize>, Vec<bool>)>, rlwe_auto_maps: Vec<(Vec<usize>, Vec<bool>)>,
scratch_lwen_plus1: M::R,
scratch_dplus2_ring: M,
_phantom: PhantomData<M>, _phantom: PhantomData<M>,
} }
@@ -264,10 +262,6 @@ where
rlwe_auto_maps.push(generate_auto_map(ring_size, i)) rlwe_auto_maps.push(generate_auto_map(ring_size, i))
} }
// create srcatch spaces
let scratch_lwen_plus1 = M::R::zeros(parameters.lwe_n + 1);
let scratch_dplus2_ring = M::zeros(parameters.d_rgsw + 2, parameters.rlwe_n);
BoolEvaluator { BoolEvaluator {
parameters: parameters, parameters: parameters,
decomposer_lwe, decomposer_lwe,
@@ -280,8 +274,7 @@ where
nand_test_vec: nand_test_vec_autog, nand_test_vec: nand_test_vec_autog,
rlweq_by8: rlwe_qby8, rlweq_by8: rlwe_qby8,
rlwe_auto_maps, rlwe_auto_maps,
scratch_lwen_plus1,
scratch_dplus2_ring,
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
@@ -293,101 +286,103 @@ where
} }
fn server_key(&self, client_key: &ClientKey) -> ServerKey<M> { fn server_key(&self, client_key: &ClientKey) -> ServerKey<M> {
let sk_rlwe = &client_key.sk_rlwe; // let sk_rlwe = &client_key.sk_rlwe;
let sk_lwe = &client_key.sk_lwe; // let sk_lwe = &client_key.sk_lwe;
let d_rgsw_gadget_vec = gadget_vector( // let d_rgsw_gadget_vec = gadget_vector(
self.parameters.rlwe_logq, // self.parameters.rlwe_logq,
self.parameters.logb_rgsw, // self.parameters.logb_rgsw,
self.parameters.d_rgsw, // self.parameters.d_rgsw,
); // );
// generate galois key -g, g // // generate galois key -g, g
let mut galois_keys = HashMap::new(); // let mut galois_keys = HashMap::new();
let g = self.parameters.g as isize; // let g = self.parameters.g as isize;
for i in [g, -g] { // for i in [g, -g] {
let gk = DefaultSecureRng::with_local_mut(|rng| { // let gk = DefaultSecureRng::with_local_mut(|rng| {
let mut ksk_out = M::zeros(self.parameters.d_rgsw * 2, self.parameters.rlwe_n); // let mut ksk_out = M::zeros(self.parameters.d_rgsw * 2,
galois_key_gen( // self.parameters.rlwe_n); galois_key_gen(
&mut ksk_out, // &mut ksk_out,
sk_rlwe, // sk_rlwe,
i, // i,
&d_rgsw_gadget_vec, // &d_rgsw_gadget_vec,
&self.rlwe_modop, // &self.rlwe_modop,
&self.rlwe_nttop, // &self.rlwe_nttop,
rng, // rng,
); // );
ksk_out // ksk_out
}); // });
galois_keys.insert(i, gk); // galois_keys.insert(i, gk);
} // }
// generate rgsw ciphertexts RGSW(si) where si is i^th LWE secret element // // generate rgsw ciphertexts RGSW(si) where si is i^th LWE secret element
let ring_size = self.parameters.rlwe_n; // let ring_size = self.parameters.rlwe_n;
let rlwe_q = self.parameters.rlwe_q; // let rlwe_q = self.parameters.rlwe_q;
let rgsw_cts = sk_lwe // let rgsw_cts = sk_lwe
.values() // .values()
.iter() // .iter()
.map(|si| { // .map(|si| {
// X^{si}; assume |emebedding_factor * si| < N // // X^{si}; assume |emebedding_factor * si| < N
let mut m = M::zeros(1, ring_size); // let mut m = M::zeros(1, ring_size);
let si = (self.embedding_factor as i32) * si; // let si = (self.embedding_factor as i32) * si;
// dbg!(si); // // dbg!(si);
if si < 0 { // if si < 0 {
// X^{-i} = X^{2N - i} = -X^{N-i} // // X^{-i} = X^{2N - i} = -X^{N-i}
m.set( // m.set(
0, // 0,
ring_size - (si.abs() as usize), // ring_size - (si.abs() as usize),
rlwe_q - M::MatElement::one(), // rlwe_q - M::MatElement::one(),
); // );
} else { // } else {
// X^{i} // // X^{i}
m.set(0, (si.abs() as usize), M::MatElement::one()); // m.set(0, (si.abs() as usize), M::MatElement::one());
} // }
self.rlwe_nttop.forward(m.get_row_mut(0)); // self.rlwe_nttop.forward(m.get_row_mut(0));
let rgsw_si = DefaultSecureRng::with_local_mut(|rng| { // let rgsw_si = DefaultSecureRng::with_local_mut(|rng| {
let mut rgsw_si = M::zeros(self.parameters.d_rgsw * 4, ring_size); // let mut rgsw_si = M::zeros(self.parameters.d_rgsw * 4,
encrypt_rgsw( // ring_size); secret_key_encrypt_rgsw(
&mut rgsw_si, // &mut rgsw_si,
&m, // &m,
&d_rgsw_gadget_vec, // &d_rgsw_gadget_vec,
sk_rlwe, // sk_rlwe,
&self.rlwe_modop, // &self.rlwe_modop,
&self.rlwe_nttop, // &self.rlwe_nttop,
rng, // rng,
); // );
rgsw_si // rgsw_si
}); // });
rgsw_si // rgsw_si
}) // })
.collect_vec(); // .collect_vec();
// LWE KSK from RLWE secret s -> LWE secret z // // LWE KSK from RLWE secret s -> LWE secret z
let d_lwe_gadget = gadget_vector( // let d_lwe_gadget = gadget_vector(
self.parameters.lwe_logq, // self.parameters.lwe_logq,
self.parameters.logb_lwe, // self.parameters.logb_lwe,
self.parameters.d_lwe, // self.parameters.d_lwe,
); // );
let mut lwe_ksk = DefaultSecureRng::with_local_mut(|rng| { // let mut lwe_ksk = DefaultSecureRng::with_local_mut(|rng| {
let mut out = M::zeros(self.parameters.d_lwe * ring_size, self.parameters.lwe_n + 1); // let mut out = M::zeros(self.parameters.d_lwe * ring_size,
lwe_ksk_keygen( // self.parameters.lwe_n + 1); lwe_ksk_keygen(
&sk_rlwe.values(), // &sk_rlwe.values(),
&sk_lwe.values(), // &sk_lwe.values(),
&mut out, // &mut out,
&d_lwe_gadget, // &d_lwe_gadget,
&self.lwe_modop, // &self.lwe_modop,
rng, // rng,
); // );
out // out
}); // });
ServerKey { // ServerKey {
rgsw_cts, // rgsw_cts,
galois_keys, // galois_keys,
lwe_ksk, // lwe_ksk,
} // }
todo!()
} }
/// TODO(Jay): Fetch client key from thread local /// TODO(Jay): Fetch client key from thread local
@@ -434,6 +429,7 @@ where
} }
} }
// TODO(Jay): scratch spaces must be thread local. Don't pass them as arguments
pub fn nand( pub fn nand(
&self, &self,
c0: &M::R, c0: &M::R,
@@ -688,9 +684,7 @@ fn pbs<
nttop_rlweq: &NttOp, nttop_rlweq: &NttOp,
pbs_key: &K, pbs_key: &K,
) where ) where
// FIXME(Jay): TryConvertFrom<[i32], Parameters = M::MatElement> are only needed for <M as Matrix>::R: RowMut,
// debugging purposes
<M as Matrix>::R: RowMut + TryConvertFrom<[i32], Parameters = M::MatElement>,
M::MatElement: PrimInt + ToPrimitive + FromPrimitive + One + Copy + Zero + Display, M::MatElement: PrimInt + ToPrimitive + FromPrimitive + One + Copy + Zero + Display,
{ {
let rlwe_q = parameters.rlwe_q(); let rlwe_q = parameters.rlwe_q();
@@ -775,7 +769,8 @@ fn pbs<
gb_monomial_sign = false gb_monomial_sign = false
} }
// monomial mul // monomial mul
let mut trivial_rlwe_test_poly = RlweCiphertext(M::zeros(2, rlwe_n), true); let mut trivial_rlwe_test_poly =
RlweCiphertext::<_, DefaultSecureRng>::from_raw(M::zeros(2, rlwe_n), true);
if parameters.embedding_factor() == 1 { if parameters.embedding_factor() == 1 {
monomial_mul( monomial_mul(
test_vec.as_ref(), test_vec.as_ref(),
@@ -853,16 +848,12 @@ fn pbs<
} }
fn mod_switch_odd(v: f64, from_q: f64, to_q: f64) -> usize { fn mod_switch_odd(v: f64, from_q: f64, to_q: f64) -> usize {
// println!("v: {v}, odd_v: {odd_v}, lwe_q:{lwe_q}, br_q:{br_q}");
let odd_v = (((v * to_q) / (from_q)).floor()).to_usize().unwrap(); let odd_v = (((v * to_q) / (from_q)).floor()).to_usize().unwrap();
// println!(
// "v: {v}, odd_v: {odd_v}, returned_oddv: {},lwe_q:{from_q}, br_q:{to_q}",
// odd_v + ((odd_v & 1) ^ 1)
// );
//TODO(Jay): check correctness of this //TODO(Jay): check correctness of this
odd_v + ((odd_v & 1) ^ 1) odd_v + ((odd_v & 1) ^ 1)
} }
// TODO(Jay): Add tests for sample extract
fn sample_extract<M: Matrix + MatrixMut, ModOp: ArithmeticOps<Element = M::MatElement>>( fn sample_extract<M: Matrix + MatrixMut, ModOp: ArithmeticOps<Element = M::MatElement>>(
lwe_out: &mut M::R, lwe_out: &mut M::R,
rlwe_in: &M, rlwe_in: &M,
@@ -936,58 +927,26 @@ where
} }
impl PBSTracer<Vec<Vec<u64>>> { impl PBSTracer<Vec<Vec<u64>>> {
fn trace(&self, parameters: &BoolParameters<u64>, client_key: &ClientKey, expected_m: bool) { fn trace(&self, parameters: &BoolParameters<u64>, client_key: &ClientKey) {
let lwe_q = parameters.lwe_q; let modop_lweq = ModularOpsU64::new(parameters.lwe_q as u64);
let lwe_qby8 = ((lwe_q as f64) / 8.0).round() as u64;
let expected_m_lweq = if expected_m {
lwe_qby8
} else {
lwe_q - lwe_qby8
};
let modop_lweq = ModularOpsU64::new(lwe_q);
// noise after mod down Q -> Q_ks // noise after mod down Q -> Q_ks
let noise0 = { let m_back0 = decrypt_lwe(&self.ct_lwe_q_mod, client_key.sk_rlwe.values(), &modop_lweq);
measure_noise_lwe(
&self.ct_lwe_q_mod,
client_key.sk_rlwe.values(),
&modop_lweq,
&expected_m_lweq,
)
};
// noise after key switch from RLWE -> LWE // noise after key switch from RLWE -> LWE
let noise1 = { let m_back1 = decrypt_lwe(
measure_noise_lwe(
&self.ct_lwe_q_mod_after_ksk, &self.ct_lwe_q_mod_after_ksk,
client_key.sk_lwe.values(), client_key.sk_lwe.values(),
&modop_lweq, &modop_lweq,
&expected_m_lweq, );
)
};
// noise after mod down odd from Q_ks -> q // noise after mod down odd from Q_ks -> q
let br_q = parameters.br_q as u64; let modop_br_q = ModularOpsU64::new(parameters.br_q as u64);
let expected_m_brq = if expected_m { let m_back2 = decrypt_lwe(&self.ct_br_q_mod, client_key.sk_lwe.values(), &modop_br_q);
br_q >> 3
} else {
br_q - (br_q >> 3)
};
let modop_br_q = ModularOpsU64::new(br_q);
let noise2 = {
measure_noise_lwe(
&self.ct_br_q_mod,
client_key.sk_lwe.values(),
&modop_br_q,
&expected_m_brq,
)
};
println!( println!(
" "
m: {expected_m}, M after mod down Q -> Q_ks: {m_back0},
Noise after mod down Q -> Q_ks: {noise0}, M after key switch from RLWE -> LWE: {m_back1},
Noise after key switch from RLWE -> LWE: {noise1}, M after mod dwon Q_ks -> q: {m_back2}
Noise after mod dwon Q_ks -> q: {noise2}
" "
); );
} }
@@ -1025,7 +984,7 @@ mod tests {
lwe_n: 493, lwe_n: 493,
d_rgsw: 3, d_rgsw: 3,
logb_rgsw: 8, logb_rgsw: 8,
d_lwe: 3, d_lwe: 2,
logb_lwe: 4, logb_lwe: 4,
g: 5, g: 5,
w: 1, w: 1,
@@ -1063,10 +1022,9 @@ mod tests {
let bool_evaluator = let bool_evaluator =
BoolEvaluator::<Vec<Vec<u64>>, u64, NttBackendU64, ModularOpsU64>::new(SP_BOOL_PARAMS); BoolEvaluator::<Vec<Vec<u64>>, u64, NttBackendU64, ModularOpsU64>::new(SP_BOOL_PARAMS);
// println!("{:?}", bool_evaluator.nand_test_vec); // println!("{:?}", bool_evaluator.nand_test_vec);
let client_key = bool_evaluator.client_key(); let client_key = bool_evaluator.client_key();
set_client_key(&client_key);
let server_key = bool_evaluator.server_key(&client_key); let server_key = bool_evaluator.server_key(&client_key);
let mut scratch_lwen_plus1 = vec![0u64; bool_evaluator.parameters.lwe_n + 1]; let mut scratch_lwen_plus1 = vec![0u64; bool_evaluator.parameters.lwe_n + 1];
@@ -1079,7 +1037,7 @@ mod tests {
let mut m1 = true; let mut m1 = true;
let mut ct0 = bool_evaluator.encrypt(m0, &client_key); let mut ct0 = bool_evaluator.encrypt(m0, &client_key);
let mut ct1 = bool_evaluator.encrypt(m1, &client_key); let mut ct1 = bool_evaluator.encrypt(m1, &client_key);
for _ in 0..4 { for _ in 0..100 {
let ct_back = bool_evaluator.nand( let ct_back = bool_evaluator.nand(
&ct0, &ct0,
&ct1, &ct1,
@@ -1093,9 +1051,9 @@ mod tests {
// Trace and measure PBS noise // Trace and measure PBS noise
{ {
// Trace PBS // Trace PBS
PBSTracer::with_local(|t| t.trace(&SP_BOOL_PARAMS, &client_key, m_out)); PBSTracer::with_local(|t| t.trace(&SP_BOOL_PARAMS, &client_key));
// Calculate nosie in ciphertext post PBS // Calculate noise in ciphertext post PBS
let ideal = if m_out { let ideal = if m_out {
bool_evaluator.rlweq_by8 bool_evaluator.rlweq_by8
} else { } else {

View File

@@ -9,11 +9,13 @@ mod backend;
mod bool; mod bool;
mod decomposer; mod decomposer;
mod lwe; mod lwe;
mod multi_party;
mod ntt; mod ntt;
mod num; mod num;
mod random; mod random;
mod rgsw; mod rgsw;
mod utils; mod utils;
pub trait Matrix: AsRef<[Self::R]> { pub trait Matrix: AsRef<[Self::R]> {
type MatElement; type MatElement;
type R: Row<Element = Self::MatElement>; type R: Row<Element = Self::MatElement>;

View File

@@ -187,6 +187,11 @@ where
operator.sub(b, &sa) operator.sub(b, &sa)
} }
/// Measures noise in input LWE ciphertext with reference of `ideal_m`
///
/// - ct: Input LWE ciphertext
/// - s: corresponding secret
/// - ideal_m: Ideal `encoded` message
pub(crate) fn measure_noise_lwe<Ro: Row, Op: ArithmeticOps<Element = Ro::Element>, S>( pub(crate) fn measure_noise_lwe<Ro: Row, Op: ArithmeticOps<Element = Ro::Element>, S>(
ct: &Ro, ct: &Ro,
s: &[S], s: &[S],
@@ -206,7 +211,6 @@ where
}); });
let m = operator.sub(&ct.as_ref()[0], &sa); let m = operator.sub(&ct.as_ref()[0], &sa);
println!("measire: {m} {ideal_m}");
let mut diff = operator.sub(&m, ideal_m); let mut diff = operator.sub(&m, ideal_m);
let q = operator.modulus(); let q = operator.modulus();
if diff > (q >> 1) { if diff > (q >> 1) {
@@ -221,16 +225,19 @@ mod tests {
use crate::{ use crate::{
backend::{ModInit, ModularOpsU64}, backend::{ModInit, ModularOpsU64},
decomposer::{gadget_vector, DefaultDecomposer}, decomposer::{gadget_vector, DefaultDecomposer},
lwe::lwe_key_switch, lwe::{lwe_key_switch, measure_noise_lwe},
random::DefaultSecureRng, random::DefaultSecureRng,
rgsw::measure_noise,
Secret, Secret,
}; };
use super::{decrypt_lwe, encrypt_lwe, lwe_ksk_keygen, LweSecret}; use super::{decrypt_lwe, encrypt_lwe, lwe_ksk_keygen, LweSecret};
const K: usize = 500;
#[test] #[test]
fn encrypt_decrypt_works() { fn encrypt_decrypt_works() {
let logq = 20; let logq = 16;
let q = 1u64 << logq; let q = 1u64 << logq;
let lwe_n = 1024; let lwe_n = 1024;
let logp = 3; let logp = 3;
@@ -262,12 +269,12 @@ mod tests {
#[test] #[test]
fn key_switch_works() { fn key_switch_works() {
let logq = 16; let logq = 16;
let logp = 3; let logp = 2;
let q = 1u64 << logq; let q = 1u64 << logq;
let lwe_in_n = 1024; let lwe_in_n = 2048;
let lwe_out_n = 470; let lwe_out_n = 493;
let d_ks = 3; let d_ks = 3;
let logb = 4; let logb = 5;
let lwe_sk_in = LweSecret::random(lwe_in_n >> 1, lwe_in_n); let lwe_sk_in = LweSecret::random(lwe_in_n >> 1, lwe_in_n);
let lwe_sk_out = LweSecret::random(lwe_out_n >> 1, lwe_out_n); let lwe_sk_out = LweSecret::random(lwe_out_n >> 1, lwe_out_n);
@@ -276,7 +283,7 @@ mod tests {
let modq_op = ModularOpsU64::new(q); let modq_op = ModularOpsU64::new(q);
// genrate ksk // genrate ksk
for _ in 0..10 { for _ in 0..K {
let mut ksk = vec![vec![0u64; lwe_out_n + 1]; d_ks * lwe_in_n]; let mut ksk = vec![vec![0u64; lwe_out_n + 1]; d_ks * lwe_in_n];
let gadget = gadget_vector(logq, logb, d_ks); let gadget = gadget_vector(logq, logb, d_ks);
lwe_ksk_keygen( lwe_ksk_keygen(
@@ -311,6 +318,9 @@ mod tests {
let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round() let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round()
as u64) as u64)
% (1u64 << logp); % (1u64 << logp);
let noise =
measure_noise_lwe(&lwe_out_ct, lwe_sk_out.values(), &modq_op, &encoded_m);
println!("Noise: {noise}");
assert_eq!(m, m_back, "Expected {m} but got {m_back}"); assert_eq!(m, m_back, "Expected {m} but got {m_back}");
// dbg!(m, m_back); // dbg!(m, m_back);
// dbg!(encoded_m, encoded_m_back); // dbg!(encoded_m, encoded_m_back);

View File

@@ -1,5 +1,6 @@
use itertools::Itertools; use itertools::Itertools;
use rand::{thread_rng, Rng, RngCore}; use rand::{thread_rng, Rng, RngCore, SeedableRng};
use rand_chacha::ChaCha8Rng;
use crate::{ use crate::{
backend::{ArithmeticOps, ModInit, ModularOpsU64}, backend::{ArithmeticOps, ModInit, ModularOpsU64},
@@ -8,6 +9,8 @@ use crate::{
pub trait NttInit { pub trait NttInit {
type Element; type Element;
/// Ntt istance must be compatible across different instances with same `q`
/// and `n`
fn new(q: Self::Element, n: usize) -> Self; fn new(q: Self::Element, n: usize) -> Self;
} }
@@ -189,6 +192,7 @@ pub(crate) fn find_primitive_root<R: RngCore>(q: u64, n: u64, rng: &mut R) -> Op
None None
} }
#[derive(Debug)]
pub struct NttBackendU64 { pub struct NttBackendU64 {
q: u64, q: u64,
q_twice: u64, q_twice: u64,
@@ -204,10 +208,9 @@ impl NttInit for NttBackendU64 {
type Element = u64; type Element = u64;
fn new(q: u64, n: usize) -> Self { fn new(q: u64, n: usize) -> Self {
// \psi = 2n^{th} primitive root of unity in F_q // \psi = 2n^{th} primitive root of unity in F_q
let mut rng = thread_rng(); let mut rng = ChaCha8Rng::from_seed([0u8; 32]);
let psi = find_primitive_root(q, (n * 2) as u64, &mut rng) let psi = find_primitive_root(q, (n * 2) as u64, &mut rng)
.expect("Unable to find 2n^th root of unity"); .expect("Unable to find 2n^th root of unity");
let psi_inv = mod_inverse(psi, q); let psi_inv = mod_inverse(psi, q);
// assert!( // assert!(
@@ -382,7 +385,7 @@ mod tests {
#[test] #[test]
fn native_ntt_negacylic_mul() { fn native_ntt_negacylic_mul() {
let primes = [40, 50, 60] let primes = [25, 40, 50, 60]
.iter() .iter()
.map(|bits| generate_prime(*bits, (2 * N) as u64, 1u64 << bits).unwrap()) .map(|bits| generate_prime(*bits, (2 * N) as u64, 1u64 << bits).unwrap())
.collect_vec(); .collect_vec();

View File

@@ -11,6 +11,11 @@ thread_local! {
pub(crate) static DEFAULT_RNG: RefCell<DefaultSecureRng> = RefCell::new(DefaultSecureRng::new()); pub(crate) static DEFAULT_RNG: RefCell<DefaultSecureRng> = RefCell::new(DefaultSecureRng::new());
} }
pub(crate) trait NewWithSeed {
type Seed;
fn new_with_seed(seed: Self::Seed) -> Self;
}
pub trait RandomGaussianDist<M> pub trait RandomGaussianDist<M>
where where
M: ?Sized, M: ?Sized,
@@ -41,6 +46,17 @@ impl DefaultSecureRng {
let rng = ChaCha8Rng::from_entropy(); let rng = ChaCha8Rng::from_entropy();
DefaultSecureRng { rng } DefaultSecureRng { rng }
} }
pub fn fill_bytes(&mut self, a: &mut [u8; 32]) {
self.rng.fill_bytes(a);
}
}
impl NewWithSeed for DefaultSecureRng {
type Seed = <ChaCha8Rng as SeedableRng>::Seed;
fn new_with_seed(seed: Self::Seed) -> Self {
DefaultSecureRng::new_seeded(seed)
}
} }
impl RandomUniformDist<usize> for DefaultSecureRng { impl RandomUniformDist<usize> for DefaultSecureRng {
@@ -86,19 +102,19 @@ impl RandomUniformDist<[u64]> for DefaultSecureRng {
impl RandomGaussianDist<u64> for DefaultSecureRng { impl RandomGaussianDist<u64> for DefaultSecureRng {
type Parameters = u64; type Parameters = u64;
fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut u64) { fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut u64) {
let o = rand_distr::Normal::new(0.0, 3.2f64) // let o = rand_distr::Normal::new(0.0, 3.2f64)
.unwrap() // .unwrap()
.sample(&mut self.rng) // .sample(&mut self.rng)
.round(); // .round();
// let o = 0.0f64; // // let o = 0.0f64;
let is_neg = o.is_sign_negative() && o != 0.0; // let is_neg = o.is_sign_negative() && o != 0.0;
if is_neg { // if is_neg {
*container = parameters - (o.abs() as u64); // *container = parameters - (o.abs() as u64);
} else { // } else {
*container = o as u64; // *container = o as u64;
} // }
} }
} }
@@ -124,21 +140,21 @@ impl RandomGaussianDist<u32> for DefaultSecureRng {
impl RandomGaussianDist<[u64]> for DefaultSecureRng { impl RandomGaussianDist<[u64]> for DefaultSecureRng {
type Parameters = u64; type Parameters = u64;
fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u64]) { fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u64]) {
izip!( // izip!(
rand_distr::Normal::new(0.0, 3.2f64) // rand_distr::Normal::new(0.0, 3.2f64)
.unwrap() // .unwrap()
.sample_iter(&mut self.rng), // .sample_iter(&mut self.rng),
container.iter_mut() // container.iter_mut()
) // )
.for_each(|(oi, v)| { // .for_each(|(oi, v)| {
let oi = oi.round(); // let oi = oi.round();
let is_neg = oi.is_sign_negative() && oi != 0.0; // let is_neg = oi.is_sign_negative() && oi != 0.0;
if is_neg { // if is_neg {
*v = parameters - (oi.abs() as u64); // *v = parameters - (oi.abs() as u64);
} else { // } else {
*v = oi as u64; // *v = oi as u64;
} // }
}); // });
} }
} }

File diff suppressed because it is too large Load Diff