diff --git a/backend/src/scalar_znx.rs b/backend/src/scalar_znx.rs index 2cc1797..4acedb5 100644 --- a/backend/src/scalar_znx.rs +++ b/backend/src/scalar_znx.rs @@ -109,7 +109,7 @@ impl>> ScalarZnx { } pub fn new(n: usize, cols: usize) -> Self { - let data = alloc_aligned::(Self::bytes_of(n, cols)); + let data: Vec = alloc_aligned::(Self::bytes_of(n, cols)); Self { data: data.into(), n, diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs index 38c8bf2..87471e9 100644 --- a/core/src/blind_rotation/ccgi.rs +++ b/core/src/blind_rotation/ccgi.rs @@ -1,29 +1,46 @@ use std::time::Instant; use backend::{ - FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxDftOps, - VecZnxOps, ZnxView, ZnxViewMut, ZnxZero, + FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, + Scratch, VecZnxBigAlloc, VecZnxBigOps, VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero, }; use itertools::izip; use crate::{ - FourierGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, ScratchCore, + FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWECiphertextToMut, GLWEPlaintext, Infos, LWECiphertext, + ScratchCore, blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable}, lwe::ciphertext::LWECiphertextToRef, }; pub fn cggi_blind_rotate_scratch_space( module: &Module, + extension_factor: usize, basek: usize, k_lut: usize, k_brk: usize, rows: usize, rank: usize, ) -> usize { - let size = k_brk.div_ceil(basek); - GGSWCiphertext::, FFT64>::bytes_of(module, basek, k_brk, rows, 1, rank) - + (module.mat_znx_dft_mul_x_pow_minus_one_scratch_space(size, rank + 1) - | GLWECiphertext::external_product_inplace_scratch_space(module, basek, k_lut, k_brk, 1, rank)) + let lut_size: usize = k_lut.div_ceil(basek); + let brk_size: usize = k_brk.div_ceil(basek); + + let acc_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_brk, rank) * extension_factor; + let acc_big: usize = module.bytes_of_vec_znx_big(rank + 1, brk_size); + let acc_dft_add: usize = FourierGLWECiphertext::bytes_of(module, basek, k_brk, rank) * extension_factor; + let vmp_res: usize = FourierGLWECiphertext::bytes_of(module, basek, k_brk, rank) * extension_factor; + let xai_plus_y: usize = module.bytes_of_scalar_znx(1); + let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1); + let vmp: usize = module.vmp_apply_tmp_bytes(lut_size, lut_size, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2) + + let acc: usize; + if extension_factor > 1 { + acc = GLWECiphertext::bytes_of(module, basek, k_lut, rank) * extension_factor; + } else { + acc = 0; + } + + return acc + acc_dft + acc_dft_add + vmp_res + xai_plus_y + xai_plus_y_dft + (vmp | acc_big); } pub fn cggi_blind_rotate( @@ -37,8 +54,8 @@ pub fn cggi_blind_rotate( DataRes: AsRef<[u8]> + AsMut<[u8]>, DataIn: AsRef<[u8]>, { - if lut.data.len() > 1 { - cggi_blind_rotate_block_binary_exnteded(module, res, lwe, lut, brk, scratch); + if lut.extension_factor() > 1 { + cggi_blind_rotate_block_binary_extended(module, res, lwe, lut, brk, scratch); } else if brk.block_size() > 1 { cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch); } else { @@ -46,7 +63,7 @@ pub fn cggi_blind_rotate( } } -pub(crate) fn cggi_blind_rotate_block_binary_exnteded( +pub(crate) fn cggi_blind_rotate_block_binary_extended( module: &Module, res: &mut GLWECiphertext, lwe: &LWECiphertext, @@ -57,198 +74,164 @@ pub(crate) fn cggi_blind_rotate_block_binary_exnteded( DataRes: AsRef<[u8]> + AsMut<[u8]>, DataIn: AsRef<[u8]>, { + let extension_factor: usize = lut.extension_factor(); + let basek: usize = res.basek(); + + let (mut acc, scratch1) = scratch.tmp_vec_glwe_ct(extension_factor, module, basek, res.k(), res.rank()); + let (mut acc_dft, scratch2) = scratch1.tmp_vec_fourier_glwe_ct(extension_factor, module, basek, brk.k(), res.rank()); + let (mut vmp_res, scratch3) = scratch2.tmp_vec_fourier_glwe_ct(extension_factor, module, basek, brk.k(), res.rank()); + let (mut acc_add_dft, scratch4) = scratch3.tmp_vec_fourier_glwe_ct(extension_factor, module, basek, brk.k(), res.rank()); + + (0..extension_factor).for_each(|i| { + acc[i].data.zero(); + }); + + let (mut xai_plus_y, scratch5) = scratch4.tmp_scalar_znx(module, 1); + let (mut xai_plus_y_dft, scratch6) = scratch5.tmp_scalar_znx_dft(module, 1); + let (mut acc_add_big, scratch7) = scratch6.tmp_vec_znx_big(module, 1, brk.size()); + let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space - let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); - let basek: usize = out_mut.basek(); - let cols: usize = out_mut.rank() + 1; + let two_n_ext: usize = 2 * lut.domain_size(); - mod_switch_2n( - 2 * module.n() * lut.extension_factor(), - &mut lwe_2n, - &lwe_ref, - ); + let cols: usize = res.rank() + 1; - let extension_factor: i64 = lut.extension_factor() as i64; - - let mut acc: Vec>> = Vec::with_capacity(lut.extension_factor()); - - for _ in 0..extension_factor { - acc.push(GLWECiphertext::alloc( - module, - basek, - out_mut.k(), - out_mut.rank(), - )); - } + negate_and_mod_switch_2n(two_n_ext, &mut lwe_2n, &lwe_ref); let a: &[i64] = &lwe_2n[1..]; - let b: i64 = lwe_2n[0]; + let b_pos: usize = ((lwe_2n[0] + two_n_ext as i64) % two_n_ext as i64) as usize; - let b_inner: i64 = b / extension_factor; - let b_outer: i64 = b % extension_factor; + let b_hi: usize = b_pos / extension_factor; + let b_lo: usize = b_pos % extension_factor; - for (i, j) in (0..b_outer).zip(extension_factor - b_outer..extension_factor) { - module.vec_znx_rotate( - b_inner + 1, - &mut acc[j as usize].data, - 0, - &lut.data[i as usize], - 0, - ); + for (i, j) in (0..b_lo).zip(extension_factor - b_lo..extension_factor) { + module.vec_znx_rotate(b_hi as i64 + 1, &mut acc[i].data, 0, &lut.data[j], 0); } - for (i, j) in (b_outer..extension_factor).zip(0..extension_factor - b_outer) { - module.vec_znx_rotate( - b_inner, - &mut acc[j as usize].data, - 0, - &lut.data[i as usize], - 0, - ); + for (i, j) in (b_lo..extension_factor).zip(0..extension_factor - b_lo) { + module.vec_znx_rotate(b_hi as i64, &mut acc[i].data, 0, &lut.data[j], 0); } let block_size: usize = brk.block_size(); - let mut acc_dft: Vec, FFT64>> = Vec::with_capacity(lut.extension_factor()); - - for _ in 0..extension_factor { - acc_dft.push(FourierGLWECiphertext::alloc( - module, - basek, - out_mut.k(), - out_mut.rank(), - )); - } - - let mut vmp_res: Vec, FFT64>> = Vec::with_capacity(lut.extension_factor()); - - for _ in 0..extension_factor { - vmp_res.push(FourierGLWECiphertext::alloc( - module, - basek, - out_mut.k(), - out_mut.rank(), - )); - } - - let mut acc_add_dft: Vec, FFT64>> = Vec::with_capacity(lut.extension_factor()); - - for _ in 0..extension_factor { - acc_add_dft.push(FourierGLWECiphertext::alloc( - module, - basek, - out_mut.k(), - out_mut.rank(), - )); - } - - let mut xai_minus_one: backend::ScalarZnx> = module.new_scalar_znx(1); - let mut xai_minus_one_dft: backend::ScalarZnxDft, FFT64> = module.new_scalar_znx_dft(1); - izip!( a.chunks_exact(block_size), brk.data.chunks_exact(block_size) ) .enumerate() .for_each(|(i, (ai, ski))| { - (0..lut.extension_factor()).for_each(|i| { - acc[i].dft(module, &mut acc_dft[i]); + (0..extension_factor).for_each(|i| { + (0..cols).for_each(|j| { + module.vec_znx_dft(1, 0, &mut acc_dft[i].data, j, &acc[i].data, j); + }); acc_add_dft[i].data.zero(); }); + // TODO: first & last iterations can be optimized izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| { - let aii_inner: i64 = aii / extension_factor; - let aii_outer: i64 = aii % extension_factor; + let ai_pos: usize = ((aii + two_n_ext as i64) % two_n_ext as i64) as usize; + let ai_hi: usize = ai_pos / extension_factor; + let ai_lo: usize = ai_pos % extension_factor; // vmp_res = DFT(acc) * BRK[i] - (0..lut.extension_factor()).for_each(|i| { - module.vmp_apply(&mut vmp_res[i].data, &acc_dft[i].data, &skii.data, scratch); + (0..extension_factor).for_each(|i| { + module.vmp_apply(&mut vmp_res[i].data, &acc_dft[i].data, &skii.data, scratch7); }); - if aii_outer == 0 { - xai_minus_one.zero(); - xai_minus_one.at_mut(0, 0)[0] = 1; - module.vec_znx_rotate_inplace(aii_inner, &mut xai_minus_one, 0); - xai_minus_one.at_mut(0, 0)[0] -= 1; - module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); + // Trivial case: no rotation between polynomials, we can directly multiply with (X^{-ai} - 1) + if ai_lo == 0 { + // DFT X^{-ai} + set_xai_plus_y( + module, + ai_hi as i64, + -1, + &mut xai_plus_y_dft, + &mut xai_plus_y, + ); - (0..lut.extension_factor()).for_each(|j| { + // Sets acc_add_dft[i] = (acc[i] * sk) * (X^{-ai} - 1) + (0..extension_factor).for_each(|j| { (0..cols).for_each(|i| { - module.svp_apply_inplace(&mut vmp_res[j].data, i, &xai_minus_one_dft, 0); + module.svp_apply_inplace(&mut vmp_res[j].data, i, &xai_plus_y_dft, 0); module.vec_znx_dft_add_inplace(&mut acc_add_dft[j].data, i, &vmp_res[j].data, i); }); - }) + }); + // Non trivial case: rotation between polynomials + // In this case we can't directly multiply with (X^{-ai} - 1) because of the + // ring homomorphism R^{N} -> prod R^{N/extension_factor}, so we split the + // computation in two steps: acc_add_dft = (acc * sk) * (-1) + (acc * sk) * X^{-ai} } else { - xai_minus_one.zero(); - xai_minus_one.at_mut(0, 0)[0] = 1; - module.vec_znx_rotate_inplace(aii_inner + 1, &mut xai_minus_one, 0); - xai_minus_one.at_mut(0, 0)[0] -= 1; - module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); - - for (i, j) in (0..aii_outer).zip(extension_factor - aii_outer..extension_factor) { - module.vec_znx_rotate( - b_inner + 1, - &mut acc[j as usize].data, - 0, - &lut.data[i as usize], - 0, - ); - + // Sets acc_add_dft[i] = acc[i] * sk + (0..extension_factor).for_each(|i| { (0..cols).for_each(|k| { - module.svp_apply_inplace(&mut vmp_res[i as usize].data, k, &xai_minus_one_dft, 0); - module.vec_znx_dft_add_inplace( - &mut acc_add_dft[j as usize].data, - k, - &vmp_res[i as usize].data, - k, - ); + module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i].data, k, &vmp_res[i].data, k); + }) + }); + + // DFT X^{-ai+1} + set_xai_plus_y( + module, + ai_hi as i64 + 1, + 0, + &mut xai_plus_y_dft, + &mut xai_plus_y, + ); + + // Sets acc_add_dft[0..ai_lo] += (acc[extension_factor - ai_lo..extension_factor] * sk) * X^{-ai+1} + for (i, j) in (0..ai_lo).zip(extension_factor - ai_lo..extension_factor) { + module.vec_znx_rotate(b_hi as i64 + 1, &mut acc[i].data, 0, &lut.data[j], 0); + (0..cols).for_each(|k| { + module.svp_apply_inplace(&mut vmp_res[j].data, k, &xai_plus_y_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[i].data, k, &vmp_res[j].data, k); }); } - xai_minus_one.zero(); - xai_minus_one.at_mut(0, 0)[0] = 1; - module.vec_znx_rotate_inplace(aii_inner, &mut xai_minus_one, 0); - xai_minus_one.at_mut(0, 0)[0] -= 1; - module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); - - for (i, j) in (aii_outer..extension_factor).zip(0..extension_factor - aii_outer) { - module.vec_znx_rotate( - b_inner, - &mut acc[j as usize].data, - 0, - &lut.data[i as usize], - 0, - ); + // DFT X^{-ai} + set_xai_plus_y( + module, + ai_hi as i64, + 0, + &mut xai_plus_y_dft, + &mut xai_plus_y, + ); + // Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai} + for (i, j) in (ai_lo..extension_factor).zip(0..extension_factor - ai_lo) { + module.vec_znx_rotate(b_hi as i64, &mut acc[i].data, 0, &lut.data[j], 0); (0..cols).for_each(|k| { - module.svp_apply_inplace(&mut vmp_res[i as usize].data, k, &xai_minus_one_dft, 0); - module.vec_znx_dft_add_inplace( - &mut acc_add_dft[j as usize].data, - k, - &vmp_res[i as usize].data, - k, - ); + module.svp_apply_inplace(&mut vmp_res[j].data, k, &xai_plus_y_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[i].data, k, &vmp_res[j].data, k); }); } } }); - if i == lwe.n() - block_size { + (0..extension_factor).for_each(|j| { (0..cols).for_each(|i| { - module.vec_znx_dft_add_inplace(&mut acc_dft[0].data, i, &acc_add_dft[0].data, i); + module.vec_znx_dft_add_inplace(&mut acc_dft[j].data, i, &acc_add_dft[j].data, i); + module.vec_znx_idft(&mut acc_add_big, 0, &acc_dft[j].data, i, scratch7); + module.vec_znx_big_normalize(basek, &mut acc[j].data, i, &acc_add_big, 0, scratch7); }); - acc_dft[0].idft(module, &mut out_mut, scratch); - } else { - (0..lut.extension_factor()).for_each(|j| { - (0..cols).for_each(|i| { - module.vec_znx_dft_add_inplace(&mut acc_dft[j].data, i, &acc_add_dft[j].data, i); - }); - - acc_dft[j].idft(module, &mut acc[j], scratch); - }) - } + }); }); + + (0..cols).for_each(|i| { + module.vec_znx_copy(&mut res.data, i, &acc[0].data, i); + }); +} + +fn set_xai_plus_y( + module: &Module, + k: i64, + y: i64, + res: &mut ScalarZnxDft<&mut [u8], FFT64>, + buf: &mut ScalarZnx<&mut [u8]>, +) { + buf.zero(); + buf.at_mut(0, 0)[0] = 1; + module.vec_znx_rotate_inplace(k, buf, 0); + buf.at_mut(0, 0)[0] += y; + module.svp_prepare(res, 0, buf, 0); } pub(crate) fn cggi_blind_rotate_block_binary( @@ -270,7 +253,7 @@ pub(crate) fn cggi_blind_rotate_block_binary( let cols: usize = out_mut.rank() + 1; - mod_switch_2n(2 * module.n(), &mut lwe_2n, &lwe_ref); + negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe_ref); let a: &[i64] = &lwe_2n[1..]; let b: i64 = lwe_2n[0]; @@ -278,17 +261,17 @@ pub(crate) fn cggi_blind_rotate_block_binary( out_mut.data.zero(); // Initialize out to X^{b} * LUT(X) - module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0); + module.vec_znx_rotate(-b, &mut out_mut.data, 0, &lut.data[0], 0); let block_size: usize = brk.block_size(); // ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)] - let (mut acc_dft, scratch1) = scratch.tmp_glwe_fourier(module, brk.basek(), out_mut.k(), out_mut.rank()); - let (mut acc_add_dft, scratch2) = scratch1.tmp_glwe_fourier(module, brk.basek(), out_mut.k(), out_mut.rank()); - let (mut vmp_res, scratch3) = scratch2.tmp_glwe_fourier(module, basek, out_mut.k(), out_mut.rank()); - let (mut xai_minus_one, scratch4) = scratch3.tmp_scalar_znx(module, 1); - let (mut xai_minus_one_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1); + let (mut acc_dft, scratch1) = scratch.tmp_fourier_glwe_ct(module, brk.basek(), out_mut.k(), out_mut.rank()); + let (mut acc_add_dft, scratch2) = scratch1.tmp_fourier_glwe_ct(module, brk.basek(), out_mut.k(), out_mut.rank()); + let (mut vmp_res, scratch3) = scratch2.tmp_fourier_glwe_ct(module, basek, out_mut.k(), out_mut.rank()); + let (mut xai_plus_y, scratch4) = scratch3.tmp_scalar_znx(module, 1); + let (mut xai_plus_y_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1); let start: Instant = Instant::now(); izip!( @@ -304,15 +287,11 @@ pub(crate) fn cggi_blind_rotate_block_binary( module.vmp_apply(&mut vmp_res.data, &acc_dft.data, &skii.data, scratch5); // DFT(X^ai -1) - xai_minus_one.zero(); - xai_minus_one.at_mut(0, 0)[0] = 1; - module.vec_znx_rotate_inplace(*aii, &mut xai_minus_one, 0); - xai_minus_one.at_mut(0, 0)[0] -= 1; - module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); + set_xai_plus_y(module, *aii, -1, &mut xai_plus_y_dft, &mut xai_plus_y); // DFT(X^ai -1) * (DFT(acc) * BRK[i]) (0..cols).for_each(|i| { - module.svp_apply_inplace(&mut vmp_res.data, i, &xai_minus_one_dft, 0); + module.svp_apply_inplace(&mut vmp_res.data, i, &xai_plus_y_dft, 0); module.vec_znx_dft_add_inplace(&mut acc_add_dft.data, i, &vmp_res.data, i); }); }); @@ -324,15 +303,15 @@ pub(crate) fn cggi_blind_rotate_block_binary( acc_dft.idft(module, &mut out_mut, scratch5); }); let duration: std::time::Duration = start.elapsed(); - println!("external products: {} us", duration.as_micros()); } -pub(crate) fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { +pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { let basek: usize = lwe.basek(); let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1; res.copy_from_slice(&lwe.data.at(0, 0)); + res.iter_mut().for_each(|x| *x = -*x); if basek > log2n { let diff: usize = basek - log2n; diff --git a/core/src/blind_rotation/key.rs b/core/src/blind_rotation/key.rs index 8cb24cb..b7f9c3f 100644 --- a/core/src/blind_rotation/key.rs +++ b/core/src/blind_rotation/key.rs @@ -46,8 +46,13 @@ impl BlindRotationKeyCGGI { assert_eq!(sk_glwe.n(), module.n()); assert_eq!(sk_glwe.rank(), self.data[0].rank()); match sk_lwe.dist { - Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) => {} - _ => panic!("invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb"), + Distribution::BinaryBlock(_) + | Distribution::BinaryFixed(_) + | Distribution::BinaryProb(_) + | Distribution::ZERO => {} + _ => panic!( + "invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)" + ), } } @@ -79,6 +84,11 @@ impl BlindRotationKeyCGGI { self.data[0].k() } + #[allow(dead_code)] + pub(crate) fn size(&self) -> usize { + self.data[0].size() + } + #[allow(dead_code)] pub(crate) fn rank(&self) -> usize { self.data[0].rank() diff --git a/core/src/blind_rotation/lut.rs b/core/src/blind_rotation/lut.rs index b6e9a7f..ed54dd2 100644 --- a/core/src/blind_rotation/lut.rs +++ b/core/src/blind_rotation/lut.rs @@ -1,4 +1,4 @@ -use backend::{FFT64, Module, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned}; +use backend::{FFT64, Module, ScalarZnx, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxView, ZnxViewMut, alloc_aligned}; pub struct LookUpTable { pub(crate) data: Vec>>, @@ -84,6 +84,31 @@ impl LookUpTable { } } + pub fn set_raw(&mut self, module: &Module, lut: &ScalarZnx) + where + D: AsRef<[u8]>, + { + let domain_size: usize = self.domain_size(); + + let size: usize = self.k.div_ceil(self.basek); + + let mut lut_full: VecZnx> = VecZnx::new::(domain_size, 1, size); + + lut_full.at_mut(0, 0).copy_from_slice(lut.raw()); + + if self.extension_factor() > 1 { + (0..self.extension_factor()).for_each(|i| { + module.switch_degree(&mut self.data[i], 0, &lut_full, 0); + if i < self.extension_factor() { + lut_full.rotate(-1); + } + }); + } else { + module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0); + } + } + + #[allow(dead_code)] pub(crate) fn rotate(&mut self, k: i64) { let extension_factor: usize = self.extension_factor(); let two_n: usize = 2 * self.data[0].n(); diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs index 6322c75..5deb497 100644 --- a/core/src/blind_rotation/test_fft64/cggi.rs +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -1,12 +1,12 @@ use std::time::Instant; -use backend::{Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxView}; +use backend::{Encoding, FFT64, Module, ScalarZnx, ScratchOwned, Stats, VecZnxOps, ZnxView, ZnxViewMut}; use sampling::source::Source; use crate::{ FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret, blind_rotation::{ - ccgi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, mod_switch_2n}, + ccgi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, negate_and_mod_switch_2n}, key::BlindRotationKeyCGGI, lut::LookUpTable, }, @@ -16,22 +16,24 @@ use crate::{ #[test] fn blind_rotation() { let module: Module = Module::::new(2048); - let basek: usize = 20; + let basek: usize = 18; let n_lwe: usize = 1071; - let k_lwe: usize = 22; - let k_brk: usize = 60; + let k_lwe: usize = 24; + let k_brk: usize = 3 * basek; let rows_brk: usize = 2; - let k_lut: usize = 60; + let k_lut: usize = 2 * basek; let rank: usize = 1; let block_size: usize = 7; - let message_modulus: usize = 64; + let extension_factor: usize = 2; - let mut source_xs: Source = Source::new([0u8; 32]); - let mut source_xe: Source = Source::new([0u8; 32]); - let mut source_xa: Source = Source::new([0u8; 32]); + let message_modulus: usize = 1 << 6; + + let mut source_xs: Source = Source::new([1u8; 32]); + let mut source_xe: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([1u8; 32]); let mut sk_glwe: GLWESecret> = GLWESecret::alloc(&module, rank); sk_glwe.fill_ternary_prob(0.5, &mut source_xs); @@ -40,9 +42,21 @@ fn blind_rotation() { let mut sk_lwe: LWESecret> = LWESecret::alloc(n_lwe); sk_lwe.fill_binary_block(block_size, &mut source_xs); + sk_lwe.data.raw_mut()[0] = 0; + + println!("sk_lwe: {:?}", sk_lwe.data.raw()); + let mut scratch: ScratchOwned = ScratchOwned::new( BlindRotationKeyCGGI::generate_from_sk_scratch_space(&module, basek, k_brk, rank) - | cggi_blind_rotate_scratch_space(&module, basek, k_lut, k_brk, rows_brk, rank), + | cggi_blind_rotate_scratch_space( + &module, + extension_factor, + basek, + k_lut, + k_brk, + rows_brk, + rank, + ), ); let start: Instant = Instant::now(); @@ -65,8 +79,8 @@ fn blind_rotation() { let mut pt_lwe: LWEPlaintext> = LWEPlaintext::alloc(basek, k_lwe); - let x: i64 = 0; - let bits: usize = 6; + let x: i64 = 1; + let bits: usize = 8; pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits); @@ -82,7 +96,7 @@ fn blind_rotation() { 2 * x + 1 } - let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, 1); + let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); lut.set(&module, lut_fn, message_modulus); let mut res: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_lut, rank); @@ -103,7 +117,7 @@ fn blind_rotation() { let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space - mod_switch_2n(module.n() * 2, &mut lwe_2n, &lwe.to_ref()); + negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe.to_ref()); let pt_want: i64 = (lwe_2n[0] + lwe_2n[1..] @@ -111,15 +125,22 @@ fn blind_rotation() { .zip(sk_lwe.data.at(0, 0)) .map(|(x, y)| x * y) .sum::()) - % (module.n() as i64 * 2); + % (2 * lut.domain_size()) as i64; - module.vec_znx_rotate_inplace(pt_want, &mut lut.data[0], 0); + println!("pt_want: {}", pt_want); - println!("pt_want: {}", lut.data[0]); + lut.rotate(pt_want); + lut.data.iter().for_each(|d| { + println!("{}", d); + }); + + // First limb should be exactly equal (test are parameterized such that the noise does not reach + // the first limb) + // assert_eq!(pt_have.data.at_mut(0, 0), lut.data[0].at_mut(0, 0)); + + // Then checks the noise module.vec_znx_sub_ab_inplace(&mut lut.data[0], 0, &pt_have.data, 0); - let noise: f64 = lut.data[0].std(0, basek); - println!("noise: {}", noise); } diff --git a/core/src/blind_rotation/test_fft64/lut.rs b/core/src/blind_rotation/test_fft64/lut.rs index 9377d76..58c393b 100644 --- a/core/src/blind_rotation/test_fft64/lut.rs +++ b/core/src/blind_rotation/test_fft64/lut.rs @@ -25,7 +25,7 @@ fn standard() { let step: usize = lut.domain_size().div_round(message_modulus); (0..lut.domain_size()).step_by(step).for_each(|i| { - (0..step).for_each(|j| { + (0..step).for_each(|_| { assert_eq!( lut_fn((i / step) as i64) % message_modulus as i64, lut.data[0].raw()[0] / scale as i64 @@ -58,7 +58,7 @@ fn extended() { let step: usize = lut.domain_size().div_round(message_modulus); (0..lut.domain_size()).step_by(step).for_each(|i| { - (0..step).for_each(|j| { + (0..step).for_each(|_| { assert_eq!( lut_fn((i / step) as i64) % message_modulus as i64, lut.data[0].raw()[0] / scale as i64 diff --git a/core/src/gglwe/automorphism.rs b/core/src/gglwe/automorphism.rs index e18e65a..07fcb14 100644 --- a/core/src/gglwe/automorphism.rs +++ b/core/src/gglwe/automorphism.rs @@ -77,7 +77,7 @@ impl + AsRef<[u8]>> GLWEAutomorphismKey { let (mut tmp_idft_data, scratct1) = scratch.tmp_vec_znx_big(module, cols_out, self.size()); { - let (mut tmp_dft, scratch2) = scratct1.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_dft, scratch2) = scratct1.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); // Extracts relevant row lhs.get_row(module, row_j, col_i, &mut tmp_dft); @@ -109,7 +109,7 @@ impl + AsRef<[u8]>> GLWEAutomorphismKey { tmp_idft.keyswitch_inplace(module, &rhs.key, scratct1); { - let (mut tmp_dft, _) = scratct1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_dft, _) = scratct1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a) // and switches back to DFT domain @@ -124,7 +124,7 @@ impl + AsRef<[u8]>> GLWEAutomorphismKey { }); }); - let (mut tmp_dft, _) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_dft, _) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); tmp_dft.data.zero(); (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { diff --git a/core/src/gglwe/encryption.rs b/core/src/gglwe/encryption.rs index 6c31b2a..bc1137c 100644 --- a/core/src/gglwe/encryption.rs +++ b/core/src/gglwe/encryption.rs @@ -70,7 +70,7 @@ impl + AsRef<[u8]>> GGLWECiphertext { let (mut tmp_pt, scrach_1) = scratch.tmp_glwe_pt(module, basek, k); let (mut tmp_ct, scrach_2) = scrach_1.tmp_glwe_ct(module, basek, k, rank_out); - let (mut tmp_ct_dft, scratch_3) = scrach_2.tmp_glwe_fourier(module, basek, k, rank_out); + let (mut tmp_ct_dft, scratch_3) = scrach_2.tmp_fourier_glwe_ct(module, basek, k, rank_out); // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns // diff --git a/core/src/gglwe/external_product.rs b/core/src/gglwe/external_product.rs index 2e063ef..26a8c92 100644 --- a/core/src/gglwe/external_product.rs +++ b/core/src/gglwe/external_product.rs @@ -66,8 +66,8 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (mut tmp_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - let (mut tmp_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_in, scratch1) = scratch.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_out, scratch2) = scratch1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { @@ -103,7 +103,7 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (mut tmp, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); println!("tmp: {}", tmp.size()); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { diff --git a/core/src/gglwe/keyswitch.rs b/core/src/gglwe/keyswitch.rs index 632309d..fe4a3f6 100644 --- a/core/src/gglwe/keyswitch.rs +++ b/core/src/gglwe/keyswitch.rs @@ -113,8 +113,8 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (mut tmp_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - let (mut tmp_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_in, scratch1) = scratch.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_out, scratch2) = scratch1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { @@ -150,7 +150,7 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (mut tmp, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { diff --git a/core/src/ggsw/ciphertext.rs b/core/src/ggsw/ciphertext.rs index 12e9723..e52f305 100644 --- a/core/src/ggsw/ciphertext.rs +++ b/core/src/ggsw/ciphertext.rs @@ -290,7 +290,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { // Switch vec_znx_ct into DFT domain { - let (mut tmp_ct_dft, _) = scratch2.tmp_glwe_fourier(module, basek, k, rank); + let (mut tmp_ct_dft, _) = scratch2.tmp_fourier_glwe_ct(module, basek, k, rank); tmp_ct.dft(module, &mut tmp_ct_dft); self.set_row(module, row_i, col_j, &tmp_ct_dft); } @@ -438,7 +438,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) (1..cols).for_each(|col_j| { self.expand_row(module, col_j, &mut tmp_res.data, &ci_dft, tsk, scratch2); - let (mut tmp_res_dft, _) = scratch2.tmp_glwe_fourier(module, basek, self.k(), rank); + let (mut tmp_res_dft, _) = scratch2.tmp_fourier_glwe_ct(module, basek, self.k(), rank); tmp_res.dft(module, &mut tmp_res_dft); self.set_row(module, row_i, col_j, &tmp_res_dft); }); @@ -541,7 +541,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { tensor_key, scratch2, ); - let (mut tmp_res_dft, _) = scratch2.tmp_glwe_fourier(module, basek, self.k(), rank); + let (mut tmp_res_dft, _) = scratch2.tmp_fourier_glwe_ct(module, basek, self.k(), rank); tmp_res.dft(module, &mut tmp_res_dft); self.set_row(module, row_i, col_j, &tmp_res_dft); }); @@ -599,8 +599,8 @@ impl + AsRef<[u8]>> GGSWCiphertext { ) } - let (mut tmp_ct_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); - let (mut tmp_ct_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_ct_in, scratch1) = scratch.tmp_fourier_glwe_ct(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_ct_out, scratch2) = scratch1.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); (0..self.rank() + 1).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { @@ -636,7 +636,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { ); } - let (mut tmp_ct, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_ct, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); (0..self.rank() + 1).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { @@ -674,7 +674,7 @@ impl> GGSWCiphertext { ) ) } - let (mut tmp_dft_dft, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + let (mut tmp_dft_dft, scratch1) = scratch.tmp_fourier_glwe_ct(module, self.basek(), self.k(), self.rank()); self.get_row(module, row_i, 0, &mut tmp_dft_dft); res.keyswitch_from_fourier(module, &tmp_dft_dft, ksk, scratch1); } diff --git a/core/src/lib.rs b/core/src/lib.rs index afdcc55..ba28589 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -28,6 +28,14 @@ pub(crate) const SIX_SIGMA: f64 = 6.0; pub trait ScratchCore { fn tmp_glwe_ct(&mut self, module: &Module, basek: usize, k: usize, rank: usize) -> (GLWECiphertext<&mut [u8]>, &mut Self); + fn tmp_vec_glwe_ct( + &mut self, + size: usize, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (Vec>, &mut Self); fn tmp_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self); fn tmp_gglwe( &mut self, @@ -48,13 +56,21 @@ pub trait ScratchCore { digits: usize, rank: usize, ) -> (GGSWCiphertext<&mut [u8], B>, &mut Self); - fn tmp_glwe_fourier( + fn tmp_fourier_glwe_ct( &mut self, module: &Module, basek: usize, k: usize, rank: usize, ) -> (FourierGLWECiphertext<&mut [u8], B>, &mut Self); + fn tmp_vec_fourier_glwe_ct( + &mut self, + size: usize, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (Vec>, &mut Self); fn tmp_sk(&mut self, module: &Module, rank: usize) -> (GLWESecret<&mut [u8]>, &mut Self); fn tmp_fourier_sk(&mut self, module: &Module, rank: usize) -> (FourierGLWESecret<&mut [u8], B>, &mut Self); fn tmp_glwe_pk( @@ -106,6 +122,24 @@ impl ScratchCore for Scratch { (GLWECiphertext { data, basek, k }, scratch) } + fn tmp_vec_glwe_ct( + &mut self, + size: usize, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (Vec>, &mut Self) { + let mut scratch: &mut Scratch = self; + let mut cts: Vec> = Vec::with_capacity(size); + for _ in 0..size { + let (ct, new_scratch) = scratch.tmp_glwe_ct(module, basek, k, rank); + scratch = new_scratch; + cts.push(ct); + } + (cts, scratch) + } + fn tmp_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self) { let (data, scratch) = self.tmp_vec_znx(module, 1, k.div_ceil(basek)); (GLWEPlaintext { data, basek, k }, scratch) @@ -166,7 +200,7 @@ impl ScratchCore for Scratch { ) } - fn tmp_glwe_fourier( + fn tmp_fourier_glwe_ct( &mut self, module: &Module, basek: usize, @@ -177,6 +211,24 @@ impl ScratchCore for Scratch { (FourierGLWECiphertext { data, basek, k }, scratch) } + fn tmp_vec_fourier_glwe_ct( + &mut self, + size: usize, + module: &Module, + basek: usize, + k: usize, + rank: usize, + ) -> (Vec>, &mut Self) { + let mut scratch: &mut Scratch = self; + let mut cts: Vec> = Vec::with_capacity(size); + for _ in 0..size { + let (ct, new_scratch) = scratch.tmp_fourier_glwe_ct(module, basek, k, rank); + scratch = new_scratch; + cts.push(ct); + } + (cts, scratch) + } + fn tmp_glwe_pk( &mut self, module: &Module, @@ -184,7 +236,7 @@ impl ScratchCore for Scratch { k: usize, rank: usize, ) -> (GLWEPublicKey<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_glwe_fourier(module, basek, k, rank); + let (data, scratch) = self.tmp_fourier_glwe_ct(module, basek, k, rank); ( GLWEPublicKey { data,