From d989867c914314f91f4b5a5009a850ca656246ff Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 24 Oct 2025 18:13:43 +0200 Subject: [PATCH 01/11] Add bdd rotation --- Cargo.lock | 8 +-- poulpy-backend/Cargo.toml | 2 +- poulpy-core/Cargo.toml | 2 +- poulpy-core/src/operations/ggsw.rs | 55 +++++++++++++++++++ poulpy-core/src/operations/glwe.rs | 5 ++ poulpy-core/src/operations/mod.rs | 2 + poulpy-hal/Cargo.toml | 2 +- poulpy-schemes/Cargo.toml | 2 +- .../src/tfhe/bdd_arithmetic/bdd_rotation.rs | 42 ++++++++++++++ .../tfhe/bdd_arithmetic/ciphertexts/block.rs | 11 ++++ .../ciphertexts/block_prepared.rs | 23 ++++++++ .../src/tfhe/bdd_arithmetic/eval.rs | 53 ++++++++++-------- poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs | 2 + 13 files changed, 177 insertions(+), 32 deletions(-) create mode 100644 poulpy-core/src/operations/ggsw.rs create mode 100644 poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs diff --git a/Cargo.lock b/Cargo.lock index 01ab4c1..0037bc7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -353,7 +353,7 @@ dependencies = [ [[package]] name = "poulpy-backend" -version = "0.2.0" +version = "0.3.1" dependencies = [ "byteorder", "cmake", @@ -370,7 +370,7 @@ dependencies = [ [[package]] name = "poulpy-core" -version = "0.2.0" +version = "0.3.1" dependencies = [ "byteorder", "criterion", @@ -383,7 +383,7 @@ dependencies = [ [[package]] name = "poulpy-hal" -version = "0.2.0" +version = "0.3.1" dependencies = [ "bytemuck", "byteorder", @@ -400,7 +400,7 @@ dependencies = [ [[package]] name = "poulpy-schemes" -version = "0.2.0" +version = "0.3.0" dependencies = [ "byteorder", "criterion", diff --git a/poulpy-backend/Cargo.toml b/poulpy-backend/Cargo.toml index 1b74aec..0ca482c 100644 --- a/poulpy-backend/Cargo.toml +++ b/poulpy-backend/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "poulpy-backend" -version = "0.2.0" +version = "0.3.1" edition = "2024" license = "Apache-2.0" readme = "README.md" diff --git a/poulpy-core/Cargo.toml b/poulpy-core/Cargo.toml index ad78124..42eb830 100644 --- a/poulpy-core/Cargo.toml +++ b/poulpy-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "poulpy-core" -version = "0.2.0" +version = "0.3.1" edition = "2024" license = "Apache-2.0" description = "A backend agnostic crate implementing RLWE-based encryption & arithmetic." diff --git a/poulpy-core/src/operations/ggsw.rs b/poulpy-core/src/operations/ggsw.rs new file mode 100644 index 0000000..b850c3c --- /dev/null +++ b/poulpy-core/src/operations/ggsw.rs @@ -0,0 +1,55 @@ +use poulpy_hal::layouts::{Backend, Module, Scratch}; + +use crate::{ + GLWERotate, ScratchTakeCore, + layouts::{GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWEInfos}, +}; + +impl GGSWRotate for Module where Module: GLWERotate {} + +pub trait GGSWRotate +where + Self: GLWERotate, +{ + fn ggsw_rotate_tmp_bytes(&self) -> usize { + self.glwe_rotate_tmp_bytes() + } + + fn ggsw_rotate(&self, k: i64, res: &mut R, a: &A) + where + R: GGSWToMut, + A: GGSWToRef, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let a: &GGSW<&[u8]> = &a.to_ref(); + + assert!(res.dnum() <= a.dnum()); + assert_eq!(res.dsize(), a.dsize()); + assert_eq!(res.rank(), a.rank()); + let rows: usize = res.dnum().into(); + let cols: usize = (res.rank() + 1).into(); + + for row in 0..rows { + for col in 0..cols { + self.glwe_rotate(k, &mut res.at_mut(row, col), &a.at(row, col)); + } + } + } + + fn ggsw_rotate_inplace(&self, k: i64, res: &mut R, scratch: &mut Scratch) + where + R: GGSWToMut, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + + let rows: usize = res.dnum().into(); + let cols: usize = (res.rank() + 1).into(); + + for row in 0..rows { + for col in 0..cols { + self.glwe_rotate_inplace(k, &mut res.at_mut(row, col), scratch); + } + } + } +} diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index c6f1818..6dcb52c 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -5,6 +5,7 @@ use poulpy_hal::{ VecZnxSubInplace, VecZnxSubNegateInplace, }, layouts::{Backend, Module, Scratch, VecZnx, ZnxZero}, + reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes, }; use crate::{ @@ -185,6 +186,10 @@ pub trait GLWERotate where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace, { + fn glwe_rotate_tmp_bytes(&self) -> usize { + vec_znx_rotate_inplace_tmp_bytes(self.n()) + } + fn glwe_rotate(&self, k: i64, res: &mut R, a: &A) where R: GLWEToMut, diff --git a/poulpy-core/src/operations/mod.rs b/poulpy-core/src/operations/mod.rs index 3b2432e..8775060 100644 --- a/poulpy-core/src/operations/mod.rs +++ b/poulpy-core/src/operations/mod.rs @@ -1,3 +1,5 @@ +mod ggsw; mod glwe; +pub use ggsw::*; pub use glwe::*; diff --git a/poulpy-hal/Cargo.toml b/poulpy-hal/Cargo.toml index 5364c76..93325b3 100644 --- a/poulpy-hal/Cargo.toml +++ b/poulpy-hal/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "poulpy-hal" -version = "0.2.0" +version = "0.3.1" edition = "2024" license = "Apache-2.0" readme = "README.md" diff --git a/poulpy-schemes/Cargo.toml b/poulpy-schemes/Cargo.toml index 62db78a..75464ba 100644 --- a/poulpy-schemes/Cargo.toml +++ b/poulpy-schemes/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "poulpy-schemes" -version = "0.2.0" +version = "0.3.0" edition = "2024" license = "Apache-2.0" readme = "README.md" diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs new file mode 100644 index 0000000..a3bbf2e --- /dev/null +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs @@ -0,0 +1,42 @@ +use poulpy_core::{ + GLWECopy, GLWERotate, ScratchTakeCore, + layouts::{GLWE, GLWEToMut}, +}; +use poulpy_hal::layouts::{Backend, Scratch}; + +use crate::tfhe::bdd_arithmetic::{Cmux, GetGGSWBit, UnsignedInteger}; + +pub trait BDDRotation +where + Self: GLWECopy + GLWERotate + Cmux, + Scratch: ScratchTakeCore, +{ + /// Homomorphic multiplication of res by X^{k[bit_start..bit_start + bit_size] * bit_step}. + fn bdd_rotate( + &self, + res: &mut R, + k: K, + bit_start: usize, + bit_size: usize, + bit_step: usize, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + K: GetGGSWBit, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + let (mut tmp_res, scratch_1) = scratch.take_glwe(res); + + self.glwe_copy(&mut tmp_res, res); + + for i in 1..bit_size { + // res' = res * X^2^(i * bit_step) + self.glwe_rotate(1 << (i + bit_step), &mut tmp_res, res); + + // res = (res - res') * GGSW(b[i]) + res' + self.cmux_inplace(res, &tmp_res, &k.get_bit(i + bit_start), scratch_1); + } + } +} diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs index 0109eb7..549c9e0 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block.rs @@ -39,6 +39,17 @@ impl GLWEInfos for FheUintBlocks { } } +impl FheUintBlocks { + pub fn new(blocks: Vec>) -> Self { + assert_eq!(blocks.len(), T::WORD_SIZE); + Self { + blocks, + _base: 1, + _phantom: PhantomData, + } + } +} + impl FheUintBlocks, T> { pub fn alloc_from_infos(module: &Module, infos: &A) -> Self where diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs index bdc1945..2814773 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/block_prepared.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use poulpy_core::layouts::{ Base2K, Dnum, Dsize, GGSWInfos, GGSWPreparedFactory, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWPrepared, }; +use poulpy_core::layouts::{GGSWPreparedToMut, GGSWPreparedToRef}; use poulpy_core::{GGSWEncryptSk, ScratchTakeCore, layouts::GLWESecretPreparedToRef}; use poulpy_hal::layouts::{Backend, Data, DataRef, Module}; @@ -28,6 +29,28 @@ impl FheUintBlocksPreparedFactory for Mo { } +pub trait GetGGSWBit { + fn get_bit(&self, bit: usize) -> GGSWPrepared<&[u8], BE>; +} + +impl GetGGSWBit for FheUintBlocksPrepared { + fn get_bit(&self, bit: usize) -> GGSWPrepared<&[u8], BE> { + assert!(bit <= self.blocks.len()); + self.blocks[bit].to_ref() + } +} + +pub trait GetGGSWBitMut { + fn get_bit(&mut self, bit: usize) -> GGSWPrepared<&mut [u8], BE>; +} + +impl GetGGSWBitMut for FheUintBlocksPrepared { + fn get_bit(&mut self, bit: usize) -> GGSWPrepared<&mut [u8], BE> { + assert!(bit <= self.blocks.len()); + self.blocks[bit].to_mut() + } +} + pub trait FheUintBlocksPreparedFactory where Self: Sized + GGSWPreparedFactory, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs index 4f65eff..3c34f94 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs @@ -3,12 +3,9 @@ use core::panic; use itertools::Itertools; use poulpy_core::{ GLWEAdd, GLWECopy, GLWEExternalProduct, GLWESub, ScratchTakeCore, - layouts::{ - GLWE, LWEInfos, - prepared::{GGSWPrepared, GGSWPreparedToRef}, - }, + layouts::{GLWE, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef}, }; -use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}; +use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero}; use crate::tfhe::bdd_arithmetic::UnsignedInteger; @@ -146,30 +143,38 @@ pub enum Node { None, } -pub trait Cmux { - fn cmux(&self, out: &mut GLWE, t: &GLWE, f: &GLWE, s: &GGSWPrepared, scratch: &mut Scratch) +pub trait Cmux +where + Self: GLWEExternalProduct + GLWESub + GLWEAdd, + Scratch: ScratchTakeCore, +{ + fn cmux(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch) where - O: DataMut, - T: DataRef, - F: DataRef, - S: DataRef; + R: GLWEToMut, + T: GLWEToRef, + F: GLWEToRef, + S: GGSWPreparedToRef, + { + self.glwe_sub(res, t, f); + self.glwe_external_product_inplace(res, s, scratch); + self.glwe_add_inplace(res, f); + } + + fn cmux_inplace(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + S: GGSWPreparedToRef, + { + self.glwe_sub_inplace(res, a); + self.glwe_external_product_inplace(res, s, scratch); + self.glwe_add_inplace(res, a); + } } impl Cmux for Module where - Module: GLWEExternalProduct + GLWESub + GLWEAdd, + Self: GLWEExternalProduct + GLWESub + GLWEAdd, Scratch: ScratchTakeCore, { - fn cmux(&self, out: &mut GLWE, t: &GLWE, f: &GLWE, s: &GGSWPrepared, scratch: &mut Scratch) - where - O: DataMut, - T: DataRef, - F: DataRef, - S: DataRef, - { - // let mut out: GLWECiphertext<&mut [u8]> = out.to_mut(); - self.glwe_sub(out, t, f); - self.glwe_external_product_inplace(out, s, scratch); - self.glwe_add_inplace(out, f); - } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs index 0f66049..22e5073 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs @@ -1,10 +1,12 @@ mod bdd_2w_to_1w; +mod bdd_rotation; mod ciphertexts; mod circuits; mod eval; mod key; pub use bdd_2w_to_1w::*; +pub use bdd_rotation::*; pub use ciphertexts::*; pub(crate) use circuits::*; pub(crate) use eval::*; From eaac9c07d8cd948c44b16668b9b8282ffe507d12 Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Sat, 25 Oct 2025 15:43:18 +0200 Subject: [PATCH 02/11] Add GLWETensor --- poulpy-core/src/layouts/glwe_tensor.rs | 146 +++++++++++++++++++++++++ poulpy-core/src/layouts/mod.rs | 2 + 2 files changed, 148 insertions(+) create mode 100644 poulpy-core/src/layouts/glwe_tensor.rs diff --git a/poulpy-core/src/layouts/glwe_tensor.rs b/poulpy-core/src/layouts/glwe_tensor.rs new file mode 100644 index 0000000..516a854 --- /dev/null +++ b/poulpy-core/src/layouts/glwe_tensor.rs @@ -0,0 +1,146 @@ +use poulpy_hal::{ + layouts::{Data, DataMut, DataRef, FillUniform, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos}, + source::Source, +}; + +use crate::layouts::{Base2K, Degree, GLWEInfos, LWEInfos, Rank, SetGLWEInfos, TorusPrecision}; +use std::fmt; + +#[derive(PartialEq, Eq, Clone)] +pub struct GLWETensor { + pub(crate) data: VecZnx, + pub(crate) base2k: Base2K, + pub(crate) rank: Rank, + pub(crate) k: TorusPrecision, +} + +impl SetGLWEInfos for GLWETensor { + fn set_base2k(&mut self, base2k: Base2K) { + self.base2k = base2k + } + + fn set_k(&mut self, k: TorusPrecision) { + self.k = k + } +} + +impl GLWETensor { + pub fn data(&self) -> &VecZnx { + &self.data + } +} + +impl GLWETensor { + pub fn data_mut(&mut self) -> &mut VecZnx { + &mut self.data + } +} + +impl LWEInfos for GLWETensor { + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn size(&self) -> usize { + self.data.size() + } +} + +impl GLWEInfos for GLWETensor { + fn rank(&self) -> Rank { + self.rank + } +} + +impl fmt::Debug for GLWETensor { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{self}") + } +} + +impl fmt::Display for GLWETensor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "GLWETensor: base2k={} k={}: {}", + self.base2k().0, + self.k().0, + self.data + ) + } +} + +impl FillUniform for GLWETensor { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.data.fill_uniform(log_bound, source); + } +} + +impl GLWETensor> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GLWEInfos, + { + Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.rank()) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { + let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1) as usize; + GLWETensor { + data: VecZnx::alloc(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize), + base2k, + k, + rank, + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GLWEInfos, + { + Self::bytes_of(infos.n(), infos.base2k(), infos.k(), infos.rank()) + } + + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { + let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1) as usize; + VecZnx::bytes_of(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize) + } +} + +pub trait GLWETensorToRef { + fn to_ref(&self) -> GLWETensor<&[u8]>; +} + +impl GLWETensorToRef for GLWETensor { + fn to_ref(&self) -> GLWETensor<&[u8]> { + GLWETensor { + k: self.k, + base2k: self.base2k, + data: self.data.to_ref(), + rank: self.rank, + } + } +} + +pub trait GLWETensorToMut { + fn to_mut(&mut self) -> GLWETensor<&mut [u8]>; +} + +impl GLWETensorToMut for GLWETensor { + fn to_mut(&mut self) -> GLWETensor<&mut [u8]> { + GLWETensor { + k: self.k, + base2k: self.base2k, + rank: self.rank, + data: self.data.to_mut(), + } + } +} diff --git a/poulpy-core/src/layouts/mod.rs b/poulpy-core/src/layouts/mod.rs index 1168ac8..7c5cc5b 100644 --- a/poulpy-core/src/layouts/mod.rs +++ b/poulpy-core/src/layouts/mod.rs @@ -6,6 +6,7 @@ mod glwe_plaintext; mod glwe_public_key; mod glwe_secret; mod glwe_switching_key; +mod glwe_tensor; mod glwe_tensor_key; mod glwe_to_lwe_switching_key; mod lwe; @@ -26,6 +27,7 @@ pub use glwe_plaintext::*; pub use glwe_public_key::*; pub use glwe_secret::*; pub use glwe_switching_key::*; +pub use glwe_tensor::*; pub use glwe_tensor_key::*; pub use glwe_to_lwe_switching_key::*; pub use lwe::*; From e6e685c00ee226f938ffb00a2dc7cf96b6756504 Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Sat, 25 Oct 2025 15:55:06 +0200 Subject: [PATCH 03/11] Add GGSW blind rotation --- .../src/tfhe/bdd_arithmetic/bdd_rotation.rs | 51 +++++++++++++++++-- .../src/tfhe/bdd_arithmetic/eval.rs | 11 +++- 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs index a3bbf2e..7124907 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs @@ -1,21 +1,64 @@ use poulpy_core::{ GLWECopy, GLWERotate, ScratchTakeCore, - layouts::{GLWE, GLWEToMut}, + layouts::{GGSW, GGSWInfos, GGSWToMut, GLWE, GLWEInfos, GLWEToMut}, }; use poulpy_hal::layouts::{Backend, Scratch}; use crate::tfhe::bdd_arithmetic::{Cmux, GetGGSWBit, UnsignedInteger}; -pub trait BDDRotation +pub trait GGSWBlindRotation +where + Self: GLWEBlindRotation, + Scratch: ScratchTakeCore, +{ + fn ggsw_blind_rotation( + &self, + res: &mut R, + k: &K, + bit_start: usize, + bit_size: usize, + bit_step: usize, + scratch: &mut Scratch, + ) where + R: GGSWToMut, + K: GetGGSWBit, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + + for row in 0..res.dnum().into() { + for col in 0..(res.rank() + 1).into() { + self.glwe_blind_rotation( + &mut res.at_mut(row, col), + k, + bit_start, + bit_size, + bit_step, + scratch, + ); + } + } + } +} + +pub trait GLWEBlindRotation where Self: GLWECopy + GLWERotate + Cmux, Scratch: ScratchTakeCore, { + fn glwe_blind_rotation_tmp_bytes(&self, res_infos: &R, b_infos: &B) -> usize + where + R: GLWEInfos, + B: GGSWInfos, + { + self.cmux_tmp_bytes(res_infos, res_infos, b_infos) + GLWE::bytes_of_from_infos(res_infos) + } + /// Homomorphic multiplication of res by X^{k[bit_start..bit_start + bit_size] * bit_step}. - fn bdd_rotate( + fn glwe_blind_rotation( &self, res: &mut R, - k: K, + k: &K, bit_start: usize, bit_size: usize, bit_step: usize, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs index 3c34f94..4613208 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs @@ -3,7 +3,7 @@ use core::panic; use itertools::Itertools; use poulpy_core::{ GLWEAdd, GLWECopy, GLWEExternalProduct, GLWESub, ScratchTakeCore, - layouts::{GLWE, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef}, + layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef}, }; use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero}; @@ -148,6 +148,15 @@ where Self: GLWEExternalProduct + GLWESub + GLWEAdd, Scratch: ScratchTakeCore, { + fn cmux_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GGSWInfos, + { + self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos) + } + fn cmux(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch) where R: GLWEToMut, From 6d6d00e9e45613b099516251b1880c37bfedfeee Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Sat, 25 Oct 2025 15:56:26 +0200 Subject: [PATCH 04/11] Add scratch space for ggsw blind rotation --- .../src/tfhe/bdd_arithmetic/bdd_rotation.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs index 7124907..bf2116e 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs @@ -11,6 +11,14 @@ where Self: GLWEBlindRotation, Scratch: ScratchTakeCore, { + fn ggsw_blind_rotation_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize + where + R: GLWEInfos, + K: GGSWInfos, + { + self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + } + fn ggsw_blind_rotation( &self, res: &mut R, @@ -46,12 +54,12 @@ where Self: GLWECopy + GLWERotate + Cmux, Scratch: ScratchTakeCore, { - fn glwe_blind_rotation_tmp_bytes(&self, res_infos: &R, b_infos: &B) -> usize + fn glwe_blind_rotation_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize where R: GLWEInfos, - B: GGSWInfos, + K: GGSWInfos, { - self.cmux_tmp_bytes(res_infos, res_infos, b_infos) + GLWE::bytes_of_from_infos(res_infos) + self.cmux_tmp_bytes(res_infos, res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) } /// Homomorphic multiplication of res by X^{k[bit_start..bit_start + bit_size] * bit_step}. From 98208d5e67d307b03e41ce78f367ed2ac41f0ea3 Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Sat, 25 Oct 2025 17:58:34 +0200 Subject: [PATCH 05/11] add test for GLWEBlindRotation --- poulpy-backend/src/cpu_fft64_avx/vec_znx.rs | 13 +- poulpy-backend/src/cpu_fft64_ref/vec_znx.rs | 13 +- .../src/cpu_spqlios/fft64/vec_znx.rs | 13 +- poulpy-core/src/operations/glwe.rs | 15 +- poulpy-hal/src/api/vec_znx.rs | 6 + poulpy-hal/src/delegates/vec_znx.rs | 16 ++- poulpy-hal/src/oep/vec_znx.rs | 10 ++ poulpy-hal/src/reference/vec_znx/mod.rs | 2 + poulpy-hal/src/reference/vec_znx/zero.rs | 16 +++ .../src/tfhe/bdd_arithmetic/bdd_rotation.rs | 65 +++++++-- .../tfhe/bdd_arithmetic/tests/fft64_ref.rs | 7 +- .../tests/test_suite/glwe_blind_rotation.rs | 134 ++++++++++++++++++ .../bdd_arithmetic/tests/test_suite/mod.rs | 2 + 13 files changed, 286 insertions(+), 26 deletions(-) create mode 100644 poulpy-hal/src/reference/vec_znx/zero.rs create mode 100644 poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs diff --git a/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs b/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs index 33325a7..9e286b7 100644 --- a/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs +++ b/poulpy-backend/src/cpu_fft64_avx/vec_znx.rs @@ -14,7 +14,7 @@ use poulpy_hal::{ VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, - VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, + VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxZeroImpl, }, reference::vec_znx::{ vec_znx_add, vec_znx_add_inplace, vec_znx_add_normal_ref, vec_znx_add_scalar, vec_znx_add_scalar_inplace, @@ -25,13 +25,22 @@ use poulpy_hal::{ vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace, vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar, - vec_znx_sub_scalar_inplace, vec_znx_switch_ring, + vec_znx_sub_scalar_inplace, vec_znx_switch_ring, vec_znx_zero, }, source::Source, }; use crate::cpu_fft64_avx::FFT64Avx; +unsafe impl VecZnxZeroImpl for FFT64Avx { + fn vec_znx_zero_impl(_module: &Module, res: &mut R, res_col: usize) + where + R: VecZnxToMut, + { + vec_znx_zero::<_, FFT64Avx>(res, res_col); + } +} + unsafe impl VecZnxNormalizeTmpBytesImpl for FFT64Avx { fn vec_znx_normalize_tmp_bytes_impl(module: &Module) -> usize { vec_znx_normalize_tmp_bytes(module.n()) diff --git a/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs b/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs index fa88aaa..a2a2086 100644 --- a/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs +++ b/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs @@ -14,7 +14,7 @@ use poulpy_hal::{ VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, - VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, + VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxZeroImpl, }, reference::vec_znx::{ vec_znx_add, vec_znx_add_inplace, vec_znx_add_normal_ref, vec_znx_add_scalar, vec_znx_add_scalar_inplace, @@ -25,13 +25,22 @@ use poulpy_hal::{ vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace, vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar, - vec_znx_sub_scalar_inplace, vec_znx_switch_ring, + vec_znx_sub_scalar_inplace, vec_znx_switch_ring, vec_znx_zero, }, source::Source, }; use crate::cpu_fft64_ref::FFT64Ref; +unsafe impl VecZnxZeroImpl for FFT64Ref { + fn vec_znx_zero_impl(_module: &Module, res: &mut R, res_col: usize) + where + R: VecZnxToMut, + { + vec_znx_zero::<_, FFT64Ref>(res, res_col); + } +} + unsafe impl VecZnxNormalizeTmpBytesImpl for FFT64Ref { fn vec_znx_normalize_tmp_bytes_impl(module: &Module) -> usize { vec_znx_normalize_tmp_bytes(module.n()) diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs index c3a110a..2c35145 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx.rs @@ -15,7 +15,7 @@ use poulpy_hal::{ VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, - VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, + VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxZeroImpl, }, reference::{ vec_znx::{ @@ -23,7 +23,7 @@ use poulpy_hal::{ vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, vec_znx_merge_rings, vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_normalize_tmp_bytes, vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, - vec_znx_split_ring_tmp_bytes, vec_znx_switch_ring, + vec_znx_split_ring_tmp_bytes, vec_znx_switch_ring, vec_znx_zero, }, znx::{znx_copy_ref, znx_zero_ref}, }, @@ -35,6 +35,15 @@ use crate::cpu_spqlios::{ ffi::{module::module_info_t, vec_znx, znx}, }; +unsafe impl VecZnxZeroImpl for FFT64Spqlios { + fn vec_znx_zero_impl(_module: &Module, res: &mut R, res_col: usize) + where + R: VecZnxToMut, + { + vec_znx_zero::<_, FFT64Spqlios>(res, res_col); + } +} + unsafe impl VecZnxNormalizeTmpBytesImpl for FFT64Spqlios { fn vec_znx_normalize_tmp_bytes_impl(module: &Module) -> usize { vec_znx_normalize_tmp_bytes(module.n()) diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index 6dcb52c..95df49a 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -2,7 +2,7 @@ use poulpy_hal::{ api::{ ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, - VecZnxSubInplace, VecZnxSubNegateInplace, + VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxZero, }, layouts::{Backend, Module, Scratch, VecZnx, ZnxZero}, reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes, @@ -262,11 +262,11 @@ where } } -impl GLWECopy for Module where Self: ModuleN + VecZnxCopy {} +impl GLWECopy for Module where Self: ModuleN + VecZnxCopy + VecZnxZero {} pub trait GLWECopy where - Self: ModuleN + VecZnxCopy, + Self: ModuleN + VecZnxCopy + VecZnxZero, { fn glwe_copy(&self, res: &mut R, a: &A) where @@ -278,12 +278,17 @@ where assert_eq!(res.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32); - assert_eq!(res.rank(), a.rank()); - for i in 0..res.rank().as_usize() + 1 { + let min_rank: usize = res.rank().min(a.rank()).as_usize() + 1; + + for i in 0..min_rank { self.vec_znx_copy(res.data_mut(), i, a.data(), i); } + for i in min_rank..(res.rank() + 1).into() { + self.vec_znx_zero(res.data_mut(), i); + } + res.set_k(a.k().min(res.max_k())); res.set_base2k(a.base2k()); } diff --git a/poulpy-hal/src/api/vec_znx.rs b/poulpy-hal/src/api/vec_znx.rs index afa30b8..8bf0e65 100644 --- a/poulpy-hal/src/api/vec_znx.rs +++ b/poulpy-hal/src/api/vec_znx.rs @@ -8,6 +8,12 @@ pub trait VecZnxNormalizeTmpBytes { fn vec_znx_normalize_tmp_bytes(&self) -> usize; } +pub trait VecZnxZero { + fn vec_znx_zero(&self, res: &mut R, res_col: usize) + where + R: VecZnxToMut; +} + pub trait VecZnxNormalize { #[allow(clippy::too_many_arguments)] /// Normalizes the selected column of `a` and stores the result into the selected column of `res`. diff --git a/poulpy-hal/src/delegates/vec_znx.rs b/poulpy-hal/src/delegates/vec_znx.rs index 60a961e..02f512a 100644 --- a/poulpy-hal/src/delegates/vec_znx.rs +++ b/poulpy-hal/src/delegates/vec_znx.rs @@ -6,7 +6,7 @@ use crate::{ VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing, VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubInplace, - VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing, + VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing, VecZnxZero, }, layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}, oep::{ @@ -18,11 +18,23 @@ use crate::{ VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, - VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, + VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxZeroImpl, }, source::Source, }; +impl VecZnxZero for Module +where + B: Backend + VecZnxZeroImpl, +{ + fn vec_znx_zero(&self, res: &mut R, res_col: usize) + where + R: VecZnxToMut, + { + B::vec_znx_zero_impl(self, res, res_col); + } +} + impl VecZnxNormalizeTmpBytes for Module where B: Backend + VecZnxNormalizeTmpBytesImpl, diff --git a/poulpy-hal/src/oep/vec_znx.rs b/poulpy-hal/src/oep/vec_znx.rs index 380253e..47bc94a 100644 --- a/poulpy-hal/src/oep/vec_znx.rs +++ b/poulpy-hal/src/oep/vec_znx.rs @@ -3,6 +3,16 @@ use crate::{ source::Source, }; +/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) +/// * See TODO for reference implementation. +/// * See [crate::api::VecZnxZero] for corresponding public API. +/// # Safety [crate::doc::backend_safety] for safety contract. +pub unsafe trait VecZnxZeroImpl { + fn vec_znx_zero_impl(module: &Module, res: &mut R, res_col: usize) + where + R: VecZnxToMut; +} + /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [crate::api::VecZnxNormalizeTmpBytes] for corresponding public API. diff --git a/poulpy-hal/src/reference/vec_znx/mod.rs b/poulpy-hal/src/reference/vec_znx/mod.rs index 4edb574..cb945ea 100644 --- a/poulpy-hal/src/reference/vec_znx/mod.rs +++ b/poulpy-hal/src/reference/vec_znx/mod.rs @@ -13,6 +13,7 @@ mod split_ring; mod sub; mod sub_scalar; mod switch_ring; +mod zero; pub use add::*; pub use add_scalar::*; @@ -29,3 +30,4 @@ pub use split_ring::*; pub use sub::*; pub use sub_scalar::*; pub use switch_ring::*; +pub use zero::*; diff --git a/poulpy-hal/src/reference/vec_znx/zero.rs b/poulpy-hal/src/reference/vec_znx/zero.rs new file mode 100644 index 0000000..7febbf5 --- /dev/null +++ b/poulpy-hal/src/reference/vec_znx/zero.rs @@ -0,0 +1,16 @@ +use crate::{ + layouts::{VecZnx, VecZnxToMut, ZnxInfos, ZnxViewMut}, + reference::znx::ZnxZero, +}; + +pub fn vec_znx_zero(res: &mut R, res_col: usize) +where + R: VecZnxToMut, + ZNXARI: ZnxZero, +{ + let mut res: VecZnx<&mut [u8]> = res.to_mut(); + let res_size = res.size(); + for j in 0..res_size { + ZNXARI::znx_zero(res.at_mut(res_col, j)); + } +} diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs index bf2116e..a457208 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs @@ -1,11 +1,18 @@ use poulpy_core::{ GLWECopy, GLWERotate, ScratchTakeCore, - layouts::{GGSW, GGSWInfos, GGSWToMut, GLWE, GLWEInfos, GLWEToMut}, + layouts::{GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef}, }; -use poulpy_hal::layouts::{Backend, Scratch}; +use poulpy_hal::layouts::{Backend, Module, Scratch}; use crate::tfhe::bdd_arithmetic::{Cmux, GetGGSWBit, UnsignedInteger}; +impl GGSWBlindRotation for Module +where + Self: GLWEBlindRotation, + Scratch: ScratchTakeCore, +{ +} + pub trait GGSWBlindRotation where Self: GLWEBlindRotation, @@ -19,9 +26,10 @@ where self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos) } - fn ggsw_blind_rotation( + fn ggsw_blind_rotation( &self, res: &mut R, + test_ggsw: &G, k: &K, bit_start: usize, bit_size: usize, @@ -29,15 +37,18 @@ where scratch: &mut Scratch, ) where R: GGSWToMut, + G: GGSWToRef, K: GetGGSWBit, Scratch: ScratchTakeCore, { let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let test_ggsw: &GGSW<&[u8]> = &test_ggsw.to_ref(); for row in 0..res.dnum().into() { for col in 0..(res.rank() + 1).into() { self.glwe_blind_rotation( &mut res.at_mut(row, col), + &test_ggsw.at(row, col), k, bit_start, bit_size, @@ -49,6 +60,13 @@ where } } +impl GLWEBlindRotation for Module +where + Self: GLWECopy + GLWERotate + Cmux, + Scratch: ScratchTakeCore, +{ +} + pub trait GLWEBlindRotation where Self: GLWECopy + GLWERotate + Cmux, @@ -63,9 +81,10 @@ where } /// Homomorphic multiplication of res by X^{k[bit_start..bit_start + bit_size] * bit_step}. - fn glwe_blind_rotation( + fn glwe_blind_rotation( &self, res: &mut R, + test_glwe: &G, k: &K, bit_start: usize, bit_size: usize, @@ -73,21 +92,43 @@ where scratch: &mut Scratch, ) where R: GLWEToMut, + G: GLWEToRef, K: GetGGSWBit, Scratch: ScratchTakeCore, { - let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + assert!(bit_start + bit_size <= T::WORD_SIZE); - let (mut tmp_res, scratch_1) = scratch.take_glwe(res); + let mut res: GLWE<&mut [u8]> = res.to_mut(); - self.glwe_copy(&mut tmp_res, res); + let (mut tmp_res, scratch_1) = scratch.take_glwe(&res); - for i in 1..bit_size { - // res' = res * X^2^(i * bit_step) - self.glwe_rotate(1 << (i + bit_step), &mut tmp_res, res); + // res <- test_glwe + self.glwe_copy(&mut res, test_glwe); - // res = (res - res') * GGSW(b[i]) + res' - self.cmux_inplace(res, &tmp_res, &k.get_bit(i + bit_start), scratch_1); + // a_is_res = true => (a, b) = (&mut res, &mut tmp_res) + // a_is_res = false => (a, b) = (&mut tmp_res, &mut res) + let mut a_is_res: bool = true; + + for i in 0..bit_size { + let (a, b) = if a_is_res { + (&mut res, &mut tmp_res) + } else { + (&mut tmp_res, &mut res) + }; + + // a <- a ; b <- a * X^{-2^{i + bit_step}} + self.glwe_rotate(-1 << (i + bit_step), b, a); + + // b <- (b - a) * GGSW(b[i]) + a + self.cmux_inplace(b, a, &k.get_bit(i + bit_start), scratch_1); + + // ping-pong roles for next iter + a_is_res = !a_is_res; + } + + // Ensure the final value ends up in `res` + if !a_is_res { + self.glwe_copy(&mut res, &tmp_res); } } } diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs index fbae3c8..05c5085 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs @@ -3,11 +3,16 @@ use poulpy_backend::FFT64Ref; use crate::tfhe::{ bdd_arithmetic::tests::test_suite::{ test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, test_bdd_sra, - test_bdd_srl, test_bdd_sub, test_bdd_xor, + test_bdd_srl, test_bdd_sub, test_bdd_xor, test_glwe_blind_rotation, }, blind_rotation::CGGI, }; +#[test] +fn test_glwe_blind_rotation_fft64_ref() { + test_glwe_blind_rotation::() +} + #[test] fn test_bdd_prepare_fft64_ref() { test_bdd_prepare::() diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs new file mode 100644 index 0000000..ed1fdd0 --- /dev/null +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs @@ -0,0 +1,134 @@ +use poulpy_core::{ + GGSWEncryptSk, GLWEDecrypt, GLWEEncryptSk, ScratchTakeCore, + layouts::{ + Base2K, Degree, Dnum, Dsize, GGSWLayout, GGSWPreparedFactory, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, + GLWESecretPrepared, GLWESecretPreparedFactory, LWEInfos, Rank, TorusPrecision, + }, +}; +use poulpy_hal::{ + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Module, Scratch, ScratchOwned}, + source::Source, +}; +use rand::RngCore; + +use crate::tfhe::bdd_arithmetic::{FheUintBlocksPrepared, GLWEBlindRotation}; + +pub fn test_glwe_blind_rotation() +where + Module: ModuleNew + + GLWESecretPreparedFactory + + GGSWPreparedFactory + + GGSWEncryptSk + + GLWEBlindRotation + + GLWEDecrypt + + GLWEEncryptSk, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, +{ + let n: Degree = Degree(1 << 11); + let base2k: Base2K = Base2K(13); + let rank: Rank = Rank(1); + let k_glwe: TorusPrecision = TorusPrecision(26); + let k_ggsw: TorusPrecision = TorusPrecision(39); + let dnum: Dnum = Dnum(3); + + let glwe_infos: GLWELayout = GLWELayout { + n, + base2k, + k: k_glwe, + rank, + }; + let ggsw_infos: GGSWLayout = GGSWLayout { + n, + base2k, + k: k_ggsw, + rank, + dnum, + dsize: Dsize(1), + }; + + let n_glwe: usize = glwe_infos.n().into(); + + let module: Module = Module::::new(n_glwe as u64); + let mut source: Source = Source::new([6u8; 32]); + let mut source_xs: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([2u8; 32]); + let mut source_xe: Source = Source::new([3u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc_from_infos(&glwe_infos); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_glwe_prep: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc_from_infos(&module, &glwe_infos); + sk_glwe_prep.prepare(&module, &sk_glwe); + + let mut res: GLWE> = GLWE::alloc_from_infos(&glwe_infos); + + let mut test_glwe: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); + let mut data: Vec = vec![0i64; module.n()]; + data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + test_glwe.encode_vec_i64(&data, base2k.as_usize().into()); + + println!("pt: {}", test_glwe); + + let k: u32 = source.next_u32(); + + println!("k: {k}"); + + let mut k_enc_prep: FheUintBlocksPrepared, u32, BE> = + FheUintBlocksPrepared::, u32, BE>::alloc(&module, &ggsw_infos); + k_enc_prep.encrypt_sk( + &module, + k, + &sk_glwe_prep, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let base: [usize; 2] = [6, 5]; + + assert_eq!(base.iter().sum::(), module.log_n()); + + // Starting bit + let mut bit_start: usize = 0; + + let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&glwe_infos); + + for _ in 0..32_usize.div_ceil(module.log_n()) { + // By how many bits to left shift + let mut bit_step: usize = 0; + + for digit in base { + let mask: u32 = (1 << digit) - 1; + + // How many bits to take + let bit_size: usize = (32 - bit_start).min(digit); + + module.glwe_blind_rotation( + &mut res, + &test_glwe, + &k_enc_prep, + bit_start, + bit_size, + bit_step, + scratch.borrow(), + ); + + res.decrypt(&module, &mut pt, &sk_glwe_prep, scratch.borrow()); + + assert_eq!( + (((k >> bit_start) & mask) << bit_step) as i64, + pt.decode_coeff_i64(base2k.as_usize().into(), 0) + ); + + bit_step += digit; + bit_start += digit; + + if bit_start >= 32 { + break; + } + } + } +} diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs index 73c96d3..76ef5e9 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs @@ -1,5 +1,6 @@ mod add; mod and; +mod glwe_blind_rotation; mod or; mod prepare; mod sll; @@ -12,6 +13,7 @@ mod xor; pub use add::*; pub use and::*; +pub use glwe_blind_rotation::*; pub use or::*; pub use prepare::*; pub use sll::*; From 6dd93ceaeaa4090f19b7746d0d53fc930fba198a Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Sun, 26 Oct 2025 10:28:13 +0100 Subject: [PATCH 06/11] Add test for ggsw scalar blind rotation --- poulpy-core/src/noise/ggsw.rs | 1 + .../src/tfhe/bdd_arithmetic/bdd_rotation.rs | 126 +++++++++++---- .../tfhe/bdd_arithmetic/tests/fft64_ref.rs | 7 +- .../tests/test_suite/ggsw_blind_rotations.rs | 149 ++++++++++++++++++ .../bdd_arithmetic/tests/test_suite/mod.rs | 2 + 5 files changed, 254 insertions(+), 31 deletions(-) create mode 100644 poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs diff --git a/poulpy-core/src/noise/ggsw.rs b/poulpy-core/src/noise/ggsw.rs index b6805c6..616d7c2 100644 --- a/poulpy-core/src/noise/ggsw.rs +++ b/poulpy-core/src/noise/ggsw.rs @@ -162,6 +162,7 @@ where sk_prepared, scratch.borrow(), ); + self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0); let std_pt: f64 = pt_have.data.std(base2k, 0).log2(); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs index a457208..89965ef 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs @@ -1,24 +1,27 @@ use poulpy_core::{ GLWECopy, GLWERotate, ScratchTakeCore, - layouts::{GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef}, + layouts::{GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos}, +}; +use poulpy_hal::{ + api::{VecZnxAddScalarInplace, VecZnxNormalizeInplace}, + layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxZero}, }; -use poulpy_hal::layouts::{Backend, Module, Scratch}; use crate::tfhe::bdd_arithmetic::{Cmux, GetGGSWBit, UnsignedInteger}; impl GGSWBlindRotation for Module where - Self: GLWEBlindRotation, + Self: GLWEBlindRotation + VecZnxAddScalarInplace + VecZnxNormalizeInplace, Scratch: ScratchTakeCore, { } pub trait GGSWBlindRotation where - Self: GLWEBlindRotation, + Self: GLWEBlindRotation + VecZnxAddScalarInplace + VecZnxNormalizeInplace, Scratch: ScratchTakeCore, { - fn ggsw_blind_rotation_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize + fn ggsw_blind_rotate_from_ggsw_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize where R: GLWEInfos, K: GGSWInfos, @@ -26,38 +29,98 @@ where self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos) } - fn ggsw_blind_rotation( + /// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. + fn ggsw_blind_rotate_from_ggsw( &self, res: &mut R, - test_ggsw: &G, + a: &A, k: &K, bit_start: usize, - bit_size: usize, - bit_step: usize, + bit_mask: usize, + bit_lsh: usize, scratch: &mut Scratch, ) where R: GGSWToMut, - G: GGSWToRef, + A: GGSWToRef, K: GetGGSWBit, Scratch: ScratchTakeCore, { let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); - let test_ggsw: &GGSW<&[u8]> = &test_ggsw.to_ref(); + let a: &GGSW<&[u8]> = &a.to_ref(); - for row in 0..res.dnum().into() { - for col in 0..(res.rank() + 1).into() { + assert!(res.dnum() <= a.dnum()); + assert_eq!(res.dsize(), a.dsize()); + + for col in 0..(res.rank() + 1).into() { + for row in 0..res.dnum().into() { self.glwe_blind_rotation( &mut res.at_mut(row, col), - &test_ggsw.at(row, col), + &a.at(row, col), k, bit_start, - bit_size, - bit_step, + bit_mask, + bit_lsh, scratch, ); } } } + + fn ggsw_blind_rotate_from_scalar_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize + where + R: GLWEInfos, + K: GGSWInfos, + { + self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) + } + + fn ggsw_blind_rotate_from_scalar( + &self, + res: &mut R, + test_vector: &S, + k: &K, + bit_start: usize, + bit_mask: usize, + bit_lsh: usize, + scratch: &mut Scratch, + ) where + R: GGSWToMut, + S: ScalarZnxToRef, + K: GetGGSWBit, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); + let test_vector: &ScalarZnx<&[u8]> = &test_vector.to_ref(); + + let base2k: usize = res.base2k().into(); + let dsize: usize = res.dsize().into(); + + let (mut tmp_glwe, scratch_1) = scratch.take_glwe(res); + + for col in 0..(res.rank() + 1).into() { + for row in 0..res.dnum().into() { + tmp_glwe.data_mut().zero(); + self.vec_znx_add_scalar_inplace( + tmp_glwe.data_mut(), + col, + (dsize - 1) + row * dsize, + test_vector, + 0, + ); + self.vec_znx_normalize_inplace(base2k, tmp_glwe.data_mut(), col, scratch_1); + + self.glwe_blind_rotation( + &mut res.at_mut(row, col), + &tmp_glwe, + k, + bit_start, + bit_mask, + bit_lsh, + scratch_1, + ); + } + } + } } impl GLWEBlindRotation for Module @@ -80,47 +143,50 @@ where self.cmux_tmp_bytes(res_infos, res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) } - /// Homomorphic multiplication of res by X^{k[bit_start..bit_start + bit_size] * bit_step}. - fn glwe_blind_rotation( + /// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. + fn glwe_blind_rotation( &self, res: &mut R, - test_glwe: &G, + a: &A, k: &K, - bit_start: usize, - bit_size: usize, - bit_step: usize, + bit_rsh: usize, + bit_mask: usize, + bit_lsh: usize, scratch: &mut Scratch, ) where R: GLWEToMut, - G: GLWEToRef, + A: GLWEToRef, K: GetGGSWBit, Scratch: ScratchTakeCore, { - assert!(bit_start + bit_size <= T::WORD_SIZE); + assert!(bit_rsh + bit_mask <= T::WORD_SIZE); let mut res: GLWE<&mut [u8]> = res.to_mut(); let (mut tmp_res, scratch_1) = scratch.take_glwe(&res); - // res <- test_glwe - self.glwe_copy(&mut res, test_glwe); + // a <- a ; b <- a * X^{-2^{i + bit_lsh}} + self.glwe_rotate(-1 << bit_lsh, &mut res, a); + + // b <- (b - a) * GGSW(b[i]) + a + self.cmux_inplace(&mut res, a, &k.get_bit(bit_rsh), scratch_1); // a_is_res = true => (a, b) = (&mut res, &mut tmp_res) // a_is_res = false => (a, b) = (&mut tmp_res, &mut res) let mut a_is_res: bool = true; - for i in 0..bit_size { + for i in 1..bit_mask { let (a, b) = if a_is_res { (&mut res, &mut tmp_res) } else { (&mut tmp_res, &mut res) }; - // a <- a ; b <- a * X^{-2^{i + bit_step}} - self.glwe_rotate(-1 << (i + bit_step), b, a); + // a <- a ; b <- a * X^{-2^{i + bit_lsh}} + self.glwe_rotate(-1 << (i + bit_lsh), b, a); // b <- (b - a) * GGSW(b[i]) + a - self.cmux_inplace(b, a, &k.get_bit(i + bit_start), scratch_1); + self.cmux_inplace(b, a, &k.get_bit(i + bit_rsh), scratch_1); // ping-pong roles for next iter a_is_res = !a_is_res; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs index 05c5085..94c7352 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs @@ -3,7 +3,7 @@ use poulpy_backend::FFT64Ref; use crate::tfhe::{ bdd_arithmetic::tests::test_suite::{ test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, test_bdd_sra, - test_bdd_srl, test_bdd_sub, test_bdd_xor, test_glwe_blind_rotation, + test_bdd_srl, test_bdd_sub, test_bdd_xor, test_ggsw_blind_rotation, test_glwe_blind_rotation, }, blind_rotation::CGGI, }; @@ -13,6 +13,11 @@ fn test_glwe_blind_rotation_fft64_ref() { test_glwe_blind_rotation::() } +#[test] +fn test_ggsw_blind_rotation_fft64_ref() { + test_ggsw_blind_rotation::() +} + #[test] fn test_bdd_prepare_fft64_ref() { test_bdd_prepare::() diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs new file mode 100644 index 0000000..1e0146f --- /dev/null +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs @@ -0,0 +1,149 @@ +use poulpy_core::{ + GGSWEncryptSk, GGSWNoise, GLWEDecrypt, GLWEEncryptSk, SIGMA, ScratchTakeCore, + layouts::{ + Base2K, Degree, Dnum, Dsize, GGSW, GGSWLayout, GGSWPreparedFactory, GLWESecret, GLWESecretPrepared, + GLWESecretPreparedFactory, LWEInfos, Rank, TorusPrecision, + }, +}; +use poulpy_hal::{ + api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxRotateInplace}, + layouts::{Backend, Module, ScalarZnx, Scratch, ScratchOwned, ZnxView, ZnxViewMut}, + source::Source, +}; +use rand::RngCore; + +use crate::tfhe::bdd_arithmetic::{FheUintBlocksPrepared, GGSWBlindRotation}; + +pub fn test_ggsw_blind_rotation() +where + Module: ModuleNew + + GLWESecretPreparedFactory + + GGSWPreparedFactory + + GGSWEncryptSk + + GGSWBlindRotation + + GGSWNoise + + GLWEDecrypt + + GLWEEncryptSk + + VecZnxRotateInplace, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchTakeCore, +{ + let n: Degree = Degree(1 << 11); + let base2k: Base2K = Base2K(13); + let rank: Rank = Rank(1); + let k_ggsw_res: TorusPrecision = TorusPrecision(39); + let k_ggsw_apply: TorusPrecision = TorusPrecision(52); + + let ggsw_res_infos: GGSWLayout = GGSWLayout { + n, + base2k, + k: k_ggsw_res, + rank, + dnum: Dnum(2), + dsize: Dsize(1), + }; + + let ggsw_k_infos: GGSWLayout = GGSWLayout { + n, + base2k, + k: k_ggsw_apply, + rank, + dnum: Dnum(3), + dsize: Dsize(1), + }; + + let n_glwe: usize = n.into(); + + let module: Module = Module::::new(n_glwe as u64); + let mut source: Source = Source::new([6u8; 32]); + let mut source_xs: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([2u8; 32]); + let mut source_xe: Source = Source::new([3u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(1 << 22); + + let mut sk_glwe: GLWESecret> = GLWESecret::alloc(n, rank); + sk_glwe.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_glwe_prep: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(&module, rank); + sk_glwe_prep.prepare(&module, &sk_glwe); + + let mut res: GGSW> = GGSW::alloc_from_infos(&ggsw_res_infos); + + let mut scalar: ScalarZnx> = ScalarZnx::alloc(n_glwe, 1); + scalar + .raw_mut() + .iter_mut() + .enumerate() + .for_each(|(i, x)| *x = i as i64); + + let k: u32 = source.next_u32(); + + // println!("k: {k}"); + + let mut k_enc_prep: FheUintBlocksPrepared, u32, BE> = + FheUintBlocksPrepared::, u32, BE>::alloc(&module, &ggsw_k_infos); + k_enc_prep.encrypt_sk( + &module, + k, + &sk_glwe_prep, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let base: [usize; 2] = [6, 5]; + + assert_eq!(base.iter().sum::(), module.log_n()); + + // Starting bit + let mut bit_start: usize = 0; + + let max_noise = |col_i: usize| { + let mut noise: f64 = -(ggsw_res_infos.size() as f64 * base2k.as_usize() as f64) + SIGMA.log2() + 2.0; + noise += 0.5 * ggsw_res_infos.log_n() as f64; + if col_i != 0 { + noise += 0.5 * ggsw_res_infos.log_n() as f64 + } + noise + }; + + for _ in 0..32_usize.div_ceil(module.log_n()) { + // By how many bits to left shift + let mut bit_step: usize = 0; + + for digit in base { + let mask: u32 = (1 << digit) - 1; + + // How many bits to take + let bit_size: usize = (32 - bit_start).min(digit); + + module.ggsw_blind_rotate_from_scalar( + &mut res, + &scalar, + &k_enc_prep, + bit_start, + bit_size, + bit_step, + scratch.borrow(), + ); + + let rot: i64 = (((k >> bit_start) & mask) << bit_step) as i64; + + let mut scalar_want: ScalarZnx> = ScalarZnx::alloc(module.n(), 1); + scalar_want.raw_mut().copy_from_slice(scalar.raw()); + + module.vec_znx_rotate_inplace(-rot, &mut scalar_want.as_vec_znx_mut(), 0, scratch.borrow()); + + // res.print_noise(&module, &sk_glwe_prep, &scalar_want); + + res.assert_noise(&module, &sk_glwe_prep, &scalar_want, &max_noise); + + bit_step += digit; + bit_start += digit; + + if bit_start >= 32 { + break; + } + } + } +} diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs index 76ef5e9..1b2c645 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/mod.rs @@ -1,5 +1,6 @@ mod add; mod and; +mod ggsw_blind_rotations; mod glwe_blind_rotation; mod or; mod prepare; @@ -13,6 +14,7 @@ mod xor; pub use add::*; pub use and::*; +pub use ggsw_blind_rotations::*; pub use glwe_blind_rotation::*; pub use or::*; pub use prepare::*; From 96c32c531cde1cfd08031207454a53dc9ae145f3 Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Sun, 26 Oct 2025 10:45:30 +0100 Subject: [PATCH 07/11] rename to what it actually does --- .../{bdd_rotation.rs => blind_rotation.rs} | 20 +++++++++---------- poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs | 4 ++-- .../tfhe/bdd_arithmetic/tests/fft64_ref.rs | 10 +++++----- .../tests/test_suite/ggsw_blind_rotations.rs | 4 ++-- .../tests/test_suite/glwe_blind_rotation.rs | 4 ++-- 5 files changed, 21 insertions(+), 21 deletions(-) rename poulpy-schemes/src/tfhe/bdd_arithmetic/{bdd_rotation.rs => blind_rotation.rs} (89%) diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_rotation.rs similarity index 89% rename from poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs rename to poulpy-schemes/src/tfhe/bdd_arithmetic/blind_rotation.rs index 89965ef..27d97d5 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/bdd_rotation.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_rotation.rs @@ -21,16 +21,16 @@ where Self: GLWEBlindRotation + VecZnxAddScalarInplace + VecZnxNormalizeInplace, Scratch: ScratchTakeCore, { - fn ggsw_blind_rotate_from_ggsw_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize + fn ggsw_to_ggsw_blind_rotation_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize where R: GLWEInfos, K: GGSWInfos, { - self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos) } /// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. - fn ggsw_blind_rotate_from_ggsw( + fn ggsw_to_ggsw_blind_rotation( &self, res: &mut R, a: &A, @@ -53,7 +53,7 @@ where for col in 0..(res.rank() + 1).into() { for row in 0..res.dnum().into() { - self.glwe_blind_rotation( + self.glwe_to_glwe_blind_rotation( &mut res.at_mut(row, col), &a.at(row, col), k, @@ -66,15 +66,15 @@ where } } - fn ggsw_blind_rotate_from_scalar_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize + fn scalar_to_ggsw_blind_rotation_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize where R: GLWEInfos, K: GGSWInfos, { - self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) + self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) } - fn ggsw_blind_rotate_from_scalar( + fn scalar_to_ggsw_blind_rotation( &self, res: &mut R, test_vector: &S, @@ -109,7 +109,7 @@ where ); self.vec_znx_normalize_inplace(base2k, tmp_glwe.data_mut(), col, scratch_1); - self.glwe_blind_rotation( + self.glwe_to_glwe_blind_rotation( &mut res.at_mut(row, col), &tmp_glwe, k, @@ -135,7 +135,7 @@ where Self: GLWECopy + GLWERotate + Cmux, Scratch: ScratchTakeCore, { - fn glwe_blind_rotation_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize + fn glwe_to_glwe_blind_rotation_tmp_bytes(&self, res_infos: &R, k_infos: &K) -> usize where R: GLWEInfos, K: GGSWInfos, @@ -144,7 +144,7 @@ where } /// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. - fn glwe_blind_rotation( + fn glwe_to_glwe_blind_rotation( &self, res: &mut R, a: &A, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs index 22e5073..f73eb30 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/mod.rs @@ -1,12 +1,12 @@ mod bdd_2w_to_1w; -mod bdd_rotation; +mod blind_rotation; mod ciphertexts; mod circuits; mod eval; mod key; pub use bdd_2w_to_1w::*; -pub use bdd_rotation::*; +pub use blind_rotation::*; pub use ciphertexts::*; pub(crate) use circuits::*; pub(crate) use eval::*; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs index 94c7352..a34cdae 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs @@ -3,19 +3,19 @@ use poulpy_backend::FFT64Ref; use crate::tfhe::{ bdd_arithmetic::tests::test_suite::{ test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, test_bdd_sra, - test_bdd_srl, test_bdd_sub, test_bdd_xor, test_ggsw_blind_rotation, test_glwe_blind_rotation, + test_bdd_srl, test_bdd_sub, test_bdd_xor, test_scalar_to_ggsw_blind_rotation, test_glwe_to_glwe_blind_rotation, }, blind_rotation::CGGI, }; #[test] -fn test_glwe_blind_rotation_fft64_ref() { - test_glwe_blind_rotation::() +fn test_glwe_to_glwe_blind_rotation_fft64_ref() { + test_glwe_to_glwe_blind_rotation::() } #[test] -fn test_ggsw_blind_rotation_fft64_ref() { - test_ggsw_blind_rotation::() +fn test_scalar_to_ggsw_blind_rotation_fft64_ref() { + test_scalar_to_ggsw_blind_rotation::() } #[test] diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs index 1e0146f..5a0f6bd 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/ggsw_blind_rotations.rs @@ -14,7 +14,7 @@ use rand::RngCore; use crate::tfhe::bdd_arithmetic::{FheUintBlocksPrepared, GGSWBlindRotation}; -pub fn test_ggsw_blind_rotation() +pub fn test_scalar_to_ggsw_blind_rotation() where Module: ModuleNew + GLWESecretPreparedFactory @@ -117,7 +117,7 @@ where // How many bits to take let bit_size: usize = (32 - bit_start).min(digit); - module.ggsw_blind_rotate_from_scalar( + module.scalar_to_ggsw_blind_rotation( &mut res, &scalar, &k_enc_prep, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs index ed1fdd0..47f97f3 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs @@ -14,7 +14,7 @@ use rand::RngCore; use crate::tfhe::bdd_arithmetic::{FheUintBlocksPrepared, GLWEBlindRotation}; -pub fn test_glwe_blind_rotation() +pub fn test_glwe_to_glwe_blind_rotation() where Module: ModuleNew + GLWESecretPreparedFactory @@ -106,7 +106,7 @@ where // How many bits to take let bit_size: usize = (32 - bit_start).min(digit); - module.glwe_blind_rotation( + module.glwe_to_glwe_blind_rotation( &mut res, &test_glwe, &k_enc_prep, From 881483d1bbc148581d23edf7b8440d01159b4357 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Sun, 26 Oct 2025 16:32:22 +0100 Subject: [PATCH 08/11] wip --- .../src/cpu_fft64_avx/vec_znx_dft.rs | 4 +- poulpy-backend/src/cpu_fft64_ref/tests.rs | 2 +- .../src/cpu_fft64_ref/vec_znx_dft.rs | 4 +- .../src/cpu_spqlios/fft64/vec_znx_dft.rs | 6 +- poulpy-hal/src/api/convolution.rs | 161 +++++++++++------- poulpy-hal/src/api/vec_znx_dft.rs | 2 +- poulpy-hal/src/delegates/vec_znx_dft.rs | 4 +- poulpy-hal/src/oep/vec_znx_dft.rs | 2 +- poulpy-hal/src/reference/fft64/vec_znx_dft.rs | 9 +- poulpy-hal/src/test_suite/convolution.rs | 102 ++++++----- .../algorithms/cggi/algorithm.rs | 17 +- 11 files changed, 173 insertions(+), 140 deletions(-) diff --git a/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs b/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs index 57ffc6f..1e1954e 100644 --- a/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_fft64_avx/vec_znx_dft.rs @@ -194,10 +194,10 @@ unsafe impl VecZnxDftCopyImpl for FFT64Avx { } unsafe impl VecZnxDftZeroImpl for FFT64Avx { - fn vec_znx_dft_zero_impl(_module: &Module, res: &mut R) + fn vec_znx_dft_zero_impl(_module: &Module, res: &mut R, res_col: usize) where R: VecZnxDftToMut, { - vec_znx_dft_zero(res); + vec_znx_dft_zero(res, res_col); } } diff --git a/poulpy-backend/src/cpu_fft64_ref/tests.rs b/poulpy-backend/src/cpu_fft64_ref/tests.rs index 3f824c3..4531117 100644 --- a/poulpy-backend/src/cpu_fft64_ref/tests.rs +++ b/poulpy-backend/src/cpu_fft64_ref/tests.rs @@ -4,6 +4,6 @@ use crate::FFT64Ref; #[test] fn test_convolution_fft64_ref() { - let module: Module = Module::::new(64); + let module: Module = Module::::new(8); test_convolution(&module); } diff --git a/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs b/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs index 5ad6400..b6ee4dd 100644 --- a/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_fft64_ref/vec_znx_dft.rs @@ -194,10 +194,10 @@ unsafe impl VecZnxDftCopyImpl for FFT64Ref { } unsafe impl VecZnxDftZeroImpl for FFT64Ref { - fn vec_znx_dft_zero_impl(_module: &Module, res: &mut R) + fn vec_znx_dft_zero_impl(_module: &Module, res: &mut R, res_col: usize) where R: VecZnxDftToMut, { - vec_znx_dft_zero(res); + vec_znx_dft_zero(res, res_col); } } diff --git a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs index 3b67089..cdffb41 100644 --- a/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs +++ b/poulpy-backend/src/cpu_spqlios/fft64/vec_znx_dft.rs @@ -12,7 +12,7 @@ use poulpy_hal::{ reference::{ fft64::{ reim::{ReimCopy, ReimZero, reim_copy_ref, reim_negate_inplace_ref, reim_negate_ref, reim_zero_ref}, - vec_znx_dft::vec_znx_dft_copy, + vec_znx_dft::{vec_znx_dft_copy, vec_znx_dft_zero}, }, znx::znx_zero_ref, }, @@ -426,10 +426,10 @@ impl ReimZero for FFT64Spqlios { } unsafe impl VecZnxDftZeroImpl for FFT64Spqlios { - fn vec_znx_dft_zero_impl(_module: &Module, res: &mut R) + fn vec_znx_dft_zero_impl(_module: &Module, res: &mut R, res_col: usize) where R: VecZnxDftToMut, { - res.to_mut().data.fill(0); + vec_znx_dft_zero(res, res_col); } } diff --git a/poulpy-hal/src/api/convolution.rs b/poulpy-hal/src/api/convolution.rs index 2f32de2..9b34ead 100644 --- a/poulpy-hal/src/api/convolution.rs +++ b/poulpy-hal/src/api/convolution.rs @@ -1,9 +1,9 @@ use crate::{ api::{ ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxDftAddScaledInplace, - VecZnxDftBytesOf, + VecZnxDftBytesOf, VecZnxDftZero, }, - layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxZero}, + layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos}, }; impl Convolution for Module @@ -15,7 +15,8 @@ where + SvpPrepare + SvpPPolBytesOf + VecZnxDftBytesOf - + VecZnxDftAddScaledInplace, + + VecZnxDftAddScaledInplace + + VecZnxDftZero, Scratch: ScratchTakeBasic, { } @@ -29,46 +30,15 @@ where + SvpPrepare + SvpPPolBytesOf + VecZnxDftBytesOf - + VecZnxDftAddScaledInplace, + + VecZnxDftAddScaledInplace + + VecZnxDftZero, Scratch: ScratchTakeBasic, { - fn convolution_tmp_bytes(&self, res_size: usize) -> usize { - self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, res_size) + fn convolution_tmp_bytes(&self, b_size: usize) -> usize { + self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size) } - /// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K - /// and scales the result by 2^{res_scale * K} - /// - /// # Example - /// a = [a00, a10, a20, a30] = (a00 * 2^-K + a01 * 2^-2K) + (a10 * 2^-K + a11 * 2^-2K) * X ... - /// [a01, a11, a21, a31] - /// - /// b = [b00, b10, b20, b30] = (b00 * 2^-K + b01 * 2^-2K) + (b10 * 2^-K + b11 * 2^-2K) * X ... - /// [b01, b11, b21, b31] - /// - /// If res_scale = 0: - /// res = [ 0, 0, 0, 0] = (r01 * 2^-2K + r02 * 2^-3K + r03 * 2^-4K + r04 * 2^-5K) + ... - /// [r01, r11, r21, r31] - /// [r02, r12, r22, r32] - /// [r03, r13, r23, r33] - /// [r04, r14, r24, r34] - /// - /// If res_scale = 1: - /// res = [r01, r11, r21, r31] = (r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K + r04 * 2^-4K + r05 * 2^-5K) + ... - /// [r02, r12, r22, r32] - /// [r03, r13, r23, r33] - /// [r04, r14, r24, r34] - /// [r05, r15, r25, r35] - /// - /// If res_scale = -1: - /// res = [ 0, 0, 0, 0] = (r01 * 2^-3K + r02 * 2^-4K + r03 * 2^-5K) + ... - /// [ 0, 0, 0, 0] - /// [r01, r11, r21, r31] - /// [r02, r12, r22, r32] - /// [r03, r13, r23, r33] - /// - /// If res.size() < a.size() + b.size() + 1 + res_scale, result is truncated accordingly in the Y dimension. - fn convolution(&self, res: &mut R, res_scale: i64, a: &A, b: &B, scratch: &mut Scratch) + fn bivariate_convolution_full(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) where R: VecZnxDftToMut, A: VecZnxToRef, @@ -78,32 +48,99 @@ where let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref(); let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref(); - assert!(res.cols() >= a.cols() + b.cols() - 1); + let res_cols: usize = res.cols(); + let a_cols: usize = a.cols(); + let b_cols: usize = b.cols(); - res.zero(); + assert!(res_cols >= a_cols + b_cols - 1); - let (mut ppol, scratch_1) = scratch.take_svp_ppol(self, 1); - let (mut res_tmp, _) = scratch_1.take_vec_znx_dft(self, 1, res.size()); - - for a_col in 0..a.cols() { - for a_limb in 0..a.size() { - // Prepares the j-th limb of the i-th col of A - self.svp_prepare(&mut ppol, 0, &a.as_scalar_znx_ref(a_col, a_limb), 0); - - for b_col in 0..b.cols() { - // Multiplies with the i-th col of B - self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col); - - // Adds on the [a_col + b_col] of res, scaled by 2^{-(a_limb + 1) * Base2K} - self.vec_znx_dft_add_scaled_inplace( - res, - a_col + b_col, - &res_tmp, - 0, - -(1 + a_limb as i64) + res_scale, - ); - } + for res_col in 0..res_cols { + let a_min: usize = res_col.saturating_sub(b_cols - 1); + let a_max: usize = res_col.min(a_cols - 1); + self.bivariate_convolution_single(k, res, res_col, a, a_min, b, res_col - a_min, scratch); + for a_col in a_min + 1..a_max + 1 { + self.bivariate_convolution_single_add(k, res, res_col, a, a_col, b, res_col - a_col, scratch); } } } + + /// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K over the + /// selected columsn and stores the result on the selected column, scaled by 2^{k * Base2K} + /// + /// # Example + /// a = [a00, a10, a20, a30] = (a00 * 2^-K + a01 * 2^-2K) + (a10 * 2^-K + a11 * 2^-2K) * X ... + /// [a01, a11, a21, a31] + /// + /// b = [b00, b10, b20, b30] = (b00 * 2^-K + b01 * 2^-2K) + (b10 * 2^-K + b11 * 2^-2K) * X ... + /// [b01, b11, b21, b31] + /// + /// If k = 0: + /// res = [ 0, 0, 0, 0] = (r01 * 2^-2K + r02 * 2^-3K + r03 * 2^-4K + r04 * 2^-5K) + ... + /// [r01, r11, r21, r31] + /// [r02, r12, r22, r32] + /// [r03, r13, r23, r33] + /// [r04, r14, r24, r34] + /// + /// If k = 1: + /// res = [r01, r11, r21, r31] = (r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K + r04 * 2^-4K + r05 * 2^-5K) + ... + /// [r02, r12, r22, r32] + /// [r03, r13, r23, r33] + /// [r04, r14, r24, r34] + /// [r05, r15, r25, r35] + /// + /// If k = -1: + /// res = [ 0, 0, 0, 0] = (r01 * 2^-3K + r02 * 2^-4K + r03 * 2^-5K) + ... + /// [ 0, 0, 0, 0] + /// [r01, r11, r21, r31] + /// [r02, r12, r22, r32] + /// [r03, r13, r23, r33] + /// + /// If res.size() < a.size() + b.size() + 1 + k, result is truncated accordingly in the Y dimension. + fn bivariate_convolution_single_add( + &self, + k: i64, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxDftToMut, + A: VecZnxToRef, + B: VecZnxDftToRef, + { + let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut(); + let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref(); + let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref(); + + let (mut ppol, scratch_1) = scratch.take_svp_ppol(self, 1); + let (mut res_tmp, _) = scratch_1.take_vec_znx_dft(self, 1, b.size()); + + for a_limb in 0..a.size() { + self.svp_prepare(&mut ppol, 0, &a.as_scalar_znx_ref(a_col, a_limb), 0); + self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col); + self.vec_znx_dft_add_scaled_inplace(res, res_col, &res_tmp, 0, -(1 + a_limb as i64) + k); + } + } + + fn bivariate_convolution_single( + &self, + k: i64, + res: &mut R, + res_col: usize, + a: &A, + a_col: usize, + b: &B, + b_col: usize, + scratch: &mut Scratch, + ) where + R: VecZnxDftToMut, + A: VecZnxToRef, + B: VecZnxDftToRef, + { + self.vec_znx_dft_zero(res, res_col); + self.bivariate_convolution_single_add(k, res, res_col, a, a_col, b, b_col, scratch); + } } diff --git a/poulpy-hal/src/api/vec_znx_dft.rs b/poulpy-hal/src/api/vec_znx_dft.rs index 61396c4..0044c18 100644 --- a/poulpy-hal/src/api/vec_znx_dft.rs +++ b/poulpy-hal/src/api/vec_znx_dft.rs @@ -97,7 +97,7 @@ pub trait VecZnxDftCopy { } pub trait VecZnxDftZero { - fn vec_znx_dft_zero(&self, res: &mut R) + fn vec_znx_dft_zero(&self, res: &mut R, res_col: usize) where R: VecZnxDftToMut; } diff --git a/poulpy-hal/src/delegates/vec_znx_dft.rs b/poulpy-hal/src/delegates/vec_znx_dft.rs index 7dfb25f..3e9cd03 100644 --- a/poulpy-hal/src/delegates/vec_znx_dft.rs +++ b/poulpy-hal/src/delegates/vec_znx_dft.rs @@ -200,10 +200,10 @@ impl VecZnxDftZero for Module where B: Backend + VecZnxDftZeroImpl, { - fn vec_znx_dft_zero(&self, res: &mut R) + fn vec_znx_dft_zero(&self, res: &mut R, res_col: usize) where R: VecZnxDftToMut, { - B::vec_znx_dft_zero_impl(self, res); + B::vec_znx_dft_zero_impl(self, res, res_col); } } diff --git a/poulpy-hal/src/oep/vec_znx_dft.rs b/poulpy-hal/src/oep/vec_znx_dft.rs index f561084..abdb92a 100644 --- a/poulpy-hal/src/oep/vec_znx_dft.rs +++ b/poulpy-hal/src/oep/vec_znx_dft.rs @@ -188,7 +188,7 @@ pub unsafe trait VecZnxDftCopyImpl { /// * See [crate::api::VecZnxDftZero] for corresponding public API. /// # Safety [crate::doc::backend_safety] for safety contract. pub unsafe trait VecZnxDftZeroImpl { - fn vec_znx_dft_zero_impl(module: &Module, res: &mut R) + fn vec_znx_dft_zero_impl(module: &Module, res: &mut R, res_col: usize) where R: VecZnxDftToMut; } diff --git a/poulpy-hal/src/reference/fft64/vec_znx_dft.rs b/poulpy-hal/src/reference/fft64/vec_znx_dft.rs index e8d12e6..fa8d9e1 100644 --- a/poulpy-hal/src/reference/fft64/vec_znx_dft.rs +++ b/poulpy-hal/src/reference/fft64/vec_znx_dft.rs @@ -118,7 +118,7 @@ where } } else if a_scale < 0 { let shift: usize = (a_scale.unsigned_abs() as usize).min(res_size); - let sum_size: usize = a_size.min(res_size).saturating_sub(shift); + let sum_size: usize = a_size.min(res_size.saturating_sub(shift)); for j in 0..sum_size { BE::reim_add_inplace(res.at_mut(res_col, j + shift), a.at(a_col, j)); } @@ -398,10 +398,13 @@ where } } -pub fn vec_znx_dft_zero(res: &mut R) +pub fn vec_znx_dft_zero(res: &mut R, res_col: usize) where R: VecZnxDftToMut, BE: Backend + ReimZero, { - BE::reim_zero(res.to_mut().raw_mut()); + let res: &mut VecZnxDft<&mut [u8], BE> = &mut res.to_mut(); + for j in 0..res.size() { + BE::reim_zero(res.at_mut(res_col, j)) + } } diff --git a/poulpy-hal/src/test_suite/convolution.rs b/poulpy-hal/src/test_suite/convolution.rs index 8f4c71c..05f2df9 100644 --- a/poulpy-hal/src/test_suite/convolution.rs +++ b/poulpy-hal/src/test_suite/convolution.rs @@ -1,7 +1,7 @@ use crate::{ api::{ - Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigNormalize, - VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalizeInplace, + Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc, + VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace, }, layouts::{ Backend, FillUniform, Scratch, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, @@ -16,9 +16,10 @@ where + Convolution + VecZnxDftAlloc + VecZnxDftApply - + VecZnxIdftApplyConsume + + VecZnxIdftApplyTmpA + VecZnxBigNormalize - + VecZnxNormalizeInplace, + + VecZnxNormalizeInplace + + VecZnxBigAlloc, Scratch: ScratchTakeBasic, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, { @@ -26,70 +27,63 @@ where let base2k: usize = 12; - for a_cols in 1..3 { - for b_cols in 1..3 { - for a_size in 1..5 { - for b_size in 1..5 { - let mut a: VecZnx> = VecZnx::alloc(module.n(), a_cols, a_size); - let mut b: VecZnx> = VecZnx::alloc(module.n(), b_cols, b_size); + let a_cols: usize = 3; + let b_cols: usize = 3; + let a_size: usize = 3; + let b_size: usize = 3; + let c_cols: usize = a_cols + b_cols - 1; + let c_size: usize = a_size + b_size; - let mut c_want: VecZnx> = VecZnx::alloc(module.n(), a_cols + b_cols - 1, b_size + a_size); - let mut c_have: VecZnx> = VecZnx::alloc(module.n(), c_want.cols(), c_want.size()); + let mut a: VecZnx> = VecZnx::alloc(module.n(), a_cols, a_size); + let mut b: VecZnx> = VecZnx::alloc(module.n(), b_cols, b_size); - let mut scratch: ScratchOwned = ScratchOwned::alloc(module.convolution_tmp_bytes(c_want.size())); + let mut c_want: VecZnx> = VecZnx::alloc(module.n(), c_cols, c_size); + let mut c_have: VecZnx> = VecZnx::alloc(module.n(), c_cols, c_size); + let mut c_have_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(c_cols, c_size); + let mut c_have_big: VecZnxBig, BE> = module.vec_znx_big_alloc(c_cols, c_size); - a.fill_uniform(base2k, &mut source); - b.fill_uniform(base2k, &mut source); + let mut scratch: ScratchOwned = ScratchOwned::alloc(module.convolution_tmp_bytes(b_size)); - let mut b_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(b.cols(), b.size()); + a.fill_uniform(base2k, &mut source); + b.fill_uniform(base2k, &mut source); - for i in 0..b.cols() { - module.vec_znx_dft_apply(1, 0, &mut b_dft, i, &b, i); - } + let mut b_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(b_cols, b_size); + for i in 0..b.cols() { + module.vec_znx_dft_apply(1, 0, &mut b_dft, i, &b, i); + } - for mut res_scale in 0..2 * c_want.size() as i64 + 1 { - res_scale -= c_want.size() as i64; + for mut k in 0..(2 * c_size + 1) as i64 { + k -= c_size as i64; - let mut c_have_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(c_have.cols(), c_have.size()); - module.convolution(&mut c_have_dft, res_scale, &a, &b_dft, scratch.borrow()); + module.bivariate_convolution_full(k, &mut c_have_dft, &a, &b_dft, scratch.borrow()); - let c_have_big: VecZnxBig, BE> = module.vec_znx_idft_apply_consume(c_have_dft); - - for i in 0..c_have.cols() { - module.vec_znx_big_normalize( - base2k, - &mut c_have, - i, - base2k, - &c_have_big, - i, - scratch.borrow(), - ); - } - - convolution_naive( - module, - base2k, - &mut c_want, - res_scale, - &a, - &b, - scratch.borrow(), - ); - - assert_eq!(c_want, c_have); - } - } - } + for i in 0..c_cols { + module.vec_znx_idft_apply_tmpa(&mut c_have_big, i, &mut c_have_dft, i); } + + for i in 0..c_cols { + module.vec_znx_big_normalize( + base2k, + &mut c_have, + i, + base2k, + &c_have_big, + i, + scratch.borrow(), + ); + } + + convolution_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow()); + + assert_eq!(c_want, c_have); } } fn convolution_naive( module: &M, base2k: usize, + k: i64, res: &mut R, - res_scale: i64, a: &A, b: &B, scratch: &mut Scratch, @@ -112,11 +106,11 @@ fn convolution_naive( for a_limb in 0..a.size() { for b_col in 0..b.cols() { for b_limb in 0..b.size() { - let res_scale_abs = res_scale.unsigned_abs() as usize; + let res_scale_abs = k.unsigned_abs() as usize; let mut res_limb: usize = a_limb + b_limb + 1; - if res_scale <= 0 { + if k <= 0 { res_limb += res_scale_abs; if res_limb < res.size() { diff --git a/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/algorithm.rs b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/algorithm.rs index b9ec277..c65db52 100644 --- a/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/algorithm.rs +++ b/poulpy-schemes/src/tfhe/blind_rotation/algorithms/cggi/algorithm.rs @@ -189,12 +189,12 @@ fn execute_block_binary_extended( brk.data.chunks_exact(block_size) ) .for_each(|(ai, ski)| { - (0..extension_factor).for_each(|i| { - (0..cols).for_each(|j| { + for i in 0..extension_factor { + for j in 0..cols { module.vec_znx_dft_apply(1, 0, &mut acc_dft[i], j, &acc[i], j); - }); - module.vec_znx_dft_zero(&mut acc_add_dft[i]) - }); + module.vec_znx_dft_zero(&mut acc_add_dft[i], j) + } + } // TODO: first & last iterations can be optimized izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| { @@ -342,11 +342,10 @@ fn execute_block_binary( brk.data.chunks_exact(block_size) ) .for_each(|(ai, ski)| { - (0..cols).for_each(|j| { + for j in 0..cols { module.vec_znx_dft_apply(1, 0, &mut acc_dft, j, out_mut.data_mut(), j); - }); - - module.vec_znx_dft_zero(&mut acc_add_dft); + module.vec_znx_dft_zero(&mut acc_add_dft, j) + } izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| { let ai_pos: usize = ((aii + two_n as i64) & (two_n - 1) as i64) as usize; From 6e9cef5ecd69dc28c6777e86e36cffa7ace53a88 Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Sun, 26 Oct 2025 17:31:07 +0100 Subject: [PATCH 09/11] Auto stash before merge of "dev_bdd_selector" and "origin/dev_bdd_selector" --- poulpy-core/src/operations/glwe.rs | 18 +++++++++++++----- poulpy-hal/src/api/convolution.rs | 1 + 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index 95df49a..bd1256b 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -1,18 +1,26 @@ use poulpy_hal::{ api::{ - ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, - VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxZero, + Convolution, ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxZero }, layouts::{Backend, Module, Scratch, VecZnx, ZnxZero}, reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes, }; use crate::{ - ScratchTakeCore, - layouts::{GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision}, + layouts::{GLWEInfos, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision, GLWE}, ScratchTakeCore }; +pub trait GLWETensoring where Self: Convolution, Scratch: ScratchTakeCore { + fn glwe_tensor(&self, res: &mut R, res_scale: i64, a: &A, b: &B, scratch: &mut Scratch) where R: GLWETensorToMut, A: GLWEToRef, B: GLWEToRef{ + + let res: &mut GLWETensor<&mut [u8]> = &mut res.to_mut(); + let a: &mut GLWE<&[u8]> = &mut a.to_ref(); + let b: &GLWE<&[u8]> = &b.to_ref(); + + self.bivariate_convolution(res.data_mut(), res_scale, a, b, scratch); + } +} + pub trait GLWEAdd where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace, diff --git a/poulpy-hal/src/api/convolution.rs b/poulpy-hal/src/api/convolution.rs index 9b34ead..420ea83 100644 --- a/poulpy-hal/src/api/convolution.rs +++ b/poulpy-hal/src/api/convolution.rs @@ -123,6 +123,7 @@ where self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col); self.vec_znx_dft_add_scaled_inplace(res, res_col, &res_tmp, 0, -(1 + a_limb as i64) + k); } + } fn bivariate_convolution_single( From 41ca5aafcc8f69c9162a93ca63cb39c4f32f0004 Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Sun, 26 Oct 2025 19:03:15 +0100 Subject: [PATCH 10/11] Add glwe tensoiring --- poulpy-backend/src/cpu_fft64_avx/tests.rs | 5 +- poulpy-backend/src/cpu_fft64_ref/tests.rs | 4 +- poulpy-core/src/layouts/glwe_tensor.rs | 4 +- poulpy-core/src/operations/glwe.rs | 221 +++++++++++------- poulpy-hal/src/api/convolution.rs | 81 ++++--- poulpy-hal/src/test_suite/convolution.rs | 12 +- .../src/tfhe/bdd_arithmetic/blind_rotation.rs | 4 + .../tfhe/bdd_arithmetic/tests/fft64_ref.rs | 2 +- .../tests/test_suite/glwe_blind_rotation.rs | 4 - 9 files changed, 199 insertions(+), 138 deletions(-) diff --git a/poulpy-backend/src/cpu_fft64_avx/tests.rs b/poulpy-backend/src/cpu_fft64_avx/tests.rs index 35ae6d3..9abddf0 100644 --- a/poulpy-backend/src/cpu_fft64_avx/tests.rs +++ b/poulpy-backend/src/cpu_fft64_avx/tests.rs @@ -1,5 +1,6 @@ use poulpy_hal::{ - api::ModuleNew, backend_test_suite, cross_backend_test_suite, layouts::Module, test_suite::convolution::test_convolution, + api::ModuleNew, backend_test_suite, cross_backend_test_suite, layouts::Module, + test_suite::convolution::test_bivariate_tensoring, }; use crate::FFT64Avx; @@ -123,5 +124,5 @@ backend_test_suite! { #[test] fn test_convolution_fft64_avx() { let module: Module = Module::::new(64); - test_convolution(&module); + test_bivariate_tensoring(&module); } diff --git a/poulpy-backend/src/cpu_fft64_ref/tests.rs b/poulpy-backend/src/cpu_fft64_ref/tests.rs index 4531117..177dfb8 100644 --- a/poulpy-backend/src/cpu_fft64_ref/tests.rs +++ b/poulpy-backend/src/cpu_fft64_ref/tests.rs @@ -1,9 +1,9 @@ -use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_convolution}; +use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_bivariate_tensoring}; use crate::FFT64Ref; #[test] fn test_convolution_fft64_ref() { let module: Module = Module::::new(8); - test_convolution(&module); + test_bivariate_tensoring(&module); } diff --git a/poulpy-core/src/layouts/glwe_tensor.rs b/poulpy-core/src/layouts/glwe_tensor.rs index 516a854..8ff6428 100644 --- a/poulpy-core/src/layouts/glwe_tensor.rs +++ b/poulpy-core/src/layouts/glwe_tensor.rs @@ -93,7 +93,7 @@ impl GLWETensor> { } pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self { - let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1) as usize; + let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1); GLWETensor { data: VecZnx::alloc(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize), base2k, @@ -110,7 +110,7 @@ impl GLWETensor> { } pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize { - let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1) as usize; + let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1); VecZnx::bytes_of(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize) } } diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index bd1256b..492b611 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -1,29 +1,77 @@ use poulpy_hal::{ api::{ - Convolution, ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxZero + BivariateTensoring, ModuleN, ScratchTakeBasic, VecZnxAdd, VecZnxAddInplace, VecZnxBigNormalize, VecZnxCopy, + VecZnxIdftApplyConsume, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNormalize, + VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, + VecZnxSubNegateInplace, VecZnxZero, }, - layouts::{Backend, Module, Scratch, VecZnx, ZnxZero}, + layouts::{Backend, Module, Scratch, VecZnx, VecZnxBig, ZnxInfos}, reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes, }; use crate::{ - layouts::{GLWEInfos, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision, GLWE}, ScratchTakeCore + ScratchTakeCore, + layouts::{ + GLWE, GLWEInfos, GLWEPrepared, GLWEPreparedToRef, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos, + TorusPrecision, + }, }; -pub trait GLWETensoring where Self: Convolution, Scratch: ScratchTakeCore { - fn glwe_tensor(&self, res: &mut R, res_scale: i64, a: &A, b: &B, scratch: &mut Scratch) where R: GLWETensorToMut, A: GLWEToRef, B: GLWEToRef{ - +pub trait GLWETensoring +where + Self: BivariateTensoring + VecZnxIdftApplyConsume + VecZnxBigNormalize, + Scratch: ScratchTakeCore, +{ + /// res = (a (x) b) * 2^{k * a_base2k} + /// + /// # Requires + /// * a.base2k() == b.base2k() + /// * res.cols() >= a.cols() + b.cols() - 1 + /// + /// # Behavior + /// * res precision is truncated to res.max_k().min(a.max_k() + b.max_k() + k * a_base2k) + fn glwe_tensor(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + where + R: GLWETensorToMut, + A: GLWEToRef, + B: GLWEPreparedToRef, + { let res: &mut GLWETensor<&mut [u8]> = &mut res.to_mut(); - let a: &mut GLWE<&[u8]> = &mut a.to_ref(); - let b: &GLWE<&[u8]> = &b.to_ref(); + let a: &GLWE<&[u8]> = &a.to_ref(); + let b: &GLWEPrepared<&[u8], BE> = &b.to_ref(); - self.bivariate_convolution(res.data_mut(), res_scale, a, b, scratch); + assert_eq!(a.base2k(), b.base2k()); + assert_eq!(a.rank(), res.rank()); + + let res_cols: usize = res.data.cols(); + + // Get tmp buffer of min precision between a_prec * b_prec and res_prec + let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, res_cols, res.max_k().div_ceil(a.base2k()) as usize); + + // DFT(res) = DFT(a) (x) DFT(b) + self.bivariate_tensoring(k, &mut res_dft, &a.data, &b.data, scratch_1); + + // res = IDFT(res) + let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft); + + // Normalize and switches basis if required + for res_col in 0..res_cols { + self.vec_znx_big_normalize( + res.base2k().into(), + &mut res.data, + res_col, + a.base2k().into(), + &res_big, + res_col, + scratch_1, + ); + } } } pub trait GLWEAdd where - Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace, + Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero, { fn glwe_add(&self, res: &mut R, a: &A, b: &B) where @@ -39,35 +87,38 @@ where assert_eq!(b.n(), self.n() as u32); assert_eq!(res.n(), self.n() as u32); assert_eq!(a.base2k(), b.base2k()); - assert!(res.rank() >= a.rank().max(b.rank())); + assert_eq!(res.base2k(), b.base2k()); + + if a.rank() == 0 { + assert_eq!(res.rank(), b.rank()); + } else if b.rank() == 0 { + assert_eq!(res.rank(), a.rank()); + } else { + assert_eq!(res.rank(), a.rank()); + assert_eq!(res.rank(), b.rank()); + } let min_col: usize = (a.rank().min(b.rank()) + 1).into(); let max_col: usize = (a.rank().max(b.rank() + 1)).into(); let self_col: usize = (res.rank() + 1).into(); - (0..min_col).for_each(|i| { + for i in 0..min_col { self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i); - }); - - if a.rank() > b.rank() { - (min_col..max_col).for_each(|i| { - self.vec_znx_copy(res.data_mut(), i, a.data(), i); - }); - } else { - (min_col..max_col).for_each(|i| { - self.vec_znx_copy(res.data_mut(), i, b.data(), i); - }); } - let size: usize = res.size(); - (max_col..self_col).for_each(|i| { - (0..size).for_each(|j| { - res.data.zero_at(i, j); - }); - }); + if a.rank() > b.rank() { + for i in min_col..max_col { + self.vec_znx_copy(res.data_mut(), i, a.data(), i); + } + } else { + for i in min_col..max_col { + self.vec_znx_copy(res.data_mut(), i, b.data(), i); + } + } - res.set_base2k(a.base2k()); - res.set_k(set_k_binary(res, a, b)); + for i in max_col..self_col { + self.vec_znx_zero(res.data_mut(), i); + } } fn glwe_add_inplace(&self, res: &mut R, a: &A) @@ -83,24 +134,22 @@ where assert_eq!(res.base2k(), a.base2k()); assert!(res.rank() >= a.rank()); - (0..(a.rank() + 1).into()).for_each(|i| { + for i in 0..(a.rank() + 1).into() { self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i); - }); - - res.set_k(set_k_unary(res, a)) + } } } -impl GLWEAdd for Module where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace {} +impl GLWEAdd for Module where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero {} impl GLWESub for Module where - Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace + Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace { } pub trait GLWESub where - Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace, + Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace, { fn glwe_sub(&self, res: &mut R, a: &A, b: &B) where @@ -114,37 +163,40 @@ where assert_eq!(a.n(), self.n() as u32); assert_eq!(b.n(), self.n() as u32); - assert_eq!(a.base2k(), b.base2k()); - assert!(res.rank() >= a.rank().max(b.rank())); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.base2k(), res.base2k()); + assert_eq!(b.base2k(), res.base2k()); + + if a.rank() == 0 { + assert_eq!(res.rank(), b.rank()); + } else if b.rank() == 0 { + assert_eq!(res.rank(), a.rank()); + } else { + assert_eq!(res.rank(), a.rank()); + assert_eq!(res.rank(), b.rank()); + } let min_col: usize = (a.rank().min(b.rank()) + 1).into(); let max_col: usize = (a.rank().max(b.rank() + 1)).into(); let self_col: usize = (res.rank() + 1).into(); - (0..min_col).for_each(|i| { + for i in 0..min_col { self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i); - }); - - if a.rank() > b.rank() { - (min_col..max_col).for_each(|i| { - self.vec_znx_copy(res.data_mut(), i, a.data(), i); - }); - } else { - (min_col..max_col).for_each(|i| { - self.vec_znx_copy(res.data_mut(), i, b.data(), i); - self.vec_znx_negate_inplace(res.data_mut(), i); - }); } - let size: usize = res.size(); - (max_col..self_col).for_each(|i| { - (0..size).for_each(|j| { - res.data.zero_at(i, j); - }); - }); + if a.rank() > b.rank() { + for i in min_col..max_col { + self.vec_znx_copy(res.data_mut(), i, a.data(), i); + } + } else { + for i in min_col..max_col { + self.vec_znx_negate(res.data_mut(), i, b.data(), i); + } + } - res.set_base2k(a.base2k()); - res.set_k(set_k_binary(res, a, b)); + for i in max_col..self_col { + self.vec_znx_zero(res.data_mut(), i); + } } fn glwe_sub_inplace(&self, res: &mut R, a: &A) @@ -158,13 +210,11 @@ where assert_eq!(res.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32); assert_eq!(res.base2k(), a.base2k()); - assert!(res.rank() >= a.rank()); + assert!(res.rank() == a.rank() || a.rank() == 0); - (0..(a.rank() + 1).into()).for_each(|i| { + for i in 0..(a.rank() + 1).into() { self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i); - }); - - res.set_k(set_k_unary(res, a)) + } } fn glwe_sub_negate_inplace(&self, res: &mut R, a: &A) @@ -178,21 +228,19 @@ where assert_eq!(res.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32); assert_eq!(res.base2k(), a.base2k()); - assert!(res.rank() >= a.rank()); + assert!(res.rank() == a.rank() || a.rank() == 0); - (0..(a.rank() + 1).into()).for_each(|i| { + for i in 0..(a.rank() + 1).into() { self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i); - }); - - res.set_k(set_k_unary(res, a)) + } } } -impl GLWERotate for Module where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace {} +impl GLWERotate for Module where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace + VecZnxZero {} pub trait GLWERotate where - Self: ModuleN + VecZnxRotate + VecZnxRotateInplace, + Self: ModuleN + VecZnxRotate + VecZnxRotateInplace + VecZnxZero, { fn glwe_rotate_tmp_bytes(&self) -> usize { vec_znx_rotate_inplace_tmp_bytes(self.n()) @@ -207,14 +255,18 @@ where let a: &GLWE<&[u8]> = &a.to_ref(); assert_eq!(a.n(), self.n() as u32); - assert_eq!(res.rank(), a.rank()); + assert_eq!(res.n(), self.n() as u32); + assert!(res.rank() == a.rank() || a.rank() == 0); - (0..(a.rank() + 1).into()).for_each(|i| { + let res_cols = (res.rank() + 1).into(); + let a_cols = (a.rank() + 1).into(); + + for i in 0..a_cols { self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i); - }); - - res.set_base2k(a.base2k()); - res.set_k(set_k_unary(res, a)) + } + for i in a_cols..res_cols { + self.vec_znx_zero(res.data_mut(), i); + } } fn glwe_rotate_inplace(&self, k: i64, res: &mut R, scratch: &mut Scratch) @@ -224,9 +276,9 @@ where { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - (0..(res.rank() + 1).into()).for_each(|i| { + for i in 0..(res.rank() + 1).into() { self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch); - }); + } } } @@ -251,9 +303,6 @@ where for i in 0..res.rank().as_usize() + 1 { self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i); } - - res.set_base2k(a.base2k()); - res.set_k(set_k_unary(res, a)) } fn glwe_mul_xp_minus_one_inplace(&self, k: i64, res: &mut R, scratch: &mut Scratch) @@ -286,6 +335,7 @@ where assert_eq!(res.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32); + assert!(res.rank() == a.rank() || a.rank() == 0); let min_rank: usize = res.rank().min(a.rank()).as_usize() + 1; @@ -296,9 +346,6 @@ where for i in min_rank..(res.rank() + 1).into() { self.vec_znx_zero(res.data_mut(), i); } - - res.set_k(a.k().min(res.max_k())); - res.set_base2k(a.base2k()); } } @@ -364,8 +411,6 @@ where scratch, ); } - - res.set_k(a.k().min(res.k())); } fn glwe_normalize_inplace(&self, res: &mut R, scratch: &mut Scratch) @@ -380,6 +425,7 @@ where } } +#[allow(dead_code)] // c = op(a, b) fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision { // If either operands is a ciphertext @@ -401,6 +447,7 @@ fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> T } } +#[allow(dead_code)] // a = op(a, b) fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision { if a.rank() != 0 || b.rank() != 0 { diff --git a/poulpy-hal/src/api/convolution.rs b/poulpy-hal/src/api/convolution.rs index 420ea83..10caf6b 100644 --- a/poulpy-hal/src/api/convolution.rs +++ b/poulpy-hal/src/api/convolution.rs @@ -6,39 +6,19 @@ use crate::{ layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos}, }; -impl Convolution for Module +impl BivariateTensoring for Module where - Self: Sized - + ModuleN - + SvpPPolAlloc - + SvpApplyDftToDft - + SvpPrepare - + SvpPPolBytesOf - + VecZnxDftBytesOf - + VecZnxDftAddScaledInplace - + VecZnxDftZero, + Self: BivariateConvolution, Scratch: ScratchTakeBasic, { } -pub trait Convolution +pub trait BivariateTensoring where - Self: Sized - + ModuleN - + SvpPPolAlloc - + SvpApplyDftToDft - + SvpPrepare - + SvpPPolBytesOf - + VecZnxDftBytesOf - + VecZnxDftAddScaledInplace - + VecZnxDftZero, + Self: BivariateConvolution, Scratch: ScratchTakeBasic, { - fn convolution_tmp_bytes(&self, b_size: usize) -> usize { - self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size) - } - - fn bivariate_convolution_full(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + fn bivariate_tensoring(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) where R: VecZnxDftToMut, A: VecZnxToRef, @@ -55,14 +35,48 @@ where assert!(res_cols >= a_cols + b_cols - 1); for res_col in 0..res_cols { - let a_min: usize = res_col.saturating_sub(b_cols - 1); - let a_max: usize = res_col.min(a_cols - 1); - self.bivariate_convolution_single(k, res, res_col, a, a_min, b, res_col - a_min, scratch); - for a_col in a_min + 1..a_max + 1 { - self.bivariate_convolution_single_add(k, res, res_col, a, a_col, b, res_col - a_col, scratch); + self.vec_znx_dft_zero(res, res_col); + } + + for a_col in 0..a_cols { + for b_col in 0..b_cols { + self.bivariate_convolution_add(k, res, a_col + b_col, a, a_col, b, b_col, scratch); } } } +} + +impl BivariateConvolution for Module +where + Self: Sized + + ModuleN + + SvpPPolAlloc + + SvpApplyDftToDft + + SvpPrepare + + SvpPPolBytesOf + + VecZnxDftBytesOf + + VecZnxDftAddScaledInplace + + VecZnxDftZero, + Scratch: ScratchTakeBasic, +{ +} + +pub trait BivariateConvolution +where + Self: Sized + + ModuleN + + SvpPPolAlloc + + SvpApplyDftToDft + + SvpPrepare + + SvpPPolBytesOf + + VecZnxDftBytesOf + + VecZnxDftAddScaledInplace + + VecZnxDftZero, + Scratch: ScratchTakeBasic, +{ + fn convolution_tmp_bytes(&self, b_size: usize) -> usize { + self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size) + } /// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K over the /// selected columsn and stores the result on the selected column, scaled by 2^{k * Base2K} @@ -96,7 +110,7 @@ where /// [r03, r13, r23, r33] /// /// If res.size() < a.size() + b.size() + 1 + k, result is truncated accordingly in the Y dimension. - fn bivariate_convolution_single_add( + fn bivariate_convolution_add( &self, k: i64, res: &mut R, @@ -123,10 +137,9 @@ where self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col); self.vec_znx_dft_add_scaled_inplace(res, res_col, &res_tmp, 0, -(1 + a_limb as i64) + k); } - } - fn bivariate_convolution_single( + fn bivariate_convolution( &self, k: i64, res: &mut R, @@ -142,6 +155,6 @@ where B: VecZnxDftToRef, { self.vec_znx_dft_zero(res, res_col); - self.bivariate_convolution_single_add(k, res, res_col, a, a_col, b, b_col, scratch); + self.bivariate_convolution_add(k, res, res_col, a, a_col, b, b_col, scratch); } } diff --git a/poulpy-hal/src/test_suite/convolution.rs b/poulpy-hal/src/test_suite/convolution.rs index 05f2df9..d175656 100644 --- a/poulpy-hal/src/test_suite/convolution.rs +++ b/poulpy-hal/src/test_suite/convolution.rs @@ -1,6 +1,6 @@ use crate::{ api::{ - Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc, + BivariateTensoring, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc, VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace, }, layouts::{ @@ -10,10 +10,10 @@ use crate::{ source::Source, }; -pub fn test_convolution(module: &M) +pub fn test_bivariate_tensoring(module: &M) where M: ModuleN - + Convolution + + BivariateTensoring + VecZnxDftAlloc + VecZnxDftApply + VecZnxIdftApplyTmpA @@ -55,7 +55,7 @@ where for mut k in 0..(2 * c_size + 1) as i64 { k -= c_size as i64; - module.bivariate_convolution_full(k, &mut c_have_dft, &a, &b_dft, scratch.borrow()); + module.bivariate_tensoring(k, &mut c_have_dft, &a, &b_dft, scratch.borrow()); for i in 0..c_cols { module.vec_znx_idft_apply_tmpa(&mut c_have_big, i, &mut c_have_dft, i); @@ -73,13 +73,13 @@ where ); } - convolution_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow()); + bivariate_tensoring_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow()); assert_eq!(c_want, c_have); } } -fn convolution_naive( +fn bivariate_tensoring_naive( module: &M, base2k: usize, k: i64, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_rotation.rs index 27d97d5..96dc418 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/blind_rotation.rs @@ -29,6 +29,7 @@ where self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos) } + #[allow(clippy::too_many_arguments)] /// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. fn ggsw_to_ggsw_blind_rotation( &self, @@ -74,6 +75,7 @@ where self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) } + #[allow(clippy::too_many_arguments)] fn scalar_to_ggsw_blind_rotation( &self, res: &mut R, @@ -143,6 +145,7 @@ where self.cmux_tmp_bytes(res_infos, res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) } + #[allow(clippy::too_many_arguments)] /// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. fn glwe_to_glwe_blind_rotation( &self, @@ -162,6 +165,7 @@ where assert!(bit_rsh + bit_mask <= T::WORD_SIZE); let mut res: GLWE<&mut [u8]> = res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); let (mut tmp_res, scratch_1) = scratch.take_glwe(&res); diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs index a34cdae..b855ffa 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/fft64_ref.rs @@ -3,7 +3,7 @@ use poulpy_backend::FFT64Ref; use crate::tfhe::{ bdd_arithmetic::tests::test_suite::{ test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, test_bdd_sra, - test_bdd_srl, test_bdd_sub, test_bdd_xor, test_scalar_to_ggsw_blind_rotation, test_glwe_to_glwe_blind_rotation, + test_bdd_srl, test_bdd_sub, test_bdd_xor, test_glwe_to_glwe_blind_rotation, test_scalar_to_ggsw_blind_rotation, }, blind_rotation::CGGI, }; diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs index 47f97f3..7573b49 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/tests/test_suite/glwe_blind_rotation.rs @@ -70,12 +70,8 @@ where data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); test_glwe.encode_vec_i64(&data, base2k.as_usize().into()); - println!("pt: {}", test_glwe); - let k: u32 = source.next_u32(); - println!("k: {k}"); - let mut k_enc_prep: FheUintBlocksPrepared, u32, BE> = FheUintBlocksPrepared::, u32, BE>::alloc(&module, &ggsw_infos); k_enc_prep.encrypt_sk( From 8d4c19a304bdc473e393e9be57f479cdca8dbd3c Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Mon, 27 Oct 2025 11:28:53 +0100 Subject: [PATCH 11/11] Distinguish between gglwe_to_ggsw key and tensor_key + update key repreentation --- Cargo.lock | 9 +- Cargo.toml | 7 +- poulpy-core/Cargo.toml | 1 + poulpy-core/src/automorphism/gglwe_atk.rs | 37 +- poulpy-core/src/automorphism/ggsw_ct.rs | 23 +- poulpy-core/src/automorphism/glwe_ct.rs | 91 +++- poulpy-core/src/conversion/gglwe_to_ggsw.rs | 265 +++++------ poulpy-core/src/conversion/lwe_to_glwe.rs | 27 +- .../compressed/gglwe_to_ggsw_key.rs | 124 ++++++ .../encryption/compressed/glwe_tensor_key.rs | 128 ++---- poulpy-core/src/encryption/compressed/mod.rs | 2 + poulpy-core/src/encryption/gglwe.rs | 2 +- .../src/encryption/gglwe_to_ggsw_key.rs | 112 +++++ poulpy-core/src/encryption/glwe_tensor_key.rs | 100 ++--- ...we_switching_key.rs => glwe_to_lwe_key.rs} | 19 +- ...we_switching_key.rs => lwe_to_glwe_key.rs} | 20 +- poulpy-core/src/encryption/mod.rs | 10 +- poulpy-core/src/glwe_packer.rs | 388 ++++++++++++++++ poulpy-core/src/glwe_packing.rs | 336 +------------- poulpy-core/src/glwe_trace.rs | 53 ++- poulpy-core/src/keyswitching/ggsw.rs | 82 ++-- poulpy-core/src/keyswitching/glwe.rs | 414 ++++++++++-------- .../layouts/compressed/gglwe_to_ggsw_key.rs | 237 ++++++++++ .../src/layouts/compressed/glwe_tensor_key.rs | 142 ++---- ...we_switching_key.rs => glwe_to_lwe_key.rs} | 8 +- ...we_switching_key.rs => lwe_to_glwe_key.rs} | 40 +- poulpy-core/src/layouts/compressed/mod.rs | 10 +- poulpy-core/src/layouts/gglwe_to_ggsw_key.rs | 254 +++++++++++ poulpy-core/src/layouts/glwe_secret_tensor.rs | 221 ++++++++++ poulpy-core/src/layouts/glwe_tensor_key.rs | 101 +---- ...we_switching_key.rs => glwe_to_lwe_key.rs} | 36 +- ...we_switching_key.rs => lwe_to_glwe_key.rs} | 48 +- poulpy-core/src/layouts/mod.rs | 12 +- .../src/layouts/prepared/gglwe_to_ggsw_key.rs | 252 +++++++++++ .../layouts/prepared/glwe_switching_key.rs | 6 +- .../src/layouts/prepared/glwe_tensor_key.rs | 91 +--- ...we_switching_key.rs => glwe_to_lwe_key.rs} | 80 ++-- .../src/layouts/prepared/lwe_switching_key.rs | 2 +- ...we_switching_key.rs => lwe_to_glwe_key.rs} | 84 ++-- poulpy-core/src/layouts/prepared/mod.rs | 10 +- poulpy-core/src/lib.rs | 2 + poulpy-core/src/noise/gglwe.rs | 2 +- poulpy-core/src/operations/glwe.rs | 8 + poulpy-core/src/scratch.rs | 57 +-- poulpy-core/src/tests/mod.rs | 4 +- poulpy-core/src/tests/serialization.rs | 21 +- .../tests/test_suite/automorphism/ggsw_ct.rs | 50 +-- .../src/tests/test_suite/conversion.rs | 28 +- .../encryption/gglwe_to_ggsw_key.rs | 144 ++++++ .../tests/test_suite/encryption/glwe_tsk.rs | 114 +---- .../src/tests/test_suite/encryption/mod.rs | 2 + .../src/tests/test_suite/keyswitch/ggsw_ct.rs | 29 +- poulpy-core/src/tests/test_suite/packing.rs | 4 +- poulpy-hal/Cargo.toml | 2 +- poulpy-hal/src/api/convolution.rs | 2 + poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs | 18 +- .../src/tfhe/circuit_bootstrapping/circuit.rs | 11 +- .../src/tfhe/circuit_bootstrapping/key.rs | 12 +- .../circuit_bootstrapping/key_prepared.rs | 14 +- 59 files changed, 2812 insertions(+), 1596 deletions(-) create mode 100644 poulpy-core/src/encryption/compressed/gglwe_to_ggsw_key.rs create mode 100644 poulpy-core/src/encryption/gglwe_to_ggsw_key.rs rename poulpy-core/src/encryption/{glwe_to_lwe_switching_key.rs => glwe_to_lwe_key.rs} (83%) rename poulpy-core/src/encryption/{lwe_to_glwe_switching_key.rs => lwe_to_glwe_key.rs} (81%) create mode 100644 poulpy-core/src/glwe_packer.rs create mode 100644 poulpy-core/src/layouts/compressed/gglwe_to_ggsw_key.rs rename poulpy-core/src/layouts/compressed/{glwe_to_lwe_switching_key.rs => glwe_to_lwe_key.rs} (95%) rename poulpy-core/src/layouts/compressed/{lwe_to_glwe_switching_key.rs => lwe_to_glwe_key.rs} (73%) create mode 100644 poulpy-core/src/layouts/gglwe_to_ggsw_key.rs create mode 100644 poulpy-core/src/layouts/glwe_secret_tensor.rs rename poulpy-core/src/layouts/{glwe_to_lwe_switching_key.rs => glwe_to_lwe_key.rs} (79%) rename poulpy-core/src/layouts/{lwe_to_glwe_switching_key.rs => lwe_to_glwe_key.rs} (73%) create mode 100644 poulpy-core/src/layouts/prepared/gglwe_to_ggsw_key.rs rename poulpy-core/src/layouts/prepared/{glwe_to_lwe_switching_key.rs => glwe_to_lwe_key.rs} (54%) rename poulpy-core/src/layouts/prepared/{lwe_to_glwe_switching_key.rs => lwe_to_glwe_key.rs} (53%) create mode 100644 poulpy-core/src/tests/test_suite/encryption/gglwe_to_ggsw_key.rs diff --git a/Cargo.lock b/Cargo.lock index 0037bc7..76ebe9b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -49,9 +49,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.23.2" +version = "1.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3995eaeebcdf32f91f980d360f78732ddc061097ab4e39991ae7a6ace9194677" +checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4" [[package]] name = "byteorder" @@ -372,6 +372,7 @@ dependencies = [ name = "poulpy-core" version = "0.3.1" dependencies = [ + "bytemuck", "byteorder", "criterion", "itertools 0.14.0", @@ -534,9 +535,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "rug" -version = "1.27.0" +version = "1.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4207e8d668e5b8eb574bda8322088ccd0d7782d3d03c7e8d562e82ed82bdcbc3" +checksum = "58ad2e973fe3c3214251a840a621812a4f40468da814b1a3d6947d433c2af11f" dependencies = [ "az", "gmp-mpfr-sys", diff --git a/Cargo.toml b/Cargo.toml index ed44102..964e654 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,8 +7,8 @@ poulpy-hal = {path = "poulpy-hal"} poulpy-core = {path = "poulpy-core"} poulpy-backend = {path = "poulpy-backend"} poulpy-schemes = {path = "poulpy-schemes"} -rug = "1.27" -rand = "0.9.1" +rug = "1.28.0" +rand = "0.9.2" rand_chacha = "0.9.0" rand_core = "0.9.3" rand_distr = "0.5.1" @@ -16,4 +16,5 @@ itertools = "0.14.0" criterion = "0.7.0" byteorder = "1.5.0" zstd = "0.13.3" -once_cell = "1.21.3" \ No newline at end of file +once_cell = "1.21.3" +bytemuck = "1.24.0" \ No newline at end of file diff --git a/poulpy-core/Cargo.toml b/poulpy-core/Cargo.toml index 42eb830..cf15d04 100644 --- a/poulpy-core/Cargo.toml +++ b/poulpy-core/Cargo.toml @@ -15,6 +15,7 @@ poulpy-hal = {workspace = true} poulpy-backend = {workspace = true} itertools = {workspace = true} byteorder = {workspace = true} +bytemuck = {workspace = true} once_cell = {workspace = true} [[bench]] diff --git a/poulpy-core/src/automorphism/gglwe_atk.rs b/poulpy-core/src/automorphism/gglwe_atk.rs index ffc35cd..87e545e 100644 --- a/poulpy-core/src/automorphism/gglwe_atk.rs +++ b/poulpy-core/src/automorphism/gglwe_atk.rs @@ -1,11 +1,10 @@ use poulpy_hal::{ - api::VecZnxAutomorphism, - layouts::{Backend, DataMut, GaloisElement, Module, Scratch}, + api::{VecZnxAutomorphism, VecZnxAutomorphismInplace}, + layouts::{Backend, CyclotomicOrder, DataMut, GaloisElement, Module, Scratch}, }; use crate::{ - ScratchTakeCore, - automorphism::glwe_ct::GLWEAutomorphism, + GLWEKeyswitch, ScratchTakeCore, layouts::{ GGLWE, GGLWEInfos, GGLWEPreparedToRef, GGLWEToMut, GGLWEToRef, GLWE, GLWEAutomorphismKey, GetGaloisElement, SetGaloisElement, @@ -45,14 +44,10 @@ impl GLWEAutomorphismKey { } } -impl GLWEAutomorphismKeyAutomorphism for Module where - Self: GaloisElement + GLWEAutomorphism + VecZnxAutomorphism -{ -} - -pub trait GLWEAutomorphismKeyAutomorphism +impl GLWEAutomorphismKeyAutomorphism for Module where - Self: GaloisElement + GLWEAutomorphism + VecZnxAutomorphism, + Self: GaloisElement + GLWEKeyswitch + VecZnxAutomorphism + VecZnxAutomorphismInplace + CyclotomicOrder, + Scratch: ScratchTakeCore, { fn glwe_automorphism_key_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where @@ -68,7 +63,6 @@ where R: GGLWEToMut + SetGaloisElement + GGLWEInfos, A: GGLWEToRef + GetGaloisElement + GGLWEInfos, K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - Scratch: ScratchTakeCore, { assert!( res.dnum().as_u32() <= a.dnum().as_u32(), @@ -163,3 +157,22 @@ where res.set_p((res.p() * key.p()) % self.cyclotomic_order()); } } + +pub trait GLWEAutomorphismKeyAutomorphism { + fn glwe_automorphism_key_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GGLWEInfos, + A: GGLWEInfos, + K: GGLWEInfos; + + fn glwe_automorphism_key_automorphism(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GGLWEToMut + SetGaloisElement + GGLWEInfos, + A: GGLWEToRef + GetGaloisElement + GGLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos; + + fn glwe_automorphism_key_automorphism_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GGLWEToMut + SetGaloisElement + GetGaloisElement + GGLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos; +} diff --git a/poulpy-core/src/automorphism/ggsw_ct.rs b/poulpy-core/src/automorphism/ggsw_ct.rs index fb54f6d..8644f98 100644 --- a/poulpy-core/src/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/automorphism/ggsw_ct.rs @@ -7,8 +7,8 @@ use crate::{ GGSWExpandRows, ScratchTakeCore, automorphism::glwe_ct::GLWEAutomorphism, layouts::{ - GGLWEInfos, GGLWEPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GetGaloisElement, - prepared::{GLWETensorKeyPrepared, GLWETensorKeyPreparedToRef}, + GGLWEInfos, GGLWEPreparedToRef, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedToRef, GGSW, GGSWInfos, GGSWToMut, + GGSWToRef, GetGaloisElement, }, }; @@ -36,7 +36,7 @@ impl GGSW { where A: GGSWToRef, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, M: GGSWAutomorphism, { @@ -46,7 +46,7 @@ impl GGSW { pub fn automorphism_inplace(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch) where K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, M: GGSWAutomorphism, { @@ -67,11 +67,8 @@ where K: GGLWEInfos, T: GGLWEInfos, { - let out_size: usize = res_infos.size(); - let ci_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), out_size); - let ks_internal: usize = self.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos); - let expand: usize = self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos); - ci_dft + (ks_internal.max(expand)) + self.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos) + .max(self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos)) } fn ggsw_automorphism(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) @@ -79,12 +76,12 @@ where R: GGSWToMut, A: GGSWToRef, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let a: &GGSW<&[u8]> = &a.to_ref(); - let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); assert_eq!(res.dsize(), a.dsize()); assert!(res.dnum() <= a.dnum()); @@ -104,11 +101,11 @@ where where R: GGSWToMut, K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); - let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); // Keyswitch the j-th row of the col 0 for row in 0..res.dnum().as_usize() { diff --git a/poulpy-core/src/automorphism/glwe_ct.rs b/poulpy-core/src/automorphism/glwe_ct.rs index 7161239..b382197 100644 --- a/poulpy-core/src/automorphism/glwe_ct.rs +++ b/poulpy-core/src/automorphism/glwe_ct.rs @@ -1,13 +1,13 @@ use poulpy_hal::{ api::{ - ScratchTakeBasic, VecZnxAutomorphismInplace, VecZnxBigAutomorphismInplace, VecZnxBigSubSmallInplace, - VecZnxBigSubSmallNegateInplace, + ScratchTakeBasic, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, + VecZnxBigSubSmallInplace, VecZnxBigSubSmallNegateInplace, VecZnxNormalize, }, layouts::{Backend, DataMut, Module, Scratch, VecZnxBig}, }; use crate::{ - GLWEKeyswitch, ScratchTakeCore, keyswitch_internal, + GLWEKeySwitchInternal, GLWEKeyswitch, ScratchTakeCore, layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos}, }; @@ -101,13 +101,71 @@ impl GLWE { } } -pub trait GLWEAutomorphism +pub trait GLWEAutomorphism { + fn glwe_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + K: GGLWEInfos; + + fn glwe_automorphism(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; + + fn glwe_automorphism_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; + + fn glwe_automorphism_add(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; + + fn glwe_automorphism_add_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; + + fn glwe_automorphism_sub(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; + + fn glwe_automorphism_sub_negate(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; + + fn glwe_automorphism_sub_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; + + fn glwe_automorphism_sub_negate_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GetGaloisElement + GGLWEPreparedToRef + GGLWEInfos; +} + +impl GLWEAutomorphism for Module where - Self: GLWEKeyswitch + Self: Sized + + GLWEKeyswitch + + GLWEKeySwitchInternal + + VecZnxNormalize + VecZnxAutomorphismInplace + VecZnxBigAutomorphismInplace + VecZnxBigSubSmallInplace - + VecZnxBigSubSmallNegateInplace, + + VecZnxBigSubSmallNegateInplace + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize, + Scratch: ScratchTakeCore, { fn glwe_automorphism_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize where @@ -160,7 +218,7 @@ where let a: &GLWE<&[u8]> = &a.to_ref(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, a, key, scratch_1); + let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); @@ -186,7 +244,7 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, res, key, scratch_1); + let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); @@ -214,7 +272,7 @@ where let a: &GLWE<&[u8]> = &a.to_ref(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, a, key, scratch_1); + let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); @@ -242,7 +300,7 @@ where let a: &GLWE<&[u8]> = &a.to_ref(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, a, key, scratch_1); + let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); @@ -268,7 +326,7 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, res, key, scratch_1); + let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); @@ -294,7 +352,7 @@ where let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size - let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, res, key, scratch_1); + let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); for i in 0..res.rank().as_usize() + 1 { self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); @@ -311,12 +369,3 @@ where } } } - -impl GLWEAutomorphism for Module where - Self: GLWEKeyswitch - + VecZnxAutomorphismInplace - + VecZnxBigAutomorphismInplace - + VecZnxBigSubSmallInplace - + VecZnxBigSubSmallNegateInplace -{ -} diff --git a/poulpy-core/src/conversion/gglwe_to_ggsw.rs b/poulpy-core/src/conversion/gglwe_to_ggsw.rs index b33759e..8554e50 100644 --- a/poulpy-core/src/conversion/gglwe_to_ggsw.rs +++ b/poulpy-core/src/conversion/gglwe_to_ggsw.rs @@ -1,17 +1,16 @@ use poulpy_hal::{ api::{ - ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftAddInplace, VecZnxDftApply, - VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, + VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, }, - layouts::{Backend, DataMut, Module, Scratch, VmpPMat, ZnxInfos}, + layouts::{Backend, DataMut, Module, Scratch, VecZnxBig}, }; use crate::{ - GLWECopy, ScratchTakeCore, + GGLWEProduct, GLWECopy, ScratchTakeCore, layouts::{ - GGLWE, GGLWEInfos, GGLWEToRef, GGSW, GGSWInfos, GGSWToMut, GLWEInfos, LWEInfos, - prepared::{GLWETensorKeyPrepared, GLWETensorKeyPreparedToRef}, + GGLWE, GGLWEInfos, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedToRef, GGLWEToRef, GGSW, GGSWInfos, GGSWToMut, GLWE, + GLWEInfos, LWEInfos, }, }; @@ -31,7 +30,7 @@ impl GGSW { where M: GGSWFromGGLWE, G: GGLWEToRef, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { module.ggsw_from_gglwe(self, gglwe, tsk, scratch); @@ -54,12 +53,12 @@ where where R: GGSWToMut, A: GGLWEToRef, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let a: &GGLWE<&[u8]> = &a.to_ref(); - let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); assert_eq!(res.rank(), a.rank_out()); assert_eq!(res.dnum(), a.dnum()); @@ -85,177 +84,140 @@ pub trait GGSWFromGGLWE { where R: GGSWToMut, A: GGLWEToRef, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore; } -impl GGSWExpandRows for Module where - Self: Sized - + ModuleN - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigBytesOf - + VecZnxNormalizeTmpBytes - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigBytesOf - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VecZnxDftCopy - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftAddInplace - + VecZnxBigNormalize - + VecZnxIdftApplyTmpA - + VecZnxNormalize -{ +pub trait GGSWExpandRows { + fn ggsw_expand_rows_tmp_bytes(&self, res_infos: &R, tsk_infos: &A) -> usize + where + R: GGSWInfos, + A: GGLWEInfos; + + fn ggsw_expand_row(&self, res: &mut R, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + T: GGLWEToGGSWKeyPreparedToRef, + Scratch: ScratchTakeCore; } -pub trait GGSWExpandRows +impl GGSWExpandRows for Module where - Self: Sized - + ModuleN - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigBytesOf - + VecZnxNormalizeTmpBytes - + VecZnxDftApply - + VecZnxDftCopy - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftAddInplace + Self: GGLWEProduct + VecZnxBigNormalize - + VecZnxIdftApplyTmpA - + VecZnxNormalize, + + VecZnxBigNormalizeTmpBytes + + VecZnxBigBytesOf + + VecZnxDftBytesOf + + VecZnxDftApply + + VecZnxNormalize + + VecZnxBigAddSmallInplace + + VecZnxIdftApplyConsume, { fn ggsw_expand_rows_tmp_bytes(&self, res_infos: &R, tsk_infos: &A) -> usize where R: GGSWInfos, A: GGLWEInfos, { - let tsk_size: usize = tsk_infos.k().div_ceil(tsk_infos.base2k()) as usize; - let size_in: usize = res_infos - .k() - .div_ceil(tsk_infos.base2k()) - .div_ceil(tsk_infos.dsize().into()) as usize; + let base2k_in: usize = res_infos.base2k().into(); + let base2k_tsk: usize = tsk_infos.base2k().into(); - let tmp_dft_i: usize = self.bytes_of_vec_znx_dft((tsk_infos.rank_out() + 1).into(), tsk_size); - let tmp_a: usize = self.bytes_of_vec_znx_dft(1, size_in); - let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( - tsk_size, - size_in, - size_in, - (tsk_infos.rank_in()).into(), // Verify if rank+1 - (tsk_infos.rank_out()).into(), // Verify if rank+1 - tsk_size, - ); - let tmp_idft: usize = self.bytes_of_vec_znx_big(1, tsk_size); - let norm: usize = self.vec_znx_normalize_tmp_bytes(); + let rank: usize = res_infos.rank().into(); + let cols: usize = rank + 1; - tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) + let res_size = res_infos.size(); + let a_size: usize = (res_infos.size() * base2k_in).div_ceil(base2k_tsk); + + let a_dft = self.bytes_of_vec_znx_dft(cols - 1, a_size); + let res_dft = self.bytes_of_vec_znx_dft(cols, a_size); + let gglwe_prod: usize = self.gglwe_product_dft_tmp_bytes(res_size, a_size, tsk_infos); + let normalize = self.vec_znx_big_normalize_tmp_bytes(); + + (a_dft + res_dft + gglwe_prod).max(normalize) } fn ggsw_expand_row(&self, res: &mut R, tsk: &T, scratch: &mut Scratch) where R: GGSWToMut, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); - let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); + let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); - let basek_in: usize = res.base2k().into(); - let basek_tsk: usize = tsk.base2k().into(); + let base2k_in: usize = res.base2k().into(); + let base2k_tsk: usize = tsk.base2k().into(); assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk)); let rank: usize = res.rank().into(); let cols: usize = rank + 1; - let a_size: usize = (res.size() * basek_in).div_ceil(basek_tsk); + let a_size: usize = (res.size() * base2k_in).div_ceil(base2k_tsk); // Keyswitch the j-th row of the col 0 - for row_i in 0..res.dnum().into() { - let a = &res.at(row_i, 0).data; + for row in 0..res.dnum().as_usize() { + let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, a_size); - // Pre-compute DFT of (a0, a1, a2) - let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size); + { + let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0); - if basek_in == basek_tsk { - for i in 0..cols { - self.vec_znx_dft_apply(1, 0, &mut ci_dft, i, a, i); - } - } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size); - for i in 0..cols { - self.vec_znx_normalize(basek_tsk, &mut a_conv, 0, basek_in, a, i, scratch_2); - self.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &a_conv, 0); + if base2k_in == base2k_tsk { + for col_i in 0..cols - 1 { + self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1); + } + } else { + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size); + for i in 0..cols - 1 { + self.vec_znx_normalize( + base2k_tsk, + &mut a_conv, + 0, + base2k_in, + glwe_mi_1.data(), + i + 1, + scratch_2, + ); + self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_conv, 0); + } } } - for col_j in 1..cols { - // Example for rank 3: + // Example for rank 3: + // + // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is + // actually composed of that many dnum and we focus on a specific row here + // implicitely given ci_dft. + // + // # Input + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (0, 0, 0, 0) + // col 2: (0, 0, 0, 0) + // col 3: (0, 0, 0, 0) + // + // # Output + // + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) + // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) + // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) + // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) + for col in 1..cols { + let (mut res_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, tsk.size()); // Todo optimise + + // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 // - // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is - // actually composed of that many dnum and we focus on a specific row here - // implicitely given ci_dft. + // # Example for col=1 // - // # Input - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (0, 0, 0, 0) - // col 2: (0, 0, 0, 0) - // col 3: (0, 0, 0, 0) - // - // # Output - // - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 ) - // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) - // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) - // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) + // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) + // + + // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) + // + + // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) + // = + // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) + self.gglwe_product_dft(&mut res_dft, &a_dft, tsk.at(col - 1), scratch_2); - let dsize: usize = tsk.dsize().into(); - - let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, tsk.size()); - let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, ci_dft.size().div_ceil(dsize)); - - { - // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 - // - // # Example for col=1 - // - // a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2) - // + - // a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2) - // + - // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) - // = - // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) - for col_i in 1..cols { - let pmat: &VmpPMat<&[u8], BE> = &tsk.at(col_i - 1, col_j - 1).data; // Selects Enc(s[i]s[j]) - - // Extracts a[i] and multipies with Enc(s[i]s[j]) - for di in 0..dsize { - tmp_a.set_size((ci_dft.size() + di) / dsize); - - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - tmp_dft_i.set_size(tsk.size() - ((dsize - di) as isize - 2).max(0) as usize); - - self.vec_znx_dft_copy(dsize, dsize - 1 - di, &mut tmp_a, 0, &ci_dft, col_i); - if di == 0 && col_i == 1 { - self.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3); - } else { - self.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3); - } - } - } - } + let mut res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft); // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i // @@ -266,18 +228,17 @@ where // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) // = // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) - self.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0); - let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(self, 1, tsk.size()); - for i in 0..cols { - self.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i); + self.vec_znx_big_add_small_inplace(&mut res_big, col, res.at(row, 0).data(), 0); + + for j in 0..cols { self.vec_znx_big_normalize( - basek_in, - &mut res.at_mut(row_i, col_j).data, - i, - basek_tsk, - &tmp_idft, - 0, - scratch_3, + res.base2k().as_usize(), + res.at_mut(row, col).data_mut(), + j, + tsk.base2k().as_usize(), + &res_big, + j, + scratch_2, ); } } diff --git a/poulpy-core/src/conversion/lwe_to_glwe.rs b/poulpy-core/src/conversion/lwe_to_glwe.rs index c759ee5..c9b40c3 100644 --- a/poulpy-core/src/conversion/lwe_to_glwe.rs +++ b/poulpy-core/src/conversion/lwe_to_glwe.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - api::ScratchTakeBasic, + api::{ScratchTakeBasic, VecZnxNormalize, VecZnxNormalizeTmpBytes}, layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxView, ZnxViewMut, ZnxZero}, }; @@ -8,11 +8,10 @@ use crate::{ layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToMut, LWE, LWEInfos, LWEToRef}, }; -impl GLWEFromLWE for Module where Self: GLWEKeyswitch {} - -pub trait GLWEFromLWE +impl GLWEFromLWE for Module where - Self: GLWEKeyswitch, + Self: GLWEKeyswitch + VecZnxNormalizeTmpBytes + VecZnxNormalize, + Scratch: ScratchTakeCore, { fn glwe_from_lwe_tmp_bytes(&self, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize where @@ -41,7 +40,6 @@ where R: GLWEToMut, A: LWEToRef, K: GGLWEPreparedToRef + GGLWEInfos, - Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let lwe: &LWE<&[u8]> = &lwe.to_ref(); @@ -105,6 +103,23 @@ where } } +pub trait GLWEFromLWE +where + Self: GLWEKeyswitch, +{ + fn glwe_from_lwe_tmp_bytes(&self, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize + where + R: GLWEInfos, + A: LWEInfos, + K: GGLWEInfos; + + fn glwe_from_lwe(&self, res: &mut R, lwe: &A, ksk: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: LWEToRef, + K: GGLWEPreparedToRef + GGLWEInfos; +} + impl GLWE> { pub fn from_lwe_tmp_bytes(module: &M, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize where diff --git a/poulpy-core/src/encryption/compressed/gglwe_to_ggsw_key.rs b/poulpy-core/src/encryption/compressed/gglwe_to_ggsw_key.rs new file mode 100644 index 0000000..92f382c --- /dev/null +++ b/poulpy-core/src/encryption/compressed/gglwe_to_ggsw_key.rs @@ -0,0 +1,124 @@ +use poulpy_hal::{ + api::{ModuleN, ScratchTakeBasic, VecZnxCopy}, + layouts::{Backend, DataMut, Module, Scratch}, + source::Source, +}; + +use crate::{ + GGLWECompressedEncryptSk, GetDistribution, ScratchTakeCore, + layouts::{ + GGLWEInfos, GGLWEToGGSWKeyCompressed, GGLWEToGGSWKeyCompressedToMut, GLWEInfos, GLWESecret, GLWESecretTensor, + GLWESecretTensorFactory, GLWESecretToRef, + prepared::{GLWESecretPrepared, GLWESecretPreparedFactory}, + }, +}; + +impl GGLWEToGGSWKeyCompressed> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GGLWEToGGSWKeyCompressedEncryptSk, + { + module.gglwe_to_ggsw_key_encrypt_sk_tmp_bytes(infos) + } +} + +impl GGLWEToGGSWKeyCompressed { + pub fn encrypt_sk( + &mut self, + module: &M, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + M: GGLWEToGGSWKeyCompressedEncryptSk, + S: GLWESecretToRef + GetDistribution + GLWEInfos, + Scratch: ScratchTakeCore, + { + module.gglwe_to_ggsw_key_encrypt_sk(self, sk, seed_xa, source_xe, scratch); + } +} + +pub trait GGLWEToGGSWKeyCompressedEncryptSk { + fn gglwe_to_ggsw_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn gglwe_to_ggsw_key_encrypt_sk( + &self, + res: &mut R, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToGGSWKeyCompressedToMut + GGLWEInfos, + S: GLWESecretToRef + GetDistribution + GLWEInfos; +} + +impl GGLWEToGGSWKeyCompressedEncryptSk for Module +where + Self: ModuleN + GGLWECompressedEncryptSk + GLWESecretTensorFactory + GLWESecretPreparedFactory + VecZnxCopy, + Scratch: ScratchTakeCore, +{ + fn gglwe_to_ggsw_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + let sk_prepared: usize = GLWESecretPrepared::bytes_of(self, infos.rank()); + let sk_tensor: usize = GLWESecretTensor::bytes_of_from_infos(infos); + let gglwe_encrypt: usize = self.gglwe_compressed_encrypt_sk_tmp_bytes(infos); + let sk_ij = GLWESecret::bytes_of(self.n().into(), infos.rank()); + (sk_prepared + sk_tensor + sk_ij) + gglwe_encrypt.max(self.glwe_secret_tensor_prepare_tmp_bytes(infos.rank())) + } + + fn gglwe_to_ggsw_key_encrypt_sk( + &self, + res: &mut R, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToGGSWKeyCompressedToMut + GGLWEInfos, + S: GLWESecretToRef + GetDistribution + GLWEInfos, + { + assert_eq!(res.rank(), sk.rank()); + assert_eq!(res.n(), sk.n()); + + let res: &mut GGLWEToGGSWKeyCompressed<&mut [u8]> = &mut res.to_mut(); + let rank: usize = res.rank_out().as_usize(); + + let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank()); + let (mut sk_tensor, scratch_2) = scratch_1.take_glwe_secret_tensor(self.n().into(), res.rank()); + sk_prepared.prepare(self, sk); + sk_tensor.prepare(self, sk, scratch_2); + + let (mut sk_ij, scratch_3) = scratch_2.take_scalar_znx(self.n(), rank); + + let mut source_xa = Source::new(seed_xa); + + for i in 0..rank { + for j in 0..rank { + self.vec_znx_copy( + &mut sk_ij.as_vec_znx_mut(), + j, + &sk_tensor.at(i, j).as_vec_znx(), + 0, + ); + } + + let (seed_xa_tmp, _) = source_xa.branch(); + + res.at_mut(i).encrypt_sk( + self, + &sk_ij, + &sk_prepared, + seed_xa_tmp, + source_xe, + scratch_3, + ); + } + } +} diff --git a/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs b/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs index 14c9217..12af7ee 100644 --- a/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs +++ b/poulpy-core/src/encryption/compressed/glwe_tensor_key.rs @@ -1,17 +1,15 @@ use poulpy_hal::{ - api::{ - ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolBytesOf, SvpPrepare, VecZnxBigBytesOf, VecZnxBigNormalize, - VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyTmpA, - }, + api::ScratchTakeBasic, layouts::{Backend, DataMut, Module, Scratch}, source::Source, }; use crate::{ - GGLWECompressedEncryptSk, GLWETensorKeyEncryptSk, GetDistribution, ScratchTakeCore, + GGLWECompressedEncryptSk, GetDistribution, ScratchTakeCore, layouts::{ - GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretPrepared, GLWESecretPreparedFactory, GLWESecretToRef, - GLWETensorKeyCompressedAtMut, LWEInfos, Rank, compressed::GLWETensorKeyCompressed, + GGLWECompressedSeedMut, GGLWECompressedToMut, GGLWEInfos, GGLWELayout, GLWEInfos, GLWESecretPrepared, + GLWESecretPreparedFactory, GLWESecretTensor, GLWESecretTensorFactory, GLWESecretToRef, + compressed::GLWETensorKeyCompressed, }, }; @@ -34,7 +32,7 @@ impl GLWETensorKeyCompressed { source_xe: &mut Source, scratch: &mut Scratch, ) where - S: GLWESecretToRef + GetDistribution, + S: GLWESecretToRef + GetDistribution + GLWEInfos, M: GLWETensorKeyCompressedEncryptSk, { module.glwe_tensor_key_compressed_encrypt_sk(self, sk, seed_xa, source_xe, scratch); @@ -46,7 +44,7 @@ pub trait GLWETensorKeyCompressedEncryptSk { where A: GGLWEInfos; - fn glwe_tensor_key_compressed_encrypt_sk( + fn glwe_tensor_key_compressed_encrypt_sk( &self, res: &mut R, sk: &S, @@ -54,40 +52,38 @@ pub trait GLWETensorKeyCompressedEncryptSk { source_xe: &mut Source, scratch: &mut Scratch, ) where - D: DataMut, - R: GLWETensorKeyCompressedAtMut + GGLWEInfos, - S: GLWESecretToRef + GetDistribution; + R: GGLWECompressedToMut + GGLWEInfos + GGLWECompressedSeedMut, + S: GLWESecretToRef + GetDistribution + GLWEInfos; } impl GLWETensorKeyCompressedEncryptSk for Module where - Self: ModuleN - + GGLWECompressedEncryptSk - + GLWETensorKeyEncryptSk - + VecZnxDftApply - + SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxBigNormalize - + SvpPrepare - + SvpPPolBytesOf - + VecZnxDftBytesOf - + VecZnxBigBytesOf - + GLWESecretPreparedFactory, + Self: GGLWECompressedEncryptSk + GLWESecretPreparedFactory + GLWESecretTensorFactory, Scratch: ScratchTakeBasic + ScratchTakeCore, { fn glwe_tensor_key_compressed_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos, { - GLWESecretPrepared::bytes_of(self, infos.rank_out()) - + self.bytes_of_vec_znx_dft(infos.rank_out().into(), 1) - + self.bytes_of_vec_znx_big(1, 1) - + self.bytes_of_vec_znx_dft(1, 1) - + GLWESecret::bytes_of(self.n().into(), Rank(1)) - + self.gglwe_compressed_encrypt_sk_tmp_bytes(infos) + let sk_prepared: usize = GLWESecretPrepared::bytes_of(self, infos.rank_out()); + let sk_tensor: usize = GLWESecretTensor::bytes_of_from_infos(infos); + + let tensor_infos: GGLWELayout = GGLWELayout { + n: infos.n(), + base2k: infos.base2k(), + k: infos.k(), + rank_in: GLWESecretTensor::pairs(infos.rank().into()).into(), + rank_out: infos.rank_out(), + dnum: infos.dnum(), + dsize: infos.dsize(), + }; + + let gglwe_encrypt: usize = self.gglwe_compressed_encrypt_sk_tmp_bytes(&tensor_infos); + + (sk_prepared + sk_tensor) + gglwe_encrypt.max(self.glwe_secret_tensor_prepare_tmp_bytes(infos.rank())) } - fn glwe_tensor_key_compressed_encrypt_sk( + fn glwe_tensor_key_compressed_encrypt_sk( &self, res: &mut R, sk: &S, @@ -95,62 +91,24 @@ where source_xe: &mut Source, scratch: &mut Scratch, ) where - D: DataMut, - R: GGLWEInfos + GLWETensorKeyCompressedAtMut, - S: GLWESecretToRef + GetDistribution, + R: GGLWEInfos + GGLWECompressedToMut + GGLWECompressedSeedMut, + S: GLWESecretToRef + GetDistribution + GLWEInfos, { - let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank()); - sk_dft_prep.prepare(self, sk); + assert_eq!(res.rank_out(), sk.rank()); + assert_eq!(res.n(), sk.n()); - let sk: &GLWESecret<&[u8]> = &sk.to_ref(); + let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank()); + let (mut sk_tensor, scratch_2) = scratch_1.take_glwe_secret_tensor(self.n().into(), res.rank()); + sk_prepared.prepare(self, sk); + sk_tensor.prepare(self, sk, scratch_2); - #[cfg(debug_assertions)] - { - assert_eq!(res.rank_out(), sk.rank()); - assert_eq!(res.n(), sk.n()); - } - - // let n: usize = sk.n().into(); - let rank: usize = res.rank_out().into(); - - let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank, 1); - - for i in 0..rank { - self.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); - } - - let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1); - let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(self.n().into(), Rank(1)); - let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(self, 1, 1); - - let mut source_xa: Source = Source::new(seed_xa); - - for i in 0..rank { - for j in i..rank { - self.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); - - self.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); - self.vec_znx_big_normalize( - res.base2k().into(), - &mut sk_ij.data.as_vec_znx_mut(), - 0, - res.base2k().into(), - &sk_ij_big, - 0, - scratch_5, - ); - - let (seed_xa_tmp, _) = source_xa.branch(); - - self.gglwe_compressed_encrypt_sk( - res.at_mut(i, j), - &sk_ij.data, - &sk_dft_prep, - seed_xa_tmp, - source_xe, - scratch_5, - ); - } - } + self.gglwe_compressed_encrypt_sk( + res, + &sk_tensor.data, + &sk_prepared, + seed_xa, + source_xe, + scratch_2, + ); } } diff --git a/poulpy-core/src/encryption/compressed/mod.rs b/poulpy-core/src/encryption/compressed/mod.rs index e96eeb5..1b21e1f 100644 --- a/poulpy-core/src/encryption/compressed/mod.rs +++ b/poulpy-core/src/encryption/compressed/mod.rs @@ -1,4 +1,5 @@ mod gglwe; +mod gglwe_to_ggsw_key; mod ggsw; mod glwe_automorphism_key; mod glwe_ct; @@ -6,6 +7,7 @@ mod glwe_switching_key; mod glwe_tensor_key; pub use gglwe::*; +pub use gglwe_to_ggsw_key::*; pub use ggsw::*; pub use glwe_automorphism_key::*; pub use glwe_ct::*; diff --git a/poulpy-core/src/encryption/gglwe.rs b/poulpy-core/src/encryption/gglwe.rs index ba78cde..a50b565 100644 --- a/poulpy-core/src/encryption/gglwe.rs +++ b/poulpy-core/src/encryption/gglwe.rs @@ -148,7 +148,7 @@ where // Example for ksk rank 2 to rank 3: // // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2) - // (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2) + // (-(b0*s0 + b1*s1 + b2*s2) + s1', b0, b1, b2) // // Example ksk rank 2 to rank 1 // diff --git a/poulpy-core/src/encryption/gglwe_to_ggsw_key.rs b/poulpy-core/src/encryption/gglwe_to_ggsw_key.rs new file mode 100644 index 0000000..017455f --- /dev/null +++ b/poulpy-core/src/encryption/gglwe_to_ggsw_key.rs @@ -0,0 +1,112 @@ +use poulpy_hal::{ + api::{ModuleN, ScratchTakeBasic, VecZnxCopy}, + layouts::{Backend, DataMut, Module, Scratch}, + source::Source, +}; + +use crate::{ + GGLWEEncryptSk, GetDistribution, ScratchTakeCore, + layouts::{ + GGLWEInfos, GGLWEToGGSWKey, GGLWEToGGSWKeyToMut, GLWEInfos, GLWESecret, GLWESecretTensor, GLWESecretTensorFactory, + GLWESecretToRef, + prepared::{GLWESecretPrepared, GLWESecretPreparedFactory}, + }, +}; + +impl GGLWEToGGSWKey> { + pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GGLWEToGGSWKeyEncryptSk, + { + module.gglwe_to_ggsw_key_encrypt_sk_tmp_bytes(infos) + } +} + +impl GGLWEToGGSWKey { + pub fn encrypt_sk( + &mut self, + module: &M, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + M: GGLWEToGGSWKeyEncryptSk, + S: GLWESecretToRef + GetDistribution + GLWEInfos, + Scratch: ScratchTakeCore, + { + module.gglwe_to_ggsw_key_encrypt_sk(self, sk, source_xa, source_xe, scratch); + } +} + +pub trait GGLWEToGGSWKeyEncryptSk { + fn gglwe_to_ggsw_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn gglwe_to_ggsw_key_encrypt_sk( + &self, + res: &mut R, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToGGSWKeyToMut, + S: GLWESecretToRef + GetDistribution + GLWEInfos; +} + +impl GGLWEToGGSWKeyEncryptSk for Module +where + Self: ModuleN + GGLWEEncryptSk + GLWESecretTensorFactory + GLWESecretPreparedFactory + VecZnxCopy, + Scratch: ScratchTakeCore, +{ + fn gglwe_to_ggsw_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + let sk_prepared: usize = GLWESecretPrepared::bytes_of(self, infos.rank()); + let sk_tensor: usize = GLWESecretTensor::bytes_of_from_infos(infos); + let gglwe_encrypt: usize = self.gglwe_encrypt_sk_tmp_bytes(infos); + let sk_ij = GLWESecret::bytes_of(self.n().into(), infos.rank()); + (sk_prepared + sk_tensor + sk_ij) + gglwe_encrypt.max(self.glwe_secret_tensor_prepare_tmp_bytes(infos.rank())) + } + + fn gglwe_to_ggsw_key_encrypt_sk( + &self, + res: &mut R, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEToGGSWKeyToMut, + S: GLWESecretToRef + GetDistribution + GLWEInfos, + { + let res: &mut GGLWEToGGSWKey<&mut [u8]> = &mut res.to_mut(); + + let rank: usize = res.rank_out().as_usize(); + + let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank()); + let (mut sk_tensor, scratch_2) = scratch_1.take_glwe_secret_tensor(self.n().into(), res.rank()); + sk_prepared.prepare(self, sk); + sk_tensor.prepare(self, sk, scratch_2); + + let (mut sk_ij, scratch_3) = scratch_2.take_scalar_znx(self.n(), rank); + + for i in 0..rank { + for j in 0..rank { + self.vec_znx_copy( + &mut sk_ij.as_vec_znx_mut(), + j, + &sk_tensor.at(i, j).as_vec_znx(), + 0, + ); + } + + res.at_mut(i) + .encrypt_sk(self, &sk_ij, &sk_prepared, source_xa, source_xe, scratch_3); + } + } +} diff --git a/poulpy-core/src/encryption/glwe_tensor_key.rs b/poulpy-core/src/encryption/glwe_tensor_key.rs index b7afae5..08df09b 100644 --- a/poulpy-core/src/encryption/glwe_tensor_key.rs +++ b/poulpy-core/src/encryption/glwe_tensor_key.rs @@ -1,8 +1,5 @@ use poulpy_hal::{ - api::{ - ModuleN, ScratchTakeBasic, SvpApplyDftToDft, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, - VecZnxIdftApplyTmpA, - }, + api::ModuleN, layouts::{Backend, DataMut, Module, Scratch}, source::Source, }; @@ -10,7 +7,8 @@ use poulpy_hal::{ use crate::{ GGLWEEncryptSk, GetDistribution, ScratchTakeCore, layouts::{ - GGLWE, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, GLWETensorKey, GLWETensorKeyToMut, LWEInfos, Rank, + GGLWEInfos, GGLWELayout, GGLWEToMut, GLWEInfos, GLWESecretTensor, GLWESecretTensorFactory, GLWESecretToRef, + GLWETensorKey, prepared::{GLWESecretPrepared, GLWESecretPreparedFactory}, }, }; @@ -55,33 +53,35 @@ pub trait GLWETensorKeyEncryptSk { source_xe: &mut Source, scratch: &mut Scratch, ) where - R: GLWETensorKeyToMut, + R: GGLWEToMut + GGLWEInfos, S: GLWESecretToRef + GetDistribution + GLWEInfos; } impl GLWETensorKeyEncryptSk for Module where - Self: ModuleN - + GGLWEEncryptSk - + VecZnxDftBytesOf - + VecZnxBigBytesOf - + GLWESecretPreparedFactory - + VecZnxDftApply - + SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxBigNormalize, + Self: ModuleN + GGLWEEncryptSk + GLWESecretPreparedFactory + GLWESecretTensorFactory, Scratch: ScratchTakeCore, { fn glwe_tensor_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos, { - GLWESecretPrepared::bytes_of(self, infos.rank_out()) - + self.bytes_of_vec_znx_dft(infos.rank_out().into(), 1) - + self.bytes_of_vec_znx_big(1, 1) - + self.bytes_of_vec_znx_dft(1, 1) - + GLWESecret::bytes_of(self.n().into(), Rank(1)) - + GGLWE::encrypt_sk_tmp_bytes(self, infos) + let sk_prepared: usize = GLWESecretPrepared::bytes_of(self, infos.rank_out()); + let sk_tensor: usize = GLWESecretTensor::bytes_of_from_infos(infos); + + let tensor_infos: GGLWELayout = GGLWELayout { + n: infos.n(), + base2k: infos.base2k(), + k: infos.k(), + rank_in: GLWESecretTensor::pairs(infos.rank().into()).into(), + rank_out: infos.rank_out(), + dnum: infos.dnum(), + dsize: infos.dsize(), + }; + + let gglwe_encrypt: usize = self.gglwe_encrypt_sk_tmp_bytes(&tensor_infos); + + (sk_prepared + sk_tensor) + gglwe_encrypt.max(self.glwe_secret_tensor_prepare_tmp_bytes(infos.rank())) } fn glwe_tensor_key_encrypt_sk( @@ -92,56 +92,24 @@ where source_xe: &mut Source, scratch: &mut Scratch, ) where - R: GLWETensorKeyToMut, + R: GGLWEToMut + GGLWEInfos, S: GLWESecretToRef + GetDistribution + GLWEInfos, { - let res: &mut GLWETensorKey<&mut [u8]> = &mut res.to_mut(); - - // let n: RingDegree = sk.n(); - let rank: Rank = res.rank_out(); - - let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, sk.rank()); - sk_prepared.prepare(self, sk); - - let sk: &GLWESecret<&[u8]> = &sk.to_ref(); - assert_eq!(res.rank_out(), sk.rank()); assert_eq!(res.n(), sk.n()); - let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank.into(), 1); + let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank()); + let (mut sk_tensor, scratch_2) = scratch_1.take_glwe_secret_tensor(self.n().into(), res.rank()); + sk_prepared.prepare(self, sk); + sk_tensor.prepare(self, sk, scratch_2); - (0..rank.into()).for_each(|i| { - self.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); - }); - - let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1); - let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(self.n().into(), Rank(1)); - let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(self, 1, 1); - - (0..rank.into()).for_each(|i| { - (i..rank.into()).for_each(|j| { - self.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i); - - self.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); - self.vec_znx_big_normalize( - res.base2k().into(), - &mut sk_ij.data.as_vec_znx_mut(), - 0, - res.base2k().into(), - &sk_ij_big, - 0, - scratch_5, - ); - - res.at_mut(i, j).encrypt_sk( - self, - &sk_ij.data, - &sk_prepared, - source_xa, - source_xe, - scratch_5, - ); - }); - }) + self.gglwe_encrypt_sk( + res, + &sk_tensor.data, + &sk_prepared, + source_xa, + source_xe, + scratch_2, + ); } } diff --git a/poulpy-core/src/encryption/glwe_to_lwe_switching_key.rs b/poulpy-core/src/encryption/glwe_to_lwe_key.rs similarity index 83% rename from poulpy-core/src/encryption/glwe_to_lwe_switching_key.rs rename to poulpy-core/src/encryption/glwe_to_lwe_key.rs index 71877a4..0609fb6 100644 --- a/poulpy-core/src/encryption/glwe_to_lwe_switching_key.rs +++ b/poulpy-core/src/encryption/glwe_to_lwe_key.rs @@ -7,23 +7,22 @@ use poulpy_hal::{ use crate::{ GGLWEEncryptSk, ScratchTakeCore, layouts::{ - GGLWE, GGLWEInfos, GGLWEToMut, GLWESecret, GLWESecretToRef, GLWEToLWESwitchingKey, LWEInfos, LWESecret, LWESecretToRef, - Rank, + GGLWE, GGLWEInfos, GGLWEToMut, GLWESecret, GLWESecretToRef, GLWEToLWEKey, LWEInfos, LWESecret, LWESecretToRef, Rank, prepared::{GLWESecretPrepared, GLWESecretPreparedFactory}, }, }; -impl GLWEToLWESwitchingKey> { +impl GLWEToLWEKey> { pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize where A: GGLWEInfos, M: GLWEToLWESwitchingKeyEncryptSk, { - module.glwe_to_lwe_switching_key_encrypt_sk_tmp_bytes(infos) + module.glwe_to_lwe_key_encrypt_sk_tmp_bytes(infos) } } -impl GLWEToLWESwitchingKey { +impl GLWEToLWEKey { pub fn encrypt_sk( &mut self, module: &M, @@ -38,16 +37,16 @@ impl GLWEToLWESwitchingKey { S2: GLWESecretToRef, Scratch: ScratchTakeCore, { - module.glwe_to_lwe_switching_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); + module.glwe_to_lwe_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); } } pub trait GLWEToLWESwitchingKeyEncryptSk { - fn glwe_to_lwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + fn glwe_to_lwe_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos; - fn glwe_to_lwe_switching_key_encrypt_sk( + fn glwe_to_lwe_key_encrypt_sk( &self, res: &mut R, sk_lwe: &S1, @@ -70,7 +69,7 @@ where + VecZnxAutomorphismInplaceTmpBytes, Scratch: ScratchTakeCore, { - fn glwe_to_lwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + fn glwe_to_lwe_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos, { @@ -79,7 +78,7 @@ where .max(GLWESecret::bytes_of(self.n().into(), infos.rank_in()) + self.vec_znx_automorphism_inplace_tmp_bytes()) } - fn glwe_to_lwe_switching_key_encrypt_sk( + fn glwe_to_lwe_key_encrypt_sk( &self, res: &mut R, sk_lwe: &S1, diff --git a/poulpy-core/src/encryption/lwe_to_glwe_switching_key.rs b/poulpy-core/src/encryption/lwe_to_glwe_key.rs similarity index 81% rename from poulpy-core/src/encryption/lwe_to_glwe_switching_key.rs rename to poulpy-core/src/encryption/lwe_to_glwe_key.rs index af31420..c5fcd15 100644 --- a/poulpy-core/src/encryption/lwe_to_glwe_switching_key.rs +++ b/poulpy-core/src/encryption/lwe_to_glwe_key.rs @@ -8,21 +8,21 @@ use crate::{ GGLWEEncryptSk, ScratchTakeCore, layouts::{ GGLWE, GGLWEInfos, GGLWEToMut, GLWESecret, GLWESecretPreparedFactory, GLWESecretPreparedToRef, LWEInfos, LWESecret, - LWESecretToRef, LWEToGLWESwitchingKey, Rank, + LWESecretToRef, LWEToGLWEKey, Rank, }, }; -impl LWEToGLWESwitchingKey> { +impl LWEToGLWEKey> { pub fn encrypt_sk_tmp_bytes(module: &M, infos: &A) -> usize where A: GGLWEInfos, M: LWEToGLWESwitchingKeyEncryptSk, { - module.lwe_to_glwe_switching_key_encrypt_sk_tmp_bytes(infos) + module.lwe_to_glwe_key_encrypt_sk_tmp_bytes(infos) } } -impl LWEToGLWESwitchingKey { +impl LWEToGLWEKey { pub fn encrypt_sk( &mut self, module: &M, @@ -37,16 +37,16 @@ impl LWEToGLWESwitchingKey { M: LWEToGLWESwitchingKeyEncryptSk, Scratch: ScratchTakeCore, { - module.lwe_to_glwe_switching_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); + module.lwe_to_glwe_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); } } pub trait LWEToGLWESwitchingKeyEncryptSk { - fn lwe_to_glwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + fn lwe_to_glwe_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos; - fn lwe_to_glwe_switching_key_encrypt_sk( + fn lwe_to_glwe_key_encrypt_sk( &self, res: &mut R, sk_lwe: &S1, @@ -69,20 +69,20 @@ where + VecZnxAutomorphismInplaceTmpBytes, Scratch: ScratchTakeCore, { - fn lwe_to_glwe_switching_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize + fn lwe_to_glwe_key_encrypt_sk_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos, { debug_assert_eq!( infos.rank_in(), Rank(1), - "rank_in != 1 is not supported for LWEToGLWESwitchingKey" + "rank_in != 1 is not supported for LWEToGLWEKeyPrepared" ); GLWESecret::bytes_of(self.n().into(), infos.rank_in()) + GGLWE::encrypt_sk_tmp_bytes(self, infos).max(self.vec_znx_automorphism_inplace_tmp_bytes()) } - fn lwe_to_glwe_switching_key_encrypt_sk( + fn lwe_to_glwe_key_encrypt_sk( &self, res: &mut R, sk_lwe: &S1, diff --git a/poulpy-core/src/encryption/mod.rs b/poulpy-core/src/encryption/mod.rs index 7a391a6..d64757f 100644 --- a/poulpy-core/src/encryption/mod.rs +++ b/poulpy-core/src/encryption/mod.rs @@ -1,28 +1,30 @@ mod compressed; mod gglwe; +mod gglwe_to_ggsw_key; mod ggsw; mod glwe; mod glwe_automorphism_key; mod glwe_public_key; mod glwe_switching_key; mod glwe_tensor_key; -mod glwe_to_lwe_switching_key; +mod glwe_to_lwe_key; mod lwe; mod lwe_switching_key; -mod lwe_to_glwe_switching_key; +mod lwe_to_glwe_key; pub use compressed::*; pub use gglwe::*; +pub use gglwe_to_ggsw_key::*; pub use ggsw::*; pub use glwe::*; pub use glwe_automorphism_key::*; pub use glwe_public_key::*; pub use glwe_switching_key::*; pub use glwe_tensor_key::*; -pub use glwe_to_lwe_switching_key::*; +pub use glwe_to_lwe_key::*; pub use lwe::*; pub use lwe_switching_key::*; -pub use lwe_to_glwe_switching_key::*; +pub use lwe_to_glwe_key::*; pub const SIGMA: f64 = 3.2; pub(crate) const SIGMA_BOUND: f64 = 6.0 * SIGMA; diff --git a/poulpy-core/src/glwe_packer.rs b/poulpy-core/src/glwe_packer.rs new file mode 100644 index 0000000..da8c93e --- /dev/null +++ b/poulpy-core/src/glwe_packer.rs @@ -0,0 +1,388 @@ +use std::collections::HashMap; + +use poulpy_hal::{ + api::ModuleLogN, + layouts::{Backend, GaloisElement, Module, Scratch}, +}; + +use crate::{ + GLWEAdd, GLWEAutomorphism, GLWECopy, GLWENormalize, GLWERotate, GLWEShift, GLWESub, ScratchTakeCore, + glwe_trace::GLWETrace, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos}, +}; + +/// [GLWEPacker] enables only the fly GLWE packing +/// with constant memory of Log(N) ciphertexts. +/// Main difference with usual GLWE packing is that +/// the output is bit-reversed. +pub struct GLWEPacker { + accumulators: Vec, + log_batch: usize, + counter: usize, +} + +/// [Accumulator] stores intermediate packing result. +/// There are Log(N) such accumulators in a [GLWEPacker]. +struct Accumulator { + data: GLWE>, + value: bool, // Implicit flag for zero ciphertext + control: bool, // Can be combined with incoming value +} + +impl Accumulator { + /// Allocates a new [Accumulator]. + /// + /// #Arguments + /// + /// * `module`: static backend FFT tables. + /// * `base2k`: base 2 logarithm of the GLWE ciphertext in memory digit representation. + /// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus. + /// * `rank`: rank of the GLWE ciphertext. + pub fn alloc(infos: &A) -> Self + where + A: GLWEInfos, + { + Self { + data: GLWE::alloc_from_infos(infos), + value: false, + control: false, + } + } +} + +impl GLWEPacker { + /// Instantiates a new [GLWEPacker]. + /// + /// # Arguments + /// + /// * `log_batch`: packs coefficients which are multiples of X^{N/2^log_batch}. + /// i.e. with `log_batch=0` only the constant coefficient is packed + /// and N GLWE ciphertext can be packed. With `log_batch=2` all coefficients + /// which are multiples of X^{N/4} are packed. Meaning that N/4 ciphertexts + /// can be packed. + pub fn alloc(infos: &A, log_batch: usize) -> Self + where + A: GLWEInfos, + { + let mut accumulators: Vec = Vec::::new(); + let log_n: usize = infos.n().log2(); + (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(infos))); + GLWEPacker { + accumulators, + log_batch, + counter: 0, + } + } + + /// Implicit reset of the internal state (to be called before a new packing procedure). + fn reset(&mut self) { + for i in 0..self.accumulators.len() { + self.accumulators[i].value = false; + self.accumulators[i].control = false; + } + self.counter = 0; + } + + /// Number of scratch space bytes required to call [Self::add]. + pub fn tmp_bytes(module: &M, res_infos: &R, key_infos: &K) -> usize + where + R: GLWEInfos, + K: GGLWEInfos, + M: GLWEPackerOps, + { + GLWE::bytes_of_from_infos(res_infos) + + module + .glwe_rsh_tmp_byte() + .max(module.glwe_automorphism_tmp_bytes(res_infos, res_infos, key_infos)) + } + + pub fn galois_elements(module: &M) -> Vec + where + M: GLWETrace, + { + module.glwe_trace_galois_elements() + } + + /// Adds a GLWE ciphertext to the [GLWEPacker]. + /// #Arguments + /// + /// * `module`: static backend FFT tables. + /// * `res`: space to append fully packed ciphertext. Only when the number + /// of packed ciphertexts reaches N/2^log_batch is a result written. + /// * `a`: ciphertext to pack. Can optionally give None to pack a 0 ciphertext. + /// * `auto_keys`: a [HashMap] containing the [AutomorphismKeyExec]s. + /// * `scratch`: scratch space of size at least [Self::tmp_bytes]. + pub fn add(&mut self, module: &M, a: Option<&A>, auto_keys: &HashMap, scratch: &mut Scratch) + where + A: GLWEToRef + GLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + M: GLWEPackerOps, + Scratch: ScratchTakeCore, + { + assert!( + (self.counter as u32) < self.accumulators[0].data.n(), + "Packing limit of {} reached", + self.accumulators[0].data.n().0 as usize >> self.log_batch + ); + + module.packer_add(self, a, self.log_batch, auto_keys, scratch); + self.counter += 1 << self.log_batch; + } + + /// Flush result to`res`. + pub fn flush(&mut self, module: &M, res: &mut R) + where + R: GLWEToMut, + M: GLWEPackerOps, + { + assert!(self.counter as u32 == self.accumulators[0].data.n()); + // Copy result GLWE into res GLWE + module.glwe_copy( + res, + &self.accumulators[module.log_n() - self.log_batch - 1].data, + ); + + self.reset(); + } +} + +impl GLWEPackerOps for Module where + Self: Sized + + ModuleLogN + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize + + GLWECopy + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize +{ +} + +pub trait GLWEPackerOps +where + Self: Sized + + ModuleLogN + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize + + GLWECopy + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize, +{ + fn packer_add( + &self, + packer: &mut GLWEPacker, + a: Option<&A>, + i: usize, + auto_keys: &HashMap, + scratch: &mut Scratch, + ) where + A: GLWEToRef + GLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, + { + pack_core(self, a, &mut packer.accumulators, i, auto_keys, scratch) + } +} + +fn pack_core( + module: &M, + a: Option<&A>, + accumulators: &mut [Accumulator], + i: usize, + auto_keys: &HashMap, + scratch: &mut Scratch, +) where + A: GLWEToRef + GLWEInfos, + M: ModuleLogN + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize + + GLWECopy + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, +{ + let log_n: usize = module.log_n(); + + if i == log_n { + return; + } + + // Isolate the first accumulator + let (acc_prev, acc_next) = accumulators.split_at_mut(1); + + // Control = true accumlator is free to overide + if !acc_prev[0].control { + let acc_mut_ref: &mut Accumulator = &mut acc_prev[0]; // from split_at_mut + + // No previous value -> copies and sets flags accordingly + if let Some(a_ref) = a { + module.glwe_copy(&mut acc_mut_ref.data, a_ref); + acc_mut_ref.value = true + } else { + acc_mut_ref.value = false + } + acc_mut_ref.control = true; // Able to be combined on next call + } else { + // Compresses acc_prev <- combine(acc_prev, a). + combine(module, &mut acc_prev[0], a, i, auto_keys, scratch); + acc_prev[0].control = false; + + // Propagates to next accumulator + if acc_prev[0].value { + pack_core( + module, + Some(&acc_prev[0].data), + acc_next, + i + 1, + auto_keys, + scratch, + ); + } else { + pack_core( + module, + None::<&GLWE>>, + acc_next, + i + 1, + auto_keys, + scratch, + ); + } + } +} + +fn combine( + module: &M, + acc: &mut Accumulator, + b: Option<&B>, + i: usize, + auto_keys: &HashMap, + scratch: &mut Scratch, +) where + B: GLWEToRef + GLWEInfos, + B: GLWEToRef + GLWEInfos, + M: ModuleLogN + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize + + GLWECopy + + GLWEAutomorphism + + GaloisElement + + GLWERotate + + GLWESub + + GLWEShift + + GLWEAdd + + GLWENormalize, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + Scratch: ScratchTakeCore, +{ + let log_n: usize = acc.data.n().log2(); + let a: &mut GLWE> = &mut acc.data; + + let gal_el: i64 = if i == 0 { + -1 + } else { + module.galois_element(1 << (i - 1)) + }; + + let t: i64 = 1 << (log_n - i - 1); + + // Goal is to evaluate: a = a + b*X^t + phi(a - b*X^t)) + // We also use the identity: AUTO(a * X^t, g) = -X^t * AUTO(a, g) + // where t = 2^(log_n - i - 1) and g = 5^{2^(i - 1)} + // Different cases for wether a and/or b are zero. + // + // Implicite RSH without modulus switch, introduces extra I(X) * Q/2 on decryption. + // Necessary so that the scaling of the plaintext remains constant. + // It however is ok to do so here because coefficients are eventually + // either mapped to garbage or twice their value which vanishes I(X) + // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q. + if acc.value { + if let Some(b) = b { + let (mut tmp_b, scratch_1) = scratch.take_glwe(a); + + // a = a * X^-t + module.glwe_rotate_inplace(-t, a, scratch_1); + + // tmp_b = a * X^-t - b + module.glwe_sub(&mut tmp_b, a, b); + module.glwe_rsh(1, &mut tmp_b, scratch_1); + + // a = a * X^-t + b + module.glwe_add_inplace(a, b); + module.glwe_rsh(1, a, scratch_1); + + module.glwe_normalize_inplace(&mut tmp_b, scratch_1); + + // tmp_b = phi(a * X^-t - b) + if let Some(auto_key) = auto_keys.get(&gal_el) { + module.glwe_automorphism_inplace(&mut tmp_b, auto_key, scratch_1); + } else { + panic!("auto_key[{gal_el}] not found"); + } + + // a = a * X^-t + b - phi(a * X^-t - b) + module.glwe_sub_inplace(a, &tmp_b); + module.glwe_normalize_inplace(a, scratch_1); + + // a = a + b * X^t - phi(a * X^-t - b) * X^t + // = a + b * X^t - phi(a * X^-t - b) * - phi(X^t) + // = a + b * X^t + phi(a - b * X^t) + module.glwe_rotate_inplace(t, a, scratch_1); + } else { + module.glwe_rsh(1, a, scratch); + // a = a + phi(a) + if let Some(auto_key) = auto_keys.get(&gal_el) { + module.glwe_automorphism_add_inplace(a, auto_key, scratch); + } else { + panic!("auto_key[{gal_el}] not found"); + } + } + } else if let Some(b) = b { + let (mut tmp_b, scratch_1) = scratch.take_glwe(a); + module.glwe_rotate(t, &mut tmp_b, b); + module.glwe_rsh(1, &mut tmp_b, scratch_1); + + // a = (b* X^t - phi(b* X^t)) + if let Some(auto_key) = auto_keys.get(&gal_el) { + module.glwe_automorphism_sub_negate(a, &tmp_b, auto_key, scratch_1); + } else { + panic!("auto_key[{gal_el}] not found"); + } + + acc.value = true; + } +} diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index 09540b2..6debd0d 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -7,166 +7,23 @@ use poulpy_hal::{ use crate::{ GLWEAdd, GLWEAutomorphism, GLWECopy, GLWENormalize, GLWERotate, GLWEShift, GLWESub, ScratchTakeCore, - glwe_trace::GLWETrace, - layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos}, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement}, }; - -/// [GLWEPacker] enables only the fly GLWE packing -/// with constant memory of Log(N) ciphertexts. -/// Main difference with usual GLWE packing is that -/// the output is bit-reversed. -pub struct GLWEPacker { - accumulators: Vec, - log_batch: usize, - counter: usize, +pub trait GLWEPacking { + /// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)] + /// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)] + fn glwe_pack( + &self, + cts: &mut HashMap, + log_gap_out: usize, + keys: &HashMap, + scratch: &mut Scratch, + ) where + R: GLWEToMut + GLWEToRef + GLWEInfos, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos; } -/// [Accumulator] stores intermediate packing result. -/// There are Log(N) such accumulators in a [GLWEPacker]. -struct Accumulator { - data: GLWE>, - value: bool, // Implicit flag for zero ciphertext - control: bool, // Can be combined with incoming value -} - -impl Accumulator { - /// Allocates a new [Accumulator]. - /// - /// #Arguments - /// - /// * `module`: static backend FFT tables. - /// * `base2k`: base 2 logarithm of the GLWE ciphertext in memory digit representation. - /// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus. - /// * `rank`: rank of the GLWE ciphertext. - pub fn alloc(infos: &A) -> Self - where - A: GLWEInfos, - { - Self { - data: GLWE::alloc_from_infos(infos), - value: false, - control: false, - } - } -} - -impl GLWEPacker { - /// Instantiates a new [GLWEPacker]. - /// - /// # Arguments - /// - /// * `log_batch`: packs coefficients which are multiples of X^{N/2^log_batch}. - /// i.e. with `log_batch=0` only the constant coefficient is packed - /// and N GLWE ciphertext can be packed. With `log_batch=2` all coefficients - /// which are multiples of X^{N/4} are packed. Meaning that N/4 ciphertexts - /// can be packed. - pub fn alloc(infos: &A, log_batch: usize) -> Self - where - A: GLWEInfos, - { - let mut accumulators: Vec = Vec::::new(); - let log_n: usize = infos.n().log2(); - (0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(infos))); - GLWEPacker { - accumulators, - log_batch, - counter: 0, - } - } - - /// Implicit reset of the internal state (to be called before a new packing procedure). - fn reset(&mut self) { - for i in 0..self.accumulators.len() { - self.accumulators[i].value = false; - self.accumulators[i].control = false; - } - self.counter = 0; - } - - /// Number of scratch space bytes required to call [Self::add]. - pub fn tmp_bytes(module: &M, res_infos: &R, key_infos: &K) -> usize - where - R: GLWEInfos, - K: GGLWEInfos, - M: GLWEPacking, - { - GLWE::bytes_of_from_infos(res_infos) - + module - .glwe_rsh_tmp_byte() - .max(module.glwe_automorphism_tmp_bytes(res_infos, res_infos, key_infos)) - } - - pub fn galois_elements(module: &M) -> Vec - where - M: GLWETrace, - { - module.glwe_trace_galois_elements() - } - - /// Adds a GLWE ciphertext to the [GLWEPacker]. - /// #Arguments - /// - /// * `module`: static backend FFT tables. - /// * `res`: space to append fully packed ciphertext. Only when the number - /// of packed ciphertexts reaches N/2^log_batch is a result written. - /// * `a`: ciphertext to pack. Can optionally give None to pack a 0 ciphertext. - /// * `auto_keys`: a [HashMap] containing the [AutomorphismKeyExec]s. - /// * `scratch`: scratch space of size at least [Self::tmp_bytes]. - pub fn add(&mut self, module: &M, a: Option<&A>, auto_keys: &HashMap, scratch: &mut Scratch) - where - A: GLWEToRef + GLWEInfos, - K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - M: GLWEPacking, - Scratch: ScratchTakeCore, - { - assert!( - (self.counter as u32) < self.accumulators[0].data.n(), - "Packing limit of {} reached", - self.accumulators[0].data.n().0 as usize >> self.log_batch - ); - - pack_core( - module, - a, - &mut self.accumulators, - self.log_batch, - auto_keys, - scratch, - ); - self.counter += 1 << self.log_batch; - } - - /// Flush result to`res`. - pub fn flush(&mut self, module: &M, res: &mut R) - where - R: GLWEToMut, - M: GLWEPacking, - { - assert!(self.counter as u32 == self.accumulators[0].data.n()); - // Copy result GLWE into res GLWE - module.glwe_copy( - res, - &self.accumulators[module.log_n() - self.log_batch - 1].data, - ); - - self.reset(); - } -} - -impl GLWEPacking for Module where - Self: GLWEAutomorphism - + GaloisElement - + ModuleLogN - + GLWERotate - + GLWESub - + GLWEShift - + GLWEAdd - + GLWENormalize - + GLWECopy -{ -} - -pub trait GLWEPacking +impl GLWEPacking for Module where Self: GLWEAutomorphism + GaloisElement @@ -177,6 +34,7 @@ where + GLWEAdd + GLWENormalize + GLWECopy, + Scratch: ScratchTakeCore, { /// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)] /// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)] @@ -189,7 +47,6 @@ where ) where R: GLWEToMut + GLWEToRef + GLWEInfos, K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - Scratch: ScratchTakeCore, { #[cfg(debug_assertions)] { @@ -223,169 +80,6 @@ where } } -fn pack_core( - module: &M, - a: Option<&A>, - accumulators: &mut [Accumulator], - i: usize, - auto_keys: &HashMap, - scratch: &mut Scratch, -) where - A: GLWEToRef + GLWEInfos, - K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - M: ModuleLogN - + GLWEAutomorphism - + GaloisElement - + GLWERotate - + GLWESub - + GLWEShift - + GLWEAdd - + GLWENormalize - + GLWECopy, - Scratch: ScratchTakeCore, -{ - let log_n: usize = module.log_n(); - - if i == log_n { - return; - } - - // Isolate the first accumulator - let (acc_prev, acc_next) = accumulators.split_at_mut(1); - - // Control = true accumlator is free to overide - if !acc_prev[0].control { - let acc_mut_ref: &mut Accumulator = &mut acc_prev[0]; // from split_at_mut - - // No previous value -> copies and sets flags accordingly - if let Some(a_ref) = a { - module.glwe_copy(&mut acc_mut_ref.data, a_ref); - acc_mut_ref.value = true - } else { - acc_mut_ref.value = false - } - acc_mut_ref.control = true; // Able to be combined on next call - } else { - // Compresses acc_prev <- combine(acc_prev, a). - combine(module, &mut acc_prev[0], a, i, auto_keys, scratch); - acc_prev[0].control = false; - - // Propagates to next accumulator - if acc_prev[0].value { - pack_core( - module, - Some(&acc_prev[0].data), - acc_next, - i + 1, - auto_keys, - scratch, - ); - } else { - pack_core( - module, - None::<&GLWE>>, - acc_next, - i + 1, - auto_keys, - scratch, - ); - } - } -} - -/// [combine] merges two ciphertexts together. -fn combine( - module: &M, - acc: &mut Accumulator, - b: Option<&B>, - i: usize, - auto_keys: &HashMap, - scratch: &mut Scratch, -) where - B: GLWEToRef + GLWEInfos, - M: GLWEAutomorphism + GaloisElement + GLWERotate + GLWESub + GLWEShift + GLWEAdd + GLWENormalize, - B: GLWEToRef + GLWEInfos, - K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - Scratch: ScratchTakeCore, -{ - let log_n: usize = acc.data.n().log2(); - let a: &mut GLWE> = &mut acc.data; - - let gal_el: i64 = if i == 0 { - -1 - } else { - module.galois_element(1 << (i - 1)) - }; - - let t: i64 = 1 << (log_n - i - 1); - - // Goal is to evaluate: a = a + b*X^t + phi(a - b*X^t)) - // We also use the identity: AUTO(a * X^t, g) = -X^t * AUTO(a, g) - // where t = 2^(log_n - i - 1) and g = 5^{2^(i - 1)} - // Different cases for wether a and/or b are zero. - // - // Implicite RSH without modulus switch, introduces extra I(X) * Q/2 on decryption. - // Necessary so that the scaling of the plaintext remains constant. - // It however is ok to do so here because coefficients are eventually - // either mapped to garbage or twice their value which vanishes I(X) - // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q. - if acc.value { - if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe(a); - - // a = a * X^-t - module.glwe_rotate_inplace(-t, a, scratch_1); - - // tmp_b = a * X^-t - b - module.glwe_sub(&mut tmp_b, a, b); - module.glwe_rsh(1, &mut tmp_b, scratch_1); - - // a = a * X^-t + b - module.glwe_add_inplace(a, b); - module.glwe_rsh(1, a, scratch_1); - - module.glwe_normalize_inplace(&mut tmp_b, scratch_1); - - // tmp_b = phi(a * X^-t - b) - if let Some(auto_key) = auto_keys.get(&gal_el) { - module.glwe_automorphism_inplace(&mut tmp_b, auto_key, scratch_1); - } else { - panic!("auto_key[{gal_el}] not found"); - } - - // a = a * X^-t + b - phi(a * X^-t - b) - module.glwe_sub_inplace(a, &tmp_b); - module.glwe_normalize_inplace(a, scratch_1); - - // a = a + b * X^t - phi(a * X^-t - b) * X^t - // = a + b * X^t - phi(a * X^-t - b) * - phi(X^t) - // = a + b * X^t + phi(a - b * X^t) - module.glwe_rotate_inplace(t, a, scratch_1); - } else { - module.glwe_rsh(1, a, scratch); - // a = a + phi(a) - if let Some(auto_key) = auto_keys.get(&gal_el) { - module.glwe_automorphism_add_inplace(a, auto_key, scratch); - } else { - panic!("auto_key[{gal_el}] not found"); - } - } - } else if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe(a); - module.glwe_rotate(t, &mut tmp_b, b); - module.glwe_rsh(1, &mut tmp_b, scratch_1); - - // a = (b* X^t - phi(b* X^t)) - if let Some(auto_key) = auto_keys.get(&gal_el) { - module.glwe_automorphism_sub_negate(a, &tmp_b, auto_key, scratch_1); - } else { - panic!("auto_key[{gal_el}] not found"); - } - - acc.value = true; - } -} - #[allow(clippy::too_many_arguments)] fn pack_internal( module: &M, diff --git a/poulpy-core/src/glwe_trace.rs b/poulpy-core/src/glwe_trace.rs index c2ba15c..0ba7b81 100644 --- a/poulpy-core/src/glwe_trace.rs +++ b/poulpy-core/src/glwe_trace.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; use poulpy_hal::{ - api::ModuleLogN, - layouts::{Backend, DataMut, GaloisElement, Module, Scratch, VecZnx, galois_element}, + api::{ModuleLogN, VecZnxNormalize, VecZnxNormalizeTmpBytes}, + layouts::{Backend, CyclotomicOrder, DataMut, GaloisElement, Module, Scratch, VecZnx, galois_element}, }; use crate::{ @@ -27,7 +27,7 @@ impl GLWE> { K: GGLWEInfos, M: GLWETrace, { - module.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos) + module.glwe_trace_tmp_bytes(res_infos, a_infos, key_infos) } } @@ -65,11 +65,6 @@ impl GLWE { } } -impl GLWETrace for Module where - Self: ModuleLogN + GaloisElement + GLWEAutomorphism + GLWEShift + GLWECopy -{ -} - #[inline(always)] pub fn trace_galois_elements(log_n: usize, cyclotomic_order: i64) -> Vec { (0..log_n) @@ -83,9 +78,17 @@ pub fn trace_galois_elements(log_n: usize, cyclotomic_order: i64) -> Vec { .collect() } -pub trait GLWETrace +impl GLWETrace for Module where - Self: ModuleLogN + GaloisElement + GLWEAutomorphism + GLWEShift + GLWECopy, + Self: ModuleLogN + + GaloisElement + + GLWEAutomorphism + + GLWEShift + + GLWECopy + + CyclotomicOrder + + VecZnxNormalizeTmpBytes + + VecZnxNormalize, + Scratch: ScratchTakeCore, { fn glwe_trace_galois_elements(&self) -> Vec { trace_galois_elements(self.log_n(), self.cyclotomic_order()) @@ -115,7 +118,6 @@ where R: GLWEToMut, A: GLWEToRef, K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - Scratch: ScratchTakeCore, { self.glwe_copy(res, a); self.glwe_trace_inplace(res, start, end, keys, scratch); @@ -125,7 +127,6 @@ where where R: GLWEToMut, K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, - Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); @@ -212,3 +213,31 @@ where } } } + +pub trait GLWETrace { + fn glwe_trace_galois_elements(&self) -> Vec; + + fn glwe_trace_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + K: GGLWEInfos; + + fn glwe_trace( + &self, + res: &mut R, + start: usize, + end: usize, + a: &A, + keys: &HashMap, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + A: GLWEToRef, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos; + + fn glwe_trace_inplace(&self, res: &mut R, start: usize, end: usize, keys: &HashMap, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos; +} diff --git a/poulpy-core/src/keyswitching/ggsw.rs b/poulpy-core/src/keyswitching/ggsw.rs index 231b071..3dfb0b1 100644 --- a/poulpy-core/src/keyswitching/ggsw.rs +++ b/poulpy-core/src/keyswitching/ggsw.rs @@ -1,9 +1,9 @@ -use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, VecZnx}; +use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch}; use crate::{ GGSWExpandRows, ScratchTakeCore, keyswitching::GLWEKeyswitch, - layouts::{GGLWEInfos, GGLWEPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, prepared::GLWETensorKeyPreparedToRef}, + layouts::{GGLWEInfos, GGLWEPreparedToRef, GGLWEToGGSWKeyPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef}, }; impl GGSW> { @@ -30,7 +30,7 @@ impl GGSW { where A: GGSWToRef, K: GGLWEPreparedToRef, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, M: GGSWKeyswitch, { @@ -40,7 +40,7 @@ impl GGSW { pub fn keyswitch_inplace(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch) where K: GGLWEPreparedToRef, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, M: GGSWKeyswitch, { @@ -48,9 +48,7 @@ impl GGSW { } } -impl GGSWKeyswitch for Module where Self: GLWEKeyswitch + GGSWExpandRows {} - -pub trait GGSWKeyswitch +impl GGSWKeyswitch for Module where Self: GLWEKeyswitch + GGSWExpandRows, { @@ -65,25 +63,26 @@ where assert_eq!(tsk_infos.rank_in(), tsk_infos.rank_out()); assert_eq!(key_infos.rank_in(), tsk_infos.rank_in()); - let rank: usize = key_infos.rank_out().into(); + self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + .max(self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos)) + } - let size_out: usize = res_infos.k().div_ceil(res_infos.base2k()) as usize; - let res_znx: usize = VecZnx::bytes_of(self.n(), rank + 1, size_out); - let ci_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out); - let ks: usize = self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos); - let expand_rows: usize = self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos); - let res_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out); + fn ggsw_keyswitch_inplace(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + K: GGLWEPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); - if a_infos.base2k() == tsk_infos.base2k() { - res_znx + ci_dft + (ks | expand_rows | res_dft) - } else { - let a_conv: usize = VecZnx::bytes_of( - self.n(), - 1, - res_infos.k().div_ceil(tsk_infos.base2k()) as usize, - ) + self.vec_znx_normalize_tmp_bytes(); - res_znx + ci_dft + (a_conv | ks | expand_rows | res_dft) + for row in 0..res.dnum().into() { + // Key-switch column 0, i.e. + // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) + self.glwe_keyswitch_inplace(&mut res.at_mut(row, 0), key, scratch); } + + self.ggsw_expand_row(res, tsk, scratch); } fn ggsw_keyswitch(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) @@ -91,7 +90,7 @@ where R: GGSWToMut, A: GGSWToRef, K: GGLWEPreparedToRef, - T: GLWETensorKeyPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, Scratch: ScratchTakeCore, { let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); @@ -108,22 +107,31 @@ where self.ggsw_expand_row(res, tsk, scratch); } +} + +pub trait GGSWKeyswitch +where + Self: GLWEKeyswitch + GGSWExpandRows, +{ + fn ggsw_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K, tsk_infos: &T) -> usize + where + R: GGSWInfos, + A: GGSWInfos, + K: GGLWEInfos, + T: GGLWEInfos; + + fn ggsw_keyswitch(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch) + where + R: GGSWToMut, + A: GGSWToRef, + K: GGLWEPreparedToRef, + T: GGLWEToGGSWKeyPreparedToRef, + Scratch: ScratchTakeCore; fn ggsw_keyswitch_inplace(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch) where R: GGSWToMut, K: GGLWEPreparedToRef, - T: GLWETensorKeyPreparedToRef, - Scratch: ScratchTakeCore, - { - let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); - - for row in 0..res.dnum().into() { - // Key-switch column 0, i.e. - // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) - self.glwe_keyswitch_inplace(&mut res.at_mut(row, 0), key, scratch); - } - - self.ggsw_expand_row(res, tsk, scratch); - } + T: GGLWEToGGSWKeyPreparedToRef, + Scratch: ScratchTakeCore; } diff --git a/poulpy-core/src/keyswitching/glwe.rs b/poulpy-core/src/keyswitching/glwe.rs index a021777..72def40 100644 --- a/poulpy-core/src/keyswitching/glwe.rs +++ b/poulpy-core/src/keyswitching/glwe.rs @@ -1,10 +1,10 @@ use poulpy_hal::{ api::{ ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, - VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, + VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, - layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos}, + layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VecZnxDftToRef, VmpPMat, ZnxInfos}, }; use crate::{ @@ -45,46 +45,10 @@ impl GLWE { } } -impl GLWEKeyswitch for Module where - Self: Sized - + ModuleN - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes -{ -} - -pub trait GLWEKeyswitch +impl GLWEKeyswitch for Module where - Self: Sized - + ModuleN - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VecZnxNormalizeTmpBytes - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, + Self: Sized + GLWEKeySwitchInternal + VecZnxBigNormalizeTmpBytes + VecZnxBigNormalize, + Scratch: ScratchTakeCore, { fn glwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize where @@ -92,34 +56,10 @@ where A: GLWEInfos, B: GGLWEInfos, { - let in_size: usize = a_infos - .k() - .div_ceil(key_infos.base2k()) - .div_ceil(key_infos.dsize().into()) as usize; - let out_size: usize = res_infos.size(); - let ksk_size: usize = key_infos.size(); - let res_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE - let ai_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size); - let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( - out_size, - in_size, - in_size, - (key_infos.rank_in()).into(), - (key_infos.rank_out() + 1).into(), - ksk_size, - ) + self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size); - let normalize_big: usize = self.vec_znx_big_normalize_tmp_bytes(); - if a_infos.base2k() == key_infos.base2k() { - res_dft + ((ai_dft + vmp) | normalize_big) - } else if key_infos.dsize() == 1 { - // In this case, we only need one column, temporary, that we can drop once a_dft is computed. - let normalize_conv: usize = VecZnx::bytes_of(self.n(), 1, in_size) + self.vec_znx_normalize_tmp_bytes(); - res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big) - } else { - // Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion. - let normalize_conv: usize = VecZnx::bytes_of(self.n(), (key_infos.rank_in()).into(), in_size); - res_dft + ((ai_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big) - } + let cols: usize = res_infos.rank().as_usize() + 1; + self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_infos, key_infos) + .max(self.vec_znx_big_normalize_tmp_bytes()) + + self.bytes_of_vec_znx_dft(cols, key_infos.size()) } fn glwe_keyswitch(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) @@ -127,7 +67,6 @@ where R: GLWEToMut, A: GLWEToRef, K: GGLWEPreparedToRef, - Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let a: &GLWE<&[u8]> = &a.to_ref(); @@ -164,8 +103,8 @@ where let base2k_out: usize = b.base2k().into(); let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), b.size()); // Todo optimise - let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, a, b, scratch_1); - (0..(res.rank() + 1).into()).for_each(|i| { + let res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, b, scratch_1); + for i in 0..(res.rank() + 1).into() { self.vec_znx_big_normalize( basek_out, &mut res.data, @@ -175,37 +114,36 @@ where i, scratch_1, ); - }) + } } fn glwe_keyswitch_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) where R: GLWEToMut, K: GGLWEPreparedToRef, - Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - let a: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); + let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); assert_eq!( res.rank(), - a.rank_in(), + key.rank_in(), "res.rank(): {} != a.rank_in(): {}", res.rank(), - a.rank_in() + key.rank_in() ); assert_eq!( res.rank(), - a.rank_out(), + key.rank_out(), "res.rank(): {} != b.rank_out(): {}", res.rank(), - a.rank_out() + key.rank_out() ); assert_eq!(res.n(), self.n() as u32); - assert_eq!(a.n(), self.n() as u32); + assert_eq!(key.n(), self.n() as u32); - let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, res, a); + let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, res, key); assert!( scratch.available() >= scrach_needed, @@ -214,11 +152,11 @@ where ); let base2k_in: usize = res.base2k().into(); - let base2k_out: usize = a.base2k().into(); + let base2k_out: usize = key.base2k().into(); - let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); // Todo optimise - let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, res, a, scratch_1); - (0..(res.rank() + 1).into()).for_each(|i| { + let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise + let res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); + for i in 0..(res.rank() + 1).into() { self.vec_znx_big_normalize( base2k_in, &mut res.data, @@ -228,143 +166,235 @@ where i, scratch_1, ); - }) + } } } -impl GLWE> {} +pub trait GLWEKeyswitch { + fn glwe_keyswitch_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + B: GGLWEInfos; -impl GLWE {} + fn glwe_keyswitch(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + K: GGLWEPreparedToRef; -pub(crate) fn keyswitch_internal( - module: &M, - mut res: VecZnxDft, - a: &A, - key: &K, - scratch: &mut Scratch, -) -> VecZnxBig -where - DR: DataMut, - A: GLWEToRef, - K: GGLWEPreparedToRef, - M: ModuleN - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDftTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd + fn glwe_keyswitch_inplace(&self, res: &mut R, key: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + K: GGLWEPreparedToRef; +} + +impl GLWEKeySwitchInternal for Module where + Self: GGLWEProduct + VecZnxDftApply + + VecZnxNormalize + VecZnxIdftApplyConsume + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize, - Scratch: ScratchTakeCore, + + VecZnxNormalizeTmpBytes { - let a: &GLWE<&[u8]> = &a.to_ref(); - let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); +} - let base2k_in: usize = a.base2k().into(); - let base2k_out: usize = key.base2k().into(); - let cols: usize = (a.rank() + 1).into(); - let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out); - let pmat: &VmpPMat<&[u8], BE> = &key.data; +pub(crate) trait GLWEKeySwitchInternal +where + Self: GGLWEProduct + + VecZnxDftApply + + VecZnxNormalize + + VecZnxIdftApplyConsume + + VecZnxBigAddSmallInplace + + VecZnxNormalizeTmpBytes, +{ + fn glwe_keyswitch_internal_tmp_bytes(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize + where + R: GLWEInfos, + A: GLWEInfos, + K: GGLWEInfos, + { + let cols: usize = (a_infos.rank() + 1).into(); + let a_size: usize = a_infos.size(); - if key.dsize() == 1 { - let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size()); + let a_conv = if a_infos.base2k() == key_infos.base2k() { + 0 + } else { + VecZnx::bytes_of(self.n(), 1, a_size) + self.vec_znx_normalize_tmp_bytes() + }; + + self.gglwe_product_dft_tmp_bytes(res_infos.size(), a_size, key_infos) + self.bytes_of_vec_znx_dft(cols, a_size) + a_conv + } + + fn glwe_keyswitch_internal( + &self, + mut res: VecZnxDft, + a: &A, + key: &K, + scratch: &mut Scratch, + ) -> VecZnxBig + where + DR: DataMut, + A: GLWEToRef, + K: GGLWEPreparedToRef, + Scratch: ScratchTakeCore, + { + let a: &GLWE<&[u8]> = &a.to_ref(); + let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); + + let base2k_in: usize = a.base2k().into(); + let base2k_out: usize = key.base2k().into(); + let cols: usize = (a.rank() + 1).into(); + let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out); + + let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, a_size); if base2k_in == base2k_out { - (0..cols - 1).for_each(|col_i| { - module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a.data(), col_i + 1); - }); + for col_i in 0..cols - 1 { + self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, a.data(), col_i + 1); + } } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), 1, a_size); - (0..cols - 1).for_each(|col_i| { - module.vec_znx_normalize( + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size); + for i in 0..cols - 1 { + self.vec_znx_normalize( base2k_out, &mut a_conv, 0, base2k_in, a.data(), - col_i + 1, + i + 1, scratch_2, ); - module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, &a_conv, 0); - }); + self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_conv, 0); + } } - module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1); - } else { - let dsize: usize = key.dsize().into(); + self.gglwe_product_dft(&mut res, &a_dft, key, scratch_1); - let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a_size.div_ceil(dsize)); - ai_dft.data_mut().fill(0); + let mut res_big: VecZnxBig = self.vec_znx_idft_apply_consume(res); + self.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0); + res_big + } +} - if base2k_in == base2k_out { - for di in 0..dsize { - ai_dft.set_size((a_size + di) / dsize); +impl GGLWEProduct for Module where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftCopy +{ +} - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize); +pub(crate) trait GGLWEProduct +where + Self: Sized + + ModuleN + + VecZnxDftBytesOf + + VmpApplyDftToDftTmpBytes + + VmpApplyDftToDft + + VmpApplyDftToDftAdd + + VecZnxDftCopy, +{ + fn gglwe_product_dft_tmp_bytes(&self, res_size: usize, a_size: usize, key_infos: &K) -> usize + where + K: GGLWEInfos, + { + let dsize: usize = key_infos.dsize().as_usize(); - for j in 0..cols - 1 { - module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, a.data(), j + 1); - } - - if di == 0 { - module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1); - } else { - module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_1); - } - } + if dsize == 1 { + self.vmp_apply_dft_to_dft_tmp_bytes( + res_size, + a_size, + key_infos.dnum().into(), + (key_infos.rank_in()).into(), + (key_infos.rank_out() + 1).into(), + key_infos.size(), + ) } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), cols - 1, a_size); - for j in 0..cols - 1 { - module.vec_znx_normalize( - base2k_out, - &mut a_conv, - j, - base2k_in, - a.data(), - j + 1, - scratch_2, - ); - } + let dnum: usize = key_infos.dnum().into(); + let a_size: usize = a_size.div_ceil(dsize).min(dnum); + let ai_dft: usize = self.bytes_of_vec_znx_dft(key_infos.rank_in().into(), a_size); - for di in 0..dsize { - ai_dft.set_size((a_size + di) / dsize); + let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes( + res_size, + a_size, + dnum, + (key_infos.rank_in()).into(), + (key_infos.rank_out() + 1).into(), + key_infos.size(), + ); - // Small optimization for dsize > 2 - // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then - // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. - // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. - // It is possible to further ignore the last dsize-1 limbs, but this introduce - // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same - // noise is kept with respect to the ideal functionality. - res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize); - - for j in 0..cols - 1 { - module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, &a_conv, j); - } - - if di == 0 { - module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_2); - } else { - module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_2); - } - } + ai_dft + vmp } - - res.set_size(res.max_size()); } - let mut res_big: VecZnxBig = module.vec_znx_idft_apply_consume(res); - module.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0); - res_big + fn gglwe_product_dft(&self, res: &mut VecZnxDft, a: &A, key: &K, scratch: &mut Scratch) + where + DR: DataMut, + A: VecZnxDftToRef, + K: GGLWEPreparedToRef, + Scratch: ScratchTakeCore, + { + let a: &VecZnxDft<&[u8], BE> = &a.to_ref(); + let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); + + let cols: usize = a.cols(); + let a_size: usize = a.size(); + let pmat: &VmpPMat<&[u8], BE> = &key.data; + + // If dsize == 1, then the digit decomposition is equal to Base2K and we can simply + // can the vmp API. + if key.dsize() == 1 { + self.vmp_apply_dft_to_dft(res, a, pmat, scratch); + // If dsize != 1, then the digit decomposition is k * Base2K with k > 1. + // As such we need to perform a bivariate polynomial convolution in (X, Y) / (X^{N}+1) with Y = 2^-K + // (instead of yn univariate one in X). + // + // Since the basis in Y is small (in practice degree 6-7 max), we perform it naiveley. + // To do so, we group the different limbs of ai_dft by their respective degree in Y + // which are multiples of the current digit. + // For example if dsize = 3, with ai_dft = [a0, a1, a2, a3, a4, a5, a6], + // we group them as [[a0, a3, a5], [a1, a4, a6], [a2, a5, 0]] + // and evaluate sum(a_di * pmat * 2^{di*Base2k}) + } else { + let dsize: usize = key.dsize().into(); + let dnum: usize = key.dnum().into(); + + // We bound ai_dft size by the number of rows of the matrix + let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize).min(dnum)); + ai_dft.data_mut().fill(0); + + for di in 0..dsize { + // Sets ai_dft size according to the current digit (if dsize does not divides a_size), + // bounded by the number of rows (digits) in the prepared matrix. + ai_dft.set_size(((a_size + di) / dsize).min(dnum)); + + // Small optimization for dsize > 2 + // VMP produce some error e, and since we aggregate vmp * 2^{di * Base2k}, then + // we also aggregate ei * 2^{di * Base2k}, with the largest error being ei * 2^{(dsize-1) * Base2k}. + // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. + // It is possible to further ignore the last dsize-1 limbs, but this introduce + // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same + // noise is kept with respect to the ideal functionality. + res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize); + + for j in 0..cols { + self.vec_znx_dft_copy(dsize, dsize - di - 1, &mut ai_dft, j, a, j); + } + + if di == 0 { + // res = pmat * ai_dft + self.vmp_apply_dft_to_dft(res, &ai_dft, pmat, scratch_1); + } else { + // res = (pmat * ai_dft) * 2^{di * Base2k} + self.vmp_apply_dft_to_dft_add(res, &ai_dft, pmat, di, scratch_1); + } + } + + res.set_size(res.max_size()); + } + } } diff --git a/poulpy-core/src/layouts/compressed/gglwe_to_ggsw_key.rs b/poulpy-core/src/layouts/compressed/gglwe_to_ggsw_key.rs new file mode 100644 index 0000000..e158a0c --- /dev/null +++ b/poulpy-core/src/layouts/compressed/gglwe_to_ggsw_key.rs @@ -0,0 +1,237 @@ +use poulpy_hal::{ + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, + source::Source, +}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEDecompress, GGLWEInfos, + GGLWEToGGSWKey, GGLWEToGGSWKeyToMut, GLWEInfos, LWEInfos, Rank, TorusPrecision, +}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +use std::fmt; + +#[derive(PartialEq, Eq, Clone)] +pub struct GGLWEToGGSWKeyCompressed { + pub(crate) keys: Vec>, +} + +impl LWEInfos for GGLWEToGGSWKeyCompressed { + fn n(&self) -> Degree { + self.keys[0].n() + } + + fn base2k(&self) -> Base2K { + self.keys[0].base2k() + } + + fn k(&self) -> TorusPrecision { + self.keys[0].k() + } + + fn size(&self) -> usize { + self.keys[0].size() + } +} + +impl GLWEInfos for GGLWEToGGSWKeyCompressed { + fn rank(&self) -> Rank { + self.keys[0].rank_out() + } +} + +impl GGLWEInfos for GGLWEToGGSWKeyCompressed { + fn rank_in(&self) -> Rank { + self.rank_out() + } + + fn rank_out(&self) -> Rank { + self.keys[0].rank_out() + } + + fn dsize(&self) -> Dsize { + self.keys[0].dsize() + } + + fn dnum(&self) -> Dnum { + self.keys[0].dnum() + } +} + +impl fmt::Debug for GGLWEToGGSWKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{self}") + } +} + +impl FillUniform for GGLWEToGGSWKeyCompressed { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.keys + .iter_mut() + .for_each(|key: &mut GGLWECompressed| key.fill_uniform(log_bound, source)) + } +} + +impl fmt::Display for GGLWEToGGSWKeyCompressed { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "(GGLWEToGGSWKeyCompressed)",)?; + for (i, key) in self.keys.iter().enumerate() { + write!(f, "{i}: {key}")?; + } + Ok(()) + } +} + +impl GGLWEToGGSWKeyCompressed> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEToGGSWKeyCompressed" + ); + Self::alloc( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { + GGLWEToGGSWKeyCompressed { + keys: (0..rank.as_usize()) + .map(|_| GGLWECompressed::alloc(n, base2k, k, rank, rank, dnum, dsize)) + .collect(), + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEToGGSWKeyCompressed" + ); + Self::bytes_of( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + rank.as_usize() * GGLWECompressed::bytes_of(n, base2k, k, rank, dnum, dsize) + } +} + +impl GGLWEToGGSWKeyCompressed { + // Returns a mutable reference to GGLWE_{s}([s[i]*s[0], s[i]*s[1], ..., s[i]*s[rank]]) + pub fn at_mut(&mut self, i: usize) -> &mut GGLWECompressed { + assert!((i as u32) < self.rank()); + &mut self.keys[i] + } +} + +impl GGLWEToGGSWKeyCompressed { + // Returns a reference to GGLWE_{s}(s[i] * s[j]) + pub fn at(&self, i: usize) -> &GGLWECompressed { + assert!((i as u32) < self.rank()); + &self.keys[i] + } +} + +impl ReaderFrom for GGLWEToGGSWKeyCompressed { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + let len: usize = reader.read_u64::()? as usize; + if self.keys.len() != len { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("self.keys.len()={} != read len={}", self.keys.len(), len), + )); + } + for key in &mut self.keys { + key.read_from(reader)?; + } + Ok(()) + } +} + +impl WriterTo for GGLWEToGGSWKeyCompressed { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.keys.len() as u64)?; + for key in &self.keys { + key.write_to(writer)?; + } + Ok(()) + } +} + +pub trait GGLWEToGGSWKeyDecompress +where + Self: GGLWEDecompress, +{ + fn decompress_gglwe_to_ggsw_key(&self, res: &mut R, other: &O) + where + R: GGLWEToGGSWKeyToMut, + O: GGLWEToGGSWKeyCompressedToRef, + { + let res: &mut GGLWEToGGSWKey<&mut [u8]> = &mut res.to_mut(); + let other: &GGLWEToGGSWKeyCompressed<&[u8]> = &other.to_ref(); + + assert_eq!(res.keys.len(), other.keys.len()); + + for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) { + self.decompress_gglwe(a, b); + } + } +} + +impl GGLWEToGGSWKey { + pub fn decompress(&mut self, module: &M, other: &O) + where + M: GGLWEToGGSWKeyDecompress, + O: GGLWEToGGSWKeyCompressedToRef, + { + module.decompress_gglwe_to_ggsw_key(self, other); + } +} + +pub trait GGLWEToGGSWKeyCompressedToRef { + fn to_ref(&self) -> GGLWEToGGSWKeyCompressed<&[u8]>; +} + +impl GGLWEToGGSWKeyCompressedToRef for GGLWEToGGSWKeyCompressed +where + GGLWECompressed: GGLWECompressedToRef, +{ + fn to_ref(&self) -> GGLWEToGGSWKeyCompressed<&[u8]> { + GGLWEToGGSWKeyCompressed { + keys: self.keys.iter().map(|c| c.to_ref()).collect(), + } + } +} + +pub trait GGLWEToGGSWKeyCompressedToMut { + fn to_mut(&mut self) -> GGLWEToGGSWKeyCompressed<&mut [u8]>; +} + +impl GGLWEToGGSWKeyCompressedToMut for GGLWEToGGSWKeyCompressed +where + GGLWECompressed: GGLWECompressedToMut, +{ + fn to_mut(&mut self) -> GGLWEToGGSWKeyCompressed<&mut [u8]> { + GGLWEToGGSWKeyCompressed { + keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), + } + } +} diff --git a/poulpy-core/src/layouts/compressed/glwe_tensor_key.rs b/poulpy-core/src/layouts/compressed/glwe_tensor_key.rs index 6939ff2..c6e9297 100644 --- a/poulpy-core/src/layouts/compressed/glwe_tensor_key.rs +++ b/poulpy-core/src/layouts/compressed/glwe_tensor_key.rs @@ -4,31 +4,34 @@ use poulpy_hal::{ }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEDecompress, GGLWEInfos, - GLWEInfos, GLWETensorKey, GLWETensorKeyToMut, LWEInfos, Rank, TorusPrecision, + Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedSeedMut, GGLWECompressedToMut, GGLWECompressedToRef, + GGLWEDecompress, GGLWEInfos, GGLWEToMut, GLWEInfos, GLWETensorKey, LWEInfos, Rank, TorusPrecision, }; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct GLWETensorKeyCompressed { - pub(crate) keys: Vec>, +pub struct GLWETensorKeyCompressed(pub(crate) GGLWECompressed); + +impl GGLWECompressedSeedMut for GLWETensorKeyCompressed { + fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> { + &mut self.0.seed + } } impl LWEInfos for GLWETensorKeyCompressed { fn n(&self) -> Degree { - self.keys[0].n() + self.0.n() } fn base2k(&self) -> Base2K { - self.keys[0].base2k() + self.0.base2k() } fn k(&self) -> TorusPrecision { - self.keys[0].k() + self.0.k() } fn size(&self) -> usize { - self.keys[0].size() + self.0.size() } } impl GLWEInfos for GLWETensorKeyCompressed { @@ -43,15 +46,15 @@ impl GGLWEInfos for GLWETensorKeyCompressed { } fn rank_out(&self) -> Rank { - self.keys[0].rank_out() + self.0.rank_out() } fn dsize(&self) -> Dsize { - self.keys[0].dsize() + self.0.dsize() } fn dnum(&self) -> Dnum { - self.keys[0].dnum() + self.0.dnum() } } @@ -63,18 +66,14 @@ impl fmt::Debug for GLWETensorKeyCompressed { impl FillUniform for GLWETensorKeyCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - self.keys - .iter_mut() - .for_each(|key: &mut GGLWECompressed| key.fill_uniform(log_bound, source)) + self.0.fill_uniform(log_bound, source); } } impl fmt::Display for GLWETensorKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "(GLWETensorKeyCompressed)",)?; - for (i, key) in self.keys.iter().enumerate() { - write!(f, "{i}: {key}")?; - } + write!(f, "{}", self.0)?; Ok(()) } } @@ -96,11 +95,15 @@ impl GLWETensorKeyCompressed> { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); - GLWETensorKeyCompressed { - keys: (0..pairs) - .map(|_| GGLWECompressed::alloc(n, base2k, k, Rank(1), rank, dnum, dsize)) - .collect(), - } + GLWETensorKeyCompressed(GGLWECompressed::alloc( + n, + base2k, + k, + Rank(pairs), + rank, + dnum, + dsize, + )) } pub fn bytes_of_from_infos(infos: &A) -> usize @@ -118,88 +121,35 @@ impl GLWETensorKeyCompressed> { } pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { - let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; - pairs * GGLWECompressed::bytes_of(n, base2k, k, Rank(1), dnum, dsize) + let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); + GGLWECompressed::bytes_of(n, base2k, k, Rank(pairs), dnum, dsize) } } impl ReaderFrom for GLWETensorKeyCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - let len: usize = reader.read_u64::()? as usize; - if self.keys.len() != len { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("self.keys.len()={} != read len={}", self.keys.len(), len), - )); - } - for key in &mut self.keys { - key.read_from(reader)?; - } + self.0.read_from(reader)?; Ok(()) } } impl WriterTo for GLWETensorKeyCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.keys.len() as u64)?; - for key in &self.keys { - key.write_to(writer)?; - } + self.0.write_to(writer)?; Ok(()) } } -pub trait GLWETensorKeyCompressedAtRef { - fn at(&self, i: usize, j: usize) -> &GGLWECompressed; -} - -impl GLWETensorKeyCompressedAtRef for GLWETensorKeyCompressed { - fn at(&self, mut i: usize, mut j: usize) -> &GGLWECompressed { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank_out().into(); - &self.keys[i * rank + j - (i * (i + 1) / 2)] - } -} - -pub trait GLWETensorKeyCompressedAtMut { - fn at_mut(&mut self, i: usize, j: usize) -> &mut GGLWECompressed; -} - -impl GLWETensorKeyCompressedAtMut for GLWETensorKeyCompressed { - fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWECompressed { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank_out().into(); - &mut self.keys[i * rank + j - (i * (i + 1) / 2)] - } -} - pub trait GLWETensorKeyDecompress where Self: GGLWEDecompress, { fn decompress_tensor_key(&self, res: &mut R, other: &O) where - R: GLWETensorKeyToMut, - O: GLWETensorKeyCompressedToRef, + R: GGLWEToMut, + O: GGLWECompressedToRef, { - let res: &mut GLWETensorKey<&mut [u8]> = &mut res.to_mut(); - let other: &GLWETensorKeyCompressed<&[u8]> = &other.to_ref(); - - assert_eq!( - res.keys.len(), - other.keys.len(), - "invalid receiver: res.keys.len()={} != other.keys.len()={}", - res.keys.len(), - other.keys.len() - ); - - for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) { - self.decompress_gglwe(a, b); - } + self.decompress_gglwe(res, other); } } @@ -208,39 +158,27 @@ impl GLWETensorKeyDecompress for Module where Self: GGLWEDecompre impl GLWETensorKey { pub fn decompress(&mut self, module: &M, other: &O) where - O: GLWETensorKeyCompressedToRef, + O: GGLWECompressedToRef, M: GLWETensorKeyDecompress, { module.decompress_tensor_key(self, other); } } -pub trait GLWETensorKeyCompressedToMut { - fn to_mut(&mut self) -> GLWETensorKeyCompressed<&mut [u8]>; -} - -impl GLWETensorKeyCompressedToMut for GLWETensorKeyCompressed +impl GGLWECompressedToMut for GLWETensorKeyCompressed where GGLWECompressed: GGLWECompressedToMut, { - fn to_mut(&mut self) -> GLWETensorKeyCompressed<&mut [u8]> { - GLWETensorKeyCompressed { - keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), - } + fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]> { + self.0.to_mut() } } -pub trait GLWETensorKeyCompressedToRef { - fn to_ref(&self) -> GLWETensorKeyCompressed<&[u8]>; -} - -impl GLWETensorKeyCompressedToRef for GLWETensorKeyCompressed +impl GGLWECompressedToRef for GLWETensorKeyCompressed where GGLWECompressed: GGLWECompressedToRef, { - fn to_ref(&self) -> GLWETensorKeyCompressed<&[u8]> { - GLWETensorKeyCompressed { - keys: self.keys.iter().map(|c| c.to_ref()).collect(), - } + fn to_ref(&self) -> GGLWECompressed<&[u8]> { + self.0.to_ref() } } diff --git a/poulpy-core/src/layouts/compressed/glwe_to_lwe_switching_key.rs b/poulpy-core/src/layouts/compressed/glwe_to_lwe_key.rs similarity index 95% rename from poulpy-core/src/layouts/compressed/glwe_to_lwe_switching_key.rs rename to poulpy-core/src/layouts/compressed/glwe_to_lwe_key.rs index 6ac325c..5552d11 100644 --- a/poulpy-core/src/layouts/compressed/glwe_to_lwe_switching_key.rs +++ b/poulpy-core/src/layouts/compressed/glwe_to_lwe_key.rs @@ -7,7 +7,7 @@ use poulpy_hal::{ use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEInfos, GGLWEToMut, GLWEInfos, - GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, GLWEToLWESwitchingKey, LWEInfos, Rank, TorusPrecision, + GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, GLWEToLWEKey, LWEInfos, Rank, TorusPrecision, compressed::{GLWESwitchingKeyCompressed, GLWESwitchingKeyDecompress}, }; @@ -147,7 +147,7 @@ pub trait GLWEToLWESwitchingKeyDecompress where Self: GLWESwitchingKeyDecompress, { - fn decompress_glwe_to_lwe_switching_key(&self, res: &mut R, other: &O) + fn decompress_glwe_to_lwe_key(&self, res: &mut R, other: &O) where R: GGLWEToMut + GLWESwitchingKeyDegreesMut, O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, @@ -158,13 +158,13 @@ where impl GLWEToLWESwitchingKeyDecompress for Module where Self: GLWESwitchingKeyDecompress {} -impl GLWEToLWESwitchingKey { +impl GLWEToLWEKey { pub fn decompress(&mut self, module: &M, other: &O) where O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, M: GLWEToLWESwitchingKeyDecompress, { - module.decompress_glwe_to_lwe_switching_key(self, other); + module.decompress_glwe_to_lwe_key(self, other); } } diff --git a/poulpy-core/src/layouts/compressed/lwe_to_glwe_switching_key.rs b/poulpy-core/src/layouts/compressed/lwe_to_glwe_key.rs similarity index 73% rename from poulpy-core/src/layouts/compressed/lwe_to_glwe_switching_key.rs rename to poulpy-core/src/layouts/compressed/lwe_to_glwe_key.rs index 7a724c9..984ed05 100644 --- a/poulpy-core/src/layouts/compressed/lwe_to_glwe_switching_key.rs +++ b/poulpy-core/src/layouts/compressed/lwe_to_glwe_key.rs @@ -5,15 +5,15 @@ use poulpy_hal::{ use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEInfos, GGLWEToMut, GLWEInfos, - GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, LWEInfos, LWEToGLWESwitchingKey, Rank, TorusPrecision, + GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, LWEInfos, LWEToGLWEKey, Rank, TorusPrecision, compressed::{GLWESwitchingKeyCompressed, GLWESwitchingKeyDecompress}, }; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct LWEToGLWESwitchingKeyCompressed(pub(crate) GLWESwitchingKeyCompressed); +pub struct LWEToGLWEKeyCompressed(pub(crate) GLWESwitchingKeyCompressed); -impl LWEInfos for LWEToGLWESwitchingKeyCompressed { +impl LWEInfos for LWEToGLWEKeyCompressed { fn n(&self) -> Degree { self.0.n() } @@ -29,13 +29,13 @@ impl LWEInfos for LWEToGLWESwitchingKeyCompressed { self.0.size() } } -impl GLWEInfos for LWEToGLWESwitchingKeyCompressed { +impl GLWEInfos for LWEToGLWEKeyCompressed { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for LWEToGLWESwitchingKeyCompressed { +impl GGLWEInfos for LWEToGLWEKeyCompressed { fn dsize(&self) -> Dsize { self.0.dsize() } @@ -53,37 +53,37 @@ impl GGLWEInfos for LWEToGLWESwitchingKeyCompressed { } } -impl fmt::Debug for LWEToGLWESwitchingKeyCompressed { +impl fmt::Debug for LWEToGLWEKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for LWEToGLWESwitchingKeyCompressed { +impl FillUniform for LWEToGLWEKeyCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.0.fill_uniform(log_bound, source); } } -impl fmt::Display for LWEToGLWESwitchingKeyCompressed { +impl fmt::Display for LWEToGLWEKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "(LWEToGLWESwitchingKeyCompressed) {}", self.0) } } -impl ReaderFrom for LWEToGLWESwitchingKeyCompressed { +impl ReaderFrom for LWEToGLWEKeyCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.0.read_from(reader) } } -impl WriterTo for LWEToGLWESwitchingKeyCompressed { +impl WriterTo for LWEToGLWEKeyCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { self.0.write_to(writer) } } -impl LWEToGLWESwitchingKeyCompressed> { +impl LWEToGLWEKeyCompressed> { pub fn alloc_from_infos(infos: &A) -> Self where A: GGLWEInfos, @@ -108,7 +108,7 @@ impl LWEToGLWESwitchingKeyCompressed> { } pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { - LWEToGLWESwitchingKeyCompressed(GLWESwitchingKeyCompressed::alloc( + LWEToGLWEKeyCompressed(GLWESwitchingKeyCompressed::alloc( n, base2k, k, @@ -141,11 +141,11 @@ impl LWEToGLWESwitchingKeyCompressed> { } } -pub trait LWEToGLWESwitchingKeyDecompress +pub trait LWEToGLWEKeyDecompress where Self: GLWESwitchingKeyDecompress, { - fn decompress_lwe_to_glwe_switching_key(&self, res: &mut R, other: &O) + fn decompress_lwe_to_glwe_key(&self, res: &mut R, other: &O) where R: GGLWEToMut + GLWESwitchingKeyDegreesMut, O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, @@ -154,25 +154,25 @@ where } } -impl LWEToGLWESwitchingKeyDecompress for Module where Self: GLWESwitchingKeyDecompress {} +impl LWEToGLWEKeyDecompress for Module where Self: GLWESwitchingKeyDecompress {} -impl LWEToGLWESwitchingKey { +impl LWEToGLWEKey { pub fn decompress(&mut self, module: &M, other: &O) where O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, - M: LWEToGLWESwitchingKeyDecompress, + M: LWEToGLWEKeyDecompress, { - module.decompress_lwe_to_glwe_switching_key(self, other); + module.decompress_lwe_to_glwe_key(self, other); } } -impl GGLWECompressedToRef for LWEToGLWESwitchingKeyCompressed { +impl GGLWECompressedToRef for LWEToGLWEKeyCompressed { fn to_ref(&self) -> GGLWECompressed<&[u8]> { self.0.to_ref() } } -impl GGLWECompressedToMut for LWEToGLWESwitchingKeyCompressed { +impl GGLWECompressedToMut for LWEToGLWEKeyCompressed { fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]> { self.0.to_mut() } diff --git a/poulpy-core/src/layouts/compressed/mod.rs b/poulpy-core/src/layouts/compressed/mod.rs index b85d48d..8dd6145 100644 --- a/poulpy-core/src/layouts/compressed/mod.rs +++ b/poulpy-core/src/layouts/compressed/mod.rs @@ -1,21 +1,23 @@ mod gglwe; +mod gglwe_to_ggsw_key; mod ggsw; mod glwe; mod glwe_automorphism_key; mod glwe_switching_key; mod glwe_tensor_key; -mod glwe_to_lwe_switching_key; +mod glwe_to_lwe_key; mod lwe; mod lwe_switching_key; -mod lwe_to_glwe_switching_key; +mod lwe_to_glwe_key; pub use gglwe::*; +pub use gglwe_to_ggsw_key::*; pub use ggsw::*; pub use glwe::*; pub use glwe_automorphism_key::*; pub use glwe_switching_key::*; pub use glwe_tensor_key::*; -pub use glwe_to_lwe_switching_key::*; +pub use glwe_to_lwe_key::*; pub use lwe::*; pub use lwe_switching_key::*; -pub use lwe_to_glwe_switching_key::*; +pub use lwe_to_glwe_key::*; diff --git a/poulpy-core/src/layouts/gglwe_to_ggsw_key.rs b/poulpy-core/src/layouts/gglwe_to_ggsw_key.rs new file mode 100644 index 0000000..398dfd2 --- /dev/null +++ b/poulpy-core/src/layouts/gglwe_to_ggsw_key.rs @@ -0,0 +1,254 @@ +use poulpy_hal::{ + layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo}, + source::Source, +}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWEInfos, LWEInfos, Rank, TorusPrecision, +}; +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; + +use std::fmt; + +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +pub struct GGLWEToGGSWKeyLayout { + pub n: Degree, + pub base2k: Base2K, + pub k: TorusPrecision, + pub rank: Rank, + pub dnum: Dnum, + pub dsize: Dsize, +} + +#[derive(PartialEq, Eq, Clone)] +pub struct GGLWEToGGSWKey { + pub(crate) keys: Vec>, +} + +impl LWEInfos for GGLWEToGGSWKey { + fn n(&self) -> Degree { + self.keys[0].n() + } + + fn base2k(&self) -> Base2K { + self.keys[0].base2k() + } + + fn k(&self) -> TorusPrecision { + self.keys[0].k() + } + + fn size(&self) -> usize { + self.keys[0].size() + } +} + +impl GLWEInfos for GGLWEToGGSWKey { + fn rank(&self) -> Rank { + self.keys[0].rank_out() + } +} + +impl GGLWEInfos for GGLWEToGGSWKey { + fn rank_in(&self) -> Rank { + self.rank_out() + } + + fn rank_out(&self) -> Rank { + self.keys[0].rank_out() + } + + fn dsize(&self) -> Dsize { + self.keys[0].dsize() + } + + fn dnum(&self) -> Dnum { + self.keys[0].dnum() + } +} + +impl LWEInfos for GGLWEToGGSWKeyLayout { + fn n(&self) -> Degree { + self.n + } + + fn base2k(&self) -> Base2K { + self.base2k + } + + fn k(&self) -> TorusPrecision { + self.k + } +} + +impl GLWEInfos for GGLWEToGGSWKeyLayout { + fn rank(&self) -> Rank { + self.rank_out() + } +} + +impl GGLWEInfos for GGLWEToGGSWKeyLayout { + fn rank_in(&self) -> Rank { + self.rank + } + + fn dsize(&self) -> Dsize { + self.dsize + } + + fn rank_out(&self) -> Rank { + self.rank + } + + fn dnum(&self) -> Dnum { + self.dnum + } +} + +impl fmt::Debug for GGLWEToGGSWKey { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{self}") + } +} + +impl FillUniform for GGLWEToGGSWKey { + fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { + self.keys + .iter_mut() + .for_each(|key: &mut GGLWE| key.fill_uniform(log_bound, source)) + } +} + +impl fmt::Display for GGLWEToGGSWKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "(GGLWEToGGSWKey)",)?; + for (i, key) in self.keys.iter().enumerate() { + write!(f, "{i}: {key}")?; + } + Ok(()) + } +} + +impl GGLWEToGGSWKey> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEToGGSWKey" + ); + Self::alloc( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { + GGLWEToGGSWKey { + keys: (0..rank.as_usize()) + .map(|_| GGLWE::alloc(n, base2k, k, rank, rank, dnum, dsize)) + .collect(), + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEToGGSWKey" + ); + Self::bytes_of( + infos.n(), + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + rank.as_usize() * GGLWE::bytes_of(n, base2k, k, rank, rank, dnum, dsize) + } +} + +impl GGLWEToGGSWKey { + // Returns a mutable reference to GGLWE_{s}([s[i]*s[0], s[i]*s[1], ..., s[i]*s[rank]]) + pub fn at_mut(&mut self, i: usize) -> &mut GGLWE { + assert!((i as u32) < self.rank()); + &mut self.keys[i] + } +} + +impl GGLWEToGGSWKey { + // Returns a reference to GGLWE_{s}(s[i] * s[j]) + pub fn at(&self, i: usize) -> &GGLWE { + assert!((i as u32) < self.rank()); + &self.keys[i] + } +} + +impl ReaderFrom for GGLWEToGGSWKey { + fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { + let len: usize = reader.read_u64::()? as usize; + if self.keys.len() != len { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("self.keys.len()={} != read len={}", self.keys.len(), len), + )); + } + for key in &mut self.keys { + key.read_from(reader)?; + } + Ok(()) + } +} + +impl WriterTo for GGLWEToGGSWKey { + fn write_to(&self, writer: &mut W) -> std::io::Result<()> { + writer.write_u64::(self.keys.len() as u64)?; + for key in &self.keys { + key.write_to(writer)?; + } + Ok(()) + } +} + +pub trait GGLWEToGGSWKeyToRef { + fn to_ref(&self) -> GGLWEToGGSWKey<&[u8]>; +} + +impl GGLWEToGGSWKeyToRef for GGLWEToGGSWKey +where + GGLWE: GGLWEToRef, +{ + fn to_ref(&self) -> GGLWEToGGSWKey<&[u8]> { + GGLWEToGGSWKey { + keys: self.keys.iter().map(|c| c.to_ref()).collect(), + } + } +} + +pub trait GGLWEToGGSWKeyToMut { + fn to_mut(&mut self) -> GGLWEToGGSWKey<&mut [u8]>; +} + +impl GGLWEToGGSWKeyToMut for GGLWEToGGSWKey +where + GGLWE: GGLWEToMut, +{ + fn to_mut(&mut self) -> GGLWEToGGSWKey<&mut [u8]> { + GGLWEToGGSWKey { + keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), + } + } +} diff --git a/poulpy-core/src/layouts/glwe_secret_tensor.rs b/poulpy-core/src/layouts/glwe_secret_tensor.rs new file mode 100644 index 0000000..287eda8 --- /dev/null +++ b/poulpy-core/src/layouts/glwe_secret_tensor.rs @@ -0,0 +1,221 @@ +use poulpy_hal::{ + api::{ + ModuleN, ScratchTakeBasic, SvpApplyDftToDft, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, + VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyTmpA, + }, + layouts::{ + Backend, Data, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToMut, ScalarZnxToRef, Scratch, ZnxInfos, ZnxView, + ZnxViewMut, + }, +}; + +use crate::{ + ScratchTakeCore, + dist::Distribution, + layouts::{ + Base2K, Degree, GLWEInfos, GLWESecret, GLWESecretPreparedFactory, GLWESecretToMut, GLWESecretToRef, LWEInfos, Rank, + TorusPrecision, + }, +}; + +pub struct GLWESecretTensor { + pub(crate) data: ScalarZnx, + pub(crate) rank: Rank, + pub(crate) dist: Distribution, +} + +impl GLWESecretTensor> { + pub(crate) fn pairs(rank: usize) -> usize { + (((rank + 1) * rank) >> 1).max(1) + } +} + +impl LWEInfos for GLWESecretTensor { + fn base2k(&self) -> Base2K { + Base2K(0) + } + + fn k(&self) -> TorusPrecision { + TorusPrecision(0) + } + + fn n(&self) -> Degree { + Degree(self.data.n() as u32) + } + + fn size(&self) -> usize { + 1 + } +} + +impl GLWESecretTensor { + pub fn at(&self, mut i: usize, mut j: usize) -> ScalarZnx<&[u8]> { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank().into(); + ScalarZnx { + data: bytemuck::cast_slice(self.data.at(i * rank + j - (i * (i + 1) / 2), 0)), + n: self.n().into(), + cols: 1, + } + } +} + +impl GLWESecretTensor { + pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> ScalarZnx<&mut [u8]> { + if i > j { + std::mem::swap(&mut i, &mut j); + }; + let rank: usize = self.rank().into(); + ScalarZnx { + n: self.n().into(), + data: bytemuck::cast_slice_mut(self.data.at_mut(i * rank + j - (i * (i + 1) / 2), 0)), + cols: 1, + } + } +} + +impl GLWEInfos for GLWESecretTensor { + fn rank(&self) -> Rank { + self.rank + } +} + +impl GLWESecretToRef for GLWESecretTensor { + fn to_ref(&self) -> GLWESecret<&[u8]> { + GLWESecret { + data: self.data.to_ref(), + dist: self.dist, + } + } +} + +impl GLWESecretToMut for GLWESecretTensor { + fn to_mut(&mut self) -> GLWESecret<&mut [u8]> { + GLWESecret { + dist: self.dist, + data: self.data.to_mut(), + } + } +} + +impl GLWESecretTensor> { + pub fn alloc_from_infos(infos: &A) -> Self + where + A: GLWEInfos, + { + Self::alloc(infos.n(), infos.rank()) + } + + pub fn alloc(n: Degree, rank: Rank) -> Self { + GLWESecretTensor { + data: ScalarZnx::alloc(n.into(), Self::pairs(rank.into())), + rank, + dist: Distribution::NONE, + } + } + + pub fn bytes_of_from_infos(infos: &A) -> usize + where + A: GLWEInfos, + { + Self::bytes_of(infos.n(), Self::pairs(infos.rank().into()).into()) + } + + pub fn bytes_of(n: Degree, rank: Rank) -> usize { + ScalarZnx::bytes_of(n.into(), Self::pairs(rank.into())) + } +} + +impl GLWESecretTensor { + pub fn prepare(&mut self, module: &M, other: &S, scratch: &mut Scratch) + where + M: GLWESecretTensorFactory, + S: GLWESecretToRef + GLWEInfos, + Scratch: ScratchTakeCore, + { + module.glwe_secret_tensor_prepare(self, other, scratch); + } +} + +pub trait GLWESecretTensorFactory { + fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize; + + fn glwe_secret_tensor_prepare(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GLWESecretToMut + GLWEInfos, + O: GLWESecretToRef + GLWEInfos; +} + +impl GLWESecretTensorFactory for Module +where + Self: ModuleN + + GLWESecretPreparedFactory + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDft + + VecZnxIdftApplyTmpA + + VecZnxBigNormalize + + VecZnxDftBytesOf + + VecZnxBigBytesOf + + VecZnxBigNormalizeTmpBytes, + Scratch: ScratchTakeCore, +{ + fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize { + self.bytes_of_glwe_secret_prepared(rank) + + self.bytes_of_vec_znx_dft(rank.into(), 1) + + self.bytes_of_vec_znx_dft(1, 1) + + self.bytes_of_vec_znx_big(1, 1) + + self.vec_znx_big_normalize_tmp_bytes() + } + + fn glwe_secret_tensor_prepare(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: GLWESecretToMut + GLWEInfos, + A: GLWESecretToRef + GLWEInfos, + { + let res: &mut GLWESecret<&mut [u8]> = &mut res.to_mut(); + let a: &GLWESecret<&[u8]> = &a.to_ref(); + + println!("res.rank: {} a.rank: {}", res.rank(), a.rank()); + + assert_eq!(res.rank(), GLWESecretTensor::pairs(a.rank().into()) as u32); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + + let rank: usize = a.rank().into(); + + let (mut a_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, rank.into()); + a_prepared.prepare(self, a); + + let base2k: usize = 17; + + let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank, 1); + for i in 0..rank { + self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a.data.as_vec_znx(), i); + } + + let (mut a_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1); + let (mut a_ij_dft, scratch_4) = scratch_3.take_vec_znx_dft(self, 1, 1); + + // sk_tensor = sk (x) sk + // For example: (s0, s1) (x) (s0, s1) = (s0^2, s0s1, s1^2) + for i in 0..rank { + for j in i..rank { + let idx: usize = i * rank + j - (i * (i + 1) / 2); + self.svp_apply_dft_to_dft(&mut a_ij_dft, 0, &a_prepared.data, j, &a_dft, i); + self.vec_znx_idft_apply_tmpa(&mut a_ij_big, 0, &mut a_ij_dft, 0); + self.vec_znx_big_normalize( + base2k, + &mut res.data.as_vec_znx_mut(), + idx, + base2k, + &a_ij_big, + 0, + scratch_4, + ); + } + } + } +} diff --git a/poulpy-core/src/layouts/glwe_tensor_key.rs b/poulpy-core/src/layouts/glwe_tensor_key.rs index bc0100f..032a892 100644 --- a/poulpy-core/src/layouts/glwe_tensor_key.rs +++ b/poulpy-core/src/layouts/glwe_tensor_key.rs @@ -6,7 +6,6 @@ use poulpy_hal::{ use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWEInfos, LWEInfos, Rank, TorusPrecision, }; -use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; @@ -21,31 +20,29 @@ pub struct GLWETensorKeyLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct GLWETensorKey { - pub(crate) keys: Vec>, -} +pub struct GLWETensorKey(pub(crate) GGLWE); impl LWEInfos for GLWETensorKey { fn n(&self) -> Degree { - self.keys[0].n() + self.0.n() } fn base2k(&self) -> Base2K { - self.keys[0].base2k() + self.0.base2k() } fn k(&self) -> TorusPrecision { - self.keys[0].k() + self.0.k() } fn size(&self) -> usize { - self.keys[0].size() + self.0.size() } } impl GLWEInfos for GLWETensorKey { fn rank(&self) -> Rank { - self.keys[0].rank_out() + self.0.rank_out() } } @@ -55,15 +52,15 @@ impl GGLWEInfos for GLWETensorKey { } fn rank_out(&self) -> Rank { - self.keys[0].rank_out() + self.0.rank_out() } fn dsize(&self) -> Dsize { - self.keys[0].dsize() + self.0.dsize() } fn dnum(&self) -> Dnum { - self.keys[0].dnum() + self.0.dnum() } } @@ -113,18 +110,14 @@ impl fmt::Debug for GLWETensorKey { impl FillUniform for GLWETensorKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { - self.keys - .iter_mut() - .for_each(|key: &mut GGLWE| key.fill_uniform(log_bound, source)) + self.0.fill_uniform(log_bound, source) } } impl fmt::Display for GLWETensorKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { writeln!(f, "(GLWETensorKey)",)?; - for (i, key) in self.keys.iter().enumerate() { - write!(f, "{i}: {key}")?; - } + write!(f, "{}", self.0)?; Ok(()) } } @@ -151,11 +144,7 @@ impl GLWETensorKey> { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); - GLWETensorKey { - keys: (0..pairs) - .map(|_| GGLWE::alloc(n, base2k, k, Rank(1), rank, dnum, dsize)) - .collect(), - } + GLWETensorKey(GGLWE::alloc(n, base2k, k, Rank(pairs), rank, dnum, dsize)) } pub fn bytes_of_from_infos(infos: &A) -> usize @@ -178,85 +167,39 @@ impl GLWETensorKey> { } pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { - let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; - pairs * GGLWE::bytes_of(n, base2k, k, Rank(1), rank, dnum, dsize) - } -} - -impl GLWETensorKey { - // Returns a mutable reference to GGLWE_{s}(s[i] * s[j]) - pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWE { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank_out().into(); - &mut self.keys[i * rank + j - (i * (i + 1) / 2)] - } -} - -impl GLWETensorKey { - // Returns a reference to GGLWE_{s}(s[i] * s[j]) - pub fn at(&self, mut i: usize, mut j: usize) -> &GGLWE { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank_out().into(); - &self.keys[i * rank + j - (i * (i + 1) / 2)] + let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); + GGLWE::bytes_of(n, base2k, k, Rank(pairs), rank, dnum, dsize) } } impl ReaderFrom for GLWETensorKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { - let len: usize = reader.read_u64::()? as usize; - if self.keys.len() != len { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - format!("self.keys.len()={} != read len={}", self.keys.len(), len), - )); - } - for key in &mut self.keys { - key.read_from(reader)?; - } + self.0.read_from(reader)?; Ok(()) } } impl WriterTo for GLWETensorKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { - writer.write_u64::(self.keys.len() as u64)?; - for key in &self.keys { - key.write_to(writer)?; - } + self.0.write_to(writer)?; Ok(()) } } -pub trait GLWETensorKeyToRef { - fn to_ref(&self) -> GLWETensorKey<&[u8]>; -} - -impl GLWETensorKeyToRef for GLWETensorKey +impl GGLWEToRef for GLWETensorKey where GGLWE: GGLWEToRef, { - fn to_ref(&self) -> GLWETensorKey<&[u8]> { - GLWETensorKey { - keys: self.keys.iter().map(|c| c.to_ref()).collect(), - } + fn to_ref(&self) -> GGLWE<&[u8]> { + self.0.to_ref() } } -pub trait GLWETensorKeyToMut { - fn to_mut(&mut self) -> GLWETensorKey<&mut [u8]>; -} - -impl GLWETensorKeyToMut for GLWETensorKey +impl GGLWEToMut for GLWETensorKey where GGLWE: GGLWEToMut, { - fn to_mut(&mut self) -> GLWETensorKey<&mut [u8]> { - GLWETensorKey { - keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), - } + fn to_mut(&mut self) -> GGLWE<&mut [u8]> { + self.0.to_mut() } } diff --git a/poulpy-core/src/layouts/glwe_to_lwe_switching_key.rs b/poulpy-core/src/layouts/glwe_to_lwe_key.rs similarity index 79% rename from poulpy-core/src/layouts/glwe_to_lwe_switching_key.rs rename to poulpy-core/src/layouts/glwe_to_lwe_key.rs index bc3ee4b..2541d32 100644 --- a/poulpy-core/src/layouts/glwe_to_lwe_switching_key.rs +++ b/poulpy-core/src/layouts/glwe_to_lwe_key.rs @@ -59,9 +59,9 @@ impl GGLWEInfos for GLWEToLWEKeyLayout { /// A special [GLWESwitchingKey] required to for the conversion from [GLWE] to [LWE]. #[derive(PartialEq, Eq, Clone)] -pub struct GLWEToLWESwitchingKey(pub(crate) GLWESwitchingKey); +pub struct GLWEToLWEKey(pub(crate) GLWESwitchingKey); -impl LWEInfos for GLWEToLWESwitchingKey { +impl LWEInfos for GLWEToLWEKey { fn base2k(&self) -> Base2K { self.0.base2k() } @@ -79,12 +79,12 @@ impl LWEInfos for GLWEToLWESwitchingKey { } } -impl GLWEInfos for GLWEToLWESwitchingKey { +impl GLWEInfos for GLWEToLWEKey { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GLWEToLWESwitchingKey { +impl GGLWEInfos for GLWEToLWEKey { fn rank_in(&self) -> Rank { self.0.rank_in() } @@ -102,37 +102,37 @@ impl GGLWEInfos for GLWEToLWESwitchingKey { } } -impl fmt::Debug for GLWEToLWESwitchingKey { +impl fmt::Debug for GLWEToLWEKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GLWEToLWESwitchingKey { +impl FillUniform for GLWEToLWEKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.0.fill_uniform(log_bound, source); } } -impl fmt::Display for GLWEToLWESwitchingKey { +impl fmt::Display for GLWEToLWEKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "(GLWEToLWESwitchingKey) {}", self.0) + write!(f, "(GLWEToLWEKey) {}", self.0) } } -impl ReaderFrom for GLWEToLWESwitchingKey { +impl ReaderFrom for GLWEToLWEKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.0.read_from(reader) } } -impl WriterTo for GLWEToLWESwitchingKey { +impl WriterTo for GLWEToLWEKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { self.0.write_to(writer) } } -impl GLWEToLWESwitchingKey> { +impl GLWEToLWEKey> { pub fn alloc_from_infos(infos: &A) -> Self where A: GGLWEInfos, @@ -140,12 +140,12 @@ impl GLWEToLWESwitchingKey> { assert_eq!( infos.rank_out().0, 1, - "rank_out > 1 is not supported for GLWEToLWESwitchingKey" + "rank_out > 1 is not supported for GLWEToLWEKey" ); assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for GLWEToLWESwitchingKey" + "dsize > 1 is not supported for GLWEToLWEKey" ); Self::alloc( infos.n(), @@ -157,7 +157,7 @@ impl GLWEToLWESwitchingKey> { } pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self { - GLWEToLWESwitchingKey(GLWESwitchingKey::alloc( + GLWEToLWEKey(GLWESwitchingKey::alloc( n, base2k, k, @@ -196,19 +196,19 @@ impl GLWEToLWESwitchingKey> { } } -impl GGLWEToRef for GLWEToLWESwitchingKey { +impl GGLWEToRef for GLWEToLWEKey { fn to_ref(&self) -> GGLWE<&[u8]> { self.0.to_ref() } } -impl GGLWEToMut for GLWEToLWESwitchingKey { +impl GGLWEToMut for GLWEToLWEKey { fn to_mut(&mut self) -> GGLWE<&mut [u8]> { self.0.to_mut() } } -impl GLWESwitchingKeyDegreesMut for GLWEToLWESwitchingKey { +impl GLWESwitchingKeyDegreesMut for GLWEToLWEKey { fn input_degree(&mut self) -> &mut Degree { &mut self.0.input_degree } @@ -218,7 +218,7 @@ impl GLWESwitchingKeyDegreesMut for GLWEToLWESwitchingKey { } } -impl GLWESwitchingKeyDegrees for GLWEToLWESwitchingKey { +impl GLWESwitchingKeyDegrees for GLWEToLWEKey { fn input_degree(&self) -> &Degree { &self.0.input_degree } diff --git a/poulpy-core/src/layouts/lwe_to_glwe_switching_key.rs b/poulpy-core/src/layouts/lwe_to_glwe_key.rs similarity index 73% rename from poulpy-core/src/layouts/lwe_to_glwe_switching_key.rs rename to poulpy-core/src/layouts/lwe_to_glwe_key.rs index caa676d..5a44f61 100644 --- a/poulpy-core/src/layouts/lwe_to_glwe_switching_key.rs +++ b/poulpy-core/src/layouts/lwe_to_glwe_key.rs @@ -11,7 +11,7 @@ use crate::layouts::{ }; #[derive(PartialEq, Eq, Copy, Clone, Debug)] -pub struct LWEToGLWESwitchingKeyLayout { +pub struct LWEToGLWEKeyLayout { pub n: Degree, pub base2k: Base2K, pub k: TorusPrecision, @@ -19,7 +19,7 @@ pub struct LWEToGLWESwitchingKeyLayout { pub dnum: Dnum, } -impl LWEInfos for LWEToGLWESwitchingKeyLayout { +impl LWEInfos for LWEToGLWEKeyLayout { fn base2k(&self) -> Base2K { self.base2k } @@ -33,13 +33,13 @@ impl LWEInfos for LWEToGLWESwitchingKeyLayout { } } -impl GLWEInfos for LWEToGLWESwitchingKeyLayout { +impl GLWEInfos for LWEToGLWEKeyLayout { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for LWEToGLWESwitchingKeyLayout { +impl GGLWEInfos for LWEToGLWEKeyLayout { fn rank_in(&self) -> Rank { Rank(1) } @@ -58,9 +58,9 @@ impl GGLWEInfos for LWEToGLWESwitchingKeyLayout { } #[derive(PartialEq, Eq, Clone)] -pub struct LWEToGLWESwitchingKey(pub(crate) GLWESwitchingKey); +pub struct LWEToGLWEKey(pub(crate) GLWESwitchingKey); -impl LWEInfos for LWEToGLWESwitchingKey { +impl LWEInfos for LWEToGLWEKey { fn base2k(&self) -> Base2K { self.0.base2k() } @@ -78,12 +78,12 @@ impl LWEInfos for LWEToGLWESwitchingKey { } } -impl GLWEInfos for LWEToGLWESwitchingKey { +impl GLWEInfos for LWEToGLWEKey { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for LWEToGLWESwitchingKey { +impl GGLWEInfos for LWEToGLWEKey { fn dsize(&self) -> Dsize { self.0.dsize() } @@ -101,37 +101,37 @@ impl GGLWEInfos for LWEToGLWESwitchingKey { } } -impl fmt::Debug for LWEToGLWESwitchingKey { +impl fmt::Debug for LWEToGLWEKey { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for LWEToGLWESwitchingKey { +impl FillUniform for LWEToGLWEKey { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.0.fill_uniform(log_bound, source); } } -impl fmt::Display for LWEToGLWESwitchingKey { +impl fmt::Display for LWEToGLWEKey { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "(LWEToGLWESwitchingKey) {}", self.0) + write!(f, "(LWEToGLWEKey) {}", self.0) } } -impl ReaderFrom for LWEToGLWESwitchingKey { +impl ReaderFrom for LWEToGLWEKey { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.0.read_from(reader) } } -impl WriterTo for LWEToGLWESwitchingKey { +impl WriterTo for LWEToGLWEKey { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { self.0.write_to(writer) } } -impl LWEToGLWESwitchingKey> { +impl LWEToGLWEKey> { pub fn alloc_from_infos(infos: &A) -> Self where A: GGLWEInfos, @@ -139,12 +139,12 @@ impl LWEToGLWESwitchingKey> { assert_eq!( infos.rank_in().0, 1, - "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + "rank_in > 1 is not supported for LWEToGLWEKey" ); assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for LWEToGLWESwitchingKey" + "dsize > 1 is not supported for LWEToGLWEKey" ); Self::alloc( @@ -157,7 +157,7 @@ impl LWEToGLWESwitchingKey> { } pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { - LWEToGLWESwitchingKey(GLWESwitchingKey::alloc( + LWEToGLWEKey(GLWESwitchingKey::alloc( n, base2k, k, @@ -175,12 +175,12 @@ impl LWEToGLWESwitchingKey> { assert_eq!( infos.rank_in().0, 1, - "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + "rank_in > 1 is not supported for LWEToGLWEKey" ); assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for LWEToGLWESwitchingKey" + "dsize > 1 is not supported for LWEToGLWEKey" ); Self::bytes_of( infos.n(), @@ -196,19 +196,19 @@ impl LWEToGLWESwitchingKey> { } } -impl GGLWEToRef for LWEToGLWESwitchingKey { +impl GGLWEToRef for LWEToGLWEKey { fn to_ref(&self) -> GGLWE<&[u8]> { self.0.to_ref() } } -impl GGLWEToMut for LWEToGLWESwitchingKey { +impl GGLWEToMut for LWEToGLWEKey { fn to_mut(&mut self) -> GGLWE<&mut [u8]> { self.0.to_mut() } } -impl GLWESwitchingKeyDegreesMut for LWEToGLWESwitchingKey { +impl GLWESwitchingKeyDegreesMut for LWEToGLWEKey { fn input_degree(&mut self) -> &mut Degree { &mut self.0.input_degree } @@ -218,7 +218,7 @@ impl GLWESwitchingKeyDegreesMut for LWEToGLWESwitchingKey { } } -impl GLWESwitchingKeyDegrees for LWEToGLWESwitchingKey { +impl GLWESwitchingKeyDegrees for LWEToGLWEKey { fn input_degree(&self) -> &Degree { &self.0.input_degree } diff --git a/poulpy-core/src/layouts/mod.rs b/poulpy-core/src/layouts/mod.rs index 7c5cc5b..2dbc700 100644 --- a/poulpy-core/src/layouts/mod.rs +++ b/poulpy-core/src/layouts/mod.rs @@ -1,40 +1,44 @@ mod gglwe; +mod gglwe_to_ggsw_key; mod ggsw; mod glwe; mod glwe_automorphism_key; mod glwe_plaintext; mod glwe_public_key; mod glwe_secret; +mod glwe_secret_tensor; mod glwe_switching_key; mod glwe_tensor; mod glwe_tensor_key; -mod glwe_to_lwe_switching_key; +mod glwe_to_lwe_key; mod lwe; mod lwe_plaintext; mod lwe_secret; mod lwe_switching_key; -mod lwe_to_glwe_switching_key; +mod lwe_to_glwe_key; pub mod compressed; pub mod prepared; pub use compressed::*; pub use gglwe::*; +pub use gglwe_to_ggsw_key::*; pub use ggsw::*; pub use glwe::*; pub use glwe_automorphism_key::*; pub use glwe_plaintext::*; pub use glwe_public_key::*; pub use glwe_secret::*; +pub use glwe_secret_tensor::*; pub use glwe_switching_key::*; pub use glwe_tensor::*; pub use glwe_tensor_key::*; -pub use glwe_to_lwe_switching_key::*; +pub use glwe_to_lwe_key::*; pub use lwe::*; pub use lwe_plaintext::*; pub use lwe_secret::*; pub use lwe_switching_key::*; -pub use lwe_to_glwe_switching_key::*; +pub use lwe_to_glwe_key::*; pub use prepared::*; use poulpy_hal::layouts::{Backend, Module}; diff --git a/poulpy-core/src/layouts/prepared/gglwe_to_ggsw_key.rs b/poulpy-core/src/layouts/prepared/gglwe_to_ggsw_key.rs new file mode 100644 index 0000000..d63dca6 --- /dev/null +++ b/poulpy-core/src/layouts/prepared/gglwe_to_ggsw_key.rs @@ -0,0 +1,252 @@ +use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; + +use crate::layouts::{ + Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEPrepared, GGLWEPreparedFactory, GGLWEPreparedToMut, GGLWEPreparedToRef, + GGLWEToGGSWKey, GGLWEToGGSWKeyToRef, GLWEInfos, LWEInfos, Rank, TorusPrecision, +}; + +pub struct GGLWEToGGSWKeyPrepared { + pub(crate) keys: Vec>, +} + +impl LWEInfos for GGLWEToGGSWKeyPrepared { + fn n(&self) -> Degree { + self.keys[0].n() + } + + fn base2k(&self) -> Base2K { + self.keys[0].base2k() + } + + fn k(&self) -> TorusPrecision { + self.keys[0].k() + } + + fn size(&self) -> usize { + self.keys[0].size() + } +} + +impl GLWEInfos for GGLWEToGGSWKeyPrepared { + fn rank(&self) -> Rank { + self.keys[0].rank_out() + } +} + +impl GGLWEInfos for GGLWEToGGSWKeyPrepared { + fn rank_in(&self) -> Rank { + self.rank_out() + } + + fn rank_out(&self) -> Rank { + self.keys[0].rank_out() + } + + fn dsize(&self) -> Dsize { + self.keys[0].dsize() + } + + fn dnum(&self) -> Dnum { + self.keys[0].dnum() + } +} + +pub trait GGLWEToGGSWKeyPreparedFactory { + fn alloc_gglwe_to_ggsw_key_prepared_from_infos(&self, infos: &A) -> GGLWEToGGSWKeyPrepared, BE> + where + A: GGLWEInfos; + + fn alloc_gglwe_to_ggsw_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> GGLWEToGGSWKeyPrepared, BE>; + + fn bytes_of_gglwe_to_ggsw_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn bytes_of_gglwe_to_ggsw(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize; + + fn prepare_gglwe_to_ggsw_key_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos; + + fn prepare_gglwe_to_ggsw_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GGLWEToGGSWKeyPreparedToMut, + O: GGLWEToGGSWKeyToRef; +} + +impl GGLWEToGGSWKeyPreparedFactory for Module +where + Self: GGLWEPreparedFactory, +{ + fn alloc_gglwe_to_ggsw_key_prepared_from_infos(&self, infos: &A) -> GGLWEToGGSWKeyPrepared, BE> + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEToGGSWKeyPrepared" + ); + self.alloc_gglwe_to_ggsw_key_prepared( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + fn alloc_gglwe_to_ggsw_key_prepared( + &self, + base2k: Base2K, + k: TorusPrecision, + rank: Rank, + dnum: Dnum, + dsize: Dsize, + ) -> GGLWEToGGSWKeyPrepared, BE> { + GGLWEToGGSWKeyPrepared { + keys: (0..rank.as_usize()) + .map(|_| self.alloc_gglwe_prepared(base2k, k, rank, rank, dnum, dsize)) + .collect(), + } + } + + fn bytes_of_gglwe_to_ggsw_from_infos(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + assert_eq!( + infos.rank_in(), + infos.rank_out(), + "rank_in != rank_out is not supported for GGLWEToGGSWKeyPrepared" + ); + self.bytes_of_gglwe_to_ggsw( + infos.base2k(), + infos.k(), + infos.rank(), + infos.dnum(), + infos.dsize(), + ) + } + + fn bytes_of_gglwe_to_ggsw(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { + rank.as_usize() * self.bytes_of_gglwe_prepared(base2k, k, rank, rank, dnum, dsize) + } + + fn prepare_gglwe_to_ggsw_key_tmp_bytes(&self, infos: &A) -> usize + where + A: GGLWEInfos, + { + self.prepare_gglwe_tmp_bytes(infos) + } + + fn prepare_gglwe_to_ggsw_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + where + R: GGLWEToGGSWKeyPreparedToMut, + O: GGLWEToGGSWKeyToRef, + { + let res: &mut GGLWEToGGSWKeyPrepared<&mut [u8], BE> = &mut res.to_mut(); + let other: &GGLWEToGGSWKey<&[u8]> = &other.to_ref(); + + assert_eq!(res.keys.len(), other.keys.len()); + + for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) { + self.prepare_gglwe(a, b, scratch); + } + } +} + +impl GGLWEToGGSWKeyPrepared, BE> { + pub fn alloc_from_infos(module: &M, infos: &A) -> Self + where + A: GGLWEInfos, + M: GGLWEToGGSWKeyPreparedFactory, + { + module.alloc_gglwe_to_ggsw_key_prepared_from_infos(infos) + } + + pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self + where + M: GGLWEToGGSWKeyPreparedFactory, + { + module.alloc_gglwe_to_ggsw_key_prepared(base2k, k, rank, dnum, dsize) + } + + pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize + where + A: GGLWEInfos, + M: GGLWEToGGSWKeyPreparedFactory, + { + module.bytes_of_gglwe_to_ggsw_from_infos(infos) + } + + pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize + where + M: GGLWEToGGSWKeyPreparedFactory, + { + module.bytes_of_gglwe_to_ggsw(base2k, k, rank, dnum, dsize) + } +} + +impl GGLWEToGGSWKeyPrepared { + pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) + where + M: GGLWEToGGSWKeyPreparedFactory, + O: GGLWEToGGSWKeyToRef, + { + module.prepare_gglwe_to_ggsw_key(self, other, scratch); + } +} + +impl GGLWEToGGSWKeyPrepared { + // Returns a mutable reference to GGLWEPrepared_{s}([s[i]*s[0], s[i]*s[1], ..., s[i]*s[rank]]) + pub fn at_mut(&mut self, i: usize) -> &mut GGLWEPrepared { + assert!((i as u32) < self.rank()); + &mut self.keys[i] + } +} + +impl GGLWEToGGSWKeyPrepared { + // Returns a reference to GGLWEPrepared_{s}([s[i]*s[0], s[i]*s[1], ..., s[i]*s[rank]]) + pub fn at(&self, i: usize) -> &GGLWEPrepared { + assert!((i as u32) < self.rank()); + &self.keys[i] + } +} + +pub trait GGLWEToGGSWKeyPreparedToRef { + fn to_ref(&self) -> GGLWEToGGSWKeyPrepared<&[u8], BE>; +} + +impl GGLWEToGGSWKeyPreparedToRef for GGLWEToGGSWKeyPrepared +where + GGLWEPrepared: GGLWEPreparedToRef, +{ + fn to_ref(&self) -> GGLWEToGGSWKeyPrepared<&[u8], BE> { + GGLWEToGGSWKeyPrepared { + keys: self.keys.iter().map(|c| c.to_ref()).collect(), + } + } +} + +pub trait GGLWEToGGSWKeyPreparedToMut { + fn to_mut(&mut self) -> GGLWEToGGSWKeyPrepared<&mut [u8], BE>; +} + +impl GGLWEToGGSWKeyPreparedToMut for GGLWEToGGSWKeyPrepared +where + GGLWEPrepared: GGLWEPreparedToMut, +{ + fn to_mut(&mut self) -> GGLWEToGGSWKeyPrepared<&mut [u8], BE> { + GGLWEToGGSWKeyPrepared { + keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), + } + } +} diff --git a/poulpy-core/src/layouts/prepared/glwe_switching_key.rs b/poulpy-core/src/layouts/prepared/glwe_switching_key.rs index d73d17d..f73299b 100644 --- a/poulpy-core/src/layouts/prepared/glwe_switching_key.rs +++ b/poulpy-core/src/layouts/prepared/glwe_switching_key.rs @@ -109,7 +109,7 @@ where ) } - fn bytes_of_glwe_switching_key_prepared( + fn bytes_of_glwe_key_prepared( &self, base2k: Base2K, k: TorusPrecision, @@ -125,7 +125,7 @@ where where A: GGLWEInfos, { - self.bytes_of_glwe_switching_key_prepared( + self.bytes_of_glwe_key_prepared( infos.base2k(), infos.k(), infos.rank_in(), @@ -199,7 +199,7 @@ impl GLWESwitchingKeyPrepared, B> { where M: GLWESwitchingKeyPreparedFactory, { - module.bytes_of_glwe_switching_key_prepared(base2k, k, rank_in, rank_out, dnum, dsize) + module.bytes_of_glwe_key_prepared(base2k, k, rank_in, rank_out, dnum, dsize) } } diff --git a/poulpy-core/src/layouts/prepared/glwe_tensor_key.rs b/poulpy-core/src/layouts/prepared/glwe_tensor_key.rs index bd63c75..0304b37 100644 --- a/poulpy-core/src/layouts/prepared/glwe_tensor_key.rs +++ b/poulpy-core/src/layouts/prepared/glwe_tensor_key.rs @@ -2,29 +2,27 @@ use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch}; use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEPrepared, GGLWEPreparedFactory, GGLWEPreparedToMut, GGLWEPreparedToRef, - GLWEInfos, GLWETensorKey, GLWETensorKeyToRef, LWEInfos, Rank, TorusPrecision, + GGLWEToRef, GLWEInfos, LWEInfos, Rank, TorusPrecision, }; #[derive(PartialEq, Eq)] -pub struct GLWETensorKeyPrepared { - pub(crate) keys: Vec>, -} +pub struct GLWETensorKeyPrepared(pub(crate) GGLWEPrepared); impl LWEInfos for GLWETensorKeyPrepared { fn n(&self) -> Degree { - self.keys[0].n() + self.0.n() } fn base2k(&self) -> Base2K { - self.keys[0].base2k() + self.0.base2k() } fn k(&self) -> TorusPrecision { - self.keys[0].k() + self.0.k() } fn size(&self) -> usize { - self.keys[0].size() + self.0.size() } } @@ -40,15 +38,15 @@ impl GGLWEInfos for GLWETensorKeyPrepared { } fn rank_out(&self) -> Rank { - self.keys[0].rank_out() + self.0.rank_out() } fn dsize(&self) -> Dsize { - self.keys[0].dsize() + self.0.dsize() } fn dnum(&self) -> Dnum { - self.keys[0].dnum() + self.0.dnum() } } @@ -65,11 +63,7 @@ where rank: Rank, ) -> GLWETensorKeyPrepared, B> { let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); - GLWETensorKeyPrepared { - keys: (0..pairs) - .map(|_| self.alloc_gglwe_prepared(base2k, k, Rank(1), rank, dnum, dsize)) - .collect(), - } + GLWETensorKeyPrepared(self.alloc_gglwe_prepared(base2k, k, Rank(pairs), rank, dnum, dsize)) } fn alloc_tensor_key_prepared_from_infos(&self, infos: &A) -> GLWETensorKeyPrepared, B> @@ -91,8 +85,8 @@ where } fn bytes_of_tensor_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { - let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; - pairs * self.bytes_of_gglwe_prepared(base2k, k, Rank(1), rank, dnum, dsize) + let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); + self.bytes_of_gglwe_prepared(base2k, k, Rank(pairs), rank, dnum, dsize) } fn bytes_of_tensor_key_prepared_from_infos(&self, infos: &A) -> usize @@ -117,17 +111,10 @@ where fn prepare_tensor_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) where - R: GLWETensorKeyPreparedToMut, - O: GLWETensorKeyToRef, + R: GGLWEPreparedToMut, + O: GGLWEToRef, { - let mut res: GLWETensorKeyPrepared<&mut [u8], B> = res.to_mut(); - let other: GLWETensorKey<&[u8]> = other.to_ref(); - - assert_eq!(res.keys.len(), other.keys.len()); - - for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) { - self.prepare_gglwe(a, b, scratch); - } + self.prepare_gglwe(res, other, scratch); } } @@ -165,28 +152,6 @@ impl GLWETensorKeyPrepared, B> { } } -impl GLWETensorKeyPrepared { - // Returns a mutable reference to GGLWE_{s}(s[i] * s[j]) - pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWEPrepared { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank_out().into(); - &mut self.keys[i * rank + j - (i * (i + 1) / 2)] - } -} - -impl GLWETensorKeyPrepared { - // Returns a reference to GGLWE_{s}(s[i] * s[j]) - pub fn at(&self, mut i: usize, mut j: usize) -> &GGLWEPrepared { - if i > j { - std::mem::swap(&mut i, &mut j); - }; - let rank: usize = self.rank_out().into(); - &self.keys[i * rank + j - (i * (i + 1) / 2)] - } -} - impl GLWETensorKeyPrepared, B> { pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) -> usize where @@ -200,39 +165,27 @@ impl GLWETensorKeyPrepared, B> { impl GLWETensorKeyPrepared { pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) where - O: GLWETensorKeyToRef, + O: GGLWEToRef, M: GLWETensorKeyPreparedFactory, { module.prepare_tensor_key(self, other, scratch); } } -pub trait GLWETensorKeyPreparedToMut { - fn to_mut(&mut self) -> GLWETensorKeyPrepared<&mut [u8], B>; -} - -impl GLWETensorKeyPreparedToMut for GLWETensorKeyPrepared +impl GGLWEPreparedToMut for GLWETensorKeyPrepared where GGLWEPrepared: GGLWEPreparedToMut, { - fn to_mut(&mut self) -> GLWETensorKeyPrepared<&mut [u8], B> { - GLWETensorKeyPrepared { - keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), - } + fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B> { + self.0.to_mut() } } -pub trait GLWETensorKeyPreparedToRef { - fn to_ref(&self) -> GLWETensorKeyPrepared<&[u8], B>; -} - -impl GLWETensorKeyPreparedToRef for GLWETensorKeyPrepared +impl GGLWEPreparedToRef for GLWETensorKeyPrepared where GGLWEPrepared: GGLWEPreparedToRef, { - fn to_ref(&self) -> GLWETensorKeyPrepared<&[u8], B> { - GLWETensorKeyPrepared { - keys: self.keys.iter().map(|c| c.to_ref()).collect(), - } + fn to_ref(&self) -> GGLWEPrepared<&[u8], B> { + self.0.to_ref() } } diff --git a/poulpy-core/src/layouts/prepared/glwe_to_lwe_switching_key.rs b/poulpy-core/src/layouts/prepared/glwe_to_lwe_key.rs similarity index 54% rename from poulpy-core/src/layouts/prepared/glwe_to_lwe_switching_key.rs rename to poulpy-core/src/layouts/prepared/glwe_to_lwe_key.rs index 6edac5e..675a73f 100644 --- a/poulpy-core/src/layouts/prepared/glwe_to_lwe_switching_key.rs +++ b/poulpy-core/src/layouts/prepared/glwe_to_lwe_key.rs @@ -7,9 +7,9 @@ use crate::layouts::{ }; #[derive(PartialEq, Eq)] -pub struct GLWEToLWESwitchingKeyPrepared(pub(crate) GLWESwitchingKeyPrepared); +pub struct GLWEToLWEKeyPrepared(pub(crate) GLWESwitchingKeyPrepared); -impl LWEInfos for GLWEToLWESwitchingKeyPrepared { +impl LWEInfos for GLWEToLWEKeyPrepared { fn base2k(&self) -> Base2K { self.0.base2k() } @@ -27,13 +27,13 @@ impl LWEInfos for GLWEToLWESwitchingKeyPrepared { } } -impl GLWEInfos for GLWEToLWESwitchingKeyPrepared { +impl GLWEInfos for GLWEToLWEKeyPrepared { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GLWEToLWESwitchingKeyPrepared { +impl GGLWEInfos for GLWEToLWEKeyPrepared { fn rank_in(&self) -> Rank { self.0.rank_in() } @@ -51,65 +51,65 @@ impl GGLWEInfos for GLWEToLWESwitchingKeyPrepared { } } -pub trait GLWEToLWESwitchingKeyPreparedFactory +pub trait GLWEToLWEKeyPreparedFactory where Self: GLWESwitchingKeyPreparedFactory, { - fn alloc_glwe_to_lwe_switching_key_prepared( + fn alloc_glwe_to_lwe_key_prepared( &self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, - ) -> GLWEToLWESwitchingKeyPrepared, B> { - GLWEToLWESwitchingKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1))) + ) -> GLWEToLWEKeyPrepared, B> { + GLWEToLWEKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1))) } - fn alloc_glwe_to_lwe_switching_key_prepared_from_infos(&self, infos: &A) -> GLWEToLWESwitchingKeyPrepared, B> + fn alloc_glwe_to_lwe_key_prepared_from_infos(&self, infos: &A) -> GLWEToLWEKeyPrepared, B> where A: GGLWEInfos, { debug_assert_eq!( infos.rank_out().0, 1, - "rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + "rank_out > 1 is not supported for GLWEToLWEKeyPrepared" ); debug_assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + "dsize > 1 is not supported for GLWEToLWEKeyPrepared" ); - self.alloc_glwe_to_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) + self.alloc_glwe_to_lwe_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) } - fn bytes_of_glwe_to_lwe_switching_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize { - self.bytes_of_glwe_switching_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1)) + fn bytes_of_glwe_to_lwe_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize { + self.bytes_of_glwe_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1)) } - fn bytes_of_glwe_to_lwe_switching_key_prepared_from_infos(&self, infos: &A) -> usize + fn bytes_of_glwe_to_lwe_key_prepared_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, { debug_assert_eq!( infos.rank_out().0, 1, - "rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + "rank_out > 1 is not supported for GLWEToLWEKeyPrepared" ); debug_assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for GLWEToLWESwitchingKeyPrepared" + "dsize > 1 is not supported for GLWEToLWEKeyPrepared" ); - self.bytes_of_glwe_to_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) + self.bytes_of_glwe_to_lwe_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) } - fn prepare_glwe_to_lwe_switching_key_tmp_bytes(&self, infos: &A) -> usize + fn prepare_glwe_to_lwe_key_tmp_bytes(&self, infos: &A) -> usize where A: GGLWEInfos, { self.prepare_glwe_switching_key_tmp_bytes(infos) } - fn prepare_glwe_to_lwe_switching_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + fn prepare_glwe_to_lwe_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) where R: GGLWEPreparedToMut + GLWESwitchingKeyDegreesMut, O: GGLWEToRef + GLWESwitchingKeyDegrees, @@ -118,61 +118,61 @@ where } } -impl GLWEToLWESwitchingKeyPreparedFactory for Module where Self: GLWESwitchingKeyPreparedFactory {} +impl GLWEToLWEKeyPreparedFactory for Module where Self: GLWESwitchingKeyPreparedFactory {} -impl GLWEToLWESwitchingKeyPrepared, B> { +impl GLWEToLWEKeyPrepared, B> { pub fn alloc_from_infos(module: &M, infos: &A) -> Self where A: GGLWEInfos, - M: GLWEToLWESwitchingKeyPreparedFactory, + M: GLWEToLWEKeyPreparedFactory, { - module.alloc_glwe_to_lwe_switching_key_prepared_from_infos(infos) + module.alloc_glwe_to_lwe_key_prepared_from_infos(infos) } pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self where - M: GLWEToLWESwitchingKeyPreparedFactory, + M: GLWEToLWEKeyPreparedFactory, { - module.alloc_glwe_to_lwe_switching_key_prepared(base2k, k, rank_in, dnum) + module.alloc_glwe_to_lwe_key_prepared(base2k, k, rank_in, dnum) } pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize where A: GGLWEInfos, - M: GLWEToLWESwitchingKeyPreparedFactory, + M: GLWEToLWEKeyPreparedFactory, { - module.bytes_of_glwe_to_lwe_switching_key_prepared_from_infos(infos) + module.bytes_of_glwe_to_lwe_key_prepared_from_infos(infos) } pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize where - M: GLWEToLWESwitchingKeyPreparedFactory, + M: GLWEToLWEKeyPreparedFactory, { - module.bytes_of_glwe_to_lwe_switching_key_prepared(base2k, k, rank_in, dnum) + module.bytes_of_glwe_to_lwe_key_prepared(base2k, k, rank_in, dnum) } } -impl GLWEToLWESwitchingKeyPrepared, B> { +impl GLWEToLWEKeyPrepared, B> { pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) where A: GGLWEInfos, - M: GLWEToLWESwitchingKeyPreparedFactory, + M: GLWEToLWEKeyPreparedFactory, { - module.prepare_glwe_to_lwe_switching_key_tmp_bytes(infos); + module.prepare_glwe_to_lwe_key_tmp_bytes(infos); } } -impl GLWEToLWESwitchingKeyPrepared { +impl GLWEToLWEKeyPrepared { pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) where O: GGLWEToRef + GLWESwitchingKeyDegrees, - M: GLWEToLWESwitchingKeyPreparedFactory, + M: GLWEToLWEKeyPreparedFactory, { - module.prepare_glwe_to_lwe_switching_key(self, other, scratch); + module.prepare_glwe_to_lwe_key(self, other, scratch); } } -impl GGLWEPreparedToRef for GLWEToLWESwitchingKeyPrepared +impl GGLWEPreparedToRef for GLWEToLWEKeyPrepared where GLWESwitchingKeyPrepared: GGLWEPreparedToRef, { @@ -181,7 +181,7 @@ where } } -impl GGLWEPreparedToMut for GLWEToLWESwitchingKeyPrepared +impl GGLWEPreparedToMut for GLWEToLWEKeyPrepared where GLWESwitchingKeyPrepared: GGLWEPreparedToRef, { @@ -190,7 +190,7 @@ where } } -impl GLWESwitchingKeyDegreesMut for GLWEToLWESwitchingKeyPrepared { +impl GLWESwitchingKeyDegreesMut for GLWEToLWEKeyPrepared { fn input_degree(&mut self) -> &mut Degree { &mut self.0.input_degree } @@ -200,7 +200,7 @@ impl GLWESwitchingKeyDegreesMut for GLWEToLWESwitchingKe } } -impl GLWESwitchingKeyDegrees for GLWEToLWESwitchingKeyPrepared { +impl GLWESwitchingKeyDegrees for GLWEToLWEKeyPrepared { fn input_degree(&self) -> &Degree { &self.0.input_degree } diff --git a/poulpy-core/src/layouts/prepared/lwe_switching_key.rs b/poulpy-core/src/layouts/prepared/lwe_switching_key.rs index 327d001..16f77eb 100644 --- a/poulpy-core/src/layouts/prepared/lwe_switching_key.rs +++ b/poulpy-core/src/layouts/prepared/lwe_switching_key.rs @@ -86,7 +86,7 @@ where } fn bytes_of_lwe_switching_key_prepared(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { - self.bytes_of_glwe_switching_key_prepared(base2k, k, Rank(1), Rank(1), dnum, Dsize(1)) + self.bytes_of_glwe_key_prepared(base2k, k, Rank(1), Rank(1), dnum, Dsize(1)) } fn bytes_of_lwe_switching_key_prepared_from_infos(&self, infos: &A) -> usize diff --git a/poulpy-core/src/layouts/prepared/lwe_to_glwe_switching_key.rs b/poulpy-core/src/layouts/prepared/lwe_to_glwe_key.rs similarity index 53% rename from poulpy-core/src/layouts/prepared/lwe_to_glwe_switching_key.rs rename to poulpy-core/src/layouts/prepared/lwe_to_glwe_key.rs index 30ed131..25f08f8 100644 --- a/poulpy-core/src/layouts/prepared/lwe_to_glwe_switching_key.rs +++ b/poulpy-core/src/layouts/prepared/lwe_to_glwe_key.rs @@ -8,9 +8,9 @@ use crate::layouts::{ /// A special [GLWESwitchingKey] required to for the conversion from [LWE] to [GLWE]. #[derive(PartialEq, Eq)] -pub struct LWEToGLWESwitchingKeyPrepared(pub(crate) GLWESwitchingKeyPrepared); +pub struct LWEToGLWEKeyPrepared(pub(crate) GLWESwitchingKeyPrepared); -impl LWEInfos for LWEToGLWESwitchingKeyPrepared { +impl LWEInfos for LWEToGLWEKeyPrepared { fn base2k(&self) -> Base2K { self.0.base2k() } @@ -28,13 +28,13 @@ impl LWEInfos for LWEToGLWESwitchingKeyPrepared { } } -impl GLWEInfos for LWEToGLWESwitchingKeyPrepared { +impl GLWEInfos for LWEToGLWEKeyPrepared { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for LWEToGLWESwitchingKeyPrepared { +impl GGLWEInfos for LWEToGLWEKeyPrepared { fn dsize(&self) -> Dsize { self.0.dsize() } @@ -52,71 +52,65 @@ impl GGLWEInfos for LWEToGLWESwitchingKeyPrepared { } } -pub trait LWEToGLWESwitchingKeyPreparedFactory +pub trait LWEToGLWEKeyPreparedFactory where Self: GLWESwitchingKeyPreparedFactory, { - fn alloc_lwe_to_glwe_switching_key_prepared( + fn alloc_lwe_to_glwe_key_prepared( &self, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum, - ) -> LWEToGLWESwitchingKeyPrepared, B> { - LWEToGLWESwitchingKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1))) + ) -> LWEToGLWEKeyPrepared, B> { + LWEToGLWEKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1))) } - fn alloc_lwe_to_glwe_switching_key_prepared_from_infos(&self, infos: &A) -> LWEToGLWESwitchingKeyPrepared, B> + fn alloc_lwe_to_glwe_key_prepared_from_infos(&self, infos: &A) -> LWEToGLWEKeyPrepared, B> where A: GGLWEInfos, { debug_assert_eq!( infos.rank_in().0, 1, - "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + "rank_in > 1 is not supported for LWEToGLWEKey" ); debug_assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for LWEToGLWESwitchingKey" + "dsize > 1 is not supported for LWEToGLWEKey" ); - self.alloc_lwe_to_glwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) + self.alloc_lwe_to_glwe_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) } - fn bytes_of_lwe_to_glwe_switching_key_prepared( - &self, - base2k: Base2K, - k: TorusPrecision, - rank_out: Rank, - dnum: Dnum, - ) -> usize { - self.bytes_of_glwe_switching_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1)) + fn bytes_of_lwe_to_glwe_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize { + self.bytes_of_glwe_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1)) } - fn bytes_of_lwe_to_glwe_switching_key_prepared_from_infos(&self, infos: &A) -> usize + fn bytes_of_lwe_to_glwe_key_prepared_from_infos(&self, infos: &A) -> usize where A: GGLWEInfos, { debug_assert_eq!( infos.rank_in().0, 1, - "rank_in > 1 is not supported for LWEToGLWESwitchingKey" + "rank_in > 1 is not supported for LWEToGLWEKey" ); debug_assert_eq!( infos.dsize().0, 1, - "dsize > 1 is not supported for LWEToGLWESwitchingKey" + "dsize > 1 is not supported for LWEToGLWEKey" ); - self.bytes_of_lwe_to_glwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) + self.bytes_of_lwe_to_glwe_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) } - fn prepare_lwe_to_glwe_switching_key_tmp_bytes(&self, infos: &A) + fn prepare_lwe_to_glwe_key_tmp_bytes(&self, infos: &A) where A: GGLWEInfos, { self.prepare_glwe_switching_key_tmp_bytes(infos); } - fn prepare_lwe_to_glwe_switching_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) + fn prepare_lwe_to_glwe_key(&self, res: &mut R, other: &O, scratch: &mut Scratch) where R: GGLWEPreparedToMut + GLWESwitchingKeyDegreesMut, O: GGLWEToRef + GLWESwitchingKeyDegrees, @@ -125,61 +119,61 @@ where } } -impl LWEToGLWESwitchingKeyPreparedFactory for Module where Self: GLWESwitchingKeyPreparedFactory {} +impl LWEToGLWEKeyPreparedFactory for Module where Self: GLWESwitchingKeyPreparedFactory {} -impl LWEToGLWESwitchingKeyPrepared, B> { +impl LWEToGLWEKeyPrepared, B> { pub fn alloc_from_infos(module: &M, infos: &A) -> Self where A: GGLWEInfos, - M: LWEToGLWESwitchingKeyPreparedFactory, + M: LWEToGLWEKeyPreparedFactory, { - module.alloc_lwe_to_glwe_switching_key_prepared_from_infos(infos) + module.alloc_lwe_to_glwe_key_prepared_from_infos(infos) } pub fn alloc(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self where - M: LWEToGLWESwitchingKeyPreparedFactory, + M: LWEToGLWEKeyPreparedFactory, { - module.alloc_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum) + module.alloc_lwe_to_glwe_key_prepared(base2k, k, rank_out, dnum) } pub fn bytes_of_from_infos(module: &M, infos: &A) -> usize where A: GGLWEInfos, - M: LWEToGLWESwitchingKeyPreparedFactory, + M: LWEToGLWEKeyPreparedFactory, { - module.bytes_of_lwe_to_glwe_switching_key_prepared_from_infos(infos) + module.bytes_of_lwe_to_glwe_key_prepared_from_infos(infos) } pub fn bytes_of(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize where - M: LWEToGLWESwitchingKeyPreparedFactory, + M: LWEToGLWEKeyPreparedFactory, { - module.bytes_of_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum) + module.bytes_of_lwe_to_glwe_key_prepared(base2k, k, rank_out, dnum) } } -impl LWEToGLWESwitchingKeyPrepared, B> { +impl LWEToGLWEKeyPrepared, B> { pub fn prepare_tmp_bytes(&self, module: &M, infos: &A) where A: GGLWEInfos, - M: LWEToGLWESwitchingKeyPreparedFactory, + M: LWEToGLWEKeyPreparedFactory, { - module.prepare_lwe_to_glwe_switching_key_tmp_bytes(infos); + module.prepare_lwe_to_glwe_key_tmp_bytes(infos); } } -impl LWEToGLWESwitchingKeyPrepared { +impl LWEToGLWEKeyPrepared { pub fn prepare(&mut self, module: &M, other: &O, scratch: &mut Scratch) where O: GGLWEToRef + GLWESwitchingKeyDegrees, - M: LWEToGLWESwitchingKeyPreparedFactory, + M: LWEToGLWEKeyPreparedFactory, { - module.prepare_lwe_to_glwe_switching_key(self, other, scratch); + module.prepare_lwe_to_glwe_key(self, other, scratch); } } -impl GGLWEPreparedToRef for LWEToGLWESwitchingKeyPrepared +impl GGLWEPreparedToRef for LWEToGLWEKeyPrepared where GLWESwitchingKeyPrepared: GGLWEPreparedToRef, { @@ -188,7 +182,7 @@ where } } -impl GGLWEPreparedToMut for LWEToGLWESwitchingKeyPrepared +impl GGLWEPreparedToMut for LWEToGLWEKeyPrepared where GLWESwitchingKeyPrepared: GGLWEPreparedToMut, { @@ -197,7 +191,7 @@ where } } -impl GLWESwitchingKeyDegreesMut for LWEToGLWESwitchingKeyPrepared { +impl GLWESwitchingKeyDegreesMut for LWEToGLWEKeyPrepared { fn input_degree(&mut self) -> &mut Degree { &mut self.0.input_degree } diff --git a/poulpy-core/src/layouts/prepared/mod.rs b/poulpy-core/src/layouts/prepared/mod.rs index 8944b97..4d76cfb 100644 --- a/poulpy-core/src/layouts/prepared/mod.rs +++ b/poulpy-core/src/layouts/prepared/mod.rs @@ -1,4 +1,5 @@ mod gglwe; +mod gglwe_to_ggsw_key; mod ggsw; mod glwe; mod glwe_automorphism_key; @@ -6,11 +7,12 @@ mod glwe_public_key; mod glwe_secret; mod glwe_switching_key; mod glwe_tensor_key; -mod glwe_to_lwe_switching_key; +mod glwe_to_lwe_key; mod lwe_switching_key; -mod lwe_to_glwe_switching_key; +mod lwe_to_glwe_key; pub use gglwe::*; +pub use gglwe_to_ggsw_key::*; pub use ggsw::*; pub use glwe::*; pub use glwe_automorphism_key::*; @@ -18,6 +20,6 @@ pub use glwe_public_key::*; pub use glwe_secret::*; pub use glwe_switching_key::*; pub use glwe_tensor_key::*; -pub use glwe_to_lwe_switching_key::*; +pub use glwe_to_lwe_key::*; pub use lwe_switching_key::*; -pub use lwe_to_glwe_switching_key::*; +pub use lwe_to_glwe_key::*; diff --git a/poulpy-core/src/lib.rs b/poulpy-core/src/lib.rs index ccad084..e9c6499 100644 --- a/poulpy-core/src/lib.rs +++ b/poulpy-core/src/lib.rs @@ -4,6 +4,7 @@ mod decryption; mod dist; mod encryption; mod external_product; +mod glwe_packer; mod glwe_packing; mod glwe_trace; mod keyswitching; @@ -20,6 +21,7 @@ pub use decryption::*; pub use dist::*; pub use encryption::*; pub use external_product::*; +pub use glwe_packer::*; pub use glwe_packing::*; pub use glwe_trace::*; pub use keyswitching::*; diff --git a/poulpy-core/src/noise/gglwe.rs b/poulpy-core/src/noise/gglwe.rs index dc32d57..c6dd278 100644 --- a/poulpy-core/src/noise/gglwe.rs +++ b/poulpy-core/src/noise/gglwe.rs @@ -62,7 +62,7 @@ where let noise_have: f64 = pt.data.std(base2k, 0).log2(); - // println!("noise_have: {noise_have}"); + println!("noise_have: {noise_have}"); assert!( noise_have <= max_noise, diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index 492b611..9802c14 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -67,6 +67,14 @@ where ); } } + + // fn glwe_relinearize(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch) + // where + // R: GLWEToRef, + // A: GLWETensorToRef, + // T: GLWETensorKeyPreparedToRef, + // { + // } } pub trait GLWEAdd diff --git a/poulpy-core/src/scratch.rs b/poulpy-core/src/scratch.rs index 2220dc4..944fbd7 100644 --- a/poulpy-core/src/scratch.rs +++ b/poulpy-core/src/scratch.rs @@ -7,7 +7,7 @@ use crate::{ dist::Distribution, layouts::{ Degree, GGLWE, GGLWEInfos, GGLWELayout, GGSW, GGSWInfos, GLWE, GLWEAutomorphismKey, GLWEInfos, GLWEPlaintext, - GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESwitchingKey, GLWETensorKey, Rank, + GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESecretTensor, GLWESwitchingKey, GLWETensorKey, Rank, prepared::{ GGLWEPrepared, GGSWPrepared, GLWEAutomorphismKeyPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared, GLWESwitchingKeyPrepared, GLWETensorKeyPrepared, @@ -232,6 +232,18 @@ where ) } + fn take_glwe_secret_tensor(&mut self, n: Degree, rank: Rank) -> (GLWESecretTensor<&mut [u8]>, &mut Self) { + let (data, scratch) = self.take_scalar_znx(n.into(), GLWESecretTensor::pairs(rank.into())); + ( + GLWESecretTensor { + data, + rank, + dist: Distribution::NONE, + }, + scratch, + ) + } + fn take_glwe_secret_prepared(&mut self, module: &M, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self) where M: ModuleN + SvpPPolBytesOf, @@ -313,25 +325,12 @@ where infos.rank_out(), "rank_in != rank_out is not supported for GLWETensorKey" ); - let mut keys: Vec> = Vec::new(); - let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize; - - let mut scratch: &mut Self = self; + let pairs: u32 = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1); let mut ksk_infos: GGLWELayout = infos.gglwe_layout(); - ksk_infos.rank_in = Rank(1); - - if pairs != 0 { - let (gglwe, s) = scratch.take_gglwe(&ksk_infos); - scratch = s; - keys.push(gglwe); - } - for _ in 1..pairs { - let (gglwe, s) = scratch.take_gglwe(&ksk_infos); - scratch = s; - keys.push(gglwe); - } - (GLWETensorKey { keys }, scratch) + ksk_infos.rank_in = Rank(pairs); + let (data, scratch) = self.take_gglwe(infos); + (GLWETensorKey(data), scratch) } fn take_glwe_tensor_key_prepared(&mut self, module: &M, infos: &A) -> (GLWETensorKeyPrepared<&mut [u8], B>, &mut Self) @@ -346,25 +345,11 @@ where "rank_in != rank_out is not supported for GGLWETensorKeyPrepared" ); - let mut keys: Vec> = Vec::new(); - let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize; - - let mut scratch: &mut Self = self; - + let pairs: u32 = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1); let mut ksk_infos: GGLWELayout = infos.gglwe_layout(); - ksk_infos.rank_in = Rank(1); - - if pairs != 0 { - let (gglwe, s) = scratch.take_gglwe_prepared(module, &ksk_infos); - scratch = s; - keys.push(gglwe); - } - for _ in 1..pairs { - let (gglwe, s) = scratch.take_gglwe_prepared(module, &ksk_infos); - scratch = s; - keys.push(gglwe); - } - (GLWETensorKeyPrepared { keys }, scratch) + ksk_infos.rank_in = Rank(pairs); + let (data, scratch) = self.take_gglwe_prepared(module, infos); + (GLWETensorKeyPrepared(data), scratch) } } diff --git a/poulpy-core/src/tests/mod.rs b/poulpy-core/src/tests/mod.rs index dd16db0..aab0ec9 100644 --- a/poulpy-core/src/tests/mod.rs +++ b/poulpy-core/src/tests/mod.rs @@ -36,6 +36,7 @@ gglwe_automorphism_key_encrypt_sk => crate::tests::test_suite::encryption::test_ gglwe_automorphism_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_compressed_encrypt_sk, gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, +gglwe_to_ggsw_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_to_ggsw_key_encrypt_sk, // GGLWE Keyswitching gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, @@ -75,7 +76,7 @@ backend_test_suite!( glwe_encrypt_zero_sk => crate::tests::test_suite::encryption::test_glwe_encrypt_zero_sk, glwe_encrypt_pk => crate::tests::test_suite::encryption::test_glwe_encrypt_pk, // GLWE Keyswitch - glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, +glwe_keyswitch => crate::tests::test_suite::keyswitch::test_glwe_keyswitch, glwe_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_glwe_keyswitch_inplace, // GLWE Automorphism glwe_automorphism => crate::tests::test_suite::automorphism::test_glwe_automorphism, @@ -93,6 +94,7 @@ gglwe_automorphism_key_encrypt_sk => crate::tests::test_suite::encryption::test_ gglwe_automorphism_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_compressed_encrypt_sk, gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, +gglwe_to_ggsw_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_to_ggsw_key_encrypt_sk, // GGLWE Keyswitching gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, diff --git a/poulpy-core/src/tests/serialization.rs b/poulpy-core/src/tests/serialization.rs index c67d87d..14e62bb 100644 --- a/poulpy-core/src/tests/serialization.rs +++ b/poulpy-core/src/tests/serialization.rs @@ -1,12 +1,12 @@ use poulpy_hal::test_suite::serialization::test_reader_writer_interface; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWE, GGSW, GLWE, GLWEAutomorphismKey, GLWESwitchingKey, GLWETensorKey, GLWEToLWESwitchingKey, - LWE, LWESwitchingKey, LWEToGLWESwitchingKey, Rank, TorusPrecision, + Base2K, Degree, Dnum, Dsize, GGLWE, GGSW, GLWE, GLWEAutomorphismKey, GLWESwitchingKey, GLWETensorKey, GLWEToLWEKey, LWE, + LWESwitchingKey, LWEToGLWEKey, Rank, TorusPrecision, compressed::{ GGLWECompressed, GGSWCompressed, GLWEAutomorphismKeyCompressed, GLWECompressed, GLWESwitchingKeyCompressed, GLWETensorKeyCompressed, GLWEToLWESwitchingKeyCompressed, LWECompressed, LWESwitchingKeyCompressed, - LWEToGLWESwitchingKeyCompressed, + LWEToGLWEKeyCompressed, }, }; @@ -93,28 +93,27 @@ fn test_tensor_key_compressed_serialization() { } #[test] -fn glwe_to_lwe_switching_key_serialization() { - let original: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM); +fn glwe_to_lwe_key_serialization() { + let original: GLWEToLWEKey> = GLWEToLWEKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM); test_reader_writer_interface(original); } #[test] -fn glwe_to_lwe_switching_key_compressed_serialization() { +fn glwe_to_lwe_key_compressed_serialization() { let original: GLWEToLWESwitchingKeyCompressed> = GLWEToLWESwitchingKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM); test_reader_writer_interface(original); } #[test] -fn lwe_to_glwe_switching_key_serialization() { - let original: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM); +fn lwe_to_glwe_key_serialization() { + let original: LWEToGLWEKey> = LWEToGLWEKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM); test_reader_writer_interface(original); } #[test] -fn lwe_to_glwe_switching_key_compressed_serialization() { - let original: LWEToGLWESwitchingKeyCompressed> = - LWEToGLWESwitchingKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM); +fn lwe_to_glwe_key_compressed_serialization() { + let original: LWEToGLWEKeyCompressed> = LWEToGLWEKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM); test_reader_writer_interface(original); } diff --git a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs index b978a9d..6e2a226 100644 --- a/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/automorphism/ggsw_ct.rs @@ -5,12 +5,12 @@ use poulpy_hal::{ }; use crate::{ - GGSWAutomorphism, GGSWEncryptSk, GGSWNoise, GLWEAutomorphismKeyEncryptSk, GLWETensorKeyEncryptSk, ScratchTakeCore, + GGLWEToGGSWKeyEncryptSk, GGSWAutomorphism, GGSWEncryptSk, GGSWNoise, GLWEAutomorphismKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGSW, GGSWLayout, GLWEAutomorphismKey, GLWEAutomorphismKeyPreparedFactory, GLWESecret, GLWESecretPreparedFactory, - GLWETensorKey, GLWETensorKeyLayout, GLWETensorKeyPreparedFactory, - prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared, GLWETensorKeyPrepared}, + GGLWEToGGSWKey, GGLWEToGGSWKeyLayout, GGLWEToGGSWKeyPreparedFactory, GGSW, GGSWLayout, GLWEAutomorphismKey, + GLWEAutomorphismKeyPreparedFactory, GLWESecret, GLWESecretPreparedFactory, + prepared::{GGLWEToGGSWKeyPrepared, GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, }, noise::noise_ggsw_keyswitch, }; @@ -21,8 +21,8 @@ where + GLWEAutomorphismKeyEncryptSk + GLWEAutomorphismKeyPreparedFactory + GGSWAutomorphism - + GLWETensorKeyPreparedFactory - + GLWETensorKeyEncryptSk + + GGLWEToGGSWKeyPreparedFactory + + GGLWEToGGSWKeyEncryptSk + GLWESecretPreparedFactory + VecZnxAutomorphismInplace + GGSWNoise, @@ -64,7 +64,7 @@ where rank: rank.into(), }; - let tensor_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { + let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -73,7 +73,7 @@ where rank: rank.into(), }; - let auto_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { + let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -84,7 +84,7 @@ where let mut ct_in: GGSW> = GGSW::alloc_from_infos(&ggsw_in_layout); let mut ct_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_layout); - let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tensor_key_layout); + let mut tsk: GGLWEToGGSWKey> = GGLWEToGGSWKey::alloc_from_infos(&tsk_layout); let mut auto_key: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&auto_key_layout); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -95,8 +95,8 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc( GGSW::encrypt_sk_tmp_bytes(module, &ct_in) | GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key) - | GLWETensorKey::encrypt_sk_tmp_bytes(module, &tensor_key) - | GGSW::automorphism_tmp_bytes(module, &ct_out, &ct_in, &auto_key, &tensor_key), + | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk) + | GGSW::automorphism_tmp_bytes(module, &ct_out, &ct_in, &auto_key, &tsk), ); let var_xs: f64 = 0.5; @@ -115,7 +115,7 @@ where &mut source_xe, scratch.borrow(), ); - tensor_key.encrypt_sk( + tsk.encrypt_sk( module, &sk, &mut source_xa, @@ -138,9 +138,8 @@ where GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_layout); auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); - let mut tsk_prepared: GLWETensorKeyPrepared, BE> = - GLWETensorKeyPrepared::alloc_from_infos(module, &tensor_key_layout); - tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); + let mut tsk_prepared: GGLWEToGGSWKeyPrepared, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk); + tsk_prepared.prepare(module, &tsk, scratch.borrow()); ct_out.automorphism( module, @@ -180,8 +179,8 @@ where + GLWEAutomorphismKeyEncryptSk + GLWEAutomorphismKeyPreparedFactory + GGSWAutomorphism - + GLWETensorKeyPreparedFactory - + GLWETensorKeyEncryptSk + + GGLWEToGGSWKeyPreparedFactory + + GGLWEToGGSWKeyEncryptSk + GLWESecretPreparedFactory + VecZnxAutomorphismInplace + GGSWNoise, @@ -211,7 +210,7 @@ where rank: rank.into(), }; - let tensor_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { + let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), base2k: base2k.into(), k: k_tsk.into(), @@ -220,7 +219,7 @@ where rank: rank.into(), }; - let auto_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { + let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { n: n.into(), base2k: base2k.into(), k: k_ksk.into(), @@ -230,7 +229,7 @@ where }; let mut ct: GGSW> = GGSW::alloc_from_infos(&ggsw_out_layout); - let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tensor_key_layout); + let mut tsk: GGLWEToGGSWKey> = GGLWEToGGSWKey::alloc_from_infos(&tsk_layout); let mut auto_key: GLWEAutomorphismKey> = GLWEAutomorphismKey::alloc_from_infos(&auto_key_layout); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -241,8 +240,8 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc( GGSW::encrypt_sk_tmp_bytes(module, &ct) | GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key) - | GLWETensorKey::encrypt_sk_tmp_bytes(module, &tensor_key) - | GGSW::automorphism_tmp_bytes(module, &ct, &ct, &auto_key, &tensor_key), + | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk) + | GGSW::automorphism_tmp_bytes(module, &ct, &ct, &auto_key, &tsk), ); let var_xs: f64 = 0.5; @@ -261,7 +260,7 @@ where &mut source_xe, scratch.borrow(), ); - tensor_key.encrypt_sk( + tsk.encrypt_sk( module, &sk, &mut source_xa, @@ -284,9 +283,8 @@ where GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_layout); auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); - let mut tsk_prepared: GLWETensorKeyPrepared, BE> = - GLWETensorKeyPrepared::alloc_from_infos(module, &tensor_key_layout); - tsk_prepared.prepare(module, &tensor_key, scratch.borrow()); + let mut tsk_prepared: GGLWEToGGSWKeyPrepared, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk); + tsk_prepared.prepare(module, &tsk, scratch.borrow()); ct.automorphism_inplace(module, &auto_key_prepared, &tsk_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/conversion.rs b/poulpy-core/src/tests/test_suite/conversion.rs index c6e7d00..2412411 100644 --- a/poulpy-core/src/tests/test_suite/conversion.rs +++ b/poulpy-core/src/tests/test_suite/conversion.rs @@ -8,10 +8,10 @@ use crate::{ GLWEDecrypt, GLWEEncryptSk, GLWEFromLWE, GLWEToLWESwitchingKeyEncryptSk, LWEDecrypt, LWEEncryptSk, LWEToGLWESwitchingKeyEncryptSk, ScratchTakeCore, layouts::{ - Base2K, Degree, Dnum, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWEToLWEKeyLayout, - GLWEToLWESwitchingKey, GLWEToLWESwitchingKeyPreparedFactory, LWE, LWELayout, LWEPlaintext, LWESecret, - LWEToGLWESwitchingKey, LWEToGLWESwitchingKeyLayout, LWEToGLWESwitchingKeyPreparedFactory, Rank, TorusPrecision, - prepared::{GLWESecretPrepared, GLWEToLWESwitchingKeyPrepared, LWEToGLWESwitchingKeyPrepared}, + Base2K, Degree, Dnum, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWEToLWEKey, + GLWEToLWEKeyLayout, GLWEToLWEKeyPrepared, GLWEToLWEKeyPreparedFactory, LWE, LWELayout, LWEPlaintext, LWESecret, + LWEToGLWEKey, LWEToGLWEKeyLayout, LWEToGLWEKeyPrepared, LWEToGLWEKeyPreparedFactory, Rank, TorusPrecision, + prepared::GLWESecretPrepared, }, }; @@ -22,7 +22,7 @@ where + GLWEDecrypt + GLWESecretPreparedFactory + LWEEncryptSk - + LWEToGLWESwitchingKeyPreparedFactory, + + LWEToGLWEKeyPreparedFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { @@ -36,7 +36,7 @@ where let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let lwe_to_glwe_infos: LWEToGLWESwitchingKeyLayout = LWEToGLWESwitchingKeyLayout { + let lwe_to_glwe_infos: LWEToGLWEKeyLayout = LWEToGLWEKeyLayout { n: n_glwe, base2k: Base2K(17), k: TorusPrecision(51), @@ -58,7 +58,7 @@ where }; let mut scratch: ScratchOwned = ScratchOwned::alloc( - LWEToGLWESwitchingKey::encrypt_sk_tmp_bytes(module, &lwe_to_glwe_infos) + LWEToGLWEKey::encrypt_sk_tmp_bytes(module, &lwe_to_glwe_infos) | GLWE::from_lwe_tmp_bytes(module, &glwe_infos, &lwe_infos, &lwe_to_glwe_infos) | GLWE::decrypt_tmp_bytes(module, &glwe_infos), ); @@ -80,7 +80,7 @@ where let mut lwe_ct: LWE> = LWE::alloc_from_infos(&lwe_infos); lwe_ct.encrypt_sk(module, &lwe_pt, &sk_lwe, &mut source_xa, &mut source_xe); - let mut ksk: LWEToGLWESwitchingKey> = LWEToGLWESwitchingKey::alloc_from_infos(&lwe_to_glwe_infos); + let mut ksk: LWEToGLWEKey> = LWEToGLWEKey::alloc_from_infos(&lwe_to_glwe_infos); ksk.encrypt_sk( module, @@ -93,8 +93,7 @@ where let mut glwe_ct: GLWE> = GLWE::alloc_from_infos(&glwe_infos); - let mut ksk_prepared: LWEToGLWESwitchingKeyPrepared, BE> = - LWEToGLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); + let mut ksk_prepared: LWEToGLWEKeyPrepared, BE> = LWEToGLWEKeyPrepared::alloc_from_infos(module, &ksk); ksk_prepared.prepare(module, &ksk, scratch.borrow()); glwe_ct.from_lwe(module, &lwe_ct, &ksk_prepared, scratch.borrow()); @@ -114,7 +113,7 @@ where + GLWEDecrypt + GLWESecretPreparedFactory + GLWEToLWESwitchingKeyEncryptSk - + GLWEToLWESwitchingKeyPreparedFactory, + + GLWEToLWEKeyPreparedFactory, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { @@ -150,7 +149,7 @@ where let mut source_xe: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::alloc( - GLWEToLWESwitchingKey::encrypt_sk_tmp_bytes(module, &glwe_to_lwe_infos) + GLWEToLWEKey::encrypt_sk_tmp_bytes(module, &glwe_to_lwe_infos) | LWE::from_glwe_tmp_bytes(module, &lwe_infos, &glwe_infos, &glwe_to_lwe_infos) | GLWE::decrypt_tmp_bytes(module, &glwe_infos), ); @@ -178,7 +177,7 @@ where scratch.borrow(), ); - let mut ksk: GLWEToLWESwitchingKey> = GLWEToLWESwitchingKey::alloc_from_infos(&glwe_to_lwe_infos); + let mut ksk: GLWEToLWEKey> = GLWEToLWEKey::alloc_from_infos(&glwe_to_lwe_infos); ksk.encrypt_sk( module, @@ -191,8 +190,7 @@ where let mut lwe_ct: LWE> = LWE::alloc_from_infos(&lwe_infos); - let mut ksk_prepared: GLWEToLWESwitchingKeyPrepared, BE> = - GLWEToLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); + let mut ksk_prepared: GLWEToLWEKeyPrepared, BE> = GLWEToLWEKeyPrepared::alloc_from_infos(module, &ksk); ksk_prepared.prepare(module, &ksk, scratch.borrow()); lwe_ct.from_glwe(module, &glwe_ct, &ksk_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_to_ggsw_key.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_to_ggsw_key.rs new file mode 100644 index 0000000..884e21a --- /dev/null +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_to_ggsw_key.rs @@ -0,0 +1,144 @@ +use poulpy_hal::{ + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxCopy}, + layouts::{Backend, Module, ScalarZnx, Scratch, ScratchOwned}, + source::Source, +}; + +use crate::{ + GGLWENoise, GGLWEToGGSWKeyCompressedEncryptSk, GGLWEToGGSWKeyEncryptSk, ScratchTakeCore, + decryption::GLWEDecrypt, + encryption::SIGMA, + layouts::{ + Dsize, GGLWEDecompress, GGLWEToGGSWKey, GGLWEToGGSWKeyCompressed, GGLWEToGGSWKeyDecompress, GGLWEToGGSWKeyLayout, + GLWESecret, GLWESecretPreparedFactory, GLWESecretTensor, GLWESecretTensorFactory, LWEInfos, prepared::GLWESecretPrepared, + }, +}; + +pub fn test_gglwe_to_ggsw_key_encrypt_sk(module: &Module) +where + Module: GGLWEToGGSWKeyEncryptSk + + GLWESecretTensorFactory + + GLWESecretPreparedFactory + + GLWEDecrypt + + GGLWENoise + + VecZnxCopy, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, +{ + let base2k: usize = 8; + let k: usize = 54; + + for rank in 2_usize..3 { + let n: usize = module.n(); + let dnum: usize = k / base2k; + + let key_infos: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k.into(), + dnum: dnum.into(), + dsize: Dsize(1), + rank: rank.into(), + }; + + let mut key: GGLWEToGGSWKey> = GGLWEToGGSWKey::alloc_from_infos(&key_infos); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &key_infos)); + + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&key_infos); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); + + key.encrypt_sk( + module, + &sk, + &mut source_xa, + &mut source_xe, + scratch.borrow(), + ); + + let mut sk_tensor: GLWESecretTensor> = GLWESecretTensor::alloc_from_infos(&sk); + sk_tensor.prepare(module, &sk, scratch.borrow()); + + let max_noise = SIGMA.log2() + 0.5 - (key.k().as_u32() as f64); + + let mut pt_want: ScalarZnx> = ScalarZnx::alloc(module.n(), rank); + + for i in 0..rank { + for j in 0..rank { + module.vec_znx_copy( + &mut pt_want.as_vec_znx_mut(), + j, + &sk_tensor.at(i, j).as_vec_znx(), + 0, + ); + } + + println!("pt_want: {}", pt_want.as_vec_znx()); + + module.gglwe_assert_noise(key.at(i), &sk_prepared, &pt_want, max_noise); + } + } +} + +pub fn test_gglwe_to_ggsw_compressed_encrypt_sk(module: &Module) +where + Module: GGLWEToGGSWKeyCompressedEncryptSk + + GLWESecretPreparedFactory + + GLWEDecrypt + + GLWESecretTensorFactory + + GGLWENoise + + GGLWEDecompress + + GGLWEToGGSWKeyDecompress, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, + Scratch: ScratchAvailable + ScratchTakeCore, +{ + let base2k = 8; + let k = 54; + for rank in 1_usize..3 { + let n: usize = module.n(); + let dnum: usize = k / base2k; + + let key_infos: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { + n: n.into(), + base2k: base2k.into(), + k: k.into(), + dnum: dnum.into(), + dsize: Dsize(1), + rank: rank.into(), + }; + + let mut key_compressed: GGLWEToGGSWKeyCompressed> = GGLWEToGGSWKeyCompressed::alloc_from_infos(&key_infos); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEToGGSWKeyCompressed::encrypt_sk_tmp_bytes( + module, &key_infos, + )); + + let mut sk: GLWESecret> = GLWESecret::alloc_from_infos(&key_infos); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_prepared: GLWESecretPrepared, BE> = GLWESecretPrepared::alloc(module, rank.into()); + sk_prepared.prepare(module, &sk); + + let seed_xa: [u8; 32] = [1u8; 32]; + + key_compressed.encrypt_sk(module, &sk, seed_xa, &mut source_xe, scratch.borrow()); + + let mut key: GGLWEToGGSWKey> = GGLWEToGGSWKey::alloc_from_infos(&key_infos); + key.decompress(module, &key_compressed); + + let mut sk_tensor: GLWESecretTensor> = GLWESecretTensor::alloc_from_infos(&sk); + sk_tensor.prepare(module, &sk, scratch.borrow()); + + for i in 0..rank { + module.gglwe_assert_noise(key.at(i), &sk_prepared, &sk_tensor.data, SIGMA + 0.5); + } + } +} diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs index 940f917..26baa92 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_tsk.rs @@ -1,20 +1,16 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, VecZnxBigAlloc, VecZnxBigNormalize, - VecZnxCopy, VecZnxDftAlloc, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyTmpA, VecZnxSubScalarInplace, - VecZnxSwitchRing, - }, - layouts::{Backend, Module, Scratch, ScratchOwned, VecZnxBig, VecZnxDft}, + api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow}, + layouts::{Backend, Module, Scratch, ScratchOwned}, source::Source, }; use crate::{ - GLWETensorKeyCompressedEncryptSk, GLWETensorKeyEncryptSk, ScratchTakeCore, + GGLWENoise, GLWETensorKeyCompressedEncryptSk, GLWETensorKeyEncryptSk, ScratchTakeCore, decryption::GLWEDecrypt, encryption::SIGMA, layouts::{ - Dsize, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWETensorKey, GLWETensorKeyCompressed, GLWETensorKeyLayout, - prepared::GLWESecretPrepared, + Dsize, GGLWEDecompress, GLWESecret, GLWESecretPreparedFactory, GLWESecretTensor, GLWESecretTensorFactory, GLWETensorKey, + GLWETensorKeyCompressed, GLWETensorKeyLayout, prepared::GLWESecretPrepared, }, }; @@ -23,20 +19,15 @@ where Module: GLWETensorKeyEncryptSk + GLWESecretPreparedFactory + GLWEDecrypt - + VecZnxDftAlloc - + VecZnxBigAlloc - + VecZnxDftApply - + SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxBigNormalize - + VecZnxSubScalarInplace, + + GLWESecretTensorFactory + + GGLWENoise, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { let base2k: usize = 8; let k: usize = 54; - for rank in 1_usize..3 { + for rank in 2_usize..3 { let n: usize = module.n(); let dnum: usize = k / base2k; @@ -73,42 +64,10 @@ where scratch.borrow(), ); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&tensor_key_infos); + let mut sk_tensor: GLWESecretTensor> = GLWESecretTensor::alloc_from_infos(&sk); + sk_tensor.prepare(module, &sk, scratch.borrow()); - let mut sk_ij_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(1, 1); - let mut sk_ij_big: VecZnxBig, BE> = module.vec_znx_big_alloc(1, 1); - let mut sk_ij: GLWESecret> = GLWESecret::alloc(n.into(), 1_u32.into()); - let mut sk_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(rank, 1); - - for i in 0..rank { - module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); - } - - for i in 0..rank { - for j in 0..rank { - module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i); - module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); - module.vec_znx_big_normalize( - base2k, - &mut sk_ij.data.as_vec_znx_mut(), - 0, - base2k, - &sk_ij_big, - 0, - scratch.borrow(), - ); - for row_i in 0..dnum { - let ct = tensor_key.at(i, j).at(row_i, 0); - - ct.decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); - - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, 0); - - let std_pt: f64 = pt.data.std(base2k, 0) * (k as f64).exp2(); - assert!((SIGMA - std_pt).abs() <= 0.5, "{SIGMA} {std_pt}"); - } - } - } + module.gglwe_assert_noise(&tensor_key, &sk_prepared, &sk_tensor.data, SIGMA + 0.5); } } @@ -118,15 +77,9 @@ where + GLWESecretPreparedFactory + GLWETensorKeyCompressedEncryptSk + GLWEDecrypt - + VecZnxDftAlloc - + VecZnxBigAlloc - + VecZnxDftApply - + SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxSubScalarInplace - + VecZnxFillUniform - + VecZnxCopy - + VecZnxSwitchRing, + + GLWESecretTensorFactory + + GGLWENoise + + GGLWEDecompress, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchAvailable + ScratchTakeCore, { @@ -168,42 +121,9 @@ where let mut tensor_key: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tensor_key_infos); tensor_key.decompress(module, &tensor_key_compressed); - let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc_from_infos(&tensor_key_infos); + let mut sk_tensor: GLWESecretTensor> = GLWESecretTensor::alloc_from_infos(&sk); + sk_tensor.prepare(module, &sk, scratch.borrow()); - let mut sk_ij_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(1, 1); - let mut sk_ij_big: VecZnxBig, BE> = module.vec_znx_big_alloc(1, 1); - let mut sk_ij: GLWESecret> = GLWESecret::alloc(n.into(), 1_u32.into()); - let mut sk_dft: VecZnxDft, BE> = module.vec_znx_dft_alloc(rank, 1); - - for i in 0..rank { - module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); - } - - for i in 0..rank { - for j in 0..rank { - module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i); - module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); - module.vec_znx_big_normalize( - base2k, - &mut sk_ij.data.as_vec_znx_mut(), - 0, - base2k, - &sk_ij_big, - 0, - scratch.borrow(), - ); - for row_i in 0..dnum { - tensor_key - .at(i, j) - .at(row_i, 0) - .decrypt(module, &mut pt, &sk_prepared, scratch.borrow()); - - module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, 0); - - let std_pt: f64 = pt.data.std(base2k, 0) * (k as f64).exp2(); - assert!((SIGMA - std_pt).abs() <= 0.5, "{SIGMA} {std_pt}"); - } - } - } + module.gglwe_assert_noise(&tensor_key, &sk_prepared, &sk_tensor.data, SIGMA + 0.5); } } diff --git a/poulpy-core/src/tests/test_suite/encryption/mod.rs b/poulpy-core/src/tests/test_suite/encryption/mod.rs index d871177..0fe0f49 100644 --- a/poulpy-core/src/tests/test_suite/encryption/mod.rs +++ b/poulpy-core/src/tests/test_suite/encryption/mod.rs @@ -1,11 +1,13 @@ mod gglwe_atk; mod gglwe_ct; +mod gglwe_to_ggsw_key; mod ggsw_ct; mod glwe_ct; mod glwe_tsk; pub use gglwe_atk::*; pub use gglwe_ct::*; +pub use gglwe_to_ggsw_key::*; pub use ggsw_ct::*; pub use glwe_ct::*; pub use glwe_tsk::*; diff --git a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs index 28e71a5..c4191fa 100644 --- a/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs +++ b/poulpy-core/src/tests/test_suite/keyswitch/ggsw_ct.rs @@ -5,12 +5,13 @@ use poulpy_hal::{ }; use crate::{ - GGSWEncryptSk, GGSWKeyswitch, GGSWNoise, GLWESwitchingKeyEncryptSk, GLWETensorKeyEncryptSk, ScratchTakeCore, + GGLWEToGGSWKeyEncryptSk, GGSWEncryptSk, GGSWKeyswitch, GGSWNoise, GLWESwitchingKeyEncryptSk, ScratchTakeCore, encryption::SIGMA, layouts::{ - GGSW, GGSWLayout, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, - GLWESwitchingKeyPreparedFactory, GLWETensorKey, GLWETensorKeyLayout, GLWETensorKeyPreparedFactory, - prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared, GLWETensorKeyPrepared}, + GGLWEToGGSWKey, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedFactory, GGSW, GGSWLayout, GLWESecret, + GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, GLWESwitchingKeyPreparedFactory, + GLWETensorKeyLayout, + prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared}, }, noise::noise_ggsw_keyswitch, }; @@ -20,10 +21,10 @@ pub fn test_ggsw_keyswitch(module: &Module) where Module: GGSWEncryptSk + GLWESwitchingKeyEncryptSk - + GLWETensorKeyEncryptSk + + GGLWEToGGSWKeyEncryptSk + GGSWKeyswitch + GLWESecretPreparedFactory - + GLWETensorKeyPreparedFactory + + GGLWEToGGSWKeyPreparedFactory + GLWESwitchingKeyPreparedFactory + GGSWNoise, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, @@ -82,7 +83,7 @@ where let mut ggsw_in: GGSW> = GGSW::alloc_from_infos(&ggsw_in_infos); let mut ggsw_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_infos); - let mut tsk: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tsk_infos); + let mut tsk: GGLWEToGGSWKey> = GGLWEToGGSWKey::alloc_from_infos(&tsk_infos); let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&ksk_apply_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -93,7 +94,7 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc( GGSW::encrypt_sk_tmp_bytes(module, &ggsw_in_infos) | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) - | GLWETensorKey::encrypt_sk_tmp_bytes(module, &tsk_infos) + | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk_infos) | GGSW::keyswitch_tmp_bytes( module, &ggsw_out_infos, @@ -148,7 +149,7 @@ where GLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); ksk_prepared.prepare(module, &ksk, scratch.borrow()); - let mut tsk_prepared: GLWETensorKeyPrepared, BE> = GLWETensorKeyPrepared::alloc_from_infos(module, &tsk); + let mut tsk_prepared: GGLWEToGGSWKeyPrepared, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk); tsk_prepared.prepare(module, &tsk, scratch.borrow()); ggsw_out.keyswitch( @@ -185,10 +186,10 @@ pub fn test_ggsw_keyswitch_inplace(module: &Module) where Module: GGSWEncryptSk + GLWESwitchingKeyEncryptSk - + GLWETensorKeyEncryptSk + + GGLWEToGGSWKeyEncryptSk + GGSWKeyswitch + GLWESecretPreparedFactory - + GLWETensorKeyPreparedFactory + + GGLWEToGGSWKeyPreparedFactory + GLWESwitchingKeyPreparedFactory + GGSWNoise, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, @@ -236,7 +237,7 @@ where }; let mut ggsw_out: GGSW> = GGSW::alloc_from_infos(&ggsw_out_infos); - let mut tsk: GLWETensorKey> = GLWETensorKey::alloc_from_infos(&tsk_infos); + let mut tsk: GGLWEToGGSWKey> = GGLWEToGGSWKey::alloc_from_infos(&tsk_infos); let mut ksk: GLWESwitchingKey> = GLWESwitchingKey::alloc_from_infos(&ksk_apply_infos); let mut pt_scalar: ScalarZnx> = ScalarZnx::alloc(n, 1); @@ -247,7 +248,7 @@ where let mut scratch: ScratchOwned = ScratchOwned::alloc( GGSW::encrypt_sk_tmp_bytes(module, &ggsw_out_infos) | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) - | GLWETensorKey::encrypt_sk_tmp_bytes(module, &tsk_infos) + | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk_infos) | GGSW::keyswitch_tmp_bytes( module, &ggsw_out_infos, @@ -302,7 +303,7 @@ where GLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); ksk_prepared.prepare(module, &ksk, scratch.borrow()); - let mut tsk_prepared: GLWETensorKeyPrepared, BE> = GLWETensorKeyPrepared::alloc_from_infos(module, &tsk); + let mut tsk_prepared: GGLWEToGGSWKeyPrepared, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk); tsk_prepared.prepare(module, &tsk, scratch.borrow()); ggsw_out.keyswitch_inplace(module, &ksk_prepared, &tsk_prepared, scratch.borrow()); diff --git a/poulpy-core/src/tests/test_suite/packing.rs b/poulpy-core/src/tests/test_suite/packing.rs index 029e059..cf8284a 100644 --- a/poulpy-core/src/tests/test_suite/packing.rs +++ b/poulpy-core/src/tests/test_suite/packing.rs @@ -7,7 +7,7 @@ use poulpy_hal::{ }; use crate::{ - GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWEPacker, GLWEPacking, GLWERotate, GLWESub, ScratchTakeCore, + GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWEPacker, GLWEPackerOps, GLWERotate, GLWESub, ScratchTakeCore, layouts::{ GLWE, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, @@ -20,7 +20,7 @@ where Module: GLWEEncryptSk + GLWEAutomorphismKeyEncryptSk + GLWEAutomorphismKeyPreparedFactory - + GLWEPacking + + GLWEPackerOps + GLWESecretPreparedFactory + GLWESub + GLWEDecrypt diff --git a/poulpy-hal/Cargo.toml b/poulpy-hal/Cargo.toml index 93325b3..f114681 100644 --- a/poulpy-hal/Cargo.toml +++ b/poulpy-hal/Cargo.toml @@ -19,7 +19,7 @@ rand_core = {workspace = true} byteorder = {workspace = true} once_cell = {workspace = true} rand_chacha = "0.9.0" -bytemuck = "1.23.2" +bytemuck = {workspace = true} [build-dependencies] diff --git a/poulpy-hal/src/api/convolution.rs b/poulpy-hal/src/api/convolution.rs index 10caf6b..d1c6c5e 100644 --- a/poulpy-hal/src/api/convolution.rs +++ b/poulpy-hal/src/api/convolution.rs @@ -78,6 +78,7 @@ where self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size) } + #[allow(clippy::too_many_arguments)] /// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K over the /// selected columsn and stores the result on the selected column, scaled by 2^{k * Base2K} /// @@ -139,6 +140,7 @@ where } } + #[allow(clippy::too_many_arguments)] fn bivariate_convolution( &self, k: i64, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs index 373eb7c..6a541cf 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/key.rs @@ -10,8 +10,8 @@ use crate::tfhe::{ use poulpy_core::{ GLWEToLWESwitchingKeyEncryptSk, GetDistribution, LWEFromGLWE, ScratchTakeCore, layouts::{ - GGSWInfos, GGSWPreparedFactory, GLWEInfos, GLWESecretToRef, GLWEToLWEKeyLayout, GLWEToLWESwitchingKey, - GLWEToLWESwitchingKeyPreparedFactory, LWE, LWEInfos, LWESecretToRef, prepared::GLWEToLWESwitchingKeyPrepared, + GGSWInfos, GGSWPreparedFactory, GLWEInfos, GLWESecretToRef, GLWEToLWEKey, GLWEToLWEKeyLayout, + GLWEToLWEKeyPreparedFactory, LWE, LWEInfos, LWESecretToRef, prepared::GLWEToLWEKeyPrepared, }, }; use poulpy_hal::{ @@ -46,7 +46,7 @@ where BRA: BlindRotationAlgo, { cbt: CircuitBootstrappingKey, - ks: GLWEToLWESwitchingKey, + ks: GLWEToLWEKey, } impl BDDKey, BRA> @@ -59,7 +59,7 @@ where { Self { cbt: CircuitBootstrappingKey::alloc_from_infos(&infos.cbt_infos()), - ks: GLWEToLWESwitchingKey::alloc_from_infos(&infos.ks_infos()), + ks: GLWEToLWEKey::alloc_from_infos(&infos.ks_infos()), } } } @@ -130,12 +130,12 @@ where BE: Backend, { pub(crate) cbt: CircuitBootstrappingKeyPrepared, - pub(crate) ks: GLWEToLWESwitchingKeyPrepared, + pub(crate) ks: GLWEToLWEKeyPrepared, } pub trait BDDKeyPreparedFactory where - Self: Sized + CircuitBootstrappingKeyPreparedFactory + GLWEToLWESwitchingKeyPreparedFactory, + Self: Sized + CircuitBootstrappingKeyPreparedFactory + GLWEToLWEKeyPreparedFactory, { fn alloc_bdd_key_from_infos(&self, infos: &A) -> BDDKeyPrepared, BRA, BE> where @@ -143,7 +143,7 @@ where { BDDKeyPrepared { cbt: CircuitBootstrappingKeyPrepared::alloc_from_infos(self, &infos.cbt_infos()), - ks: GLWEToLWESwitchingKeyPrepared::alloc_from_infos(self, &infos.ks_infos()), + ks: GLWEToLWEKeyPrepared::alloc_from_infos(self, &infos.ks_infos()), } } @@ -152,7 +152,7 @@ where A: BDDKeyInfos, { self.circuit_bootstrapping_key_prepare_tmp_bytes(&infos.cbt_infos()) - .max(self.prepare_glwe_to_lwe_switching_key_tmp_bytes(&infos.ks_infos())) + .max(self.prepare_glwe_to_lwe_key_tmp_bytes(&infos.ks_infos())) } fn prepare_bdd_key(&self, res: &mut BDDKeyPrepared, other: &BDDKey, scratch: &mut Scratch) @@ -166,7 +166,7 @@ where } } impl BDDKeyPreparedFactory for Module where - Self: Sized + CircuitBootstrappingKeyPreparedFactory + GLWEToLWESwitchingKeyPreparedFactory + Self: Sized + CircuitBootstrappingKeyPreparedFactory + GLWEToLWEKeyPreparedFactory { } diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs index a627e0c..a5adc52 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/circuit.rs @@ -6,7 +6,7 @@ use poulpy_hal::{ }; use poulpy_core::{ - GGSWFromGGLWE, GLWEDecrypt, GLWEPacking, GLWETrace, ScratchTakeCore, + GGSWFromGGLWE, GLWEDecrypt, GLWEPacking, GLWERotate, GLWETrace, ScratchTakeCore, layouts::{ Dsize, GGLWELayout, GGSWInfos, GGSWToMut, GLWEInfos, GLWESecretPreparedFactory, GLWEToMut, GLWEToRef, LWEInfos, LWEToRef, }, @@ -115,7 +115,8 @@ where + GLWEPacking + GGSWFromGGLWE + GLWESecretPreparedFactory - + GLWEDecrypt, + + GLWEDecrypt + + GLWERotate, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchTakeCore, { @@ -216,7 +217,9 @@ pub fn circuit_bootstrap_core( + GLWEPacking + GGSWFromGGLWE + GLWESecretPreparedFactory - + GLWEDecrypt, + + GLWEDecrypt + + GLWERotate + + ModuleLogN, ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, Scratch: ScratchTakeCore, { @@ -332,7 +335,7 @@ fn post_process( ) where R: GLWEToMut, A: GLWEToRef, - M: ModuleLogN + GLWETrace + GLWEPacking, + M: ModuleLogN + GLWETrace + GLWEPacking + GLWERotate, Scratch: ScratchTakeCore, { let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs index c6b8adc..de57832 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key.rs @@ -1,8 +1,8 @@ use poulpy_core::{ - Distribution, GLWEAutomorphismKeyEncryptSk, GLWETensorKeyEncryptSk, GetDistribution, ScratchTakeCore, + Distribution, GGLWEToGGSWKeyEncryptSk, GLWEAutomorphismKeyEncryptSk, GetDistribution, ScratchTakeCore, layouts::{ - GGLWEInfos, GGSWInfos, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEInfos, GLWESecretPreparedFactory, - GLWESecretToRef, GLWETensorKey, GLWETensorKeyLayout, LWEInfos, LWESecretToRef, prepared::GLWESecretPrepared, + GGLWEInfos, GGLWEToGGSWKey, GGSWInfos, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEInfos, + GLWESecretPreparedFactory, GLWESecretToRef, GLWETensorKeyLayout, LWEInfos, LWESecretToRef, prepared::GLWESecretPrepared, }, trace_galois_elements, }; @@ -81,14 +81,14 @@ impl CircuitBootstrappingKey, BRA> { (gal_el, key) }) .collect(), - tsk: GLWETensorKey::alloc_from_infos(trk_infos), + tsk: GGLWEToGGSWKey::alloc_from_infos(trk_infos), } } } pub struct CircuitBootstrappingKey { pub(crate) brk: BlindRotationKey, - pub(crate) tsk: GLWETensorKey>, + pub(crate) tsk: GGLWEToGGSWKey>, pub(crate) atk: HashMap>>, } @@ -112,7 +112,7 @@ impl CircuitBootstrappingKey { impl CircuitBootstrappingKeyEncryptSk for Module where - Self: GLWETensorKeyEncryptSk + Self: GGLWEToGGSWKeyEncryptSk + BlindRotationKeyEncryptSk + GLWEAutomorphismKeyEncryptSk + GLWESecretPreparedFactory, diff --git a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_prepared.rs b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_prepared.rs index 6adca70..c611846 100644 --- a/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_prepared.rs +++ b/poulpy-schemes/src/tfhe/circuit_bootstrapping/key_prepared.rs @@ -1,8 +1,8 @@ use poulpy_core::{ layouts::{ - GGLWEInfos, GGSWInfos, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWEInfos, GLWETensorKeyLayout, - GLWETensorKeyPreparedFactory, LWEInfos, - prepared::{GLWEAutomorphismKeyPrepared, GLWETensorKeyPrepared}, + GGLWEInfos, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedFactory, GGSWInfos, GLWEAutomorphismKeyLayout, + GLWEAutomorphismKeyPreparedFactory, GLWEInfos, GLWETensorKeyLayout, GLWETensorKeyPreparedFactory, LWEInfos, + prepared::GLWEAutomorphismKeyPrepared, }, trace_galois_elements, }; @@ -50,7 +50,7 @@ pub trait CircuitBootstrappingKeyPreparedFactory - + GLWETensorKeyPreparedFactory + + GGLWEToGGSWKeyPreparedFactory + GLWEAutomorphismKeyPreparedFactory, { fn circuit_bootstrapping_key_prepared_alloc_from_infos( @@ -65,7 +65,7 @@ where CircuitBootstrappingKeyPrepared { brk: BlindRotationKeyPrepared::alloc(self, &infos.brk_infos()), - tsk: GLWETensorKeyPrepared::alloc_from_infos(self, &infos.tsk_infos()), + tsk: GGLWEToGGSWKeyPrepared::alloc_from_infos(self, &infos.tsk_infos()), atk: gal_els .iter() .map(|&gal_el| { @@ -81,7 +81,7 @@ where A: CircuitBootstrappingKeyInfos, { self.blind_rotation_key_prepare_tmp_bytes(&infos.brk_infos()) - .max(self.prepare_tensor_key_tmp_bytes(&infos.tsk_infos())) + .max(self.prepare_gglwe_to_ggsw_key_tmp_bytes(&infos.tsk_infos())) .max(self.prepare_glwe_automorphism_key_tmp_bytes(&infos.atk_infos())) } @@ -105,7 +105,7 @@ where pub struct CircuitBootstrappingKeyPrepared { pub(crate) brk: BlindRotationKeyPrepared, - pub(crate) tsk: GLWETensorKeyPrepared, B>, + pub(crate) tsk: GGLWEToGGSWKeyPrepared, B>, pub(crate) atk: HashMap, B>>, }