diff --git a/poulpy-core/src/automorphism/mod.rs b/poulpy-core/src/automorphism/mod.rs index 1cd7bea..fd10f33 100644 --- a/poulpy-core/src/automorphism/mod.rs +++ b/poulpy-core/src/automorphism/mod.rs @@ -4,4 +4,4 @@ mod glwe_ct; pub use gglwe_atk::*; pub use ggsw_ct::*; -pub use glwe_ct::*; \ No newline at end of file +pub use glwe_ct::*; diff --git a/poulpy-core/src/conversion/gglwe_to_ggsw.rs b/poulpy-core/src/conversion/gglwe_to_ggsw.rs index a7b86fa..24d02bd 100644 --- a/poulpy-core/src/conversion/gglwe_to_ggsw.rs +++ b/poulpy-core/src/conversion/gglwe_to_ggsw.rs @@ -1,19 +1,18 @@ use poulpy_hal::{ api::{ - ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxCopy, VecZnxDftAddInplace, - VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, - VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftAddInplace, VecZnxDftApply, + VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, + VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, }, layouts::{Backend, DataMut, Module, Scratch, VmpPMat, ZnxInfos}, }; use crate::{ - ScratchTakeCore, + GLWECopy, ScratchTakeCore, layouts::{ GGLWE, GGLWEInfos, GGLWEToRef, GGSW, GGSWInfos, GGSWToMut, GLWEInfos, LWEInfos, prepared::{TensorKeyPrepared, TensorKeyPreparedToRef}, }, - operations::GLWEOperations, }; impl GGLWE> { @@ -39,11 +38,11 @@ impl GGSW { } } -impl GGSWFromGGLWE for Module where Self: GGSWExpandRows + VecZnxCopy {} +impl GGSWFromGGLWE for Module where Self: GGSWExpandRows + GLWECopy {} pub trait GGSWFromGGLWE where - Self: GGSWExpandRows + VecZnxCopy, + Self: GGSWExpandRows + GLWECopy, { fn ggsw_from_gglwe_tmp_bytes(&self, res_infos: &R, tsk_infos: &A) -> usize where @@ -71,7 +70,7 @@ where assert_eq!(tsk.n(), self.n() as u32); for row in 0..res.dnum().into() { - res.at_mut(row, 0).copy(self, &a.at(row, 0)); + self.glwe_copy(&mut res.at_mut(row, 0), &a.at(row, 0)); } self.ggsw_expand_row(res, tsk, scratch); diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index 129de7e..6fbdb1d 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -1,14 +1,13 @@ use std::collections::HashMap; use poulpy_hal::{ - api::{ - ModuleLogN, ModuleN, ScratchAvailable, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes - }, + api::{ModuleLogN, VecZnxCopy, VecZnxRotateInplace}, layouts::{Backend, DataMut, DataRef, GaloisElement, Module, Scratch}, }; use crate::{ - layouts::{prepared::{AutomorphismKeyPrepared, AutomorphismKeyPreparedToRef}, GGLWEInfos, GLWEAlloc, GLWEInfos, GLWEToRef, LWEInfos, GLWE}, GLWEAutomorphism, GLWEOperations, ScratchTakeCore + GLWEAutomorphism, ScratchTakeCore, + layouts::{GGLWEInfos, GLWE, GLWEAlloc, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::AutomorphismKeyPreparedToRef}, }; /// [GLWEPacker] enables only the fly GLWE packing @@ -41,7 +40,7 @@ impl Accumulator { pub fn alloc(module: &M, infos: &A) -> Self where A: GLWEInfos, - M: GLWEAlloc + M: GLWEAlloc, { Self { data: GLWE::alloc_from_infos(module, infos), @@ -65,7 +64,7 @@ impl GLWEPacker { pub fn new(module: &M, infos: &A, log_batch: usize) -> Self where A: GLWEInfos, - M: GLWEAlloc + M: GLWEAlloc, { let mut accumulators: Vec = Vec::::new(); let log_n: usize = infos.n().log2(); @@ -109,13 +108,8 @@ impl GLWEPacker { /// * `a`: ciphertext to pack. Can optionally give None to pack a 0 ciphertext. /// * `auto_keys`: a [HashMap] containing the [AutomorphismKeyExec]s. /// * `scratch`: scratch space of size at least [Self::tmp_bytes]. - pub fn add( - &mut self, - module: &M, - a: Option<&A>, - auto_keys: &HashMap, - scratch: &mut Scratch, - ) where + pub fn add(&mut self, module: &M, a: Option<&A>, auto_keys: &HashMap, scratch: &mut Scratch) + where A: GLWEToRef, K: AutomorphismKeyPreparedToRef, M: GLWEAutomorphism, @@ -327,110 +321,69 @@ fn combine( } } -/// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)] -/// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)] -pub fn glwe_packing( - module: &Module, - cts: &mut HashMap>, - log_gap_out: usize, - auto_keys: &HashMap>, - scratch: &mut Scratch, -) where - ATK: DataRef, - Module: VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxNormalizeTmpBytes - + VecZnxSwitchRing - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxDftCopy - + VecZnxIdftApplyTmpA - + VecZnxSub - + VecZnxAddInplace - + VecZnxNegateInplace - + VecZnxCopy - + VecZnxSubInplace - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxRotate - + VecZnxNormalize, - Scratch: ScratchAvailable, +pub trait GLWEPacking +where + Self: GLWEAutomorphism + GaloisElement + ModuleLogN, { - #[cfg(debug_assertions)] + /// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)] + /// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)] + fn glwe_pack( + &self, + cts: &mut HashMap, + log_gap_out: usize, + keys: &HashMap, + scratch: &mut Scratch, + ) where + R: GLWEToMut, + K: AutomorphismKeyPreparedToRef, + Scratch: ScratchTakeCore, { - assert!(*cts.keys().max().unwrap() < module.n()) - } + #[cfg(debug_assertions)] + { + assert!(*cts.keys().max().unwrap() < self.n()) + } - let log_n: usize = module.log_n(); + let log_n: usize = self.log_n(); - (0..log_n - log_gap_out).for_each(|i| { - let t: usize = (1 << log_n).min(1 << (log_n - 1 - i)); + (0..log_n - log_gap_out).for_each(|i| { + let t: usize = (1 << log_n).min(1 << (log_n - 1 - i)); - let auto_key: &AutomorphismKeyPrepared = if i == 0 { - auto_keys.get(&-1).unwrap() - } else { - auto_keys.get(&module.galois_element(1 << (i - 1))).unwrap() - }; + let key: &K = if i == 0 { + keys.get(&-1).unwrap() + } else { + keys.get(&self.galois_element(1 << (i - 1))).unwrap() + }; - (0..t).for_each(|j| { - let mut a: Option<&mut GLWE> = cts.remove(&j); - let mut b: Option<&mut GLWE> = cts.remove(&(j + t)); + (0..t).for_each(|j| { + let mut a: Option<&mut R> = cts.remove(&j); + let mut b: Option<&mut R> = cts.remove(&(j + t)); - pack_internal(module, &mut a, &mut b, i, auto_key, scratch); + pack_internal(self, &mut a, &mut b, i, key, scratch); - if let Some(a) = a { - cts.insert(j, a); - } else if let Some(b) = b { - cts.insert(j, b); - } + if let Some(a) = a { + cts.insert(j, a); + } else if let Some(b) = b { + cts.insert(j, b); + } + }); }); - }); + } } #[allow(clippy::too_many_arguments)] -fn pack_internal( - module: &Module, - a: &mut Option<&mut GLWE>, - b: &mut Option<&mut GLWE>, +fn pack_internal( + module: &M, + a: &mut Option<&mut A>, + b: &mut Option<&mut B>, i: usize, - auto_key: &AutomorphismKeyPrepared, - scratch: &mut Scratch, + auto_key: &K, + scratch: &mut Scratch, ) where - Module: VecZnxRotateInplace - + VecZnxNormalizeInplace - + VecZnxNormalizeTmpBytes - + VecZnxBigAutomorphismInplace - + VecZnxRshInplace - + VecZnxDftCopy - + VecZnxIdftApplyTmpA - + VecZnxSub - + VecZnxAddInplace - + VecZnxNegateInplace - + VecZnxCopy - + VecZnxSubInplace - + VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxRotate - + VecZnxNormalize, - Scratch: ScratchAvailable, + M: GLWEAutomorphism, + A: GLWEToMut + GLWEInfos, + B: GLWEToMut + GLWEInfos, + K: AutomorphismKeyPreparedToRef, + Scratch: ScratchTakeCore, { // Goal is to evaluate: a = a + b*X^t + phi(a - b*X^t)) // We also use the identity: AUTO(a * X^t, g) = -X^t * AUTO(a, g) @@ -446,7 +399,7 @@ fn pack_internal( let t: i64 = 1 << (a.n().log2() - i - 1); if let Some(b) = b.as_deref_mut() { - let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a); + let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(module, a); // a = a * X^-t a.rotate_inplace(module, -t, scratch_1); diff --git a/poulpy-core/src/glwe_trace.rs b/poulpy-core/src/glwe_trace.rs index 36aabb9..dc5582c 100644 --- a/poulpy-core/src/glwe_trace.rs +++ b/poulpy-core/src/glwe_trace.rs @@ -9,10 +9,7 @@ use poulpy_hal::{ layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx}, }; -use crate::{ - layouts::{Base2K, GGLWEInfos, GLWE, GLWEInfos, GLWELayout, LWEInfos, prepared::AutomorphismKeyPrepared}, - operations::GLWEOperations, -}; +use crate::layouts::{Base2K, GGLWEInfos, GLWE, GLWEInfos, GLWELayout, LWEInfos, prepared::AutomorphismKeyPrepared}; impl GLWE> { pub fn trace_galois_elements(module: &Module) -> Vec { diff --git a/poulpy-core/src/keyswitching/mod.rs b/poulpy-core/src/keyswitching/mod.rs index 462e474..7071680 100644 --- a/poulpy-core/src/keyswitching/mod.rs +++ b/poulpy-core/src/keyswitching/mod.rs @@ -4,6 +4,6 @@ mod glwe_ct; mod lwe_ct; pub use gglwe_ct::*; -//pub use gglwe_ct::*; +// pub use gglwe_ct::*; pub use glwe_ct::*; pub use lwe_ct::*; diff --git a/poulpy-core/src/lib.rs b/poulpy-core/src/lib.rs index 49c39b0..2dcc77a 100644 --- a/poulpy-core/src/lib.rs +++ b/poulpy-core/src/lib.rs @@ -14,12 +14,12 @@ mod utils; pub use operations::*; pub mod layouts; +pub use automorphism::*; pub use conversion::*; pub use dist::*; pub use external_product::*; pub use glwe_packing::*; pub use keyswitching::*; -pub use automorphism::*; pub use encryption::SIGMA; diff --git a/poulpy-core/src/operations/glwe.rs b/poulpy-core/src/operations/glwe.rs index 021d10e..4ad713c 100644 --- a/poulpy-core/src/operations/glwe.rs +++ b/poulpy-core/src/operations/glwe.rs @@ -1,320 +1,292 @@ use poulpy_hal::{ api::{ - VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, + ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSubNegateInplace, }, - layouts::{Backend, DataMut, Scratch, VecZnx, ZnxZero}, + layouts::{Backend, Module, Scratch, VecZnx, ZnxZero}, }; -use crate::layouts::{GLWE, GLWEInfos, GLWEPlaintext, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision}; +use crate::{ + ScratchTakeCore, + layouts::{GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision}, +}; -impl GLWEOperations for GLWEPlaintext +pub trait GLWEAdd where - D: DataMut, - GLWEPlaintext: GLWEToMut + GLWEInfos, + Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace, { + fn glwe_add(&self, res: &mut R, a: &A, b: &B) + where + R: GLWEToMut, + A: GLWEToRef, + B: GLWEToRef, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &mut GLWE<&[u8]> = &mut a.to_ref(); + let b: &GLWE<&[u8]> = &b.to_ref(); + + assert_eq!(a.n(), self.n() as u32); + assert_eq!(b.n(), self.n() as u32); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.base2k(), b.base2k()); + assert!(res.rank() >= a.rank().max(b.rank())); + + let min_col: usize = (a.rank().min(b.rank()) + 1).into(); + let max_col: usize = (a.rank().max(b.rank() + 1)).into(); + let self_col: usize = (res.rank() + 1).into(); + + (0..min_col).for_each(|i| { + self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i); + }); + + if a.rank() > b.rank() { + (min_col..max_col).for_each(|i| { + self.vec_znx_copy(res.data_mut(), i, a.data(), i); + }); + } else { + (min_col..max_col).for_each(|i| { + self.vec_znx_copy(res.data_mut(), i, b.data(), i); + }); + } + + let size: usize = res.size(); + (max_col..self_col).for_each(|i| { + (0..size).for_each(|j| { + res.data.zero_at(i, j); + }); + }); + + res.set_base2k(a.base2k()); + res.set_k(set_k_binary(res, a, b)); + } + + fn glwe_add_inplace(&self, res: &mut R, a: &A) + where + R: GLWEToMut, + A: GLWEToRef, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.base2k(), a.base2k()); + assert!(res.rank() >= a.rank()); + + (0..(a.rank() + 1).into()).for_each(|i| { + self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i); + }); + + res.set_k(set_k_unary(res, a)) + } } -impl GLWEOperations for GLWE where GLWE: GLWEToMut + GLWEInfos {} +impl GLWEAdd for Module where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace {} -pub trait GLWEOperations: GLWEToMut + GLWEInfos + SetGLWEInfos + Sized { - fn add(&mut self, module: &M, a: &A, b: &B) +pub trait GLWESub +where + Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace, +{ + fn glwe_sub(&self, res: &mut R, a: &A, b: &B) where - A: GLWEToRef + GLWEInfos, - B: GLWEToRef + GLWEInfos, - M: VecZnxAdd + VecZnxCopy, + R: GLWEToMut, + A: GLWEToRef, + B: GLWEToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(a.base2k(), b.base2k()); - assert!(self.rank() >= a.rank().max(b.rank())); - } + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + let b: &GLWE<&[u8]> = &b.to_ref(); + + assert_eq!(a.n(), self.n() as u32); + assert_eq!(b.n(), self.n() as u32); + assert_eq!(a.base2k(), b.base2k()); + assert!(res.rank() >= a.rank().max(b.rank())); let min_col: usize = (a.rank().min(b.rank()) + 1).into(); let max_col: usize = (a.rank().max(b.rank() + 1)).into(); - let self_col: usize = (self.rank() + 1).into(); - - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); - let b_ref: &GLWE<&[u8]> = &b.to_ref(); + let self_col: usize = (res.rank() + 1).into(); (0..min_col).for_each(|i| { - module.vec_znx_add(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i); + self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i); }); if a.rank() > b.rank() { (min_col..max_col).for_each(|i| { - module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i); + self.vec_znx_copy(res.data_mut(), i, a.data(), i); }); } else { (min_col..max_col).for_each(|i| { - module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i); + self.vec_znx_copy(res.data_mut(), i, b.data(), i); + self.vec_znx_negate_inplace(res.data_mut(), i); }); } - let size: usize = self_mut.size(); + let size: usize = res.size(); (max_col..self_col).for_each(|i| { (0..size).for_each(|j| { - self_mut.data.zero_at(i, j); + res.data.zero_at(i, j); }); }); - self.set_base2k(a.base2k()); - self.set_k(set_k_binary(self, a, b)); + res.set_base2k(a.base2k()); + res.set_k(set_k_binary(res, a, b)); } - fn add_inplace(&mut self, module: &M, a: &A) + fn glwe_sub_inplace(&self, res: &mut R, a: &A) where - A: GLWEToRef + GLWEInfos, - M: VecZnxAddInplace, + R: GLWEToMut, + A: GLWEToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(self.base2k(), a.base2k()); - assert!(self.rank() >= a.rank()) - } + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.base2k(), a.base2k()); + assert!(res.rank() >= a.rank()); (0..(a.rank() + 1).into()).for_each(|i| { - module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i); + self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i); }); - self.set_k(set_k_unary(self, a)) + res.set_k(set_k_unary(res, a)) } - fn sub(&mut self, module: &M, a: &A, b: &B) + fn glwe_sub_negate_inplace(&self, res: &mut R, a: &A) where - A: GLWEToRef + GLWEInfos, - B: GLWEToRef + GLWEInfos, - M: VecZnxSub + VecZnxCopy + VecZnxNegateInplace, + R: GLWEToMut, + A: GLWEToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(b.n(), self.n()); - assert_eq!(a.base2k(), b.base2k()); - assert!(self.rank() >= a.rank().max(b.rank())); - } + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); - let min_col: usize = (a.rank().min(b.rank()) + 1).into(); - let max_col: usize = (a.rank().max(b.rank() + 1)).into(); - let self_col: usize = (self.rank() + 1).into(); - - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); - let b_ref: &GLWE<&[u8]> = &b.to_ref(); - - (0..min_col).for_each(|i| { - module.vec_znx_sub(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i); - }); - - if a.rank() > b.rank() { - (min_col..max_col).for_each(|i| { - module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i); - }); - } else { - (min_col..max_col).for_each(|i| { - module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i); - module.vec_znx_negate_inplace(&mut self_mut.data, i); - }); - } - - let size: usize = self_mut.size(); - (max_col..self_col).for_each(|i| { - (0..size).for_each(|j| { - self_mut.data.zero_at(i, j); - }); - }); - - self.set_base2k(a.base2k()); - self.set_k(set_k_binary(self, a, b)); - } - - fn sub_inplace_ab(&mut self, module: &M, a: &A) - where - A: GLWEToRef + GLWEInfos, - M: VecZnxSubInplace, - { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(self.base2k(), a.base2k()); - assert!(self.rank() >= a.rank()) - } - - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.base2k(), a.base2k()); + assert!(res.rank() >= a.rank()); (0..(a.rank() + 1).into()).for_each(|i| { - module.vec_znx_sub_inplace(&mut self_mut.data, i, &a_ref.data, i); + self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i); }); - self.set_k(set_k_unary(self, a)) + res.set_k(set_k_unary(res, a)) } +} - fn sub_inplace_ba(&mut self, module: &M, a: &A) +pub trait GLWERotate +where + Self: ModuleN + VecZnxRotate + VecZnxRotateInplace, +{ + fn glwe_rotate(&self, k: i64, res: &mut R, a: &A) where - A: GLWEToRef + GLWEInfos, - M: VecZnxSubNegateInplace, + R: GLWEToMut, + A: GLWEToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(self.base2k(), a.base2k()); - assert!(self.rank() >= a.rank()) - } + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.rank(), a.rank()); (0..(a.rank() + 1).into()).for_each(|i| { - module.vec_znx_sub_negate_inplace(&mut self_mut.data, i, &a_ref.data, i); + self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i); }); - self.set_k(set_k_unary(self, a)) + res.set_base2k(a.base2k()); + res.set_k(set_k_unary(res, a)) } - fn rotate(&mut self, module: &M, k: i64, a: &A) + fn glwe_rotate_inplace(&self, k: i64, res: &mut R, scratch: &mut Scratch) where - A: GLWEToRef + GLWEInfos, - M: VecZnxRotate, + R: GLWEToMut, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(self.rank(), a.rank()) + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + + (0..(res.rank() + 1).into()).for_each(|i| { + self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch); + }); + } +} + +pub trait GLWEMulXpMinusOne +where + Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace, +{ + fn glwe_mul_xp_minus_one(&self, k: i64, res: &mut R, a: &A) + where + R: GLWEToMut, + A: GLWEToRef, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.rank(), a.rank()); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i); } - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); - - (0..(a.rank() + 1).into()).for_each(|i| { - module.vec_znx_rotate(k, &mut self_mut.data, i, &a_ref.data, i); - }); - - self.set_base2k(a.base2k()); - self.set_k(set_k_unary(self, a)) + res.set_base2k(a.base2k()); + res.set_k(set_k_unary(res, a)) } - fn rotate_inplace(&mut self, module: &M, k: i64, scratch: &mut Scratch) + fn glwe_mul_xp_minus_one_inplace(&self, k: i64, res: &mut R, scratch: &mut Scratch) where - M: VecZnxRotateInplace, + R: GLWEToMut, { - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); - (0..(self_mut.rank() + 1).into()).for_each(|i| { - module.vec_znx_rotate_inplace(k, &mut self_mut.data, i, scratch); - }); + assert_eq!(res.n(), self.n() as u32); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_mul_xp_minus_one_inplace(k, res.data_mut(), i, scratch); + } } +} - fn mul_xp_minus_one(&mut self, module: &M, k: i64, a: &A) +pub trait GLWECopy +where + Self: ModuleN + VecZnxCopy, +{ + fn glwe_copy(&self, res: &mut R, a: &A) where - A: GLWEToRef + GLWEInfos, - M: VecZnxMulXpMinusOne, + R: GLWEToMut, + A: GLWEToRef, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), self.n()); - assert_eq!(self.rank(), a.rank()) + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.rank(), a.rank()); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_copy(res.data_mut(), i, a.data(), i); } - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); - - (0..(a.rank() + 1).into()).for_each(|i| { - module.vec_znx_mul_xp_minus_one(k, &mut self_mut.data, i, &a_ref.data, i); - }); - - self.set_base2k(a.base2k()); - self.set_k(set_k_unary(self, a)) + res.set_k(a.k().min(res.max_k())); + res.set_base2k(a.base2k()); } +} - fn mul_xp_minus_one_inplace(&mut self, module: &M, k: i64, scratch: &mut Scratch) +pub trait GLWEShift +where + Self: ModuleN + VecZnxRshInplace, +{ + fn glwe_rsh(&self, k: usize, res: &mut R, scratch: &mut Scratch) where - M: VecZnxMulXpMinusOneInplace, + R: GLWEToMut, + Scratch: ScratchTakeCore, { - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - - (0..(self_mut.rank() + 1).into()).for_each(|i| { - module.vec_znx_mul_xp_minus_one_inplace(k, &mut self_mut.data, i, scratch); - }); - } - - fn copy(&mut self, module: &M, a: &A) - where - A: GLWEToRef + GLWEInfos, - M: VecZnxCopy, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), a.n()); - assert_eq!(self.rank(), a.rank()); + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let base2k: usize = res.base2k().into(); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_rsh_inplace(base2k, k, res.data_mut(), i, scratch); } - - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); - - (0..(self_mut.rank() + 1).into()).for_each(|i| { - module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i); - }); - - self.set_k(a.k().min(self.max_k())); - self.set_base2k(a.base2k()); - } - - fn rsh(&mut self, module: &M, k: usize, scratch: &mut Scratch) - where - M: VecZnxRshInplace, - { - let base2k: usize = self.base2k().into(); - (0..(self.rank() + 1).into()).for_each(|i| { - module.vec_znx_rsh_inplace(base2k, k, &mut self.to_mut().data, i, scratch); - }) - } - - fn normalize(&mut self, module: &M, a: &A, scratch: &mut Scratch) - where - A: GLWEToRef + GLWEInfos, - M: VecZnxNormalize, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), a.n()); - assert_eq!(self.rank(), a.rank()); - } - - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - let a_ref: &GLWE<&[u8]> = &a.to_ref(); - - (0..(self_mut.rank() + 1).into()).for_each(|i| { - module.vec_znx_normalize( - a.base2k().into(), - &mut self_mut.data, - i, - a.base2k().into(), - &a_ref.data, - i, - scratch, - ); - }); - self.set_base2k(a.base2k()); - self.set_k(a.k().min(self.k())); - } - - fn normalize_inplace(&mut self, module: &M, scratch: &mut Scratch) - where - M: VecZnxNormalizeInplace, - { - let self_mut: &mut GLWE<&mut [u8]> = &mut self.to_mut(); - (0..(self_mut.rank() + 1).into()).for_each(|i| { - module.vec_znx_normalize_inplace(self_mut.base2k().into(), &mut self_mut.data, i, scratch); - }); } } @@ -324,6 +296,50 @@ impl GLWE> { } } +pub trait GLWENormalize +where + Self: ModuleN + VecZnxNormalize + VecZnxNormalizeInplace, +{ + fn glwe_normalize(&self, res: &mut R, a: &A, scratch: &mut Scratch) + where + R: GLWEToMut, + A: GLWEToRef, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + assert_eq!(res.n(), self.n() as u32); + assert_eq!(a.n(), self.n() as u32); + assert_eq!(res.rank(), a.rank()); + + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_normalize( + res.base2k().into(), + res.data_mut(), + i, + a.base2k().into(), + a.data(), + i, + scratch, + ); + } + + res.set_k(a.k().min(res.k())); + } + + fn glwe_normalize_inplace(&mut self, res: &mut R, scratch: &mut Scratch) + where + R: GLWEToMut, + Scratch: ScratchTakeCore, + { + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + for i in 0..res.rank().as_usize() + 1 { + self.vec_znx_normalize_inplace(res.base2k().into(), res.data_mut(), i, scratch); + } + } +} + // c = op(a, b) fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision { // If either operands is a ciphertext diff --git a/poulpy-hal/src/api/module.rs b/poulpy-hal/src/api/module.rs index c2e2f1c..a18af44 100644 --- a/poulpy-hal/src/api/module.rs +++ b/poulpy-hal/src/api/module.rs @@ -9,8 +9,11 @@ pub trait ModuleN { fn n(&self) -> usize; } -pub trait ModuleLogN where Self: ModuleN{ - fn log_n(&self) -> usize{ - (u64::BITS - (self.n() as u64-1).leading_zeros()) as usize +pub trait ModuleLogN +where + Self: ModuleN, +{ + fn log_n(&self) -> usize { + (u64::BITS - (self.n() as u64 - 1).leading_zeros()) as usize } -} \ No newline at end of file +} diff --git a/poulpy-hal/src/layouts/module.rs b/poulpy-hal/src/layouts/module.rs index ccefb3a..3382774 100644 --- a/poulpy-hal/src/layouts/module.rs +++ b/poulpy-hal/src/layouts/module.rs @@ -2,7 +2,10 @@ use std::{fmt::Display, marker::PhantomData, ptr::NonNull}; use rand_distr::num_traits::Zero; -use crate::{api::{ModuleLogN, ModuleN}, GALOISGENERATOR}; +use crate::{ + GALOISGENERATOR, + api::{ModuleLogN, ModuleN}, +}; #[allow(clippy::missing_safety_doc)] pub trait Backend: Sized { @@ -86,7 +89,7 @@ where } } -impl ModuleLogN for Module where Self: ModuleN{} +impl ModuleLogN for Module where Self: ModuleN {} impl CyclotomicOrder for Module where Self: ModuleN {}