From dee889dc0cbd59bf0af784efa66c726ac27816f2 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 13 May 2025 17:21:41 +0200 Subject: [PATCH] working on adding rank to glwe (all test passing) --- base2k/src/mat_znx_dft.rs | 464 +++++----- core/benches/external_product_glwe_fft64.rs | 2 +- core/src/elem.rs | 6 +- core/src/encryption.rs | 105 --- core/src/external_product.rs | 19 - core/src/gglwe_ciphertext.rs | 253 ++++++ core/src/ggsw.rs | 324 ------- core/src/ggsw_ciphertext.rs | 316 +++++++ core/src/glwe.rs | 845 ------------------ core/src/glwe_ciphertext.rs | 460 ++++++++++ core/src/glwe_ciphertext_fourier.rs | 261 ++++++ core/src/glwe_plaintext.rs | 53 ++ core/src/keys.rs | 83 +- core/src/keyswitch.rs | 20 - core/src/keyswitch_key.rs | 344 +++---- core/src/lib.rs | 10 +- core/src/test_fft64/{grlwe.rs => gglwe.rs} | 129 +-- core/src/test_fft64/{rgsw.rs => ggsw.rs} | 104 ++- core/src/test_fft64/{rlwe.rs => glwe.rs} | 115 +-- .../{rlwe_dft.rs => glwe_fourier.rs} | 96 +- core/src/test_fft64/mod.rs | 8 +- core/src/vec_glwe_product.rs | 33 +- 22 files changed, 2020 insertions(+), 2030 deletions(-) delete mode 100644 core/src/encryption.rs delete mode 100644 core/src/external_product.rs create mode 100644 core/src/gglwe_ciphertext.rs delete mode 100644 core/src/ggsw.rs create mode 100644 core/src/ggsw_ciphertext.rs delete mode 100644 core/src/glwe.rs create mode 100644 core/src/glwe_ciphertext.rs create mode 100644 core/src/glwe_ciphertext_fourier.rs create mode 100644 core/src/glwe_plaintext.rs delete mode 100644 core/src/keyswitch.rs rename core/src/test_fft64/{grlwe.rs => gglwe.rs} (79%) rename core/src/test_fft64/{rgsw.rs => ggsw.rs} (87%) rename core/src/test_fft64/{rlwe.rs => glwe.rs} (84%) rename core/src/test_fft64/{rlwe_dft.rs => glwe_fourier.rs} (84%) diff --git a/base2k/src/mat_znx_dft.rs b/base2k/src/mat_znx_dft.rs index c34115d..209c696 100644 --- a/base2k/src/mat_znx_dft.rs +++ b/base2k/src/mat_znx_dft.rs @@ -1,232 +1,232 @@ -use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; -use std::marker::PhantomData; - -/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], -/// stored as a 3D matrix in the DFT domain in a single contiguous array. -/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft]. -/// -/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft]. -/// See the trait [MatZnxDftOps] for additional information. -pub struct MatZnxDft { - data: D, - n: usize, - size: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - _phantom: PhantomData, -} - -impl ZnxInfos for MatZnxDft { - fn cols(&self) -> usize { - self.cols_in - } - - fn rows(&self) -> usize { - self.rows - } - - fn n(&self) -> usize { - self.n - } - - fn size(&self) -> usize { - self.size - } -} - -impl ZnxSliceSize for MatZnxDft { - fn sl(&self) -> usize { - self.n() * self.cols_out() - } -} - -impl DataView for MatZnxDft { - type D = D; - fn data(&self) -> &Self::D { - &self.data - } -} - -impl DataViewMut for MatZnxDft { - fn data_mut(&mut self) -> &mut Self::D { - &mut self.data - } -} - -impl> ZnxView for MatZnxDft { - type Scalar = f64; -} - -impl MatZnxDft { - pub(crate) fn cols_in(&self) -> usize { - self.cols_in - } - - pub(crate) fn cols_out(&self) -> usize { - self.cols_out - } -} - -impl>, B: Backend> MatZnxDft { - pub(crate) fn bytes_of(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - unsafe { - crate::ffi::vmp::bytes_of_vmp_pmat( - module.ptr, - (rows * cols_in) as u64, - (size * cols_out) as u64, - ) as usize - } - } - - pub(crate) fn new(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { - let data: Vec = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size)); - Self { - data: data.into(), - n: module.n(), - size, - rows, - cols_in, - cols_out, - _phantom: PhantomData, - } - } - - pub(crate) fn new_from_bytes( - module: &Module, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: impl Into>, - ) -> Self { - let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size)); - Self { - data: data.into(), - n: module.n(), - size, - rows, - cols_in, - cols_out, - _phantom: PhantomData, - } - } -} - -impl> MatZnxDft { - /// Returns a copy of the backend array at index (i, j) of the [MatZnxDft]. - /// - /// # Arguments - /// - /// * `row`: row index (i). - /// * `col`: col index (j). - #[allow(dead_code)] - fn at(&self, row: usize, col: usize) -> Vec { - let n: usize = self.n(); - - let mut res: Vec = alloc_aligned(n); - - if n < 8 { - res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]); - } else { - (0..n >> 3).for_each(|blk| { - res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); - }); - } - - res - } - - #[allow(dead_code)] - fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { - let nrows: usize = self.rows(); - let nsize: usize = self.size(); - if col == (nsize - 1) && (nsize & 1 == 1) { - &self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..] - } else { - &self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] - } - } -} - -pub type MatZnxDftOwned = MatZnxDft, B>; - -pub trait MatZnxDftToRef { - fn to_ref(&self) -> MatZnxDft<&[u8], B>; -} - -pub trait MatZnxDftToMut { - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>; -} - -impl MatZnxDftToMut for MatZnxDft, B> { - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - MatZnxDft { - data: self.data.as_mut_slice(), - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: PhantomData, - } - } -} - -impl MatZnxDftToRef for MatZnxDft, B> { - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - MatZnxDft { - data: self.data.as_slice(), - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: PhantomData, - } - } -} - -impl MatZnxDftToMut for MatZnxDft<&mut [u8], B> { - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - MatZnxDft { - data: self.data, - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: PhantomData, - } - } -} - -impl MatZnxDftToRef for MatZnxDft<&mut [u8], B> { - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - MatZnxDft { - data: self.data, - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: PhantomData, - } - } -} - -impl MatZnxDftToRef for MatZnxDft<&[u8], B> { - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - MatZnxDft { - data: self.data, - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: PhantomData, - } - } -} +use crate::znx_base::ZnxInfos; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; +use std::marker::PhantomData; + +/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], +/// stored as a 3D matrix in the DFT domain in a single contiguous array. +/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft]. +/// +/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft]. +/// See the trait [MatZnxDftOps] for additional information. +pub struct MatZnxDft { + data: D, + n: usize, + size: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + _phantom: PhantomData, +} + +impl ZnxInfos for MatZnxDft { + fn cols(&self) -> usize { + self.cols_in + } + + fn rows(&self) -> usize { + self.rows + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } +} + +impl ZnxSliceSize for MatZnxDft { + fn sl(&self) -> usize { + self.n() * self.cols_out() + } +} + +impl DataView for MatZnxDft { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for MatZnxDft { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for MatZnxDft { + type Scalar = f64; +} + +impl MatZnxDft { + pub fn cols_in(&self) -> usize { + self.cols_in + } + + pub fn cols_out(&self) -> usize { + self.cols_out + } +} + +impl>, B: Backend> MatZnxDft { + pub(crate) fn bytes_of(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + unsafe { + crate::ffi::vmp::bytes_of_vmp_pmat( + module.ptr, + (rows * cols_in) as u64, + (size * cols_out) as u64, + ) as usize + } + } + + pub(crate) fn new(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + let data: Vec = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n: module.n(), + size, + rows, + cols_in, + cols_out, + _phantom: PhantomData, + } + } + + pub(crate) fn new_from_bytes( + module: &Module, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: impl Into>, + ) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n: module.n(), + size, + rows, + cols_in, + cols_out, + _phantom: PhantomData, + } + } +} + +impl> MatZnxDft { + /// Returns a copy of the backend array at index (i, j) of the [MatZnxDft]. + /// + /// # Arguments + /// + /// * `row`: row index (i). + /// * `col`: col index (j). + #[allow(dead_code)] + fn at(&self, row: usize, col: usize) -> Vec { + let n: usize = self.n(); + + let mut res: Vec = alloc_aligned(n); + + if n < 8 { + res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]); + } else { + (0..n >> 3).for_each(|blk| { + res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); + }); + } + + res + } + + #[allow(dead_code)] + fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { + let nrows: usize = self.rows(); + let nsize: usize = self.size(); + if col == (nsize - 1) && (nsize & 1 == 1) { + &self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..] + } else { + &self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] + } + } +} + +pub type MatZnxDftOwned = MatZnxDft, B>; + +pub trait MatZnxDftToRef { + fn to_ref(&self) -> MatZnxDft<&[u8], B>; +} + +pub trait MatZnxDftToMut { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>; +} + +impl MatZnxDftToMut for MatZnxDft, B> { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + MatZnxDft { + data: self.data.as_mut_slice(), + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft, B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data.as_slice(), + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToMut for MatZnxDft<&mut [u8], B> { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft<&mut [u8], B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft<&[u8], B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} diff --git a/core/benches/external_product_glwe_fft64.rs b/core/benches/external_product_glwe_fft64.rs index 4462fab..435a25f 100644 --- a/core/benches/external_product_glwe_fft64.rs +++ b/core/benches/external_product_glwe_fft64.rs @@ -6,7 +6,7 @@ use rlwe::{ external_product::{ ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, }, - ggsw::GGSWCiphertext, + ggsw_ciphertext::GGSWCiphertext, glwe::GLWECiphertext, keys::{SecretKey, SecretKeyFourier}, }; diff --git a/core/src/elem.rs b/core/src/elem.rs index bf5ca1e..4562137 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -1,6 +1,6 @@ use base2k::{Backend, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos}; -use crate::{glwe::GLWECiphertextFourier, utils::derive_size}; +use crate::{glwe_ciphertext_fourier::GLWECiphertextFourier, utils::derive_size}; pub trait Infos { type Inner: ZnxInfos; @@ -23,7 +23,7 @@ pub trait Infos { } /// Returns the number of polynomials in each row. - fn rank(&self) -> usize { + fn cols(&self) -> usize { self.inner().cols() } @@ -36,7 +36,7 @@ pub trait Infos { /// Returns the total number of small polynomials. fn poly_count(&self) -> usize { - self.rows() * self.rank() * self.size() + self.rows() * self.cols() * self.size() } /// Returns the base 2 logarithm of the ciphertext base. diff --git a/core/src/encryption.rs b/core/src/encryption.rs deleted file mode 100644 index 915834c..0000000 --- a/core/src/encryption.rs +++ /dev/null @@ -1,105 +0,0 @@ -use base2k::{Backend, Module, Scratch}; -use sampling::source::Source; - -pub trait EncryptSkScratchSpace { - fn encrypt_sk_scratch_space(module: &Module, ct_size: usize) -> usize; -} - -pub trait EncryptSk { - type Ciphertext; - type Plaintext; - type SecretKey; - - fn encrypt_sk( - &self, - module: &Module, - ct: &mut Self::Ciphertext, - pt: &Self::Plaintext, - sk: &Self::SecretKey, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ); -} - -pub trait EncryptZeroSkScratchSpace { - fn encrypt_zero_sk_scratch_space(module: &Module, ct_size: usize) -> usize; -} - -pub trait EncryptZeroSk { - type Ciphertext; - type SecretKey; - - fn encrypt_zero_sk( - &self, - module: &Module, - ct: &mut Self::Ciphertext, - sk: &Self::SecretKey, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ); -} - -pub trait EncryptPkScratchSpace { - fn encrypt_pk_scratch_space(module: &Module, ct_size: usize) -> usize; -} - -pub trait EncryptPk { - type Ciphertext; - type Plaintext; - type PublicKey; - - fn encrypt_pk( - &self, - module: &Module, - ct: &mut Self::Ciphertext, - pt: &Self::Plaintext, - pk: &Self::PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ); -} - -pub trait EncryptZeroPkScratchSpace { - fn encrypt_zero_pk_scratch_space(module: &Module, ct_size: usize) -> usize; -} - -pub trait EncryptZeroPk { - type Ciphertext; - type PublicKey; - - fn encrypt_zero_pk( - &self, - module: &Module, - ct: &mut Self::Ciphertext, - pk: &Self::PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ); -} - -pub trait Decrypt { - type Plaintext; - type Ciphertext; - type SecretKey; - - fn decrypt( - &self, - module: &Module, - pt: &mut Self::Plaintext, - ct: &Self::Ciphertext, - sk: &Self::SecretKey, - scratch: &mut Scratch, - ); -} diff --git a/core/src/external_product.rs b/core/src/external_product.rs deleted file mode 100644 index e8d0a7e..0000000 --- a/core/src/external_product.rs +++ /dev/null @@ -1,19 +0,0 @@ -use base2k::{FFT64, Module, Scratch}; - -pub trait ExternalProductScratchSpace { - fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; -} - -pub trait ExternalProduct { - type Lhs; - type Rhs; - fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch); -} -pub trait ExternalProductInplaceScratchSpace { - fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize; -} - -pub trait ExternalProductInplace { - type Rhs; - fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch); -} diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs new file mode 100644 index 0000000..9d7c45a --- /dev/null +++ b/core/src/gglwe_ciphertext.rs @@ -0,0 +1,253 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, + ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, + ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos, SetRow}, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::SecretKeyFourier, + utils::derive_size, + vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, +}; + +pub struct GGLWECiphertext { + pub(crate) data: MatZnxDft, + pub(crate) basek: usize, + pub(crate) k: usize, +} + +impl GGLWECiphertext, B> { + pub fn new(module: &Module, base2k: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { + Self { + data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, derive_size(base2k, k)), + basek: base2k, + k, + } + } +} + +impl Infos for GGLWECiphertext { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GGLWECiphertext { + pub fn rank(&self) -> usize { + self.data.cols_out() - 1 + } +} + +impl MatZnxDftToMut for GGLWECiphertext +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl MatZnxDftToRef for GGLWECiphertext +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl GGLWECiphertext, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + GLWECiphertext::encrypt_sk_scratch_space(module, rank, size) + + module.bytes_of_vec_znx(rank + 1, size) + + module.bytes_of_vec_znx(1, size) + + module.bytes_of_vec_znx_dft(rank + 1, size) + } + + pub fn encrypt_pk_scratch_space(_module: &Module, _rank: usize, _pk_size: usize) -> usize { + unimplemented!() + } +} + +impl GGLWECiphertext +where + MatZnxDft: MatZnxDftToMut + ZnxInfos, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + ScalarZnx: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(sk_dft.n(), module.n()); + assert_eq!(pt.n(), module.n()); + } + + let rows: usize = self.rows(); + let size: usize = self.size(); + let basek: usize = self.basek(); + let k: usize = self.k(); + + let cols: usize = self.rank() + 1; + + let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size); + let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, cols, size); + let (tmp_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, cols, size); + + let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { + data: tmp_znx_pt, + basek, + k, + }; + + let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { + data: tmp_znx_ct, + basek, + k, + }; + + let mut vec_znx_ct_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier { + data: tmp_znx_dft_ct, + basek, + k, + }; + + (0..rows).for_each(|row_i| { + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0); + module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3); + + // rlwe encrypt of vec_znx_pt into vec_znx_ct + vec_znx_ct.encrypt_sk( + module, + &vec_znx_pt, + sk_dft, + source_xa, + source_xe, + sigma, + bound, + scratch_3, + ); + + vec_znx_pt.data.zero(); // zeroes for next iteration + + // Switch vec_znx_ct into DFT domain + vec_znx_ct.dft(module, &mut vec_znx_ct_dft); + + // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft + module.vmp_prepare_row(self, row_i, 0, &vec_znx_ct_dft); + }); + } +} + +impl GetRow for GGLWECiphertext +where + MatZnxDft: MatZnxDftToRef, +{ + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) + where + VecZnxDft: VecZnxDftToMut, + { + #[cfg(debug_assertions)] + { + assert_eq!(col_j, 0); + } + module.vmp_extract_row(res, self, row_i, col_j); + } +} + +impl SetRow for GGLWECiphertext +where + MatZnxDft: MatZnxDftToMut, +{ + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + where + VecZnxDft: VecZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(col_j, 0); + } + module.vmp_prepare_row(self, row_i, col_j, a); + } +} + +impl VecGLWEProductScratchSpace for GGLWECiphertext, FFT64> { + fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { + module.bytes_of_vec_znx_dft(2, grlwe_size) + + (module.vec_znx_big_normalize_tmp_bytes() + | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) + + module.bytes_of_vec_znx_dft(1, a_size))) + } +} + +impl VecGLWEProduct for GGLWECiphertext +where + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + fn prod_with_glwe( + &self, + module: &Module, + res: &mut GLWECiphertext, + a: &GLWECiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef, + { + let log_base2k: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.basek(), log_base2k); + assert_eq!(a.basek(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), module.n()); + } + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise + + { + let (mut a1_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, a.size()); + module.vec_znx_dft(&mut a1_dft, 0, a, 1); + module.vmp_apply(&mut res_dft, &a1_dft, self, scratch2); + } + + let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); + + module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); + module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); + } +} diff --git a/core/src/ggsw.rs b/core/src/ggsw.rs deleted file mode 100644 index 79b12a5..0000000 --- a/core/src/ggsw.rs +++ /dev/null @@ -1,324 +0,0 @@ -use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, - ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, - ZnxZero, -}; -use sampling::source::Source; - -use crate::{ - elem::{GetRow, Infos, SetRow}, - encryption::EncryptSkScratchSpace, - external_product::{ - ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, - }, - glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext, encrypt_glwe_sk}, - keys::SecretKeyFourier, - keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, - keyswitch_key::GLWEKeySwitchKey, - utils::derive_size, - vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, -}; - -pub struct GGSWCiphertext { - pub data: MatZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} - -impl GGSWCiphertext, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { - Self { - data: module.new_mat_znx_dft(rows, 2, 2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for GGSWCiphertext { - type Inner = MatZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.log_base2k - } - - fn k(&self) -> usize { - self.log_k - } -} - -impl MatZnxDftToMut for GGSWCiphertext -where - MatZnxDft: MatZnxDftToMut, -{ - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - self.data.to_mut() - } -} - -impl MatZnxDftToRef for GGSWCiphertext -where - MatZnxDft: MatZnxDftToRef, -{ - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - self.data.to_ref() - } -} - -impl GGSWCiphertext, FFT64> { - pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { - GLWECiphertext::encrypt_sk_scratch_space(module, size) - + module.bytes_of_vec_znx(2, size) - + module.bytes_of_vec_znx(1, size) - + module.bytes_of_vec_znx_dft(2, size) - } -} - -pub fn encrypt_rgsw_sk( - module: &Module, - ct: &mut GGSWCiphertext, - pt: &ScalarZnx

, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - MatZnxDft: MatZnxDftToMut, - ScalarZnx

: ScalarZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let size: usize = ct.size(); - let log_base2k: usize = ct.basek(); - - let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size); - let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, 2, size); - - let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { - data: tmp_znx_pt, - log_base2k: log_base2k, - log_k: ct.k(), - }; - - let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { - data: tmp_znx_ct, - log_base2k: log_base2k, - log_k: ct.k(), - }; - - (0..ct.rows()).for_each(|row_j| { - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0); - module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2); - - (0..ct.rank()).for_each(|col_i| { - // rlwe encrypt of vec_znx_pt into vec_znx_ct - encrypt_glwe_sk( - module, - &mut vec_znx_ct, - Some((&vec_znx_pt, col_i)), - sk_dft, - source_xa, - source_xe, - sigma, - bound, - scrach_2, - ); - - // Switch vec_znx_ct into DFT domain - { - let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, 2, size); - module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0); - module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1); - module.vmp_prepare_row(ct, row_j, col_i, &vec_znx_dft_ct); - } - }); - - vec_znx_pt.data.zero(); // zeroes for next iteration - }); -} - -impl GGSWCiphertext { - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: &ScalarZnx

, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToMut, - ScalarZnx

, - ct: &GLWECiphertext, - sk_dft: &SecretKeyFourier, - scratch: &mut Scratch, -) where - VecZnx

: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct - - { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct - module.vec_znx_dft(&mut c0_dft, 0, ct, 1); - - // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); - } - - // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) - module.vec_znx_big_add_small_inplace(&mut c0_big, 0, ct, 0); - - // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize(ct.basek(), pt, 0, &mut c0_big, 0, scratch_1); - - pt.log_base2k = ct.basek(); - pt.log_k = pt.k().min(ct.k()); -} - -impl GLWECiphertext -where - VecZnx: VecZnxToMut + VecZnxToRef, -{ - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: &GLWEPlaintext, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - encrypt_glwe_sk( - module, - self, - Some((pt, 0)), - sk_dft, - source_xa, - source_xe, - sigma, - bound, - scratch, - ) - } - - pub fn encrypt_zero_sk( - &mut self, - module: &Module, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - ScalarZnxDft: ScalarZnxDftToRef, - { - encrypt_glwe_sk::( - module, self, None, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ) - } - - pub fn encrypt_pk( - &mut self, - module: &Module, - pt: &GLWEPlaintext, - pk: &PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToRef, - VecZnxDft: VecZnxDftToRef, - { - encrypt_glwe_pk( - module, - self, - Some(pt), - pk, - source_xu, - source_xe, - sigma, - bound, - scratch, - ) - } - - pub fn encrypt_zero_pk( - &mut self, - module: &Module, - pk: &PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - VecZnxDft: VecZnxDftToRef, - { - encrypt_glwe_pk::( - module, self, None, pk, source_xu, source_xe, sigma, bound, scratch, - ) - } -} - -impl GLWECiphertext -where - VecZnx: VecZnxToRef, -{ - pub fn decrypt( - &self, - module: &Module, - pt: &mut GLWEPlaintext, - sk_dft: &SecretKeyFourier, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToMut + VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - decrypt_glwe(module, pt, self, sk_dft, scratch); - } -} - -pub(crate) fn encrypt_glwe_pk( - module: &Module, - ct: &mut GLWECiphertext, - pt: Option<&GLWEPlaintext

>, - pk: &PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx

: VecZnxToRef, - VecZnxDft: VecZnxDftToRef, -{ - #[cfg(debug_assertions)] - { - assert_eq!(ct.basek(), pk.basek()); - assert_eq!(ct.n(), module.n()); - assert_eq!(pk.n(), module.n()); - if let Some(pt) = pt { - assert_eq!(pt.basek(), pk.basek()); - assert_eq!(pt.n(), module.n()); - } - } - - let log_base2k: usize = pk.basek(); - let size_pk: usize = pk.size(); - - // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); - - { - let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); - match pk.dist { - SecretDistribution::NONE => panic!( - "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate" - ), - SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), - SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), - SecretDistribution::ZERO => {} - } - - module.svp_prepare(&mut u_dft, 0, &u, 0); - } - - let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) - let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) - - // ct[0] = pk[0] * u + m + e0 - module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0); - module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); - tmp_big.add_normal(log_base2k, 0, pk.k(), source_xe, sigma, bound); - - if let Some(pt) = pt { - module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0); - } - - module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3); - - // ct[1] = pk[1] * u + e1 - module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1); - module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); - tmp_big.add_normal(log_base2k, 0, pk.k(), source_xe, sigma, bound); - module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3); -} - -pub struct GLWEPlaintext { - pub data: VecZnx, - pub log_base2k: usize, - pub log_k: usize, -} - -impl Infos for GLWEPlaintext { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.log_base2k - } - - fn k(&self) -> usize { - self.log_k - } -} - -impl VecZnxToMut for GLWEPlaintext -where - VecZnx: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data.to_mut() - } -} - -impl VecZnxToRef for GLWEPlaintext -where - VecZnx: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data.to_ref() - } -} - -impl GLWEPlaintext> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -pub struct GLWECiphertextFourier { - pub data: VecZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} - -impl GLWECiphertextFourier, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for GLWECiphertextFourier { - type Inner = VecZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.log_base2k - } - - fn k(&self) -> usize { - self.log_k - } -} - -impl VecZnxDftToMut for GLWECiphertextFourier -where - VecZnxDft: VecZnxDftToMut, -{ - fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { - self.data.to_mut() - } -} - -impl VecZnxDftToRef for GLWECiphertextFourier -where - VecZnxDft: VecZnxDftToRef, -{ - fn to_ref(&self) -> VecZnxDft<&[u8], B> { - self.data.to_ref() - } -} - -impl GLWECiphertextFourier -where - GLWECiphertextFourier: VecZnxDftToRef, -{ - #[allow(dead_code)] - pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { - module.bytes_of_vec_znx(2, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) - } - - pub(crate) fn idft(&self, module: &Module, res: &mut GLWECiphertext, scratch: &mut Scratch) - where - GLWECiphertext: VecZnxToMut, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), 2); - assert_eq!(res.rank(), 2); - assert_eq!(self.basek(), res.basek()) - } - - let min_size: usize = self.size().min(res.size()); - - let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size); - - module.vec_znx_idft(&mut res_big, 0, self, 0, scratch1); - module.vec_znx_idft(&mut res_big, 1, self, 1, scratch1); - module.vec_znx_big_normalize(self.basek(), res, 0, &res_big, 0, scratch1); - module.vec_znx_big_normalize(self.basek(), res, 1, &res_big, 1, scratch1); - } -} - -pub(crate) fn encrypt_zero_glwe_dft_sk( - module: &Module, - ct: &mut GLWECiphertextFourier, - sk: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let log_base2k: usize = ct.basek(); - let log_k: usize = ct.k(); - let size: usize = ct.size(); - - #[cfg(debug_assertions)] - { - match sk.dist { - SecretDistribution::NONE => panic!("invalid sk.dist = SecretDistribution::NONE"), - _ => {} - } - assert_eq!(ct.rank(), 2); - } - - // ct[1] = DFT(a) - { - let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size); - tmp_znx.fill_uniform(log_base2k, 0, size, source_xa); - module.vec_znx_dft(ct, 1, &tmp_znx, 0); - } - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); - - { - let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); - // c0_dft = ct[1] * DFT(s) - module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0); - } - - // c0_big += e - c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); - - // c0 = norm(c0_big = -as - e), NOTE: e is centered at 0. - let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size); - module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2); - module.vec_znx_negate_inplace(&mut tmp_znx, 0); - // ct[0] = DFT(-as + e) - module.vec_znx_dft(ct, 0, &tmp_znx, 0); -} - -impl GLWECiphertextFourier, FFT64> { - pub fn encrypt_zero_sk_scratch_space(module: &Module, size: usize) -> usize { - (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) - + module.bytes_of_vec_znx_big(1, size) - + module.bytes_of_vec_znx(1, size) - + module.vec_znx_big_normalize_tmp_bytes() - } - - pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() - | module.bytes_of_vec_znx_dft(1, size) - | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) - + module.bytes_of_vec_znx_big(1, size) - } -} - -pub fn decrypt_rlwe_dft( - module: &Module, - pt: &mut GLWEPlaintext

, - ct: &GLWECiphertextFourier, - sk: &SecretKeyFourier, - scratch: &mut Scratch, -) where - VecZnx

: VecZnxToMut + VecZnxToRef, - VecZnxDft: VecZnxDftToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct - - { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct - // c0_dft = DFT(a) * DFT(s) - module.svp_apply(&mut c0_dft, 0, sk, 0, ct, 1); - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); - } - - { - let (mut c1_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, ct.size()); - // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) - module.vec_znx_idft(&mut c1_big, 0, ct, 0, scratch_2); - module.vec_znx_big_add_inplace(&mut c0_big, 0, &c1_big, 0); - } - - // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize(ct.basek(), pt, 0, &mut c0_big, 0, scratch_1); - - pt.log_base2k = ct.basek(); - pt.log_k = pt.k().min(ct.k()); -} - -impl GLWECiphertextFourier -where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, -{ - pub(crate) fn encrypt_zero_sk( - &mut self, - module: &Module, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - ScalarZnxDft: ScalarZnxDftToRef, - { - encrypt_zero_glwe_dft_sk( - module, self, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ) - } - - pub fn decrypt( - &self, - module: &Module, - pt: &mut GLWEPlaintext

, - sk_dft: &SecretKeyFourier, - scratch: &mut Scratch, - ) where - VecZnx

: VecZnxToMut + VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - decrypt_rlwe_dft(module, pt, self, sk_dft, scratch); - } -} - -impl KeySwitchScratchSpace for GLWECiphertextFourier, FFT64> { - fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) - } -} - -impl KeySwitch for GLWECiphertextFourier -where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - VecZnxDft: VecZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GLWECiphertextFourier; - type Rhs = GLWEKeySwitchKey; - - fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe_fourier(module, self, lhs, scratch); - } -} - -impl KeySwitchInplaceScratchSpace for GLWECiphertextFourier, FFT64> { - fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, - ) - } -} - -impl KeySwitchInplace for GLWECiphertextFourier -where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Rhs = GLWEKeySwitchKey; - - fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe_fourier_inplace(module, self, scratch); - } -} - -impl ExternalProductScratchSpace for GLWECiphertextFourier, FFT64> { - fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) - } -} - -impl ExternalProduct for GLWECiphertextFourier -where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - VecZnxDft: VecZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GLWECiphertextFourier; - type Rhs = GGSWCiphertext; - - fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe_fourier(module, self, lhs, scratch); - } -} - -impl ExternalProductInplaceScratchSpace for GLWECiphertextFourier, FFT64> { - fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, - ) - } -} - -impl ExternalProductInplace for GLWECiphertextFourier -where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - type Rhs = GGSWCiphertext; - - fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe_fourier_inplace(module, self, scratch); - } -} diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs new file mode 100644 index 0000000..ed8a39e --- /dev/null +++ b/core/src/glwe_ciphertext.rs @@ -0,0 +1,460 @@ +use base2k::{ + AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, + ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, + ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + elem::Infos, + gglwe_ciphertext::GGLWECiphertext, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, + utils::derive_size, + vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, +}; + +pub struct GLWECiphertext { + pub data: VecZnx, + pub basek: usize, + pub k: usize, +} + +impl GLWECiphertext> { + pub fn new(module: &Module, basek: usize, k: usize, rank: usize) -> Self { + Self { + data: module.new_vec_znx(rank + 1, derive_size(basek, k)), + basek, + k, + } + } +} + +impl Infos for GLWECiphertext { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GLWECiphertext { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} + +impl VecZnxToMut for GLWECiphertext +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for GLWECiphertext +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl GLWECiphertext +where + VecZnx: VecZnxToRef, +{ + #[allow(dead_code)] + pub(crate) fn dft(&self, module: &Module, res: &mut GLWECiphertextFourier) + where + VecZnxDft: VecZnxDftToMut + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), res.rank()); + assert_eq!(self.basek(), res.basek()) + } + + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_dft(res, i, self, i); + }) + } +} + +impl GLWECiphertext> { + pub fn encrypt_sk_scratch_space(module: &Module, _rank: usize, ct_size: usize) -> usize { + module.vec_znx_big_normalize_tmp_bytes() + + module.bytes_of_vec_znx_dft(1, ct_size) + + module.bytes_of_vec_znx_big(1, ct_size) + } + pub fn encrypt_pk_scratch_space(module: &Module, _rank: usize, pk_size: usize) -> usize { + ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) + + module.bytes_of_scalar_znx_dft(1) + + module.vec_znx_big_normalize_tmp_bytes() + } + + pub fn decrypt_scratch_space(module: &Module, ct_size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, ct_size)) + + module.bytes_of_vec_znx_big(1, ct_size) + } + + pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } + + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } + + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } + + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl GLWECiphertext +where + VecZnx: VecZnxToMut + VecZnxToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &GLWEPlaintext, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + self.encrypt_sk_private( + module, + Some((pt, 0)), + sk_dft, + source_xa, + source_xe, + sigma, + bound, + scratch, + ); + } + + pub fn encrypt_zero_sk( + &mut self, + module: &Module, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + ScalarZnxDft: ScalarZnxDftToRef, + { + self.encrypt_sk_private( + module, None, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ); + } + + pub fn encrypt_pk( + &mut self, + module: &Module, + pt: &GLWEPlaintext, + pk: &GLWEPublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + { + self.encrypt_pk_private( + module, + Some((pt, 0)), + pk, + source_xu, + source_xe, + sigma, + bound, + scratch, + ); + } + + pub fn encrypt_zero_pk( + &mut self, + module: &Module, + pk: &GLWEPublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef, + { + self.encrypt_pk_private( + module, None, pk, source_xu, source_xe, sigma, bound, scratch, + ); + } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + { + rhs.0.prod_with_glwe(module, self, lhs, scratch); + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.0.prod_with_glwe_inplace(module, self, scratch); + } + + pub fn external_product( + &mut self, + module: &Module, + lhs: &GLWECiphertext, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + { + rhs.prod_with_glwe(module, self, lhs, scratch); + } + + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.prod_with_glwe_inplace(module, self, scratch); + } + + pub(crate) fn encrypt_sk_private( + &mut self, + module: &Module, + pt: Option<(&GLWEPlaintext, usize)>, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(sk_dft.n(), module.n()); + assert_eq!(self.n(), module.n()); + if let Some((pt, col)) = pt { + assert_eq!(pt.n(), module.n()); + assert!(col < self.rank() + 1); + } + } + + let log_base2k: usize = self.basek(); + let log_k: usize = self.k(); + let size: usize = self.size(); + let cols: usize = self.rank() + 1; + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx(module, 1, size); + c0_big.zero(); + + { + // c[i] = uniform + // c[0] -= c[i] * s[i], + (1..cols).for_each(|i| { + let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size); + + // c[i] = uniform + self.data.fill_uniform(log_base2k, i, size, source_xa); + + // c[i] = norm(IDFT(DFT(c[i]) * DFT(s[i]))) + module.vec_znx_dft(&mut ci_dft, 0, self, i); + module.svp_apply_inplace(&mut ci_dft, 0, sk_dft, i - 1); + let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); + + // use c[0] as buffer, which is overwritten later by the normalization step + module.vec_znx_big_normalize(log_base2k, self, 0, &ci_big, 0, scratch_2); + + // c0_tmp = -c[i] * s[i] (use c[0] as buffer) + module.vec_znx_sub_ab_inplace(&mut c0_big, 0, self, 0); + + // c[i] += m if col = i + if let Some((pt, col)) = pt { + if i == col { + module.vec_znx_add_inplace(self, i, pt, 0); + module.vec_znx_normalize_inplace(log_base2k, self, i, scratch_2); + } + } + }); + } + + // c[0] += e + c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); + + // c[0] += m if col = 0 + if let Some((pt, col)) = pt { + if col == 0 { + module.vec_znx_add_inplace(&mut c0_big, 0, pt, 0); + } + } + + // c[0] = norm(c[0]) + module.vec_znx_normalize(log_base2k, self, 0, &c0_big, 0, scratch_1); + } + + pub(crate) fn encrypt_pk_private( + &mut self, + module: &Module, + pt: Option<(&GLWEPlaintext, usize)>, + pk: &GLWEPublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.basek(), pk.basek()); + assert_eq!(self.n(), module.n()); + assert_eq!(pk.n(), module.n()); + assert_eq!(self.rank(), pk.rank()); + if let Some((pt, _)) = pt { + assert_eq!(pt.basek(), pk.basek()); + assert_eq!(pt.n(), module.n()); + } + } + + let log_base2k: usize = pk.basek(); + let size_pk: usize = pk.size(); + let cols: usize = self.rank() + 1; + + // Generates u according to the underlying secret distribution. + let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); + + { + let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); + match pk.dist { + SecretDistribution::NONE => panic!( + "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ + Self::generate" + ), + SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), + SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + SecretDistribution::ZERO => {} + } + + module.svp_prepare(&mut u_dft, 0, &u, 0); + } + + // ct[i] = pk[i] * u + ei (+ m if col = i) + (0..cols).for_each(|i| { + let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size_pk); + // ci_dft = DFT(u) * DFT(pk[i]) + module.svp_apply(&mut ci_dft, 0, &u_dft, 0, pk, i); + + // ci_big = u * p[i] + let mut ci_big = module.vec_znx_idft_consume(ci_dft); + + // ci_big = u * pk[i] + e + ci_big.add_normal(log_base2k, 0, pk.k(), source_xe, sigma, bound); + + // ci_big = u * pk[i] + e + m (if col = i) + if let Some((pt, col)) = pt { + if col == i { + module.vec_znx_big_add_small_inplace(&mut ci_big, 0, pt, 0); + } + } + + // ct[i] = norm(ci_big) + module.vec_znx_big_normalize(log_base2k, self, i, &ci_big, 0, scratch_2); + }); + } +} + +impl GLWECiphertext +where + VecZnx: VecZnxToRef, +{ + pub fn decrypt( + &self, + module: &Module, + pt: &mut GLWEPlaintext, + sk_dft: &SecretKeyFourier, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut, + ScalarZnxDft: ScalarZnxDftToRef, + { + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct + + { + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct + module.vec_znx_dft(&mut c0_dft, 0, self, 1); + + // c0_dft = DFT(a) * DFT(s) + module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); + + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); + } + + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_big_add_small_inplace(&mut c0_big, 0, self, 0); + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(self.basek(), pt, 0, &mut c0_big, 0, scratch_1); + + pt.basek = self.basek(); + pt.k = pt.k().min(self.k()); + } +} diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs new file mode 100644 index 0000000..bcc7648 --- /dev/null +++ b/core/src/glwe_ciphertext_fourier.rs @@ -0,0 +1,261 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, + VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, + VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + elem::Infos, + gglwe_ciphertext::GGLWECiphertext, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext::GLWECiphertext, + glwe_plaintext::GLWEPlaintext, + keys::SecretKeyFourier, + keyswitch_key::GLWESwitchingKey, + utils::derive_size, + vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, +}; + +pub struct GLWECiphertextFourier { + pub data: VecZnxDft, + pub basek: usize, + pub k: usize, +} + +impl GLWECiphertextFourier, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, rank: usize) -> Self { + Self { + data: module.new_vec_znx_dft(rank + 1, derive_size(log_base2k, log_k)), + basek: log_base2k, + k: log_k, + } + } +} + +impl Infos for GLWECiphertextFourier { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl GLWECiphertextFourier { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} + +impl VecZnxDftToMut for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut, +{ + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl VecZnxDftToRef for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToRef, +{ + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl GLWECiphertextFourier, FFT64> { + #[allow(dead_code)] + pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { + module.bytes_of_vec_znx(1, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) + } + + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, ct_size: usize) -> usize { + module.bytes_of_vec_znx(rank + 1, ct_size) + GLWECiphertext::encrypt_sk_scratch_space(module, rank, ct_size) + } + + pub fn decrypt_scratch_space(module: &Module, ct_size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() + | module.bytes_of_vec_znx_dft(1, ct_size) + | (module.bytes_of_vec_znx_big(1, ct_size) + module.vec_znx_idft_tmp_bytes())) + + module.bytes_of_vec_znx_big(1, ct_size) + } + + pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } + + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } + + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } + + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, +{ + pub fn encrypt_zero_sk( + &mut self, + module: &Module, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + ScalarZnxDft: ScalarZnxDftToRef, + { + let (vec_znx_tmp, scratch_1) = scratch.tmp_vec_znx(module, self.rank() + 1, self.size()); + let mut ct_idft = GLWECiphertext { + data: vec_znx_tmp, + basek: self.basek, + k: self.k, + }; + ct_idft.encrypt_zero_sk( + module, sk_dft, source_xa, source_xe, sigma, bound, scratch_1, + ); + + ct_idft.dft(module, self); + } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GLWECiphertextFourier, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + rhs.0.prod_with_glwe_fourier(module, self, lhs, scratch); + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.0.prod_with_glwe_fourier_inplace(module, self, scratch); + } + + pub fn external_product( + &mut self, + module: &Module, + lhs: &GLWECiphertextFourier, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + rhs.prod_with_glwe_fourier(module, self, lhs, scratch); + } + + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.prod_with_glwe_fourier_inplace(module, self, scratch); + } +} + +impl GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToRef, +{ + pub fn decrypt( + &self, + module: &Module, + pt: &mut GLWEPlaintext, + sk_dft: &SecretKeyFourier, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut + VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert_eq!(sk_dft.n(), module.n()); + } + + let cols = self.rank() + 1; + + let (mut pt_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct + pt_big.zero(); + + { + (1..cols).for_each(|i| { + let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct + module.svp_apply(&mut ci_dft, 0, sk_dft, i - 1, self, i); + let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); + module.vec_znx_big_add_inplace(&mut pt_big, 0, &ci_big, 0); + }); + } + + { + let (mut c0_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, self.size()); + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_idft(&mut c0_big, 0, self, 0, scratch_2); + module.vec_znx_big_add_inplace(&mut pt_big, 0, &c0_big, 0); + } + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(self.basek(), pt, 0, &mut pt_big, 0, scratch_1); + + pt.basek = self.basek(); + pt.k = pt.k().min(self.k()); + } + + pub(crate) fn idft(&self, module: &Module, res: &mut GLWECiphertext, scratch: &mut Scratch) + where + GLWECiphertext: VecZnxToMut, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), res.rank()); + assert_eq!(self.basek(), res.basek()) + } + + let min_size: usize = self.size().min(res.size()); + + let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 1, min_size); + + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_idft(&mut res_big, 0, self, i, scratch1); + module.vec_znx_big_normalize(self.basek(), res, i, &res_big, 0, scratch1); + }); + } +} diff --git a/core/src/glwe_plaintext.rs b/core/src/glwe_plaintext.rs new file mode 100644 index 0000000..75088d1 --- /dev/null +++ b/core/src/glwe_plaintext.rs @@ -0,0 +1,53 @@ +use base2k::{Backend, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef}; + +use crate::{elem::Infos, utils::derive_size}; + +pub struct GLWEPlaintext { + pub data: VecZnx, + pub basek: usize, + pub k: usize, +} + +impl Infos for GLWEPlaintext { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.basek + } + + fn k(&self) -> usize { + self.k + } +} + +impl VecZnxToMut for GLWEPlaintext +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for GLWEPlaintext +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl GLWEPlaintext> { + pub fn new(module: &Module, base2k: usize, k: usize) -> Self { + Self { + data: module.new_vec_znx(1, derive_size(base2k, k)), + basek: base2k, + k, + } + } +} diff --git a/core/src/keys.rs b/core/src/keys.rs index eaa569e..d57fa73 100644 --- a/core/src/keys.rs +++ b/core/src/keys.rs @@ -5,7 +5,7 @@ use base2k::{ }; use sampling::source::Source; -use crate::{elem::Infos, glwe::GLWECiphertextFourier}; +use crate::{elem::Infos, glwe_ciphertext_fourier::GLWECiphertextFourier}; #[derive(Clone, Copy, Debug)] pub enum SecretDistribution { @@ -21,25 +21,43 @@ pub struct SecretKey { } impl SecretKey> { - pub fn new(module: &Module) -> Self { + pub fn new(module: &Module, rank: usize) -> Self { Self { - data: module.new_scalar_znx(1), + data: module.new_scalar_znx(rank), dist: SecretDistribution::NONE, } } } +impl SecretKey { + pub fn n(&self) -> usize { + self.data.n() + } + + pub fn log_n(&self) -> usize { + self.data.log_n() + } + + pub fn rank(&self) -> usize { + self.data.cols() + } +} + impl SecretKey where S: AsMut<[u8]> + AsRef<[u8]>, { pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { - self.data.fill_ternary_prob(0, prob, source); + (0..self.rank()).for_each(|i| { + self.data.fill_ternary_prob(i, prob, source); + }); self.dist = SecretDistribution::TernaryProb(prob); } pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { - self.data.fill_ternary_hw(0, hw, source); + (0..self.rank()).for_each(|i| { + self.data.fill_ternary_hw(i, hw, source); + }); self.dist = SecretDistribution::TernaryFixed(hw); } @@ -72,10 +90,24 @@ pub struct SecretKeyFourier { pub dist: SecretDistribution, } +impl SecretKeyFourier { + pub fn n(&self) -> usize { + self.data.n() + } + + pub fn log_n(&self) -> usize { + self.data.log_n() + } + + pub fn rank(&self) -> usize { + self.data.cols() + } +} + impl SecretKeyFourier, B> { - pub fn new(module: &Module) -> Self { + pub fn new(module: &Module, rank: usize) -> Self { Self { - data: module.new_scalar_znx_dft(1), + data: module.new_scalar_znx_dft(rank), dist: SecretDistribution::NONE, } } @@ -91,9 +123,15 @@ impl SecretKeyFourier, B> { SecretDistribution::NONE => panic!("invalid sk: SecretDistribution::NONE"), _ => {} } + + assert_eq!(self.n(), module.n()); + assert_eq!(sk.n(), module.n()); + assert_eq!(self.rank(), sk.rank()); } - module.svp_prepare(self, 0, sk, 0); + (0..self.rank()).for_each(|i| { + module.svp_prepare(self, i, sk, i); + }); self.dist = sk.dist; } } @@ -116,21 +154,21 @@ where } } -pub struct PublicKey { +pub struct GLWEPublicKey { pub data: GLWECiphertextFourier, pub dist: SecretDistribution, } -impl PublicKey, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { +impl GLWEPublicKey, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, rank: usize) -> Self { Self { - data: GLWECiphertextFourier::new(module, log_base2k, log_k), + data: GLWECiphertextFourier::new(module, log_base2k, log_k, rank), dist: SecretDistribution::NONE, } } } -impl Infos for PublicKey { +impl Infos for GLWEPublicKey { type Inner = VecZnxDft; fn inner(&self) -> &Self::Inner { @@ -138,15 +176,21 @@ impl Infos for PublicKey { } fn basek(&self) -> usize { - self.data.log_base2k + self.data.basek } fn k(&self) -> usize { - self.data.log_k + self.data.k } } -impl VecZnxDftToMut for PublicKey +impl GLWEPublicKey { + pub fn rank(&self) -> usize { + self.cols() - 1 + } +} + +impl VecZnxDftToMut for GLWEPublicKey where VecZnxDft: VecZnxDftToMut, { @@ -155,7 +199,7 @@ where } } -impl VecZnxDftToRef for PublicKey +impl VecZnxDftToRef for GLWEPublicKey where VecZnxDft: VecZnxDftToRef, { @@ -164,7 +208,7 @@ where } } -impl PublicKey { +impl GLWEPublicKey { pub fn generate( &mut self, module: &Module, @@ -186,8 +230,9 @@ impl PublicKey { } // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::new(GLWECiphertextFourier::encrypt_zero_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::new(GLWECiphertextFourier::encrypt_sk_scratch_space( module, + self.rank(), self.size(), )); self.data.encrypt_zero_sk( diff --git a/core/src/keyswitch.rs b/core/src/keyswitch.rs deleted file mode 100644 index c77ccb4..0000000 --- a/core/src/keyswitch.rs +++ /dev/null @@ -1,20 +0,0 @@ -use base2k::{FFT64, Module, Scratch}; - -pub trait KeySwitchScratchSpace { - fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; -} - -pub trait KeySwitch { - type Lhs; - type Rhs; - fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch); -} - -pub trait KeySwitchInplaceScratchSpace { - fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize; -} - -pub trait KeySwitchInplace { - type Rhs; - fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch); -} diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index cb4c248..33d2a45 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -1,170 +1,63 @@ use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, - ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, - VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, - ZnxZero, + Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef, + ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, }; use sampling::source::Source; use crate::{ elem::{GetRow, Infos, SetRow}, - encryption::EncryptSkScratchSpace, - external_product::{ - ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, - }, - ggsw::GGSWCiphertext, - glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext}, + gglwe_ciphertext::GGLWECiphertext, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, keys::SecretKeyFourier, - keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, - utils::derive_size, vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; -pub struct GLWEKeySwitchKey { - pub data: MatZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} +pub struct GLWESwitchingKey(pub(crate) GGLWECiphertext); -impl GLWEKeySwitchKey, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { - Self { - data: module.new_mat_znx_dft(rows, 1, 2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } +impl GLWESwitchingKey, FFT64> { + pub fn new(module: &Module, base2k: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { + GLWESwitchingKey(GGLWECiphertext::new( + module, base2k, k, rows, rank_in, rank_out, + )) } } -impl Infos for GLWEKeySwitchKey { +impl Infos for GLWESwitchingKey { type Inner = MatZnxDft; fn inner(&self) -> &Self::Inner { - &self.data + &self.0.inner() } fn basek(&self) -> usize { - self.log_base2k + self.0.basek() } fn k(&self) -> usize { - self.log_k + self.0.k() } } -impl MatZnxDftToMut for GLWEKeySwitchKey +impl MatZnxDftToMut for GLWESwitchingKey where - MatZnxDft: MatZnxDftToMut, + MatZnxDft: MatZnxDftToMut, { fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - self.data.to_mut() + self.0.data.to_mut() } } -impl MatZnxDftToRef for GLWEKeySwitchKey +impl MatZnxDftToRef for GLWESwitchingKey where - MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, { fn to_ref(&self) -> MatZnxDft<&[u8], B> { - self.data.to_ref() + self.0.data.to_ref() } } -impl GLWEKeySwitchKey, FFT64> { - pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { - GLWECiphertext::encrypt_sk_scratch_space(module, size) - + module.bytes_of_vec_znx(2, size) - + module.bytes_of_vec_znx(1, size) - + module.bytes_of_vec_znx_dft(2, size) - } -} - -pub fn encrypt_glwe_key_switch_key_sk( - module: &Module, - ct: &mut GLWEKeySwitchKey, - pt: &ScalarZnx

, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - MatZnxDft: MatZnxDftToMut, - ScalarZnx

: ScalarZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let rows: usize = ct.rows(); - let size: usize = ct.size(); - let log_base2k: usize = ct.basek(); - - let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size); - let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size); - let (mut vec_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size); - - let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { - data: tmp_znx_pt, - log_base2k: log_base2k, - log_k: ct.k(), - }; - - let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { - data: tmp_znx_ct, - log_base2k: log_base2k, - log_k: ct.k(), - }; - - (0..rows).for_each(|row_i| { - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0); - module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scratch_3); - - // rlwe encrypt of vec_znx_pt into vec_znx_ct - vec_znx_ct.encrypt_sk( - module, - &vec_znx_pt, - sk_dft, - source_xa, - source_xe, - sigma, - bound, - scratch_3, - ); - - vec_znx_pt.data.zero(); // zeroes for next iteration - - // Switch vec_znx_ct into DFT domain - module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0); - module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1); - - // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft - module.vmp_prepare_row(ct, row_i, 0, &vec_znx_dft_ct); - }); -} - -impl GLWEKeySwitchKey { - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: &ScalarZnx

, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - MatZnxDft: MatZnxDftToMut, - ScalarZnx

: ScalarZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - encrypt_glwe_key_switch_key_sk( - module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ) - } -} - -impl GetRow for GLWEKeySwitchKey +impl GetRow for GLWESwitchingKey where MatZnxDft: MatZnxDftToRef, { @@ -180,7 +73,7 @@ where } } -impl SetRow for GLWEKeySwitchKey +impl SetRow for GLWESwitchingKey where MatZnxDft: MatZnxDftToMut, { @@ -196,138 +89,117 @@ where } } -impl KeySwitchScratchSpace for GLWEKeySwitchKey, FFT64> { - fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( +impl GLWESwitchingKey, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + GGLWECiphertext::encrypt_sk_scratch_space(module, rank, size) + } + + pub fn encrypt_pk_scratch_space(module: &Module, rank: usize, pk_size: usize) -> usize { + GGLWECiphertext::encrypt_pk_scratch_space(module, rank, pk_size) + } +} + +impl GLWESwitchingKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + ScalarZnx: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + self.0.encrypt_sk( + module, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ); + } +} + +impl GLWESwitchingKey, FFT64> { + pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( module, res_size, lhs, rhs, ) } -} -impl KeySwitch for GLWEKeySwitchKey -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GLWEKeySwitchKey; - type Rhs = GLWEKeySwitchKey; - - fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe(module, self, lhs, scratch); - } -} - -impl KeySwitchInplaceScratchSpace for GLWEKeySwitchKey, FFT64> { - fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( module, res_size, rhs, ) } -} -impl KeySwitchInplace for GLWEKeySwitchKey -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Rhs = GLWEKeySwitchKey; - - fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe(module, self, rhs, scratch); - } -} - -impl ExternalProductScratchSpace for GLWEKeySwitchKey, FFT64> { - fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( module, res_size, lhs, rhs, ) } -} -impl ExternalProduct for GLWEKeySwitchKey -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GLWEKeySwitchKey; - type Rhs = GGSWCiphertext; - - fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe(module, self, lhs, scratch); - } -} - -impl ExternalProductInplaceScratchSpace for GLWEKeySwitchKey, FFT64> { - fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( module, res_size, rhs, ) } } -impl ExternalProductInplace for GLWEKeySwitchKey +impl GLWESwitchingKey where MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, { - type Rhs = GGSWCiphertext; - - fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe_inplace(module, self, scratch); - } -} - -impl VecGLWEProductScratchSpace for GLWEKeySwitchKey, FFT64> { - fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { - module.bytes_of_vec_znx_dft(2, grlwe_size) - + (module.vec_znx_big_normalize_tmp_bytes() - | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) - + module.bytes_of_vec_znx_dft(1, a_size))) - } -} - -impl VecGLWEProduct for GLWEKeySwitchKey -where - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - fn prod_with_glwe( - &self, + pub fn keyswitch( + &mut self, module: &Module, - res: &mut GLWECiphertext, - a: &GLWECiphertext, + lhs: &GLWESwitchingKey, + rhs: &GLWESwitchingKey, + scratch: &mut base2k::Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + rhs.0 + .prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch); + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut base2k::Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.0 + .prod_with_vec_glwe_inplace(module, &mut self.0, scratch); + } + + pub fn external_product( + &mut self, + module: &Module, + lhs: &GLWESwitchingKey, + rhs: &GGSWCiphertext, scratch: &mut Scratch, ) where - MatZnxDft: MatZnxDftToRef, - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, { - let log_base2k: usize = self.basek(); + rhs.prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch); + } - #[cfg(debug_assertions)] - { - assert_eq!(res.basek(), log_base2k); - assert_eq!(a.basek(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - assert_eq!(a.n(), module.n()); - } - - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise - - { - let (mut a1_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, a.size()); - module.vec_znx_dft(&mut a1_dft, 0, a, 1); - module.vmp_apply(&mut res_dft, &a1_dft, self, scratch2); - } - - let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - - module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0); - - module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); - module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.prod_with_vec_glwe_inplace(module, &mut self.0, scratch); } } diff --git a/core/src/lib.rs b/core/src/lib.rs index 97db860..cdd83d1 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,10 +1,10 @@ pub mod elem; -pub mod encryption; -pub mod external_product; -pub mod ggsw; -pub mod glwe; +pub mod gglwe_ciphertext; +pub mod ggsw_ciphertext; +pub mod glwe_ciphertext; +pub mod glwe_ciphertext_fourier; +pub mod glwe_plaintext; pub mod keys; -pub mod keyswitch; pub mod keyswitch_key; #[cfg(test)] mod test_fft64; diff --git a/core/src/test_fft64/grlwe.rs b/core/src/test_fft64/gglwe.rs similarity index 79% rename from core/src/test_fft64/grlwe.rs rename to core/src/test_fft64/gglwe.rs index 9d9a077..7a7de6d 100644 --- a/core/src/test_fft64/grlwe.rs +++ b/core/src/test_fft64/gglwe.rs @@ -3,15 +3,12 @@ use sampling::source::Source; use crate::{ elem::{GetRow, Infos}, - external_product::{ - ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, - }, - ggsw::GGSWCiphertext, - glwe::{GLWECiphertextFourier, GLWEPlaintext}, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, keys::{SecretKey, SecretKeyFourier}, - keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, - keyswitch_key::GLWEKeySwitchKey, - test_fft64::rgsw::noise_rgsw_product, + keyswitch_key::GLWESwitchingKey, + test_fft64::ggsw::noise_rgsw_product, }; #[test] @@ -20,11 +17,13 @@ fn encrypt_sk() { let log_base2k: usize = 8; let log_k_ct: usize = 54; let rows: usize = 4; + let rank: usize = 1; + let rank_out: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_ct, rows); + let mut ct: GLWESwitchingKey, FFT64> = GLWESwitchingKey::new(&module, log_base2k, log_k_ct, rows, rank, rank_out); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); @@ -35,14 +34,15 @@ fn encrypt_sk() { pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), ); - let mut sk: SecretKey> = SecretKey::new(&module); - sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk: SecretKey> = SecretKey::new(&module, rank); + // sk.fill_ternary_prob(0.5, &mut source_xs); + sk.fill_zero(); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct.encrypt_sk( @@ -56,7 +56,7 @@ fn encrypt_sk() { scratch.borrow(), ); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); (0..ct.rows()).for_each(|row_i| { ct.get_row(&module, row_i, 0, &mut ct_rlwe_dft); @@ -74,21 +74,26 @@ fn keyswitch() { let log_k_grlwe: usize = 60; let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + let rank: usize = 1; + let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe_s0s1: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s1s2: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s0s2: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) - | GLWEKeySwitchKey::keyswitch_scratch_space( + | GLWESwitchingKey::keyswitch_scratch_space( &module, ct_grlwe_s0s2.size(), ct_grlwe_s0s1.size(), @@ -96,22 +101,22 @@ fn keyswitch() { ), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); - let mut sk2: SecretKey> = SecretKey::new(&module); + let mut sk2: SecretKey> = SecretKey::new(&module, rank); sk2.fill_ternary_prob(0.5, &mut source_xs); - let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk2_dft.dft(&module, &sk2); // GRLWE_{s1}(s0) = s0 -> s1 @@ -142,7 +147,7 @@ fn keyswitch() { ct_grlwe_s0s2.keyswitch(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { @@ -179,38 +184,43 @@ fn keyswitch_inplace() { let log_k_grlwe: usize = 60; let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + let rank: usize = 1; + let rank_out: usize = 1; + let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe_s0s1: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s1s2: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s0s1: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); + let mut ct_grlwe_s1s2: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_s0s1.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) - | GLWEKeySwitchKey::keyswitch_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), + | GLWESwitchingKey::keyswitch_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); - let mut sk2: SecretKey> = SecretKey::new(&module); + let mut sk2: SecretKey> = SecretKey::new(&module, rank); sk2.fill_ternary_prob(0.5, &mut source_xs); - let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk2_dft.dft(&module, &sk2); // GRLWE_{s1}(s0) = s0 -> s1 @@ -240,10 +250,10 @@ fn keyswitch_inplace() { // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) ct_grlwe_s0s1.keyswitch_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); - let ct_grlwe_s0s2: GLWEKeySwitchKey, FFT64> = ct_grlwe_s0s1; + let ct_grlwe_s0s2: GLWESwitchingKey, FFT64> = ct_grlwe_s0s1; let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { @@ -280,12 +290,17 @@ fn external_product() { let log_k_grlwe: usize = 60; let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + let rank: usize = 1; + let rank_out: usize = 1; + let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe_in: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_out: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_in: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); + let mut ct_grlwe_out: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); @@ -295,15 +310,15 @@ fn external_product() { let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe_in.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe_in.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_out.size()) - | GLWEKeySwitchKey::external_product_scratch_space( + | GLWESwitchingKey::external_product_scratch_space( &module, ct_grlwe_out.size(), ct_grlwe_in.size(), ct_rgsw.size(), ) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()), + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), ); let k: usize = 1; @@ -312,10 +327,10 @@ fn external_product() { pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); // GRLWE_{s1}(s0) = s0 -> s1 @@ -345,7 +360,7 @@ fn external_product() { ct_grlwe_out.external_product(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); @@ -393,11 +408,15 @@ fn external_product_inplace() { let log_k_grlwe: usize = 60; let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k; + let rank = 1; + let rank_out = 1; + let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); @@ -407,10 +426,10 @@ fn external_product_inplace() { let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe.size()) - | GLWEKeySwitchKey::external_product_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()), + | GLWESwitchingKey::external_product_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()), ); let k: usize = 1; @@ -419,10 +438,10 @@ fn external_product_inplace() { pt_grlwe.fill_ternary_prob(0, 0.5, &mut source_xs); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); // GRLWE_{s1}(s0) = s0 -> s1 @@ -452,7 +471,7 @@ fn external_product_inplace() { ct_grlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); diff --git a/core/src/test_fft64/rgsw.rs b/core/src/test_fft64/ggsw.rs similarity index 87% rename from core/src/test_fft64/rgsw.rs rename to core/src/test_fft64/ggsw.rs index 820b671..ce16ea5 100644 --- a/core/src/test_fft64/rgsw.rs +++ b/core/src/test_fft64/ggsw.rs @@ -6,15 +6,12 @@ use sampling::source::Source; use crate::{ elem::{GetRow, Infos}, - external_product::{ - ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, - }, - ggsw::GGSWCiphertext, - glwe::{GLWECiphertextFourier, GLWEPlaintext}, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, keys::{SecretKey, SecretKeyFourier}, - keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, - keyswitch_key::GLWEKeySwitchKey, - test_fft64::grlwe::noise_grlwe_rlwe_product, + keyswitch_key::GLWESwitchingKey, + test_fft64::gglwe::noise_grlwe_rlwe_product, }; #[test] @@ -23,11 +20,12 @@ fn encrypt_sk() { let log_base2k: usize = 8; let log_k_ct: usize = 54; let rows: usize = 4; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_ct, rows); + let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_ct, rows, rank); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); @@ -39,14 +37,14 @@ fn encrypt_sk() { pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, ct.size()) + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct.encrypt_sk( @@ -60,7 +58,7 @@ fn encrypt_sk() { scratch.borrow(), ); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); @@ -98,12 +96,15 @@ fn keyswitch() { let log_k_rgsw_out: usize = 45; let rows: usize = (log_k_rgsw_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; + let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_in, rows); - let mut ct_rgsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_out, rows); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_rgsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_in, rows, rank); + let mut ct_rgsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_out, rows, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -114,9 +115,9 @@ fn keyswitch() { pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_out.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw_in.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_in.size()) | GGSWCiphertext::keyswitch_scratch_space( &module, ct_rgsw_out.size(), @@ -125,16 +126,16 @@ fn keyswitch() { ), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -161,7 +162,8 @@ fn keyswitch() { ct_rgsw_out.keyswitch(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow()); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_out); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_out, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_out.size()); @@ -215,12 +217,14 @@ fn keyswitch_inplace() { let log_k_grlwe: usize = 60; let log_k_rgsw: usize = 45; let rows: usize = (log_k_rgsw + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw, rows); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw, rows, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -231,22 +235,22 @@ fn keyswitch_inplace() { pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -273,7 +277,8 @@ fn keyswitch_inplace() { ct_rgsw.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw.size()); @@ -328,13 +333,16 @@ fn external_product() { let log_k_rgsw_lhs_in: usize = 45; let log_k_rgsw_lhs_out: usize = 45; let rows: usize = (log_k_rgsw_lhs_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows); - let mut ct_rgsw_lhs_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_in, rows); - let mut ct_rgsw_lhs_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_out, rows); + let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); + let mut ct_rgsw_lhs_in: GGSWCiphertext, FFT64> = + GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_in, rows, rank); + let mut ct_rgsw_lhs_out: GGSWCiphertext, FFT64> = + GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_out, rows, rank); let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); @@ -350,9 +358,9 @@ fn external_product() { pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw_lhs_in.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs_in.size()) | GGSWCiphertext::external_product_scratch_space( &module, ct_rgsw_lhs_out.size(), @@ -361,10 +369,10 @@ fn external_product() { ), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct_rgsw_rhs.encrypt_sk( @@ -392,7 +400,7 @@ fn external_product() { ct_rgsw_lhs_out.external_product(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow()); let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs_out); + GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs_out, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs_out.size()); @@ -457,12 +465,13 @@ fn external_product_inplace() { let log_k_rgsw_rhs: usize = 60; let log_k_rgsw_lhs: usize = 45; let rows: usize = (log_k_rgsw_lhs + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows); - let mut ct_rgsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs, rows); + let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows, rank); + let mut ct_rgsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs, rows, rank); let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); @@ -478,16 +487,16 @@ fn external_product_inplace() { pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_rgsw_rhs.size()) | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs.size()) - | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw_lhs.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw_lhs.size()) | GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct_rgsw_rhs.encrypt_sk( @@ -514,7 +523,8 @@ fn external_product_inplace() { ct_rgsw_lhs.external_product_inplace(&module, &ct_rgsw_rhs, scratch.borrow()); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs.size()); diff --git a/core/src/test_fft64/rlwe.rs b/core/src/test_fft64/glwe.rs similarity index 84% rename from core/src/test_fft64/rlwe.rs rename to core/src/test_fft64/glwe.rs index 6958925..48b6cb6 100644 --- a/core/src/test_fft64/rlwe.rs +++ b/core/src/test_fft64/glwe.rs @@ -7,16 +7,13 @@ use sampling::source::Source; use crate::{ elem::Infos, - encryption::EncryptSkScratchSpace, - external_product::{ - ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, - }, - ggsw::GGSWCiphertext, - glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext}, - keys::{PublicKey, SecretKey, SecretKeyFourier}, - keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, - keyswitch_key::GLWEKeySwitchKey, - test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_product}, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::{GLWEPublicKey, SecretKey, SecretKeyFourier}, + keyswitch_key::GLWESwitchingKey, + test_fft64::{gglwe::noise_grlwe_rlwe_product, ggsw::noise_rgsw_product}, }; #[test] @@ -25,11 +22,12 @@ fn encrypt_sk() { let log_base2k: usize = 8; let log_k_ct: usize = 54; let log_k_pt: usize = 30; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct, rank); let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_pt); let mut source_xs: Source = Source::new([0u8; 32]); @@ -37,13 +35,14 @@ fn encrypt_sk() { let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct.size()), + GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct.size()), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); let mut data_want: Vec = vec![0i64; module.n()]; @@ -93,6 +92,7 @@ fn encrypt_zero_sk() { let module: Module = Module::::new(1024); let log_base2k: usize = 8; let log_k_ct: usize = 55; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; @@ -103,16 +103,16 @@ fn encrypt_zero_sk() { let mut source_xe: Source = Source::new([1u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); - let mut ct_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct); + let mut ct_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct, rank); let mut scratch: ScratchOwned = ScratchOwned::new( GLWECiphertextFourier::decrypt_scratch_space(&module, ct_dft.size()) - | GLWECiphertextFourier::encrypt_zero_sk_scratch_space(&module, ct_dft.size()), + | GLWECiphertextFourier::encrypt_sk_scratch_space(&module, rank, ct_dft.size()), ); ct_dft.encrypt_zero_sk( @@ -135,11 +135,12 @@ fn encrypt_pk() { let log_base2k: usize = 8; let log_k_ct: usize = 54; let log_k_pk: usize = 64; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut source_xs: Source = Source::new([0u8; 32]); @@ -147,12 +148,12 @@ fn encrypt_pk() { let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xu: Source = Source::new([0u8; 32]); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); - let mut pk: PublicKey, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk); + let mut pk: GLWEPublicKey, FFT64> = GLWEPublicKey::new(&module, log_base2k, log_k_pk, rank); pk.generate( &module, &sk_dft, @@ -163,9 +164,9 @@ fn encrypt_pk() { ); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) + GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct.size()) - | GLWECiphertext::encrypt_pk_scratch_space(&module, pk.size()), + | GLWECiphertext::encrypt_pk_scratch_space(&module, rank, pk.size()), ); let mut data_want: Vec = vec![0i64; module.n()]; @@ -206,13 +207,15 @@ fn keyswitch() { let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); @@ -226,9 +229,9 @@ fn keyswitch() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) | GLWECiphertext::keyswitch_scratch_space( &module, ct_rlwe_out.size(), @@ -237,16 +240,16 @@ fn keyswitch() { ), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -305,12 +308,14 @@ fn keyswich_inplace() { let log_k_grlwe: usize = 60; let log_k_rlwe: usize = 45; let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); @@ -324,22 +329,22 @@ fn keyswich_inplace() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -399,13 +404,14 @@ fn external_product() { let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); @@ -426,9 +432,9 @@ fn external_product() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) | GLWECiphertext::external_product_scratch_space( &module, ct_rlwe_out.size(), @@ -437,10 +443,10 @@ fn external_product() { ), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( @@ -511,12 +517,13 @@ fn external_product_inplace() { let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); @@ -537,16 +544,16 @@ fn external_product_inplace() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( diff --git a/core/src/test_fft64/rlwe_dft.rs b/core/src/test_fft64/glwe_fourier.rs similarity index 84% rename from core/src/test_fft64/rlwe_dft.rs rename to core/src/test_fft64/glwe_fourier.rs index 06359b1..661a1e5 100644 --- a/core/src/test_fft64/rlwe_dft.rs +++ b/core/src/test_fft64/glwe_fourier.rs @@ -1,15 +1,12 @@ use crate::{ elem::Infos, - encryption::EncryptSkScratchSpace, - external_product::{ - ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, - }, - ggsw::GGSWCiphertext, - glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext}, + ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, keys::{SecretKey, SecretKeyFourier}, - keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, - keyswitch_key::GLWEKeySwitchKey, - test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_product}, + keyswitch_key::GLWESwitchingKey, + test_fft64::{gglwe::noise_grlwe_rlwe_product, ggsw::noise_rgsw_product}, }; use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; use sampling::source::Source; @@ -23,16 +20,19 @@ fn keyswitch() { let log_k_rlwe_out: usize = 60; let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; + let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); let mut ct_rlwe_in_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); let mut ct_rlwe_out_dft: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out); + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); @@ -46,9 +46,9 @@ fn keyswitch() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) | GLWECiphertextFourier::keyswitch_scratch_space( &module, ct_rlwe_out.size(), @@ -57,16 +57,16 @@ fn keyswitch() { ), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -127,13 +127,16 @@ fn keyswich_inplace() { let log_k_grlwe: usize = 60; let log_k_rlwe: usize = 45; let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe); + let mut ct_grlwe: GLWESwitchingKey, FFT64> = + GLWESwitchingKey::new(&module, log_base2k, log_k_grlwe, rows, rank, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe, rank); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe, rank); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); @@ -147,22 +150,22 @@ fn keyswich_inplace() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) | GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), ); - let mut sk0: SecretKey> = SecretKey::new(&module); + let mut sk0: SecretKey> = SecretKey::new(&module, rank); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk0_dft.dft(&module, &sk0); - let mut sk1: SecretKey> = SecretKey::new(&module); + let mut sk1: SecretKey> = SecretKey::new(&module, rank); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -224,17 +227,18 @@ fn external_product() { let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out, rank); let mut ct_rlwe_dft_in: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in); + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); let mut ct_rlwe_dft_out: GLWECiphertextFourier, FFT64> = - GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out); + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); @@ -255,9 +259,9 @@ fn external_product() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe_in.size()) | GLWECiphertext::external_product_scratch_space( &module, ct_rlwe_out.size(), @@ -266,10 +270,10 @@ fn external_product() { ), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( @@ -342,13 +346,15 @@ fn external_product_inplace() { let log_k_rlwe_in: usize = 45; let log_k_rlwe_out: usize = 60; let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k; + let rank: usize = 1; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows, rank); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in, rank); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in, rank); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); @@ -369,16 +375,16 @@ fn external_product_inplace() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) - | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, rank, ct_rlwe.size()) | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), ); - let mut sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module, rank); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module, rank); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( diff --git a/core/src/test_fft64/mod.rs b/core/src/test_fft64/mod.rs index 59e2895..ffaf1dc 100644 --- a/core/src/test_fft64/mod.rs +++ b/core/src/test_fft64/mod.rs @@ -1,4 +1,4 @@ -mod grlwe; -mod rgsw; -mod rlwe; -mod rlwe_dft; +mod gglwe; +mod ggsw; +mod glwe; +mod glwe_fourier; diff --git a/core/src/vec_glwe_product.rs b/core/src/vec_glwe_product.rs index 7920de9..63c4769 100644 --- a/core/src/vec_glwe_product.rs +++ b/core/src/vec_glwe_product.rs @@ -5,7 +5,8 @@ use base2k::{ use crate::{ elem::{GetRow, Infos, SetRow}, - glwe::{GLWECiphertext, GLWECiphertextFourier}, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, }; pub(crate) trait VecGLWEProductScratchSpace { @@ -81,8 +82,8 @@ pub(crate) trait VecGLWEProduct: Infos { let mut a_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: a_data, - log_base2k: a.basek(), - log_k: a.k(), + basek: a.basek(), + k: a.k(), }; a.idft(module, &mut a_idft, scratch_1); @@ -91,8 +92,8 @@ pub(crate) trait VecGLWEProduct: Infos { let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: res_data, - log_base2k: res.basek(), - log_k: res.k(), + basek: res.basek(), + k: res.k(), }; self.prod_with_glwe(module, &mut res_idft, &a_idft, scratch_2); @@ -122,8 +123,8 @@ pub(crate) trait VecGLWEProduct: Infos { let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { data: res_data, - log_base2k: res.basek(), - log_k: res.k(), + basek: res.basek(), + k: res.k(), }; res.idft(module, &mut res_idft, scratch_1); @@ -143,22 +144,22 @@ pub(crate) trait VecGLWEProduct: Infos { let mut tmp_a_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_row_data, - log_base2k: a.basek(), - log_k: a.k(), + basek: a.basek(), + k: a.k(), }; let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, res.size()); let mut tmp_res_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_res_data, - log_base2k: res.basek(), - log_k: res.k(), + basek: res.basek(), + k: res.k(), }; let min_rows: usize = res.rows().min(a.rows()); (0..res.rows()).for_each(|row_i| { - (0..res.rank()).for_each(|col_j| { + (0..res.cols()).for_each(|col_j| { a.get_row(module, row_i, col_j, &mut tmp_a_row); self.prod_with_glwe_fourier(module, &mut tmp_res_row, &tmp_a_row, scratch2); res.set_row(module, row_i, col_j, &tmp_res_row); @@ -168,7 +169,7 @@ pub(crate) trait VecGLWEProduct: Infos { tmp_res_row.data.zero(); (min_rows..res.rows()).for_each(|row_i| { - (0..self.rank()).for_each(|col_j| { + (0..self.cols()).for_each(|col_j| { res.set_row(module, row_i, col_j, &tmp_res_row); }); }); @@ -182,12 +183,12 @@ pub(crate) trait VecGLWEProduct: Infos { let mut tmp_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { data: tmp_row_data, - log_base2k: res.basek(), - log_k: res.k(), + basek: res.basek(), + k: res.k(), }; (0..res.rows()).for_each(|row_i| { - (0..res.rank()).for_each(|col_j| { + (0..res.cols()).for_each(|col_j| { res.get_row(module, row_i, col_j, &mut tmp_row); self.prod_with_glwe_fourier_inplace(module, &mut tmp_row, scratch1); res.set_row(module, row_i, col_j, &tmp_row);

: ScalarZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - encrypt_rgsw_sk( - module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ) - } -} - -impl GetRow for GGSWCiphertext -where - MatZnxDft: MatZnxDftToRef, -{ - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) - where - VecZnxDft: VecZnxDftToMut, - { - module.vmp_extract_row(res, self, row_i, col_j); - } -} - -impl SetRow for GGSWCiphertext -where - MatZnxDft: MatZnxDftToMut, -{ - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) - where - VecZnxDft: VecZnxDftToRef, - { - module.vmp_prepare_row(self, row_i, col_j, a); - } -} - -impl KeySwitchScratchSpace for GGSWCiphertext, FFT64> { - fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, - ) - } -} - -impl KeySwitch for GGSWCiphertext -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GGSWCiphertext; - type Rhs = GLWEKeySwitchKey; - - fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe(module, self, lhs, scratch); - } -} - -impl KeySwitchInplaceScratchSpace for GGSWCiphertext, FFT64> { - fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( - module, res_size, rhs, - ) - } -} - -impl KeySwitchInplace for GGSWCiphertext -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Rhs = GLWEKeySwitchKey; - - fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe(module, self, rhs, scratch); - } -} - -impl ExternalProductScratchSpace for GGSWCiphertext, FFT64> { - fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( - module, res_size, lhs, rhs, - ) - } -} - -impl ExternalProduct for GGSWCiphertext -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GGSWCiphertext; - type Rhs = GGSWCiphertext; - - fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe(module, self, lhs, scratch); - } -} - -impl ExternalProductInplaceScratchSpace for GGSWCiphertext, FFT64> { - fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, - ) - } -} - -impl ExternalProductInplace for GGSWCiphertext -where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Rhs = GGSWCiphertext; - - fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_vec_glwe_inplace(module, self, scratch); - } -} - -impl VecGLWEProductScratchSpace for GGSWCiphertext, FFT64> { - fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { - module.bytes_of_vec_znx_dft(2, rgsw_size) - + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) - | module.vec_znx_big_normalize_tmp_bytes()) - } -} - -impl VecGLWEProduct for GGSWCiphertext -where - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - fn prod_with_glwe( - &self, - module: &Module, - res: &mut GLWECiphertext, - a: &GLWECiphertext, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef, - { - let log_base2k: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.basek(), log_base2k); - assert_eq!(a.basek(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - assert_eq!(a.n(), module.n()); - } - - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise - - { - let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, a.size()); - module.vec_znx_dft(&mut a_dft, 0, a, 0); - module.vec_znx_dft(&mut a_dft, 1, a, 1); - module.vmp_apply(&mut res_dft, &a_dft, self, scratch2); - } - - let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); - - module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); - module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); - } -} diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs new file mode 100644 index 0000000..9d42df8 --- /dev/null +++ b/core/src/ggsw_ciphertext.rs @@ -0,0 +1,316 @@ +use base2k::{ + Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, + ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch, + VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, + ZnxZero, +}; +use sampling::source::Source; + +use crate::{ + elem::{GetRow, Infos, SetRow}, + gglwe_ciphertext::GGLWECiphertext, + glwe_ciphertext::GLWECiphertext, + glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_plaintext::GLWEPlaintext, + keys::SecretKeyFourier, + keyswitch_key::GLWESwitchingKey, + utils::derive_size, + vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, +}; + +pub struct GGSWCiphertext { + pub data: MatZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl GGSWCiphertext, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize, rank: usize) -> Self { + Self { + data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for GGSWCiphertext { + type Inner = MatZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.log_base2k + } + + fn k(&self) -> usize { + self.log_k + } +} + +impl GGSWCiphertext { + pub fn rank(&self) -> usize { + self.data.cols_out() - 1 + } +} + +impl MatZnxDftToMut for GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl MatZnxDftToRef for GGSWCiphertext +where + MatZnxDft: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl GGSWCiphertext, FFT64> { + pub fn encrypt_sk_scratch_space(module: &Module, rank: usize, size: usize) -> usize { + GLWECiphertext::encrypt_sk_scratch_space(module, rank, size) + + module.bytes_of_vec_znx(rank + 1, size) + + module.bytes_of_vec_znx(1, size) + + module.bytes_of_vec_znx_dft(rank + 1, size) + } + + pub fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( + module, res_size, lhs, rhs, + ) + } + + pub fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } + + pub fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( + module, res_size, lhs, rhs, + ) + } + + pub fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + ScalarZnx: ScalarZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), sk_dft.rank()); + assert_eq!(self.n(), module.n()); + assert_eq!(pt.n(), module.n()); + assert_eq!(sk_dft.n(), module.n()); + } + + let size: usize = self.size(); + let log_base2k: usize = self.basek(); + let k: usize = self.k(); + let cols: usize = self.rank() + 1; + + let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size); + let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, cols, size); + + let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { + data: tmp_znx_pt, + basek: log_base2k, + k: k, + }; + + let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { + data: tmp_znx_ct, + basek: log_base2k, + k, + }; + + (0..self.rows()).for_each(|row_j| { + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0); + module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2); + + (0..cols).for_each(|col_i| { + // rlwe encrypt of vec_znx_pt into vec_znx_ct + + vec_znx_ct.encrypt_sk_private( + module, + Some((&vec_znx_pt, col_i)), + sk_dft, + source_xa, + source_xe, + sigma, + bound, + scrach_2, + ); + + // Switch vec_znx_ct into DFT domain + { + let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, cols, size); + + (0..cols).for_each(|i| { + module.vec_znx_dft(&mut vec_znx_dft_ct, i, &vec_znx_ct, i); + }); + + module.vmp_prepare_row(self, row_j, col_i, &vec_znx_dft_ct); + } + }); + + vec_znx_pt.data.zero(); // zeroes for next iteration + }); + } + + pub fn keyswitch( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + rhs.0.prod_with_vec_glwe(module, self, lhs, scratch); + } + + pub fn keyswitch_inplace( + &mut self, + module: &Module, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.0.prod_with_vec_glwe_inplace(module, self, scratch); + } + + pub fn external_product( + &mut self, + module: &Module, + lhs: &GGSWCiphertext, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + rhs.prod_with_vec_glwe(module, self, lhs, scratch); + } + + pub fn external_product_inplace( + &mut self, + module: &Module, + rhs: &GGSWCiphertext, + scratch: &mut Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + { + rhs.prod_with_vec_glwe_inplace(module, self, scratch); + } +} + +impl GetRow for GGSWCiphertext +where + MatZnxDft: MatZnxDftToRef, +{ + fn get_row( + &self, + module: &Module, + row_i: usize, + col_j: usize, + res: &mut GLWECiphertextFourier, + ) where + VecZnxDft: VecZnxDftToMut, + { + module.vmp_extract_row(res, self, row_i, col_j); + } +} + +impl SetRow for GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut, +{ + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) + where + VecZnxDft: VecZnxDftToRef, + { + module.vmp_prepare_row(self, row_i, col_j, a); + } +} + +impl VecGLWEProductScratchSpace for GGSWCiphertext, FFT64> { + fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { + module.bytes_of_vec_znx_dft(2, rgsw_size) + + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) + | module.vec_znx_big_normalize_tmp_bytes()) + } +} + +impl VecGLWEProduct for GGSWCiphertext +where + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + fn prod_with_glwe( + &self, + module: &Module, + res: &mut GLWECiphertext, + a: &GLWECiphertext, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef, + { + let log_base2k: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.basek(), log_base2k); + assert_eq!(a.basek(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + assert_eq!(a.n(), module.n()); + } + + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise + + { + let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, a.size()); + module.vec_znx_dft(&mut a_dft, 0, a, 0); + module.vec_znx_dft(&mut a_dft, 1, a, 1); + module.vmp_apply(&mut res_dft, &a_dft, self, scratch2); + } + + let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); + + module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1); + module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); + } +} diff --git a/core/src/glwe.rs b/core/src/glwe.rs deleted file mode 100644 index e50582d..0000000 --- a/core/src/glwe.rs +++ /dev/null @@ -1,845 +0,0 @@ -use base2k::{ - AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, - ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, - VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, -}; -use sampling::source::Source; - -use crate::{ - elem::Infos, - encryption::{EncryptSk, EncryptSkScratchSpace, EncryptZeroSkScratchSpace}, - external_product::{ - ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, - }, - ggsw::GGSWCiphertext, - keys::{PublicKey, SecretDistribution, SecretKeyFourier}, - keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, - keyswitch_key::GLWEKeySwitchKey, - utils::derive_size, - vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, -}; - -pub struct GLWECiphertext { - pub data: VecZnx, - pub log_base2k: usize, - pub log_k: usize, -} - -impl GLWECiphertext> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx(2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for GLWECiphertext { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn basek(&self) -> usize { - self.log_base2k - } - - fn k(&self) -> usize { - self.log_k - } -} - -impl VecZnxToMut for GLWECiphertext -where - VecZnx: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data.to_mut() - } -} - -impl VecZnxToRef for GLWECiphertext -where - VecZnx: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data.to_ref() - } -} - -impl GLWECiphertext -where - VecZnx: VecZnxToRef, -{ - #[allow(dead_code)] - pub(crate) fn dft(&self, module: &Module, res: &mut GLWECiphertextFourier) - where - VecZnxDft: VecZnxDftToMut + ZnxInfos, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), 2); - assert_eq!(res.rank(), 2); - assert_eq!(self.basek(), res.basek()) - } - - module.vec_znx_dft(res, 0, self, 0); - module.vec_znx_dft(res, 1, self, 1); - } -} - -impl KeySwitchScratchSpace for GLWECiphertext> { - fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) - } -} - -impl KeySwitch for GLWECiphertext -where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GLWECiphertext; - type Rhs = GLWEKeySwitchKey; - - fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe(module, self, lhs, scratch); - } -} - -impl KeySwitchInplaceScratchSpace for GLWECiphertext> { - fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, - ) - } -} - -impl KeySwitchInplace for GLWECiphertext -where - VecZnx: VecZnxToMut + VecZnxToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Rhs = GLWEKeySwitchKey; - - fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe_inplace(module, self, scratch); - } -} - -impl ExternalProductScratchSpace for GLWECiphertext> { - fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) - } -} - -impl ExternalProduct for GLWECiphertext -where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GLWECiphertext; - type Rhs = GGSWCiphertext; - - fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe(module, self, lhs, scratch); - } -} - -impl ExternalProductInplaceScratchSpace for GLWECiphertext> { - fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { - , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( - module, res_size, rhs, - ) - } -} - -impl ExternalProductInplace for GLWECiphertext -where - VecZnx: VecZnxToMut + VecZnxToRef, - MatZnxDft: MatZnxDftToRef + ZnxInfos, -{ - type Rhs = GGSWCiphertext; - - fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { - rhs.prod_with_glwe_inplace(module, self, scratch); - } -} - -impl GLWECiphertext> { - pub fn encrypt_pk_scratch_space(module: &Module, pk_size: usize) -> usize { - ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) - + module.bytes_of_scalar_znx_dft(1) - + module.vec_znx_big_normalize_tmp_bytes() - } - - pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) - } -} - -impl EncryptSkScratchSpace for GLWECiphertext> { - fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) - } -} - -impl EncryptSk for GLWECiphertext -where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - type Ciphertext = GLWECiphertext; - type Plaintext = GLWEPlaintext; - type SecretKey = SecretKeyFourier; - - fn encrypt_sk( - &self, - module: &Module, - ct: &mut Self::Ciphertext, - pt: &Self::Plaintext, - sk: &Self::SecretKey, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) { - encrypt_glwe_sk( - module, - ct, - Some((pt, 0)), - sk, - source_xa, - source_xe, - sigma, - bound, - scratch, - ); - } -} - -pub(crate) fn encrypt_glwe_sk( - module: &Module, - ct: &mut GLWECiphertext, - pt: Option<(&GLWEPlaintext, usize)>, - sk_dft: &SecretKeyFourier, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let log_base2k: usize = ct.basek(); - let log_k: usize = ct.k(); - let size: usize = ct.size(); - - // c1 = a - ct.data.fill_uniform(log_base2k, 1, size, source_xa); - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); - - { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); - module.vec_znx_dft(&mut c0_dft, 0, ct, 1); - - // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); - } - - // c0_big = m - c0_big - if let Some((pt, col)) = pt { - match col { - 0 => module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0), - 1 => { - module.vec_znx_big_negate_inplace(&mut c0_big, 0); - module.vec_znx_add_inplace(ct, 1, pt, 0); - module.vec_znx_normalize_inplace(log_base2k, ct, 1, scratch_1); - } - _ => panic!("invalid target column: {}", col), - } - } else { - module.vec_znx_big_negate_inplace(&mut c0_big, 0); - } - // c0_big += e - c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); - - // c0 = norm(c0_big = -as + m + e) - module.vec_znx_big_normalize(log_base2k, ct, 0, &c0_big, 0, scratch_1); -} - -pub fn decrypt_glwe( - module: &Module, - pt: &mut GLWEPlaintext