diff --git a/poulpy-core/src/conversion/glwe_to_lwe.rs b/poulpy-core/src/conversion/glwe_to_lwe.rs index b6c6ed1..49744ba 100644 --- a/poulpy-core/src/conversion/glwe_to_lwe.rs +++ b/poulpy-core/src/conversion/glwe_to_lwe.rs @@ -1,100 +1,125 @@ use poulpy_hal::{ - api::{ - VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, - VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, + api::ModuleN, + layouts::{Backend, DataMut, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero}, }; -use crate::layouts::{GGLWEInfos, GLWE, GLWEInfos, GLWELayout, LWE, LWEInfos, Rank, prepared::GLWEToLWESwitchingKeyPrepared}; +use crate::{ + GLWEKeyswitch, ScratchTakeCore, + layouts::{ + GGLWEInfos, GLWE, GLWEAlloc, GLWEInfos, GLWELayout, GLWEToRef, LWE, LWEInfos, LWEToMut, Rank, + prepared::{LWEToGLWESwitchingKeyPrepared, LWEToGLWESwitchingKeyPreparedToRef}, + }, +}; -impl LWE> { - pub fn from_glwe_tmp_bytes( - module: &Module, - lwe_infos: &OUT, - glwe_infos: &IN, - key_infos: &KEY, - ) -> usize +pub trait LWESampleExtract +where + Self: ModuleN, +{ + fn lwe_sample_extract(&self, res: &mut R, a: &A) where - OUT: LWEInfos, - IN: GLWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: LWEToMut, + A: GLWEToRef, { - let glwe_layout: GLWELayout = GLWELayout { - n: module.n().into(), + let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + + assert!(res.n() <= a.n()); + assert_eq!(a.n(), self.n() as u32); + assert!(res.base2k() == a.base2k()); + + let min_size: usize = res.size().min(a.size()); + let n: usize = res.n().into(); + + res.data.zero(); + (0..min_size).for_each(|i| { + let data_lwe: &mut [i64] = res.data.at_mut(0, i); + data_lwe[0] = a.data.at(0, i)[0]; + data_lwe[1..].copy_from_slice(&a.data.at(1, i)[..n]); + }); + } +} + +impl LWESampleExtract for Module where Self: ModuleN {} +impl LWEFromGLWE for Module where Self: GLWEKeyswitch + GLWEAlloc + LWESampleExtract {} + +pub trait LWEFromGLWE +where + Self: GLWEKeyswitch + GLWEAlloc + LWESampleExtract, +{ + fn lwe_from_glwe_tmp_bytes(&self, lwe_infos: &R, glwe_infos: &A, key_infos: &K) -> usize + where + R: LWEInfos, + A: GLWEInfos, + K: GGLWEInfos, + { + let res_infos: GLWELayout = GLWELayout { + n: self.n().into(), base2k: lwe_infos.base2k(), k: lwe_infos.k(), rank: Rank(1), }; - GLWE::bytes_of( - module.n().into(), - lwe_infos.base2k(), - lwe_infos.k(), - 1u32.into(), - ) + GLWE::keyswitch_tmp_bytes(module, &glwe_layout, glwe_infos, key_infos) - } -} - -impl LWE { - pub fn sample_extract(&mut self, a: &GLWE) { - #[cfg(debug_assertions)] - { - assert!(self.n() <= a.n()); - assert!(self.base2k() == a.base2k()); - } - - let min_size: usize = self.size().min(a.size()); - let n: usize = self.n().into(); - - self.data.zero(); - (0..min_size).for_each(|i| { - let data_lwe: &mut [i64] = self.data.at_mut(0, i); - data_lwe[0] = a.data.at(0, i)[0]; - data_lwe[1..].copy_from_slice(&a.data.at(1, i)[..n]); - }); + self.bytes_of_glwe(lwe_infos.base2k(), lwe_infos.k(), 1u32.into()) + + self.glwe_keyswitch_tmp_bytes(&res_infos, glwe_infos, key_infos) } - pub fn from_glwe( - &mut self, - module: &Module, - a: &GLWE, - ks: &GLWEToLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - DGlwe: DataRef, - DKs: DataRef, - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch:, + fn lwe_from_glwe(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: LWEToMut, + A: GLWEToRef, + K: LWEToGLWESwitchingKeyPreparedToRef + GGLWEInfos, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - assert_eq!(a.n(), module.n() as u32); - assert_eq!(ks.n(), module.n() as u32); - assert!(self.n() <= module.n() as u32); - } + let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); + let a: &GLWE<&[u8]> = &a.to_ref(); + let key: &LWEToGLWESwitchingKeyPrepared<&[u8], BE> = &key.to_ref(); + + assert_eq!(a.n(), self.n() as u32); + assert_eq!(key.n(), self.n() as u32); + assert!(res.n() <= self.n() as u32); let glwe_layout: GLWELayout = GLWELayout { - n: module.n().into(), - base2k: self.base2k(), - k: self.k(), + n: self.n().into(), + base2k: res.base2k(), + k: res.k(), rank: Rank(1), }; - let (mut tmp_glwe, scratch_1) = scratch.take_glwe_ct(&glwe_layout); - tmp_glwe.keyswitch(module, a, &ks.0, scratch_1); - self.sample_extract(&tmp_glwe); + let (mut tmp_glwe, scratch_1) = scratch.take_glwe_ct(self, &glwe_layout); + self.glwe_keyswitch(&mut tmp_glwe, a, &key.0, scratch_1); + self.lwe_sample_extract(res, &tmp_glwe); + } +} + +impl LWE> { + pub fn from_glwe_tmp_bytes(module: &M, lwe_infos: &R, glwe_infos: &A, key_infos: &K) -> usize + where + R: LWEInfos, + A: GLWEInfos, + K: GGLWEInfos, + M: LWEFromGLWE, + { + module.lwe_from_glwe_tmp_bytes(lwe_infos, glwe_infos, key_infos) + } +} + +impl LWE { + pub fn sample_extract(&mut self, module: &M, a: &A) + where + A: GLWEToRef, + M: LWESampleExtract, + { + module.lwe_sample_extract(self, a); + } + + pub fn from_glwe(&self, module: &M, res: &mut R, a: &A, key: &K, scratch: &mut Scratch) + where + R: LWEToMut, + A: GLWEToRef, + K: LWEToGLWESwitchingKeyPreparedToRef + GGLWEInfos, + M: LWEFromGLWE, + Scratch: ScratchTakeCore, + { + module.lwe_from_glwe(res, a, key, scratch); } } diff --git a/poulpy-core/src/conversion/lwe_to_glwe.rs b/poulpy-core/src/conversion/lwe_to_glwe.rs index c4a3b88..7be6574 100644 --- a/poulpy-core/src/conversion/lwe_to_glwe.rs +++ b/poulpy-core/src/conversion/lwe_to_glwe.rs @@ -1,79 +1,67 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, - VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, - VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, - }, - layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, ZnxView, ZnxViewMut, ZnxZero}, + api::ScratchTakeBasic, + layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxView, ZnxViewMut, ZnxZero}, }; -use crate::layouts::{GGLWEInfos, GLWE, GLWEInfos, GLWELayout, LWE, LWEInfos, prepared::LWEToGLWESwitchingKeyPrepared}; +use crate::{ + GLWEKeyswitch, ScratchTakeCore, + layouts::{ + GGLWEInfos, GLWE, GLWEAlloc, GLWEInfos, GLWELayout, GLWEToMut, LWE, LWEInfos, LWEToRef, + prepared::{LWEToGLWESwitchingKeyPrepared, LWEToGLWESwitchingKeyPreparedToRef}, + }, +}; -impl GLWE> { - pub fn from_lwe_tmp_bytes( - module: &Module, - glwe_infos: &OUT, - lwe_infos: &IN, - key_infos: &KEY, - ) -> usize +impl GLWEFromLWE for Module where Self: GLWEKeyswitch + GLWEAlloc {} + +pub trait GLWEFromLWE +where + Self: GLWEKeyswitch + GLWEAlloc, +{ + fn glwe_from_lwe_tmp_bytes(&self, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize where - OUT: GLWEInfos, - IN: LWEInfos, - KEY: GGLWEInfos, - Module: VecZnxDftBytesOf + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes, + R: GLWEInfos, + A: LWEInfos, + K: GGLWEInfos, { - let ct: usize = GLWE::bytes_of( - module.n().into(), + let ct: usize = self.bytes_of_glwe( key_infos.base2k(), lwe_infos.k().max(glwe_infos.k()), 1u32.into(), ); - let ks: usize = GLWE::keyswitch_inplace_tmp_bytes(module, glwe_infos, key_infos); + + let ks: usize = self.glwe_keyswitch_tmp_bytes(glwe_infos, glwe_infos, key_infos); if lwe_infos.base2k() == key_infos.base2k() { ct + ks } else { - let a_conv = VecZnx::bytes_of(module.n(), 1, lwe_infos.size()) + module.vec_znx_normalize_tmp_bytes(); + let a_conv = VecZnx::bytes_of(self.n(), 1, lwe_infos.size()) + self.vec_znx_normalize_tmp_bytes(); ct + a_conv + ks } } -} -impl GLWE { - pub fn from_lwe( - &mut self, - module: &Module, - lwe: &LWE, - ksk: &LWEToGLWESwitchingKeyPrepared, - scratch: &mut Scratch, - ) where - DLwe: DataRef, - DKsk: DataRef, - Module: VecZnxDftBytesOf - + VmpApplyDftToDftTmpBytes - + VecZnxBigNormalizeTmpBytes - + VmpApplyDftToDft - + VmpApplyDftToDftAdd - + VecZnxDftApply - + VecZnxIdftApplyConsume - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize - + VecZnxNormalize - + VecZnxNormalizeTmpBytes, - Scratch: ScratchAvailable, + fn glwe_from_lwe(&self, res: &mut R, lwe: &A, ksk: &K, scratch: &mut Scratch) + where + R: GLWEToMut, + A: LWEToRef, + K: LWEToGLWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), module.n() as u32); - assert_eq!(ksk.n(), module.n() as u32); - assert!(lwe.n() <= module.n() as u32); - } + let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); + let lwe: &LWE<&[u8]> = &lwe.to_ref(); + let ksk: &LWEToGLWESwitchingKeyPrepared<&[u8], BE> = &ksk.to_ref(); - let (mut glwe, scratch_1) = scratch.take_glwe_ct(&GLWELayout { - n: ksk.n(), - base2k: ksk.base2k(), - k: lwe.k(), - rank: 1u32.into(), - }); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(ksk.n(), self.n() as u32); + assert!(lwe.n() <= self.n() as u32); + + let (mut glwe, scratch_1) = scratch.take_glwe_ct( + self, + &GLWELayout { + n: ksk.n(), + base2k: ksk.base2k(), + k: lwe.k(), + rank: 1u32.into(), + }, + ); glwe.data.zero(); let n_lwe: usize = lwe.n().into(); @@ -85,14 +73,14 @@ impl GLWE { glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); } } else { - let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), 1, lwe.size()); + let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self, 1, lwe.size()); a_conv.zero(); for j in 0..lwe.size() { let data_lwe: &[i64] = lwe.data.at(0, j); a_conv.at_mut(0, j)[0] = data_lwe[0] } - module.vec_znx_normalize( + self.vec_znx_normalize( ksk.base2k().into(), &mut glwe.data, 0, @@ -108,7 +96,7 @@ impl GLWE { a_conv.at_mut(0, j)[..n_lwe].copy_from_slice(&data_lwe[1..]); } - module.vec_znx_normalize( + self.vec_znx_normalize( ksk.base2k().into(), &mut glwe.data, 1, @@ -119,6 +107,30 @@ impl GLWE { ); } - self.keyswitch(module, &glwe, &ksk.0, scratch_1); + self.glwe_keyswitch(res, &glwe, &ksk.0, scratch_1); + } +} + +impl GLWE> { + pub fn from_lwe_tmp_bytes(module: &M, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize + where + R: GLWEInfos, + A: LWEInfos, + K: GGLWEInfos, + M: GLWEFromLWE, + { + module.glwe_from_lwe_tmp_bytes(glwe_infos, lwe_infos, key_infos) + } +} + +impl GLWE { + pub fn from_lwe(&mut self, module: &M, lwe: &A, ksk: &K, scratch: &mut Scratch) + where + M: GLWEFromLWE, + A: LWEToRef, + K: LWEToGLWESwitchingKeyPreparedToRef, + Scratch: ScratchTakeCore, + { + module.glwe_from_lwe(self, lwe, ksk, scratch); } }