diff --git a/poulpy-core/src/encryption/compressed/gglwe_atk.rs b/poulpy-core/src/encryption/compressed/gglwe_atk.rs index 95dcf20..f202223 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_atk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_atk.rs @@ -1,19 +1,15 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes, - VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, - }, + api::{ScratchAvailable, SvpPPolAllocBytes, VecZnxAutomorphism, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes}, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, }; use crate::{ - TakeGLWESecret, TakeGLWESecretPrepared, + TakeGLWESecret, + encryption::compressed::gglwe_ksk::GGLWEKeyCompressedEncryptSk, layouts::{ - GGLWEInfos, GLWEInfos, GLWESecret, LWEInfos, - compressed::{GGLWEAutomorphismKeyCompressed, GGLWESwitchingKeyCompressed}, + GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, LWEInfos, + compressed::{GGLWEAutomorphismKeyCompressed, GGLWEAutomorphismKeyCompressedToMut, GGLWEKeyCompressed}, }, }; @@ -24,8 +20,75 @@ impl GGLWEAutomorphismKeyCompressed> { Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + SvpPPolAllocBytes, { assert_eq!(module.n() as u32, infos.n()); - GGLWESwitchingKeyCompressed::encrypt_sk_scratch_space(module, infos) - + GLWESecret::alloc_bytes_with(infos.n(), infos.rank_out()) + GGLWEKeyCompressed::encrypt_sk_scratch_space(module, infos) + GLWESecret::alloc_bytes_with(infos.n(), infos.rank_out()) + } +} + +pub trait GGLWEAutomorphismKeyCompressedEncryptSk { + fn gglwe_automorphism_key_compressed_encrypt_sk( + &self, + res: &mut R, + p: i64, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEAutomorphismKeyCompressedToMut, + S: GLWESecretToRef; +} + +impl GGLWEAutomorphismKeyCompressedEncryptSk for Module +where + Module: + GGLWEKeyCompressedEncryptSk + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + SvpPPolAllocBytes + VecZnxAutomorphism, + Scratch: TakeGLWESecret + ScratchAvailable, +{ + fn gglwe_automorphism_key_compressed_encrypt_sk( + &self, + res: &mut R, + p: i64, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEAutomorphismKeyCompressedToMut, + S: GLWESecretToRef, + { + let res: &mut GGLWEAutomorphismKeyCompressed<&mut [u8]> = &mut res.to_mut(); + let sk: &GLWESecret<&[u8]> = &sk.to_ref(); + + #[cfg(debug_assertions)] + { + assert_eq!(res.n(), sk.n()); + assert_eq!(res.rank_out(), res.rank_in()); + assert_eq!(sk.rank(), res.rank_out()); + assert!( + scratch.available() >= GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space(self, res), + "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space: {}", + scratch.available(), + GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space(self, res) + ) + } + + let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank()); + + { + (0..res.rank_out().into()).for_each(|i| { + self.vec_znx_automorphism( + self.galois_element_inv(p), + &mut sk_out.data.as_vec_znx_mut(), + i, + &sk.data.as_vec_znx(), + i, + ); + }); + } + + self.gglwe_key_compressed_encrypt_sk(&mut res.key, sk, &sk_out, seed_xa, source_xe, scratch_1); + + res.p = p; } } @@ -40,56 +103,8 @@ impl GGLWEAutomorphismKeyCompressed { source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: VecZnxAutomorphism - + SvpPrepare - + SvpPPolAllocBytes - + VecZnxSwitchRing - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VecZnxAddScalarInplace, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, + Module: GGLWEAutomorphismKeyCompressedEncryptSk, { - #[cfg(debug_assertions)] - { - assert_eq!(self.n(), sk.n()); - assert_eq!(self.rank_out(), self.rank_in()); - assert_eq!(sk.rank(), self.rank_out()); - assert!( - scratch.available() >= GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self), - "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space: {}", - scratch.available(), - GGLWEAutomorphismKeyCompressed::encrypt_sk_scratch_space(module, self) - ) - } - - let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank()); - - { - (0..self.rank_out().into()).for_each(|i| { - module.vec_znx_automorphism( - module.galois_element_inv(p), - &mut sk_out.data.as_vec_znx_mut(), - i, - &sk.data.as_vec_znx(), - i, - ); - }); - } - - self.key - .encrypt_sk(module, sk, &sk_out, seed_xa, source_xe, scratch_1); - - self.p = p; + module.gglwe_automorphism_key_compressed_encrypt_sk(self, p, sk, seed_xa, source_xe, scratch); } } diff --git a/poulpy-core/src/encryption/compressed/gglwe_ct.rs b/poulpy-core/src/encryption/compressed/gglwe_ct.rs index 76871da..c2b4de3 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ct.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ct.rs @@ -1,29 +1,22 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, + ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftAllocBytes, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, + ZnNormalizeInplace, }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero}, + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxZero}, source::Source, }; use crate::{ TakeGLWEPt, - encryption::{SIGMA, glwe_encrypt_sk_internal}, - layouts::{GGLWECiphertext, GGLWEInfos, LWEInfos, compressed::GGLWECiphertextCompressed, prepared::GLWESecretPrepared}, + encryption::{SIGMA, glwe_ct::GLWEEncryptSkInternal}, + layouts::{ + GGLWECiphertext, GGLWEInfos, LWEInfos, + compressed::{GGLWECiphertextCompressed, GGLWECiphertextCompressedToMut}, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, + }, }; -impl GGLWECiphertextCompressed> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize - where - A: GGLWEInfos, - Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, - { - GGLWECiphertext::encrypt_sk_scratch_space(module, infos) - } -} - impl GGLWECiphertextCompressed { #[allow(clippy::too_many_arguments)] pub fn encrypt_sk( @@ -35,83 +28,124 @@ impl GGLWECiphertextCompressed { source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Module: GGLWECompressedEncryptSk, { + module.gglwe_compressed_encrypt_sk(self, pt, sk, seed, source_xe, scratch); + } +} + +impl GGLWECiphertextCompressed> { + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + where + A: GGLWEInfos, + Module: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, + { + GGLWECiphertext::encrypt_sk_scratch_space(module, infos) + } +} + +pub trait GGLWECompressedEncryptSk { + fn gglwe_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWECiphertextCompressedToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef; +} + +impl GGLWECompressedEncryptSk for Module +where + Module: GLWEEncryptSkInternal + + VecZnxNormalizeInplace + + VecZnxNormalizeTmpBytes + + VecZnxDftAllocBytes + + VecZnxAddScalarInplace + + ZnNormalizeInplace, + Scratch: TakeGLWEPt + ScratchAvailable, +{ + fn gglwe_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWECiphertextCompressedToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, + { + let res: &mut GGLWECiphertextCompressed<&mut [u8]> = &mut res.to_mut(); + let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); + #[cfg(debug_assertions)] { use poulpy_hal::layouts::ZnxInfos; + let sk = &sk.to_ref(); assert_eq!( - self.rank_in(), + res.rank_in(), pt.cols() as u32, - "self.rank_in(): {} != pt.cols(): {}", - self.rank_in(), + "res.rank_in(): {} != pt.cols(): {}", + res.rank_in(), pt.cols() ); assert_eq!( - self.rank_out(), + res.rank_out(), sk.rank(), - "self.rank_out(): {} != sk.rank(): {}", - self.rank_out(), + "res.rank_out(): {} != sk.rank(): {}", + res.rank_out(), sk.rank() ); - assert_eq!(self.n(), sk.n()); + assert_eq!(res.n(), sk.n()); assert_eq!(pt.n() as u32, sk.n()); assert!( - scratch.available() >= GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self), + scratch.available() >= GGLWECiphertextCompressed::encrypt_sk_scratch_space(self, res), "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space: {}", scratch.available(), - GGLWECiphertextCompressed::encrypt_sk_scratch_space(module, self) + GGLWECiphertextCompressed::encrypt_sk_scratch_space(self, res) ); assert!( - self.dnum().0 * self.dsize().0 * self.base2k().0 <= self.k().0, - "self.dnum() : {} * self.dsize() : {} * self.base2k() : {} = {} >= self.k() = {}", - self.dnum(), - self.dsize(), - self.base2k(), - self.dnum().0 * self.dsize().0 * self.base2k().0, - self.k() + res.dnum().0 * res.dsize().0 * res.base2k().0 <= res.k().0, + "res.dnum() : {} * res.dsize() : {} * res.base2k() : {} = {} >= res.k() = {}", + res.dnum(), + res.dsize(), + res.base2k(), + res.dnum().0 * res.dsize().0 * res.base2k().0, + res.k() ); } - let dnum: usize = self.dnum().into(); - let dsize: usize = self.dsize().into(); - let base2k: usize = self.base2k().into(); - let rank_in: usize = self.rank_in().into(); - let cols: usize = (self.rank_out() + 1).into(); + let dnum: usize = res.dnum().into(); + let dsize: usize = res.dsize().into(); + let base2k: usize = res.base2k().into(); + let rank_in: usize = res.rank_in().into(); + let cols: usize = (res.rank_out() + 1).into(); let mut source_xa = Source::new(seed); - let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(self); + let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(res); (0..rank_in).for_each(|col_i| { (0..dnum).for_each(|d_i| { // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt tmp_pt.data.zero(); // zeroes for next iteration - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + d_i * dsize, pt, col_i); - module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); + self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + d_i * dsize, pt, col_i); + self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); let (seed, mut source_xa_tmp) = source_xa.branch(); - self.seed[col_i * dnum + d_i] = seed; + res.seed[col_i * dnum + d_i] = seed; - glwe_encrypt_sk_internal( - module, - self.base2k().into(), - self.k().into(), - &mut self.at_mut(d_i, col_i).data, + self.glwe_encrypt_sk_internal( + res.base2k().into(), + res.k().into(), + &mut res.at_mut(d_i, col_i).data, cols, true, Some((&tmp_pt, 0)), diff --git a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs index 8dd177f..e5b3716 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_ksk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_ksk.rs @@ -1,9 +1,7 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft, - VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, - VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, - VecZnxSubInplace, VecZnxSwitchRing, + ScratchAvailable, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes, + VecZnxSwitchRing, }, layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch}, source::Source, @@ -11,13 +9,15 @@ use poulpy_hal::{ use crate::{ TakeGLWESecretPrepared, + encryption::compressed::gglwe_ct::GGLWECompressedEncryptSk, layouts::{ - Degree, GGLWECiphertext, GGLWEInfos, GLWEInfos, GLWESecret, LWEInfos, compressed::GGLWESwitchingKeyCompressed, + Degree, GGLWECiphertext, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, LWEInfos, + compressed::{GGLWEKeyCompressed, GGLWEKeyCompressedToMut}, prepared::GLWESecretPrepared, }, }; -impl GGLWESwitchingKeyCompressed> { +impl GGLWEKeyCompressed> { pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where A: GGLWEInfos, @@ -29,7 +29,7 @@ impl GGLWESwitchingKeyCompressed> { } } -impl GGLWESwitchingKeyCompressed { +impl GGLWEKeyCompressed { #[allow(clippy::too_many_arguments)] pub fn encrypt_sk( &mut self, @@ -40,36 +40,65 @@ impl GGLWESwitchingKeyCompressed { source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: SvpPrepare - + SvpPPolAllocBytes - + VecZnxSwitchRing - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VecZnxAddScalarInplace, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, + Module: GGLWEKeyCompressedEncryptSk, { + module.gglwe_key_compressed_encrypt_sk(self, sk_in, sk_out, seed_xa, source_xe, scratch); + } +} + +pub trait GGLWEKeyCompressedEncryptSk { + fn gglwe_key_compressed_encrypt_sk( + &self, + res: &mut R, + sk_in: &SI, + sk_out: &SO, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEKeyCompressedToMut, + SI: GLWESecretToRef, + SO: GLWESecretToRef; +} + +impl GGLWEKeyCompressedEncryptSk for Module +where + Module: GGLWECompressedEncryptSk + + SvpPPolAllocBytes + + VecZnxNormalizeTmpBytes + + VecZnxDftAllocBytes + + VecZnxSwitchRing + + SvpPrepare, + Scratch: ScratchAvailable + TakeScalarZnx + TakeGLWESecretPrepared, +{ + fn gglwe_key_compressed_encrypt_sk( + &self, + res: &mut R, + sk_in: &SI, + sk_out: &SO, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWEKeyCompressedToMut, + SI: GLWESecretToRef, + SO: GLWESecretToRef, + { + let res: &mut GGLWEKeyCompressed<&mut [u8]> = &mut res.to_mut(); + let sk_in: &GLWESecret<&[u8]> = &sk_in.to_ref(); + let sk_out: &GLWESecret<&[u8]> = &sk_out.to_ref(); + #[cfg(debug_assertions)] { use crate::layouts::GGLWESwitchingKey; - assert!(sk_in.n().0 <= module.n() as u32); - assert!(sk_out.n().0 <= module.n() as u32); + assert!(sk_in.n().0 <= self.n() as u32); + assert!(sk_out.n().0 <= self.n() as u32); assert!( - scratch.available() >= GGLWESwitchingKey::encrypt_sk_scratch_space(module, self), + scratch.available() >= GGLWESwitchingKey::encrypt_sk_scratch_space(self, res), "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}", scratch.available(), - GGLWESwitchingKey::encrypt_sk_scratch_space(module, self) + GGLWESwitchingKey::encrypt_sk_scratch_space(self, res) ) } @@ -77,7 +106,7 @@ impl GGLWESwitchingKeyCompressed { let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank().into()); (0..sk_in.rank().into()).for_each(|i| { - module.vec_znx_switch_ring( + self.vec_znx_switch_ring( &mut sk_in_tmp.as_vec_znx_mut(), i, &sk_in.data.as_vec_znx(), @@ -89,20 +118,20 @@ impl GGLWESwitchingKeyCompressed { { let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1); (0..sk_out.rank().into()).for_each(|i| { - module.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); - module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); + self.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i); + self.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0); }); } - self.key.encrypt_sk( - module, + self.gglwe_compressed_encrypt_sk( + &mut res.key, &sk_in_tmp, &sk_out_tmp, seed_xa, source_xe, scratch_2, ); - self.sk_in_n = sk_in.n().into(); - self.sk_out_n = sk_out.n().into(); + res.sk_in_n = sk_in.n().into(); + res.sk_out_n = sk_out.n().into(); } } diff --git a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs index 6a75a57..be78f3e 100644 --- a/poulpy-core/src/encryption/compressed/gglwe_tsk.rs +++ b/poulpy-core/src/encryption/compressed/gglwe_tsk.rs @@ -1,9 +1,7 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, - TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes, - VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, - VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, + SvpApplyDftToDft, SvpPPolAllocBytes, SvpPrepare, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAllocBytes, VecZnxBigNormalize, + VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, }, layouts::{Backend, DataMut, DataRef, Module, Scratch}, source::Source, @@ -11,8 +9,10 @@ use poulpy_hal::{ use crate::{ TakeGLWESecret, TakeGLWESecretPrepared, + encryption::compressed::gglwe_ksk::GGLWEKeyCompressedEncryptSk, layouts::{ - GGLWEInfos, GGLWETensorKey, GLWEInfos, GLWESecret, LWEInfos, Rank, compressed::GGLWETensorKeyCompressed, + GGLWEInfos, GGLWETensorKey, GLWEInfos, GLWESecret, GLWESecretToRef, LWEInfos, Rank, + compressed::{GGLWETensorKeyCompressed, GGLWETensorKeyCompressedToMut}, prepared::Prepare, }, }; @@ -28,59 +28,59 @@ impl GGLWETensorKeyCompressed> { } } -impl GGLWETensorKeyCompressed { - pub fn encrypt_sk( - &mut self, - module: &Module, - sk: &GLWESecret, +pub trait GGLWETensorKeyCompressedEncryptSk { + fn gglwe_tensor_key_encrypt_sk( + &self, + res: &mut R, + sk: &S, seed_xa: [u8; 32], source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: SvpApplyDftToDft - + VecZnxIdftApplyTmpA - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + VecZnxSwitchRing - + VecZnxAddScalarInplace - + SvpPrepare - + SvpPPolAllocBytes - + SvpPPolAlloc, - Scratch: ScratchAvailable - + TakeScalarZnx - + TakeVecZnxDft - + TakeGLWESecretPrepared - + ScratchAvailable - + TakeVecZnx - + TakeVecZnxBig, + R: GGLWETensorKeyCompressedToMut, + S: GLWESecretToRef; +} + +impl GGLWETensorKeyCompressedEncryptSk for Module +where + Module: GGLWEKeyCompressedEncryptSk + + VecZnxDftApply + + SvpApplyDftToDft + + VecZnxIdftApplyTmpA + + VecZnxBigNormalize + + SvpPrepare, + Scratch: TakeGLWESecretPrepared + TakeVecZnxDft + TakeVecZnxBig + TakeGLWESecret, +{ + fn gglwe_tensor_key_encrypt_sk( + &self, + res: &mut R, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWETensorKeyCompressedToMut, + S: GLWESecretToRef, { + let res: &mut GGLWETensorKeyCompressed<&mut [u8]> = &mut res.to_mut(); + let sk: &GLWESecret<&[u8]> = &sk.to_ref(); + #[cfg(debug_assertions)] { - assert_eq!(self.rank_out(), sk.rank()); - assert_eq!(self.n(), sk.n()); + assert_eq!(res.rank_out(), sk.rank()); + assert_eq!(res.n(), sk.n()); } let n: usize = sk.n().into(); - let rank: usize = self.rank_out().into(); + let rank: usize = res.rank_out().into(); - let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(sk.n(), self.rank_out()); - sk_dft_prep.prepare(module, sk, scratch_1); + let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(sk.n(), res.rank_out()); + sk_dft_prep.prepare(self, sk, scratch_1); let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1); for i in 0..rank { - module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); + self.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); } let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1); @@ -91,14 +91,14 @@ impl GGLWETensorKeyCompressed { for i in 0..rank { for j in i..rank { - module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); + self.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i); - module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); - module.vec_znx_big_normalize( - self.base2k().into(), + self.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0); + self.vec_znx_big_normalize( + res.base2k().into(), &mut sk_ij.data.as_vec_znx_mut(), 0, - self.base2k().into(), + res.base2k().into(), &sk_ij_big, 0, scratch_5, @@ -106,9 +106,30 @@ impl GGLWETensorKeyCompressed { let (seed_xa_tmp, _) = source_xa.branch(); - self.at_mut(i, j) - .encrypt_sk(module, &sk_ij, sk, seed_xa_tmp, source_xe, scratch_5); + self.gglwe_key_compressed_encrypt_sk( + res.at_mut(i, j), + &sk_ij, + sk, + seed_xa_tmp, + source_xe, + scratch_5, + ); } } } } + +impl GGLWETensorKeyCompressed { + pub fn encrypt_sk( + &mut self, + module: &Module, + sk: &GLWESecret, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + Module: GGLWETensorKeyCompressedEncryptSk, + { + module.gglwe_tensor_key_encrypt_sk(self, sk, seed_xa, source_xe, scratch); + } +} diff --git a/poulpy-core/src/encryption/compressed/ggsw_ct.rs b/poulpy-core/src/encryption/compressed/ggsw_ct.rs index e49f246..bdad6be 100644 --- a/poulpy-core/src/encryption/compressed/ggsw_ct.rs +++ b/poulpy-core/src/encryption/compressed/ggsw_ct.rs @@ -1,18 +1,16 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero}, + api::{VecZnxAddScalarInplace, VecZnxDftAllocBytes, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxZero}, source::Source, }; use crate::{ TakeGLWEPt, - encryption::{SIGMA, glwe_encrypt_sk_internal}, + encryption::{SIGMA, glwe_ct::GLWEEncryptSkInternal}, layouts::{ - GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, compressed::GGSWCiphertextCompressed, prepared::GLWESecretPrepared, + GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, + compressed::{GGSWCiphertextCompressed, GGSWCiphertextCompressedToMut}, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, }, }; @@ -26,6 +24,95 @@ impl GGSWCiphertextCompressed> { } } +pub trait GGSWCompressedEncryptSk { + fn ggsw_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGSWCiphertextCompressedToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef; +} + +impl GGSWCompressedEncryptSk for Module +where + Module: GLWEEncryptSkInternal + VecZnxAddScalarInplace + VecZnxNormalizeInplace, + Scratch: TakeGLWEPt, +{ + fn ggsw_compressed_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + seed_xa: [u8; 32], + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGSWCiphertextCompressedToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, + { + let res: &mut GGSWCiphertextCompressed<&mut [u8]> = &mut res.to_mut(); + let sk: &GLWESecretPrepared<&[u8], B> = &sk.to_ref(); + let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); + + #[cfg(debug_assertions)] + { + use poulpy_hal::layouts::ZnxInfos; + + assert_eq!(res.rank(), sk.rank()); + assert_eq!(res.n(), sk.n()); + assert_eq!(pt.n() as u32, sk.n()); + } + + let base2k: usize = res.base2k().into(); + let rank: usize = res.rank().into(); + let cols: usize = rank + 1; + let dsize: usize = res.dsize().into(); + + let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&res.glwe_layout()); + + let mut source = Source::new(seed_xa); + + res.seed = vec![[0u8; 32]; res.dnum().0 as usize * cols]; + + for row_i in 0..res.dnum().into() { + tmp_pt.data.zero(); + + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, 0); + self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1); + + for col_j in 0..rank + 1 { + // rlwe encrypt of vec_znx_pt into vec_znx_ct + + let (seed, mut source_xa_tmp) = source.branch(); + + res.seed[row_i * cols + col_j] = seed; + + self.glwe_encrypt_sk_internal( + res.base2k().into(), + res.k().into(), + &mut res.at_mut(row_i, col_j).data, + cols, + true, + Some((&tmp_pt, col_j)), + sk, + &mut source_xa_tmp, + source_xe, + SIGMA, + scratch_1, + ); + } + } + } +} + impl GGSWCiphertextCompressed { #[allow(clippy::too_many_arguments)] pub fn encrypt_sk( @@ -37,71 +124,8 @@ impl GGSWCiphertextCompressed { source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Module: GGSWCompressedEncryptSk, { - #[cfg(debug_assertions)] - { - use poulpy_hal::layouts::ZnxInfos; - - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), sk.n()); - assert_eq!(pt.n() as u32, sk.n()); - } - - let base2k: usize = self.base2k().into(); - let rank: usize = self.rank().into(); - let cols: usize = rank + 1; - let dsize: usize = self.dsize().into(); - - let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&self.glwe_layout()); - - let mut source = Source::new(seed_xa); - - self.seed = vec![[0u8; 32]; self.dnum().0 as usize * cols]; - - (0..self.dnum().into()).for_each(|row_i| { - tmp_pt.data.zero(); - - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, 0); - module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1); - - (0..rank + 1).for_each(|col_j| { - // rlwe encrypt of vec_znx_pt into vec_znx_ct - - let (seed, mut source_xa_tmp) = source.branch(); - - self.seed[row_i * cols + col_j] = seed; - - glwe_encrypt_sk_internal( - module, - self.base2k().into(), - self.k().into(), - &mut self.at_mut(row_i, col_j).data, - cols, - true, - Some((&tmp_pt, col_j)), - sk, - &mut source_xa_tmp, - source_xe, - SIGMA, - scratch_1, - ); - }); - }); + module.ggsw_compressed_encrypt_sk(self, pt, sk, seed_xa, source_xe, scratch); } } diff --git a/poulpy-core/src/encryption/gglwe_atk.rs b/poulpy-core/src/encryption/gglwe_atk.rs index 6d45b37..9065c2e 100644 --- a/poulpy-core/src/encryption/gglwe_atk.rs +++ b/poulpy-core/src/encryption/gglwe_atk.rs @@ -5,20 +5,23 @@ use poulpy_hal::{ VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing, }, - layouts::{Backend, DataMut, DataRef, Module, Scratch}, + layouts::{Backend, DataMut, Module, Scratch}, source::Source, }; use crate::{ TakeGLWESecret, TakeGLWESecretPrepared, - layouts::{GGLWEAutomorphismKey, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, GLWESecret, LWEInfos}, + layouts::{ + GGLWEAutomorphismKey, GGLWEAutomorphismKeyToMut, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, GLWESecret, GLWESecretToRef, + LWEInfos, + }, }; impl GGLWEAutomorphismKey> { - pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize + pub fn encrypt_sk_scratch_space(module: &Module, infos: &A) -> usize where A: GGLWEInfos, - Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, + Module: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes, { assert_eq!( infos.rank_in(), @@ -28,7 +31,7 @@ impl GGLWEAutomorphismKey> { GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) + GLWESecret::alloc_bytes(&infos.glwe_layout()) } - pub fn encrypt_pk_scratch_space(module: &Module, _infos: &A) -> usize + pub fn encrypt_pk_scratch_space(module: &Module, _infos: &A) -> usize where A: GGLWEInfos, { @@ -41,58 +44,98 @@ impl GGLWEAutomorphismKey> { } } -impl GGLWEAutomorphismKey { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, +pub trait GGLWEAutomorphismKeyEncryptSk { + fn gglwe_automorphism_key_encrypt_sk( + &self, + res: &mut A, p: i64, - sk: &GLWESecret, + sk: &B, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, + scratch: &mut Scratch, ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub - + SvpPrepare - + VecZnxSwitchRing - + SvpPPolAllocBytes - + VecZnxAutomorphism, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, + A: GGLWEAutomorphismKeyToMut, + B: GLWESecretToRef; +} + +impl GGLWEAutomorphismKey +where + Self: GGLWEAutomorphismKeyToMut, +{ + pub fn encrypt_sk( + &mut self, + module: &Module, + p: i64, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S: GLWESecretToRef, + Module: GGLWEAutomorphismKeyEncryptSk, { + module.gglwe_automorphism_key_encrypt_sk(self, p, sk, source_xa, source_xe, scratch); + } +} + +impl GGLWEAutomorphismKeyEncryptSk for Module +where + Module: VecZnxAddScalarInplace + + VecZnxDftAllocBytes + + VecZnxBigNormalize + + VecZnxDftApply + + SvpApplyDftToDftInplace + + VecZnxIdftApplyConsume + + VecZnxNormalizeTmpBytes + + VecZnxFillUniform + + VecZnxSubInplace + + VecZnxAddInplace + + VecZnxNormalizeInplace + + VecZnxAddNormal + + VecZnxNormalize + + VecZnxSub + + SvpPrepare + + VecZnxSwitchRing + + SvpPPolAllocBytes + + VecZnxAutomorphism, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared, +{ + fn gglwe_automorphism_key_encrypt_sk( + &self, + res: &mut A, + p: i64, + sk: &B, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + A: GGLWEAutomorphismKeyToMut, + B: GLWESecretToRef, + { + let res: &mut GGLWEAutomorphismKey<&mut [u8]> = &mut res.to_mut(); + let sk: &GLWESecret<&[u8]> = &sk.to_ref(); + #[cfg(debug_assertions)] { use crate::layouts::{GLWEInfos, LWEInfos}; - assert_eq!(self.n(), sk.n()); - assert_eq!(self.rank_out(), self.rank_in()); - assert_eq!(sk.rank(), self.rank_out()); + assert_eq!(res.n(), sk.n()); + assert_eq!(res.rank_out(), res.rank_in()); + assert_eq!(sk.rank(), res.rank_out()); assert!( - scratch.available() >= GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, self), + scratch.available() >= GGLWEAutomorphismKey::encrypt_sk_scratch_space(self, res), "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space: {:?}", scratch.available(), - GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, self) + GGLWEAutomorphismKey::encrypt_sk_scratch_space(self, res) ) } let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank()); { - (0..self.rank_out().into()).for_each(|i| { - module.vec_znx_automorphism( - module.galois_element_inv(p), + (0..res.rank_out().into()).for_each(|i| { + self.vec_znx_automorphism( + self.galois_element_inv(p), &mut sk_out.data.as_vec_znx_mut(), i, &sk.data.as_vec_znx(), @@ -101,9 +144,9 @@ impl GGLWEAutomorphismKey { }); } - self.key - .encrypt_sk(module, sk, &sk_out, source_xa, source_xe, scratch_1); + res.key + .encrypt_sk(self, sk, &sk_out, source_xa, source_xe, scratch_1); - self.p = p; + res.p = p; } } diff --git a/poulpy-core/src/encryption/gglwe_ct.rs b/poulpy-core/src/encryption/gglwe_ct.rs index 51054cb..3928cf4 100644 --- a/poulpy-core/src/encryption/gglwe_ct.rs +++ b/poulpy-core/src/encryption/gglwe_ct.rs @@ -1,16 +1,19 @@ use poulpy_hal::{ api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, + ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAddScalarInplace, VecZnxDftAllocBytes, VecZnxNormalizeInplace, + VecZnxNormalizeTmpBytes, }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero}, + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxZero}, source::Source, }; use crate::{ TakeGLWEPt, - layouts::{GGLWECiphertext, GGLWEInfos, GLWECiphertext, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared}, + encryption::glwe_ct::GLWEEncryptSk, + layouts::{ + GGLWECiphertext, GGLWECiphertextToMut, GGLWEInfos, GLWECiphertext, GLWEPlaintext, LWEInfos, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, + }, }; impl GGLWECiphertext> { @@ -31,78 +34,89 @@ impl GGLWECiphertext> { } } -impl GGLWECiphertext { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( - &mut self, - module: &Module, - pt: &ScalarZnx, - sk: &GLWESecretPrepared, +pub trait GGLWEEncryptSk { + fn gglwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + R: GGLWECiphertextToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef; +} + +impl GGLWEEncryptSk for Module +where + Module: + GLWEEncryptSk + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxAddScalarInplace + VecZnxNormalizeInplace, + Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, +{ + fn gglwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGLWECiphertextToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, { + let res: &mut GGLWECiphertext<&mut [u8]> = &mut res.to_mut(); + let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); + #[cfg(debug_assertions)] { use poulpy_hal::layouts::ZnxInfos; + let sk: GLWESecretPrepared<&[u8], B> = sk.to_ref(); assert_eq!( - self.rank_in(), + res.rank_in(), pt.cols() as u32, - "self.rank_in(): {} != pt.cols(): {}", - self.rank_in(), + "res.rank_in(): {} != pt.cols(): {}", + res.rank_in(), pt.cols() ); assert_eq!( - self.rank_out(), + res.rank_out(), sk.rank(), - "self.rank_out(): {} != sk.rank(): {}", - self.rank_out(), + "res.rank_out(): {} != sk.rank(): {}", + res.rank_out(), sk.rank() ); - assert_eq!(self.n(), sk.n()); + assert_eq!(res.n(), sk.n()); assert_eq!(pt.n() as u32, sk.n()); assert!( - scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self), - "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}", + scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(self, res), + "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(self, res.rank()={}, res.size()={}): {}", scratch.available(), - self.rank_out(), - self.size(), - GGLWECiphertext::encrypt_sk_scratch_space(module, self) + res.rank_out(), + res.size(), + GGLWECiphertext::encrypt_sk_scratch_space(self, res) ); assert!( - self.dnum().0 * self.dsize().0 * self.base2k().0 <= self.k().0, - "self.dnum() : {} * self.dsize() : {} * self.base2k() : {} = {} >= self.k() = {}", - self.dnum(), - self.dsize(), - self.base2k(), - self.dnum().0 * self.dsize().0 * self.base2k().0, - self.k() + res.dnum().0 * res.dsize().0 * res.base2k().0 <= res.k().0, + "res.dnum() : {} * res.dsize() : {} * res.base2k() : {} = {} >= res.k() = {}", + res.dnum(), + res.dsize(), + res.base2k(), + res.dnum().0 * res.dsize().0 * res.base2k().0, + res.k() ); } - let dnum: usize = self.dnum().into(); - let dsize: usize = self.dsize().into(); - let base2k: usize = self.base2k().into(); - let rank_in: usize = self.rank_in().into(); + let dnum: usize = res.dnum().into(); + let dsize: usize = res.dsize().into(); + let base2k: usize = res.base2k().into(); + let rank_in: usize = res.rank_in().into(); - let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(self); + let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(res); // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns // // Example for ksk rank 2 to rank 3: @@ -114,17 +128,39 @@ impl GGLWECiphertext { // // (-(a*s) + s0, a) // (-(b*s) + s1, b) - (0..rank_in).for_each(|col_i| { - (0..dnum).for_each(|row_i| { + + for col_i in 0..rank_in { + for row_i in 0..dnum { // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt tmp_pt.data.zero(); // zeroes for next iteration - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, col_i); - module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); - - // rlwe encrypt of vec_znx_pt into vec_znx_ct - self.at_mut(row_i, col_i) - .encrypt_sk(module, &tmp_pt, sk, source_xa, source_xe, scrach_1); - }); - }); + self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, col_i); + self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1); + self.glwe_encrypt_sk( + &mut res.at_mut(row_i, col_i), + &tmp_pt, + sk, + source_xa, + source_xe, + scrach_1, + ); + } + } + } +} + +impl GGLWECiphertext { + #[allow(clippy::too_many_arguments)] + pub fn encrypt_sk( + &mut self, + module: &Module, + pt: &ScalarZnx, + sk: &GLWESecretPrepared, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + Module: GGLWEEncryptSk, + { + module.gglwe_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch); } } diff --git a/poulpy-core/src/encryption/ggsw_ct.rs b/poulpy-core/src/encryption/ggsw_ct.rs index 6195458..d9e35bd 100644 --- a/poulpy-core/src/encryption/ggsw_ct.rs +++ b/poulpy-core/src/encryption/ggsw_ct.rs @@ -1,16 +1,16 @@ use poulpy_hal::{ - api::{ - ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, - VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, - VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, ZnxZero}, + api::{VecZnxAddScalarInplace, VecZnxDftAllocBytes, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes}, + layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToRef, Scratch, VecZnx, ZnxZero}, source::Source, }; use crate::{ - TakeGLWEPt, - layouts::{GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GLWESecretPrepared}, + SIGMA, TakeGLWEPt, + encryption::glwe_ct::GLWEEncryptSkInternal, + layouts::{ + GGSWCiphertext, GGSWCiphertextToMut, GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, + }, }; impl GGSWCiphertext> { @@ -27,6 +27,85 @@ impl GGSWCiphertext> { } } +pub trait GGSWEncryptSk { + fn ggsw_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGSWCiphertextToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef; +} + +impl GGSWEncryptSk for Module +where + Module: GLWEEncryptSkInternal + VecZnxAddScalarInplace + VecZnxNormalizeInplace, + Scratch: TakeGLWEPt, +{ + fn ggsw_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GGSWCiphertextToMut, + P: ScalarZnxToRef, + S: GLWESecretPreparedToRef, + { + let res: &mut GGSWCiphertext<&mut [u8]> = &mut res.to_mut(); + let pt: &ScalarZnx<&[u8]> = &pt.to_ref(); + let sk: &GLWESecretPrepared<&[u8], B> = &sk.to_ref(); + + #[cfg(debug_assertions)] + { + use poulpy_hal::layouts::ZnxInfos; + + assert_eq!(res.rank(), sk.rank()); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(pt.n(), self.n()); + assert_eq!(sk.n(), self.n() as u32); + } + + let k: usize = res.k().into(); + let base2k: usize = res.base2k().into(); + let rank: usize = res.rank().into(); + let dsize: usize = res.dsize().into(); + let cols: usize = (rank + 1).into(); + + let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&res.glwe_layout()); + + for row_i in 0..res.dnum().into() { + tmp_pt.data.zero(); + // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt + self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, 0); + self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1); + for col_j in 0..rank + 1 { + self.glwe_encrypt_sk_internal( + base2k, + k, + res.at_mut(row_i, col_j).data_mut(), + cols, + false, + Some((&tmp_pt, col_j)), + sk, + source_xa, + source_xe, + SIGMA, + scratch_1, + ); + } + } + } +} + impl GGSWCiphertext { #[allow(clippy::too_many_arguments)] pub fn encrypt_sk( @@ -38,56 +117,8 @@ impl GGSWCiphertext { source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: VecZnxAddScalarInplace - + VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + Module: GGSWEncryptSk, { - #[cfg(debug_assertions)] - { - use poulpy_hal::layouts::ZnxInfos; - - assert_eq!(self.rank(), sk.rank()); - assert_eq!(self.n(), sk.n()); - assert_eq!(pt.n() as u32, sk.n()); - } - - let base2k: usize = self.base2k().into(); - let rank: usize = self.rank().into(); - let dsize: usize = self.dsize().into(); - - let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&self.glwe_layout()); - - (0..self.dnum().into()).for_each(|row_i| { - tmp_pt.data.zero(); - - // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, 0); - module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1); - - (0..rank + 1).for_each(|col_j| { - // rlwe encrypt of vec_znx_pt into vec_znx_ct - - self.at_mut(row_i, col_j).encrypt_sk_internal( - module, - Some((&tmp_pt, col_j)), - sk, - source_xa, - source_xe, - scratch_1, - ); - }); - }); + module.ggsw_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch); } } diff --git a/poulpy-core/src/encryption/glwe_ct.rs b/poulpy-core/src/encryption/glwe_ct.rs index 8ecacc6..d05ffc6 100644 --- a/poulpy-core/src/encryption/glwe_ct.rs +++ b/poulpy-core/src/encryption/glwe_ct.rs @@ -5,7 +5,7 @@ use poulpy_hal::{ VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, }, - layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, ZnxInfos, ZnxZero}, + layouts::{Backend, DataMut, Module, ScalarZnx, Scratch, VecZnx, VecZnxBig, VecZnxToMut, ZnxInfos, ZnxZero}, source::Source, }; @@ -13,8 +13,8 @@ use crate::{ dist::Distribution, encryption::{SIGMA, SIGMA_BOUND}, layouts::{ - GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, - prepared::{GLWEPublicKeyPrepared, GLWESecretPrepared}, + GLWECiphertext, GLWECiphertextToMut, GLWEInfos, GLWEPlaintext, GLWEPlaintextToRef, LWEInfos, + prepared::{GLWEPublicKeyPrepared, GLWEPublicKeyPreparedToRef, GLWESecretPrepared, GLWESecretPreparedToRef}, }, }; @@ -44,126 +44,127 @@ impl GLWECiphertext> { } } -impl GLWECiphertext { - #[allow(clippy::too_many_arguments)] - pub fn encrypt_sk( +impl GLWECiphertext { + pub fn encrypt_sk( &mut self, module: &Module, - pt: &GLWEPlaintext, - sk: &GLWESecretPrepared, + pt: &P, + sk: &S, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef, + Module: GLWEEncryptSk, { + module.glwe_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch); + } + + pub fn encrypt_zero_sk( + &mut self, + module: &Module, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + S: GLWESecretPreparedToRef, + Module: GLWEEncryptZeroSk, + { + module.glwe_encrypt_zero_sk(self, sk, source_xa, source_xe, scratch); + } + + pub fn encrypt_pk( + &mut self, + module: &Module, + pt: &P, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + P: GLWEPlaintextToRef, + K: GLWEPublicKeyPreparedToRef, + Module: GLWEEncryptPk, + { + module.glwe_encrypt_pk(self, pt, pk, source_xu, source_xe, scratch); + } + + pub fn encrypt_zero_pk( + &mut self, + module: &Module, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + K: GLWEPublicKeyPreparedToRef, + Module: GLWEEncryptZeroPk, + { + module.glwe_encrypt_zero_pk(self, pk, source_xu, source_xe, scratch); + } +} + +pub trait GLWEEncryptSk { + fn glwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWECiphertextToMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef; +} + +impl GLWEEncryptSk for Module +where + Module: GLWEEncryptSkInternal + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, + Scratch: ScratchAvailable, +{ + fn glwe_encrypt_sk( + &self, + res: &mut R, + pt: &P, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWECiphertextToMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef, + { + let mut res: GLWECiphertext<&mut [u8]> = res.to_mut(); + let pt: GLWEPlaintext<&[u8]> = pt.to_ref(); + #[cfg(debug_assertions)] { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(sk.n(), self.n()); - assert_eq!(pt.n(), self.n()); + let sk: GLWESecretPrepared<&[u8], B> = sk.to_ref(); + assert_eq!(res.rank(), sk.rank()); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(sk.n(), self.n() as u32); + assert_eq!(pt.n(), self.n() as u32); assert!( - scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self), + scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(self, &res), "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", scratch.available(), - GLWECiphertext::encrypt_sk_scratch_space(module, self) + GLWECiphertext::encrypt_sk_scratch_space(self, &res) ) } - self.encrypt_sk_internal(module, Some((pt, 0)), sk, source_xa, source_xe, scratch); - } - - pub fn encrypt_zero_sk( - &mut self, - module: &Module, - sk: &GLWESecretPrepared, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - #[cfg(debug_assertions)] - { - assert_eq!(self.rank(), sk.rank()); - assert_eq!(sk.n(), self.n()); - assert!( - scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self), - "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", - scratch.available(), - GLWECiphertext::encrypt_sk_scratch_space(module, self) - ) - } - self.encrypt_sk_internal( - module, - None::<(&GLWEPlaintext>, usize)>, - sk, - source_xa, - source_xe, - scratch, - ); - } - - #[allow(clippy::too_many_arguments)] - pub(crate) fn encrypt_sk_internal( - &mut self, - module: &Module, - pt: Option<(&GLWEPlaintext, usize)>, - sk: &GLWESecretPrepared, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - ) where - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - Scratch: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, - { - let cols: usize = (self.rank() + 1).into(); - glwe_encrypt_sk_internal( - module, - self.base2k().into(), - self.k().into(), - &mut self.data, + let cols: usize = (res.rank() + 1).into(); + self.glwe_encrypt_sk_internal( + res.base2k().into(), + res.k().into(), + res.data_mut(), cols, false, - pt, + Some((&pt, 0)), sk, source_xa, source_xe, @@ -171,46 +172,136 @@ impl GLWECiphertext { scratch, ); } +} - #[allow(clippy::too_many_arguments)] - pub fn encrypt_pk( - &mut self, - module: &Module, - pt: &GLWEPlaintext, - pk: &GLWEPublicKeyPrepared, - source_xu: &mut Source, +pub trait GLWEEncryptZeroSk { + fn glwe_encrypt_zero_sk( + &self, + res: &mut R, + sk: &S, + source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: SvpPrepare - + SvpApplyDftToDft - + VecZnxIdftApplyConsume - + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeSvpPPol + TakeScalarZnx + TakeVecZnxDft, + R: GLWECiphertextToMut, + S: GLWESecretPreparedToRef; +} + +impl GLWEEncryptZeroSk for Module +where + Module: GLWEEncryptSkInternal + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, + Scratch: ScratchAvailable, +{ + fn glwe_encrypt_zero_sk( + &self, + res: &mut R, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWECiphertextToMut, + S: GLWESecretPreparedToRef, { - self.encrypt_pk_internal::(module, Some((pt, 0)), pk, source_xu, source_xe, scratch); + let mut res: GLWECiphertext<&mut [u8]> = res.to_mut(); + + #[cfg(debug_assertions)] + { + let sk: GLWESecretPrepared<&[u8], B> = sk.to_ref(); + assert_eq!(res.rank(), sk.rank()); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(sk.n(), self.n() as u32); + assert!( + scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(self, &res), + "scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}", + scratch.available(), + GLWECiphertext::encrypt_sk_scratch_space(self, &res) + ) + } + + let cols: usize = (res.rank() + 1).into(); + self.glwe_encrypt_sk_internal( + res.base2k().into(), + res.k().into(), + res.data_mut(), + cols, + false, + None::<(&GLWEPlaintext>, usize)>, + sk, + source_xa, + source_xe, + SIGMA, + scratch, + ); } +} - pub fn encrypt_zero_pk( - &mut self, - module: &Module, - pk: &GLWEPublicKeyPrepared, +pub trait GLWEEncryptPk { + fn glwe_encrypt_pk( + &self, + res: &mut R, + pt: &P, + pk: &K, source_xu: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: SvpPrepare - + SvpApplyDftToDft - + VecZnxIdftApplyConsume - + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeSvpPPol + TakeScalarZnx + TakeVecZnxDft, + R: GLWECiphertextToMut, + P: GLWEPlaintextToRef, + K: GLWEPublicKeyPreparedToRef; +} + +impl GLWEEncryptPk for Module +where + Module: GLWEEncryptPkInternal, +{ + fn glwe_encrypt_pk( + &self, + res: &mut R, + pt: &P, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWECiphertextToMut, + P: GLWEPlaintextToRef, + K: GLWEPublicKeyPreparedToRef, { - self.encrypt_pk_internal::, DataPk, B>( - module, + self.glwe_encrypt_pk_internal(res, Some((pt, 0)), pk, source_xu, source_xe, scratch); + } +} + +pub trait GLWEEncryptZeroPk { + fn glwe_encrypt_zero_pk( + &self, + res: &mut R, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWECiphertextToMut, + K: GLWEPublicKeyPreparedToRef; +} + +impl GLWEEncryptZeroPk for Module +where + Module: GLWEEncryptPkInternal, +{ + fn glwe_encrypt_zero_pk( + &self, + res: &mut R, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWECiphertextToMut, + K: GLWEPublicKeyPreparedToRef, + { + self.glwe_encrypt_pk_internal( + res, None::<(&GLWEPlaintext>, usize)>, pk, source_xu, @@ -218,45 +309,69 @@ impl GLWECiphertext { scratch, ); } +} - #[allow(clippy::too_many_arguments)] - pub(crate) fn encrypt_pk_internal( - &mut self, - module: &Module, - pt: Option<(&GLWEPlaintext, usize)>, - pk: &GLWEPublicKeyPrepared, +pub(crate) trait GLWEEncryptPkInternal { + fn glwe_encrypt_pk_internal( + &self, + res: &mut R, + pt: Option<(&P, usize)>, + pk: &K, source_xu: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, ) where - Module: SvpPrepare - + SvpApplyDftToDft - + VecZnxIdftApplyConsume - + VecZnxBigAddNormal - + VecZnxBigAddSmallInplace - + VecZnxBigNormalize, - Scratch: TakeSvpPPol + TakeScalarZnx + TakeVecZnxDft, + R: GLWECiphertextToMut, + P: GLWEPlaintextToRef, + K: GLWEPublicKeyPreparedToRef; +} + +impl GLWEEncryptPkInternal for Module +where + Module: SvpPrepare + + SvpApplyDftToDft + + VecZnxIdftApplyConsume + + VecZnxBigAddNormal + + VecZnxBigAddSmallInplace + + VecZnxBigNormalize, + Scratch: TakeSvpPPol + TakeScalarZnx + TakeVecZnxDft, +{ + fn glwe_encrypt_pk_internal( + &self, + res: &mut R, + pt: Option<(&P, usize)>, + pk: &K, + source_xu: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + ) where + R: GLWECiphertextToMut, + P: GLWEPlaintextToRef, + K: GLWEPublicKeyPreparedToRef, { + let res: &mut GLWECiphertext<&mut [u8]> = &mut res.to_mut(); + let pk: &GLWEPublicKeyPrepared<&[u8], B> = &pk.to_ref(); + #[cfg(debug_assertions)] { - assert_eq!(self.base2k(), pk.base2k()); - assert_eq!(self.n(), pk.n()); - assert_eq!(self.rank(), pk.rank()); + assert_eq!(res.base2k(), pk.base2k()); + assert_eq!(res.n(), pk.n()); + assert_eq!(res.rank(), pk.rank()); if let Some((pt, _)) = pt { - assert_eq!(pt.base2k(), pk.base2k()); - assert_eq!(pt.n(), pk.n()); + assert_eq!(pt.to_ref().base2k(), pk.base2k()); + assert_eq!(pt.to_ref().n(), pk.n()); } } let base2k: usize = pk.base2k().into(); let size_pk: usize = pk.size(); - let cols: usize = (self.rank() + 1).into(); + let cols: usize = (res.rank() + 1).into(); // Generates u according to the underlying secret distribution. - let (mut u_dft, scratch_1) = scratch.take_svp_ppol(self.n().into(), 1); + let (mut u_dft, scratch_1) = scratch.take_svp_ppol(res.n().into(), 1); { - let (mut u, _) = scratch_1.take_scalar_znx(self.n().into(), 1); + let (mut u, _) = scratch_1.take_scalar_znx(res.n().into(), 1); match pk.dist { Distribution::NONE => panic!( "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \ @@ -270,20 +385,20 @@ impl GLWECiphertext { Distribution::ZERO => {} } - module.svp_prepare(&mut u_dft, 0, &u, 0); + self.svp_prepare(&mut u_dft, 0, &u, 0); } // ct[i] = pk[i] * u + ei (+ m if col = i) (0..cols).for_each(|i| { - let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(self.n().into(), 1, size_pk); + let (mut ci_dft, scratch_2) = scratch_1.take_vec_znx_dft(res.n().into(), 1, size_pk); // ci_dft = DFT(u) * DFT(pk[i]) - module.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i); + self.svp_apply_dft_to_dft(&mut ci_dft, 0, &u_dft, 0, &pk.data, i); // ci_big = u * p[i] - let mut ci_big = module.vec_znx_idft_apply_consume(ci_dft); + let mut ci_big = self.vec_znx_idft_apply_consume(ci_dft); // ci_big = u * pk[i] + e - module.vec_znx_big_add_normal( + self.vec_znx_big_add_normal( base2k, &mut ci_big, 0, @@ -297,30 +412,37 @@ impl GLWECiphertext { if let Some((pt, col)) = pt && col == i { - module.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.data, 0); + self.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.to_ref().data, 0); } // ct[i] = norm(ci_big) - module.vec_znx_big_normalize(base2k, &mut self.data, i, base2k, &ci_big, 0, scratch_2); + self.vec_znx_big_normalize(base2k, &mut res.data, i, base2k, &ci_big, 0, scratch_2); }); } } -#[allow(clippy::too_many_arguments)] -pub(crate) fn glwe_encrypt_sk_internal( - module: &Module, - base2k: usize, - k: usize, - ct: &mut VecZnx, - cols: usize, - compressed: bool, - pt: Option<(&GLWEPlaintext, usize)>, - sk: &GLWESecretPrepared, - source_xa: &mut Source, - source_xe: &mut Source, - sigma: f64, - scratch: &mut Scratch, -) where +pub(crate) trait GLWEEncryptSkInternal { + fn glwe_encrypt_sk_internal( + &self, + base2k: usize, + k: usize, + res: &mut R, + cols: usize, + compressed: bool, + pt: Option<(&P, usize)>, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef; +} + +impl GLWEEncryptSkInternal for Module +where Module: VecZnxDftAllocBytes + VecZnxBigNormalize + VecZnxDftApply @@ -336,72 +458,94 @@ pub(crate) fn glwe_encrypt_sk_internal: TakeVecZnxDft + ScratchAvailable + TakeVecZnx, { - #[cfg(debug_assertions)] + fn glwe_encrypt_sk_internal( + &self, + base2k: usize, + k: usize, + res: &mut R, + cols: usize, + compressed: bool, + pt: Option<(&P, usize)>, + sk: &S, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + scratch: &mut Scratch, + ) where + R: VecZnxToMut, + P: GLWEPlaintextToRef, + S: GLWESecretPreparedToRef, { - if compressed { - assert_eq!( - ct.cols(), - 1, - "invalid ciphertext: compressed tag=true but #cols={} != 1", - ct.cols() - ) - } - } + let ct: &mut VecZnx<&mut [u8]> = &mut res.to_mut(); + let sk: GLWESecretPrepared<&[u8], B> = sk.to_ref(); - let size: usize = ct.size(); - - let (mut c0, scratch_1) = scratch.take_vec_znx(ct.n(), 1, size); - c0.zero(); - - { - let (mut ci, scratch_2) = scratch_1.take_vec_znx(ct.n(), 1, size); - - // ct[i] = uniform - // ct[0] -= c[i] * s[i], - (1..cols).for_each(|i| { - let col_ct: usize = if compressed { 0 } else { i }; - - // ct[i] = uniform (+ pt) - module.vec_znx_fill_uniform(base2k, ct, col_ct, source_xa); - - let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(ct.n(), 1, size); - - // ci = ct[i] - pt - // i.e. we act as we sample ct[i] already as uniform + pt - // and if there is a pt, then we subtract it before applying DFT - if let Some((pt, col)) = pt { - if i == col { - module.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.data, 0); - module.vec_znx_normalize_inplace(base2k, &mut ci, 0, scratch_3); - module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &ci, 0); - } else { - module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct); - } - } else { - module.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct); + #[cfg(debug_assertions)] + { + if compressed { + assert_eq!( + ct.cols(), + 1, + "invalid ciphertext: compressed tag=true but #cols={} != 1", + ct.cols() + ) } + } - module.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1); - let ci_big: VecZnxBig<&mut [u8], B> = module.vec_znx_idft_apply_consume(ci_dft); + let size: usize = ct.size(); - // use c[0] as buffer, which is overwritten later by the normalization step - module.vec_znx_big_normalize(base2k, &mut ci, 0, base2k, &ci_big, 0, scratch_3); + let (mut c0, scratch_1) = scratch.take_vec_znx(ct.n(), 1, size); + c0.zero(); - // c0_tmp = -c[i] * s[i] (use c[0] as buffer) - module.vec_znx_sub_inplace(&mut c0, 0, &ci, 0); - }); + { + let (mut ci, scratch_2) = scratch_1.take_vec_znx(ct.n(), 1, size); + + // ct[i] = uniform + // ct[0] -= c[i] * s[i], + (1..cols).for_each(|i| { + let col_ct: usize = if compressed { 0 } else { i }; + + // ct[i] = uniform (+ pt) + self.vec_znx_fill_uniform(base2k, ct, col_ct, source_xa); + + let (mut ci_dft, scratch_3) = scratch_2.take_vec_znx_dft(ct.n(), 1, size); + + // ci = ct[i] - pt + // i.e. we act as we sample ct[i] already as uniform + pt + // and if there is a pt, then we subtract it before applying DFT + if let Some((pt, col)) = pt { + if i == col { + self.vec_znx_sub(&mut ci, 0, ct, col_ct, &pt.to_ref().data, 0); + self.vec_znx_normalize_inplace(base2k, &mut ci, 0, scratch_3); + self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &ci, 0); + } else { + self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct); + } + } else { + self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, ct, col_ct); + } + + self.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1); + let ci_big: VecZnxBig<&mut [u8], B> = self.vec_znx_idft_apply_consume(ci_dft); + + // use c[0] as buffer, which is overwritten later by the normalization step + self.vec_znx_big_normalize(base2k, &mut ci, 0, base2k, &ci_big, 0, scratch_3); + + // c0_tmp = -c[i] * s[i] (use c[0] as buffer) + self.vec_znx_sub_inplace(&mut c0, 0, &ci, 0); + }); + } + + // c[0] += e + self.vec_znx_add_normal(base2k, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND); + + // c[0] += m if col = 0 + if let Some((pt, col)) = pt + && col == 0 + { + self.vec_znx_add_inplace(&mut c0, 0, &pt.to_ref().data, 0); + } + + // c[0] = norm(c[0]) + self.vec_znx_normalize(base2k, ct, 0, base2k, &c0, 0, scratch_1); } - - // c[0] += e - module.vec_znx_add_normal(base2k, &mut c0, 0, k, source_xe, sigma, SIGMA_BOUND); - - // c[0] += m if col = 0 - if let Some((pt, col)) = pt - && col == 0 - { - module.vec_znx_add_inplace(&mut c0, 0, &pt.data, 0); - } - - // c[0] = norm(c[0]) - module.vec_znx_normalize(base2k, ct, 0, base2k, &c0, 0, scratch_1); } diff --git a/poulpy-core/src/encryption/glwe_pk.rs b/poulpy-core/src/encryption/glwe_pk.rs index c7cdaeb..c0de69b 100644 --- a/poulpy-core/src/encryption/glwe_pk.rs +++ b/poulpy-core/src/encryption/glwe_pk.rs @@ -1,50 +1,43 @@ use poulpy_hal::{ - api::{ - ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxAddInplace, VecZnxAddNormal, VecZnxBigNormalize, - VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, - VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, - }, + api::{ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes}, layouts::{Backend, DataMut, DataRef, Module, ScratchOwned}, - oep::{ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxDftImpl, TakeVecZnxImpl}, source::Source, }; -use crate::layouts::{GLWECiphertext, GLWEPublicKey, prepared::GLWESecretPrepared}; +use crate::{ + encryption::glwe_ct::GLWEEncryptZeroSk, + layouts::{ + GLWECiphertext, GLWEPublicKey, GLWEPublicKeyToMut, + prepared::{GLWESecretPrepared, GLWESecretPreparedToRef}, + }, +}; -impl GLWEPublicKey { - pub fn generate_from_sk( - &mut self, - module: &Module, - sk: &GLWESecretPrepared, - source_xa: &mut Source, - source_xe: &mut Source, - ) where - Module:, - Module: VecZnxDftAllocBytes - + VecZnxBigNormalize - + VecZnxDftApply - + SvpApplyDftToDftInplace - + VecZnxIdftApplyConsume - + VecZnxNormalizeTmpBytes - + VecZnxFillUniform - + VecZnxSubInplace - + VecZnxAddInplace - + VecZnxNormalizeInplace - + VecZnxAddNormal - + VecZnxNormalize - + VecZnxSub, - B: Backend - + ScratchOwnedAllocImpl - + ScratchOwnedBorrowImpl - + TakeVecZnxDftImpl - + ScratchAvailableImpl - + TakeVecZnxImpl, +pub trait GLWEPublicKeyGenerate { + fn glwe_public_key_generate(&self, res: &mut R, sk: &S, source_xa: &mut Source, source_xe: &mut Source) + where + R: GLWEPublicKeyToMut, + S: GLWESecretPreparedToRef; +} + +impl GLWEPublicKeyGenerate for Module +where + Module: GLWEEncryptZeroSk + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes, + ScratchOwned: ScratchOwnedAlloc + ScratchOwnedBorrow, +{ + fn glwe_public_key_generate(&self, res: &mut R, sk: &S, source_xa: &mut Source, source_xe: &mut Source) + where + R: GLWEPublicKeyToMut, + S: GLWESecretPreparedToRef, { + let res: &mut GLWEPublicKey<&mut [u8]> = &mut res.to_mut(); + let sk: &GLWESecretPrepared<&[u8], B> = &sk.to_ref(); + #[cfg(debug_assertions)] { use crate::{Distribution, layouts::LWEInfos}; - assert_eq!(self.n(), sk.n()); + assert_eq!(res.n(), self.n() as u32); + assert_eq!(sk.n(), self.n() as u32); if sk.dist == Distribution::NONE { panic!("invalid sk: SecretDistribution::NONE") @@ -52,10 +45,25 @@ impl GLWEPublicKey { } // Its ok to allocate scratch space here since pk is usually generated only once. - let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::encrypt_sk_scratch_space(module, self)); + let mut scratch: ScratchOwned = ScratchOwned::alloc(GLWECiphertext::encrypt_sk_scratch_space(self, res)); - let mut tmp: GLWECiphertext> = GLWECiphertext::alloc(self); - tmp.encrypt_zero_sk(module, sk, source_xa, source_xe, scratch.borrow()); - self.dist = sk.dist; + let mut tmp: GLWECiphertext> = GLWECiphertext::alloc(res); + + tmp.encrypt_zero_sk(self, sk, source_xa, source_xe, scratch.borrow()); + res.dist = sk.dist; + } +} + +impl GLWEPublicKey { + pub fn generate( + &mut self, + module: &Module, + sk: &GLWESecretPrepared, + source_xa: &mut Source, + source_xe: &mut Source, + ) where + Module: GLWEPublicKeyGenerate, + { + module.glwe_public_key_generate(self, sk, source_xa, source_xe); } } diff --git a/poulpy-core/src/encryption/mod.rs b/poulpy-core/src/encryption/mod.rs index 9380933..fb9a459 100644 --- a/poulpy-core/src/encryption/mod.rs +++ b/poulpy-core/src/encryption/mod.rs @@ -11,7 +11,5 @@ mod lwe_ct; mod lwe_ksk; mod lwe_to_glwe_ksk; -pub(crate) use glwe_ct::glwe_encrypt_sk_internal; - pub const SIGMA: f64 = 3.2; pub(crate) const SIGMA_BOUND: f64 = 6.0 * SIGMA; diff --git a/poulpy-core/src/layouts/compressed/gglwe_atk.rs b/poulpy-core/src/layouts/compressed/gglwe_atk.rs index 2a10765..2b6a54d 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_atk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_atk.rs @@ -6,14 +6,14 @@ use poulpy_hal::{ use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWEAutomorphismKey, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, + compressed::{Decompress, GGLWEKeyCompressed, GGLWEKeyCompressedToMut, GGLWEKeyCompressedToRef}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] pub struct GGLWEAutomorphismKeyCompressed { - pub(crate) key: GGLWESwitchingKeyCompressed, + pub(crate) key: GGLWEKeyCompressed, pub(crate) p: i64, } @@ -83,14 +83,14 @@ impl GGLWEAutomorphismKeyCompressed> { { debug_assert_eq!(infos.rank_in(), infos.rank_out()); Self { - key: GGLWESwitchingKeyCompressed::alloc(infos), + key: GGLWEKeyCompressed::alloc(infos), p: 0, } } pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { Self { - key: GGLWESwitchingKeyCompressed::alloc_with(n, base2k, k, rank, rank, dnum, dsize), + key: GGLWEKeyCompressed::alloc_with(n, base2k, k, rank, rank, dnum, dsize), p: 0, } } @@ -100,11 +100,11 @@ impl GGLWEAutomorphismKeyCompressed> { A: GGLWEInfos, { debug_assert_eq!(infos.rank_in(), infos.rank_out()); - GGLWESwitchingKeyCompressed::alloc_bytes(infos) + GGLWEKeyCompressed::alloc_bytes(infos) } pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { - GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, rank, dnum, dsize) + GGLWEKeyCompressed::alloc_bytes_with(n, base2k, k, rank, dnum, dsize) } } @@ -131,3 +131,35 @@ where self.p = other.p; } } + +pub trait GGLWEAutomorphismKeyCompressedToRef { + fn to_ref(&self) -> GGLWEAutomorphismKeyCompressed<&[u8]>; +} + +impl GGLWEAutomorphismKeyCompressedToRef for GGLWEAutomorphismKeyCompressed +where + GGLWEKeyCompressed: GGLWEKeyCompressedToRef, +{ + fn to_ref(&self) -> GGLWEAutomorphismKeyCompressed<&[u8]> { + GGLWEAutomorphismKeyCompressed { + key: self.key.to_ref(), + p: self.p, + } + } +} + +pub trait GGLWEAutomorphismKeyCompressedToMut { + fn to_mut(&mut self) -> GGLWEAutomorphismKeyCompressed<&mut [u8]>; +} + +impl GGLWEAutomorphismKeyCompressedToMut for GGLWEAutomorphismKeyCompressed +where + GGLWEKeyCompressed: GGLWEKeyCompressedToMut, +{ + fn to_mut(&mut self) -> GGLWEAutomorphismKeyCompressed<&mut [u8]> { + GGLWEAutomorphismKeyCompressed { + p: self.p, + key: self.key.to_mut(), + } + } +} diff --git a/poulpy-core/src/layouts/compressed/gglwe_ct.rs b/poulpy-core/src/layouts/compressed/gglwe_ct.rs index f7a4df9..6e14756 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_ct.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_ct.rs @@ -1,6 +1,8 @@ use poulpy_hal::{ api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, WriterTo, ZnxInfos}, + layouts::{ + Backend, Data, DataMut, DataRef, FillUniform, MatZnx, MatZnxToMut, MatZnxToRef, Module, ReaderFrom, WriterTo, ZnxInfos, + }, source::Source, }; @@ -289,3 +291,37 @@ where }); } } + +pub trait GGLWECiphertextCompressedToMut { + fn to_mut(&mut self) -> GGLWECiphertextCompressed<&mut [u8]>; +} + +impl GGLWECiphertextCompressedToMut for GGLWECiphertextCompressed { + fn to_mut(&mut self) -> GGLWECiphertextCompressed<&mut [u8]> { + GGLWECiphertextCompressed { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + seed: self.seed.clone(), + rank_out: self.rank_out, + data: self.data.to_mut(), + } + } +} + +pub trait GGLWECiphertextCompressedToRef { + fn to_ref(&self) -> GGLWECiphertextCompressed<&[u8]>; +} + +impl GGLWECiphertextCompressedToRef for GGLWECiphertextCompressed { + fn to_ref(&self) -> GGLWECiphertextCompressed<&[u8]> { + GGLWECiphertextCompressed { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + seed: self.seed.clone(), + rank_out: self.rank_out, + data: self.data.to_ref(), + } + } +} diff --git a/poulpy-core/src/layouts/compressed/gglwe_ksk.rs b/poulpy-core/src/layouts/compressed/gglwe_ksk.rs index 60d9316..cbb2f8c 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_ksk.rs @@ -6,19 +6,19 @@ use poulpy_hal::{ use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWEInfos, LWEInfos, Rank, TorusPrecision, - compressed::{Decompress, GGLWECiphertextCompressed}, + compressed::{Decompress, GGLWECiphertextCompressed, GGLWECiphertextCompressedToMut, GGLWECiphertextCompressedToRef}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct GGLWESwitchingKeyCompressed { +pub struct GGLWEKeyCompressed { pub(crate) key: GGLWECiphertextCompressed, pub(crate) sk_in_n: usize, // Degree of sk_in pub(crate) sk_out_n: usize, // Degree of sk_out } -impl LWEInfos for GGLWESwitchingKeyCompressed { +impl LWEInfos for GGLWEKeyCompressed { fn n(&self) -> Degree { self.key.n() } @@ -35,13 +35,13 @@ impl LWEInfos for GGLWESwitchingKeyCompressed { self.key.size() } } -impl GLWEInfos for GGLWESwitchingKeyCompressed { +impl GLWEInfos for GGLWEKeyCompressed { fn rank(&self) -> Rank { self.rank_out() } } -impl GGLWEInfos for GGLWESwitchingKeyCompressed { +impl GGLWEInfos for GGLWEKeyCompressed { fn rank_in(&self) -> Rank { self.key.rank_in() } @@ -59,19 +59,19 @@ impl GGLWEInfos for GGLWESwitchingKeyCompressed { } } -impl fmt::Debug for GGLWESwitchingKeyCompressed { +impl fmt::Debug for GGLWEKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{self}") } } -impl FillUniform for GGLWESwitchingKeyCompressed { +impl FillUniform for GGLWEKeyCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.key.fill_uniform(log_bound, source); } } -impl fmt::Display for GGLWESwitchingKeyCompressed { +impl fmt::Display for GGLWEKeyCompressed { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, @@ -81,12 +81,12 @@ impl fmt::Display for GGLWESwitchingKeyCompressed { } } -impl GGLWESwitchingKeyCompressed> { +impl GGLWEKeyCompressed> { pub fn alloc(infos: &A) -> Self where A: GGLWEInfos, { - GGLWESwitchingKeyCompressed { + GGLWEKeyCompressed { key: GGLWECiphertextCompressed::alloc(infos), sk_in_n: 0, sk_out_n: 0, @@ -102,7 +102,7 @@ impl GGLWESwitchingKeyCompressed> { dnum: Dnum, dsize: Dsize, ) -> Self { - GGLWESwitchingKeyCompressed { + GGLWEKeyCompressed { key: GGLWECiphertextCompressed::alloc_with(n, base2k, k, rank_in, rank_out, dnum, dsize), sk_in_n: 0, sk_out_n: 0, @@ -121,7 +121,7 @@ impl GGLWESwitchingKeyCompressed> { } } -impl ReaderFrom for GGLWESwitchingKeyCompressed { +impl ReaderFrom for GGLWEKeyCompressed { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.sk_in_n = reader.read_u64::()? as usize; self.sk_out_n = reader.read_u64::()? as usize; @@ -129,7 +129,7 @@ impl ReaderFrom for GGLWESwitchingKeyCompressed { } } -impl WriterTo for GGLWESwitchingKeyCompressed { +impl WriterTo for GGLWEKeyCompressed { fn write_to(&self, writer: &mut W) -> std::io::Result<()> { writer.write_u64::(self.sk_in_n as u64)?; writer.write_u64::(self.sk_out_n as u64)?; @@ -137,13 +137,47 @@ impl WriterTo for GGLWESwitchingKeyCompressed { } } -impl Decompress> for GGLWESwitchingKey +impl Decompress> for GGLWESwitchingKey where Module: VecZnxFillUniform + VecZnxCopy, { - fn decompress(&mut self, module: &Module, other: &GGLWESwitchingKeyCompressed) { + fn decompress(&mut self, module: &Module, other: &GGLWEKeyCompressed) { self.key.decompress(module, &other.key); self.sk_in_n = other.sk_in_n; self.sk_out_n = other.sk_out_n; } } + +pub trait GGLWEKeyCompressedToMut { + fn to_mut(&mut self) -> GGLWEKeyCompressed<&mut [u8]>; +} + +impl GGLWEKeyCompressedToMut for GGLWEKeyCompressed +where + GGLWECiphertextCompressed: GGLWECiphertextCompressedToMut, +{ + fn to_mut(&mut self) -> GGLWEKeyCompressed<&mut [u8]> { + GGLWEKeyCompressed { + sk_in_n: self.sk_in_n, + sk_out_n: self.sk_out_n, + key: self.key.to_mut(), + } + } +} + +pub trait GGLWEKeyCompressedToRef { + fn to_ref(&self) -> GGLWEKeyCompressed<&[u8]>; +} + +impl GGLWEKeyCompressedToRef for GGLWEKeyCompressed +where + GGLWECiphertextCompressed: GGLWECiphertextCompressedToRef, +{ + fn to_ref(&self) -> GGLWEKeyCompressed<&[u8]> { + GGLWEKeyCompressed { + sk_in_n: self.sk_in_n, + sk_out_n: self.sk_out_n, + key: self.key.to_ref(), + } + } +} diff --git a/poulpy-core/src/layouts/compressed/gglwe_tsk.rs b/poulpy-core/src/layouts/compressed/gglwe_tsk.rs index fef4647..0206788 100644 --- a/poulpy-core/src/layouts/compressed/gglwe_tsk.rs +++ b/poulpy-core/src/layouts/compressed/gglwe_tsk.rs @@ -6,14 +6,14 @@ use poulpy_hal::{ use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWETensorKey, GLWEInfos, LWEInfos, Rank, TorusPrecision, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, + compressed::{Decompress, GGLWEKeyCompressed, GGLWEKeyCompressedToMut, GGLWEKeyCompressedToRef}, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use std::fmt; #[derive(PartialEq, Eq, Clone)] pub struct GGLWETensorKeyCompressed { - pub(crate) keys: Vec>, + pub(crate) keys: Vec>, } impl LWEInfos for GGLWETensorKeyCompressed { @@ -66,7 +66,7 @@ impl FillUniform for GGLWETensorKeyCompressed { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { self.keys .iter_mut() - .for_each(|key: &mut GGLWESwitchingKeyCompressed| key.fill_uniform(log_bound, source)) + .for_each(|key: &mut GGLWEKeyCompressed| key.fill_uniform(log_bound, source)) } } @@ -101,10 +101,10 @@ impl GGLWETensorKeyCompressed> { } pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { - let mut keys: Vec>> = Vec::new(); + let mut keys: Vec>> = Vec::new(); let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); (0..pairs).for_each(|_| { - keys.push(GGLWESwitchingKeyCompressed::alloc_with( + keys.push(GGLWEKeyCompressed::alloc_with( n, base2k, k, @@ -129,7 +129,7 @@ impl GGLWETensorKeyCompressed> { let rank_out: usize = infos.rank_out().into(); let pairs: usize = (((rank_out + 1) * rank_out) >> 1).max(1); pairs - * GGLWESwitchingKeyCompressed::alloc_bytes_with( + * GGLWEKeyCompressed::alloc_bytes_with( infos.n(), infos.base2k(), infos.k(), @@ -141,7 +141,7 @@ impl GGLWETensorKeyCompressed> { pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; - pairs * GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, Rank(1), dnum, dsize) + pairs * GGLWEKeyCompressed::alloc_bytes_with(n, base2k, k, Rank(1), dnum, dsize) } } @@ -172,7 +172,7 @@ impl WriterTo for GGLWETensorKeyCompressed { } impl GGLWETensorKeyCompressed { - pub(crate) fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWESwitchingKeyCompressed { + pub(crate) fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWEKeyCompressed { if i > j { std::mem::swap(&mut i, &mut j); }; @@ -205,3 +205,33 @@ where }); } } + +pub trait GGLWETensorKeyCompressedToMut { + fn to_mut(&mut self) -> GGLWETensorKeyCompressed<&mut [u8]>; +} + +impl GGLWETensorKeyCompressedToMut for GGLWETensorKeyCompressed +where + GGLWEKeyCompressed: GGLWEKeyCompressedToMut, +{ + fn to_mut(&mut self) -> GGLWETensorKeyCompressed<&mut [u8]> { + GGLWETensorKeyCompressed { + keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(), + } + } +} + +pub trait GGLWETensorKeyCompressedToRef { + fn to_ref(&self) -> GGLWETensorKeyCompressed<&[u8]>; +} + +impl GGLWETensorKeyCompressedToRef for GGLWETensorKeyCompressed +where + GGLWEKeyCompressed: GGLWEKeyCompressedToRef, +{ + fn to_ref(&self) -> GGLWETensorKeyCompressed<&[u8]> { + GGLWETensorKeyCompressed { + keys: self.keys.iter().map(|c| c.to_ref()).collect(), + } + } +} diff --git a/poulpy-core/src/layouts/compressed/ggsw_ct.rs b/poulpy-core/src/layouts/compressed/ggsw_ct.rs index f0a62cc..b751bfc 100644 --- a/poulpy-core/src/layouts/compressed/ggsw_ct.rs +++ b/poulpy-core/src/layouts/compressed/ggsw_ct.rs @@ -1,6 +1,8 @@ use poulpy_hal::{ api::{VecZnxCopy, VecZnxFillUniform}, - layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, WriterTo, ZnxInfos}, + layouts::{ + Backend, Data, DataMut, DataRef, FillUniform, MatZnx, MatZnxToMut, MatZnxToRef, Module, ReaderFrom, WriterTo, ZnxInfos, + }, source::Source, }; @@ -235,3 +237,37 @@ where }); } } + +pub trait GGSWCiphertextCompressedToMut { + fn to_mut(&mut self) -> GGSWCiphertextCompressed<&mut [u8]>; +} + +impl GGSWCiphertextCompressedToMut for GGSWCiphertextCompressed { + fn to_mut(&mut self) -> GGSWCiphertextCompressed<&mut [u8]> { + GGSWCiphertextCompressed { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + rank: self.rank(), + seed: self.seed.clone(), + data: self.data.to_mut(), + } + } +} + +pub trait GGSWCiphertextCompressedToRef { + fn to_ref(&self) -> GGSWCiphertextCompressed<&[u8]>; +} + +impl GGSWCiphertextCompressedToRef for GGSWCiphertextCompressed { + fn to_ref(&self) -> GGSWCiphertextCompressed<&[u8]> { + GGSWCiphertextCompressed { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + rank: self.rank(), + seed: self.seed.clone(), + data: self.data.to_ref(), + } + } +} diff --git a/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs b/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs index 63933e8..317e0e5 100644 --- a/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/glwe_to_lwe_ksk.rs @@ -6,11 +6,11 @@ use poulpy_hal::{ }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, compressed::GGLWESwitchingKeyCompressed, + Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, Rank, TorusPrecision, compressed::GGLWEKeyCompressed, }; #[derive(PartialEq, Eq, Clone)] -pub struct GLWEToLWESwitchingKeyCompressed(pub(crate) GGLWESwitchingKeyCompressed); +pub struct GLWEToLWESwitchingKeyCompressed(pub(crate) GGLWEKeyCompressed); impl LWEInfos for GLWEToLWESwitchingKeyCompressed { fn base2k(&self) -> Base2K { @@ -98,11 +98,11 @@ impl GLWEToLWESwitchingKeyCompressed> { 1, "dsize > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" ); - Self(GGLWESwitchingKeyCompressed::alloc(infos)) + Self(GGLWEKeyCompressed::alloc(infos)) } pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self { - Self(GGLWESwitchingKeyCompressed::alloc_with( + Self(GGLWEKeyCompressed::alloc_with( n, base2k, k, @@ -127,10 +127,10 @@ impl GLWEToLWESwitchingKeyCompressed> { 1, "dsize > 1 is unsupported for GLWEToLWESwitchingKeyCompressed" ); - GGLWESwitchingKeyCompressed::alloc_bytes(infos) + GGLWEKeyCompressed::alloc_bytes(infos) } pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum, rank_in: Rank) -> usize { - GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, rank_in, dnum, Dsize(1)) + GGLWEKeyCompressed::alloc_bytes_with(n, base2k, k, rank_in, dnum, Dsize(1)) } } diff --git a/poulpy-core/src/layouts/compressed/lwe_ksk.rs b/poulpy-core/src/layouts/compressed/lwe_ksk.rs index 480707b..9d45d0a 100644 --- a/poulpy-core/src/layouts/compressed/lwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/lwe_ksk.rs @@ -6,12 +6,12 @@ use poulpy_hal::{ use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, LWESwitchingKey, Rank, TorusPrecision, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, + compressed::{Decompress, GGLWEKeyCompressed}, }; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct LWESwitchingKeyCompressed(pub(crate) GGLWESwitchingKeyCompressed); +pub struct LWESwitchingKeyCompressed(pub(crate) GGLWEKeyCompressed); impl LWEInfos for LWESwitchingKeyCompressed { fn base2k(&self) -> Base2K { @@ -103,11 +103,11 @@ impl LWESwitchingKeyCompressed> { 1, "rank_out > 1 is not supported for LWESwitchingKeyCompressed" ); - Self(GGLWESwitchingKeyCompressed::alloc(infos)) + Self(GGLWEKeyCompressed::alloc(infos)) } pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> Self { - Self(GGLWESwitchingKeyCompressed::alloc_with( + Self(GGLWEKeyCompressed::alloc_with( n, base2k, k, @@ -137,11 +137,11 @@ impl LWESwitchingKeyCompressed> { 1, "rank_out > 1 is not supported for LWESwitchingKey" ); - GGLWESwitchingKeyCompressed::alloc_bytes(infos) + GGLWEKeyCompressed::alloc_bytes(infos) } pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { - GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, Rank(1), dnum, Dsize(1)) + GGLWEKeyCompressed::alloc_bytes_with(n, base2k, k, Rank(1), dnum, Dsize(1)) } } diff --git a/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs b/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs index 86c353b..bcc69a4 100644 --- a/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs +++ b/poulpy-core/src/layouts/compressed/lwe_to_glwe_ksk.rs @@ -6,12 +6,12 @@ use poulpy_hal::{ use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWEInfos, GLWEInfos, LWEInfos, LWEToGLWESwitchingKey, Rank, TorusPrecision, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, + compressed::{Decompress, GGLWEKeyCompressed}, }; use std::fmt; #[derive(PartialEq, Eq, Clone)] -pub struct LWEToGLWESwitchingKeyCompressed(pub(crate) GGLWESwitchingKeyCompressed); +pub struct LWEToGLWESwitchingKeyCompressed(pub(crate) GGLWEKeyCompressed); impl LWEInfos for LWEToGLWESwitchingKeyCompressed { fn n(&self) -> Degree { @@ -98,11 +98,11 @@ impl LWEToGLWESwitchingKeyCompressed> { 1, "rank_in > 1 is not supported for LWEToGLWESwitchingKeyCompressed" ); - Self(GGLWESwitchingKeyCompressed::alloc(infos)) + Self(GGLWEKeyCompressed::alloc(infos)) } pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { - Self(GGLWESwitchingKeyCompressed::alloc_with( + Self(GGLWEKeyCompressed::alloc_with( n, base2k, k, @@ -127,11 +127,11 @@ impl LWEToGLWESwitchingKeyCompressed> { 1, "dsize > 1 is not supported for LWEToGLWESwitchingKey" ); - GGLWESwitchingKeyCompressed::alloc_bytes(infos) + GGLWEKeyCompressed::alloc_bytes(infos) } pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { - GGLWESwitchingKeyCompressed::alloc_bytes_with(n, base2k, k, Rank(1), dnum, Dsize(1)) + GGLWEKeyCompressed::alloc_bytes_with(n, base2k, k, Rank(1), dnum, Dsize(1)) } } diff --git a/poulpy-core/src/layouts/gglwe_atk.rs b/poulpy-core/src/layouts/gglwe_atk.rs index 5c786d2..a88db86 100644 --- a/poulpy-core/src/layouts/gglwe_atk.rs +++ b/poulpy-core/src/layouts/gglwe_atk.rs @@ -4,7 +4,8 @@ use poulpy_hal::{ }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GLWECiphertext, GLWEInfos, LWEInfos, Rank, TorusPrecision, + Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWESwitchingKey, GGLWESwitchingKeyToMut, GLWECiphertext, GLWEInfos, LWEInfos, + Rank, TorusPrecision, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -170,6 +171,26 @@ impl GGLWEAutomorphismKey> { } } +pub trait GGLWEAutomorphismKeyToMut { + fn to_mut(&mut self) -> GGLWEAutomorphismKey<&mut [u8]>; +} + +impl GGLWEAutomorphismKeyToMut for GGLWEAutomorphismKey +where + GGLWESwitchingKey: GGLWESwitchingKeyToMut, +{ + fn to_mut(&mut self) -> GGLWEAutomorphismKey<&mut [u8]> { + GGLWEAutomorphismKey { + key: self.key.to_mut(), + p: self.p, + } + } +} + +pub trait GGLWEAutomorphismKeyToRef { + fn to_ref(&self) -> GGLWEAutomorphismKey<&[u8]>; +} + impl GGLWEAutomorphismKey { pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { self.key.at(row, col) diff --git a/poulpy-core/src/layouts/gglwe_ct.rs b/poulpy-core/src/layouts/gglwe_ct.rs index ca8236c..713df90 100644 --- a/poulpy-core/src/layouts/gglwe_ct.rs +++ b/poulpy-core/src/layouts/gglwe_ct.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, WriterTo, ZnxInfos}, + layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, MatZnxToMut, MatZnxToRef, ReaderFrom, WriterTo, ZnxInfos}, source::Source, }; @@ -389,6 +389,36 @@ impl GGLWECiphertext> { } } +pub trait GGLWECiphertextToMut { + fn to_mut(&mut self) -> GGLWECiphertext<&mut [u8]>; +} + +impl GGLWECiphertextToMut for GGLWECiphertext { + fn to_mut(&mut self) -> GGLWECiphertext<&mut [u8]> { + GGLWECiphertext { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + data: self.data.to_mut(), + } + } +} + +pub trait GGLWECiphertextToRef { + fn to_ref(&self) -> GGLWECiphertext<&[u8]>; +} + +impl GGLWECiphertextToRef for GGLWECiphertext { + fn to_ref(&self) -> GGLWECiphertext<&[u8]> { + GGLWECiphertext { + k: self.k(), + base2k: self.base2k(), + dsize: self.dsize(), + data: self.data.to_ref(), + } + } +} + impl ReaderFrom for GGLWECiphertext { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { self.k = TorusPrecision(reader.read_u32::()?); diff --git a/poulpy-core/src/layouts/gglwe_ksk.rs b/poulpy-core/src/layouts/gglwe_ksk.rs index 31a483b..64b1694 100644 --- a/poulpy-core/src/layouts/gglwe_ksk.rs +++ b/poulpy-core/src/layouts/gglwe_ksk.rs @@ -4,7 +4,8 @@ use poulpy_hal::{ }; use crate::layouts::{ - Base2K, Degree, Dnum, Dsize, GGLWECiphertext, GGLWEInfos, GLWECiphertext, GLWEInfos, LWEInfos, Rank, TorusPrecision, + Base2K, Degree, Dnum, Dsize, GGLWECiphertext, GGLWECiphertextToMut, GGLWEInfos, GLWECiphertext, GLWEInfos, LWEInfos, Rank, + TorusPrecision, }; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; @@ -180,6 +181,23 @@ impl GGLWESwitchingKey> { } } +pub trait GGLWESwitchingKeyToMut { + fn to_mut(&mut self) -> GGLWESwitchingKey<&mut [u8]>; +} + +impl GGLWESwitchingKeyToMut for GGLWESwitchingKey +where + GGLWECiphertext: GGLWECiphertextToMut, +{ + fn to_mut(&mut self) -> GGLWESwitchingKey<&mut [u8]> { + GGLWESwitchingKey { + key: self.key.to_mut(), + sk_in_n: self.sk_in_n, + sk_out_n: self.sk_out_n, + } + } +} + impl GGLWESwitchingKey { pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> { self.key.at(row, col) diff --git a/poulpy-core/src/layouts/ggsw_ct.rs b/poulpy-core/src/layouts/ggsw_ct.rs index f1bb228..dc64bce 100644 --- a/poulpy-core/src/layouts/ggsw_ct.rs +++ b/poulpy-core/src/layouts/ggsw_ct.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, WriterTo, ZnxInfos}, + layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, MatZnxToMut, MatZnxToRef, ReaderFrom, WriterTo, ZnxInfos}, source::Source, }; use std::fmt; @@ -370,3 +370,35 @@ impl WriterTo for GGSWCiphertext { self.data.write_to(writer) } } + +pub trait GGSWCiphertextToMut { + fn to_mut(&mut self) -> GGSWCiphertext<&mut [u8]>; +} + +impl GGSWCiphertextToMut for GGSWCiphertext { + fn to_mut(&mut self) -> GGSWCiphertext<&mut [u8]> { + GGSWCiphertext::builder() + .base2k(self.base2k()) + .dsize(self.dsize()) + .k(self.k()) + .data(self.data.to_mut()) + .build() + .unwrap() + } +} + +pub trait GGSWCiphertextToRef { + fn to_ref(&self) -> GGSWCiphertext<&[u8]>; +} + +impl GGSWCiphertextToRef for GGSWCiphertext { + fn to_ref(&self) -> GGSWCiphertext<&[u8]> { + GGSWCiphertext::builder() + .base2k(self.base2k()) + .dsize(self.dsize()) + .k(self.k()) + .data(self.data.to_ref()) + .build() + .unwrap() + } +} diff --git a/poulpy-core/src/layouts/glwe_pk.rs b/poulpy-core/src/layouts/glwe_pk.rs index fc4b0fa..6e947f1 100644 --- a/poulpy-core/src/layouts/glwe_pk.rs +++ b/poulpy-core/src/layouts/glwe_pk.rs @@ -1,4 +1,4 @@ -use poulpy_hal::layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, WriterTo, ZnxInfos}; +use poulpy_hal::layouts::{Data, DataMut, DataRef, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo, ZnxInfos}; use crate::{ dist::Distribution, @@ -207,3 +207,33 @@ impl WriterTo for GLWEPublicKey { self.data.write_to(writer) } } + +pub trait GLWEPublicKeyToRef { + fn to_ref(&self) -> GLWEPublicKey<&[u8]>; +} + +impl GLWEPublicKeyToRef for GLWEPublicKey { + fn to_ref(&self) -> GLWEPublicKey<&[u8]> { + GLWEPublicKey { + data: self.data.to_ref(), + base2k: self.base2k, + k: self.k, + dist: self.dist, + } + } +} + +pub trait GLWEPublicKeyToMut { + fn to_mut(&mut self) -> GLWEPublicKey<&mut [u8]>; +} + +impl GLWEPublicKeyToMut for GLWEPublicKey { + fn to_mut(&mut self) -> GLWEPublicKey<&mut [u8]> { + GLWEPublicKey { + base2k: self.base2k, + k: self.k, + dist: self.dist, + data: self.data.to_mut(), + } + } +} diff --git a/poulpy-core/src/layouts/glwe_pt.rs b/poulpy-core/src/layouts/glwe_pt.rs index b565055..c32f229 100644 --- a/poulpy-core/src/layouts/glwe_pt.rs +++ b/poulpy-core/src/layouts/glwe_pt.rs @@ -200,3 +200,31 @@ impl GLWECiphertextToMut for GLWEPlaintext { .unwrap() } } + +pub trait GLWEPlaintextToRef { + fn to_ref(&self) -> GLWEPlaintext<&[u8]>; +} + +impl GLWEPlaintextToRef for GLWEPlaintext { + fn to_ref(&self) -> GLWEPlaintext<&[u8]> { + GLWEPlaintext { + data: self.data.to_ref(), + base2k: self.base2k, + k: self.k, + } + } +} + +pub trait GLWEPlaintextToMut { + fn to_ref(&mut self) -> GLWEPlaintext<&mut [u8]>; +} + +impl GLWEPlaintextToMut for GLWEPlaintext { + fn to_ref(&mut self) -> GLWEPlaintext<&mut [u8]> { + GLWEPlaintext { + base2k: self.base2k, + k: self.k, + data: self.data.to_mut(), + } + } +} diff --git a/poulpy-core/src/layouts/glwe_sk.rs b/poulpy-core/src/layouts/glwe_sk.rs index 8870d35..e954c3b 100644 --- a/poulpy-core/src/layouts/glwe_sk.rs +++ b/poulpy-core/src/layouts/glwe_sk.rs @@ -1,5 +1,5 @@ use poulpy_hal::{ - layouts::{Data, DataMut, DataRef, ReaderFrom, ScalarZnx, WriterTo, ZnxInfos, ZnxZero}, + layouts::{Data, DataMut, DataRef, ReaderFrom, ScalarZnx, ScalarZnxToMut, ScalarZnxToRef, WriterTo, ZnxInfos, ZnxZero}, source::Source, }; @@ -136,6 +136,32 @@ impl GLWESecret { } } +pub trait GLWESecretToMut { + fn to_mut(&mut self) -> GLWESecret<&mut [u8]>; +} + +impl GLWESecretToMut for GLWESecret { + fn to_mut(&mut self) -> GLWESecret<&mut [u8]> { + GLWESecret { + dist: self.dist, + data: self.data.to_mut(), + } + } +} + +pub trait GLWESecretToRef { + fn to_ref(&self) -> GLWESecret<&[u8]>; +} + +impl GLWESecretToRef for GLWESecret { + fn to_ref(&self) -> GLWESecret<&[u8]> { + GLWESecret { + data: self.data.to_ref(), + dist: self.dist, + } + } +} + impl ReaderFrom for GLWESecret { fn read_from(&mut self, reader: &mut R) -> std::io::Result<()> { match Distribution::read_from(reader) { diff --git a/poulpy-core/src/layouts/prepared/ggsw_ct.rs b/poulpy-core/src/layouts/prepared/ggsw_ct.rs index eb79a5a..a3871f6 100644 --- a/poulpy-core/src/layouts/prepared/ggsw_ct.rs +++ b/poulpy-core/src/layouts/prepared/ggsw_ct.rs @@ -1,6 +1,6 @@ use poulpy_hal::{ api::{VmpPMatAlloc, VmpPMatAllocBytes, VmpPrepare, VmpPrepareTmpBytes}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToRef, ZnxInfos}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToMut, VmpPMatToRef, ZnxInfos}, oep::VmpPMatAllocBytesImpl, }; @@ -295,6 +295,22 @@ where } } +pub trait GGSWCiphertextPreparedToMut { + fn to_ref(&mut self) -> GGSWCiphertextPrepared<&mut [u8], B>; +} + +impl GGSWCiphertextPreparedToMut for GGSWCiphertextPrepared { + fn to_ref(&mut self) -> GGSWCiphertextPrepared<&mut [u8], B> { + GGSWCiphertextPrepared::builder() + .base2k(self.base2k()) + .dsize(self.dsize()) + .k(self.k()) + .data(self.data.to_mut()) + .build() + .unwrap() + } +} + pub trait GGSWCiphertextPreparedToRef { fn to_ref(&self) -> GGSWCiphertextPrepared<&[u8], B>; } diff --git a/poulpy-core/src/layouts/prepared/glwe_pk.rs b/poulpy-core/src/layouts/prepared/glwe_pk.rs index 6834f58..3211017 100644 --- a/poulpy-core/src/layouts/prepared/glwe_pk.rs +++ b/poulpy-core/src/layouts/prepared/glwe_pk.rs @@ -1,6 +1,6 @@ use poulpy_hal::{ api::{VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VecZnxDft, ZnxInfos}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos}, oep::VecZnxDftAllocBytesImpl, }; @@ -205,3 +205,33 @@ where self.dist = other.dist; } } + +pub trait GLWEPublicKeyPreparedToMut { + fn to_mut(&mut self) -> GLWEPublicKeyPrepared<&mut [u8], B>; +} + +impl GLWEPublicKeyPreparedToMut for GLWEPublicKeyPrepared { + fn to_mut(&mut self) -> GLWEPublicKeyPrepared<&mut [u8], B> { + GLWEPublicKeyPrepared { + dist: self.dist, + k: self.k, + base2k: self.base2k, + data: self.data.to_mut(), + } + } +} + +pub trait GLWEPublicKeyPreparedToRef { + fn to_ref(&self) -> GLWEPublicKeyPrepared<&[u8], B>; +} + +impl GLWEPublicKeyPreparedToRef for GLWEPublicKeyPrepared { + fn to_ref(&self) -> GLWEPublicKeyPrepared<&[u8], B> { + GLWEPublicKeyPrepared { + data: self.data.to_ref(), + dist: self.dist, + k: self.k, + base2k: self.base2k, + } + } +} diff --git a/poulpy-core/src/layouts/prepared/glwe_sk.rs b/poulpy-core/src/layouts/prepared/glwe_sk.rs index d3f638b..234ea2a 100644 --- a/poulpy-core/src/layouts/prepared/glwe_sk.rs +++ b/poulpy-core/src/layouts/prepared/glwe_sk.rs @@ -1,6 +1,6 @@ use poulpy_hal::{ api::{SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare}, - layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, SvpPPol, ZnxInfos}, + layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, SvpPPol, SvpPPolToMut, SvpPPolToRef, ZnxInfos}, }; use crate::{ @@ -113,3 +113,29 @@ where self.dist = other.dist } } + +pub trait GLWESecretPreparedToRef { + fn to_ref(&self) -> GLWESecretPrepared<&[u8], B>; +} + +impl GLWESecretPreparedToRef for GLWESecretPrepared { + fn to_ref(&self) -> GLWESecretPrepared<&[u8], B> { + GLWESecretPrepared { + data: self.data.to_ref(), + dist: self.dist, + } + } +} + +pub trait GLWESecretPreparedToMut { + fn to_ref(&mut self) -> GLWESecretPrepared<&mut [u8], B>; +} + +impl GLWESecretPreparedToMut for GLWESecretPrepared { + fn to_ref(&mut self) -> GLWESecretPrepared<&mut [u8], B> { + GLWESecretPrepared { + dist: self.dist, + data: self.data.to_mut(), + } + } +} diff --git a/poulpy-core/src/tests/serialization.rs b/poulpy-core/src/tests/serialization.rs index 8fe477c..6e08395 100644 --- a/poulpy-core/src/tests/serialization.rs +++ b/poulpy-core/src/tests/serialization.rs @@ -4,7 +4,7 @@ use crate::layouts::{ Base2K, Degree, Dnum, Dsize, GGLWEAutomorphismKey, GGLWECiphertext, GGLWESwitchingKey, GGLWETensorKey, GGSWCiphertext, GLWECiphertext, GLWEToLWEKey, LWECiphertext, LWESwitchingKey, LWEToGLWESwitchingKey, Rank, TorusPrecision, compressed::{ - GGLWEAutomorphismKeyCompressed, GGLWECiphertextCompressed, GGLWESwitchingKeyCompressed, GGLWETensorKeyCompressed, + GGLWEAutomorphismKeyCompressed, GGLWECiphertextCompressed, GGLWEKeyCompressed, GGLWETensorKeyCompressed, GGSWCiphertextCompressed, GLWECiphertextCompressed, GLWEToLWESwitchingKeyCompressed, LWECiphertextCompressed, LWESwitchingKeyCompressed, LWEToGLWESwitchingKeyCompressed, }, @@ -63,8 +63,7 @@ fn test_glwe_switching_key_serialization() { #[test] fn test_glwe_switching_key_compressed_serialization() { - let original: GGLWESwitchingKeyCompressed> = - GGLWESwitchingKeyCompressed::alloc_with(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); + let original: GGLWEKeyCompressed> = GGLWEKeyCompressed::alloc_with(N_GLWE, BASE2K, K, RANK, RANK, DNUM, DSIZE); test_reader_writer_interface(original); } diff --git a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs index 60bb7e2..249161b 100644 --- a/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/gglwe_ct.rs @@ -18,7 +18,7 @@ use crate::{ encryption::SIGMA, layouts::{ GGLWECiphertextLayout, GGLWESwitchingKey, GLWESecret, - compressed::{Decompress, GGLWESwitchingKeyCompressed}, + compressed::{Decompress, GGLWEKeyCompressed}, prepared::{GLWESecretPrepared, PrepareAlloc}, }, }; @@ -173,12 +173,12 @@ where rank_out: rank_out.into(), }; - let mut ksk_compressed: GGLWESwitchingKeyCompressed> = GGLWESwitchingKeyCompressed::alloc(&gglwe_infos); + let mut ksk_compressed: GGLWEKeyCompressed> = GGLWEKeyCompressed::alloc(&gglwe_infos); let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); - let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWESwitchingKeyCompressed::encrypt_sk_scratch_space( + let mut scratch: ScratchOwned = ScratchOwned::alloc(GGLWEKeyCompressed::encrypt_sk_scratch_space( module, &gglwe_infos, )); diff --git a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs index a1169f6..35938c3 100644 --- a/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs +++ b/poulpy-core/src/tests/test_suite/encryption/glwe_ct.rs @@ -363,7 +363,7 @@ where let sk_prepared: GLWESecretPrepared, B> = sk.prepare_alloc(module, scratch.borrow()); let mut pk: GLWEPublicKey> = GLWEPublicKey::alloc(&glwe_infos); - pk.generate_from_sk(module, &sk_prepared, &mut source_xa, &mut source_xe); + pk.generate(module, &sk_prepared, &mut source_xa, &mut source_xe); module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa);