diff --git a/base2k/src/module.rs b/base2k/src/module.rs index 0e7d124..aab18b4 100644 --- a/base2k/src/module.rs +++ b/base2k/src/module.rs @@ -1,106 +1,107 @@ -use crate::GALOISGENERATOR; -use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info}; -use std::marker::PhantomData; - -#[derive(Copy, Clone)] -#[repr(u8)] -pub enum BACKEND { - FFT64, - NTT120, -} - -pub trait Backend { - const KIND: BACKEND; - fn module_type() -> u32; -} - -pub struct FFT64; -pub struct NTT120; - -impl Backend for FFT64 { - const KIND: BACKEND = BACKEND::FFT64; - fn module_type() -> u32 { - 0 - } -} - -impl Backend for NTT120 { - const KIND: BACKEND = BACKEND::NTT120; - fn module_type() -> u32 { - 1 - } -} - -pub struct Module { - pub ptr: *mut MODULE, - n: usize, - _marker: PhantomData, -} - -impl Module { - // Instantiates a new module. - pub fn new(n: usize) -> Self { - unsafe { - let m: *mut module_info_t = new_module_info(n as u64, B::module_type()); - if m.is_null() { - panic!("Failed to create module."); - } - Self { - ptr: m, - n: n, - _marker: PhantomData, - } - } - } - - pub fn n(&self) -> usize { - self.n - } - - pub fn log_n(&self) -> usize { - (usize::BITS - (self.n() - 1).leading_zeros()) as _ - } - - pub fn cyclotomic_order(&self) -> u64 { - (self.n() << 1) as _ - } - - // Returns GALOISGENERATOR^|generator| * sign(generator) - pub fn galois_element(&self, generator: i64) -> i64 { - if generator == 0 { - return 1; - } - ((mod_exp_u64(GALOISGENERATOR, generator.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * generator.signum() - } - - // Returns gen^-1 - pub fn galois_element_inv(&self, generator: i64) -> i64 { - if generator == 0 { - panic!("cannot invert 0") - } - ((mod_exp_u64( - generator.abs() as u64, - (self.cyclotomic_order() - 1) as usize, - ) & (self.cyclotomic_order() - 1)) as i64) - * generator.signum() - } - - pub fn free(self) { - unsafe { delete_module_info(self.ptr) } - drop(self); - } -} - -fn mod_exp_u64(x: u64, e: usize) -> u64 { - let mut y: u64 = 1; - let mut x_pow: u64 = x; - let mut exp = e; - while exp > 0 { - if exp & 1 == 1 { - y = y.wrapping_mul(x_pow); - } - x_pow = x_pow.wrapping_mul(x_pow); - exp >>= 1; - } - y -} +use crate::GALOISGENERATOR; +use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info}; +use std::marker::PhantomData; + +#[derive(Copy, Clone)] +#[repr(u8)] +pub enum BACKEND { + FFT64, + NTT120, +} + +pub trait Backend { + const KIND: BACKEND; + fn module_type() -> u32; +} + +pub struct FFT64; +pub struct NTT120; + +impl Backend for FFT64 { + const KIND: BACKEND = BACKEND::FFT64; + fn module_type() -> u32 { + 0 + } +} + +impl Backend for NTT120 { + const KIND: BACKEND = BACKEND::NTT120; + fn module_type() -> u32 { + 1 + } +} + +pub struct Module { + pub ptr: *mut MODULE, + n: usize, + _marker: PhantomData, +} + +impl Module { + // Instantiates a new module. + pub fn new(n: usize) -> Self { + unsafe { + let m: *mut module_info_t = new_module_info(n as u64, B::module_type()); + if m.is_null() { + panic!("Failed to create module."); + } + Self { + ptr: m, + n: n, + _marker: PhantomData, + } + } + } + + pub fn n(&self) -> usize { + self.n + } + + pub fn log_n(&self) -> usize { + (usize::BITS - (self.n() - 1).leading_zeros()) as _ + } + + pub fn cyclotomic_order(&self) -> u64 { + (self.n() << 1) as _ + } + + // Returns GALOISGENERATOR^|generator| * sign(generator) + pub fn galois_element(&self, generator: i64) -> i64 { + if generator == 0 { + return 1; + } + ((mod_exp_u64(GALOISGENERATOR, generator.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * generator.signum() + } + + // Returns gen^-1 + pub fn galois_element_inv(&self, generator: i64) -> i64 { + if generator == 0 { + panic!("cannot invert 0") + } + ((mod_exp_u64( + generator.abs() as u64, + (self.cyclotomic_order() - 1) as usize, + ) & (self.cyclotomic_order() - 1)) as i64) + * generator.signum() + } +} + +impl Drop for Module { + fn drop(&mut self) { + unsafe { delete_module_info(self.ptr) } + } +} + +fn mod_exp_u64(x: u64, e: usize) -> u64 { + let mut y: u64 = 1; + let mut x_pow: u64 = x; + let mut exp = e; + while exp > 0 { + if exp & 1 == 1 { + y = y.wrapping_mul(x_pow); + } + x_pow = x_pow.wrapping_mul(x_pow); + exp >>= 1; + } + y +} diff --git a/core/Cargo.toml b/core/Cargo.toml index 692c4fb..a54bd5a 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -10,3 +10,11 @@ base2k = {path="../base2k"} sampling = {path="../sampling"} rand_distr = {workspace = true} itertools = {workspace = true} + +[[bench]] +name = "external_product_glwe_fft64" +harness = false + +[[bench]] +name = "keyswitch_glwe_fft64" +harness = false \ No newline at end of file diff --git a/core/benches/external_product_glwe_fft64.rs b/core/benches/external_product_glwe_fft64.rs new file mode 100644 index 0000000..4462fab --- /dev/null +++ b/core/benches/external_product_glwe_fft64.rs @@ -0,0 +1,205 @@ +use base2k::{FFT64, Module, ScalarZnxAlloc, ScratchOwned}; +use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; +use rlwe::{ + elem::Infos, + encryption::EncryptSkScratchSpace, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, + }, + ggsw::GGSWCiphertext, + glwe::GLWECiphertext, + keys::{SecretKey, SecretKeyFourier}, +}; +use sampling::source::Source; + +fn bench_external_product_glwe_fft64(c: &mut Criterion) { + let mut group = c.benchmark_group("external_product_glwe_fft64"); + + struct Params { + log_n: usize, + basek: usize, + k_rlwe_in: usize, + k_rlwe_out: usize, + k_rgsw: usize, + } + + fn runner(p: Params) -> impl FnMut() { + let module: Module = Module::::new(1 << p.log_n); + + let basek: usize = p.basek; + let k_rlwe_in: usize = p.k_rlwe_in; + let k_rlwe_out: usize = p.k_rlwe_out; + let k_rgsw: usize = p.k_rgsw; + + let rows: usize = (p.k_rlwe_in + p.basek - 1) / p.basek; + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_rgsw, rows); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_in); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_out); + let pt_rgsw: base2k::ScalarZnx> = module.new_scalar_znx(1); + + let mut scratch = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::external_product_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_rgsw.size(), + ), + ); + + let mut source_xs = Source::new([0u8; 32]); + let mut source_xe = Source::new([0u8; 32]); + let mut source_xa = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + move || { + ct_rlwe_out.external_product( + black_box(&module), + black_box(&ct_rlwe_in), + black_box(&ct_rgsw), + black_box(scratch.borrow()), + ); + } + } + + let params_set: Vec = vec![Params { + log_n: 10, + basek: 7, + k_rlwe_in: 27, + k_rlwe_out: 27, + k_rgsw: 27, + }]; + + for params in params_set { + let id = BenchmarkId::new("EXTERNAL_PRODUCT_GLWE_FFT64", ""); + let mut runner = runner(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) { + let mut group = c.benchmark_group("external_product_glwe_inplace_fft64"); + + struct Params { + log_n: usize, + basek: usize, + k_rlwe: usize, + k_rgsw: usize, + } + + fn runner(p: Params) -> impl FnMut() { + let module: Module = Module::::new(1 << p.log_n); + + let basek: usize = p.basek; + let k_rlwe: usize = p.k_rlwe; + let k_rgsw: usize = p.k_rgsw; + + let rows: usize = (p.k_rlwe + p.basek - 1) / p.basek; + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, basek, k_rgsw, rows); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe); + let pt_rgsw: base2k::ScalarZnx> = module.new_scalar_znx(1); + + let mut scratch = ScratchOwned::new( + GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + ); + + let mut source_xs = Source::new([0u8; 32]); + let mut source_xe = Source::new([0u8; 32]); + let mut source_xa = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + sk_dft.dft(&module, &sk); + + ct_rgsw.encrypt_sk( + &module, + &pt_rgsw, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + move || { + let scratch_borrow = scratch.borrow(); + (0..1374).for_each(|i| { + ct_rlwe.external_product_inplace( + black_box(&module), + black_box(&ct_rgsw), + black_box(scratch_borrow), + ); + }); + } + } + + let params_set: Vec = vec![Params { + log_n: 9, + basek: 18, + k_rlwe: 27, + k_rgsw: 27, + }]; + + for params in params_set { + let id = BenchmarkId::new("EXTERNAL_PRODUCT_GLWE_INPLACE_FFT64", ""); + let mut runner = runner(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_external_product_glwe_fft64, + bench_external_product_glwe_inplace_fft64 +); +criterion_main!(benches); diff --git a/core/benches/keyswitch_glwe_fft64.rs b/core/benches/keyswitch_glwe_fft64.rs new file mode 100644 index 0000000..3a25360 --- /dev/null +++ b/core/benches/keyswitch_glwe_fft64.rs @@ -0,0 +1,200 @@ +use base2k::{FFT64, Module, ScalarZnxAlloc, ScratchOwned}; +use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; +use rlwe::{ + elem::Infos, + encryption::EncryptSkScratchSpace, + glwe::GLWECiphertext, + keys::{SecretKey, SecretKeyFourier}, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, + keyswitch_key::GLWEKeySwitchKey, +}; +use sampling::source::Source; + +fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { + let mut group = c.benchmark_group("keyswitch_glwe_fft64"); + + struct Params { + log_n: usize, + basek: usize, + k_rlwe_in: usize, + k_rlwe_out: usize, + k_grlwe: usize, + } + + fn runner(p: Params) -> impl FnMut() { + let module: Module = Module::::new(1 << p.log_n); + + let basek: usize = p.basek; + let k_rlwe_in: usize = p.k_rlwe_in; + let k_rlwe_out: usize = p.k_rlwe_out; + let k_grlwe: usize = p.k_grlwe; + + let rows: usize = (p.k_rlwe_in + p.basek - 1) / p.basek; + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, basek, k_grlwe, rows); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_in); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe_out); + let pt_grlwe: base2k::ScalarZnx> = module.new_scalar_znx(1); + + let mut scratch = ScratchOwned::new( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::keyswitch_scratch_space( + &module, + ct_rlwe_out.size(), + ct_rlwe_in.size(), + ct_grlwe.size(), + ), + ); + + let mut source_xs = Source::new([0u8; 32]); + let mut source_xe = Source::new([0u8; 32]); + let mut source_xa = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + sk_dft.dft(&module, &sk); + + ct_grlwe.encrypt_sk( + &module, + &pt_grlwe, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe_in.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + move || { + ct_rlwe_out.keyswitch( + black_box(&module), + black_box(&ct_rlwe_in), + black_box(&ct_grlwe), + black_box(scratch.borrow()), + ); + } + } + + let params_set: Vec = vec![Params { + log_n: 16, + basek: 50, + k_rlwe_in: 1250, + k_rlwe_out: 1250, + k_grlwe: 1250 + 66, + }]; + + for params in params_set { + let id = BenchmarkId::new("KEYSWITCH_GLWE_FFT64", ""); + let mut runner = runner(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) { + let mut group = c.benchmark_group("keyswitch_glwe_inplace_fft64"); + + struct Params { + log_n: usize, + basek: usize, + k_rlwe: usize, + k_grlwe: usize, + } + + fn runner(p: Params) -> impl FnMut() { + let module: Module = Module::::new(1 << p.log_n); + + let basek: usize = p.basek; + let k_rlwe: usize = p.k_rlwe; + let k_grlwe: usize = p.k_grlwe; + + let rows: usize = (p.k_rlwe + p.basek - 1) / p.basek; + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, basek, k_grlwe, rows); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, basek, k_rlwe); + let pt_grlwe: base2k::ScalarZnx> = module.new_scalar_znx(1); + + let mut scratch = ScratchOwned::new( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), + ); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); + sk_dft.dft(&module, &sk); + + ct_grlwe.encrypt_sk( + &module, + &pt_grlwe, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + ct_rlwe.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + move || { + ct_rlwe.keyswitch_inplace( + black_box(&module), + black_box(&ct_grlwe), + black_box(scratch.borrow()), + ); + } + } + + let params_set: Vec = vec![Params { + log_n: 9, + basek: 18, + k_rlwe: 27, + k_grlwe: 27, + }]; + + for params in params_set { + let id = BenchmarkId::new("KEYSWITCH_GLWE_INPLACE_FFT64", ""); + let mut runner = runner(params); + group.bench_with_input(id, &(), |b, _| b.iter(&mut runner)); + } + + group.finish(); +} + +criterion_group!( + benches, + bench_keyswitch_glwe_fft64, + bench_keyswitch_glwe_inplace_fft64 +); +criterion_main!(benches); diff --git a/core/src/elem.rs b/core/src/elem.rs index b66c86d..bf5ca1e 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -1,14 +1,6 @@ -use base2k::{ - Backend, FFT64, MatZnxDft, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, - VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, -}; +use base2k::{Backend, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos}; -use crate::{ - grlwe::GRLWECt, - rgsw::RGSWCt, - rlwe::{RLWECt, RLWECtDft}, - utils::derive_size, -}; +use crate::{glwe::GLWECiphertextFourier, utils::derive_size}; pub trait Infos { type Inner: ZnxInfos; @@ -31,244 +23,37 @@ pub trait Infos { } /// Returns the number of polynomials in each row. - fn cols(&self) -> usize { + fn rank(&self) -> usize { self.inner().cols() } /// Returns the number of size per polynomial. fn size(&self) -> usize { let size: usize = self.inner().size(); - debug_assert_eq!(size, derive_size(self.log_base2k(), self.log_k())); + debug_assert_eq!(size, derive_size(self.basek(), self.k())); size } /// Returns the total number of small polynomials. fn poly_count(&self) -> usize { - self.rows() * self.cols() * self.size() + self.rows() * self.rank() * self.size() } /// Returns the base 2 logarithm of the ciphertext base. - fn log_base2k(&self) -> usize; + fn basek(&self) -> usize; /// Returns the bit precision of the ciphertext. - fn log_k(&self) -> usize; + fn k(&self) -> usize; } pub trait GetRow { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) where VecZnxDft: VecZnxDftToMut; } pub trait SetRow { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) where VecZnxDft: VecZnxDftToRef; } - -pub trait ProdInplaceScratchSpace { - fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; - fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize; -} - -pub trait ProdInplace -where - MatZnxDft: MatZnxDftToRef, -{ - fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch); - fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch); -} - -pub trait ProdScratchSpace { - fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; - fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; -} - -pub trait Product -where - MatZnxDft: MatZnxDftToRef, -{ - type Lhs; - - fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch); - fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch); -} - -pub(crate) trait MatRLWEProductScratchSpace { - fn prod_with_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize; - - fn prod_with_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { - Self::prod_with_rlwe_scratch_space(module, res_size, res_size, mat_size) - } - - fn prod_with_rlwe_dft_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { - (Self::prod_with_rlwe_scratch_space(module, res_size, a_size, mat_size) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, a_size) - + module.bytes_of_vec_znx(2, res_size) - } - - fn prod_with_rlwe_dft_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { - (Self::prod_with_rlwe_inplace_scratch_space(module, res_size, mat_size) | module.vec_znx_idft_tmp_bytes()) - + module.bytes_of_vec_znx(2, res_size) - } - - fn prod_with_mat_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, mat_size: usize) -> usize { - Self::prod_with_rlwe_dft_scratch_space(module, res_size, a_size, mat_size) - + module.bytes_of_vec_znx_dft(2, a_size) - + module.bytes_of_vec_znx_dft(2, res_size) - } - - fn prod_with_mat_rlwe_inplace_scratch_space(module: &Module, res_size: usize, mat_size: usize) -> usize { - Self::prod_with_rlwe_dft_inplace_scratch_space(module, res_size, mat_size) + module.bytes_of_vec_znx_dft(2, res_size) - } -} - -pub(crate) trait MatRLWEProduct: Infos { - fn prod_with_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) - where - VecZnx: VecZnxToMut, - VecZnx: VecZnxToRef; - - fn prod_with_rlwe_inplace(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) - where - VecZnx: VecZnxToMut + VecZnxToRef, - { - unsafe { - let res_ptr: *mut RLWECt = res as *mut RLWECt; // This is ok because [Self::mul_rlwe] only updates res at the end. - self.prod_with_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch); - } - } - - fn prod_with_rlwe_dft( - &self, - module: &Module, - res: &mut RLWECtDft, - a: &RLWECtDft, - scratch: &mut Scratch, - ) where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, - VecZnxDft: VecZnxDftToRef + ZnxInfos, - { - let log_base2k: usize = self.log_base2k(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); - - let mut a_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: a_data, - log_base2k: a.log_base2k(), - log_k: a.log_k(), - }; - - a.idft(module, &mut a_idft, scratch_1); - - let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); - - let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: res_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - self.prod_with_rlwe(module, &mut res_idft, &a_idft, scratch_2); - - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); - } - - fn prod_with_rlwe_dft_inplace(&self, module: &Module, res: &mut RLWECtDft, scratch: &mut Scratch) - where - VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, - { - let log_base2k: usize = self.log_base2k(); - - #[cfg(debug_assertions)] - { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(self.n(), module.n()); - assert_eq!(res.n(), module.n()); - } - - let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); - - let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> { - data: res_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - res.idft(module, &mut res_idft, scratch_1); - - self.prod_with_rlwe_inplace(module, &mut res_idft, scratch_1); - - module.vec_znx_dft(res, 0, &res_idft, 0); - module.vec_znx_dft(res, 1, &res_idft, 1); - } - - fn prod_with_mat_rlwe(&self, module: &Module, res: &mut RES, a: &LHS, scratch: &mut Scratch) - where - LHS: GetRow + Infos, - RES: SetRow + Infos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); - - let mut tmp_a_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: a.log_base2k(), - log_k: a.log_k(), - }; - - let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, res.size()); - - let mut tmp_res_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_res_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - let min_rows: usize = res.rows().min(a.rows()); - - (0..res.rows()).for_each(|row_i| { - (0..res.cols()).for_each(|col_j| { - a.get_row(module, row_i, col_j, &mut tmp_a_row); - self.prod_with_rlwe_dft(module, &mut tmp_res_row, &tmp_a_row, scratch2); - res.set_row(module, row_i, col_j, &tmp_res_row); - }); - }); - - tmp_res_row.data.zero(); - - (min_rows..res.rows()).for_each(|row_i| { - (0..self.cols()).for_each(|col_j| { - res.set_row(module, row_i, col_j, &tmp_res_row); - }); - }); - } - - fn prod_with_mat_rlwe_inplace(&self, module: &Module, res: &mut RES, scratch: &mut Scratch) - where - RES: GetRow + SetRow + Infos, - { - let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); - - let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> { - data: tmp_row_data, - log_base2k: res.log_base2k(), - log_k: res.log_k(), - }; - - (0..res.rows()).for_each(|row_i| { - (0..res.cols()).for_each(|col_j| { - res.get_row(module, row_i, col_j, &mut tmp_row); - self.prod_with_rlwe_dft_inplace(module, &mut tmp_row, scratch1); - res.set_row(module, row_i, col_j, &tmp_row); - }); - }); - } -} diff --git a/core/src/encryption.rs b/core/src/encryption.rs new file mode 100644 index 0000000..915834c --- /dev/null +++ b/core/src/encryption.rs @@ -0,0 +1,105 @@ +use base2k::{Backend, Module, Scratch}; +use sampling::source::Source; + +pub trait EncryptSkScratchSpace { + fn encrypt_sk_scratch_space(module: &Module, ct_size: usize) -> usize; +} + +pub trait EncryptSk { + type Ciphertext; + type Plaintext; + type SecretKey; + + fn encrypt_sk( + &self, + module: &Module, + ct: &mut Self::Ciphertext, + pt: &Self::Plaintext, + sk: &Self::SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ); +} + +pub trait EncryptZeroSkScratchSpace { + fn encrypt_zero_sk_scratch_space(module: &Module, ct_size: usize) -> usize; +} + +pub trait EncryptZeroSk { + type Ciphertext; + type SecretKey; + + fn encrypt_zero_sk( + &self, + module: &Module, + ct: &mut Self::Ciphertext, + sk: &Self::SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ); +} + +pub trait EncryptPkScratchSpace { + fn encrypt_pk_scratch_space(module: &Module, ct_size: usize) -> usize; +} + +pub trait EncryptPk { + type Ciphertext; + type Plaintext; + type PublicKey; + + fn encrypt_pk( + &self, + module: &Module, + ct: &mut Self::Ciphertext, + pt: &Self::Plaintext, + pk: &Self::PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ); +} + +pub trait EncryptZeroPkScratchSpace { + fn encrypt_zero_pk_scratch_space(module: &Module, ct_size: usize) -> usize; +} + +pub trait EncryptZeroPk { + type Ciphertext; + type PublicKey; + + fn encrypt_zero_pk( + &self, + module: &Module, + ct: &mut Self::Ciphertext, + pk: &Self::PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ); +} + +pub trait Decrypt { + type Plaintext; + type Ciphertext; + type SecretKey; + + fn decrypt( + &self, + module: &Module, + pt: &mut Self::Plaintext, + ct: &Self::Ciphertext, + sk: &Self::SecretKey, + scratch: &mut Scratch, + ); +} diff --git a/core/src/external_product.rs b/core/src/external_product.rs new file mode 100644 index 0000000..e8d0a7e --- /dev/null +++ b/core/src/external_product.rs @@ -0,0 +1,19 @@ +use base2k::{FFT64, Module, Scratch}; + +pub trait ExternalProductScratchSpace { + fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; +} + +pub trait ExternalProduct { + type Lhs; + type Rhs; + fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch); +} +pub trait ExternalProductInplaceScratchSpace { + fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize; +} + +pub trait ExternalProductInplace { + type Rhs; + fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch); +} diff --git a/core/src/rgsw.rs b/core/src/ggsw.rs similarity index 51% rename from core/src/rgsw.rs rename to core/src/ggsw.rs index b866252..79b12a5 100644 --- a/core/src/rgsw.rs +++ b/core/src/ggsw.rs @@ -7,23 +7,26 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{ - GetRow, Infos, MatRLWEProduct, MatRLWEProductScratchSpace, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, - Product, SetRow, + elem::{GetRow, Infos, SetRow}, + encryption::EncryptSkScratchSpace, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, }, - grlwe::GRLWECt, - keys::SecretKeyDft, - rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk}, + glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext, encrypt_glwe_sk}, + keys::SecretKeyFourier, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, + keyswitch_key::GLWEKeySwitchKey, utils::derive_size, + vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; -pub struct RGSWCt { +pub struct GGSWCiphertext { pub data: MatZnxDft, pub log_base2k: usize, pub log_k: usize, } -impl RGSWCt, B> { +impl GGSWCiphertext, B> { pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { Self { data: module.new_mat_znx_dft(rows, 2, 2, derive_size(log_base2k, log_k)), @@ -33,23 +36,23 @@ impl RGSWCt, B> { } } -impl Infos for RGSWCt { +impl Infos for GGSWCiphertext { type Inner = MatZnxDft; fn inner(&self) -> &Self::Inner { &self.data } - fn log_base2k(&self) -> usize { + fn basek(&self) -> usize { self.log_base2k } - fn log_k(&self) -> usize { + fn k(&self) -> usize { self.log_k } } -impl MatZnxDftToMut for RGSWCt +impl MatZnxDftToMut for GGSWCiphertext where MatZnxDft: MatZnxDftToMut, { @@ -58,7 +61,7 @@ where } } -impl MatZnxDftToRef for RGSWCt +impl MatZnxDftToRef for GGSWCiphertext where MatZnxDft: MatZnxDftToRef, { @@ -67,9 +70,9 @@ where } } -impl RGSWCt, FFT64> { +impl GGSWCiphertext, FFT64> { pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { - RLWECt::encrypt_sk_scratch_space(module, size) + GLWECiphertext::encrypt_sk_scratch_space(module, size) + module.bytes_of_vec_znx(2, size) + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) @@ -78,9 +81,9 @@ impl RGSWCt, FFT64> { pub fn encrypt_rgsw_sk( module: &Module, - ct: &mut RGSWCt, + ct: &mut GGSWCiphertext, pt: &ScalarZnx

, - sk_dft: &SecretKeyDft, + sk_dft: &SecretKeyFourier, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -92,21 +95,21 @@ pub fn encrypt_rgsw_sk( ScalarZnxDft: ScalarZnxDftToRef, { let size: usize = ct.size(); - let log_base2k: usize = ct.log_base2k(); + let log_base2k: usize = ct.basek(); let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size); let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, 2, size); - let mut vec_znx_pt: RLWEPt<&mut [u8]> = RLWEPt { + let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { data: tmp_znx_pt, log_base2k: log_base2k, - log_k: ct.log_k(), + log_k: ct.k(), }; - let mut vec_znx_ct: RLWECt<&mut [u8]> = RLWECt { + let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { data: tmp_znx_ct, log_base2k: log_base2k, - log_k: ct.log_k(), + log_k: ct.k(), }; (0..ct.rows()).for_each(|row_j| { @@ -114,9 +117,9 @@ pub fn encrypt_rgsw_sk( module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0); module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2); - (0..ct.cols()).for_each(|col_i| { + (0..ct.rank()).for_each(|col_i| { // rlwe encrypt of vec_znx_pt into vec_znx_ct - encrypt_rlwe_sk( + encrypt_glwe_sk( module, &mut vec_znx_ct, Some((&vec_znx_pt, col_i)), @@ -141,12 +144,12 @@ pub fn encrypt_rgsw_sk( }); } -impl RGSWCt { +impl GGSWCiphertext { pub fn encrypt_sk( &mut self, module: &Module, pt: &ScalarZnx

, + ct: &GLWECiphertext, + sk_dft: &SecretKeyFourier, + scratch: &mut Scratch, +) where + VecZnx

: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct + + { + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct + module.vec_znx_dft(&mut c0_dft, 0, ct, 1); + + // c0_dft = DFT(a) * DFT(s) + module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); + + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); + } + + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_big_add_small_inplace(&mut c0_big, 0, ct, 0); + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(ct.basek(), pt, 0, &mut c0_big, 0, scratch_1); + + pt.log_base2k = ct.basek(); + pt.log_k = pt.k().min(ct.k()); +} + +impl GLWECiphertext +where + VecZnx: VecZnxToMut + VecZnxToRef, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &GLWEPlaintext, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + encrypt_glwe_sk( + module, + self, + Some((pt, 0)), + sk_dft, + source_xa, + source_xe, + sigma, + bound, + scratch, + ) + } + + pub fn encrypt_zero_sk( + &mut self, + module: &Module, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + ScalarZnxDft: ScalarZnxDftToRef, + { + encrypt_glwe_sk::( + module, self, None, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ) + } + + pub fn encrypt_pk( + &mut self, + module: &Module, + pt: &GLWEPlaintext, + pk: &PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + { + encrypt_glwe_pk( + module, + self, + Some(pt), + pk, + source_xu, + source_xe, + sigma, + bound, + scratch, + ) + } + + pub fn encrypt_zero_pk( + &mut self, + module: &Module, + pk: &PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef, + { + encrypt_glwe_pk::( + module, self, None, pk, source_xu, source_xe, sigma, bound, scratch, + ) + } +} + +impl GLWECiphertext +where + VecZnx: VecZnxToRef, +{ + pub fn decrypt( + &self, + module: &Module, + pt: &mut GLWEPlaintext, + sk_dft: &SecretKeyFourier, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut + VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + decrypt_glwe(module, pt, self, sk_dft, scratch); + } +} + +pub(crate) fn encrypt_glwe_pk( + module: &Module, + ct: &mut GLWECiphertext, + pt: Option<&GLWEPlaintext

>, + pk: &PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx

: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, +{ + #[cfg(debug_assertions)] + { + assert_eq!(ct.basek(), pk.basek()); + assert_eq!(ct.n(), module.n()); + assert_eq!(pk.n(), module.n()); + if let Some(pt) = pt { + assert_eq!(pt.basek(), pk.basek()); + assert_eq!(pt.n(), module.n()); + } + } + + let log_base2k: usize = pk.basek(); + let size_pk: usize = pk.size(); + + // Generates u according to the underlying secret distribution. + let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); + + { + let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); + match pk.dist { + SecretDistribution::NONE => panic!( + "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate" + ), + SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), + SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + SecretDistribution::ZERO => {} + } + + module.svp_prepare(&mut u_dft, 0, &u, 0); + } + + let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) + let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) + + // ct[0] = pk[0] * u + m + e0 + module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0); + module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); + tmp_big.add_normal(log_base2k, 0, pk.k(), source_xe, sigma, bound); + + if let Some(pt) = pt { + module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0); + } + + module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3); + + // ct[1] = pk[1] * u + e1 + module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1); + module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); + tmp_big.add_normal(log_base2k, 0, pk.k(), source_xe, sigma, bound); + module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3); +} + +pub struct GLWEPlaintext { + pub data: VecZnx, + pub log_base2k: usize, + pub log_k: usize, +} + +impl Infos for GLWEPlaintext { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.log_base2k + } + + fn k(&self) -> usize { + self.log_k + } +} + +impl VecZnxToMut for GLWEPlaintext +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for GLWEPlaintext +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl GLWEPlaintext> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +pub struct GLWECiphertextFourier { + pub data: VecZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl GLWECiphertextFourier, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for GLWECiphertextFourier { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.log_base2k + } + + fn k(&self) -> usize { + self.log_k + } +} + +impl VecZnxDftToMut for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut, +{ + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl VecZnxDftToRef for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToRef, +{ + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl GLWECiphertextFourier +where + GLWECiphertextFourier: VecZnxDftToRef, +{ + #[allow(dead_code)] + pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { + module.bytes_of_vec_znx(2, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) + } + + pub(crate) fn idft(&self, module: &Module, res: &mut GLWECiphertext, scratch: &mut Scratch) + where + GLWECiphertext: VecZnxToMut, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), 2); + assert_eq!(res.rank(), 2); + assert_eq!(self.basek(), res.basek()) + } + + let min_size: usize = self.size().min(res.size()); + + let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size); + + module.vec_znx_idft(&mut res_big, 0, self, 0, scratch1); + module.vec_znx_idft(&mut res_big, 1, self, 1, scratch1); + module.vec_znx_big_normalize(self.basek(), res, 0, &res_big, 0, scratch1); + module.vec_znx_big_normalize(self.basek(), res, 1, &res_big, 1, scratch1); + } +} + +pub(crate) fn encrypt_zero_glwe_dft_sk( + module: &Module, + ct: &mut GLWECiphertextFourier, + sk: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let log_base2k: usize = ct.basek(); + let log_k: usize = ct.k(); + let size: usize = ct.size(); + + #[cfg(debug_assertions)] + { + match sk.dist { + SecretDistribution::NONE => panic!("invalid sk.dist = SecretDistribution::NONE"), + _ => {} + } + assert_eq!(ct.rank(), 2); + } + + // ct[1] = DFT(a) + { + let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size); + tmp_znx.fill_uniform(log_base2k, 0, size, source_xa); + module.vec_znx_dft(ct, 1, &tmp_znx, 0); + } + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); + + { + let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + // c0_dft = ct[1] * DFT(s) + module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1); + + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0); + } + + // c0_big += e + c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); + + // c0 = norm(c0_big = -as - e), NOTE: e is centered at 0. + let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size); + module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2); + module.vec_znx_negate_inplace(&mut tmp_znx, 0); + // ct[0] = DFT(-as + e) + module.vec_znx_dft(ct, 0, &tmp_znx, 0); +} + +impl GLWECiphertextFourier, FFT64> { + pub fn encrypt_zero_sk_scratch_space(module: &Module, size: usize) -> usize { + (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) + + module.bytes_of_vec_znx_big(1, size) + + module.bytes_of_vec_znx(1, size) + + module.vec_znx_big_normalize_tmp_bytes() + } + + pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() + | module.bytes_of_vec_znx_dft(1, size) + | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) + + module.bytes_of_vec_znx_big(1, size) + } +} + +pub fn decrypt_rlwe_dft( + module: &Module, + pt: &mut GLWEPlaintext

, + ct: &GLWECiphertextFourier, + sk: &SecretKeyFourier, + scratch: &mut Scratch, +) where + VecZnx

: VecZnxToMut + VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct + + { + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct + // c0_dft = DFT(a) * DFT(s) + module.svp_apply(&mut c0_dft, 0, sk, 0, ct, 1); + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); + } + + { + let (mut c1_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, ct.size()); + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_idft(&mut c1_big, 0, ct, 0, scratch_2); + module.vec_znx_big_add_inplace(&mut c0_big, 0, &c1_big, 0); + } + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(ct.basek(), pt, 0, &mut c0_big, 0, scratch_1); + + pt.log_base2k = ct.basek(); + pt.log_k = pt.k().min(ct.k()); +} + +impl GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, +{ + pub(crate) fn encrypt_zero_sk( + &mut self, + module: &Module, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + ScalarZnxDft: ScalarZnxDftToRef, + { + encrypt_zero_glwe_dft_sk( + module, self, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ) + } + + pub fn decrypt( + &self, + module: &Module, + pt: &mut GLWEPlaintext

, + sk_dft: &SecretKeyFourier, + scratch: &mut Scratch, + ) where + VecZnx

: VecZnxToMut + VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + decrypt_rlwe_dft(module, pt, self, sk_dft, scratch); + } +} + +impl KeySwitchScratchSpace for GLWECiphertextFourier, FFT64> { + fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } +} + +impl KeySwitch for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GLWECiphertextFourier; + type Rhs = GLWEKeySwitchKey; + + fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe_fourier(module, self, lhs, scratch); + } +} + +impl KeySwitchInplaceScratchSpace for GLWECiphertextFourier, FFT64> { + fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl KeySwitchInplace for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Rhs = GLWEKeySwitchKey; + + fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe_fourier_inplace(module, self, scratch); + } +} + +impl ExternalProductScratchSpace for GLWECiphertextFourier, FFT64> { + fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } +} + +impl ExternalProduct for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GLWECiphertextFourier; + type Rhs = GGSWCiphertext; + + fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe_fourier(module, self, lhs, scratch); + } +} + +impl ExternalProductInplaceScratchSpace for GLWECiphertextFourier, FFT64> { + fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl ExternalProductInplace for GLWECiphertextFourier +where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + type Rhs = GGSWCiphertext; + + fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe_fourier_inplace(module, self, scratch); + } +} diff --git a/core/src/keys.rs b/core/src/keys.rs index 8285f85..eaa569e 100644 --- a/core/src/keys.rs +++ b/core/src/keys.rs @@ -5,7 +5,7 @@ use base2k::{ }; use sampling::source::Source; -use crate::{elem::Infos, rlwe::RLWECtDft}; +use crate::{elem::Infos, glwe::GLWECiphertextFourier}; #[derive(Clone, Copy, Debug)] pub enum SecretDistribution { @@ -67,12 +67,12 @@ where } } -pub struct SecretKeyDft { +pub struct SecretKeyFourier { pub data: ScalarZnxDft, pub dist: SecretDistribution, } -impl SecretKeyDft, B> { +impl SecretKeyFourier, B> { pub fn new(module: &Module) -> Self { Self { data: module.new_scalar_znx_dft(1), @@ -82,7 +82,7 @@ impl SecretKeyDft, B> { pub fn dft(&mut self, module: &Module, sk: &SecretKey) where - SecretKeyDft, B>: ScalarZnxDftToMut, + SecretKeyFourier, B>: ScalarZnxDftToMut, SecretKey: ScalarZnxToRef, { #[cfg(debug_assertions)] @@ -98,7 +98,7 @@ impl SecretKeyDft, B> { } } -impl ScalarZnxDftToMut for SecretKeyDft +impl ScalarZnxDftToMut for SecretKeyFourier where ScalarZnxDft: ScalarZnxDftToMut, { @@ -107,7 +107,7 @@ where } } -impl ScalarZnxDftToRef for SecretKeyDft +impl ScalarZnxDftToRef for SecretKeyFourier where ScalarZnxDft: ScalarZnxDftToRef, { @@ -117,14 +117,14 @@ where } pub struct PublicKey { - pub data: RLWECtDft, + pub data: GLWECiphertextFourier, pub dist: SecretDistribution, } impl PublicKey, B> { pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { Self { - data: RLWECtDft::new(module, log_base2k, log_k), + data: GLWECiphertextFourier::new(module, log_base2k, log_k), dist: SecretDistribution::NONE, } } @@ -137,11 +137,11 @@ impl Infos for PublicKey { &self.data.data } - fn log_base2k(&self) -> usize { + fn basek(&self) -> usize { self.data.log_base2k } - fn log_k(&self) -> usize { + fn k(&self) -> usize { self.data.log_k } } @@ -168,7 +168,7 @@ impl PublicKey { pub fn generate( &mut self, module: &Module, - sk_dft: &SecretKeyDft, + sk_dft: &SecretKeyFourier, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -186,7 +186,7 @@ impl PublicKey { } // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::new(RLWECtDft::encrypt_zero_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::new(GLWECiphertextFourier::encrypt_zero_sk_scratch_space( module, self.size(), )); diff --git a/core/src/keyswitch.rs b/core/src/keyswitch.rs new file mode 100644 index 0000000..c77ccb4 --- /dev/null +++ b/core/src/keyswitch.rs @@ -0,0 +1,20 @@ +use base2k::{FFT64, Module, Scratch}; + +pub trait KeySwitchScratchSpace { + fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; +} + +pub trait KeySwitch { + type Lhs; + type Rhs; + fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch); +} + +pub trait KeySwitchInplaceScratchSpace { + fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize; +} + +pub trait KeySwitchInplace { + type Rhs; + fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch); +} diff --git a/core/src/grlwe.rs b/core/src/keyswitch_key.rs similarity index 51% rename from core/src/grlwe.rs rename to core/src/keyswitch_key.rs index 80c976d..cb4c248 100644 --- a/core/src/grlwe.rs +++ b/core/src/keyswitch_key.rs @@ -7,23 +7,26 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{ - GetRow, Infos, MatRLWEProduct, MatRLWEProductScratchSpace, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, - Product, SetRow, + elem::{GetRow, Infos, SetRow}, + encryption::EncryptSkScratchSpace, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, }, - keys::SecretKeyDft, - rgsw::RGSWCt, - rlwe::{RLWECt, RLWECtDft, RLWEPt}, + ggsw::GGSWCiphertext, + glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext}, + keys::SecretKeyFourier, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, utils::derive_size, + vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, }; -pub struct GRLWECt { +pub struct GLWEKeySwitchKey { pub data: MatZnxDft, pub log_base2k: usize, pub log_k: usize, } -impl GRLWECt, B> { +impl GLWEKeySwitchKey, B> { pub fn new(module: &Module, log_base2k: usize, log_k: usize, rows: usize) -> Self { Self { data: module.new_mat_znx_dft(rows, 1, 2, derive_size(log_base2k, log_k)), @@ -33,23 +36,23 @@ impl GRLWECt, B> { } } -impl Infos for GRLWECt { +impl Infos for GLWEKeySwitchKey { type Inner = MatZnxDft; fn inner(&self) -> &Self::Inner { &self.data } - fn log_base2k(&self) -> usize { + fn basek(&self) -> usize { self.log_base2k } - fn log_k(&self) -> usize { + fn k(&self) -> usize { self.log_k } } -impl MatZnxDftToMut for GRLWECt +impl MatZnxDftToMut for GLWEKeySwitchKey where MatZnxDft: MatZnxDftToMut, { @@ -58,7 +61,7 @@ where } } -impl MatZnxDftToRef for GRLWECt +impl MatZnxDftToRef for GLWEKeySwitchKey where MatZnxDft: MatZnxDftToRef, { @@ -67,20 +70,20 @@ where } } -impl GRLWECt, FFT64> { +impl GLWEKeySwitchKey, FFT64> { pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { - RLWECt::encrypt_sk_scratch_space(module, size) + GLWECiphertext::encrypt_sk_scratch_space(module, size) + module.bytes_of_vec_znx(2, size) + module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx_dft(2, size) } } -pub fn encrypt_grlwe_sk( +pub fn encrypt_glwe_key_switch_key_sk( module: &Module, - ct: &mut GRLWECt, + ct: &mut GLWEKeySwitchKey, pt: &ScalarZnx

, - sk_dft: &SecretKeyDft, + sk_dft: &SecretKeyFourier, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -93,22 +96,22 @@ pub fn encrypt_grlwe_sk( { let rows: usize = ct.rows(); let size: usize = ct.size(); - let log_base2k: usize = ct.log_base2k(); + let log_base2k: usize = ct.basek(); let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size); let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size); let (mut vec_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size); - let mut vec_znx_pt: RLWEPt<&mut [u8]> = RLWEPt { + let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { data: tmp_znx_pt, log_base2k: log_base2k, - log_k: ct.log_k(), + log_k: ct.k(), }; - let mut vec_znx_ct: RLWECt<&mut [u8]> = RLWECt { + let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { data: tmp_znx_ct, log_base2k: log_base2k, - log_k: ct.log_k(), + log_k: ct.k(), }; (0..rows).for_each(|row_i| { @@ -119,7 +122,7 @@ pub fn encrypt_grlwe_sk( // rlwe encrypt of vec_znx_pt into vec_znx_ct vec_znx_ct.encrypt_sk( module, - Some(&vec_znx_pt), + &vec_znx_pt, sk_dft, source_xa, source_xe, @@ -139,12 +142,12 @@ pub fn encrypt_grlwe_sk( }); } -impl GRLWECt { +impl GLWEKeySwitchKey { pub fn encrypt_sk( &mut self, module: &Module, pt: &ScalarZnx

, - sk_dft: &SecretKeyDft, + sk_dft: &SecretKeyFourier, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -155,17 +158,17 @@ impl GRLWECt { ScalarZnx

: ScalarZnxToRef, ScalarZnxDft: ScalarZnxDftToRef, { - encrypt_grlwe_sk( + encrypt_glwe_key_switch_key_sk( module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch, ) } } -impl GetRow for GRLWECt +impl GetRow for GLWEKeySwitchKey where MatZnxDft: MatZnxDftToRef, { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) where VecZnxDft: VecZnxDftToMut, { @@ -177,11 +180,11 @@ where } } -impl SetRow for GRLWECt +impl SetRow for GLWEKeySwitchKey where MatZnxDft: MatZnxDftToMut, { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) where VecZnxDft: VecZnxDftToRef, { @@ -193,8 +196,92 @@ where } } -impl MatRLWEProductScratchSpace for GRLWECt, FFT64> { - fn prod_with_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { +impl KeySwitchScratchSpace for GLWEKeySwitchKey, FFT64> { + fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl KeySwitch for GLWEKeySwitchKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GLWEKeySwitchKey; + type Rhs = GLWEKeySwitchKey; + + fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe(module, self, lhs, scratch); + } +} + +impl KeySwitchInplaceScratchSpace for GLWEKeySwitchKey, FFT64> { + fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl KeySwitchInplace for GLWEKeySwitchKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Rhs = GLWEKeySwitchKey; + + fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe(module, self, rhs, scratch); + } +} + +impl ExternalProductScratchSpace for GLWEKeySwitchKey, FFT64> { + fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl ExternalProduct for GLWEKeySwitchKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GLWEKeySwitchKey; + type Rhs = GGSWCiphertext; + + fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe(module, self, lhs, scratch); + } +} + +impl ExternalProductInplaceScratchSpace for GLWEKeySwitchKey, FFT64> { + fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl ExternalProductInplace for GLWEKeySwitchKey +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Rhs = GGSWCiphertext; + + fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe_inplace(module, self, scratch); + } +} + +impl VecGLWEProductScratchSpace for GLWEKeySwitchKey, FFT64> { + fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, grlwe_size: usize) -> usize { module.bytes_of_vec_znx_dft(2, grlwe_size) + (module.vec_znx_big_normalize_tmp_bytes() | (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size) @@ -202,22 +289,27 @@ impl MatRLWEProductScratchSpace for GRLWECt, FFT64> { } } -impl MatRLWEProduct for GRLWECt +impl VecGLWEProduct for GLWEKeySwitchKey where MatZnxDft: MatZnxDftToRef + ZnxInfos, { - fn prod_with_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) - where + fn prod_with_glwe( + &self, + module: &Module, + res: &mut GLWECiphertext, + a: &GLWECiphertext, + scratch: &mut Scratch, + ) where MatZnxDft: MatZnxDftToRef, VecZnx: VecZnxToMut, VecZnx: VecZnxToRef, { - let log_base2k: usize = self.log_base2k(); + let log_base2k: usize = self.basek(); #[cfg(debug_assertions)] { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(a.log_base2k(), log_base2k); + assert_eq!(res.basek(), log_base2k); + assert_eq!(a.basek(), log_base2k); assert_eq!(self.n(), module.n()); assert_eq!(res.n(), module.n()); assert_eq!(a.n(), module.n()); @@ -239,53 +331,3 @@ where module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); } } - -impl ProdInplaceScratchSpace for GRLWECt, FFT64> { - fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) - } - - fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) - } -} - -impl ProdScratchSpace for GRLWECt, FFT64> { - fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) - } - - fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) - } -} - -impl ProdInplace for GRLWECt -where - GRLWECt: GetRow + SetRow + Infos, - MatZnxDft: MatZnxDftToRef, -{ - fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe_inplace(module, self, scratch); - } - - fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe_inplace(module, self, scratch); - } -} - -impl Product for GRLWECt -where - MatZnxDft: MatZnxDftToRef + MatZnxDftToMut, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = GRLWECt; - - fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe(module, self, lhs, scratch); - } - - fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe(module, self, lhs, scratch); - } -} diff --git a/core/src/lib.rs b/core/src/lib.rs index bed71cc..97db860 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,8 +1,12 @@ pub mod elem; -pub mod grlwe; +pub mod encryption; +pub mod external_product; +pub mod ggsw; +pub mod glwe; pub mod keys; -pub mod rgsw; -pub mod rlwe; +pub mod keyswitch; +pub mod keyswitch_key; #[cfg(test)] mod test_fft64; mod utils; +pub mod vec_glwe_product; diff --git a/core/src/rlwe.rs b/core/src/rlwe.rs deleted file mode 100644 index 2dab803..0000000 --- a/core/src/rlwe.rs +++ /dev/null @@ -1,701 +0,0 @@ -use base2k::{ - AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, - ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, - VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, -}; -use sampling::source::Source; - -use crate::{ - elem::{Infos, MatRLWEProduct, MatRLWEProductScratchSpace, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, - grlwe::GRLWECt, - keys::{PublicKey, SecretDistribution, SecretKeyDft}, - rgsw::RGSWCt, - utils::derive_size, -}; - -pub struct RLWECt { - pub data: VecZnx, - pub log_base2k: usize, - pub log_k: usize, -} - -impl RLWECt> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx(2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for RLWECt { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl VecZnxToMut for RLWECt -where - VecZnx: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data.to_mut() - } -} - -impl VecZnxToRef for RLWECt -where - VecZnx: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data.to_ref() - } -} - -impl RLWECt -where - VecZnx: VecZnxToRef, -{ - #[allow(dead_code)] - pub(crate) fn dft(&self, module: &Module, res: &mut RLWECtDft) - where - VecZnxDft: VecZnxDftToMut + ZnxInfos, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.cols(), 2); - assert_eq!(res.cols(), 2); - assert_eq!(self.log_base2k(), res.log_base2k()) - } - - module.vec_znx_dft(res, 0, self, 0); - module.vec_znx_dft(res, 1, self, 1); - } -} - -impl ProdInplaceScratchSpace for RLWECt> { - fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_inplace_scratch_space(module, lhs, rhs) - } - - fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_inplace_scratch_space(module, lhs, rhs) - } -} - -impl ProdScratchSpace for RLWECt> { - fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_scratch_space(module, res_size, lhs, rhs) - } - - fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_scratch_space(module, res_size, lhs, rhs) - } -} - -impl ProdInplace for RLWECt -where - VecZnx: VecZnxToMut + VecZnxToRef, - MatZnxDft: MatZnxDftToRef, -{ - fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_rlwe_inplace(module, self, scratch); - } - - fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_rlwe_inplace(module, self, scratch); - } -} - -impl Product for RLWECt -where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = RLWECt; - - fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_rlwe(module, self, lhs, scratch); - } - - fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_rlwe(module, self, lhs, scratch); - } -} - -impl RLWECt> { - pub fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) - } - - pub fn encrypt_pk_scratch_space(module: &Module, pk_size: usize) -> usize { - ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) - + module.bytes_of_scalar_znx_dft(1) - + module.vec_znx_big_normalize_tmp_bytes() - } - - pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) - } -} - -pub fn encrypt_rlwe_sk( - module: &Module, - ct: &mut RLWECt, - pt: Option<(&RLWEPt

, usize)>, - sk_dft: &SecretKeyDft, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx

: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let log_base2k: usize = ct.log_base2k(); - let log_k: usize = ct.log_k(); - let size: usize = ct.size(); - - // c1 = a - ct.data.fill_uniform(log_base2k, 1, size, source_xa); - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); - - { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); - module.vec_znx_dft(&mut c0_dft, 0, ct, 1); - - // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); - } - - // c0_big = m - c0_big - if let Some((pt, col)) = pt { - match col { - 0 => module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0), - 1 => { - module.vec_znx_big_negate_inplace(&mut c0_big, 0); - module.vec_znx_add_inplace(ct, 1, pt, 0); - module.vec_znx_normalize_inplace(log_base2k, ct, 1, scratch_1); - } - _ => panic!("invalid target column: {}", col), - } - } else { - module.vec_znx_big_negate_inplace(&mut c0_big, 0); - } - // c0_big += e - c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); - - // c0 = norm(c0_big = -as + m + e) - module.vec_znx_big_normalize(log_base2k, ct, 0, &c0_big, 0, scratch_1); -} - -pub fn decrypt_rlwe( - module: &Module, - pt: &mut RLWEPt

, - ct: &RLWECt, - sk_dft: &SecretKeyDft, - scratch: &mut Scratch, -) where - VecZnx

: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct - - { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct - module.vec_znx_dft(&mut c0_dft, 0, ct, 1); - - // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); - } - - // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) - module.vec_znx_big_add_small_inplace(&mut c0_big, 0, ct, 0); - - // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1); - - pt.log_base2k = ct.log_base2k(); - pt.log_k = pt.log_k().min(ct.log_k()); -} - -impl RLWECt { - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: Option<&RLWEPt

>, - sk_dft: &SecretKeyDft, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx

: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - if let Some(pt) = pt { - encrypt_rlwe_sk( - module, - self, - Some((pt, 0)), - sk_dft, - source_xa, - source_xe, - sigma, - bound, - scratch, - ) - } else { - encrypt_rlwe_sk::( - module, self, None, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ) - } - } - - pub fn decrypt( - &self, - module: &Module, - pt: &mut RLWEPt

, - sk_dft: &SecretKeyDft, - scratch: &mut Scratch, - ) where - VecZnx

: VecZnxToMut + VecZnxToRef, - VecZnx: VecZnxToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - decrypt_rlwe(module, pt, self, sk_dft, scratch); - } - - pub fn encrypt_pk( - &mut self, - module: &Module, - pt: Option<&RLWEPt

>, - pk: &PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx

: VecZnxToRef, - VecZnxDft: VecZnxDftToRef, - { - encrypt_rlwe_pk( - module, self, pt, pk, source_xu, source_xe, sigma, bound, scratch, - ) - } -} - -pub(crate) fn encrypt_rlwe_pk( - module: &Module, - ct: &mut RLWECt, - pt: Option<&RLWEPt

>, - pk: &PublicKey, - source_xu: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx

: VecZnxToRef, - VecZnxDft: VecZnxDftToRef, -{ - #[cfg(debug_assertions)] - { - assert_eq!(ct.log_base2k(), pk.log_base2k()); - assert_eq!(ct.n(), module.n()); - assert_eq!(pk.n(), module.n()); - if let Some(pt) = pt { - assert_eq!(pt.log_base2k(), pk.log_base2k()); - assert_eq!(pt.n(), module.n()); - } - } - - let log_base2k: usize = pk.log_base2k(); - let size_pk: usize = pk.size(); - - // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1); - - { - let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1); - match pk.dist { - SecretDistribution::NONE => panic!( - "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate" - ), - SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), - SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), - SecretDistribution::ZERO => {} - } - - module.svp_prepare(&mut u_dft, 0, &u, 0); - } - - let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) - let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) - - // ct[0] = pk[0] * u + m + e0 - module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0); - module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); - tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); - - if let Some(pt) = pt { - module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0); - } - - module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3); - - // ct[1] = pk[1] * u + e1 - module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1); - module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); - tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); - module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3); -} - -pub struct RLWEPt { - pub data: VecZnx, - pub log_base2k: usize, - pub log_k: usize, -} - -impl Infos for RLWEPt { - type Inner = VecZnx; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl VecZnxToMut for RLWEPt -where - VecZnx: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data.to_mut() - } -} - -impl VecZnxToRef for RLWEPt -where - VecZnx: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data.to_ref() - } -} - -impl RLWEPt> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -pub struct RLWECtDft { - pub data: VecZnxDft, - pub log_base2k: usize, - pub log_k: usize, -} - -impl RLWECtDft, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { - Self { - data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_k)), - log_base2k: log_base2k, - log_k: log_k, - } - } -} - -impl Infos for RLWECtDft { - type Inner = VecZnxDft; - - fn inner(&self) -> &Self::Inner { - &self.data - } - - fn log_base2k(&self) -> usize { - self.log_base2k - } - - fn log_k(&self) -> usize { - self.log_k - } -} - -impl VecZnxDftToMut for RLWECtDft -where - VecZnxDft: VecZnxDftToMut, -{ - fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { - self.data.to_mut() - } -} - -impl VecZnxDftToRef for RLWECtDft -where - VecZnxDft: VecZnxDftToRef, -{ - fn to_ref(&self) -> VecZnxDft<&[u8], B> { - self.data.to_ref() - } -} - -impl RLWECtDft -where - RLWECtDft: VecZnxDftToRef, -{ - #[allow(dead_code)] - pub(crate) fn idft_scratch_space(module: &Module, size: usize) -> usize { - module.bytes_of_vec_znx(2, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) - } - - pub(crate) fn idft(&self, module: &Module, res: &mut RLWECt, scratch: &mut Scratch) - where - RLWECt: VecZnxToMut, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.cols(), 2); - assert_eq!(res.cols(), 2); - assert_eq!(self.log_base2k(), res.log_base2k()) - } - - let min_size: usize = self.size().min(res.size()); - - let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size); - - module.vec_znx_idft(&mut res_big, 0, self, 0, scratch1); - module.vec_znx_idft(&mut res_big, 1, self, 1, scratch1); - module.vec_znx_big_normalize(self.log_base2k(), res, 0, &res_big, 0, scratch1); - module.vec_znx_big_normalize(self.log_base2k(), res, 1, &res_big, 1, scratch1); - } -} - -pub(crate) fn encrypt_zero_rlwe_dft_sk( - module: &Module, - ct: &mut RLWECtDft, - sk: &SecretKeyDft, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, -) where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let log_base2k: usize = ct.log_base2k(); - let log_k: usize = ct.log_k(); - let size: usize = ct.size(); - - #[cfg(debug_assertions)] - { - match sk.dist { - SecretDistribution::NONE => panic!("invalid sk.dist = SecretDistribution::NONE"), - _ => {} - } - assert_eq!(ct.cols(), 2); - } - - // ct[1] = DFT(a) - { - let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size); - tmp_znx.fill_uniform(log_base2k, 0, size, source_xa); - module.vec_znx_dft(ct, 1, &tmp_znx, 0); - } - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); - - { - let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); - // c0_dft = ct[1] * DFT(s) - module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0); - } - - // c0_big += e - c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); - - // c0 = norm(c0_big = -as - e), NOTE: e is centered at 0. - let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size); - module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2); - module.vec_znx_negate_inplace(&mut tmp_znx, 0); - // ct[0] = DFT(-as + e) - module.vec_znx_dft(ct, 0, &tmp_znx, 0); -} - -impl RLWECtDft, FFT64> { - pub fn encrypt_zero_sk_scratch_space(module: &Module, size: usize) -> usize { - (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) - + module.bytes_of_vec_znx_big(1, size) - + module.bytes_of_vec_znx(1, size) - + module.vec_znx_big_normalize_tmp_bytes() - } - - pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() - | module.bytes_of_vec_znx_dft(1, size) - | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) - + module.bytes_of_vec_znx_big(1, size) - } -} - -pub fn decrypt_rlwe_dft( - module: &Module, - pt: &mut RLWEPt

, - ct: &RLWECtDft, - sk: &SecretKeyDft, - scratch: &mut Scratch, -) where - VecZnx

: VecZnxToMut + VecZnxToRef, - VecZnxDft: VecZnxDftToRef, - ScalarZnxDft: ScalarZnxDftToRef, -{ - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct - - { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct - // c0_dft = DFT(a) * DFT(s) - module.svp_apply(&mut c0_dft, 0, sk, 0, ct, 1); - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); - } - - { - let (mut c1_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, ct.size()); - // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) - module.vec_znx_idft(&mut c1_big, 0, ct, 0, scratch_2); - module.vec_znx_big_add_inplace(&mut c0_big, 0, &c1_big, 0); - } - - // pt = norm(BIG(m + e)) - module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1); - - pt.log_base2k = ct.log_base2k(); - pt.log_k = pt.log_k().min(ct.log_k()); -} - -impl RLWECtDft { - pub(crate) fn encrypt_zero_sk( - &mut self, - module: &Module, - sk_dft: &SecretKeyDft, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - bound: f64, - scratch: &mut Scratch, - ) where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - encrypt_zero_rlwe_dft_sk( - module, self, sk_dft, source_xa, source_xe, sigma, bound, scratch, - ) - } - - pub fn decrypt( - &self, - module: &Module, - pt: &mut RLWEPt

, - sk_dft: &SecretKeyDft, - scratch: &mut Scratch, - ) where - VecZnx

: VecZnxToMut + VecZnxToRef, - VecZnxDft: VecZnxDftToRef, - ScalarZnxDft: ScalarZnxDftToRef, - { - decrypt_rlwe_dft(module, pt, self, sk_dft, scratch); - } -} - -impl ProdInplaceScratchSpace for RLWECtDft, FFT64> { - fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_inplace_scratch_space(module, lhs, rhs) - } - - fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_inplace_scratch_space(module, lhs, rhs) - } -} - -impl ProdScratchSpace for RLWECtDft, FFT64> { - fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_scratch_space(module, res_size, lhs, rhs) - } - - fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_rlwe_dft_scratch_space(module, res_size, lhs, rhs) - } -} - -impl ProdInplace for RLWECtDft -where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_rlwe_dft_inplace(module, self, scratch); - } - - fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_rlwe_dft_inplace(module, self, scratch); - } -} - -impl Product for RLWECtDft -where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, - VecZnxDft: VecZnxDftToRef, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = RLWECtDft; - - fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_rlwe_dft(module, self, lhs, scratch); - } - - fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_rlwe_dft(module, self, lhs, scratch); - } -} diff --git a/core/src/test_fft64/grlwe.rs b/core/src/test_fft64/grlwe.rs index 81c1023..9d9a077 100644 --- a/core/src/test_fft64/grlwe.rs +++ b/core/src/test_fft64/grlwe.rs @@ -2,11 +2,15 @@ use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZ use sampling::source::Source; use crate::{ - elem::{GetRow, Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, - grlwe::GRLWECt, - keys::{SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECtDft, RLWEPt}, + elem::{GetRow, Infos}, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, + }, + ggsw::GGSWCiphertext, + glwe::{GLWECiphertextFourier, GLWEPlaintext}, + keys::{SecretKey, SecretKeyFourier}, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, + keyswitch_key::GLWEKeySwitchKey, test_fft64::rgsw::noise_rgsw_product, }; @@ -20,8 +24,8 @@ fn encrypt_sk() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_ct, rows); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut ct: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_ct, rows); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -31,13 +35,14 @@ fn encrypt_sk() { pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), ); let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct.encrypt_sk( @@ -51,7 +56,7 @@ fn encrypt_sk() { scratch.borrow(), ); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct); (0..ct.rows()).for_each(|row_i| { ct.get_row(&module, row_i, 0, &mut ct_rlwe_dft); @@ -60,12 +65,10 @@ fn encrypt_sk() { let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2(); assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); }); - - module.free(); } #[test] -fn from_prod_by_grlwe() { +fn keyswitch() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -74,18 +77,18 @@ fn from_prod_by_grlwe() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s0s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s0s1: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s1s2: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s0s2: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) - | GRLWECt::prod_by_grlwe_scratch_space( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s2.size()) + | GLWEKeySwitchKey::keyswitch_scratch_space( &module, ct_grlwe_s0s2.size(), ct_grlwe_s0s1.size(), @@ -96,19 +99,19 @@ fn from_prod_by_grlwe() { let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); let mut sk2: SecretKey> = SecretKey::new(&module); sk2.fill_ternary_prob(0.5, &mut source_xs); - let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk2_dft.dft(&module, &sk2); // GRLWE_{s1}(s0) = s0 -> s1 @@ -136,10 +139,11 @@ fn from_prod_by_grlwe() { ); // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s0s2.prod_by_grlwe(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); + ct_grlwe_s0s2.keyswitch(&module, &ct_grlwe_s0s1, &ct_grlwe_s1s2, scratch.borrow()); - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); @@ -166,12 +170,10 @@ fn from_prod_by_grlwe() { noise_want ); }); - - module.free(); } #[test] -fn prod_by_grlwe() { +fn keyswitch_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -180,35 +182,35 @@ fn prod_by_grlwe() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe_s0s1: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_s1s2: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s0s1: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_s1s2: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) - | GRLWECt::prod_by_grlwe_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_s0s1.size()) + | GLWEKeySwitchKey::keyswitch_inplace_scratch_space(&module, ct_grlwe_s0s1.size(), ct_grlwe_s1s2.size()), ); let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); let mut sk2: SecretKey> = SecretKey::new(&module); sk2.fill_ternary_prob(0.5, &mut source_xs); - let mut sk2_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk2_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk2_dft.dft(&module, &sk2); // GRLWE_{s1}(s0) = s0 -> s1 @@ -236,12 +238,13 @@ fn prod_by_grlwe() { ); // GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0) - ct_grlwe_s0s1.prod_by_grlwe_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); + ct_grlwe_s0s1.keyswitch_inplace(&module, &ct_grlwe_s1s2, scratch.borrow()); - let ct_grlwe_s0s2: GRLWECt, FFT64> = ct_grlwe_s0s1; + let ct_grlwe_s0s2: GLWEKeySwitchKey, FFT64> = ct_grlwe_s0s1; - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); (0..ct_grlwe_s0s2.rows()).for_each(|row_i| { ct_grlwe_s0s2.get_row(&module, row_i, 0, &mut ct_rlwe_dft_s0s2); @@ -268,12 +271,10 @@ fn prod_by_grlwe() { noise_want ); }); - - module.free(); } #[test] -fn from_prod_by_rgsw() { +fn external_product() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -282,9 +283,9 @@ fn from_prod_by_rgsw() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe_in: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_grlwe_out: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_in: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe_out: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); @@ -294,15 +295,15 @@ fn from_prod_by_rgsw() { let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_in.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_out.size()) - | GRLWECt::prod_by_rgsw_scratch_space( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe_in.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe_out.size()) + | GLWEKeySwitchKey::external_product_scratch_space( &module, ct_grlwe_out.size(), ct_grlwe_in.size(), ct_rgsw.size(), ) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()), + | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()), ); let k: usize = 1; @@ -314,7 +315,7 @@ fn from_prod_by_rgsw() { let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); // GRLWE_{s1}(s0) = s0 -> s1 @@ -341,10 +342,11 @@ fn from_prod_by_rgsw() { ); // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) - ct_grlwe_out.prod_by_rgsw(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); + ct_grlwe_out.external_product(&module, &ct_grlwe_in, &ct_rgsw, scratch.borrow()); - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); @@ -382,12 +384,10 @@ fn from_prod_by_rgsw() { noise_want ); }); - - module.free(); } #[test] -fn prod_by_rgsw() { +fn external_product_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -396,8 +396,8 @@ fn prod_by_rgsw() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut pt_grlwe: ScalarZnx> = module.new_scalar_znx(1); @@ -407,10 +407,10 @@ fn prod_by_rgsw() { let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_grlwe.size()) - | GRLWECt::prod_by_rgsw_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()), + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_grlwe.size()) + | GLWEKeySwitchKey::external_product_inplace_scratch_space(&module, ct_grlwe.size(), ct_rgsw.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()), ); let k: usize = 1; @@ -422,7 +422,7 @@ fn prod_by_rgsw() { let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); // GRLWE_{s1}(s0) = s0 -> s1 @@ -449,10 +449,11 @@ fn prod_by_rgsw() { ); // GRLWE_(m) (x) RGSW_(X^k) = GRLWE_(m * X^k) - ct_grlwe.prod_by_rgsw_inplace(&module, &ct_rgsw, scratch.borrow()); + ct_grlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); - let mut ct_rlwe_dft_s0s2: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_grlwe); + let mut ct_rlwe_dft_s0s2: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_grlwe); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_grlwe); module.vec_znx_rotate_inplace(k as i64, &mut pt_grlwe, 0); @@ -490,8 +491,6 @@ fn prod_by_rgsw() { noise_want ); }); - - module.free(); } pub(crate) fn noise_grlwe_rlwe_product( diff --git a/core/src/test_fft64/rgsw.rs b/core/src/test_fft64/rgsw.rs index 50cd356..820b671 100644 --- a/core/src/test_fft64/rgsw.rs +++ b/core/src/test_fft64/rgsw.rs @@ -5,16 +5,20 @@ use base2k::{ use sampling::source::Source; use crate::{ - elem::{GetRow, Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, - grlwe::GRLWECt, - keys::{SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECtDft, RLWEPt}, + elem::{GetRow, Infos}, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, + }, + ggsw::GGSWCiphertext, + glwe::{GLWECiphertextFourier, GLWEPlaintext}, + keys::{SecretKey, SecretKeyFourier}, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, + keyswitch_key::GLWEKeySwitchKey, test_fft64::grlwe::noise_grlwe_rlwe_product, }; #[test] -fn encrypt_rgsw_sk() { +fn encrypt_sk() { let module: Module = Module::::new(2048); let log_base2k: usize = 8; let log_k_ct: usize = 54; @@ -23,9 +27,9 @@ fn encrypt_rgsw_sk() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut ct: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_ct, rows); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut pt_scalar: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -35,13 +39,14 @@ fn encrypt_rgsw_sk() { pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()), + GGSWCiphertext::encrypt_sk_scratch_space(&module, ct.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()), ); let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct.encrypt_sk( @@ -55,11 +60,11 @@ fn encrypt_rgsw_sk() { scratch.borrow(), ); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct.size()); - (0..ct.cols()).for_each(|col_j| { + (0..ct.rank()).for_each(|col_j| { (0..ct.rows()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); @@ -82,12 +87,10 @@ fn encrypt_rgsw_sk() { pt_want.data.zero(); }); }); - - module.free(); } #[test] -fn from_prod_by_grlwe() { +fn keyswitch() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -98,9 +101,9 @@ fn from_prod_by_grlwe() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw_in: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_in, rows); - let mut ct_rgsw_out: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_out, rows); + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_in, rows); + let mut ct_rgsw_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_out, rows); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -111,10 +114,10 @@ fn from_prod_by_grlwe() { pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw_out.size()) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw_in.size()) - | RGSWCt::prod_by_grlwe_scratch_space( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_out.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw_in.size()) + | GGSWCiphertext::keyswitch_scratch_space( &module, ct_rgsw_out.size(), ct_rgsw_in.size(), @@ -125,13 +128,13 @@ fn from_prod_by_grlwe() { let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -156,15 +159,15 @@ fn from_prod_by_grlwe() { scratch.borrow(), ); - ct_rgsw_out.prod_by_grlwe(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow()); + ct_rgsw_out.keyswitch(&module, &ct_rgsw_in, &ct_grlwe, scratch.borrow()); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw_out); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_out); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_out.size()); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_out); - (0..ct_rgsw_out.cols()).for_each(|col_j| { + (0..ct_rgsw_out.rank()).for_each(|col_j| { (0..ct_rgsw_out.rows()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); @@ -203,12 +206,10 @@ fn from_prod_by_grlwe() { pt_want.data.zero(); }); }); - - module.free(); } #[test] -fn from_prod_by_grlwe_inplace() { +fn keyswitch_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -218,8 +219,8 @@ fn from_prod_by_grlwe_inplace() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw, rows); + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw, rows); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); let mut source_xs: Source = Source::new([0u8; 32]); @@ -230,22 +231,22 @@ fn from_prod_by_grlwe_inplace() { pt_rgsw.fill_ternary_prob(0, 0.5, &mut source_xs); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw.size()) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RGSWCt::prod_by_grlwe_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()), + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct_rgsw.size(), ct_grlwe.size()), ); let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -270,15 +271,15 @@ fn from_prod_by_grlwe_inplace() { scratch.borrow(), ); - ct_rgsw.prod_by_grlwe_inplace(&module, &ct_grlwe, scratch.borrow()); + ct_rgsw.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw.size()); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw); - (0..ct_rgsw.cols()).for_each(|col_j| { + (0..ct_rgsw.rank()).for_each(|col_j| { (0..ct_rgsw.rows()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw, 0); @@ -317,12 +318,10 @@ fn from_prod_by_grlwe_inplace() { pt_want.data.zero(); }); }); - - module.free(); } #[test] -fn from_prod_by_rgsw() { +fn external_product() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_rgsw_rhs: usize = 60; @@ -333,9 +332,9 @@ fn from_prod_by_rgsw() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw_rhs: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_rhs, rows); - let mut ct_rgsw_lhs_in: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_lhs_in, rows); - let mut ct_rgsw_lhs_out: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_lhs_out, rows); + let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows); + let mut ct_rgsw_lhs_in: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_in, rows); + let mut ct_rgsw_lhs_out: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs_out, rows); let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); @@ -351,10 +350,10 @@ fn from_prod_by_rgsw() { pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size()) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw_lhs_in.size()) - | RGSWCt::prod_by_rgsw_scratch_space( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs_out.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw_lhs_in.size()) + | GGSWCiphertext::external_product_scratch_space( &module, ct_rgsw_lhs_out.size(), ct_rgsw_lhs_in.size(), @@ -365,7 +364,7 @@ fn from_prod_by_rgsw() { let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct_rgsw_rhs.encrypt_sk( @@ -390,17 +389,18 @@ fn from_prod_by_rgsw() { scratch.borrow(), ); - ct_rgsw_lhs_out.prod_by_rgsw(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow()); + ct_rgsw_lhs_out.external_product(&module, &ct_rgsw_lhs_in, &ct_rgsw_rhs, scratch.borrow()); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw_lhs_out); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs_out); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs_out); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs_out.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs_out.size()); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs_out); module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); - (0..ct_rgsw_lhs_out.cols()).for_each(|col_j| { + (0..ct_rgsw_lhs_out.rank()).for_each(|col_j| { (0..ct_rgsw_lhs_out.rows()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); @@ -448,12 +448,10 @@ fn from_prod_by_rgsw() { pt_want.data.zero(); }); }); - - module.free(); } #[test] -fn from_prod_by_rgsw_inplace() { +fn external_product_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_rgsw_rhs: usize = 60; @@ -463,8 +461,8 @@ fn from_prod_by_rgsw_inplace() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw_rhs: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_rhs, rows); - let mut ct_rgsw_lhs: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_rgsw_lhs, rows); + let mut ct_rgsw_rhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_rhs, rows); + let mut ct_rgsw_lhs: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_rgsw_lhs, rows); let mut pt_rgsw_lhs: ScalarZnx> = module.new_scalar_znx(1); let mut pt_rgsw_rhs: ScalarZnx> = module.new_scalar_znx(1); @@ -480,16 +478,16 @@ fn from_prod_by_rgsw_inplace() { pt_rgsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) - | RLWECtDft::decrypt_scratch_space(&module, ct_rgsw_lhs.size()) - | RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw_lhs.size()) - | RGSWCt::prod_by_rgsw_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()), + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_rgsw_rhs.size()) + | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_rgsw_lhs.size()) + | GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw_lhs.size()) + | GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_rgsw_lhs.size(), ct_rgsw_rhs.size()), ); let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct_rgsw_rhs.encrypt_sk( @@ -514,17 +512,17 @@ fn from_prod_by_rgsw_inplace() { scratch.borrow(), ); - ct_rgsw_lhs.prod_by_rgsw_inplace(&module, &ct_rgsw_rhs, scratch.borrow()); + ct_rgsw_lhs.external_product_inplace(&module, &ct_rgsw_rhs, scratch.borrow()); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rgsw_lhs); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rgsw_lhs); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); let mut pt_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(1, ct_rgsw_lhs.size()); let mut pt_big: VecZnxBig, FFT64> = module.new_vec_znx_big(1, ct_rgsw_lhs.size()); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rgsw_lhs); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rgsw_lhs); module.vec_znx_rotate_inplace(k as i64, &mut pt_rgsw_lhs, 0); - (0..ct_rgsw_lhs.cols()).for_each(|col_j| { + (0..ct_rgsw_lhs.rank()).for_each(|col_j| { (0..ct_rgsw_lhs.rows()).for_each(|row_i| { module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_rgsw_lhs, 0); @@ -572,8 +570,6 @@ fn from_prod_by_rgsw_inplace() { pt_want.data.zero(); }); }); - - module.free(); } pub(crate) fn noise_rgsw_product( diff --git a/core/src/test_fft64/rlwe.rs b/core/src/test_fft64/rlwe.rs index a2fabb9..6958925 100644 --- a/core/src/test_fft64/rlwe.rs +++ b/core/src/test_fft64/rlwe.rs @@ -6,11 +6,16 @@ use itertools::izip; use sampling::source::Source; use crate::{ - elem::{Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, - grlwe::GRLWECt, - keys::{PublicKey, SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECt, RLWECtDft, RLWEPt}, + elem::Infos, + encryption::EncryptSkScratchSpace, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, + }, + ggsw::GGSWCiphertext, + glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext}, + keys::{PublicKey, SecretKey, SecretKeyFourier}, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, + keyswitch_key::GLWEKeySwitchKey, test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_product}, }; @@ -24,21 +29,21 @@ fn encrypt_sk() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_pt); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_pt); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECt::decrypt_scratch_space(&module, ct.size()), + GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct.size()), ); let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); let mut data_want: Vec = vec![0i64; module.n()]; @@ -52,7 +57,7 @@ fn encrypt_sk() { ct.encrypt_sk( &module, - Some(&pt), + &pt, &sk_dft, &mut source_xa, &mut source_xe, @@ -81,8 +86,6 @@ fn encrypt_sk() { b_scaled ) }); - - module.free(); } #[test] @@ -94,7 +97,7 @@ fn encrypt_zero_sk() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([1u8; 32]); @@ -102,14 +105,14 @@ fn encrypt_zero_sk() { let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); - let mut ct_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct); + let mut ct_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_ct); let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECtDft::decrypt_scratch_space(&module, ct_dft.size()) - | RLWECtDft::encrypt_zero_sk_scratch_space(&module, ct_dft.size()), + GLWECiphertextFourier::decrypt_scratch_space(&module, ct_dft.size()) + | GLWECiphertextFourier::encrypt_zero_sk_scratch_space(&module, ct_dft.size()), ); ct_dft.encrypt_zero_sk( @@ -124,7 +127,6 @@ fn encrypt_zero_sk() { ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2); - module.free(); } #[test] @@ -137,8 +139,8 @@ fn encrypt_pk() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut ct: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_ct); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -147,7 +149,7 @@ fn encrypt_pk() { let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); let mut pk: PublicKey, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk); @@ -161,9 +163,9 @@ fn encrypt_pk() { ); let mut scratch: ScratchOwned = ScratchOwned::new( - RLWECt::encrypt_sk_scratch_space(&module, ct.size()) - | RLWECt::decrypt_scratch_space(&module, ct.size()) - | RLWECt::encrypt_pk_scratch_space(&module, pk.size()), + GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct.size()) + | GLWECiphertext::encrypt_pk_scratch_space(&module, pk.size()), ); let mut data_want: Vec = vec![0i64; module.n()]; @@ -178,7 +180,7 @@ fn encrypt_pk() { ct.encrypt_pk( &module, - Some(&pt_want), + &pt_want, &pk, &mut source_xu, &mut source_xe, @@ -187,19 +189,17 @@ fn encrypt_pk() { scratch.borrow(), ); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_ct); ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2); - - module.free(); } #[test] -fn prod_by_grlwe() { +fn keyswitch() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -210,11 +210,11 @@ fn prod_by_grlwe() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -226,10 +226,10 @@ fn prod_by_grlwe() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECt::prod_by_grlwe_scratch_space( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::keyswitch_scratch_space( &module, ct_rlwe_out.size(), ct_rlwe_in.size(), @@ -240,13 +240,13 @@ fn prod_by_grlwe() { let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -262,7 +262,7 @@ fn prod_by_grlwe() { ct_rlwe_in.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk0_dft, &mut source_xa, &mut source_xe, @@ -271,7 +271,7 @@ fn prod_by_grlwe() { scratch.borrow(), ); - ct_rlwe_out.prod_by_grlwe(&module, &ct_rlwe_in, &ct_grlwe, scratch.borrow()); + ct_rlwe_out.keyswitch(&module, &ct_rlwe_in, &ct_grlwe, scratch.borrow()); ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); @@ -296,12 +296,10 @@ fn prod_by_grlwe() { noise_have, noise_want ); - - module.free(); } #[test] -fn prod_by_grlwe_inplace() { +fn keyswich_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -311,10 +309,10 @@ fn prod_by_grlwe_inplace() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -326,22 +324,22 @@ fn prod_by_grlwe_inplace() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECt::prod_by_grlwe_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), ct_grlwe.size()), ); let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -357,7 +355,7 @@ fn prod_by_grlwe_inplace() { ct_rlwe.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk0_dft, &mut source_xa, &mut source_xe, @@ -366,7 +364,7 @@ fn prod_by_grlwe_inplace() { scratch.borrow(), ); - ct_rlwe.prod_by_grlwe_inplace(&module, &ct_grlwe, scratch.borrow()); + ct_rlwe.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); @@ -391,12 +389,10 @@ fn prod_by_grlwe_inplace() { noise_have, noise_want ); - - module.free(); } #[test] -fn prod_by_rgsw() { +fn external_product() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -407,12 +403,12 @@ fn prod_by_rgsw() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -430,10 +426,10 @@ fn prod_by_rgsw() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECt::prod_by_grlwe_scratch_space( + GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::external_product_scratch_space( &module, ct_rlwe_out.size(), ct_rlwe_in.size(), @@ -444,7 +440,7 @@ fn prod_by_rgsw() { let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( @@ -460,7 +456,7 @@ fn prod_by_rgsw() { ct_rlwe_in.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk_dft, &mut source_xa, &mut source_xe, @@ -469,7 +465,7 @@ fn prod_by_rgsw() { scratch.borrow(), ); - ct_rlwe_out.prod_by_rgsw(&module, &ct_rlwe_in, &ct_rgsw, scratch.borrow()); + ct_rlwe_out.external_product(&module, &ct_rlwe_in, &ct_rgsw, scratch.borrow()); ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); @@ -505,12 +501,10 @@ fn prod_by_rgsw() { noise_have, noise_want ); - - module.free(); } #[test] -fn prod_by_rgsw_inplace() { +fn external_product_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -521,11 +515,11 @@ fn prod_by_rgsw_inplace() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -543,16 +537,16 @@ fn prod_by_rgsw_inplace() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECt::prod_by_rgsw_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), ); let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( @@ -568,7 +562,7 @@ fn prod_by_rgsw_inplace() { ct_rlwe.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk_dft, &mut source_xa, &mut source_xe, @@ -577,7 +571,7 @@ fn prod_by_rgsw_inplace() { scratch.borrow(), ); - ct_rlwe.prod_by_rgsw_inplace(&module, &ct_rgsw, scratch.borrow()); + ct_rlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); @@ -613,6 +607,4 @@ fn prod_by_rgsw_inplace() { noise_have, noise_want ); - - module.free(); } diff --git a/core/src/test_fft64/rlwe_dft.rs b/core/src/test_fft64/rlwe_dft.rs index fe71a09..06359b1 100644 --- a/core/src/test_fft64/rlwe_dft.rs +++ b/core/src/test_fft64/rlwe_dft.rs @@ -1,16 +1,21 @@ use crate::{ - elem::{Infos, ProdInplace, ProdInplaceScratchSpace, ProdScratchSpace, Product}, - grlwe::GRLWECt, - keys::{SecretKey, SecretKeyDft}, - rgsw::RGSWCt, - rlwe::{RLWECt, RLWECtDft, RLWEPt}, + elem::Infos, + encryption::EncryptSkScratchSpace, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, + }, + ggsw::GGSWCiphertext, + glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext}, + keys::{SecretKey, SecretKeyFourier}, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, + keyswitch_key::GLWEKeySwitchKey, test_fft64::{grlwe::noise_grlwe_rlwe_product, rgsw::noise_rgsw_product}, }; use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; use sampling::source::Source; #[test] -fn by_grlwe_inplace() { +fn keyswitch() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -21,13 +26,15 @@ fn by_grlwe_inplace() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_in_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut ct_rlwe_out_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_in_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rlwe_out_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -39,10 +46,10 @@ fn by_grlwe_inplace() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECtDft::prod_by_grlwe_scratch_space( + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertextFourier::keyswitch_scratch_space( &module, ct_rlwe_out.size(), ct_rlwe_in.size(), @@ -53,13 +60,13 @@ fn by_grlwe_inplace() { let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -75,7 +82,7 @@ fn by_grlwe_inplace() { ct_rlwe_in.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk0_dft, &mut source_xa, &mut source_xe, @@ -85,7 +92,7 @@ fn by_grlwe_inplace() { ); ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft); - ct_rlwe_out_dft.prod_by_grlwe(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow()); + ct_rlwe_out_dft.keyswitch(&module, &ct_rlwe_in_dft, &ct_grlwe, scratch.borrow()); ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow()); ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); @@ -111,12 +118,10 @@ fn by_grlwe_inplace() { noise_have, noise_want ); - - module.free(); } #[test] -fn prod_by_grlwe_inplace() { +fn keyswich_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -126,11 +131,11 @@ fn prod_by_grlwe_inplace() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_grlwe: GRLWECt, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe); + let mut ct_grlwe: GLWEKeySwitchKey, FFT64> = GLWEKeySwitchKey::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -142,22 +147,22 @@ fn prod_by_grlwe_inplace() { .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa); let mut scratch: ScratchOwned = ScratchOwned::new( - GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECtDft::prod_by_grlwe_inplace_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), + GLWEKeySwitchKey::encrypt_sk_scratch_space(&module, ct_grlwe.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, ct_rlwe_dft.size(), ct_grlwe.size()), ); let mut sk0: SecretKey> = SecretKey::new(&module); sk0.fill_ternary_prob(0.5, &mut source_xs); - let mut sk0_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk0_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk0_dft.dft(&module, &sk0); let mut sk1: SecretKey> = SecretKey::new(&module); sk1.fill_ternary_prob(0.5, &mut source_xs); - let mut sk1_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk1_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk1_dft.dft(&module, &sk1); ct_grlwe.encrypt_sk( @@ -173,7 +178,7 @@ fn prod_by_grlwe_inplace() { ct_rlwe.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk0_dft, &mut source_xa, &mut source_xe, @@ -183,7 +188,7 @@ fn prod_by_grlwe_inplace() { ); ct_rlwe.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.prod_by_grlwe_inplace(&module, &ct_grlwe, scratch.borrow()); + ct_rlwe_dft.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow()); ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); @@ -209,12 +214,10 @@ fn prod_by_grlwe_inplace() { noise_have, noise_want ); - - module.free(); } #[test] -fn prod_by_rgsw() { +fn external_product() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -225,14 +228,16 @@ fn prod_by_rgsw() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe_in: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_out: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_out); - let mut ct_rlwe_dft_in: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_dft_out: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe_in: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_out: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_out); + let mut ct_rlwe_dft_in: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_dft_out: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_out); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -250,10 +255,10 @@ fn prod_by_rgsw() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) - | RLWECt::prod_by_rgsw_scratch_space( + GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size()) + | GLWECiphertext::external_product_scratch_space( &module, ct_rlwe_out.size(), ct_rlwe_in.size(), @@ -264,7 +269,7 @@ fn prod_by_rgsw() { let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( @@ -280,7 +285,7 @@ fn prod_by_rgsw() { ct_rlwe_in.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk_dft, &mut source_xa, &mut source_xe, @@ -290,7 +295,7 @@ fn prod_by_rgsw() { ); ct_rlwe_in.dft(&module, &mut ct_rlwe_dft_in); - ct_rlwe_dft_out.prod_by_rgsw(&module, &ct_rlwe_dft_in, &ct_rgsw, scratch.borrow()); + ct_rlwe_dft_out.external_product(&module, &ct_rlwe_dft_in, &ct_rgsw, scratch.borrow()); ct_rlwe_dft_out.idft(&module, &mut ct_rlwe_out, scratch.borrow()); ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); @@ -327,12 +332,10 @@ fn prod_by_rgsw() { noise_have, noise_want ); - - module.free(); } #[test] -fn prod_by_rgsw_inplace() { +fn external_product_inplace() { let module: Module = Module::::new(2048); let log_base2k: usize = 12; let log_k_grlwe: usize = 60; @@ -343,12 +346,12 @@ fn prod_by_rgsw_inplace() { let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; - let mut ct_rgsw: RGSWCt, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows); - let mut ct_rlwe: RLWECt> = RLWECt::new(&module, log_base2k, log_k_rlwe_in); - let mut ct_rlwe_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rgsw: GGSWCiphertext, FFT64> = GGSWCiphertext::new(&module, log_base2k, log_k_grlwe, rows); + let mut ct_rlwe: GLWECiphertext> = GLWECiphertext::new(&module, log_base2k, log_k_rlwe_in); + let mut ct_rlwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::new(&module, log_base2k, log_k_rlwe_in); let mut pt_rgsw: ScalarZnx> = module.new_scalar_znx(1); - let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in); - let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out); + let mut pt_want: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_in); + let mut pt_have: GLWEPlaintext> = GLWEPlaintext::new(&module, log_base2k, log_k_rlwe_out); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); @@ -366,16 +369,16 @@ fn prod_by_rgsw_inplace() { pt_rgsw.raw_mut()[k] = 1; // X^{k} let mut scratch: ScratchOwned = ScratchOwned::new( - RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size()) - | RLWECt::decrypt_scratch_space(&module, ct_rlwe.size()) - | RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size()) - | RLWECt::prod_by_rgsw_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), + GGSWCiphertext::encrypt_sk_scratch_space(&module, ct_rgsw.size()) + | GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size()) + | GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size()), ); let mut sk: SecretKey> = SecretKey::new(&module); sk.fill_ternary_prob(0.5, &mut source_xs); - let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + let mut sk_dft: SecretKeyFourier, FFT64> = SecretKeyFourier::new(&module); sk_dft.dft(&module, &sk); ct_rgsw.encrypt_sk( @@ -391,7 +394,7 @@ fn prod_by_rgsw_inplace() { ct_rlwe.encrypt_sk( &module, - Some(&pt_want), + &pt_want, &sk_dft, &mut source_xa, &mut source_xe, @@ -401,7 +404,7 @@ fn prod_by_rgsw_inplace() { ); ct_rlwe.dft(&module, &mut ct_rlwe_dft); - ct_rlwe_dft.prod_by_rgsw_inplace(&module, &ct_rgsw, scratch.borrow()); + ct_rlwe_dft.external_product_inplace(&module, &ct_rgsw, scratch.borrow()); ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow()); ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); @@ -438,6 +441,4 @@ fn prod_by_rgsw_inplace() { noise_have, noise_want ); - - module.free(); } diff --git a/core/src/vec_glwe_product.rs b/core/src/vec_glwe_product.rs new file mode 100644 index 0000000..7920de9 --- /dev/null +++ b/core/src/vec_glwe_product.rs @@ -0,0 +1,197 @@ +use base2k::{ + FFT64, Module, Scratch, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, + VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, +}; + +use crate::{ + elem::{GetRow, Infos, SetRow}, + glwe::{GLWECiphertext, GLWECiphertextFourier}, +}; + +pub(crate) trait VecGLWEProductScratchSpace { + fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize; + + fn prod_with_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + Self::prod_with_glwe_scratch_space(module, res_size, res_size, rhs) + } + + fn prod_with_glwe_dft_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + (Self::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(2, lhs) + + module.bytes_of_vec_znx(2, res_size) + } + + fn prod_with_glwe_dft_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + (Self::prod_with_glwe_inplace_scratch_space(module, res_size, rhs) | module.vec_znx_idft_tmp_bytes()) + + module.bytes_of_vec_znx(2, res_size) + } + + fn prod_with_vec_glwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + Self::prod_with_glwe_dft_scratch_space(module, res_size, lhs, rhs) + + module.bytes_of_vec_znx_dft(2, lhs) + + module.bytes_of_vec_znx_dft(2, res_size) + } + + fn prod_with_vec_glwe_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + Self::prod_with_glwe_dft_inplace_scratch_space(module, res_size, rhs) + module.bytes_of_vec_znx_dft(2, res_size) + } +} + +pub(crate) trait VecGLWEProduct: Infos { + fn prod_with_glwe( + &self, + module: &Module, + res: &mut GLWECiphertext, + a: &GLWECiphertext, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut, + VecZnx: VecZnxToRef; + + fn prod_with_glwe_inplace(&self, module: &Module, res: &mut GLWECiphertext, scratch: &mut Scratch) + where + VecZnx: VecZnxToMut + VecZnxToRef, + { + unsafe { + let res_ptr: *mut GLWECiphertext = res as *mut GLWECiphertext; // This is ok because [Self::mul_rlwe] only updates res at the end. + self.prod_with_glwe(&module, &mut *res_ptr, &*res_ptr, scratch); + } + } + + fn prod_with_glwe_fourier( + &self, + module: &Module, + res: &mut GLWECiphertextFourier, + a: &GLWECiphertextFourier, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef + ZnxInfos, + VecZnxDft: VecZnxDftToRef + ZnxInfos, + { + let log_base2k: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.basek(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size()); + + let mut a_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: a_data, + log_base2k: a.basek(), + log_k: a.k(), + }; + + a.idft(module, &mut a_idft, scratch_1); + + let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: res_data, + log_base2k: res.basek(), + log_k: res.k(), + }; + + self.prod_with_glwe(module, &mut res_idft, &a_idft, scratch_2); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + fn prod_with_glwe_fourier_inplace( + &self, + module: &Module, + res: &mut GLWECiphertextFourier, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef + VecZnxDftToMut, + { + let log_base2k: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.basek(), log_base2k); + assert_eq!(self.n(), module.n()); + assert_eq!(res.n(), module.n()); + } + + let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size()); + + let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: res_data, + log_base2k: res.basek(), + log_k: res.k(), + }; + + res.idft(module, &mut res_idft, scratch_1); + + self.prod_with_glwe_inplace(module, &mut res_idft, scratch_1); + + module.vec_znx_dft(res, 0, &res_idft, 0); + module.vec_znx_dft(res, 1, &res_idft, 1); + } + + fn prod_with_vec_glwe(&self, module: &Module, res: &mut RES, a: &LHS, scratch: &mut Scratch) + where + LHS: GetRow + Infos, + RES: SetRow + Infos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size()); + + let mut tmp_a_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: a.basek(), + log_k: a.k(), + }; + + let (tmp_res_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, res.size()); + + let mut tmp_res_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_res_data, + log_base2k: res.basek(), + log_k: res.k(), + }; + + let min_rows: usize = res.rows().min(a.rows()); + + (0..res.rows()).for_each(|row_i| { + (0..res.rank()).for_each(|col_j| { + a.get_row(module, row_i, col_j, &mut tmp_a_row); + self.prod_with_glwe_fourier(module, &mut tmp_res_row, &tmp_a_row, scratch2); + res.set_row(module, row_i, col_j, &tmp_res_row); + }); + }); + + tmp_res_row.data.zero(); + + (min_rows..res.rows()).for_each(|row_i| { + (0..self.rank()).for_each(|col_j| { + res.set_row(module, row_i, col_j, &tmp_res_row); + }); + }); + } + + fn prod_with_vec_glwe_inplace(&self, module: &Module, res: &mut RES, scratch: &mut Scratch) + where + RES: GetRow + SetRow + Infos, + { + let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size()); + + let mut tmp_row: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_row_data, + log_base2k: res.basek(), + log_k: res.k(), + }; + + (0..res.rows()).for_each(|row_i| { + (0..res.rank()).for_each(|col_j| { + res.get_row(module, row_i, col_j, &mut tmp_row); + self.prod_with_glwe_fourier_inplace(module, &mut tmp_row, scratch1); + res.set_row(module, row_i, col_j, &tmp_row); + }); + }); + } +}

, - sk_dft: &SecretKeyDft, + sk_dft: &SecretKeyFourier, source_xa: &mut Source, source_xe: &mut Source, sigma: f64, @@ -163,11 +166,11 @@ impl RGSWCt { } } -impl GetRow for RGSWCt +impl GetRow for GGSWCiphertext where MatZnxDft: MatZnxDftToRef, { - fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut RLWECtDft) + fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier) where VecZnxDft: VecZnxDftToMut, { @@ -175,11 +178,11 @@ where } } -impl SetRow for RGSWCt +impl SetRow for GGSWCiphertext where MatZnxDft: MatZnxDftToMut, { - fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &RLWECtDft) + fn set_row(&mut self, module: &Module, row_i: usize, col_j: usize, a: &GLWECiphertextFourier) where VecZnxDft: VecZnxDftToRef, { @@ -187,30 +190,118 @@ where } } -impl MatRLWEProductScratchSpace for RGSWCt, FFT64> { - fn prod_with_rlwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { +impl KeySwitchScratchSpace for GGSWCiphertext, FFT64> { + fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl KeySwitch for GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GGSWCiphertext; + type Rhs = GLWEKeySwitchKey; + + fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe(module, self, lhs, scratch); + } +} + +impl KeySwitchInplaceScratchSpace for GGSWCiphertext, FFT64> { + fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl KeySwitchInplace for GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Rhs = GLWEKeySwitchKey; + + fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe(module, self, rhs, scratch); + } +} + +impl ExternalProductScratchSpace for GGSWCiphertext, FFT64> { + fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space( + module, res_size, lhs, rhs, + ) + } +} + +impl ExternalProduct for GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GGSWCiphertext; + type Rhs = GGSWCiphertext; + + fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe(module, self, lhs, scratch); + } +} + +impl ExternalProductInplaceScratchSpace for GGSWCiphertext, FFT64> { + fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl ExternalProductInplace for GGSWCiphertext +where + MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Rhs = GGSWCiphertext; + + fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_vec_glwe_inplace(module, self, scratch); + } +} + +impl VecGLWEProductScratchSpace for GGSWCiphertext, FFT64> { + fn prod_with_glwe_scratch_space(module: &Module, res_size: usize, a_size: usize, rgsw_size: usize) -> usize { module.bytes_of_vec_znx_dft(2, rgsw_size) + ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size)) | module.vec_znx_big_normalize_tmp_bytes()) } } -impl MatRLWEProduct for RGSWCt +impl VecGLWEProduct for GGSWCiphertext where MatZnxDft: MatZnxDftToRef + ZnxInfos, { - fn prod_with_rlwe(&self, module: &Module, res: &mut RLWECt, a: &RLWECt, scratch: &mut Scratch) - where - MatZnxDft: MatZnxDftToRef, + fn prod_with_glwe( + &self, + module: &Module, + res: &mut GLWECiphertext, + a: &GLWECiphertext, + scratch: &mut Scratch, + ) where VecZnx: VecZnxToMut, VecZnx: VecZnxToRef, { - let log_base2k: usize = self.log_base2k(); + let log_base2k: usize = self.basek(); #[cfg(debug_assertions)] { - assert_eq!(res.log_base2k(), log_base2k); - assert_eq!(a.log_base2k(), log_base2k); + assert_eq!(res.basek(), log_base2k); + assert_eq!(a.basek(), log_base2k); assert_eq!(self.n(), module.n()); assert_eq!(res.n(), module.n()); assert_eq!(a.n(), module.n()); @@ -231,53 +322,3 @@ where module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1); } } - -impl ProdInplaceScratchSpace for RGSWCt, FFT64> { - fn prod_by_grlwe_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) - } - - fn prod_by_rgsw_inplace_scratch_space(module: &Module, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_inplace_scratch_space(module, lhs, rhs) - } -} - -impl ProdScratchSpace for RGSWCt, FFT64> { - fn prod_by_grlwe_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) - } - - fn prod_by_rgsw_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { - , FFT64> as MatRLWEProductScratchSpace>::prod_with_mat_rlwe_scratch_space(module, res_size, lhs, rhs) - } -} - -impl ProdInplace for RGSWCt -where - RGSWCt: GetRow + SetRow + Infos, - MatZnxDft: MatZnxDftToRef, -{ - fn prod_by_grlwe_inplace(&mut self, module: &Module, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe_inplace(module, self, scratch); - } - - fn prod_by_rgsw_inplace(&mut self, module: &Module, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe_inplace(module, self, scratch); - } -} - -impl Product for RGSWCt -where - MatZnxDft: MatZnxDftToRef + MatZnxDftToMut, - MatZnxDft: MatZnxDftToRef, -{ - type Lhs = RGSWCt; - - fn prod_by_grlwe(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &GRLWECt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe(module, self, lhs, scratch); - } - - fn prod_by_rgsw(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &RGSWCt, scratch: &mut Scratch) { - rhs.prod_with_mat_rlwe(module, self, lhs, scratch); - } -} diff --git a/core/src/glwe.rs b/core/src/glwe.rs new file mode 100644 index 0000000..e50582d --- /dev/null +++ b/core/src/glwe.rs @@ -0,0 +1,845 @@ +use base2k::{ + AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, + ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, + VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos, +}; +use sampling::source::Source; + +use crate::{ + elem::Infos, + encryption::{EncryptSk, EncryptSkScratchSpace, EncryptZeroSkScratchSpace}, + external_product::{ + ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace, + }, + ggsw::GGSWCiphertext, + keys::{PublicKey, SecretDistribution, SecretKeyFourier}, + keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace}, + keyswitch_key::GLWEKeySwitchKey, + utils::derive_size, + vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace}, +}; + +pub struct GLWECiphertext { + pub data: VecZnx, + pub log_base2k: usize, + pub log_k: usize, +} + +impl GLWECiphertext> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx(2, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for GLWECiphertext { + type Inner = VecZnx; + + fn inner(&self) -> &Self::Inner { + &self.data + } + + fn basek(&self) -> usize { + self.log_base2k + } + + fn k(&self) -> usize { + self.log_k + } +} + +impl VecZnxToMut for GLWECiphertext +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for GLWECiphertext +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl GLWECiphertext +where + VecZnx: VecZnxToRef, +{ + #[allow(dead_code)] + pub(crate) fn dft(&self, module: &Module, res: &mut GLWECiphertextFourier) + where + VecZnxDft: VecZnxDftToMut + ZnxInfos, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.rank(), 2); + assert_eq!(res.rank(), 2); + assert_eq!(self.basek(), res.basek()) + } + + module.vec_znx_dft(res, 0, self, 0); + module.vec_znx_dft(res, 1, self, 1); + } +} + +impl KeySwitchScratchSpace for GLWECiphertext> { + fn keyswitch_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } +} + +impl KeySwitch for GLWECiphertext +where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GLWECiphertext; + type Rhs = GLWEKeySwitchKey; + + fn keyswitch(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe(module, self, lhs, scratch); + } +} + +impl KeySwitchInplaceScratchSpace for GLWECiphertext> { + fn keyswitch_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl KeySwitchInplace for GLWECiphertext +where + VecZnx: VecZnxToMut + VecZnxToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Rhs = GLWEKeySwitchKey; + + fn keyswitch_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe_inplace(module, self, scratch); + } +} + +impl ExternalProductScratchSpace for GLWECiphertext> { + fn external_product_scratch_space(module: &Module, res_size: usize, lhs: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs) + } +} + +impl ExternalProduct for GLWECiphertext +where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, + MatZnxDft: MatZnxDftToRef, +{ + type Lhs = GLWECiphertext; + type Rhs = GGSWCiphertext; + + fn external_product(&mut self, module: &Module, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe(module, self, lhs, scratch); + } +} + +impl ExternalProductInplaceScratchSpace for GLWECiphertext> { + fn external_product_inplace_scratch_space(module: &Module, res_size: usize, rhs: usize) -> usize { + , FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space( + module, res_size, rhs, + ) + } +} + +impl ExternalProductInplace for GLWECiphertext +where + VecZnx: VecZnxToMut + VecZnxToRef, + MatZnxDft: MatZnxDftToRef + ZnxInfos, +{ + type Rhs = GGSWCiphertext; + + fn external_product_inplace(&mut self, module: &Module, rhs: &Self::Rhs, scratch: &mut Scratch) { + rhs.prod_with_glwe_inplace(module, self, scratch); + } +} + +impl GLWECiphertext> { + pub fn encrypt_pk_scratch_space(module: &Module, pk_size: usize) -> usize { + ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) + + module.bytes_of_scalar_znx_dft(1) + + module.vec_znx_big_normalize_tmp_bytes() + } + + pub fn decrypt_scratch_space(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + } +} + +impl EncryptSkScratchSpace for GLWECiphertext> { + fn encrypt_sk_scratch_space(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + } +} + +impl EncryptSk for GLWECiphertext +where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + type Ciphertext = GLWECiphertext; + type Plaintext = GLWEPlaintext; + type SecretKey = SecretKeyFourier; + + fn encrypt_sk( + &self, + module: &Module, + ct: &mut Self::Ciphertext, + pt: &Self::Plaintext, + sk: &Self::SecretKey, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) { + encrypt_glwe_sk( + module, + ct, + Some((pt, 0)), + sk, + source_xa, + source_xe, + sigma, + bound, + scratch, + ); + } +} + +pub(crate) fn encrypt_glwe_sk( + module: &Module, + ct: &mut GLWECiphertext, + pt: Option<(&GLWEPlaintext, usize)>, + sk_dft: &SecretKeyFourier, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let log_base2k: usize = ct.basek(); + let log_k: usize = ct.k(); + let size: usize = ct.size(); + + // c1 = a + ct.data.fill_uniform(log_base2k, 1, size, source_xa); + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); + + { + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + module.vec_znx_dft(&mut c0_dft, 0, ct, 1); + + // c0_dft = DFT(a) * DFT(s) + module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0); + + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); + } + + // c0_big = m - c0_big + if let Some((pt, col)) = pt { + match col { + 0 => module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0), + 1 => { + module.vec_znx_big_negate_inplace(&mut c0_big, 0); + module.vec_znx_add_inplace(ct, 1, pt, 0); + module.vec_znx_normalize_inplace(log_base2k, ct, 1, scratch_1); + } + _ => panic!("invalid target column: {}", col), + } + } else { + module.vec_znx_big_negate_inplace(&mut c0_big, 0); + } + // c0_big += e + c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); + + // c0 = norm(c0_big = -as + m + e) + module.vec_znx_big_normalize(log_base2k, ct, 0, &c0_big, 0, scratch_1); +} + +pub fn decrypt_glwe( + module: &Module, + pt: &mut GLWEPlaintext