diff --git a/poulpy-core/src/glwe_packing.rs b/poulpy-core/src/glwe_packing.rs index 7dacb97..129de7e 100644 --- a/poulpy-core/src/glwe_packing.rs +++ b/poulpy-core/src/glwe_packing.rs @@ -2,18 +2,13 @@ use std::collections::HashMap; use poulpy_hal::{ api::{ - ScratchAvailable, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, - VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftApply, - VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, - VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, - VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, + ModuleLogN, ModuleN, ScratchAvailable, VecZnxAddInplace, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallNegateInplace, VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNegateInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes }, - layouts::{Backend, DataMut, DataRef, Module, Scratch}, + layouts::{Backend, DataMut, DataRef, GaloisElement, Module, Scratch}, }; use crate::{ - GLWEOperations, - layouts::{GGLWEInfos, GLWE, GLWEInfos, LWEInfos, prepared::AutomorphismKeyPrepared}, + layouts::{prepared::{AutomorphismKeyPrepared, AutomorphismKeyPreparedToRef}, GGLWEInfos, GLWEAlloc, GLWEInfos, GLWEToRef, LWEInfos, GLWE}, GLWEAutomorphism, GLWEOperations, ScratchTakeCore }; /// [GLWEPacker] enables only the fly GLWE packing @@ -43,9 +38,10 @@ impl Accumulator { /// * `base2k`: base 2 logarithm of the GLWE ciphertext in memory digit representation. /// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus. /// * `rank`: rank of the GLWE ciphertext. - pub fn alloc(module: &Module, infos: &A) -> Self + pub fn alloc(module: &M, infos: &A) -> Self where A: GLWEInfos, + M: GLWEAlloc { Self { data: GLWE::alloc_from_infos(module, infos), @@ -66,9 +62,10 @@ impl GLWEPacker { /// and N GLWE ciphertext can be packed. With `log_batch=2` all coefficients /// which are multiples of X^{N/4} are packed. Meaning that N/4 ciphertexts /// can be packed. - pub fn new(module: Module, infos: &A, log_batch: usize) -> Self + pub fn new(module: &M, infos: &A, log_batch: usize) -> Self where A: GLWEInfos, + M: GLWEAlloc { let mut accumulators: Vec = Vec::::new(); let log_n: usize = infos.n().log2(); @@ -90,13 +87,13 @@ impl GLWEPacker { } /// Number of scratch space bytes required to call [Self::add]. - pub fn tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize + pub fn tmp_bytes(module: &M, res_infos: &R, key_infos: &K) -> usize where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + K: GGLWEInfos, + M: GLWEAlloc + GLWEAutomorphism, { - pack_core_tmp_bytes(module, out_infos, key_infos) + pack_core_tmp_bytes(module, res_infos, key_infos) } pub fn galois_elements(module: &Module) -> Vec { @@ -112,37 +109,17 @@ impl GLWEPacker { /// * `a`: ciphertext to pack. Can optionally give None to pack a 0 ciphertext. /// * `auto_keys`: a [HashMap] containing the [AutomorphismKeyExec]s. /// * `scratch`: scratch space of size at least [Self::tmp_bytes]. - pub fn add( + pub fn add( &mut self, - module: &Module, - a: Option<&GLWE>, - auto_keys: &HashMap>, + module: &M, + a: Option<&A>, + auto_keys: &HashMap, scratch: &mut Scratch, ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxCopy - + VecZnxRotateInplace - + VecZnxSub - + VecZnxNegateInplace - + VecZnxRshInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxSubInplace - + VecZnxRotate - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxBigAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, + A: GLWEToRef, + K: AutomorphismKeyPreparedToRef, + M: GLWEAutomorphism, + Scratch: ScratchTakeCore, { assert!( (self.counter as u32) < self.accumulators[0].data.n(), @@ -177,47 +154,27 @@ impl GLWEPacker { } } -fn pack_core_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize +fn pack_core_tmp_bytes(module: &M, res_infos: &R, key_infos: &K) -> usize where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + K: GGLWEInfos, + M: GLWEAlloc + GLWEAutomorphism, { - combine_tmp_bytes(module, out_infos, key_infos) + combine_tmp_bytes(module, res_infos, key_infos) } -fn pack_core( - module: &Module, - a: Option<&GLWE>, +fn pack_core( + module: &M, + a: Option<&A>, accumulators: &mut [Accumulator], i: usize, - auto_keys: &HashMap>, - scratch: &mut Scratch, + auto_keys: &HashMap, + scratch: &mut Scratch, ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxCopy - + VecZnxRotateInplace - + VecZnxSub - + VecZnxNegateInplace - + VecZnxRshInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxSubInplace - + VecZnxRotate - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxBigAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, + A: GLWEToRef + GLWEInfos, + K: AutomorphismKeyPreparedToRef, + M: GLWEAutomorphism + ModuleLogN + VecZnxCopy, + Scratch: ScratchTakeCore, { let log_n: usize = module.log_n(); @@ -268,49 +225,29 @@ fn pack_core( } } -fn combine_tmp_bytes(module: &Module, out_infos: &OUT, key_infos: &KEY) -> usize +fn combine_tmp_bytes(module: &M, res_infos: &R, key_infos: &K) -> usize where - OUT: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + K: GGLWEInfos, + M: GLWEAlloc + GLWEAutomorphism, { - GLWE::bytes_of_from_infos(module, out_infos) - + (GLWE::rsh_tmp_bytes(module.n()) | GLWE::automorphism_inplace_tmp_bytes(module, out_infos, key_infos)) + GLWE::bytes_of_from_infos(module, res_infos) + + (GLWE::rsh_tmp_bytes(module.n()) | module.glwe_automorphism_tmp_bytes(res_infos, res_infos, key_infos)) } /// [combine] merges two ciphertexts together. -fn combine( - module: &Module, +fn combine( + module: &M, acc: &mut Accumulator, - b: Option<&GLWE>, + b: Option<&B>, i: usize, - auto_keys: &HashMap>, - scratch: &mut Scratch, + auto_keys: &HashMap, + scratch: &mut Scratch, ) where - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxCopy - + VecZnxRotateInplace - + VecZnxSub - + VecZnxNegateInplace - + VecZnxRshInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxSubInplace - + VecZnxRotate - + VecZnxAutomorphismInplace - + VecZnxBigSubSmallNegateInplace - + VecZnxBigAutomorphismInplace - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, + B: GLWEToRef + GLWEInfos, + K: AutomorphismKeyPreparedToRef, + M: GLWEAutomorphism + GaloisElement + VecZnxRotateInplace, + Scratch: ScratchTakeCore, { let log_n: usize = acc.data.n().log2(); let a: &mut GLWE> = &mut acc.data; @@ -335,7 +272,7 @@ fn combine( // since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q. if acc.value { if let Some(b) = b { - let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(a); + let (mut tmp_b, scratch_1) = scratch.take_glwe_ct(module, a); // a = a * X^-t a.rotate_inplace(module, -t, scratch_1);