Merge pull request #109 from phantomzone-org/dev_bdd_selector

Add GGSW based blind rotation + refactor of tensorkey
This commit is contained in:
Jean-Philippe Bossuat
2025-10-27 18:08:27 +01:00
committed by GitHub
93 changed files with 4042 additions and 1887 deletions

17
Cargo.lock generated
View File

@@ -49,9 +49,9 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
[[package]] [[package]]
name = "bytemuck" name = "bytemuck"
version = "1.23.2" version = "1.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3995eaeebcdf32f91f980d360f78732ddc061097ab4e39991ae7a6ace9194677" checksum = "1fbdf580320f38b612e485521afda1ee26d10cc9884efaaa750d383e13e3c5f4"
[[package]] [[package]]
name = "byteorder" name = "byteorder"
@@ -353,7 +353,7 @@ dependencies = [
[[package]] [[package]]
name = "poulpy-backend" name = "poulpy-backend"
version = "0.2.0" version = "0.3.1"
dependencies = [ dependencies = [
"byteorder", "byteorder",
"cmake", "cmake",
@@ -370,8 +370,9 @@ dependencies = [
[[package]] [[package]]
name = "poulpy-core" name = "poulpy-core"
version = "0.2.0" version = "0.3.1"
dependencies = [ dependencies = [
"bytemuck",
"byteorder", "byteorder",
"criterion", "criterion",
"itertools 0.14.0", "itertools 0.14.0",
@@ -383,7 +384,7 @@ dependencies = [
[[package]] [[package]]
name = "poulpy-hal" name = "poulpy-hal"
version = "0.2.0" version = "0.3.1"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"byteorder", "byteorder",
@@ -400,7 +401,7 @@ dependencies = [
[[package]] [[package]]
name = "poulpy-schemes" name = "poulpy-schemes"
version = "0.2.0" version = "0.3.0"
dependencies = [ dependencies = [
"byteorder", "byteorder",
"criterion", "criterion",
@@ -534,9 +535,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]] [[package]]
name = "rug" name = "rug"
version = "1.27.0" version = "1.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4207e8d668e5b8eb574bda8322088ccd0d7782d3d03c7e8d562e82ed82bdcbc3" checksum = "58ad2e973fe3c3214251a840a621812a4f40468da814b1a3d6947d433c2af11f"
dependencies = [ dependencies = [
"az", "az",
"gmp-mpfr-sys", "gmp-mpfr-sys",

View File

@@ -7,8 +7,8 @@ poulpy-hal = {path = "poulpy-hal"}
poulpy-core = {path = "poulpy-core"} poulpy-core = {path = "poulpy-core"}
poulpy-backend = {path = "poulpy-backend"} poulpy-backend = {path = "poulpy-backend"}
poulpy-schemes = {path = "poulpy-schemes"} poulpy-schemes = {path = "poulpy-schemes"}
rug = "1.27" rug = "1.28.0"
rand = "0.9.1" rand = "0.9.2"
rand_chacha = "0.9.0" rand_chacha = "0.9.0"
rand_core = "0.9.3" rand_core = "0.9.3"
rand_distr = "0.5.1" rand_distr = "0.5.1"
@@ -17,3 +17,4 @@ criterion = "0.7.0"
byteorder = "1.5.0" byteorder = "1.5.0"
zstd = "0.13.3" zstd = "0.13.3"
once_cell = "1.21.3" once_cell = "1.21.3"
bytemuck = "1.24.0"

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "poulpy-backend" name = "poulpy-backend"
version = "0.2.0" version = "0.3.1"
edition = "2024" edition = "2024"
license = "Apache-2.0" license = "Apache-2.0"
readme = "README.md" readme = "README.md"

View File

@@ -1,5 +1,6 @@
use poulpy_hal::{ use poulpy_hal::{
api::ModuleNew, backend_test_suite, cross_backend_test_suite, layouts::Module, test_suite::convolution::test_convolution, api::ModuleNew, backend_test_suite, cross_backend_test_suite, layouts::Module,
test_suite::convolution::test_bivariate_tensoring,
}; };
use crate::FFT64Avx; use crate::FFT64Avx;
@@ -123,5 +124,5 @@ backend_test_suite! {
#[test] #[test]
fn test_convolution_fft64_avx() { fn test_convolution_fft64_avx() {
let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(64); let module: Module<FFT64Avx> = Module::<FFT64Avx>::new(64);
test_convolution(&module); test_bivariate_tensoring(&module);
} }

View File

@@ -14,7 +14,7 @@ use poulpy_hal::{
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxZeroImpl,
}, },
reference::vec_znx::{ reference::vec_znx::{
vec_znx_add, vec_znx_add_inplace, vec_znx_add_normal_ref, vec_znx_add_scalar, vec_znx_add_scalar_inplace, vec_znx_add, vec_znx_add_inplace, vec_znx_add_normal_ref, vec_znx_add_scalar, vec_znx_add_scalar_inplace,
@@ -25,13 +25,22 @@ use poulpy_hal::{
vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace, vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace,
vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring,
vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar, vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar,
vec_znx_sub_scalar_inplace, vec_znx_switch_ring, vec_znx_sub_scalar_inplace, vec_znx_switch_ring, vec_znx_zero,
}, },
source::Source, source::Source,
}; };
use crate::cpu_fft64_avx::FFT64Avx; use crate::cpu_fft64_avx::FFT64Avx;
unsafe impl VecZnxZeroImpl<Self> for FFT64Avx {
fn vec_znx_zero_impl<R>(_module: &Module<Self>, res: &mut R, res_col: usize)
where
R: VecZnxToMut,
{
vec_znx_zero::<_, FFT64Avx>(res, res_col);
}
}
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64Avx { unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64Avx {
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize { fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
vec_znx_normalize_tmp_bytes(module.n()) vec_znx_normalize_tmp_bytes(module.n())

View File

@@ -194,10 +194,10 @@ unsafe impl VecZnxDftCopyImpl<Self> for FFT64Avx {
} }
unsafe impl VecZnxDftZeroImpl<Self> for FFT64Avx { unsafe impl VecZnxDftZeroImpl<Self> for FFT64Avx {
fn vec_znx_dft_zero_impl<R>(_module: &Module<Self>, res: &mut R) fn vec_znx_dft_zero_impl<R>(_module: &Module<Self>, res: &mut R, res_col: usize)
where where
R: VecZnxDftToMut<Self>, R: VecZnxDftToMut<Self>,
{ {
vec_znx_dft_zero(res); vec_znx_dft_zero(res, res_col);
} }
} }

View File

@@ -1,9 +1,9 @@
use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_convolution}; use poulpy_hal::{api::ModuleNew, layouts::Module, test_suite::convolution::test_bivariate_tensoring};
use crate::FFT64Ref; use crate::FFT64Ref;
#[test] #[test]
fn test_convolution_fft64_ref() { fn test_convolution_fft64_ref() {
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(64); let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(8);
test_convolution(&module); test_bivariate_tensoring(&module);
} }

View File

@@ -14,7 +14,7 @@ use poulpy_hal::{
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxZeroImpl,
}, },
reference::vec_znx::{ reference::vec_znx::{
vec_znx_add, vec_znx_add_inplace, vec_znx_add_normal_ref, vec_znx_add_scalar, vec_znx_add_scalar_inplace, vec_znx_add, vec_znx_add_inplace, vec_znx_add_normal_ref, vec_znx_add_scalar, vec_znx_add_scalar_inplace,
@@ -25,13 +25,22 @@ use poulpy_hal::{
vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace, vec_znx_normalize_inplace, vec_znx_normalize_tmp_bytes, vec_znx_rotate, vec_znx_rotate_inplace,
vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring,
vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar, vec_znx_split_ring_tmp_bytes, vec_znx_sub, vec_znx_sub_inplace, vec_znx_sub_negate_inplace, vec_znx_sub_scalar,
vec_znx_sub_scalar_inplace, vec_znx_switch_ring, vec_znx_sub_scalar_inplace, vec_znx_switch_ring, vec_znx_zero,
}, },
source::Source, source::Source,
}; };
use crate::cpu_fft64_ref::FFT64Ref; use crate::cpu_fft64_ref::FFT64Ref;
unsafe impl VecZnxZeroImpl<Self> for FFT64Ref {
fn vec_znx_zero_impl<R>(_module: &Module<Self>, res: &mut R, res_col: usize)
where
R: VecZnxToMut,
{
vec_znx_zero::<_, FFT64Ref>(res, res_col);
}
}
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64Ref { unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64Ref {
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize { fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
vec_znx_normalize_tmp_bytes(module.n()) vec_znx_normalize_tmp_bytes(module.n())

View File

@@ -194,10 +194,10 @@ unsafe impl VecZnxDftCopyImpl<Self> for FFT64Ref {
} }
unsafe impl VecZnxDftZeroImpl<Self> for FFT64Ref { unsafe impl VecZnxDftZeroImpl<Self> for FFT64Ref {
fn vec_znx_dft_zero_impl<R>(_module: &Module<Self>, res: &mut R) fn vec_znx_dft_zero_impl<R>(_module: &Module<Self>, res: &mut R, res_col: usize)
where where
R: VecZnxDftToMut<Self>, R: VecZnxDftToMut<Self>,
{ {
vec_znx_dft_zero(res); vec_znx_dft_zero(res, res_col);
} }
} }

View File

@@ -15,7 +15,7 @@ use poulpy_hal::{
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxZeroImpl,
}, },
reference::{ reference::{
vec_znx::{ vec_znx::{
@@ -23,7 +23,7 @@ use poulpy_hal::{
vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, vec_znx_merge_rings, vec_znx_fill_uniform_ref, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_lsh_tmp_bytes, vec_znx_merge_rings,
vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_normalize_tmp_bytes, vec_znx_merge_rings_tmp_bytes, vec_znx_mul_xp_minus_one_inplace_tmp_bytes, vec_znx_normalize_tmp_bytes,
vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring, vec_znx_rotate_inplace_tmp_bytes, vec_znx_rsh, vec_znx_rsh_inplace, vec_znx_rsh_tmp_bytes, vec_znx_split_ring,
vec_znx_split_ring_tmp_bytes, vec_znx_switch_ring, vec_znx_split_ring_tmp_bytes, vec_znx_switch_ring, vec_znx_zero,
}, },
znx::{znx_copy_ref, znx_zero_ref}, znx::{znx_copy_ref, znx_zero_ref},
}, },
@@ -35,6 +35,15 @@ use crate::cpu_spqlios::{
ffi::{module::module_info_t, vec_znx, znx}, ffi::{module::module_info_t, vec_znx, znx},
}; };
unsafe impl VecZnxZeroImpl<Self> for FFT64Spqlios {
fn vec_znx_zero_impl<R>(_module: &Module<Self>, res: &mut R, res_col: usize)
where
R: VecZnxToMut,
{
vec_znx_zero::<_, FFT64Spqlios>(res, res_col);
}
}
unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64Spqlios { unsafe impl VecZnxNormalizeTmpBytesImpl<Self> for FFT64Spqlios {
fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize { fn vec_znx_normalize_tmp_bytes_impl(module: &Module<Self>) -> usize {
vec_znx_normalize_tmp_bytes(module.n()) vec_znx_normalize_tmp_bytes(module.n())

View File

@@ -12,7 +12,7 @@ use poulpy_hal::{
reference::{ reference::{
fft64::{ fft64::{
reim::{ReimCopy, ReimZero, reim_copy_ref, reim_negate_inplace_ref, reim_negate_ref, reim_zero_ref}, reim::{ReimCopy, ReimZero, reim_copy_ref, reim_negate_inplace_ref, reim_negate_ref, reim_zero_ref},
vec_znx_dft::vec_znx_dft_copy, vec_znx_dft::{vec_znx_dft_copy, vec_znx_dft_zero},
}, },
znx::znx_zero_ref, znx::znx_zero_ref,
}, },
@@ -426,10 +426,10 @@ impl ReimZero for FFT64Spqlios {
} }
unsafe impl VecZnxDftZeroImpl<Self> for FFT64Spqlios { unsafe impl VecZnxDftZeroImpl<Self> for FFT64Spqlios {
fn vec_znx_dft_zero_impl<R>(_module: &Module<Self>, res: &mut R) fn vec_znx_dft_zero_impl<R>(_module: &Module<Self>, res: &mut R, res_col: usize)
where where
R: VecZnxDftToMut<Self>, R: VecZnxDftToMut<Self>,
{ {
res.to_mut().data.fill(0); vec_znx_dft_zero(res, res_col);
} }
} }

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "poulpy-core" name = "poulpy-core"
version = "0.2.0" version = "0.3.1"
edition = "2024" edition = "2024"
license = "Apache-2.0" license = "Apache-2.0"
description = "A backend agnostic crate implementing RLWE-based encryption & arithmetic." description = "A backend agnostic crate implementing RLWE-based encryption & arithmetic."
@@ -15,6 +15,7 @@ poulpy-hal = {workspace = true}
poulpy-backend = {workspace = true} poulpy-backend = {workspace = true}
itertools = {workspace = true} itertools = {workspace = true}
byteorder = {workspace = true} byteorder = {workspace = true}
bytemuck = {workspace = true}
once_cell = {workspace = true} once_cell = {workspace = true}
[[bench]] [[bench]]

View File

@@ -1,11 +1,10 @@
use poulpy_hal::{ use poulpy_hal::{
api::VecZnxAutomorphism, api::{VecZnxAutomorphism, VecZnxAutomorphismInplace},
layouts::{Backend, DataMut, GaloisElement, Module, Scratch}, layouts::{Backend, CyclotomicOrder, DataMut, GaloisElement, Module, Scratch},
}; };
use crate::{ use crate::{
ScratchTakeCore, GLWEKeyswitch, ScratchTakeCore,
automorphism::glwe_ct::GLWEAutomorphism,
layouts::{ layouts::{
GGLWE, GGLWEInfos, GGLWEPreparedToRef, GGLWEToMut, GGLWEToRef, GLWE, GLWEAutomorphismKey, GetGaloisElement, GGLWE, GGLWEInfos, GGLWEPreparedToRef, GGLWEToMut, GGLWEToRef, GLWE, GLWEAutomorphismKey, GetGaloisElement,
SetGaloisElement, SetGaloisElement,
@@ -45,14 +44,10 @@ impl<DataSelf: DataMut> GLWEAutomorphismKey<DataSelf> {
} }
} }
impl<BE: Backend> GLWEAutomorphismKeyAutomorphism<BE> for Module<BE> where impl<BE: Backend> GLWEAutomorphismKeyAutomorphism<BE> for Module<BE>
Self: GaloisElement + GLWEAutomorphism<BE> + VecZnxAutomorphism
{
}
pub trait GLWEAutomorphismKeyAutomorphism<BE: Backend>
where where
Self: GaloisElement + GLWEAutomorphism<BE> + VecZnxAutomorphism, Self: GaloisElement + GLWEKeyswitch<BE> + VecZnxAutomorphism + VecZnxAutomorphismInplace<BE> + CyclotomicOrder,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn glwe_automorphism_key_automorphism_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize fn glwe_automorphism_key_automorphism_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
where where
@@ -68,7 +63,6 @@ where
R: GGLWEToMut + SetGaloisElement + GGLWEInfos, R: GGLWEToMut + SetGaloisElement + GGLWEInfos,
A: GGLWEToRef + GetGaloisElement + GGLWEInfos, A: GGLWEToRef + GetGaloisElement + GGLWEInfos,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos, K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
assert!( assert!(
res.dnum().as_u32() <= a.dnum().as_u32(), res.dnum().as_u32() <= a.dnum().as_u32(),
@@ -163,3 +157,22 @@ where
res.set_p((res.p() * key.p()) % self.cyclotomic_order()); res.set_p((res.p() * key.p()) % self.cyclotomic_order());
} }
} }
pub trait GLWEAutomorphismKeyAutomorphism<BE: Backend> {
fn glwe_automorphism_key_automorphism_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
where
R: GGLWEInfos,
A: GGLWEInfos,
K: GGLWEInfos;
fn glwe_automorphism_key_automorphism<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
where
R: GGLWEToMut + SetGaloisElement + GGLWEInfos,
A: GGLWEToRef + GetGaloisElement + GGLWEInfos,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos;
fn glwe_automorphism_key_automorphism_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
where
R: GGLWEToMut + SetGaloisElement + GetGaloisElement + GGLWEInfos,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos;
}

View File

@@ -7,8 +7,8 @@ use crate::{
GGSWExpandRows, ScratchTakeCore, GGSWExpandRows, ScratchTakeCore,
automorphism::glwe_ct::GLWEAutomorphism, automorphism::glwe_ct::GLWEAutomorphism,
layouts::{ layouts::{
GGLWEInfos, GGLWEPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GetGaloisElement, GGLWEInfos, GGLWEPreparedToRef, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedToRef, GGSW, GGSWInfos, GGSWToMut,
prepared::{GLWETensorKeyPrepared, GLWETensorKeyPreparedToRef}, GGSWToRef, GetGaloisElement,
}, },
}; };
@@ -36,7 +36,7 @@ impl<D: DataMut> GGSW<D> {
where where
A: GGSWToRef, A: GGSWToRef,
K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos,
T: GLWETensorKeyPreparedToRef<BE>, T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGSWAutomorphism<BE>, M: GGSWAutomorphism<BE>,
{ {
@@ -46,7 +46,7 @@ impl<D: DataMut> GGSW<D> {
pub fn automorphism_inplace<K, T, M, BE: Backend>(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch<BE>) pub fn automorphism_inplace<K, T, M, BE: Backend>(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
where where
K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos,
T: GLWETensorKeyPreparedToRef<BE>, T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGSWAutomorphism<BE>, M: GGSWAutomorphism<BE>,
{ {
@@ -67,11 +67,8 @@ where
K: GGLWEInfos, K: GGLWEInfos,
T: GGLWEInfos, T: GGLWEInfos,
{ {
let out_size: usize = res_infos.size(); self.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos)
let ci_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), out_size); .max(self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos))
let ks_internal: usize = self.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos);
let expand: usize = self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos);
ci_dft + (ks_internal.max(expand))
} }
fn ggsw_automorphism<R, A, K, T>(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch<BE>) fn ggsw_automorphism<R, A, K, T>(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
@@ -79,12 +76,12 @@ where
R: GGSWToMut, R: GGSWToMut,
A: GGSWToRef, A: GGSWToRef,
K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos,
T: GLWETensorKeyPreparedToRef<BE>, T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let a: &GGSW<&[u8]> = &a.to_ref(); let a: &GGSW<&[u8]> = &a.to_ref();
let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
assert_eq!(res.dsize(), a.dsize()); assert_eq!(res.dsize(), a.dsize());
assert!(res.dnum() <= a.dnum()); assert!(res.dnum() <= a.dnum());
@@ -104,11 +101,11 @@ where
where where
R: GGSWToMut, R: GGSWToMut,
K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos,
T: GLWETensorKeyPreparedToRef<BE>, T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
// Keyswitch the j-th row of the col 0 // Keyswitch the j-th row of the col 0
for row in 0..res.dnum().as_usize() { for row in 0..res.dnum().as_usize() {

View File

@@ -1,13 +1,13 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
ScratchTakeBasic, VecZnxAutomorphismInplace, VecZnxBigAutomorphismInplace, VecZnxBigSubSmallInplace, ScratchTakeBasic, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize,
VecZnxBigSubSmallNegateInplace, VecZnxBigSubSmallInplace, VecZnxBigSubSmallNegateInplace, VecZnxNormalize,
}, },
layouts::{Backend, DataMut, Module, Scratch, VecZnxBig}, layouts::{Backend, DataMut, Module, Scratch, VecZnxBig},
}; };
use crate::{ use crate::{
GLWEKeyswitch, ScratchTakeCore, keyswitch_internal, GLWEKeySwitchInternal, GLWEKeyswitch, ScratchTakeCore,
layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos}, layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos},
}; };
@@ -101,13 +101,71 @@ impl<DataSelf: DataMut> GLWE<DataSelf> {
} }
} }
pub trait GLWEAutomorphism<BE: Backend> pub trait GLWEAutomorphism<BE: Backend> {
fn glwe_automorphism_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
where where
Self: GLWEKeyswitch<BE> R: GLWEInfos,
A: GLWEInfos,
K: GGLWEInfos;
fn glwe_automorphism<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos;
fn glwe_automorphism_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos;
fn glwe_automorphism_add<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos;
fn glwe_automorphism_add_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos;
fn glwe_automorphism_sub<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos;
fn glwe_automorphism_sub_negate<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos;
fn glwe_automorphism_sub_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos;
fn glwe_automorphism_sub_negate_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos;
}
impl<BE: Backend> GLWEAutomorphism<BE> for Module<BE>
where
Self: Sized
+ GLWEKeyswitch<BE>
+ GLWEKeySwitchInternal<BE>
+ VecZnxNormalize<BE>
+ VecZnxAutomorphismInplace<BE> + VecZnxAutomorphismInplace<BE>
+ VecZnxBigAutomorphismInplace<BE> + VecZnxBigAutomorphismInplace<BE>
+ VecZnxBigSubSmallInplace<BE> + VecZnxBigSubSmallInplace<BE>
+ VecZnxBigSubSmallNegateInplace<BE>, + VecZnxBigSubSmallNegateInplace<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn glwe_automorphism_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize fn glwe_automorphism_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
where where
@@ -160,7 +218,7 @@ where
let a: &GLWE<&[u8]> = &a.to_ref(); let a: &GLWE<&[u8]> = &a.to_ref();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, a, key, scratch_1); let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
@@ -186,7 +244,7 @@ where
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, res, key, scratch_1); let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
@@ -214,7 +272,7 @@ where
let a: &GLWE<&[u8]> = &a.to_ref(); let a: &GLWE<&[u8]> = &a.to_ref();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, a, key, scratch_1); let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
@@ -242,7 +300,7 @@ where
let a: &GLWE<&[u8]> = &a.to_ref(); let a: &GLWE<&[u8]> = &a.to_ref();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, a, key, scratch_1); let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
@@ -268,7 +326,7 @@ where
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, res, key, scratch_1); let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
@@ -294,7 +352,7 @@ where
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
let mut res_big: VecZnxBig<_, BE> = keyswitch_internal(self, res_dft, res, key, scratch_1); let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
@@ -311,12 +369,3 @@ where
} }
} }
} }
impl<BE: Backend> GLWEAutomorphism<BE> for Module<BE> where
Self: GLWEKeyswitch<BE>
+ VecZnxAutomorphismInplace<BE>
+ VecZnxBigAutomorphismInplace<BE>
+ VecZnxBigSubSmallInplace<BE>
+ VecZnxBigSubSmallNegateInplace<BE>
{
}

View File

@@ -1,17 +1,16 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftAddInplace, VecZnxDftApply, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize,
VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
}, },
layouts::{Backend, DataMut, Module, Scratch, VmpPMat, ZnxInfos}, layouts::{Backend, DataMut, Module, Scratch, VecZnxBig},
}; };
use crate::{ use crate::{
GLWECopy, ScratchTakeCore, GGLWEProduct, GLWECopy, ScratchTakeCore,
layouts::{ layouts::{
GGLWE, GGLWEInfos, GGLWEToRef, GGSW, GGSWInfos, GGSWToMut, GLWEInfos, LWEInfos, GGLWE, GGLWEInfos, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedToRef, GGLWEToRef, GGSW, GGSWInfos, GGSWToMut, GLWE,
prepared::{GLWETensorKeyPrepared, GLWETensorKeyPreparedToRef}, GLWEInfos, LWEInfos,
}, },
}; };
@@ -31,7 +30,7 @@ impl<D: DataMut> GGSW<D> {
where where
M: GGSWFromGGLWE<BE>, M: GGSWFromGGLWE<BE>,
G: GGLWEToRef, G: GGLWEToRef,
T: GLWETensorKeyPreparedToRef<BE>, T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
module.ggsw_from_gglwe(self, gglwe, tsk, scratch); module.ggsw_from_gglwe(self, gglwe, tsk, scratch);
@@ -54,12 +53,12 @@ where
where where
R: GGSWToMut, R: GGSWToMut,
A: GGLWEToRef, A: GGLWEToRef,
T: GLWETensorKeyPreparedToRef<BE>, T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let a: &GGLWE<&[u8]> = &a.to_ref(); let a: &GGLWE<&[u8]> = &a.to_ref();
let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
assert_eq!(res.rank(), a.rank_out()); assert_eq!(res.rank(), a.rank_out());
assert_eq!(res.dnum(), a.dnum()); assert_eq!(res.dnum(), a.dnum());
@@ -85,115 +84,104 @@ pub trait GGSWFromGGLWE<BE: Backend> {
where where
R: GGSWToMut, R: GGSWToMut,
A: GGLWEToRef, A: GGLWEToRef,
T: GLWETensorKeyPreparedToRef<BE>, T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>; Scratch<BE>: ScratchTakeCore<BE>;
} }
impl<BE: Backend> GGSWExpandRows<BE> for Module<BE> where pub trait GGSWExpandRows<BE: Backend> {
Self: Sized fn ggsw_expand_rows_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize
+ ModuleN where
+ VecZnxDftBytesOf R: GGSWInfos,
+ VmpApplyDftToDftTmpBytes A: GGLWEInfos;
+ VecZnxBigBytesOf
+ VecZnxNormalizeTmpBytes fn ggsw_expand_row<R, T>(&self, res: &mut R, tsk: &T, scratch: &mut Scratch<BE>)
+ VecZnxDftBytesOf where
+ VmpApplyDftToDftTmpBytes R: GGSWToMut,
+ VecZnxBigBytesOf T: GGLWEToGGSWKeyPreparedToRef<BE>,
+ VecZnxNormalizeTmpBytes Scratch<BE>: ScratchTakeCore<BE>;
+ VecZnxDftApply<BE>
+ VecZnxDftCopy<BE>
+ VmpApplyDftToDft<BE>
+ VmpApplyDftToDftAdd<BE>
+ VecZnxDftAddInplace<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxIdftApplyTmpA<BE>
+ VecZnxNormalize<BE>
{
} }
pub trait GGSWExpandRows<BE: Backend> impl<BE: Backend> GGSWExpandRows<BE> for Module<BE>
where where
Self: Sized Self: GGLWEProduct<BE>
+ ModuleN
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigBytesOf
+ VecZnxNormalizeTmpBytes
+ VecZnxDftApply<BE>
+ VecZnxDftCopy<BE>
+ VmpApplyDftToDft<BE>
+ VmpApplyDftToDftAdd<BE>
+ VecZnxDftAddInplace<BE>
+ VecZnxBigNormalize<BE> + VecZnxBigNormalize<BE>
+ VecZnxIdftApplyTmpA<BE> + VecZnxBigNormalizeTmpBytes
+ VecZnxNormalize<BE>, + VecZnxBigBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftApply<BE>
+ VecZnxNormalize<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxIdftApplyConsume<BE>,
{ {
fn ggsw_expand_rows_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize fn ggsw_expand_rows_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize
where where
R: GGSWInfos, R: GGSWInfos,
A: GGLWEInfos, A: GGLWEInfos,
{ {
let tsk_size: usize = tsk_infos.k().div_ceil(tsk_infos.base2k()) as usize; let base2k_in: usize = res_infos.base2k().into();
let size_in: usize = res_infos let base2k_tsk: usize = tsk_infos.base2k().into();
.k()
.div_ceil(tsk_infos.base2k())
.div_ceil(tsk_infos.dsize().into()) as usize;
let tmp_dft_i: usize = self.bytes_of_vec_znx_dft((tsk_infos.rank_out() + 1).into(), tsk_size); let rank: usize = res_infos.rank().into();
let tmp_a: usize = self.bytes_of_vec_znx_dft(1, size_in); let cols: usize = rank + 1;
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
tsk_size,
size_in,
size_in,
(tsk_infos.rank_in()).into(), // Verify if rank+1
(tsk_infos.rank_out()).into(), // Verify if rank+1
tsk_size,
);
let tmp_idft: usize = self.bytes_of_vec_znx_big(1, tsk_size);
let norm: usize = self.vec_znx_normalize_tmp_bytes();
tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm)) let res_size = res_infos.size();
let a_size: usize = (res_infos.size() * base2k_in).div_ceil(base2k_tsk);
let a_dft = self.bytes_of_vec_znx_dft(cols - 1, a_size);
let res_dft = self.bytes_of_vec_znx_dft(cols, a_size);
let gglwe_prod: usize = self.gglwe_product_dft_tmp_bytes(res_size, a_size, tsk_infos);
let normalize = self.vec_znx_big_normalize_tmp_bytes();
(a_dft + res_dft + gglwe_prod).max(normalize)
} }
fn ggsw_expand_row<R, T>(&self, res: &mut R, tsk: &T, scratch: &mut Scratch<BE>) fn ggsw_expand_row<R, T>(&self, res: &mut R, tsk: &T, scratch: &mut Scratch<BE>)
where where
R: GGSWToMut, R: GGSWToMut,
T: GLWETensorKeyPreparedToRef<BE>, T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let tsk: &GLWETensorKeyPrepared<&[u8], BE> = &tsk.to_ref(); let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
let basek_in: usize = res.base2k().into(); let base2k_in: usize = res.base2k().into();
let basek_tsk: usize = tsk.base2k().into(); let base2k_tsk: usize = tsk.base2k().into();
assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk)); assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk));
let rank: usize = res.rank().into(); let rank: usize = res.rank().into();
let cols: usize = rank + 1; let cols: usize = rank + 1;
let a_size: usize = (res.size() * basek_in).div_ceil(basek_tsk); let a_size: usize = (res.size() * base2k_in).div_ceil(base2k_tsk);
// Keyswitch the j-th row of the col 0 // Keyswitch the j-th row of the col 0
for row_i in 0..res.dnum().into() { for row in 0..res.dnum().as_usize() {
let a = &res.at(row_i, 0).data; let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, a_size);
// Pre-compute DFT of (a0, a1, a2) {
let (mut ci_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size); let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0);
if basek_in == basek_tsk { if base2k_in == base2k_tsk {
for i in 0..cols { for col_i in 0..cols - 1 {
self.vec_znx_dft_apply(1, 0, &mut ci_dft, i, a, i); self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1);
} }
} else { } else {
let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size); let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size);
for i in 0..cols { for i in 0..cols - 1 {
self.vec_znx_normalize(basek_tsk, &mut a_conv, 0, basek_in, a, i, scratch_2); self.vec_znx_normalize(
self.vec_znx_dft_apply(1, 0, &mut ci_dft, i, &a_conv, 0); base2k_tsk,
&mut a_conv,
0,
base2k_in,
glwe_mi_1.data(),
i + 1,
scratch_2,
);
self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_conv, 0);
}
} }
} }
for col_j in 1..cols {
// Example for rank 3: // Example for rank 3:
// //
// Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is
@@ -213,13 +201,9 @@ where
// col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 ) // col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 )
// col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 )
// col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i])
for col in 1..cols {
let (mut res_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, tsk.size()); // Todo optimise
let dsize: usize = tsk.dsize().into();
let (mut tmp_dft_i, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, tsk.size());
let (mut tmp_a, scratch_3) = scratch_2.take_vec_znx_dft(self, 1, ci_dft.size().div_ceil(dsize));
{
// Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
// //
// # Example for col=1 // # Example for col=1
@@ -231,31 +215,9 @@ where
// a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2)
// = // =
// (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2)
for col_i in 1..cols { self.gglwe_product_dft(&mut res_dft, &a_dft, tsk.at(col - 1), scratch_2);
let pmat: &VmpPMat<&[u8], BE> = &tsk.at(col_i - 1, col_j - 1).data; // Selects Enc(s[i]s[j])
// Extracts a[i] and multipies with Enc(s[i]s[j]) let mut res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
for di in 0..dsize {
tmp_a.set_size((ci_dft.size() + di) / dsize);
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
tmp_dft_i.set_size(tsk.size() - ((dsize - di) as isize - 2).max(0) as usize);
self.vec_znx_dft_copy(dsize, dsize - 1 - di, &mut tmp_a, 0, &ci_dft, col_i);
if di == 0 && col_i == 1 {
self.vmp_apply_dft_to_dft(&mut tmp_dft_i, &tmp_a, pmat, scratch_3);
} else {
self.vmp_apply_dft_to_dft_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch_3);
}
}
}
}
// Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i
// //
@@ -266,18 +228,17 @@ where
// (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2)
// = // =
// (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2)
self.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, &ci_dft, 0); self.vec_znx_big_add_small_inplace(&mut res_big, col, res.at(row, 0).data(), 0);
let (mut tmp_idft, scratch_3) = scratch_2.take_vec_znx_big(self, 1, tsk.size());
for i in 0..cols { for j in 0..cols {
self.vec_znx_idft_apply_tmpa(&mut tmp_idft, 0, &mut tmp_dft_i, i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(
basek_in, res.base2k().as_usize(),
&mut res.at_mut(row_i, col_j).data, res.at_mut(row, col).data_mut(),
i, j,
basek_tsk, tsk.base2k().as_usize(),
&tmp_idft, &res_big,
0, j,
scratch_3, scratch_2,
); );
} }
} }

View File

@@ -1,5 +1,5 @@
use poulpy_hal::{ use poulpy_hal::{
api::ScratchTakeBasic, api::{ScratchTakeBasic, VecZnxNormalize, VecZnxNormalizeTmpBytes},
layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxView, ZnxViewMut, ZnxZero}, layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxView, ZnxViewMut, ZnxZero},
}; };
@@ -8,11 +8,10 @@ use crate::{
layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToMut, LWE, LWEInfos, LWEToRef}, layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToMut, LWE, LWEInfos, LWEToRef},
}; };
impl<BE: Backend> GLWEFromLWE<BE> for Module<BE> where Self: GLWEKeyswitch<BE> {} impl<BE: Backend> GLWEFromLWE<BE> for Module<BE>
pub trait GLWEFromLWE<BE: Backend>
where where
Self: GLWEKeyswitch<BE>, Self: GLWEKeyswitch<BE> + VecZnxNormalizeTmpBytes + VecZnxNormalize<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn glwe_from_lwe_tmp_bytes<R, A, K>(&self, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize fn glwe_from_lwe_tmp_bytes<R, A, K>(&self, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize
where where
@@ -41,7 +40,6 @@ where
R: GLWEToMut, R: GLWEToMut,
A: LWEToRef, A: LWEToRef,
K: GGLWEPreparedToRef<BE> + GGLWEInfos, K: GGLWEPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let lwe: &LWE<&[u8]> = &lwe.to_ref(); let lwe: &LWE<&[u8]> = &lwe.to_ref();
@@ -105,6 +103,23 @@ where
} }
} }
pub trait GLWEFromLWE<BE: Backend>
where
Self: GLWEKeyswitch<BE>,
{
fn glwe_from_lwe_tmp_bytes<R, A, K>(&self, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize
where
R: GLWEInfos,
A: LWEInfos,
K: GGLWEInfos;
fn glwe_from_lwe<R, A, K>(&self, res: &mut R, lwe: &A, ksk: &K, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: LWEToRef,
K: GGLWEPreparedToRef<BE> + GGLWEInfos;
}
impl GLWE<Vec<u8>> { impl GLWE<Vec<u8>> {
pub fn from_lwe_tmp_bytes<R, A, K, M, BE: Backend>(module: &M, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize pub fn from_lwe_tmp_bytes<R, A, K, M, BE: Backend>(module: &M, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize
where where

View File

@@ -0,0 +1,124 @@
use poulpy_hal::{
api::{ModuleN, ScratchTakeBasic, VecZnxCopy},
layouts::{Backend, DataMut, Module, Scratch},
source::Source,
};
use crate::{
GGLWECompressedEncryptSk, GetDistribution, ScratchTakeCore,
layouts::{
GGLWEInfos, GGLWEToGGSWKeyCompressed, GGLWEToGGSWKeyCompressedToMut, GLWEInfos, GLWESecret, GLWESecretTensor,
GLWESecretTensorFactory, GLWESecretToRef,
prepared::{GLWESecretPrepared, GLWESecretPreparedFactory},
},
};
impl GGLWEToGGSWKeyCompressed<Vec<u8>> {
pub fn encrypt_sk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
where
A: GGLWEInfos,
M: GGLWEToGGSWKeyCompressedEncryptSk<BE>,
{
module.gglwe_to_ggsw_key_encrypt_sk_tmp_bytes(infos)
}
}
impl<DataSelf: DataMut> GGLWEToGGSWKeyCompressed<DataSelf> {
pub fn encrypt_sk<M, S, BE: Backend>(
&mut self,
module: &M,
sk: &S,
seed_xa: [u8; 32],
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
M: GGLWEToGGSWKeyCompressedEncryptSk<BE>,
S: GLWESecretToRef + GetDistribution + GLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
module.gglwe_to_ggsw_key_encrypt_sk(self, sk, seed_xa, source_xe, scratch);
}
}
pub trait GGLWEToGGSWKeyCompressedEncryptSk<BE: Backend> {
fn gglwe_to_ggsw_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GGLWEInfos;
fn gglwe_to_ggsw_key_encrypt_sk<R, S>(
&self,
res: &mut R,
sk: &S,
seed_xa: [u8; 32],
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
R: GGLWEToGGSWKeyCompressedToMut + GGLWEInfos,
S: GLWESecretToRef + GetDistribution + GLWEInfos;
}
impl<BE: Backend> GGLWEToGGSWKeyCompressedEncryptSk<BE> for Module<BE>
where
Self: ModuleN + GGLWECompressedEncryptSk<BE> + GLWESecretTensorFactory<BE> + GLWESecretPreparedFactory<BE> + VecZnxCopy,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn gglwe_to_ggsw_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GGLWEInfos,
{
let sk_prepared: usize = GLWESecretPrepared::bytes_of(self, infos.rank());
let sk_tensor: usize = GLWESecretTensor::bytes_of_from_infos(infos);
let gglwe_encrypt: usize = self.gglwe_compressed_encrypt_sk_tmp_bytes(infos);
let sk_ij = GLWESecret::bytes_of(self.n().into(), infos.rank());
(sk_prepared + sk_tensor + sk_ij) + gglwe_encrypt.max(self.glwe_secret_tensor_prepare_tmp_bytes(infos.rank()))
}
fn gglwe_to_ggsw_key_encrypt_sk<R, S>(
&self,
res: &mut R,
sk: &S,
seed_xa: [u8; 32],
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
R: GGLWEToGGSWKeyCompressedToMut + GGLWEInfos,
S: GLWESecretToRef + GetDistribution + GLWEInfos,
{
assert_eq!(res.rank(), sk.rank());
assert_eq!(res.n(), sk.n());
let res: &mut GGLWEToGGSWKeyCompressed<&mut [u8]> = &mut res.to_mut();
let rank: usize = res.rank_out().as_usize();
let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank());
let (mut sk_tensor, scratch_2) = scratch_1.take_glwe_secret_tensor(self.n().into(), res.rank());
sk_prepared.prepare(self, sk);
sk_tensor.prepare(self, sk, scratch_2);
let (mut sk_ij, scratch_3) = scratch_2.take_scalar_znx(self.n(), rank);
let mut source_xa = Source::new(seed_xa);
for i in 0..rank {
for j in 0..rank {
self.vec_znx_copy(
&mut sk_ij.as_vec_znx_mut(),
j,
&sk_tensor.at(i, j).as_vec_znx(),
0,
);
}
let (seed_xa_tmp, _) = source_xa.branch();
res.at_mut(i).encrypt_sk(
self,
&sk_ij,
&sk_prepared,
seed_xa_tmp,
source_xe,
scratch_3,
);
}
}
}

View File

@@ -1,17 +1,15 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::ScratchTakeBasic,
ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolBytesOf, SvpPrepare, VecZnxBigBytesOf, VecZnxBigNormalize,
VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyTmpA,
},
layouts::{Backend, DataMut, Module, Scratch}, layouts::{Backend, DataMut, Module, Scratch},
source::Source, source::Source,
}; };
use crate::{ use crate::{
GGLWECompressedEncryptSk, GLWETensorKeyEncryptSk, GetDistribution, ScratchTakeCore, GGLWECompressedEncryptSk, GetDistribution, ScratchTakeCore,
layouts::{ layouts::{
GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretPrepared, GLWESecretPreparedFactory, GLWESecretToRef, GGLWECompressedSeedMut, GGLWECompressedToMut, GGLWEInfos, GGLWELayout, GLWEInfos, GLWESecretPrepared,
GLWETensorKeyCompressedAtMut, LWEInfos, Rank, compressed::GLWETensorKeyCompressed, GLWESecretPreparedFactory, GLWESecretTensor, GLWESecretTensorFactory, GLWESecretToRef,
compressed::GLWETensorKeyCompressed,
}, },
}; };
@@ -34,7 +32,7 @@ impl<DataSelf: DataMut> GLWETensorKeyCompressed<DataSelf> {
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch<BE>, scratch: &mut Scratch<BE>,
) where ) where
S: GLWESecretToRef + GetDistribution, S: GLWESecretToRef + GetDistribution + GLWEInfos,
M: GLWETensorKeyCompressedEncryptSk<BE>, M: GLWETensorKeyCompressedEncryptSk<BE>,
{ {
module.glwe_tensor_key_compressed_encrypt_sk(self, sk, seed_xa, source_xe, scratch); module.glwe_tensor_key_compressed_encrypt_sk(self, sk, seed_xa, source_xe, scratch);
@@ -46,7 +44,7 @@ pub trait GLWETensorKeyCompressedEncryptSk<BE: Backend> {
where where
A: GGLWEInfos; A: GGLWEInfos;
fn glwe_tensor_key_compressed_encrypt_sk<R, S, D>( fn glwe_tensor_key_compressed_encrypt_sk<R, S>(
&self, &self,
res: &mut R, res: &mut R,
sk: &S, sk: &S,
@@ -54,40 +52,38 @@ pub trait GLWETensorKeyCompressedEncryptSk<BE: Backend> {
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch<BE>, scratch: &mut Scratch<BE>,
) where ) where
D: DataMut, R: GGLWECompressedToMut + GGLWEInfos + GGLWECompressedSeedMut,
R: GLWETensorKeyCompressedAtMut<D> + GGLWEInfos, S: GLWESecretToRef + GetDistribution + GLWEInfos;
S: GLWESecretToRef + GetDistribution;
} }
impl<BE: Backend> GLWETensorKeyCompressedEncryptSk<BE> for Module<BE> impl<BE: Backend> GLWETensorKeyCompressedEncryptSk<BE> for Module<BE>
where where
Self: ModuleN Self: GGLWECompressedEncryptSk<BE> + GLWESecretPreparedFactory<BE> + GLWESecretTensorFactory<BE>,
+ GGLWECompressedEncryptSk<BE>
+ GLWETensorKeyEncryptSk<BE>
+ VecZnxDftApply<BE>
+ SvpApplyDftToDft<BE>
+ VecZnxIdftApplyTmpA<BE>
+ VecZnxBigNormalize<BE>
+ SvpPrepare<BE>
+ SvpPPolBytesOf
+ VecZnxDftBytesOf
+ VecZnxBigBytesOf
+ GLWESecretPreparedFactory<BE>,
Scratch<BE>: ScratchTakeBasic + ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeBasic + ScratchTakeCore<BE>,
{ {
fn glwe_tensor_key_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize fn glwe_tensor_key_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
GLWESecretPrepared::bytes_of(self, infos.rank_out()) let sk_prepared: usize = GLWESecretPrepared::bytes_of(self, infos.rank_out());
+ self.bytes_of_vec_znx_dft(infos.rank_out().into(), 1) let sk_tensor: usize = GLWESecretTensor::bytes_of_from_infos(infos);
+ self.bytes_of_vec_znx_big(1, 1)
+ self.bytes_of_vec_znx_dft(1, 1) let tensor_infos: GGLWELayout = GGLWELayout {
+ GLWESecret::bytes_of(self.n().into(), Rank(1)) n: infos.n(),
+ self.gglwe_compressed_encrypt_sk_tmp_bytes(infos) base2k: infos.base2k(),
k: infos.k(),
rank_in: GLWESecretTensor::pairs(infos.rank().into()).into(),
rank_out: infos.rank_out(),
dnum: infos.dnum(),
dsize: infos.dsize(),
};
let gglwe_encrypt: usize = self.gglwe_compressed_encrypt_sk_tmp_bytes(&tensor_infos);
(sk_prepared + sk_tensor) + gglwe_encrypt.max(self.glwe_secret_tensor_prepare_tmp_bytes(infos.rank()))
} }
fn glwe_tensor_key_compressed_encrypt_sk<R, S, D>( fn glwe_tensor_key_compressed_encrypt_sk<R, S>(
&self, &self,
res: &mut R, res: &mut R,
sk: &S, sk: &S,
@@ -95,62 +91,24 @@ where
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch<BE>, scratch: &mut Scratch<BE>,
) where ) where
D: DataMut, R: GGLWEInfos + GGLWECompressedToMut + GGLWECompressedSeedMut,
R: GGLWEInfos + GLWETensorKeyCompressedAtMut<D>, S: GLWESecretToRef + GetDistribution + GLWEInfos,
S: GLWESecretToRef + GetDistribution,
{
let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank());
sk_dft_prep.prepare(self, sk);
let sk: &GLWESecret<&[u8]> = &sk.to_ref();
#[cfg(debug_assertions)]
{ {
assert_eq!(res.rank_out(), sk.rank()); assert_eq!(res.rank_out(), sk.rank());
assert_eq!(res.n(), sk.n()); assert_eq!(res.n(), sk.n());
}
// let n: usize = sk.n().into(); let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank());
let rank: usize = res.rank_out().into(); let (mut sk_tensor, scratch_2) = scratch_1.take_glwe_secret_tensor(self.n().into(), res.rank());
sk_prepared.prepare(self, sk);
let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank, 1); sk_tensor.prepare(self, sk, scratch_2);
for i in 0..rank {
self.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
}
let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1);
let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(self.n().into(), Rank(1));
let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(self, 1, 1);
let mut source_xa: Source = Source::new(seed_xa);
for i in 0..rank {
for j in i..rank {
self.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
self.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
self.vec_znx_big_normalize(
res.base2k().into(),
&mut sk_ij.data.as_vec_znx_mut(),
0,
res.base2k().into(),
&sk_ij_big,
0,
scratch_5,
);
let (seed_xa_tmp, _) = source_xa.branch();
self.gglwe_compressed_encrypt_sk( self.gglwe_compressed_encrypt_sk(
res.at_mut(i, j), res,
&sk_ij.data, &sk_tensor.data,
&sk_dft_prep, &sk_prepared,
seed_xa_tmp, seed_xa,
source_xe, source_xe,
scratch_5, scratch_2,
); );
} }
} }
}
}

View File

@@ -1,4 +1,5 @@
mod gglwe; mod gglwe;
mod gglwe_to_ggsw_key;
mod ggsw; mod ggsw;
mod glwe_automorphism_key; mod glwe_automorphism_key;
mod glwe_ct; mod glwe_ct;
@@ -6,6 +7,7 @@ mod glwe_switching_key;
mod glwe_tensor_key; mod glwe_tensor_key;
pub use gglwe::*; pub use gglwe::*;
pub use gglwe_to_ggsw_key::*;
pub use ggsw::*; pub use ggsw::*;
pub use glwe_automorphism_key::*; pub use glwe_automorphism_key::*;
pub use glwe_ct::*; pub use glwe_ct::*;

View File

@@ -148,7 +148,7 @@ where
// Example for ksk rank 2 to rank 3: // Example for ksk rank 2 to rank 3:
// //
// (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2) // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2)
// (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2) // (-(b0*s0 + b1*s1 + b2*s2) + s1', b0, b1, b2)
// //
// Example ksk rank 2 to rank 1 // Example ksk rank 2 to rank 1
// //

View File

@@ -0,0 +1,112 @@
use poulpy_hal::{
api::{ModuleN, ScratchTakeBasic, VecZnxCopy},
layouts::{Backend, DataMut, Module, Scratch},
source::Source,
};
use crate::{
GGLWEEncryptSk, GetDistribution, ScratchTakeCore,
layouts::{
GGLWEInfos, GGLWEToGGSWKey, GGLWEToGGSWKeyToMut, GLWEInfos, GLWESecret, GLWESecretTensor, GLWESecretTensorFactory,
GLWESecretToRef,
prepared::{GLWESecretPrepared, GLWESecretPreparedFactory},
},
};
impl GGLWEToGGSWKey<Vec<u8>> {
pub fn encrypt_sk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
where
A: GGLWEInfos,
M: GGLWEToGGSWKeyEncryptSk<BE>,
{
module.gglwe_to_ggsw_key_encrypt_sk_tmp_bytes(infos)
}
}
impl<DataSelf: DataMut> GGLWEToGGSWKey<DataSelf> {
pub fn encrypt_sk<M, S, BE: Backend>(
&mut self,
module: &M,
sk: &S,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
M: GGLWEToGGSWKeyEncryptSk<BE>,
S: GLWESecretToRef + GetDistribution + GLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
module.gglwe_to_ggsw_key_encrypt_sk(self, sk, source_xa, source_xe, scratch);
}
}
pub trait GGLWEToGGSWKeyEncryptSk<BE: Backend> {
fn gglwe_to_ggsw_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GGLWEInfos;
fn gglwe_to_ggsw_key_encrypt_sk<R, S>(
&self,
res: &mut R,
sk: &S,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
R: GGLWEToGGSWKeyToMut,
S: GLWESecretToRef + GetDistribution + GLWEInfos;
}
impl<BE: Backend> GGLWEToGGSWKeyEncryptSk<BE> for Module<BE>
where
Self: ModuleN + GGLWEEncryptSk<BE> + GLWESecretTensorFactory<BE> + GLWESecretPreparedFactory<BE> + VecZnxCopy,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn gglwe_to_ggsw_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GGLWEInfos,
{
let sk_prepared: usize = GLWESecretPrepared::bytes_of(self, infos.rank());
let sk_tensor: usize = GLWESecretTensor::bytes_of_from_infos(infos);
let gglwe_encrypt: usize = self.gglwe_encrypt_sk_tmp_bytes(infos);
let sk_ij = GLWESecret::bytes_of(self.n().into(), infos.rank());
(sk_prepared + sk_tensor + sk_ij) + gglwe_encrypt.max(self.glwe_secret_tensor_prepare_tmp_bytes(infos.rank()))
}
fn gglwe_to_ggsw_key_encrypt_sk<R, S>(
&self,
res: &mut R,
sk: &S,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
R: GGLWEToGGSWKeyToMut,
S: GLWESecretToRef + GetDistribution + GLWEInfos,
{
let res: &mut GGLWEToGGSWKey<&mut [u8]> = &mut res.to_mut();
let rank: usize = res.rank_out().as_usize();
let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank());
let (mut sk_tensor, scratch_2) = scratch_1.take_glwe_secret_tensor(self.n().into(), res.rank());
sk_prepared.prepare(self, sk);
sk_tensor.prepare(self, sk, scratch_2);
let (mut sk_ij, scratch_3) = scratch_2.take_scalar_znx(self.n(), rank);
for i in 0..rank {
for j in 0..rank {
self.vec_znx_copy(
&mut sk_ij.as_vec_znx_mut(),
j,
&sk_tensor.at(i, j).as_vec_znx(),
0,
);
}
res.at_mut(i)
.encrypt_sk(self, &sk_ij, &sk_prepared, source_xa, source_xe, scratch_3);
}
}
}

View File

@@ -1,8 +1,5 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::ModuleN,
ModuleN, ScratchTakeBasic, SvpApplyDftToDft, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf,
VecZnxIdftApplyTmpA,
},
layouts::{Backend, DataMut, Module, Scratch}, layouts::{Backend, DataMut, Module, Scratch},
source::Source, source::Source,
}; };
@@ -10,7 +7,8 @@ use poulpy_hal::{
use crate::{ use crate::{
GGLWEEncryptSk, GetDistribution, ScratchTakeCore, GGLWEEncryptSk, GetDistribution, ScratchTakeCore,
layouts::{ layouts::{
GGLWE, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, GLWETensorKey, GLWETensorKeyToMut, LWEInfos, Rank, GGLWEInfos, GGLWELayout, GGLWEToMut, GLWEInfos, GLWESecretTensor, GLWESecretTensorFactory, GLWESecretToRef,
GLWETensorKey,
prepared::{GLWESecretPrepared, GLWESecretPreparedFactory}, prepared::{GLWESecretPrepared, GLWESecretPreparedFactory},
}, },
}; };
@@ -55,33 +53,35 @@ pub trait GLWETensorKeyEncryptSk<BE: Backend> {
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch<BE>, scratch: &mut Scratch<BE>,
) where ) where
R: GLWETensorKeyToMut, R: GGLWEToMut + GGLWEInfos,
S: GLWESecretToRef + GetDistribution + GLWEInfos; S: GLWESecretToRef + GetDistribution + GLWEInfos;
} }
impl<BE: Backend> GLWETensorKeyEncryptSk<BE> for Module<BE> impl<BE: Backend> GLWETensorKeyEncryptSk<BE> for Module<BE>
where where
Self: ModuleN Self: ModuleN + GGLWEEncryptSk<BE> + GLWESecretPreparedFactory<BE> + GLWESecretTensorFactory<BE>,
+ GGLWEEncryptSk<BE>
+ VecZnxDftBytesOf
+ VecZnxBigBytesOf
+ GLWESecretPreparedFactory<BE>
+ VecZnxDftApply<BE>
+ SvpApplyDftToDft<BE>
+ VecZnxIdftApplyTmpA<BE>
+ VecZnxBigNormalize<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn glwe_tensor_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize fn glwe_tensor_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
GLWESecretPrepared::bytes_of(self, infos.rank_out()) let sk_prepared: usize = GLWESecretPrepared::bytes_of(self, infos.rank_out());
+ self.bytes_of_vec_znx_dft(infos.rank_out().into(), 1) let sk_tensor: usize = GLWESecretTensor::bytes_of_from_infos(infos);
+ self.bytes_of_vec_znx_big(1, 1)
+ self.bytes_of_vec_znx_dft(1, 1) let tensor_infos: GGLWELayout = GGLWELayout {
+ GLWESecret::bytes_of(self.n().into(), Rank(1)) n: infos.n(),
+ GGLWE::encrypt_sk_tmp_bytes(self, infos) base2k: infos.base2k(),
k: infos.k(),
rank_in: GLWESecretTensor::pairs(infos.rank().into()).into(),
rank_out: infos.rank_out(),
dnum: infos.dnum(),
dsize: infos.dsize(),
};
let gglwe_encrypt: usize = self.gglwe_encrypt_sk_tmp_bytes(&tensor_infos);
(sk_prepared + sk_tensor) + gglwe_encrypt.max(self.glwe_secret_tensor_prepare_tmp_bytes(infos.rank()))
} }
fn glwe_tensor_key_encrypt_sk<R, S>( fn glwe_tensor_key_encrypt_sk<R, S>(
@@ -92,56 +92,24 @@ where
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch<BE>, scratch: &mut Scratch<BE>,
) where ) where
R: GLWETensorKeyToMut, R: GGLWEToMut + GGLWEInfos,
S: GLWESecretToRef + GetDistribution + GLWEInfos, S: GLWESecretToRef + GetDistribution + GLWEInfos,
{ {
let res: &mut GLWETensorKey<&mut [u8]> = &mut res.to_mut();
// let n: RingDegree = sk.n();
let rank: Rank = res.rank_out();
let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, sk.rank());
sk_prepared.prepare(self, sk);
let sk: &GLWESecret<&[u8]> = &sk.to_ref();
assert_eq!(res.rank_out(), sk.rank()); assert_eq!(res.rank_out(), sk.rank());
assert_eq!(res.n(), sk.n()); assert_eq!(res.n(), sk.n());
let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank.into(), 1); let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, res.rank());
let (mut sk_tensor, scratch_2) = scratch_1.take_glwe_secret_tensor(self.n().into(), res.rank());
sk_prepared.prepare(self, sk);
sk_tensor.prepare(self, sk, scratch_2);
(0..rank.into()).for_each(|i| { self.gglwe_encrypt_sk(
self.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i); res,
}); &sk_tensor.data,
let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1);
let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(self.n().into(), Rank(1));
let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(self, 1, 1);
(0..rank.into()).for_each(|i| {
(i..rank.into()).for_each(|j| {
self.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i);
self.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
self.vec_znx_big_normalize(
res.base2k().into(),
&mut sk_ij.data.as_vec_znx_mut(),
0,
res.base2k().into(),
&sk_ij_big,
0,
scratch_5,
);
res.at_mut(i, j).encrypt_sk(
self,
&sk_ij.data,
&sk_prepared, &sk_prepared,
source_xa, source_xa,
source_xe, source_xe,
scratch_5, scratch_2,
); );
});
})
} }
} }

View File

@@ -7,23 +7,22 @@ use poulpy_hal::{
use crate::{ use crate::{
GGLWEEncryptSk, ScratchTakeCore, GGLWEEncryptSk, ScratchTakeCore,
layouts::{ layouts::{
GGLWE, GGLWEInfos, GGLWEToMut, GLWESecret, GLWESecretToRef, GLWEToLWESwitchingKey, LWEInfos, LWESecret, LWESecretToRef, GGLWE, GGLWEInfos, GGLWEToMut, GLWESecret, GLWESecretToRef, GLWEToLWEKey, LWEInfos, LWESecret, LWESecretToRef, Rank,
Rank,
prepared::{GLWESecretPrepared, GLWESecretPreparedFactory}, prepared::{GLWESecretPrepared, GLWESecretPreparedFactory},
}, },
}; };
impl GLWEToLWESwitchingKey<Vec<u8>> { impl GLWEToLWEKey<Vec<u8>> {
pub fn encrypt_sk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize pub fn encrypt_sk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
M: GLWEToLWESwitchingKeyEncryptSk<BE>, M: GLWEToLWESwitchingKeyEncryptSk<BE>,
{ {
module.glwe_to_lwe_switching_key_encrypt_sk_tmp_bytes(infos) module.glwe_to_lwe_key_encrypt_sk_tmp_bytes(infos)
} }
} }
impl<D: DataMut> GLWEToLWESwitchingKey<D> { impl<D: DataMut> GLWEToLWEKey<D> {
pub fn encrypt_sk<M, S1, S2, BE: Backend>( pub fn encrypt_sk<M, S1, S2, BE: Backend>(
&mut self, &mut self,
module: &M, module: &M,
@@ -38,16 +37,16 @@ impl<D: DataMut> GLWEToLWESwitchingKey<D> {
S2: GLWESecretToRef, S2: GLWESecretToRef,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
module.glwe_to_lwe_switching_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); module.glwe_to_lwe_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch);
} }
} }
pub trait GLWEToLWESwitchingKeyEncryptSk<BE: Backend> { pub trait GLWEToLWESwitchingKeyEncryptSk<BE: Backend> {
fn glwe_to_lwe_switching_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize fn glwe_to_lwe_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where where
A: GGLWEInfos; A: GGLWEInfos;
fn glwe_to_lwe_switching_key_encrypt_sk<R, S1, S2>( fn glwe_to_lwe_key_encrypt_sk<R, S1, S2>(
&self, &self,
res: &mut R, res: &mut R,
sk_lwe: &S1, sk_lwe: &S1,
@@ -70,7 +69,7 @@ where
+ VecZnxAutomorphismInplaceTmpBytes, + VecZnxAutomorphismInplaceTmpBytes,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn glwe_to_lwe_switching_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize fn glwe_to_lwe_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
@@ -79,7 +78,7 @@ where
.max(GLWESecret::bytes_of(self.n().into(), infos.rank_in()) + self.vec_znx_automorphism_inplace_tmp_bytes()) .max(GLWESecret::bytes_of(self.n().into(), infos.rank_in()) + self.vec_znx_automorphism_inplace_tmp_bytes())
} }
fn glwe_to_lwe_switching_key_encrypt_sk<R, S1, S2>( fn glwe_to_lwe_key_encrypt_sk<R, S1, S2>(
&self, &self,
res: &mut R, res: &mut R,
sk_lwe: &S1, sk_lwe: &S1,

View File

@@ -8,21 +8,21 @@ use crate::{
GGLWEEncryptSk, ScratchTakeCore, GGLWEEncryptSk, ScratchTakeCore,
layouts::{ layouts::{
GGLWE, GGLWEInfos, GGLWEToMut, GLWESecret, GLWESecretPreparedFactory, GLWESecretPreparedToRef, LWEInfos, LWESecret, GGLWE, GGLWEInfos, GGLWEToMut, GLWESecret, GLWESecretPreparedFactory, GLWESecretPreparedToRef, LWEInfos, LWESecret,
LWESecretToRef, LWEToGLWESwitchingKey, Rank, LWESecretToRef, LWEToGLWEKey, Rank,
}, },
}; };
impl LWEToGLWESwitchingKey<Vec<u8>> { impl LWEToGLWEKey<Vec<u8>> {
pub fn encrypt_sk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize pub fn encrypt_sk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
M: LWEToGLWESwitchingKeyEncryptSk<BE>, M: LWEToGLWESwitchingKeyEncryptSk<BE>,
{ {
module.lwe_to_glwe_switching_key_encrypt_sk_tmp_bytes(infos) module.lwe_to_glwe_key_encrypt_sk_tmp_bytes(infos)
} }
} }
impl<D: DataMut> LWEToGLWESwitchingKey<D> { impl<D: DataMut> LWEToGLWEKey<D> {
pub fn encrypt_sk<S1, S2, M, BE: Backend>( pub fn encrypt_sk<S1, S2, M, BE: Backend>(
&mut self, &mut self,
module: &M, module: &M,
@@ -37,16 +37,16 @@ impl<D: DataMut> LWEToGLWESwitchingKey<D> {
M: LWEToGLWESwitchingKeyEncryptSk<BE>, M: LWEToGLWESwitchingKeyEncryptSk<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
module.lwe_to_glwe_switching_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch); module.lwe_to_glwe_key_encrypt_sk(self, sk_lwe, sk_glwe, source_xa, source_xe, scratch);
} }
} }
pub trait LWEToGLWESwitchingKeyEncryptSk<BE: Backend> { pub trait LWEToGLWESwitchingKeyEncryptSk<BE: Backend> {
fn lwe_to_glwe_switching_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize fn lwe_to_glwe_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where where
A: GGLWEInfos; A: GGLWEInfos;
fn lwe_to_glwe_switching_key_encrypt_sk<R, S1, S2>( fn lwe_to_glwe_key_encrypt_sk<R, S1, S2>(
&self, &self,
res: &mut R, res: &mut R,
sk_lwe: &S1, sk_lwe: &S1,
@@ -69,20 +69,20 @@ where
+ VecZnxAutomorphismInplaceTmpBytes, + VecZnxAutomorphismInplaceTmpBytes,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn lwe_to_glwe_switching_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize fn lwe_to_glwe_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
debug_assert_eq!( debug_assert_eq!(
infos.rank_in(), infos.rank_in(),
Rank(1), Rank(1),
"rank_in != 1 is not supported for LWEToGLWESwitchingKey" "rank_in != 1 is not supported for LWEToGLWEKeyPrepared"
); );
GLWESecret::bytes_of(self.n().into(), infos.rank_in()) GLWESecret::bytes_of(self.n().into(), infos.rank_in())
+ GGLWE::encrypt_sk_tmp_bytes(self, infos).max(self.vec_znx_automorphism_inplace_tmp_bytes()) + GGLWE::encrypt_sk_tmp_bytes(self, infos).max(self.vec_znx_automorphism_inplace_tmp_bytes())
} }
fn lwe_to_glwe_switching_key_encrypt_sk<R, S1, S2>( fn lwe_to_glwe_key_encrypt_sk<R, S1, S2>(
&self, &self,
res: &mut R, res: &mut R,
sk_lwe: &S1, sk_lwe: &S1,

View File

@@ -1,28 +1,30 @@
mod compressed; mod compressed;
mod gglwe; mod gglwe;
mod gglwe_to_ggsw_key;
mod ggsw; mod ggsw;
mod glwe; mod glwe;
mod glwe_automorphism_key; mod glwe_automorphism_key;
mod glwe_public_key; mod glwe_public_key;
mod glwe_switching_key; mod glwe_switching_key;
mod glwe_tensor_key; mod glwe_tensor_key;
mod glwe_to_lwe_switching_key; mod glwe_to_lwe_key;
mod lwe; mod lwe;
mod lwe_switching_key; mod lwe_switching_key;
mod lwe_to_glwe_switching_key; mod lwe_to_glwe_key;
pub use compressed::*; pub use compressed::*;
pub use gglwe::*; pub use gglwe::*;
pub use gglwe_to_ggsw_key::*;
pub use ggsw::*; pub use ggsw::*;
pub use glwe::*; pub use glwe::*;
pub use glwe_automorphism_key::*; pub use glwe_automorphism_key::*;
pub use glwe_public_key::*; pub use glwe_public_key::*;
pub use glwe_switching_key::*; pub use glwe_switching_key::*;
pub use glwe_tensor_key::*; pub use glwe_tensor_key::*;
pub use glwe_to_lwe_switching_key::*; pub use glwe_to_lwe_key::*;
pub use lwe::*; pub use lwe::*;
pub use lwe_switching_key::*; pub use lwe_switching_key::*;
pub use lwe_to_glwe_switching_key::*; pub use lwe_to_glwe_key::*;
pub const SIGMA: f64 = 3.2; pub const SIGMA: f64 = 3.2;
pub(crate) const SIGMA_BOUND: f64 = 6.0 * SIGMA; pub(crate) const SIGMA_BOUND: f64 = 6.0 * SIGMA;

View File

@@ -0,0 +1,388 @@
use std::collections::HashMap;
use poulpy_hal::{
api::ModuleLogN,
layouts::{Backend, GaloisElement, Module, Scratch},
};
use crate::{
GLWEAdd, GLWEAutomorphism, GLWECopy, GLWENormalize, GLWERotate, GLWEShift, GLWESub, ScratchTakeCore,
glwe_trace::GLWETrace,
layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos},
};
/// [GLWEPacker] enables only the fly GLWE packing
/// with constant memory of Log(N) ciphertexts.
/// Main difference with usual GLWE packing is that
/// the output is bit-reversed.
pub struct GLWEPacker {
accumulators: Vec<Accumulator>,
log_batch: usize,
counter: usize,
}
/// [Accumulator] stores intermediate packing result.
/// There are Log(N) such accumulators in a [GLWEPacker].
struct Accumulator {
data: GLWE<Vec<u8>>,
value: bool, // Implicit flag for zero ciphertext
control: bool, // Can be combined with incoming value
}
impl Accumulator {
/// Allocates a new [Accumulator].
///
/// #Arguments
///
/// * `module`: static backend FFT tables.
/// * `base2k`: base 2 logarithm of the GLWE ciphertext in memory digit representation.
/// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus.
/// * `rank`: rank of the GLWE ciphertext.
pub fn alloc<A>(infos: &A) -> Self
where
A: GLWEInfos,
{
Self {
data: GLWE::alloc_from_infos(infos),
value: false,
control: false,
}
}
}
impl GLWEPacker {
/// Instantiates a new [GLWEPacker].
///
/// # Arguments
///
/// * `log_batch`: packs coefficients which are multiples of X^{N/2^log_batch}.
/// i.e. with `log_batch=0` only the constant coefficient is packed
/// and N GLWE ciphertext can be packed. With `log_batch=2` all coefficients
/// which are multiples of X^{N/4} are packed. Meaning that N/4 ciphertexts
/// can be packed.
pub fn alloc<A>(infos: &A, log_batch: usize) -> Self
where
A: GLWEInfos,
{
let mut accumulators: Vec<Accumulator> = Vec::<Accumulator>::new();
let log_n: usize = infos.n().log2();
(0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(infos)));
GLWEPacker {
accumulators,
log_batch,
counter: 0,
}
}
/// Implicit reset of the internal state (to be called before a new packing procedure).
fn reset(&mut self) {
for i in 0..self.accumulators.len() {
self.accumulators[i].value = false;
self.accumulators[i].control = false;
}
self.counter = 0;
}
/// Number of scratch space bytes required to call [Self::add].
pub fn tmp_bytes<R, K, M, BE: Backend>(module: &M, res_infos: &R, key_infos: &K) -> usize
where
R: GLWEInfos,
K: GGLWEInfos,
M: GLWEPackerOps<BE>,
{
GLWE::bytes_of_from_infos(res_infos)
+ module
.glwe_rsh_tmp_byte()
.max(module.glwe_automorphism_tmp_bytes(res_infos, res_infos, key_infos))
}
pub fn galois_elements<M, BE: Backend>(module: &M) -> Vec<i64>
where
M: GLWETrace<BE>,
{
module.glwe_trace_galois_elements()
}
/// Adds a GLWE ciphertext to the [GLWEPacker].
/// #Arguments
///
/// * `module`: static backend FFT tables.
/// * `res`: space to append fully packed ciphertext. Only when the number
/// of packed ciphertexts reaches N/2^log_batch is a result written.
/// * `a`: ciphertext to pack. Can optionally give None to pack a 0 ciphertext.
/// * `auto_keys`: a [HashMap] containing the [AutomorphismKeyExec]s.
/// * `scratch`: scratch space of size at least [Self::tmp_bytes].
pub fn add<A, K, M, BE: Backend>(&mut self, module: &M, a: Option<&A>, auto_keys: &HashMap<i64, K>, scratch: &mut Scratch<BE>)
where
A: GLWEToRef + GLWEInfos,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
M: GLWEPackerOps<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
assert!(
(self.counter as u32) < self.accumulators[0].data.n(),
"Packing limit of {} reached",
self.accumulators[0].data.n().0 as usize >> self.log_batch
);
module.packer_add(self, a, self.log_batch, auto_keys, scratch);
self.counter += 1 << self.log_batch;
}
/// Flush result to`res`.
pub fn flush<R, M, BE: Backend>(&mut self, module: &M, res: &mut R)
where
R: GLWEToMut,
M: GLWEPackerOps<BE>,
{
assert!(self.counter as u32 == self.accumulators[0].data.n());
// Copy result GLWE into res GLWE
module.glwe_copy(
res,
&self.accumulators[module.log_n() - self.log_batch - 1].data,
);
self.reset();
}
}
impl<BE: Backend> GLWEPackerOps<BE> for Module<BE> where
Self: Sized
+ ModuleLogN
+ GLWEAutomorphism<BE>
+ GaloisElement
+ GLWERotate<BE>
+ GLWESub
+ GLWEShift<BE>
+ GLWEAdd
+ GLWENormalize<BE>
+ GLWECopy
+ GLWEAutomorphism<BE>
+ GaloisElement
+ GLWERotate<BE>
+ GLWESub
+ GLWEShift<BE>
+ GLWEAdd
+ GLWENormalize<BE>
{
}
pub trait GLWEPackerOps<BE: Backend>
where
Self: Sized
+ ModuleLogN
+ GLWEAutomorphism<BE>
+ GaloisElement
+ GLWERotate<BE>
+ GLWESub
+ GLWEShift<BE>
+ GLWEAdd
+ GLWENormalize<BE>
+ GLWECopy
+ GLWEAutomorphism<BE>
+ GaloisElement
+ GLWERotate<BE>
+ GLWESub
+ GLWEShift<BE>
+ GLWEAdd
+ GLWENormalize<BE>,
{
fn packer_add<A, K>(
&self,
packer: &mut GLWEPacker,
a: Option<&A>,
i: usize,
auto_keys: &HashMap<i64, K>,
scratch: &mut Scratch<BE>,
) where
A: GLWEToRef + GLWEInfos,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
pack_core(self, a, &mut packer.accumulators, i, auto_keys, scratch)
}
}
fn pack_core<A, K, M, BE: Backend>(
module: &M,
a: Option<&A>,
accumulators: &mut [Accumulator],
i: usize,
auto_keys: &HashMap<i64, K>,
scratch: &mut Scratch<BE>,
) where
A: GLWEToRef + GLWEInfos,
M: ModuleLogN
+ GLWEAutomorphism<BE>
+ GaloisElement
+ GLWERotate<BE>
+ GLWESub
+ GLWEShift<BE>
+ GLWEAdd
+ GLWENormalize<BE>
+ GLWECopy
+ GLWEAutomorphism<BE>
+ GaloisElement
+ GLWERotate<BE>
+ GLWESub
+ GLWEShift<BE>
+ GLWEAdd
+ GLWENormalize<BE>,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
let log_n: usize = module.log_n();
if i == log_n {
return;
}
// Isolate the first accumulator
let (acc_prev, acc_next) = accumulators.split_at_mut(1);
// Control = true accumlator is free to overide
if !acc_prev[0].control {
let acc_mut_ref: &mut Accumulator = &mut acc_prev[0]; // from split_at_mut
// No previous value -> copies and sets flags accordingly
if let Some(a_ref) = a {
module.glwe_copy(&mut acc_mut_ref.data, a_ref);
acc_mut_ref.value = true
} else {
acc_mut_ref.value = false
}
acc_mut_ref.control = true; // Able to be combined on next call
} else {
// Compresses acc_prev <- combine(acc_prev, a).
combine(module, &mut acc_prev[0], a, i, auto_keys, scratch);
acc_prev[0].control = false;
// Propagates to next accumulator
if acc_prev[0].value {
pack_core(
module,
Some(&acc_prev[0].data),
acc_next,
i + 1,
auto_keys,
scratch,
);
} else {
pack_core(
module,
None::<&GLWE<Vec<u8>>>,
acc_next,
i + 1,
auto_keys,
scratch,
);
}
}
}
fn combine<B, K, M, BE: Backend>(
module: &M,
acc: &mut Accumulator,
b: Option<&B>,
i: usize,
auto_keys: &HashMap<i64, K>,
scratch: &mut Scratch<BE>,
) where
B: GLWEToRef + GLWEInfos,
B: GLWEToRef + GLWEInfos,
M: ModuleLogN
+ GLWEAutomorphism<BE>
+ GaloisElement
+ GLWERotate<BE>
+ GLWESub
+ GLWEShift<BE>
+ GLWEAdd
+ GLWENormalize<BE>
+ GLWECopy
+ GLWEAutomorphism<BE>
+ GaloisElement
+ GLWERotate<BE>
+ GLWESub
+ GLWEShift<BE>
+ GLWEAdd
+ GLWENormalize<BE>,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
let log_n: usize = acc.data.n().log2();
let a: &mut GLWE<Vec<u8>> = &mut acc.data;
let gal_el: i64 = if i == 0 {
-1
} else {
module.galois_element(1 << (i - 1))
};
let t: i64 = 1 << (log_n - i - 1);
// Goal is to evaluate: a = a + b*X^t + phi(a - b*X^t))
// We also use the identity: AUTO(a * X^t, g) = -X^t * AUTO(a, g)
// where t = 2^(log_n - i - 1) and g = 5^{2^(i - 1)}
// Different cases for wether a and/or b are zero.
//
// Implicite RSH without modulus switch, introduces extra I(X) * Q/2 on decryption.
// Necessary so that the scaling of the plaintext remains constant.
// It however is ok to do so here because coefficients are eventually
// either mapped to garbage or twice their value which vanishes I(X)
// since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q.
if acc.value {
if let Some(b) = b {
let (mut tmp_b, scratch_1) = scratch.take_glwe(a);
// a = a * X^-t
module.glwe_rotate_inplace(-t, a, scratch_1);
// tmp_b = a * X^-t - b
module.glwe_sub(&mut tmp_b, a, b);
module.glwe_rsh(1, &mut tmp_b, scratch_1);
// a = a * X^-t + b
module.glwe_add_inplace(a, b);
module.glwe_rsh(1, a, scratch_1);
module.glwe_normalize_inplace(&mut tmp_b, scratch_1);
// tmp_b = phi(a * X^-t - b)
if let Some(auto_key) = auto_keys.get(&gal_el) {
module.glwe_automorphism_inplace(&mut tmp_b, auto_key, scratch_1);
} else {
panic!("auto_key[{gal_el}] not found");
}
// a = a * X^-t + b - phi(a * X^-t - b)
module.glwe_sub_inplace(a, &tmp_b);
module.glwe_normalize_inplace(a, scratch_1);
// a = a + b * X^t - phi(a * X^-t - b) * X^t
// = a + b * X^t - phi(a * X^-t - b) * - phi(X^t)
// = a + b * X^t + phi(a - b * X^t)
module.glwe_rotate_inplace(t, a, scratch_1);
} else {
module.glwe_rsh(1, a, scratch);
// a = a + phi(a)
if let Some(auto_key) = auto_keys.get(&gal_el) {
module.glwe_automorphism_add_inplace(a, auto_key, scratch);
} else {
panic!("auto_key[{gal_el}] not found");
}
}
} else if let Some(b) = b {
let (mut tmp_b, scratch_1) = scratch.take_glwe(a);
module.glwe_rotate(t, &mut tmp_b, b);
module.glwe_rsh(1, &mut tmp_b, scratch_1);
// a = (b* X^t - phi(b* X^t))
if let Some(auto_key) = auto_keys.get(&gal_el) {
module.glwe_automorphism_sub_negate(a, &tmp_b, auto_key, scratch_1);
} else {
panic!("auto_key[{gal_el}] not found");
}
acc.value = true;
}
}

View File

@@ -7,166 +7,23 @@ use poulpy_hal::{
use crate::{ use crate::{
GLWEAdd, GLWEAutomorphism, GLWECopy, GLWENormalize, GLWERotate, GLWEShift, GLWESub, ScratchTakeCore, GLWEAdd, GLWEAutomorphism, GLWECopy, GLWENormalize, GLWERotate, GLWEShift, GLWESub, ScratchTakeCore,
glwe_trace::GLWETrace, layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement},
layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos},
}; };
pub trait GLWEPacking<BE: Backend> {
/// [GLWEPacker] enables only the fly GLWE packing /// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)]
/// with constant memory of Log(N) ciphertexts. /// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)]
/// Main difference with usual GLWE packing is that fn glwe_pack<R, K>(
/// the output is bit-reversed. &self,
pub struct GLWEPacker { cts: &mut HashMap<usize, &mut R>,
accumulators: Vec<Accumulator>, log_gap_out: usize,
log_batch: usize, keys: &HashMap<i64, K>,
counter: usize, scratch: &mut Scratch<BE>,
) where
R: GLWEToMut + GLWEToRef + GLWEInfos,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos;
} }
/// [Accumulator] stores intermediate packing result. impl<BE: Backend> GLWEPacking<BE> for Module<BE>
/// There are Log(N) such accumulators in a [GLWEPacker].
struct Accumulator {
data: GLWE<Vec<u8>>,
value: bool, // Implicit flag for zero ciphertext
control: bool, // Can be combined with incoming value
}
impl Accumulator {
/// Allocates a new [Accumulator].
///
/// #Arguments
///
/// * `module`: static backend FFT tables.
/// * `base2k`: base 2 logarithm of the GLWE ciphertext in memory digit representation.
/// * `k`: base 2 precision of the GLWE ciphertext precision over the Torus.
/// * `rank`: rank of the GLWE ciphertext.
pub fn alloc<A>(infos: &A) -> Self
where
A: GLWEInfos,
{
Self {
data: GLWE::alloc_from_infos(infos),
value: false,
control: false,
}
}
}
impl GLWEPacker {
/// Instantiates a new [GLWEPacker].
///
/// # Arguments
///
/// * `log_batch`: packs coefficients which are multiples of X^{N/2^log_batch}.
/// i.e. with `log_batch=0` only the constant coefficient is packed
/// and N GLWE ciphertext can be packed. With `log_batch=2` all coefficients
/// which are multiples of X^{N/4} are packed. Meaning that N/4 ciphertexts
/// can be packed.
pub fn alloc<A>(infos: &A, log_batch: usize) -> Self
where
A: GLWEInfos,
{
let mut accumulators: Vec<Accumulator> = Vec::<Accumulator>::new();
let log_n: usize = infos.n().log2();
(0..log_n - log_batch).for_each(|_| accumulators.push(Accumulator::alloc(infos)));
GLWEPacker {
accumulators,
log_batch,
counter: 0,
}
}
/// Implicit reset of the internal state (to be called before a new packing procedure).
fn reset(&mut self) {
for i in 0..self.accumulators.len() {
self.accumulators[i].value = false;
self.accumulators[i].control = false;
}
self.counter = 0;
}
/// Number of scratch space bytes required to call [Self::add].
pub fn tmp_bytes<R, K, M, BE: Backend>(module: &M, res_infos: &R, key_infos: &K) -> usize
where
R: GLWEInfos,
K: GGLWEInfos,
M: GLWEPacking<BE>,
{
GLWE::bytes_of_from_infos(res_infos)
+ module
.glwe_rsh_tmp_byte()
.max(module.glwe_automorphism_tmp_bytes(res_infos, res_infos, key_infos))
}
pub fn galois_elements<M, BE: Backend>(module: &M) -> Vec<i64>
where
M: GLWETrace<BE>,
{
module.glwe_trace_galois_elements()
}
/// Adds a GLWE ciphertext to the [GLWEPacker].
/// #Arguments
///
/// * `module`: static backend FFT tables.
/// * `res`: space to append fully packed ciphertext. Only when the number
/// of packed ciphertexts reaches N/2^log_batch is a result written.
/// * `a`: ciphertext to pack. Can optionally give None to pack a 0 ciphertext.
/// * `auto_keys`: a [HashMap] containing the [AutomorphismKeyExec]s.
/// * `scratch`: scratch space of size at least [Self::tmp_bytes].
pub fn add<A, K, M, BE: Backend>(&mut self, module: &M, a: Option<&A>, auto_keys: &HashMap<i64, K>, scratch: &mut Scratch<BE>)
where
A: GLWEToRef + GLWEInfos,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
M: GLWEPacking<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
assert!(
(self.counter as u32) < self.accumulators[0].data.n(),
"Packing limit of {} reached",
self.accumulators[0].data.n().0 as usize >> self.log_batch
);
pack_core(
module,
a,
&mut self.accumulators,
self.log_batch,
auto_keys,
scratch,
);
self.counter += 1 << self.log_batch;
}
/// Flush result to`res`.
pub fn flush<R, M, BE: Backend>(&mut self, module: &M, res: &mut R)
where
R: GLWEToMut,
M: GLWEPacking<BE>,
{
assert!(self.counter as u32 == self.accumulators[0].data.n());
// Copy result GLWE into res GLWE
module.glwe_copy(
res,
&self.accumulators[module.log_n() - self.log_batch - 1].data,
);
self.reset();
}
}
impl<BE: Backend> GLWEPacking<BE> for Module<BE> where
Self: GLWEAutomorphism<BE>
+ GaloisElement
+ ModuleLogN
+ GLWERotate<BE>
+ GLWESub
+ GLWEShift<BE>
+ GLWEAdd
+ GLWENormalize<BE>
+ GLWECopy
{
}
pub trait GLWEPacking<BE: Backend>
where where
Self: GLWEAutomorphism<BE> Self: GLWEAutomorphism<BE>
+ GaloisElement + GaloisElement
@@ -177,6 +34,7 @@ where
+ GLWEAdd + GLWEAdd
+ GLWENormalize<BE> + GLWENormalize<BE>
+ GLWECopy, + GLWECopy,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
/// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)] /// Packs [x_0: GLWE(m_0), x_1: GLWE(m_1), ..., x_i: GLWE(m_i)]
/// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)] /// to [0: GLWE(m_0 * X^x_0 + m_1 * X^x_1 + ... + m_i * X^x_i)]
@@ -189,7 +47,6 @@ where
) where ) where
R: GLWEToMut + GLWEToRef + GLWEInfos, R: GLWEToMut + GLWEToRef + GLWEInfos,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos, K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -223,169 +80,6 @@ where
} }
} }
fn pack_core<A, K, M, BE: Backend>(
module: &M,
a: Option<&A>,
accumulators: &mut [Accumulator],
i: usize,
auto_keys: &HashMap<i64, K>,
scratch: &mut Scratch<BE>,
) where
A: GLWEToRef + GLWEInfos,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
M: ModuleLogN
+ GLWEAutomorphism<BE>
+ GaloisElement
+ GLWERotate<BE>
+ GLWESub
+ GLWEShift<BE>
+ GLWEAdd
+ GLWENormalize<BE>
+ GLWECopy,
Scratch<BE>: ScratchTakeCore<BE>,
{
let log_n: usize = module.log_n();
if i == log_n {
return;
}
// Isolate the first accumulator
let (acc_prev, acc_next) = accumulators.split_at_mut(1);
// Control = true accumlator is free to overide
if !acc_prev[0].control {
let acc_mut_ref: &mut Accumulator = &mut acc_prev[0]; // from split_at_mut
// No previous value -> copies and sets flags accordingly
if let Some(a_ref) = a {
module.glwe_copy(&mut acc_mut_ref.data, a_ref);
acc_mut_ref.value = true
} else {
acc_mut_ref.value = false
}
acc_mut_ref.control = true; // Able to be combined on next call
} else {
// Compresses acc_prev <- combine(acc_prev, a).
combine(module, &mut acc_prev[0], a, i, auto_keys, scratch);
acc_prev[0].control = false;
// Propagates to next accumulator
if acc_prev[0].value {
pack_core(
module,
Some(&acc_prev[0].data),
acc_next,
i + 1,
auto_keys,
scratch,
);
} else {
pack_core(
module,
None::<&GLWE<Vec<u8>>>,
acc_next,
i + 1,
auto_keys,
scratch,
);
}
}
}
/// [combine] merges two ciphertexts together.
fn combine<B, M, K, BE: Backend>(
module: &M,
acc: &mut Accumulator,
b: Option<&B>,
i: usize,
auto_keys: &HashMap<i64, K>,
scratch: &mut Scratch<BE>,
) where
B: GLWEToRef + GLWEInfos,
M: GLWEAutomorphism<BE> + GaloisElement + GLWERotate<BE> + GLWESub + GLWEShift<BE> + GLWEAdd + GLWENormalize<BE>,
B: GLWEToRef + GLWEInfos,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
let log_n: usize = acc.data.n().log2();
let a: &mut GLWE<Vec<u8>> = &mut acc.data;
let gal_el: i64 = if i == 0 {
-1
} else {
module.galois_element(1 << (i - 1))
};
let t: i64 = 1 << (log_n - i - 1);
// Goal is to evaluate: a = a + b*X^t + phi(a - b*X^t))
// We also use the identity: AUTO(a * X^t, g) = -X^t * AUTO(a, g)
// where t = 2^(log_n - i - 1) and g = 5^{2^(i - 1)}
// Different cases for wether a and/or b are zero.
//
// Implicite RSH without modulus switch, introduces extra I(X) * Q/2 on decryption.
// Necessary so that the scaling of the plaintext remains constant.
// It however is ok to do so here because coefficients are eventually
// either mapped to garbage or twice their value which vanishes I(X)
// since 2*(I(X) * Q/2) = I(X) * Q = 0 mod Q.
if acc.value {
if let Some(b) = b {
let (mut tmp_b, scratch_1) = scratch.take_glwe(a);
// a = a * X^-t
module.glwe_rotate_inplace(-t, a, scratch_1);
// tmp_b = a * X^-t - b
module.glwe_sub(&mut tmp_b, a, b);
module.glwe_rsh(1, &mut tmp_b, scratch_1);
// a = a * X^-t + b
module.glwe_add_inplace(a, b);
module.glwe_rsh(1, a, scratch_1);
module.glwe_normalize_inplace(&mut tmp_b, scratch_1);
// tmp_b = phi(a * X^-t - b)
if let Some(auto_key) = auto_keys.get(&gal_el) {
module.glwe_automorphism_inplace(&mut tmp_b, auto_key, scratch_1);
} else {
panic!("auto_key[{gal_el}] not found");
}
// a = a * X^-t + b - phi(a * X^-t - b)
module.glwe_sub_inplace(a, &tmp_b);
module.glwe_normalize_inplace(a, scratch_1);
// a = a + b * X^t - phi(a * X^-t - b) * X^t
// = a + b * X^t - phi(a * X^-t - b) * - phi(X^t)
// = a + b * X^t + phi(a - b * X^t)
module.glwe_rotate_inplace(t, a, scratch_1);
} else {
module.glwe_rsh(1, a, scratch);
// a = a + phi(a)
if let Some(auto_key) = auto_keys.get(&gal_el) {
module.glwe_automorphism_add_inplace(a, auto_key, scratch);
} else {
panic!("auto_key[{gal_el}] not found");
}
}
} else if let Some(b) = b {
let (mut tmp_b, scratch_1) = scratch.take_glwe(a);
module.glwe_rotate(t, &mut tmp_b, b);
module.glwe_rsh(1, &mut tmp_b, scratch_1);
// a = (b* X^t - phi(b* X^t))
if let Some(auto_key) = auto_keys.get(&gal_el) {
module.glwe_automorphism_sub_negate(a, &tmp_b, auto_key, scratch_1);
} else {
panic!("auto_key[{gal_el}] not found");
}
acc.value = true;
}
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn pack_internal<M, A, B, K, BE: Backend>( fn pack_internal<M, A, B, K, BE: Backend>(
module: &M, module: &M,

View File

@@ -1,8 +1,8 @@
use std::collections::HashMap; use std::collections::HashMap;
use poulpy_hal::{ use poulpy_hal::{
api::ModuleLogN, api::{ModuleLogN, VecZnxNormalize, VecZnxNormalizeTmpBytes},
layouts::{Backend, DataMut, GaloisElement, Module, Scratch, VecZnx, galois_element}, layouts::{Backend, CyclotomicOrder, DataMut, GaloisElement, Module, Scratch, VecZnx, galois_element},
}; };
use crate::{ use crate::{
@@ -27,7 +27,7 @@ impl GLWE<Vec<u8>> {
K: GGLWEInfos, K: GGLWEInfos,
M: GLWETrace<BE>, M: GLWETrace<BE>,
{ {
module.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos) module.glwe_trace_tmp_bytes(res_infos, a_infos, key_infos)
} }
} }
@@ -65,11 +65,6 @@ impl<D: DataMut> GLWE<D> {
} }
} }
impl<BE: Backend> GLWETrace<BE> for Module<BE> where
Self: ModuleLogN + GaloisElement + GLWEAutomorphism<BE> + GLWEShift<BE> + GLWECopy
{
}
#[inline(always)] #[inline(always)]
pub fn trace_galois_elements(log_n: usize, cyclotomic_order: i64) -> Vec<i64> { pub fn trace_galois_elements(log_n: usize, cyclotomic_order: i64) -> Vec<i64> {
(0..log_n) (0..log_n)
@@ -83,9 +78,17 @@ pub fn trace_galois_elements(log_n: usize, cyclotomic_order: i64) -> Vec<i64> {
.collect() .collect()
} }
pub trait GLWETrace<BE: Backend> impl<BE: Backend> GLWETrace<BE> for Module<BE>
where where
Self: ModuleLogN + GaloisElement + GLWEAutomorphism<BE> + GLWEShift<BE> + GLWECopy, Self: ModuleLogN
+ GaloisElement
+ GLWEAutomorphism<BE>
+ GLWEShift<BE>
+ GLWECopy
+ CyclotomicOrder
+ VecZnxNormalizeTmpBytes
+ VecZnxNormalize<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn glwe_trace_galois_elements(&self) -> Vec<i64> { fn glwe_trace_galois_elements(&self) -> Vec<i64> {
trace_galois_elements(self.log_n(), self.cyclotomic_order()) trace_galois_elements(self.log_n(), self.cyclotomic_order())
@@ -115,7 +118,6 @@ where
R: GLWEToMut, R: GLWEToMut,
A: GLWEToRef, A: GLWEToRef,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos, K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
self.glwe_copy(res, a); self.glwe_copy(res, a);
self.glwe_trace_inplace(res, start, end, keys, scratch); self.glwe_trace_inplace(res, start, end, keys, scratch);
@@ -125,7 +127,6 @@ where
where where
R: GLWEToMut, R: GLWEToMut,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos, K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
@@ -212,3 +213,31 @@ where
} }
} }
} }
pub trait GLWETrace<BE: Backend> {
fn glwe_trace_galois_elements(&self) -> Vec<i64>;
fn glwe_trace_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
K: GGLWEInfos;
fn glwe_trace<R, A, K>(
&self,
res: &mut R,
start: usize,
end: usize,
a: &A,
keys: &HashMap<i64, K>,
scratch: &mut Scratch<BE>,
) where
R: GLWEToMut,
A: GLWEToRef,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos;
fn glwe_trace_inplace<R, K>(&self, res: &mut R, start: usize, end: usize, keys: &HashMap<i64, K>, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos;
}

View File

@@ -1,9 +1,9 @@
use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, VecZnx}; use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch};
use crate::{ use crate::{
GGSWExpandRows, ScratchTakeCore, GGSWExpandRows, ScratchTakeCore,
keyswitching::GLWEKeyswitch, keyswitching::GLWEKeyswitch,
layouts::{GGLWEInfos, GGLWEPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, prepared::GLWETensorKeyPreparedToRef}, layouts::{GGLWEInfos, GGLWEPreparedToRef, GGLWEToGGSWKeyPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef},
}; };
impl GGSW<Vec<u8>> { impl GGSW<Vec<u8>> {
@@ -30,7 +30,7 @@ impl<D: DataMut> GGSW<D> {
where where
A: GGSWToRef, A: GGSWToRef,
K: GGLWEPreparedToRef<BE>, K: GGLWEPreparedToRef<BE>,
T: GLWETensorKeyPreparedToRef<BE>, T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGSWKeyswitch<BE>, M: GGSWKeyswitch<BE>,
{ {
@@ -40,7 +40,7 @@ impl<D: DataMut> GGSW<D> {
pub fn keyswitch_inplace<M, K, T, BE: Backend>(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch<BE>) pub fn keyswitch_inplace<M, K, T, BE: Backend>(&mut self, module: &M, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
where where
K: GGLWEPreparedToRef<BE>, K: GGLWEPreparedToRef<BE>,
T: GLWETensorKeyPreparedToRef<BE>, T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGSWKeyswitch<BE>, M: GGSWKeyswitch<BE>,
{ {
@@ -48,9 +48,7 @@ impl<D: DataMut> GGSW<D> {
} }
} }
impl<BE: Backend> GGSWKeyswitch<BE> for Module<BE> where Self: GLWEKeyswitch<BE> + GGSWExpandRows<BE> {} impl<BE: Backend> GGSWKeyswitch<BE> for Module<BE>
pub trait GGSWKeyswitch<BE: Backend>
where where
Self: GLWEKeyswitch<BE> + GGSWExpandRows<BE>, Self: GLWEKeyswitch<BE> + GGSWExpandRows<BE>,
{ {
@@ -65,25 +63,26 @@ where
assert_eq!(tsk_infos.rank_in(), tsk_infos.rank_out()); assert_eq!(tsk_infos.rank_in(), tsk_infos.rank_out());
assert_eq!(key_infos.rank_in(), tsk_infos.rank_in()); assert_eq!(key_infos.rank_in(), tsk_infos.rank_in());
let rank: usize = key_infos.rank_out().into(); self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
.max(self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos))
let size_out: usize = res_infos.k().div_ceil(res_infos.base2k()) as usize;
let res_znx: usize = VecZnx::bytes_of(self.n(), rank + 1, size_out);
let ci_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out);
let ks: usize = self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos);
let expand_rows: usize = self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos);
let res_dft: usize = self.bytes_of_vec_znx_dft(rank + 1, size_out);
if a_infos.base2k() == tsk_infos.base2k() {
res_znx + ci_dft + (ks | expand_rows | res_dft)
} else {
let a_conv: usize = VecZnx::bytes_of(
self.n(),
1,
res_infos.k().div_ceil(tsk_infos.base2k()) as usize,
) + self.vec_znx_normalize_tmp_bytes();
res_znx + ci_dft + (a_conv | ks | expand_rows | res_dft)
} }
fn ggsw_keyswitch_inplace<R, K, T>(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
where
R: GGSWToMut,
K: GGLWEPreparedToRef<BE>,
T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
for row in 0..res.dnum().into() {
// Key-switch column 0, i.e.
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
self.glwe_keyswitch_inplace(&mut res.at_mut(row, 0), key, scratch);
}
self.ggsw_expand_row(res, tsk, scratch);
} }
fn ggsw_keyswitch<R, A, K, T>(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch<BE>) fn ggsw_keyswitch<R, A, K, T>(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
@@ -91,7 +90,7 @@ where
R: GGSWToMut, R: GGSWToMut,
A: GGSWToRef, A: GGSWToRef,
K: GGLWEPreparedToRef<BE>, K: GGLWEPreparedToRef<BE>,
T: GLWETensorKeyPreparedToRef<BE>, T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
@@ -108,22 +107,31 @@ where
self.ggsw_expand_row(res, tsk, scratch); self.ggsw_expand_row(res, tsk, scratch);
} }
}
pub trait GGSWKeyswitch<BE: Backend>
where
Self: GLWEKeyswitch<BE> + GGSWExpandRows<BE>,
{
fn ggsw_keyswitch_tmp_bytes<R, A, K, T>(&self, res_infos: &R, a_infos: &A, key_infos: &K, tsk_infos: &T) -> usize
where
R: GGSWInfos,
A: GGSWInfos,
K: GGLWEInfos,
T: GGLWEInfos;
fn ggsw_keyswitch<R, A, K, T>(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
where
R: GGSWToMut,
A: GGSWToRef,
K: GGLWEPreparedToRef<BE>,
T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>;
fn ggsw_keyswitch_inplace<R, K, T>(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch<BE>) fn ggsw_keyswitch_inplace<R, K, T>(&self, res: &mut R, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
where where
R: GGSWToMut, R: GGSWToMut,
K: GGLWEPreparedToRef<BE>, K: GGLWEPreparedToRef<BE>,
T: GLWETensorKeyPreparedToRef<BE>, T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>;
{
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
for row in 0..res.dnum().into() {
// Key-switch column 0, i.e.
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
self.glwe_keyswitch_inplace(&mut res.at_mut(row, 0), key, scratch);
}
self.ggsw_expand_row(res, tsk, scratch);
}
} }

View File

@@ -1,10 +1,10 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, ModuleN, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VecZnxDftApply, VecZnxDftBytesOf, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes,
VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
}, },
layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos}, layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VecZnxDftToRef, VmpPMat, ZnxInfos},
}; };
use crate::{ use crate::{
@@ -45,46 +45,10 @@ impl<D: DataMut> GLWE<D> {
} }
} }
impl<BE: Backend> GLWEKeyswitch<BE> for Module<BE> where impl<BE: Backend> GLWEKeyswitch<BE> for Module<BE>
Self: Sized
+ ModuleN
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<BE>
+ VmpApplyDftToDftAdd<BE>
+ VecZnxDftApply<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalize<BE>
+ VecZnxNormalizeTmpBytes
{
}
pub trait GLWEKeyswitch<BE: Backend>
where where
Self: Sized Self: Sized + GLWEKeySwitchInternal<BE> + VecZnxBigNormalizeTmpBytes + VecZnxBigNormalize<BE>,
+ ModuleN Scratch<BE>: ScratchTakeCore<BE>,
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxNormalizeTmpBytes
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDft<BE>
+ VmpApplyDftToDftAdd<BE>
+ VecZnxDftApply<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalize<BE>
+ VecZnxNormalizeTmpBytes,
{ {
fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
where where
@@ -92,34 +56,10 @@ where
A: GLWEInfos, A: GLWEInfos,
B: GGLWEInfos, B: GGLWEInfos,
{ {
let in_size: usize = a_infos let cols: usize = res_infos.rank().as_usize() + 1;
.k() self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_infos, key_infos)
.div_ceil(key_infos.base2k()) .max(self.vec_znx_big_normalize_tmp_bytes())
.div_ceil(key_infos.dsize().into()) as usize; + self.bytes_of_vec_znx_dft(cols, key_infos.size())
let out_size: usize = res_infos.size();
let ksk_size: usize = key_infos.size();
let res_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_out() + 1).into(), ksk_size); // TODO OPTIMIZE
let ai_dft: usize = self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size);
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
out_size,
in_size,
in_size,
(key_infos.rank_in()).into(),
(key_infos.rank_out() + 1).into(),
ksk_size,
) + self.bytes_of_vec_znx_dft((key_infos.rank_in()).into(), in_size);
let normalize_big: usize = self.vec_znx_big_normalize_tmp_bytes();
if a_infos.base2k() == key_infos.base2k() {
res_dft + ((ai_dft + vmp) | normalize_big)
} else if key_infos.dsize() == 1 {
// In this case, we only need one column, temporary, that we can drop once a_dft is computed.
let normalize_conv: usize = VecZnx::bytes_of(self.n(), 1, in_size) + self.vec_znx_normalize_tmp_bytes();
res_dft + (((ai_dft + normalize_conv) | vmp) | normalize_big)
} else {
// Since we stride over a to get a_dft when dsize > 1, we need to store the full columns of a with in the base conversion.
let normalize_conv: usize = VecZnx::bytes_of(self.n(), (key_infos.rank_in()).into(), in_size);
res_dft + ((ai_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
}
} }
fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>) fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
@@ -127,7 +67,6 @@ where
R: GLWEToMut, R: GLWEToMut,
A: GLWEToRef, A: GLWEToRef,
K: GGLWEPreparedToRef<BE>, K: GGLWEPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref(); let a: &GLWE<&[u8]> = &a.to_ref();
@@ -164,8 +103,8 @@ where
let base2k_out: usize = b.base2k().into(); let base2k_out: usize = b.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), b.size()); // Todo optimise let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), b.size()); // Todo optimise
let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, a, b, scratch_1); let res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, b, scratch_1);
(0..(res.rank() + 1).into()).for_each(|i| { for i in 0..(res.rank() + 1).into() {
self.vec_znx_big_normalize( self.vec_znx_big_normalize(
basek_out, basek_out,
&mut res.data, &mut res.data,
@@ -175,37 +114,36 @@ where
i, i,
scratch_1, scratch_1,
); );
}) }
} }
fn glwe_keyswitch_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>) fn glwe_keyswitch_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
where where
R: GLWEToMut, R: GLWEToMut,
K: GGLWEPreparedToRef<BE>, K: GGLWEPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GGLWEPrepared<&[u8], BE> = &key.to_ref(); let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
assert_eq!( assert_eq!(
res.rank(), res.rank(),
a.rank_in(), key.rank_in(),
"res.rank(): {} != a.rank_in(): {}", "res.rank(): {} != a.rank_in(): {}",
res.rank(), res.rank(),
a.rank_in() key.rank_in()
); );
assert_eq!( assert_eq!(
res.rank(), res.rank(),
a.rank_out(), key.rank_out(),
"res.rank(): {} != b.rank_out(): {}", "res.rank(): {} != b.rank_out(): {}",
res.rank(), res.rank(),
a.rank_out() key.rank_out()
); );
assert_eq!(res.n(), self.n() as u32); assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32); assert_eq!(key.n(), self.n() as u32);
let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, res, a); let scrach_needed: usize = self.glwe_keyswitch_tmp_bytes(res, res, key);
assert!( assert!(
scratch.available() >= scrach_needed, scratch.available() >= scrach_needed,
@@ -214,11 +152,11 @@ where
); );
let base2k_in: usize = res.base2k().into(); let base2k_in: usize = res.base2k().into();
let base2k_out: usize = a.base2k().into(); let base2k_out: usize = key.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), a.size()); // Todo optimise let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // Todo optimise
let res_big: VecZnxBig<&mut [u8], BE> = keyswitch_internal(self, res_dft, res, a, scratch_1); let res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1);
(0..(res.rank() + 1).into()).for_each(|i| { for i in 0..(res.rank() + 1).into() {
self.vec_znx_big_normalize( self.vec_znx_big_normalize(
base2k_in, base2k_in,
&mut res.data, &mut res.data,
@@ -228,16 +166,68 @@ where
i, i,
scratch_1, scratch_1,
); );
}) }
} }
} }
impl GLWE<Vec<u8>> {} pub trait GLWEKeyswitch<BE: Backend> {
fn glwe_keyswitch_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, key_infos: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GGLWEInfos;
impl<DataSelf: DataMut> GLWE<DataSelf> {} fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
K: GGLWEPreparedToRef<BE>;
pub(crate) fn keyswitch_internal<BE: Backend, M, DR, A, K>( fn glwe_keyswitch_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
module: &M, where
R: GLWEToMut,
K: GGLWEPreparedToRef<BE>;
}
impl<BE: Backend> GLWEKeySwitchInternal<BE> for Module<BE> where
Self: GGLWEProduct<BE>
+ VecZnxDftApply<BE>
+ VecZnxNormalize<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxNormalizeTmpBytes
{
}
pub(crate) trait GLWEKeySwitchInternal<BE: Backend>
where
Self: GGLWEProduct<BE>
+ VecZnxDftApply<BE>
+ VecZnxNormalize<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxNormalizeTmpBytes,
{
fn glwe_keyswitch_internal_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
K: GGLWEInfos,
{
let cols: usize = (a_infos.rank() + 1).into();
let a_size: usize = a_infos.size();
let a_conv = if a_infos.base2k() == key_infos.base2k() {
0
} else {
VecZnx::bytes_of(self.n(), 1, a_size) + self.vec_znx_normalize_tmp_bytes()
};
self.gglwe_product_dft_tmp_bytes(res_infos.size(), a_size, key_infos) + self.bytes_of_vec_znx_dft(cols, a_size) + a_conv
}
fn glwe_keyswitch_internal<DR, A, K>(
&self,
mut res: VecZnxDft<DR, BE>, mut res: VecZnxDft<DR, BE>,
a: &A, a: &A,
key: &K, key: &K,
@@ -247,18 +237,6 @@ where
DR: DataMut, DR: DataMut,
A: GLWEToRef, A: GLWEToRef,
K: GGLWEPreparedToRef<BE>, K: GGLWEPreparedToRef<BE>,
M: ModuleN
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VmpApplyDftToDftTmpBytes
+ VmpApplyDftToDft<BE>
+ VmpApplyDftToDftAdd<BE>
+ VecZnxDftApply<BE>
+ VecZnxIdftApplyConsume<BE>
+ VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxNormalize<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let a: &GLWE<&[u8]> = &a.to_ref(); let a: &GLWE<&[u8]> = &a.to_ref();
@@ -268,103 +246,155 @@ where
let base2k_out: usize = key.base2k().into(); let base2k_out: usize = key.base2k().into();
let cols: usize = (a.rank() + 1).into(); let cols: usize = (a.rank() + 1).into();
let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out); let a_size: usize = (a.size() * base2k_in).div_ceil(base2k_out);
let pmat: &VmpPMat<&[u8], BE> = &key.data;
if key.dsize() == 1 { let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, a_size);
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a.size());
if base2k_in == base2k_out { if base2k_in == base2k_out {
(0..cols - 1).for_each(|col_i| { for col_i in 0..cols - 1 {
module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, a.data(), col_i + 1); self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, a.data(), col_i + 1);
}); }
} else { } else {
let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), 1, a_size); let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size);
(0..cols - 1).for_each(|col_i| { for i in 0..cols - 1 {
module.vec_znx_normalize( self.vec_znx_normalize(
base2k_out, base2k_out,
&mut a_conv, &mut a_conv,
0, 0,
base2k_in, base2k_in,
a.data(), a.data(),
col_i + 1, i + 1,
scratch_2, scratch_2,
); );
module.vec_znx_dft_apply(1, 0, &mut ai_dft, col_i, &a_conv, 0); self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_conv, 0);
}); }
} }
module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1); self.gglwe_product_dft(&mut res, &a_dft, key, scratch_1);
let mut res_big: VecZnxBig<DR, BE> = self.vec_znx_idft_apply_consume(res);
self.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0);
res_big
}
}
impl<BE: Backend> GGLWEProduct<BE> for Module<BE> where
Self: Sized
+ ModuleN
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VmpApplyDftToDft<BE>
+ VmpApplyDftToDftAdd<BE>
+ VecZnxDftCopy<BE>
{
}
pub(crate) trait GGLWEProduct<BE: Backend>
where
Self: Sized
+ ModuleN
+ VecZnxDftBytesOf
+ VmpApplyDftToDftTmpBytes
+ VmpApplyDftToDft<BE>
+ VmpApplyDftToDftAdd<BE>
+ VecZnxDftCopy<BE>,
{
fn gglwe_product_dft_tmp_bytes<K>(&self, res_size: usize, a_size: usize, key_infos: &K) -> usize
where
K: GGLWEInfos,
{
let dsize: usize = key_infos.dsize().as_usize();
if dsize == 1 {
self.vmp_apply_dft_to_dft_tmp_bytes(
res_size,
a_size,
key_infos.dnum().into(),
(key_infos.rank_in()).into(),
(key_infos.rank_out() + 1).into(),
key_infos.size(),
)
} else {
let dnum: usize = key_infos.dnum().into();
let a_size: usize = a_size.div_ceil(dsize).min(dnum);
let ai_dft: usize = self.bytes_of_vec_znx_dft(key_infos.rank_in().into(), a_size);
let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
res_size,
a_size,
dnum,
(key_infos.rank_in()).into(),
(key_infos.rank_out() + 1).into(),
key_infos.size(),
);
ai_dft + vmp
}
}
fn gglwe_product_dft<DR, A, K>(&self, res: &mut VecZnxDft<DR, BE>, a: &A, key: &K, scratch: &mut Scratch<BE>)
where
DR: DataMut,
A: VecZnxDftToRef<BE>,
K: GGLWEPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let a: &VecZnxDft<&[u8], BE> = &a.to_ref();
let key: &GGLWEPrepared<&[u8], BE> = &key.to_ref();
let cols: usize = a.cols();
let a_size: usize = a.size();
let pmat: &VmpPMat<&[u8], BE> = &key.data;
// If dsize == 1, then the digit decomposition is equal to Base2K and we can simply
// can the vmp API.
if key.dsize() == 1 {
self.vmp_apply_dft_to_dft(res, a, pmat, scratch);
// If dsize != 1, then the digit decomposition is k * Base2K with k > 1.
// As such we need to perform a bivariate polynomial convolution in (X, Y) / (X^{N}+1) with Y = 2^-K
// (instead of yn univariate one in X).
//
// Since the basis in Y is small (in practice degree 6-7 max), we perform it naiveley.
// To do so, we group the different limbs of ai_dft by their respective degree in Y
// which are multiples of the current digit.
// For example if dsize = 3, with ai_dft = [a0, a1, a2, a3, a4, a5, a6],
// we group them as [[a0, a3, a5], [a1, a4, a6], [a2, a5, 0]]
// and evaluate sum(a_di * pmat * 2^{di*Base2k})
} else { } else {
let dsize: usize = key.dsize().into(); let dsize: usize = key.dsize().into();
let dnum: usize = key.dnum().into();
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols - 1, a_size.div_ceil(dsize)); // We bound ai_dft size by the number of rows of the matrix
let (mut ai_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize).min(dnum));
ai_dft.data_mut().fill(0); ai_dft.data_mut().fill(0);
if base2k_in == base2k_out {
for di in 0..dsize { for di in 0..dsize {
ai_dft.set_size((a_size + di) / dsize); // Sets ai_dft size according to the current digit (if dsize does not divides a_size),
// bounded by the number of rows (digits) in the prepared matrix.
ai_dft.set_size(((a_size + di) / dsize).min(dnum));
// Small optimization for dsize > 2 // Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then // VMP produce some error e, and since we aggregate vmp * 2^{di * Base2k}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}. // we also aggregate ei * 2^{di * Base2k}, with the largest error being ei * 2^{(dsize-1) * Base2k}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products. // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce // It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality. // noise is kept with respect to the ideal functionality.
res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize); res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols - 1 { for j in 0..cols {
module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, a.data(), j + 1); self.vec_znx_dft_copy(dsize, dsize - di - 1, &mut ai_dft, j, a, j);
} }
if di == 0 { if di == 0 {
module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_1); // res = pmat * ai_dft
self.vmp_apply_dft_to_dft(res, &ai_dft, pmat, scratch_1);
} else { } else {
module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_1); // res = (pmat * ai_dft) * 2^{di * Base2k}
} self.vmp_apply_dft_to_dft_add(res, &ai_dft, pmat, di, scratch_1);
}
} else {
let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), cols - 1, a_size);
for j in 0..cols - 1 {
module.vec_znx_normalize(
base2k_out,
&mut a_conv,
j,
base2k_in,
a.data(),
j + 1,
scratch_2,
);
}
for di in 0..dsize {
ai_dft.set_size((a_size + di) / dsize);
// Small optimization for dsize > 2
// VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
// we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
// As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
// It is possible to further ignore the last dsize-1 limbs, but this introduce
// ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
// noise is kept with respect to the ideal functionality.
res.set_size(pmat.size() - ((dsize - di) as isize - 2).max(0) as usize);
for j in 0..cols - 1 {
module.vec_znx_dft_apply(dsize, dsize - di - 1, &mut ai_dft, j, &a_conv, j);
}
if di == 0 {
module.vmp_apply_dft_to_dft(&mut res, &ai_dft, pmat, scratch_2);
} else {
module.vmp_apply_dft_to_dft_add(&mut res, &ai_dft, pmat, di, scratch_2);
}
} }
} }
res.set_size(res.max_size()); res.set_size(res.max_size());
} }
}
let mut res_big: VecZnxBig<DR, BE> = module.vec_znx_idft_apply_consume(res);
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a.data(), 0);
res_big
} }

View File

@@ -0,0 +1,237 @@
use poulpy_hal::{
layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo},
source::Source,
};
use crate::layouts::{
Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEDecompress, GGLWEInfos,
GGLWEToGGSWKey, GGLWEToGGSWKeyToMut, GLWEInfos, LWEInfos, Rank, TorusPrecision,
};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::fmt;
#[derive(PartialEq, Eq, Clone)]
pub struct GGLWEToGGSWKeyCompressed<D: Data> {
pub(crate) keys: Vec<GGLWECompressed<D>>,
}
impl<D: Data> LWEInfos for GGLWEToGGSWKeyCompressed<D> {
fn n(&self) -> Degree {
self.keys[0].n()
}
fn base2k(&self) -> Base2K {
self.keys[0].base2k()
}
fn k(&self) -> TorusPrecision {
self.keys[0].k()
}
fn size(&self) -> usize {
self.keys[0].size()
}
}
impl<D: Data> GLWEInfos for GGLWEToGGSWKeyCompressed<D> {
fn rank(&self) -> Rank {
self.keys[0].rank_out()
}
}
impl<D: Data> GGLWEInfos for GGLWEToGGSWKeyCompressed<D> {
fn rank_in(&self) -> Rank {
self.rank_out()
}
fn rank_out(&self) -> Rank {
self.keys[0].rank_out()
}
fn dsize(&self) -> Dsize {
self.keys[0].dsize()
}
fn dnum(&self) -> Dnum {
self.keys[0].dnum()
}
}
impl<D: DataRef> fmt::Debug for GGLWEToGGSWKeyCompressed<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{self}")
}
}
impl<D: DataMut> FillUniform for GGLWEToGGSWKeyCompressed<D> {
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
self.keys
.iter_mut()
.for_each(|key: &mut GGLWECompressed<D>| key.fill_uniform(log_bound, source))
}
}
impl<D: DataRef> fmt::Display for GGLWEToGGSWKeyCompressed<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "(GGLWEToGGSWKeyCompressed)",)?;
for (i, key) in self.keys.iter().enumerate() {
write!(f, "{i}: {key}")?;
}
Ok(())
}
}
impl GGLWEToGGSWKeyCompressed<Vec<u8>> {
pub fn alloc_from_infos<A>(infos: &A) -> Self
where
A: GGLWEInfos,
{
assert_eq!(
infos.rank_in(),
infos.rank_out(),
"rank_in != rank_out is not supported for GGLWEToGGSWKeyCompressed"
);
Self::alloc(
infos.n(),
infos.base2k(),
infos.k(),
infos.rank(),
infos.dnum(),
infos.dsize(),
)
}
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self {
GGLWEToGGSWKeyCompressed {
keys: (0..rank.as_usize())
.map(|_| GGLWECompressed::alloc(n, base2k, k, rank, rank, dnum, dsize))
.collect(),
}
}
pub fn bytes_of_from_infos<A>(infos: &A) -> usize
where
A: GGLWEInfos,
{
assert_eq!(
infos.rank_in(),
infos.rank_out(),
"rank_in != rank_out is not supported for GGLWEToGGSWKeyCompressed"
);
Self::bytes_of(
infos.n(),
infos.base2k(),
infos.k(),
infos.rank(),
infos.dnum(),
infos.dsize(),
)
}
pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
rank.as_usize() * GGLWECompressed::bytes_of(n, base2k, k, rank, dnum, dsize)
}
}
impl<D: DataMut> GGLWEToGGSWKeyCompressed<D> {
// Returns a mutable reference to GGLWE_{s}([s[i]*s[0], s[i]*s[1], ..., s[i]*s[rank]])
pub fn at_mut(&mut self, i: usize) -> &mut GGLWECompressed<D> {
assert!((i as u32) < self.rank());
&mut self.keys[i]
}
}
impl<D: DataRef> GGLWEToGGSWKeyCompressed<D> {
// Returns a reference to GGLWE_{s}(s[i] * s[j])
pub fn at(&self, i: usize) -> &GGLWECompressed<D> {
assert!((i as u32) < self.rank());
&self.keys[i]
}
}
impl<D: DataMut> ReaderFrom for GGLWEToGGSWKeyCompressed<D> {
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
if self.keys.len() != len {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("self.keys.len()={} != read len={}", self.keys.len(), len),
));
}
for key in &mut self.keys {
key.read_from(reader)?;
}
Ok(())
}
}
impl<D: DataRef> WriterTo for GGLWEToGGSWKeyCompressed<D> {
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
writer.write_u64::<LittleEndian>(self.keys.len() as u64)?;
for key in &self.keys {
key.write_to(writer)?;
}
Ok(())
}
}
pub trait GGLWEToGGSWKeyDecompress
where
Self: GGLWEDecompress,
{
fn decompress_gglwe_to_ggsw_key<R, O>(&self, res: &mut R, other: &O)
where
R: GGLWEToGGSWKeyToMut,
O: GGLWEToGGSWKeyCompressedToRef,
{
let res: &mut GGLWEToGGSWKey<&mut [u8]> = &mut res.to_mut();
let other: &GGLWEToGGSWKeyCompressed<&[u8]> = &other.to_ref();
assert_eq!(res.keys.len(), other.keys.len());
for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) {
self.decompress_gglwe(a, b);
}
}
}
impl<D: DataMut> GGLWEToGGSWKey<D> {
pub fn decompress<O, M>(&mut self, module: &M, other: &O)
where
M: GGLWEToGGSWKeyDecompress,
O: GGLWEToGGSWKeyCompressedToRef,
{
module.decompress_gglwe_to_ggsw_key(self, other);
}
}
pub trait GGLWEToGGSWKeyCompressedToRef {
fn to_ref(&self) -> GGLWEToGGSWKeyCompressed<&[u8]>;
}
impl<D: DataRef> GGLWEToGGSWKeyCompressedToRef for GGLWEToGGSWKeyCompressed<D>
where
GGLWECompressed<D>: GGLWECompressedToRef,
{
fn to_ref(&self) -> GGLWEToGGSWKeyCompressed<&[u8]> {
GGLWEToGGSWKeyCompressed {
keys: self.keys.iter().map(|c| c.to_ref()).collect(),
}
}
}
pub trait GGLWEToGGSWKeyCompressedToMut {
fn to_mut(&mut self) -> GGLWEToGGSWKeyCompressed<&mut [u8]>;
}
impl<D: DataMut> GGLWEToGGSWKeyCompressedToMut for GGLWEToGGSWKeyCompressed<D>
where
GGLWECompressed<D>: GGLWECompressedToMut,
{
fn to_mut(&mut self) -> GGLWEToGGSWKeyCompressed<&mut [u8]> {
GGLWEToGGSWKeyCompressed {
keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(),
}
}
}

View File

@@ -4,31 +4,34 @@ use poulpy_hal::{
}; };
use crate::layouts::{ use crate::layouts::{
Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEDecompress, GGLWEInfos, Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedSeedMut, GGLWECompressedToMut, GGLWECompressedToRef,
GLWEInfos, GLWETensorKey, GLWETensorKeyToMut, LWEInfos, Rank, TorusPrecision, GGLWEDecompress, GGLWEInfos, GGLWEToMut, GLWEInfos, GLWETensorKey, LWEInfos, Rank, TorusPrecision,
}; };
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::fmt; use std::fmt;
#[derive(PartialEq, Eq, Clone)] #[derive(PartialEq, Eq, Clone)]
pub struct GLWETensorKeyCompressed<D: Data> { pub struct GLWETensorKeyCompressed<D: Data>(pub(crate) GGLWECompressed<D>);
pub(crate) keys: Vec<GGLWECompressed<D>>,
impl<D: DataMut> GGLWECompressedSeedMut for GLWETensorKeyCompressed<D> {
fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> {
&mut self.0.seed
}
} }
impl<D: Data> LWEInfos for GLWETensorKeyCompressed<D> { impl<D: Data> LWEInfos for GLWETensorKeyCompressed<D> {
fn n(&self) -> Degree { fn n(&self) -> Degree {
self.keys[0].n() self.0.n()
} }
fn base2k(&self) -> Base2K { fn base2k(&self) -> Base2K {
self.keys[0].base2k() self.0.base2k()
} }
fn k(&self) -> TorusPrecision { fn k(&self) -> TorusPrecision {
self.keys[0].k() self.0.k()
} }
fn size(&self) -> usize { fn size(&self) -> usize {
self.keys[0].size() self.0.size()
} }
} }
impl<D: Data> GLWEInfos for GLWETensorKeyCompressed<D> { impl<D: Data> GLWEInfos for GLWETensorKeyCompressed<D> {
@@ -43,15 +46,15 @@ impl<D: Data> GGLWEInfos for GLWETensorKeyCompressed<D> {
} }
fn rank_out(&self) -> Rank { fn rank_out(&self) -> Rank {
self.keys[0].rank_out() self.0.rank_out()
} }
fn dsize(&self) -> Dsize { fn dsize(&self) -> Dsize {
self.keys[0].dsize() self.0.dsize()
} }
fn dnum(&self) -> Dnum { fn dnum(&self) -> Dnum {
self.keys[0].dnum() self.0.dnum()
} }
} }
@@ -63,18 +66,14 @@ impl<D: DataRef> fmt::Debug for GLWETensorKeyCompressed<D> {
impl<D: DataMut> FillUniform for GLWETensorKeyCompressed<D> { impl<D: DataMut> FillUniform for GLWETensorKeyCompressed<D> {
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
self.keys self.0.fill_uniform(log_bound, source);
.iter_mut()
.for_each(|key: &mut GGLWECompressed<D>| key.fill_uniform(log_bound, source))
} }
} }
impl<D: DataRef> fmt::Display for GLWETensorKeyCompressed<D> { impl<D: DataRef> fmt::Display for GLWETensorKeyCompressed<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "(GLWETensorKeyCompressed)",)?; writeln!(f, "(GLWETensorKeyCompressed)",)?;
for (i, key) in self.keys.iter().enumerate() { write!(f, "{}", self.0)?;
write!(f, "{i}: {key}")?;
}
Ok(()) Ok(())
} }
} }
@@ -96,11 +95,15 @@ impl GLWETensorKeyCompressed<Vec<u8>> {
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self {
let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1);
GLWETensorKeyCompressed { GLWETensorKeyCompressed(GGLWECompressed::alloc(
keys: (0..pairs) n,
.map(|_| GGLWECompressed::alloc(n, base2k, k, Rank(1), rank, dnum, dsize)) base2k,
.collect(), k,
} Rank(pairs),
rank,
dnum,
dsize,
))
} }
pub fn bytes_of_from_infos<A>(infos: &A) -> usize pub fn bytes_of_from_infos<A>(infos: &A) -> usize
@@ -118,88 +121,35 @@ impl GLWETensorKeyCompressed<Vec<u8>> {
} }
pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1);
pairs * GGLWECompressed::bytes_of(n, base2k, k, Rank(1), dnum, dsize) GGLWECompressed::bytes_of(n, base2k, k, Rank(pairs), dnum, dsize)
} }
} }
impl<D: DataMut> ReaderFrom for GLWETensorKeyCompressed<D> { impl<D: DataMut> ReaderFrom for GLWETensorKeyCompressed<D> {
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> { fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
let len: usize = reader.read_u64::<LittleEndian>()? as usize; self.0.read_from(reader)?;
if self.keys.len() != len {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("self.keys.len()={} != read len={}", self.keys.len(), len),
));
}
for key in &mut self.keys {
key.read_from(reader)?;
}
Ok(()) Ok(())
} }
} }
impl<D: DataRef> WriterTo for GLWETensorKeyCompressed<D> { impl<D: DataRef> WriterTo for GLWETensorKeyCompressed<D> {
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> { fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
writer.write_u64::<LittleEndian>(self.keys.len() as u64)?; self.0.write_to(writer)?;
for key in &self.keys {
key.write_to(writer)?;
}
Ok(()) Ok(())
} }
} }
pub trait GLWETensorKeyCompressedAtRef<D: DataRef> {
fn at(&self, i: usize, j: usize) -> &GGLWECompressed<D>;
}
impl<D: DataRef> GLWETensorKeyCompressedAtRef<D> for GLWETensorKeyCompressed<D> {
fn at(&self, mut i: usize, mut j: usize) -> &GGLWECompressed<D> {
if i > j {
std::mem::swap(&mut i, &mut j);
};
let rank: usize = self.rank_out().into();
&self.keys[i * rank + j - (i * (i + 1) / 2)]
}
}
pub trait GLWETensorKeyCompressedAtMut<D: DataMut> {
fn at_mut(&mut self, i: usize, j: usize) -> &mut GGLWECompressed<D>;
}
impl<D: DataMut> GLWETensorKeyCompressedAtMut<D> for GLWETensorKeyCompressed<D> {
fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWECompressed<D> {
if i > j {
std::mem::swap(&mut i, &mut j);
};
let rank: usize = self.rank_out().into();
&mut self.keys[i * rank + j - (i * (i + 1) / 2)]
}
}
pub trait GLWETensorKeyDecompress pub trait GLWETensorKeyDecompress
where where
Self: GGLWEDecompress, Self: GGLWEDecompress,
{ {
fn decompress_tensor_key<R, O>(&self, res: &mut R, other: &O) fn decompress_tensor_key<R, O>(&self, res: &mut R, other: &O)
where where
R: GLWETensorKeyToMut, R: GGLWEToMut,
O: GLWETensorKeyCompressedToRef, O: GGLWECompressedToRef,
{ {
let res: &mut GLWETensorKey<&mut [u8]> = &mut res.to_mut(); self.decompress_gglwe(res, other);
let other: &GLWETensorKeyCompressed<&[u8]> = &other.to_ref();
assert_eq!(
res.keys.len(),
other.keys.len(),
"invalid receiver: res.keys.len()={} != other.keys.len()={}",
res.keys.len(),
other.keys.len()
);
for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) {
self.decompress_gglwe(a, b);
}
} }
} }
@@ -208,39 +158,27 @@ impl<B: Backend> GLWETensorKeyDecompress for Module<B> where Self: GGLWEDecompre
impl<D: DataMut> GLWETensorKey<D> { impl<D: DataMut> GLWETensorKey<D> {
pub fn decompress<O, M>(&mut self, module: &M, other: &O) pub fn decompress<O, M>(&mut self, module: &M, other: &O)
where where
O: GLWETensorKeyCompressedToRef, O: GGLWECompressedToRef,
M: GLWETensorKeyDecompress, M: GLWETensorKeyDecompress,
{ {
module.decompress_tensor_key(self, other); module.decompress_tensor_key(self, other);
} }
} }
pub trait GLWETensorKeyCompressedToMut { impl<D: DataMut> GGLWECompressedToMut for GLWETensorKeyCompressed<D>
fn to_mut(&mut self) -> GLWETensorKeyCompressed<&mut [u8]>;
}
impl<D: DataMut> GLWETensorKeyCompressedToMut for GLWETensorKeyCompressed<D>
where where
GGLWECompressed<D>: GGLWECompressedToMut, GGLWECompressed<D>: GGLWECompressedToMut,
{ {
fn to_mut(&mut self) -> GLWETensorKeyCompressed<&mut [u8]> { fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]> {
GLWETensorKeyCompressed { self.0.to_mut()
keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(),
}
} }
} }
pub trait GLWETensorKeyCompressedToRef { impl<D: DataRef> GGLWECompressedToRef for GLWETensorKeyCompressed<D>
fn to_ref(&self) -> GLWETensorKeyCompressed<&[u8]>;
}
impl<D: DataRef> GLWETensorKeyCompressedToRef for GLWETensorKeyCompressed<D>
where where
GGLWECompressed<D>: GGLWECompressedToRef, GGLWECompressed<D>: GGLWECompressedToRef,
{ {
fn to_ref(&self) -> GLWETensorKeyCompressed<&[u8]> { fn to_ref(&self) -> GGLWECompressed<&[u8]> {
GLWETensorKeyCompressed { self.0.to_ref()
keys: self.keys.iter().map(|c| c.to_ref()).collect(),
}
} }
} }

View File

@@ -7,7 +7,7 @@ use poulpy_hal::{
use crate::layouts::{ use crate::layouts::{
Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEInfos, GGLWEToMut, GLWEInfos, Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEInfos, GGLWEToMut, GLWEInfos,
GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, GLWEToLWESwitchingKey, LWEInfos, Rank, TorusPrecision, GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, GLWEToLWEKey, LWEInfos, Rank, TorusPrecision,
compressed::{GLWESwitchingKeyCompressed, GLWESwitchingKeyDecompress}, compressed::{GLWESwitchingKeyCompressed, GLWESwitchingKeyDecompress},
}; };
@@ -147,7 +147,7 @@ pub trait GLWEToLWESwitchingKeyDecompress
where where
Self: GLWESwitchingKeyDecompress, Self: GLWESwitchingKeyDecompress,
{ {
fn decompress_glwe_to_lwe_switching_key<R, O>(&self, res: &mut R, other: &O) fn decompress_glwe_to_lwe_key<R, O>(&self, res: &mut R, other: &O)
where where
R: GGLWEToMut + GLWESwitchingKeyDegreesMut, R: GGLWEToMut + GLWESwitchingKeyDegreesMut,
O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, O: GGLWECompressedToRef + GLWESwitchingKeyDegrees,
@@ -158,13 +158,13 @@ where
impl<B: Backend> GLWEToLWESwitchingKeyDecompress for Module<B> where Self: GLWESwitchingKeyDecompress {} impl<B: Backend> GLWEToLWESwitchingKeyDecompress for Module<B> where Self: GLWESwitchingKeyDecompress {}
impl<D: DataMut> GLWEToLWESwitchingKey<D> { impl<D: DataMut> GLWEToLWEKey<D> {
pub fn decompress<O, M>(&mut self, module: &M, other: &O) pub fn decompress<O, M>(&mut self, module: &M, other: &O)
where where
O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, O: GGLWECompressedToRef + GLWESwitchingKeyDegrees,
M: GLWEToLWESwitchingKeyDecompress, M: GLWEToLWESwitchingKeyDecompress,
{ {
module.decompress_glwe_to_lwe_switching_key(self, other); module.decompress_glwe_to_lwe_key(self, other);
} }
} }

View File

@@ -5,15 +5,15 @@ use poulpy_hal::{
use crate::layouts::{ use crate::layouts::{
Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEInfos, GGLWEToMut, GLWEInfos, Base2K, Degree, Dnum, Dsize, GGLWECompressed, GGLWECompressedToMut, GGLWECompressedToRef, GGLWEInfos, GGLWEToMut, GLWEInfos,
GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, LWEInfos, LWEToGLWESwitchingKey, Rank, TorusPrecision, GLWESwitchingKeyDegrees, GLWESwitchingKeyDegreesMut, LWEInfos, LWEToGLWEKey, Rank, TorusPrecision,
compressed::{GLWESwitchingKeyCompressed, GLWESwitchingKeyDecompress}, compressed::{GLWESwitchingKeyCompressed, GLWESwitchingKeyDecompress},
}; };
use std::fmt; use std::fmt;
#[derive(PartialEq, Eq, Clone)] #[derive(PartialEq, Eq, Clone)]
pub struct LWEToGLWESwitchingKeyCompressed<D: Data>(pub(crate) GLWESwitchingKeyCompressed<D>); pub struct LWEToGLWEKeyCompressed<D: Data>(pub(crate) GLWESwitchingKeyCompressed<D>);
impl<D: Data> LWEInfos for LWEToGLWESwitchingKeyCompressed<D> { impl<D: Data> LWEInfos for LWEToGLWEKeyCompressed<D> {
fn n(&self) -> Degree { fn n(&self) -> Degree {
self.0.n() self.0.n()
} }
@@ -29,13 +29,13 @@ impl<D: Data> LWEInfos for LWEToGLWESwitchingKeyCompressed<D> {
self.0.size() self.0.size()
} }
} }
impl<D: Data> GLWEInfos for LWEToGLWESwitchingKeyCompressed<D> { impl<D: Data> GLWEInfos for LWEToGLWEKeyCompressed<D> {
fn rank(&self) -> Rank { fn rank(&self) -> Rank {
self.rank_out() self.rank_out()
} }
} }
impl<D: Data> GGLWEInfos for LWEToGLWESwitchingKeyCompressed<D> { impl<D: Data> GGLWEInfos for LWEToGLWEKeyCompressed<D> {
fn dsize(&self) -> Dsize { fn dsize(&self) -> Dsize {
self.0.dsize() self.0.dsize()
} }
@@ -53,37 +53,37 @@ impl<D: Data> GGLWEInfos for LWEToGLWESwitchingKeyCompressed<D> {
} }
} }
impl<D: DataRef> fmt::Debug for LWEToGLWESwitchingKeyCompressed<D> { impl<D: DataRef> fmt::Debug for LWEToGLWEKeyCompressed<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{self}") write!(f, "{self}")
} }
} }
impl<D: DataMut> FillUniform for LWEToGLWESwitchingKeyCompressed<D> { impl<D: DataMut> FillUniform for LWEToGLWEKeyCompressed<D> {
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
self.0.fill_uniform(log_bound, source); self.0.fill_uniform(log_bound, source);
} }
} }
impl<D: DataRef> fmt::Display for LWEToGLWESwitchingKeyCompressed<D> { impl<D: DataRef> fmt::Display for LWEToGLWEKeyCompressed<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "(LWEToGLWESwitchingKeyCompressed) {}", self.0) write!(f, "(LWEToGLWESwitchingKeyCompressed) {}", self.0)
} }
} }
impl<D: DataMut> ReaderFrom for LWEToGLWESwitchingKeyCompressed<D> { impl<D: DataMut> ReaderFrom for LWEToGLWEKeyCompressed<D> {
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> { fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
self.0.read_from(reader) self.0.read_from(reader)
} }
} }
impl<D: DataRef> WriterTo for LWEToGLWESwitchingKeyCompressed<D> { impl<D: DataRef> WriterTo for LWEToGLWEKeyCompressed<D> {
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> { fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
self.0.write_to(writer) self.0.write_to(writer)
} }
} }
impl LWEToGLWESwitchingKeyCompressed<Vec<u8>> { impl LWEToGLWEKeyCompressed<Vec<u8>> {
pub fn alloc_from_infos<A>(infos: &A) -> Self pub fn alloc_from_infos<A>(infos: &A) -> Self
where where
A: GGLWEInfos, A: GGLWEInfos,
@@ -108,7 +108,7 @@ impl LWEToGLWESwitchingKeyCompressed<Vec<u8>> {
} }
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self {
LWEToGLWESwitchingKeyCompressed(GLWESwitchingKeyCompressed::alloc( LWEToGLWEKeyCompressed(GLWESwitchingKeyCompressed::alloc(
n, n,
base2k, base2k,
k, k,
@@ -141,11 +141,11 @@ impl LWEToGLWESwitchingKeyCompressed<Vec<u8>> {
} }
} }
pub trait LWEToGLWESwitchingKeyDecompress pub trait LWEToGLWEKeyDecompress
where where
Self: GLWESwitchingKeyDecompress, Self: GLWESwitchingKeyDecompress,
{ {
fn decompress_lwe_to_glwe_switching_key<R, O>(&self, res: &mut R, other: &O) fn decompress_lwe_to_glwe_key<R, O>(&self, res: &mut R, other: &O)
where where
R: GGLWEToMut + GLWESwitchingKeyDegreesMut, R: GGLWEToMut + GLWESwitchingKeyDegreesMut,
O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, O: GGLWECompressedToRef + GLWESwitchingKeyDegrees,
@@ -154,25 +154,25 @@ where
} }
} }
impl<B: Backend> LWEToGLWESwitchingKeyDecompress for Module<B> where Self: GLWESwitchingKeyDecompress {} impl<B: Backend> LWEToGLWEKeyDecompress for Module<B> where Self: GLWESwitchingKeyDecompress {}
impl<D: DataMut> LWEToGLWESwitchingKey<D> { impl<D: DataMut> LWEToGLWEKey<D> {
pub fn decompress<O, M>(&mut self, module: &M, other: &O) pub fn decompress<O, M>(&mut self, module: &M, other: &O)
where where
O: GGLWECompressedToRef + GLWESwitchingKeyDegrees, O: GGLWECompressedToRef + GLWESwitchingKeyDegrees,
M: LWEToGLWESwitchingKeyDecompress, M: LWEToGLWEKeyDecompress,
{ {
module.decompress_lwe_to_glwe_switching_key(self, other); module.decompress_lwe_to_glwe_key(self, other);
} }
} }
impl<D: DataRef> GGLWECompressedToRef for LWEToGLWESwitchingKeyCompressed<D> { impl<D: DataRef> GGLWECompressedToRef for LWEToGLWEKeyCompressed<D> {
fn to_ref(&self) -> GGLWECompressed<&[u8]> { fn to_ref(&self) -> GGLWECompressed<&[u8]> {
self.0.to_ref() self.0.to_ref()
} }
} }
impl<D: DataMut> GGLWECompressedToMut for LWEToGLWESwitchingKeyCompressed<D> { impl<D: DataMut> GGLWECompressedToMut for LWEToGLWEKeyCompressed<D> {
fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]> { fn to_mut(&mut self) -> GGLWECompressed<&mut [u8]> {
self.0.to_mut() self.0.to_mut()
} }

View File

@@ -1,21 +1,23 @@
mod gglwe; mod gglwe;
mod gglwe_to_ggsw_key;
mod ggsw; mod ggsw;
mod glwe; mod glwe;
mod glwe_automorphism_key; mod glwe_automorphism_key;
mod glwe_switching_key; mod glwe_switching_key;
mod glwe_tensor_key; mod glwe_tensor_key;
mod glwe_to_lwe_switching_key; mod glwe_to_lwe_key;
mod lwe; mod lwe;
mod lwe_switching_key; mod lwe_switching_key;
mod lwe_to_glwe_switching_key; mod lwe_to_glwe_key;
pub use gglwe::*; pub use gglwe::*;
pub use gglwe_to_ggsw_key::*;
pub use ggsw::*; pub use ggsw::*;
pub use glwe::*; pub use glwe::*;
pub use glwe_automorphism_key::*; pub use glwe_automorphism_key::*;
pub use glwe_switching_key::*; pub use glwe_switching_key::*;
pub use glwe_tensor_key::*; pub use glwe_tensor_key::*;
pub use glwe_to_lwe_switching_key::*; pub use glwe_to_lwe_key::*;
pub use lwe::*; pub use lwe::*;
pub use lwe_switching_key::*; pub use lwe_switching_key::*;
pub use lwe_to_glwe_switching_key::*; pub use lwe_to_glwe_key::*;

View File

@@ -0,0 +1,254 @@
use poulpy_hal::{
layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo},
source::Source,
};
use crate::layouts::{
Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWEInfos, LWEInfos, Rank, TorusPrecision,
};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::fmt;
#[derive(PartialEq, Eq, Copy, Clone, Debug)]
pub struct GGLWEToGGSWKeyLayout {
pub n: Degree,
pub base2k: Base2K,
pub k: TorusPrecision,
pub rank: Rank,
pub dnum: Dnum,
pub dsize: Dsize,
}
#[derive(PartialEq, Eq, Clone)]
pub struct GGLWEToGGSWKey<D: Data> {
pub(crate) keys: Vec<GGLWE<D>>,
}
impl<D: Data> LWEInfos for GGLWEToGGSWKey<D> {
fn n(&self) -> Degree {
self.keys[0].n()
}
fn base2k(&self) -> Base2K {
self.keys[0].base2k()
}
fn k(&self) -> TorusPrecision {
self.keys[0].k()
}
fn size(&self) -> usize {
self.keys[0].size()
}
}
impl<D: Data> GLWEInfos for GGLWEToGGSWKey<D> {
fn rank(&self) -> Rank {
self.keys[0].rank_out()
}
}
impl<D: Data> GGLWEInfos for GGLWEToGGSWKey<D> {
fn rank_in(&self) -> Rank {
self.rank_out()
}
fn rank_out(&self) -> Rank {
self.keys[0].rank_out()
}
fn dsize(&self) -> Dsize {
self.keys[0].dsize()
}
fn dnum(&self) -> Dnum {
self.keys[0].dnum()
}
}
impl LWEInfos for GGLWEToGGSWKeyLayout {
fn n(&self) -> Degree {
self.n
}
fn base2k(&self) -> Base2K {
self.base2k
}
fn k(&self) -> TorusPrecision {
self.k
}
}
impl GLWEInfos for GGLWEToGGSWKeyLayout {
fn rank(&self) -> Rank {
self.rank_out()
}
}
impl GGLWEInfos for GGLWEToGGSWKeyLayout {
fn rank_in(&self) -> Rank {
self.rank
}
fn dsize(&self) -> Dsize {
self.dsize
}
fn rank_out(&self) -> Rank {
self.rank
}
fn dnum(&self) -> Dnum {
self.dnum
}
}
impl<D: DataRef> fmt::Debug for GGLWEToGGSWKey<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{self}")
}
}
impl<D: DataMut> FillUniform for GGLWEToGGSWKey<D> {
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
self.keys
.iter_mut()
.for_each(|key: &mut GGLWE<D>| key.fill_uniform(log_bound, source))
}
}
impl<D: DataRef> fmt::Display for GGLWEToGGSWKey<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "(GGLWEToGGSWKey)",)?;
for (i, key) in self.keys.iter().enumerate() {
write!(f, "{i}: {key}")?;
}
Ok(())
}
}
impl GGLWEToGGSWKey<Vec<u8>> {
pub fn alloc_from_infos<A>(infos: &A) -> Self
where
A: GGLWEInfos,
{
assert_eq!(
infos.rank_in(),
infos.rank_out(),
"rank_in != rank_out is not supported for GGLWEToGGSWKey"
);
Self::alloc(
infos.n(),
infos.base2k(),
infos.k(),
infos.rank(),
infos.dnum(),
infos.dsize(),
)
}
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self {
GGLWEToGGSWKey {
keys: (0..rank.as_usize())
.map(|_| GGLWE::alloc(n, base2k, k, rank, rank, dnum, dsize))
.collect(),
}
}
pub fn bytes_of_from_infos<A>(infos: &A) -> usize
where
A: GGLWEInfos,
{
assert_eq!(
infos.rank_in(),
infos.rank_out(),
"rank_in != rank_out is not supported for GGLWEToGGSWKey"
);
Self::bytes_of(
infos.n(),
infos.base2k(),
infos.k(),
infos.rank(),
infos.dnum(),
infos.dsize(),
)
}
pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
rank.as_usize() * GGLWE::bytes_of(n, base2k, k, rank, rank, dnum, dsize)
}
}
impl<D: DataMut> GGLWEToGGSWKey<D> {
// Returns a mutable reference to GGLWE_{s}([s[i]*s[0], s[i]*s[1], ..., s[i]*s[rank]])
pub fn at_mut(&mut self, i: usize) -> &mut GGLWE<D> {
assert!((i as u32) < self.rank());
&mut self.keys[i]
}
}
impl<D: DataRef> GGLWEToGGSWKey<D> {
// Returns a reference to GGLWE_{s}(s[i] * s[j])
pub fn at(&self, i: usize) -> &GGLWE<D> {
assert!((i as u32) < self.rank());
&self.keys[i]
}
}
impl<D: DataMut> ReaderFrom for GGLWEToGGSWKey<D> {
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
if self.keys.len() != len {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("self.keys.len()={} != read len={}", self.keys.len(), len),
));
}
for key in &mut self.keys {
key.read_from(reader)?;
}
Ok(())
}
}
impl<D: DataRef> WriterTo for GGLWEToGGSWKey<D> {
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
writer.write_u64::<LittleEndian>(self.keys.len() as u64)?;
for key in &self.keys {
key.write_to(writer)?;
}
Ok(())
}
}
pub trait GGLWEToGGSWKeyToRef {
fn to_ref(&self) -> GGLWEToGGSWKey<&[u8]>;
}
impl<D: DataRef> GGLWEToGGSWKeyToRef for GGLWEToGGSWKey<D>
where
GGLWE<D>: GGLWEToRef,
{
fn to_ref(&self) -> GGLWEToGGSWKey<&[u8]> {
GGLWEToGGSWKey {
keys: self.keys.iter().map(|c| c.to_ref()).collect(),
}
}
}
pub trait GGLWEToGGSWKeyToMut {
fn to_mut(&mut self) -> GGLWEToGGSWKey<&mut [u8]>;
}
impl<D: DataMut> GGLWEToGGSWKeyToMut for GGLWEToGGSWKey<D>
where
GGLWE<D>: GGLWEToMut,
{
fn to_mut(&mut self) -> GGLWEToGGSWKey<&mut [u8]> {
GGLWEToGGSWKey {
keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(),
}
}
}

View File

@@ -0,0 +1,221 @@
use poulpy_hal::{
api::{
ModuleN, ScratchTakeBasic, SvpApplyDftToDft, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyTmpA,
},
layouts::{
Backend, Data, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToMut, ScalarZnxToRef, Scratch, ZnxInfos, ZnxView,
ZnxViewMut,
},
};
use crate::{
ScratchTakeCore,
dist::Distribution,
layouts::{
Base2K, Degree, GLWEInfos, GLWESecret, GLWESecretPreparedFactory, GLWESecretToMut, GLWESecretToRef, LWEInfos, Rank,
TorusPrecision,
},
};
pub struct GLWESecretTensor<D: Data> {
pub(crate) data: ScalarZnx<D>,
pub(crate) rank: Rank,
pub(crate) dist: Distribution,
}
impl GLWESecretTensor<Vec<u8>> {
pub(crate) fn pairs(rank: usize) -> usize {
(((rank + 1) * rank) >> 1).max(1)
}
}
impl<D: Data> LWEInfos for GLWESecretTensor<D> {
fn base2k(&self) -> Base2K {
Base2K(0)
}
fn k(&self) -> TorusPrecision {
TorusPrecision(0)
}
fn n(&self) -> Degree {
Degree(self.data.n() as u32)
}
fn size(&self) -> usize {
1
}
}
impl<D: DataRef> GLWESecretTensor<D> {
pub fn at(&self, mut i: usize, mut j: usize) -> ScalarZnx<&[u8]> {
if i > j {
std::mem::swap(&mut i, &mut j);
};
let rank: usize = self.rank().into();
ScalarZnx {
data: bytemuck::cast_slice(self.data.at(i * rank + j - (i * (i + 1) / 2), 0)),
n: self.n().into(),
cols: 1,
}
}
}
impl<D: DataMut> GLWESecretTensor<D> {
pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> ScalarZnx<&mut [u8]> {
if i > j {
std::mem::swap(&mut i, &mut j);
};
let rank: usize = self.rank().into();
ScalarZnx {
n: self.n().into(),
data: bytemuck::cast_slice_mut(self.data.at_mut(i * rank + j - (i * (i + 1) / 2), 0)),
cols: 1,
}
}
}
impl<D: Data> GLWEInfos for GLWESecretTensor<D> {
fn rank(&self) -> Rank {
self.rank
}
}
impl<D: DataRef> GLWESecretToRef for GLWESecretTensor<D> {
fn to_ref(&self) -> GLWESecret<&[u8]> {
GLWESecret {
data: self.data.to_ref(),
dist: self.dist,
}
}
}
impl<D: DataMut> GLWESecretToMut for GLWESecretTensor<D> {
fn to_mut(&mut self) -> GLWESecret<&mut [u8]> {
GLWESecret {
dist: self.dist,
data: self.data.to_mut(),
}
}
}
impl GLWESecretTensor<Vec<u8>> {
pub fn alloc_from_infos<A>(infos: &A) -> Self
where
A: GLWEInfos,
{
Self::alloc(infos.n(), infos.rank())
}
pub fn alloc(n: Degree, rank: Rank) -> Self {
GLWESecretTensor {
data: ScalarZnx::alloc(n.into(), Self::pairs(rank.into())),
rank,
dist: Distribution::NONE,
}
}
pub fn bytes_of_from_infos<A>(infos: &A) -> usize
where
A: GLWEInfos,
{
Self::bytes_of(infos.n(), Self::pairs(infos.rank().into()).into())
}
pub fn bytes_of(n: Degree, rank: Rank) -> usize {
ScalarZnx::bytes_of(n.into(), Self::pairs(rank.into()))
}
}
impl<D: DataMut> GLWESecretTensor<D> {
pub fn prepare<M, S, BE: Backend>(&mut self, module: &M, other: &S, scratch: &mut Scratch<BE>)
where
M: GLWESecretTensorFactory<BE>,
S: GLWESecretToRef + GLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
module.glwe_secret_tensor_prepare(self, other, scratch);
}
}
pub trait GLWESecretTensorFactory<BE: Backend> {
fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize;
fn glwe_secret_tensor_prepare<R, O>(&self, res: &mut R, other: &O, scratch: &mut Scratch<BE>)
where
R: GLWESecretToMut + GLWEInfos,
O: GLWESecretToRef + GLWEInfos;
}
impl<BE: Backend> GLWESecretTensorFactory<BE> for Module<BE>
where
Self: ModuleN
+ GLWESecretPreparedFactory<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxDftApply<BE>
+ SvpApplyDftToDft<BE>
+ VecZnxIdftApplyTmpA<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxDftBytesOf
+ VecZnxBigBytesOf
+ VecZnxBigNormalizeTmpBytes,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize {
self.bytes_of_glwe_secret_prepared(rank)
+ self.bytes_of_vec_znx_dft(rank.into(), 1)
+ self.bytes_of_vec_znx_dft(1, 1)
+ self.bytes_of_vec_znx_big(1, 1)
+ self.vec_znx_big_normalize_tmp_bytes()
}
fn glwe_secret_tensor_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
where
R: GLWESecretToMut + GLWEInfos,
A: GLWESecretToRef + GLWEInfos,
{
let res: &mut GLWESecret<&mut [u8]> = &mut res.to_mut();
let a: &GLWESecret<&[u8]> = &a.to_ref();
println!("res.rank: {} a.rank: {}", res.rank(), a.rank());
assert_eq!(res.rank(), GLWESecretTensor::pairs(a.rank().into()) as u32);
assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32);
let rank: usize = a.rank().into();
let (mut a_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, rank.into());
a_prepared.prepare(self, a);
let base2k: usize = 17;
let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank, 1);
for i in 0..rank {
self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a.data.as_vec_znx(), i);
}
let (mut a_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1);
let (mut a_ij_dft, scratch_4) = scratch_3.take_vec_znx_dft(self, 1, 1);
// sk_tensor = sk (x) sk
// For example: (s0, s1) (x) (s0, s1) = (s0^2, s0s1, s1^2)
for i in 0..rank {
for j in i..rank {
let idx: usize = i * rank + j - (i * (i + 1) / 2);
self.svp_apply_dft_to_dft(&mut a_ij_dft, 0, &a_prepared.data, j, &a_dft, i);
self.vec_znx_idft_apply_tmpa(&mut a_ij_big, 0, &mut a_ij_dft, 0);
self.vec_znx_big_normalize(
base2k,
&mut res.data.as_vec_znx_mut(),
idx,
base2k,
&a_ij_big,
0,
scratch_4,
);
}
}
}
}

View File

@@ -0,0 +1,146 @@
use poulpy_hal::{
layouts::{Data, DataMut, DataRef, FillUniform, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos},
source::Source,
};
use crate::layouts::{Base2K, Degree, GLWEInfos, LWEInfos, Rank, SetGLWEInfos, TorusPrecision};
use std::fmt;
#[derive(PartialEq, Eq, Clone)]
pub struct GLWETensor<D: Data> {
pub(crate) data: VecZnx<D>,
pub(crate) base2k: Base2K,
pub(crate) rank: Rank,
pub(crate) k: TorusPrecision,
}
impl<D: DataMut> SetGLWEInfos for GLWETensor<D> {
fn set_base2k(&mut self, base2k: Base2K) {
self.base2k = base2k
}
fn set_k(&mut self, k: TorusPrecision) {
self.k = k
}
}
impl<D: DataRef> GLWETensor<D> {
pub fn data(&self) -> &VecZnx<D> {
&self.data
}
}
impl<D: DataMut> GLWETensor<D> {
pub fn data_mut(&mut self) -> &mut VecZnx<D> {
&mut self.data
}
}
impl<D: Data> LWEInfos for GLWETensor<D> {
fn base2k(&self) -> Base2K {
self.base2k
}
fn k(&self) -> TorusPrecision {
self.k
}
fn n(&self) -> Degree {
Degree(self.data.n() as u32)
}
fn size(&self) -> usize {
self.data.size()
}
}
impl<D: Data> GLWEInfos for GLWETensor<D> {
fn rank(&self) -> Rank {
self.rank
}
}
impl<D: DataRef> fmt::Debug for GLWETensor<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{self}")
}
}
impl<D: DataRef> fmt::Display for GLWETensor<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"GLWETensor: base2k={} k={}: {}",
self.base2k().0,
self.k().0,
self.data
)
}
}
impl<D: DataMut> FillUniform for GLWETensor<D> {
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
self.data.fill_uniform(log_bound, source);
}
}
impl GLWETensor<Vec<u8>> {
pub fn alloc_from_infos<A>(infos: &A) -> Self
where
A: GLWEInfos,
{
Self::alloc(infos.n(), infos.base2k(), infos.k(), infos.rank())
}
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> Self {
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1);
GLWETensor {
data: VecZnx::alloc(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize),
base2k,
k,
rank,
}
}
pub fn bytes_of_from_infos<A>(infos: &A) -> usize
where
A: GLWEInfos,
{
Self::bytes_of(infos.n(), infos.base2k(), infos.k(), infos.rank())
}
pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank) -> usize {
let pairs: usize = (((rank + 1) * rank).as_usize() >> 1).max(1);
VecZnx::bytes_of(n.into(), pairs + 1, k.0.div_ceil(base2k.0) as usize)
}
}
pub trait GLWETensorToRef {
fn to_ref(&self) -> GLWETensor<&[u8]>;
}
impl<D: DataRef> GLWETensorToRef for GLWETensor<D> {
fn to_ref(&self) -> GLWETensor<&[u8]> {
GLWETensor {
k: self.k,
base2k: self.base2k,
data: self.data.to_ref(),
rank: self.rank,
}
}
}
pub trait GLWETensorToMut {
fn to_mut(&mut self) -> GLWETensor<&mut [u8]>;
}
impl<D: DataMut> GLWETensorToMut for GLWETensor<D> {
fn to_mut(&mut self) -> GLWETensor<&mut [u8]> {
GLWETensor {
k: self.k,
base2k: self.base2k,
rank: self.rank,
data: self.data.to_mut(),
}
}
}

View File

@@ -6,7 +6,6 @@ use poulpy_hal::{
use crate::layouts::{ use crate::layouts::{
Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWEInfos, LWEInfos, Rank, TorusPrecision, Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToMut, GGLWEToRef, GLWEInfos, LWEInfos, Rank, TorusPrecision,
}; };
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::fmt; use std::fmt;
@@ -21,31 +20,29 @@ pub struct GLWETensorKeyLayout {
} }
#[derive(PartialEq, Eq, Clone)] #[derive(PartialEq, Eq, Clone)]
pub struct GLWETensorKey<D: Data> { pub struct GLWETensorKey<D: Data>(pub(crate) GGLWE<D>);
pub(crate) keys: Vec<GGLWE<D>>,
}
impl<D: Data> LWEInfos for GLWETensorKey<D> { impl<D: Data> LWEInfos for GLWETensorKey<D> {
fn n(&self) -> Degree { fn n(&self) -> Degree {
self.keys[0].n() self.0.n()
} }
fn base2k(&self) -> Base2K { fn base2k(&self) -> Base2K {
self.keys[0].base2k() self.0.base2k()
} }
fn k(&self) -> TorusPrecision { fn k(&self) -> TorusPrecision {
self.keys[0].k() self.0.k()
} }
fn size(&self) -> usize { fn size(&self) -> usize {
self.keys[0].size() self.0.size()
} }
} }
impl<D: Data> GLWEInfos for GLWETensorKey<D> { impl<D: Data> GLWEInfos for GLWETensorKey<D> {
fn rank(&self) -> Rank { fn rank(&self) -> Rank {
self.keys[0].rank_out() self.0.rank_out()
} }
} }
@@ -55,15 +52,15 @@ impl<D: Data> GGLWEInfos for GLWETensorKey<D> {
} }
fn rank_out(&self) -> Rank { fn rank_out(&self) -> Rank {
self.keys[0].rank_out() self.0.rank_out()
} }
fn dsize(&self) -> Dsize { fn dsize(&self) -> Dsize {
self.keys[0].dsize() self.0.dsize()
} }
fn dnum(&self) -> Dnum { fn dnum(&self) -> Dnum {
self.keys[0].dnum() self.0.dnum()
} }
} }
@@ -113,18 +110,14 @@ impl<D: DataRef> fmt::Debug for GLWETensorKey<D> {
impl<D: DataMut> FillUniform for GLWETensorKey<D> { impl<D: DataMut> FillUniform for GLWETensorKey<D> {
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
self.keys self.0.fill_uniform(log_bound, source)
.iter_mut()
.for_each(|key: &mut GGLWE<D>| key.fill_uniform(log_bound, source))
} }
} }
impl<D: DataRef> fmt::Display for GLWETensorKey<D> { impl<D: DataRef> fmt::Display for GLWETensorKey<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "(GLWETensorKey)",)?; writeln!(f, "(GLWETensorKey)",)?;
for (i, key) in self.keys.iter().enumerate() { write!(f, "{}", self.0)?;
write!(f, "{i}: {key}")?;
}
Ok(()) Ok(())
} }
} }
@@ -151,11 +144,7 @@ impl GLWETensorKey<Vec<u8>> {
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self {
let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1); let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1);
GLWETensorKey { GLWETensorKey(GGLWE::alloc(n, base2k, k, Rank(pairs), rank, dnum, dsize))
keys: (0..pairs)
.map(|_| GGLWE::alloc(n, base2k, k, Rank(1), rank, dnum, dsize))
.collect(),
}
} }
pub fn bytes_of_from_infos<A>(infos: &A) -> usize pub fn bytes_of_from_infos<A>(infos: &A) -> usize
@@ -178,85 +167,39 @@ impl GLWETensorKey<Vec<u8>> {
} }
pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; let pairs: u32 = (((rank.0 + 1) * rank.0) >> 1).max(1);
pairs * GGLWE::bytes_of(n, base2k, k, Rank(1), rank, dnum, dsize) GGLWE::bytes_of(n, base2k, k, Rank(pairs), rank, dnum, dsize)
}
}
impl<D: DataMut> GLWETensorKey<D> {
// Returns a mutable reference to GGLWE_{s}(s[i] * s[j])
pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWE<D> {
if i > j {
std::mem::swap(&mut i, &mut j);
};
let rank: usize = self.rank_out().into();
&mut self.keys[i * rank + j - (i * (i + 1) / 2)]
}
}
impl<D: DataRef> GLWETensorKey<D> {
// Returns a reference to GGLWE_{s}(s[i] * s[j])
pub fn at(&self, mut i: usize, mut j: usize) -> &GGLWE<D> {
if i > j {
std::mem::swap(&mut i, &mut j);
};
let rank: usize = self.rank_out().into();
&self.keys[i * rank + j - (i * (i + 1) / 2)]
} }
} }
impl<D: DataMut> ReaderFrom for GLWETensorKey<D> { impl<D: DataMut> ReaderFrom for GLWETensorKey<D> {
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> { fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
let len: usize = reader.read_u64::<LittleEndian>()? as usize; self.0.read_from(reader)?;
if self.keys.len() != len {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("self.keys.len()={} != read len={}", self.keys.len(), len),
));
}
for key in &mut self.keys {
key.read_from(reader)?;
}
Ok(()) Ok(())
} }
} }
impl<D: DataRef> WriterTo for GLWETensorKey<D> { impl<D: DataRef> WriterTo for GLWETensorKey<D> {
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> { fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
writer.write_u64::<LittleEndian>(self.keys.len() as u64)?; self.0.write_to(writer)?;
for key in &self.keys {
key.write_to(writer)?;
}
Ok(()) Ok(())
} }
} }
pub trait GLWETensorKeyToRef { impl<D: DataRef> GGLWEToRef for GLWETensorKey<D>
fn to_ref(&self) -> GLWETensorKey<&[u8]>;
}
impl<D: DataRef> GLWETensorKeyToRef for GLWETensorKey<D>
where where
GGLWE<D>: GGLWEToRef, GGLWE<D>: GGLWEToRef,
{ {
fn to_ref(&self) -> GLWETensorKey<&[u8]> { fn to_ref(&self) -> GGLWE<&[u8]> {
GLWETensorKey { self.0.to_ref()
keys: self.keys.iter().map(|c| c.to_ref()).collect(),
}
} }
} }
pub trait GLWETensorKeyToMut { impl<D: DataMut> GGLWEToMut for GLWETensorKey<D>
fn to_mut(&mut self) -> GLWETensorKey<&mut [u8]>;
}
impl<D: DataMut> GLWETensorKeyToMut for GLWETensorKey<D>
where where
GGLWE<D>: GGLWEToMut, GGLWE<D>: GGLWEToMut,
{ {
fn to_mut(&mut self) -> GLWETensorKey<&mut [u8]> { fn to_mut(&mut self) -> GGLWE<&mut [u8]> {
GLWETensorKey { self.0.to_mut()
keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(),
}
} }
} }

View File

@@ -59,9 +59,9 @@ impl GGLWEInfos for GLWEToLWEKeyLayout {
/// A special [GLWESwitchingKey] required to for the conversion from [GLWE] to [LWE]. /// A special [GLWESwitchingKey] required to for the conversion from [GLWE] to [LWE].
#[derive(PartialEq, Eq, Clone)] #[derive(PartialEq, Eq, Clone)]
pub struct GLWEToLWESwitchingKey<D: Data>(pub(crate) GLWESwitchingKey<D>); pub struct GLWEToLWEKey<D: Data>(pub(crate) GLWESwitchingKey<D>);
impl<D: Data> LWEInfos for GLWEToLWESwitchingKey<D> { impl<D: Data> LWEInfos for GLWEToLWEKey<D> {
fn base2k(&self) -> Base2K { fn base2k(&self) -> Base2K {
self.0.base2k() self.0.base2k()
} }
@@ -79,12 +79,12 @@ impl<D: Data> LWEInfos for GLWEToLWESwitchingKey<D> {
} }
} }
impl<D: Data> GLWEInfos for GLWEToLWESwitchingKey<D> { impl<D: Data> GLWEInfos for GLWEToLWEKey<D> {
fn rank(&self) -> Rank { fn rank(&self) -> Rank {
self.rank_out() self.rank_out()
} }
} }
impl<D: Data> GGLWEInfos for GLWEToLWESwitchingKey<D> { impl<D: Data> GGLWEInfos for GLWEToLWEKey<D> {
fn rank_in(&self) -> Rank { fn rank_in(&self) -> Rank {
self.0.rank_in() self.0.rank_in()
} }
@@ -102,37 +102,37 @@ impl<D: Data> GGLWEInfos for GLWEToLWESwitchingKey<D> {
} }
} }
impl<D: DataRef> fmt::Debug for GLWEToLWESwitchingKey<D> { impl<D: DataRef> fmt::Debug for GLWEToLWEKey<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{self}") write!(f, "{self}")
} }
} }
impl<D: DataMut> FillUniform for GLWEToLWESwitchingKey<D> { impl<D: DataMut> FillUniform for GLWEToLWEKey<D> {
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
self.0.fill_uniform(log_bound, source); self.0.fill_uniform(log_bound, source);
} }
} }
impl<D: DataRef> fmt::Display for GLWEToLWESwitchingKey<D> { impl<D: DataRef> fmt::Display for GLWEToLWEKey<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "(GLWEToLWESwitchingKey) {}", self.0) write!(f, "(GLWEToLWEKey) {}", self.0)
} }
} }
impl<D: DataMut> ReaderFrom for GLWEToLWESwitchingKey<D> { impl<D: DataMut> ReaderFrom for GLWEToLWEKey<D> {
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> { fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
self.0.read_from(reader) self.0.read_from(reader)
} }
} }
impl<D: DataRef> WriterTo for GLWEToLWESwitchingKey<D> { impl<D: DataRef> WriterTo for GLWEToLWEKey<D> {
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> { fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
self.0.write_to(writer) self.0.write_to(writer)
} }
} }
impl GLWEToLWESwitchingKey<Vec<u8>> { impl GLWEToLWEKey<Vec<u8>> {
pub fn alloc_from_infos<A>(infos: &A) -> Self pub fn alloc_from_infos<A>(infos: &A) -> Self
where where
A: GGLWEInfos, A: GGLWEInfos,
@@ -140,12 +140,12 @@ impl GLWEToLWESwitchingKey<Vec<u8>> {
assert_eq!( assert_eq!(
infos.rank_out().0, infos.rank_out().0,
1, 1,
"rank_out > 1 is not supported for GLWEToLWESwitchingKey" "rank_out > 1 is not supported for GLWEToLWEKey"
); );
assert_eq!( assert_eq!(
infos.dsize().0, infos.dsize().0,
1, 1,
"dsize > 1 is not supported for GLWEToLWESwitchingKey" "dsize > 1 is not supported for GLWEToLWEKey"
); );
Self::alloc( Self::alloc(
infos.n(), infos.n(),
@@ -157,7 +157,7 @@ impl GLWEToLWESwitchingKey<Vec<u8>> {
} }
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self {
GLWEToLWESwitchingKey(GLWESwitchingKey::alloc( GLWEToLWEKey(GLWESwitchingKey::alloc(
n, n,
base2k, base2k,
k, k,
@@ -196,19 +196,19 @@ impl GLWEToLWESwitchingKey<Vec<u8>> {
} }
} }
impl<D: DataRef> GGLWEToRef for GLWEToLWESwitchingKey<D> { impl<D: DataRef> GGLWEToRef for GLWEToLWEKey<D> {
fn to_ref(&self) -> GGLWE<&[u8]> { fn to_ref(&self) -> GGLWE<&[u8]> {
self.0.to_ref() self.0.to_ref()
} }
} }
impl<D: DataMut> GGLWEToMut for GLWEToLWESwitchingKey<D> { impl<D: DataMut> GGLWEToMut for GLWEToLWEKey<D> {
fn to_mut(&mut self) -> GGLWE<&mut [u8]> { fn to_mut(&mut self) -> GGLWE<&mut [u8]> {
self.0.to_mut() self.0.to_mut()
} }
} }
impl<D: DataMut> GLWESwitchingKeyDegreesMut for GLWEToLWESwitchingKey<D> { impl<D: DataMut> GLWESwitchingKeyDegreesMut for GLWEToLWEKey<D> {
fn input_degree(&mut self) -> &mut Degree { fn input_degree(&mut self) -> &mut Degree {
&mut self.0.input_degree &mut self.0.input_degree
} }
@@ -218,7 +218,7 @@ impl<D: DataMut> GLWESwitchingKeyDegreesMut for GLWEToLWESwitchingKey<D> {
} }
} }
impl<D: DataRef> GLWESwitchingKeyDegrees for GLWEToLWESwitchingKey<D> { impl<D: DataRef> GLWESwitchingKeyDegrees for GLWEToLWEKey<D> {
fn input_degree(&self) -> &Degree { fn input_degree(&self) -> &Degree {
&self.0.input_degree &self.0.input_degree
} }

View File

@@ -11,7 +11,7 @@ use crate::layouts::{
}; };
#[derive(PartialEq, Eq, Copy, Clone, Debug)] #[derive(PartialEq, Eq, Copy, Clone, Debug)]
pub struct LWEToGLWESwitchingKeyLayout { pub struct LWEToGLWEKeyLayout {
pub n: Degree, pub n: Degree,
pub base2k: Base2K, pub base2k: Base2K,
pub k: TorusPrecision, pub k: TorusPrecision,
@@ -19,7 +19,7 @@ pub struct LWEToGLWESwitchingKeyLayout {
pub dnum: Dnum, pub dnum: Dnum,
} }
impl LWEInfos for LWEToGLWESwitchingKeyLayout { impl LWEInfos for LWEToGLWEKeyLayout {
fn base2k(&self) -> Base2K { fn base2k(&self) -> Base2K {
self.base2k self.base2k
} }
@@ -33,13 +33,13 @@ impl LWEInfos for LWEToGLWESwitchingKeyLayout {
} }
} }
impl GLWEInfos for LWEToGLWESwitchingKeyLayout { impl GLWEInfos for LWEToGLWEKeyLayout {
fn rank(&self) -> Rank { fn rank(&self) -> Rank {
self.rank_out() self.rank_out()
} }
} }
impl GGLWEInfos for LWEToGLWESwitchingKeyLayout { impl GGLWEInfos for LWEToGLWEKeyLayout {
fn rank_in(&self) -> Rank { fn rank_in(&self) -> Rank {
Rank(1) Rank(1)
} }
@@ -58,9 +58,9 @@ impl GGLWEInfos for LWEToGLWESwitchingKeyLayout {
} }
#[derive(PartialEq, Eq, Clone)] #[derive(PartialEq, Eq, Clone)]
pub struct LWEToGLWESwitchingKey<D: Data>(pub(crate) GLWESwitchingKey<D>); pub struct LWEToGLWEKey<D: Data>(pub(crate) GLWESwitchingKey<D>);
impl<D: Data> LWEInfos for LWEToGLWESwitchingKey<D> { impl<D: Data> LWEInfos for LWEToGLWEKey<D> {
fn base2k(&self) -> Base2K { fn base2k(&self) -> Base2K {
self.0.base2k() self.0.base2k()
} }
@@ -78,12 +78,12 @@ impl<D: Data> LWEInfos for LWEToGLWESwitchingKey<D> {
} }
} }
impl<D: Data> GLWEInfos for LWEToGLWESwitchingKey<D> { impl<D: Data> GLWEInfos for LWEToGLWEKey<D> {
fn rank(&self) -> Rank { fn rank(&self) -> Rank {
self.rank_out() self.rank_out()
} }
} }
impl<D: Data> GGLWEInfos for LWEToGLWESwitchingKey<D> { impl<D: Data> GGLWEInfos for LWEToGLWEKey<D> {
fn dsize(&self) -> Dsize { fn dsize(&self) -> Dsize {
self.0.dsize() self.0.dsize()
} }
@@ -101,37 +101,37 @@ impl<D: Data> GGLWEInfos for LWEToGLWESwitchingKey<D> {
} }
} }
impl<D: DataRef> fmt::Debug for LWEToGLWESwitchingKey<D> { impl<D: DataRef> fmt::Debug for LWEToGLWEKey<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{self}") write!(f, "{self}")
} }
} }
impl<D: DataMut> FillUniform for LWEToGLWESwitchingKey<D> { impl<D: DataMut> FillUniform for LWEToGLWEKey<D> {
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
self.0.fill_uniform(log_bound, source); self.0.fill_uniform(log_bound, source);
} }
} }
impl<D: DataRef> fmt::Display for LWEToGLWESwitchingKey<D> { impl<D: DataRef> fmt::Display for LWEToGLWEKey<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "(LWEToGLWESwitchingKey) {}", self.0) write!(f, "(LWEToGLWEKey) {}", self.0)
} }
} }
impl<D: DataMut> ReaderFrom for LWEToGLWESwitchingKey<D> { impl<D: DataMut> ReaderFrom for LWEToGLWEKey<D> {
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> { fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
self.0.read_from(reader) self.0.read_from(reader)
} }
} }
impl<D: DataRef> WriterTo for LWEToGLWESwitchingKey<D> { impl<D: DataRef> WriterTo for LWEToGLWEKey<D> {
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> { fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
self.0.write_to(writer) self.0.write_to(writer)
} }
} }
impl LWEToGLWESwitchingKey<Vec<u8>> { impl LWEToGLWEKey<Vec<u8>> {
pub fn alloc_from_infos<A>(infos: &A) -> Self pub fn alloc_from_infos<A>(infos: &A) -> Self
where where
A: GGLWEInfos, A: GGLWEInfos,
@@ -139,12 +139,12 @@ impl LWEToGLWESwitchingKey<Vec<u8>> {
assert_eq!( assert_eq!(
infos.rank_in().0, infos.rank_in().0,
1, 1,
"rank_in > 1 is not supported for LWEToGLWESwitchingKey" "rank_in > 1 is not supported for LWEToGLWEKey"
); );
assert_eq!( assert_eq!(
infos.dsize().0, infos.dsize().0,
1, 1,
"dsize > 1 is not supported for LWEToGLWESwitchingKey" "dsize > 1 is not supported for LWEToGLWEKey"
); );
Self::alloc( Self::alloc(
@@ -157,7 +157,7 @@ impl LWEToGLWESwitchingKey<Vec<u8>> {
} }
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self {
LWEToGLWESwitchingKey(GLWESwitchingKey::alloc( LWEToGLWEKey(GLWESwitchingKey::alloc(
n, n,
base2k, base2k,
k, k,
@@ -175,12 +175,12 @@ impl LWEToGLWESwitchingKey<Vec<u8>> {
assert_eq!( assert_eq!(
infos.rank_in().0, infos.rank_in().0,
1, 1,
"rank_in > 1 is not supported for LWEToGLWESwitchingKey" "rank_in > 1 is not supported for LWEToGLWEKey"
); );
assert_eq!( assert_eq!(
infos.dsize().0, infos.dsize().0,
1, 1,
"dsize > 1 is not supported for LWEToGLWESwitchingKey" "dsize > 1 is not supported for LWEToGLWEKey"
); );
Self::bytes_of( Self::bytes_of(
infos.n(), infos.n(),
@@ -196,19 +196,19 @@ impl LWEToGLWESwitchingKey<Vec<u8>> {
} }
} }
impl<D: DataRef> GGLWEToRef for LWEToGLWESwitchingKey<D> { impl<D: DataRef> GGLWEToRef for LWEToGLWEKey<D> {
fn to_ref(&self) -> GGLWE<&[u8]> { fn to_ref(&self) -> GGLWE<&[u8]> {
self.0.to_ref() self.0.to_ref()
} }
} }
impl<D: DataMut> GGLWEToMut for LWEToGLWESwitchingKey<D> { impl<D: DataMut> GGLWEToMut for LWEToGLWEKey<D> {
fn to_mut(&mut self) -> GGLWE<&mut [u8]> { fn to_mut(&mut self) -> GGLWE<&mut [u8]> {
self.0.to_mut() self.0.to_mut()
} }
} }
impl<D: DataMut> GLWESwitchingKeyDegreesMut for LWEToGLWESwitchingKey<D> { impl<D: DataMut> GLWESwitchingKeyDegreesMut for LWEToGLWEKey<D> {
fn input_degree(&mut self) -> &mut Degree { fn input_degree(&mut self) -> &mut Degree {
&mut self.0.input_degree &mut self.0.input_degree
} }
@@ -218,7 +218,7 @@ impl<D: DataMut> GLWESwitchingKeyDegreesMut for LWEToGLWESwitchingKey<D> {
} }
} }
impl<D: DataRef> GLWESwitchingKeyDegrees for LWEToGLWESwitchingKey<D> { impl<D: DataRef> GLWESwitchingKeyDegrees for LWEToGLWEKey<D> {
fn input_degree(&self) -> &Degree { fn input_degree(&self) -> &Degree {
&self.0.input_degree &self.0.input_degree
} }

View File

@@ -1,38 +1,44 @@
mod gglwe; mod gglwe;
mod gglwe_to_ggsw_key;
mod ggsw; mod ggsw;
mod glwe; mod glwe;
mod glwe_automorphism_key; mod glwe_automorphism_key;
mod glwe_plaintext; mod glwe_plaintext;
mod glwe_public_key; mod glwe_public_key;
mod glwe_secret; mod glwe_secret;
mod glwe_secret_tensor;
mod glwe_switching_key; mod glwe_switching_key;
mod glwe_tensor;
mod glwe_tensor_key; mod glwe_tensor_key;
mod glwe_to_lwe_switching_key; mod glwe_to_lwe_key;
mod lwe; mod lwe;
mod lwe_plaintext; mod lwe_plaintext;
mod lwe_secret; mod lwe_secret;
mod lwe_switching_key; mod lwe_switching_key;
mod lwe_to_glwe_switching_key; mod lwe_to_glwe_key;
pub mod compressed; pub mod compressed;
pub mod prepared; pub mod prepared;
pub use compressed::*; pub use compressed::*;
pub use gglwe::*; pub use gglwe::*;
pub use gglwe_to_ggsw_key::*;
pub use ggsw::*; pub use ggsw::*;
pub use glwe::*; pub use glwe::*;
pub use glwe_automorphism_key::*; pub use glwe_automorphism_key::*;
pub use glwe_plaintext::*; pub use glwe_plaintext::*;
pub use glwe_public_key::*; pub use glwe_public_key::*;
pub use glwe_secret::*; pub use glwe_secret::*;
pub use glwe_secret_tensor::*;
pub use glwe_switching_key::*; pub use glwe_switching_key::*;
pub use glwe_tensor::*;
pub use glwe_tensor_key::*; pub use glwe_tensor_key::*;
pub use glwe_to_lwe_switching_key::*; pub use glwe_to_lwe_key::*;
pub use lwe::*; pub use lwe::*;
pub use lwe_plaintext::*; pub use lwe_plaintext::*;
pub use lwe_secret::*; pub use lwe_secret::*;
pub use lwe_switching_key::*; pub use lwe_switching_key::*;
pub use lwe_to_glwe_switching_key::*; pub use lwe_to_glwe_key::*;
pub use prepared::*; pub use prepared::*;
use poulpy_hal::layouts::{Backend, Module}; use poulpy_hal::layouts::{Backend, Module};

View File

@@ -0,0 +1,252 @@
use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch};
use crate::layouts::{
Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEPrepared, GGLWEPreparedFactory, GGLWEPreparedToMut, GGLWEPreparedToRef,
GGLWEToGGSWKey, GGLWEToGGSWKeyToRef, GLWEInfos, LWEInfos, Rank, TorusPrecision,
};
pub struct GGLWEToGGSWKeyPrepared<D: Data, BE: Backend> {
pub(crate) keys: Vec<GGLWEPrepared<D, BE>>,
}
impl<D: Data, BE: Backend> LWEInfos for GGLWEToGGSWKeyPrepared<D, BE> {
fn n(&self) -> Degree {
self.keys[0].n()
}
fn base2k(&self) -> Base2K {
self.keys[0].base2k()
}
fn k(&self) -> TorusPrecision {
self.keys[0].k()
}
fn size(&self) -> usize {
self.keys[0].size()
}
}
impl<D: Data, BE: Backend> GLWEInfos for GGLWEToGGSWKeyPrepared<D, BE> {
fn rank(&self) -> Rank {
self.keys[0].rank_out()
}
}
impl<D: Data, BE: Backend> GGLWEInfos for GGLWEToGGSWKeyPrepared<D, BE> {
fn rank_in(&self) -> Rank {
self.rank_out()
}
fn rank_out(&self) -> Rank {
self.keys[0].rank_out()
}
fn dsize(&self) -> Dsize {
self.keys[0].dsize()
}
fn dnum(&self) -> Dnum {
self.keys[0].dnum()
}
}
pub trait GGLWEToGGSWKeyPreparedFactory<BE: Backend> {
fn alloc_gglwe_to_ggsw_key_prepared_from_infos<A>(&self, infos: &A) -> GGLWEToGGSWKeyPrepared<Vec<u8>, BE>
where
A: GGLWEInfos;
fn alloc_gglwe_to_ggsw_key_prepared(
&self,
base2k: Base2K,
k: TorusPrecision,
rank: Rank,
dnum: Dnum,
dsize: Dsize,
) -> GGLWEToGGSWKeyPrepared<Vec<u8>, BE>;
fn bytes_of_gglwe_to_ggsw_from_infos<A>(&self, infos: &A) -> usize
where
A: GGLWEInfos;
fn bytes_of_gglwe_to_ggsw(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize;
fn prepare_gglwe_to_ggsw_key_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GGLWEInfos;
fn prepare_gglwe_to_ggsw_key<R, O>(&self, res: &mut R, other: &O, scratch: &mut Scratch<BE>)
where
R: GGLWEToGGSWKeyPreparedToMut<BE>,
O: GGLWEToGGSWKeyToRef;
}
impl<BE: Backend> GGLWEToGGSWKeyPreparedFactory<BE> for Module<BE>
where
Self: GGLWEPreparedFactory<BE>,
{
fn alloc_gglwe_to_ggsw_key_prepared_from_infos<A>(&self, infos: &A) -> GGLWEToGGSWKeyPrepared<Vec<u8>, BE>
where
A: GGLWEInfos,
{
assert_eq!(
infos.rank_in(),
infos.rank_out(),
"rank_in != rank_out is not supported for GGLWEToGGSWKeyPrepared"
);
self.alloc_gglwe_to_ggsw_key_prepared(
infos.base2k(),
infos.k(),
infos.rank(),
infos.dnum(),
infos.dsize(),
)
}
fn alloc_gglwe_to_ggsw_key_prepared(
&self,
base2k: Base2K,
k: TorusPrecision,
rank: Rank,
dnum: Dnum,
dsize: Dsize,
) -> GGLWEToGGSWKeyPrepared<Vec<u8>, BE> {
GGLWEToGGSWKeyPrepared {
keys: (0..rank.as_usize())
.map(|_| self.alloc_gglwe_prepared(base2k, k, rank, rank, dnum, dsize))
.collect(),
}
}
fn bytes_of_gglwe_to_ggsw_from_infos<A>(&self, infos: &A) -> usize
where
A: GGLWEInfos,
{
assert_eq!(
infos.rank_in(),
infos.rank_out(),
"rank_in != rank_out is not supported for GGLWEToGGSWKeyPrepared"
);
self.bytes_of_gglwe_to_ggsw(
infos.base2k(),
infos.k(),
infos.rank(),
infos.dnum(),
infos.dsize(),
)
}
fn bytes_of_gglwe_to_ggsw(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
rank.as_usize() * self.bytes_of_gglwe_prepared(base2k, k, rank, rank, dnum, dsize)
}
fn prepare_gglwe_to_ggsw_key_tmp_bytes<A>(&self, infos: &A) -> usize
where
A: GGLWEInfos,
{
self.prepare_gglwe_tmp_bytes(infos)
}
fn prepare_gglwe_to_ggsw_key<R, O>(&self, res: &mut R, other: &O, scratch: &mut Scratch<BE>)
where
R: GGLWEToGGSWKeyPreparedToMut<BE>,
O: GGLWEToGGSWKeyToRef,
{
let res: &mut GGLWEToGGSWKeyPrepared<&mut [u8], BE> = &mut res.to_mut();
let other: &GGLWEToGGSWKey<&[u8]> = &other.to_ref();
assert_eq!(res.keys.len(), other.keys.len());
for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) {
self.prepare_gglwe(a, b, scratch);
}
}
}
impl<BE: Backend> GGLWEToGGSWKeyPrepared<Vec<u8>, BE> {
pub fn alloc_from_infos<A, M>(module: &M, infos: &A) -> Self
where
A: GGLWEInfos,
M: GGLWEToGGSWKeyPreparedFactory<BE>,
{
module.alloc_gglwe_to_ggsw_key_prepared_from_infos(infos)
}
pub fn alloc<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self
where
M: GGLWEToGGSWKeyPreparedFactory<BE>,
{
module.alloc_gglwe_to_ggsw_key_prepared(base2k, k, rank, dnum, dsize)
}
pub fn bytes_of_from_infos<A, M>(module: &M, infos: &A) -> usize
where
A: GGLWEInfos,
M: GGLWEToGGSWKeyPreparedFactory<BE>,
{
module.bytes_of_gglwe_to_ggsw_from_infos(infos)
}
pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize
where
M: GGLWEToGGSWKeyPreparedFactory<BE>,
{
module.bytes_of_gglwe_to_ggsw(base2k, k, rank, dnum, dsize)
}
}
impl<D: DataMut, BE: Backend> GGLWEToGGSWKeyPrepared<D, BE> {
pub fn prepare<M, O>(&mut self, module: &M, other: &O, scratch: &mut Scratch<BE>)
where
M: GGLWEToGGSWKeyPreparedFactory<BE>,
O: GGLWEToGGSWKeyToRef,
{
module.prepare_gglwe_to_ggsw_key(self, other, scratch);
}
}
impl<D: DataMut, BE: Backend> GGLWEToGGSWKeyPrepared<D, BE> {
// Returns a mutable reference to GGLWEPrepared_{s}([s[i]*s[0], s[i]*s[1], ..., s[i]*s[rank]])
pub fn at_mut(&mut self, i: usize) -> &mut GGLWEPrepared<D, BE> {
assert!((i as u32) < self.rank());
&mut self.keys[i]
}
}
impl<D: DataRef, BE: Backend> GGLWEToGGSWKeyPrepared<D, BE> {
// Returns a reference to GGLWEPrepared_{s}([s[i]*s[0], s[i]*s[1], ..., s[i]*s[rank]])
pub fn at(&self, i: usize) -> &GGLWEPrepared<D, BE> {
assert!((i as u32) < self.rank());
&self.keys[i]
}
}
pub trait GGLWEToGGSWKeyPreparedToRef<BE: Backend> {
fn to_ref(&self) -> GGLWEToGGSWKeyPrepared<&[u8], BE>;
}
impl<D: DataRef, BE: Backend> GGLWEToGGSWKeyPreparedToRef<BE> for GGLWEToGGSWKeyPrepared<D, BE>
where
GGLWEPrepared<D, BE>: GGLWEPreparedToRef<BE>,
{
fn to_ref(&self) -> GGLWEToGGSWKeyPrepared<&[u8], BE> {
GGLWEToGGSWKeyPrepared {
keys: self.keys.iter().map(|c| c.to_ref()).collect(),
}
}
}
pub trait GGLWEToGGSWKeyPreparedToMut<BE: Backend> {
fn to_mut(&mut self) -> GGLWEToGGSWKeyPrepared<&mut [u8], BE>;
}
impl<D: DataMut, BE: Backend> GGLWEToGGSWKeyPreparedToMut<BE> for GGLWEToGGSWKeyPrepared<D, BE>
where
GGLWEPrepared<D, BE>: GGLWEPreparedToMut<BE>,
{
fn to_mut(&mut self) -> GGLWEToGGSWKeyPrepared<&mut [u8], BE> {
GGLWEToGGSWKeyPrepared {
keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(),
}
}
}

View File

@@ -109,7 +109,7 @@ where
) )
} }
fn bytes_of_glwe_switching_key_prepared( fn bytes_of_glwe_key_prepared(
&self, &self,
base2k: Base2K, base2k: Base2K,
k: TorusPrecision, k: TorusPrecision,
@@ -125,7 +125,7 @@ where
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
self.bytes_of_glwe_switching_key_prepared( self.bytes_of_glwe_key_prepared(
infos.base2k(), infos.base2k(),
infos.k(), infos.k(),
infos.rank_in(), infos.rank_in(),
@@ -199,7 +199,7 @@ impl<B: Backend> GLWESwitchingKeyPrepared<Vec<u8>, B> {
where where
M: GLWESwitchingKeyPreparedFactory<B>, M: GLWESwitchingKeyPreparedFactory<B>,
{ {
module.bytes_of_glwe_switching_key_prepared(base2k, k, rank_in, rank_out, dnum, dsize) module.bytes_of_glwe_key_prepared(base2k, k, rank_in, rank_out, dnum, dsize)
} }
} }

View File

@@ -2,29 +2,27 @@ use poulpy_hal::layouts::{Backend, Data, DataMut, DataRef, Module, Scratch};
use crate::layouts::{ use crate::layouts::{
Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEPrepared, GGLWEPreparedFactory, GGLWEPreparedToMut, GGLWEPreparedToRef, Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEPrepared, GGLWEPreparedFactory, GGLWEPreparedToMut, GGLWEPreparedToRef,
GLWEInfos, GLWETensorKey, GLWETensorKeyToRef, LWEInfos, Rank, TorusPrecision, GGLWEToRef, GLWEInfos, LWEInfos, Rank, TorusPrecision,
}; };
#[derive(PartialEq, Eq)] #[derive(PartialEq, Eq)]
pub struct GLWETensorKeyPrepared<D: Data, B: Backend> { pub struct GLWETensorKeyPrepared<D: Data, B: Backend>(pub(crate) GGLWEPrepared<D, B>);
pub(crate) keys: Vec<GGLWEPrepared<D, B>>,
}
impl<D: Data, B: Backend> LWEInfos for GLWETensorKeyPrepared<D, B> { impl<D: Data, B: Backend> LWEInfos for GLWETensorKeyPrepared<D, B> {
fn n(&self) -> Degree { fn n(&self) -> Degree {
self.keys[0].n() self.0.n()
} }
fn base2k(&self) -> Base2K { fn base2k(&self) -> Base2K {
self.keys[0].base2k() self.0.base2k()
} }
fn k(&self) -> TorusPrecision { fn k(&self) -> TorusPrecision {
self.keys[0].k() self.0.k()
} }
fn size(&self) -> usize { fn size(&self) -> usize {
self.keys[0].size() self.0.size()
} }
} }
@@ -40,15 +38,15 @@ impl<D: Data, B: Backend> GGLWEInfos for GLWETensorKeyPrepared<D, B> {
} }
fn rank_out(&self) -> Rank { fn rank_out(&self) -> Rank {
self.keys[0].rank_out() self.0.rank_out()
} }
fn dsize(&self) -> Dsize { fn dsize(&self) -> Dsize {
self.keys[0].dsize() self.0.dsize()
} }
fn dnum(&self) -> Dnum { fn dnum(&self) -> Dnum {
self.keys[0].dnum() self.0.dnum()
} }
} }
@@ -65,11 +63,7 @@ where
rank: Rank, rank: Rank,
) -> GLWETensorKeyPrepared<Vec<u8>, B> { ) -> GLWETensorKeyPrepared<Vec<u8>, B> {
let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1); let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1);
GLWETensorKeyPrepared { GLWETensorKeyPrepared(self.alloc_gglwe_prepared(base2k, k, Rank(pairs), rank, dnum, dsize))
keys: (0..pairs)
.map(|_| self.alloc_gglwe_prepared(base2k, k, Rank(1), rank, dnum, dsize))
.collect(),
}
} }
fn alloc_tensor_key_prepared_from_infos<A>(&self, infos: &A) -> GLWETensorKeyPrepared<Vec<u8>, B> fn alloc_tensor_key_prepared_from_infos<A>(&self, infos: &A) -> GLWETensorKeyPrepared<Vec<u8>, B>
@@ -91,8 +85,8 @@ where
} }
fn bytes_of_tensor_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize { fn bytes_of_tensor_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
let pairs: usize = (((rank.0 + 1) * rank.0) >> 1).max(1) as usize; let pairs: u32 = (((rank.as_u32() + 1) * rank.as_u32()) >> 1).max(1);
pairs * self.bytes_of_gglwe_prepared(base2k, k, Rank(1), rank, dnum, dsize) self.bytes_of_gglwe_prepared(base2k, k, Rank(pairs), rank, dnum, dsize)
} }
fn bytes_of_tensor_key_prepared_from_infos<A>(&self, infos: &A) -> usize fn bytes_of_tensor_key_prepared_from_infos<A>(&self, infos: &A) -> usize
@@ -117,17 +111,10 @@ where
fn prepare_tensor_key<R, O>(&self, res: &mut R, other: &O, scratch: &mut Scratch<B>) fn prepare_tensor_key<R, O>(&self, res: &mut R, other: &O, scratch: &mut Scratch<B>)
where where
R: GLWETensorKeyPreparedToMut<B>, R: GGLWEPreparedToMut<B>,
O: GLWETensorKeyToRef, O: GGLWEToRef,
{ {
let mut res: GLWETensorKeyPrepared<&mut [u8], B> = res.to_mut(); self.prepare_gglwe(res, other, scratch);
let other: GLWETensorKey<&[u8]> = other.to_ref();
assert_eq!(res.keys.len(), other.keys.len());
for (a, b) in res.keys.iter_mut().zip(other.keys.iter()) {
self.prepare_gglwe(a, b, scratch);
}
} }
} }
@@ -165,28 +152,6 @@ impl<B: Backend> GLWETensorKeyPrepared<Vec<u8>, B> {
} }
} }
impl<D: DataMut, B: Backend> GLWETensorKeyPrepared<D, B> {
// Returns a mutable reference to GGLWE_{s}(s[i] * s[j])
pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GGLWEPrepared<D, B> {
if i > j {
std::mem::swap(&mut i, &mut j);
};
let rank: usize = self.rank_out().into();
&mut self.keys[i * rank + j - (i * (i + 1) / 2)]
}
}
impl<D: DataRef, B: Backend> GLWETensorKeyPrepared<D, B> {
// Returns a reference to GGLWE_{s}(s[i] * s[j])
pub fn at(&self, mut i: usize, mut j: usize) -> &GGLWEPrepared<D, B> {
if i > j {
std::mem::swap(&mut i, &mut j);
};
let rank: usize = self.rank_out().into();
&self.keys[i * rank + j - (i * (i + 1) / 2)]
}
}
impl<B: Backend> GLWETensorKeyPrepared<Vec<u8>, B> { impl<B: Backend> GLWETensorKeyPrepared<Vec<u8>, B> {
pub fn prepare_tmp_bytes<A, M>(&self, module: &M, infos: &A) -> usize pub fn prepare_tmp_bytes<A, M>(&self, module: &M, infos: &A) -> usize
where where
@@ -200,39 +165,27 @@ impl<B: Backend> GLWETensorKeyPrepared<Vec<u8>, B> {
impl<D: DataMut, B: Backend> GLWETensorKeyPrepared<D, B> { impl<D: DataMut, B: Backend> GLWETensorKeyPrepared<D, B> {
pub fn prepare<O, M>(&mut self, module: &M, other: &O, scratch: &mut Scratch<B>) pub fn prepare<O, M>(&mut self, module: &M, other: &O, scratch: &mut Scratch<B>)
where where
O: GLWETensorKeyToRef, O: GGLWEToRef,
M: GLWETensorKeyPreparedFactory<B>, M: GLWETensorKeyPreparedFactory<B>,
{ {
module.prepare_tensor_key(self, other, scratch); module.prepare_tensor_key(self, other, scratch);
} }
} }
pub trait GLWETensorKeyPreparedToMut<B: Backend> { impl<D: DataMut, B: Backend> GGLWEPreparedToMut<B> for GLWETensorKeyPrepared<D, B>
fn to_mut(&mut self) -> GLWETensorKeyPrepared<&mut [u8], B>;
}
impl<D: DataMut, B: Backend> GLWETensorKeyPreparedToMut<B> for GLWETensorKeyPrepared<D, B>
where where
GGLWEPrepared<D, B>: GGLWEPreparedToMut<B>, GGLWEPrepared<D, B>: GGLWEPreparedToMut<B>,
{ {
fn to_mut(&mut self) -> GLWETensorKeyPrepared<&mut [u8], B> { fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B> {
GLWETensorKeyPrepared { self.0.to_mut()
keys: self.keys.iter_mut().map(|c| c.to_mut()).collect(),
}
} }
} }
pub trait GLWETensorKeyPreparedToRef<B: Backend> { impl<D: DataRef, B: Backend> GGLWEPreparedToRef<B> for GLWETensorKeyPrepared<D, B>
fn to_ref(&self) -> GLWETensorKeyPrepared<&[u8], B>;
}
impl<D: DataRef, B: Backend> GLWETensorKeyPreparedToRef<B> for GLWETensorKeyPrepared<D, B>
where where
GGLWEPrepared<D, B>: GGLWEPreparedToRef<B>, GGLWEPrepared<D, B>: GGLWEPreparedToRef<B>,
{ {
fn to_ref(&self) -> GLWETensorKeyPrepared<&[u8], B> { fn to_ref(&self) -> GGLWEPrepared<&[u8], B> {
GLWETensorKeyPrepared { self.0.to_ref()
keys: self.keys.iter().map(|c| c.to_ref()).collect(),
}
} }
} }

View File

@@ -7,9 +7,9 @@ use crate::layouts::{
}; };
#[derive(PartialEq, Eq)] #[derive(PartialEq, Eq)]
pub struct GLWEToLWESwitchingKeyPrepared<D: Data, B: Backend>(pub(crate) GLWESwitchingKeyPrepared<D, B>); pub struct GLWEToLWEKeyPrepared<D: Data, B: Backend>(pub(crate) GLWESwitchingKeyPrepared<D, B>);
impl<D: Data, B: Backend> LWEInfos for GLWEToLWESwitchingKeyPrepared<D, B> { impl<D: Data, B: Backend> LWEInfos for GLWEToLWEKeyPrepared<D, B> {
fn base2k(&self) -> Base2K { fn base2k(&self) -> Base2K {
self.0.base2k() self.0.base2k()
} }
@@ -27,13 +27,13 @@ impl<D: Data, B: Backend> LWEInfos for GLWEToLWESwitchingKeyPrepared<D, B> {
} }
} }
impl<D: Data, B: Backend> GLWEInfos for GLWEToLWESwitchingKeyPrepared<D, B> { impl<D: Data, B: Backend> GLWEInfos for GLWEToLWEKeyPrepared<D, B> {
fn rank(&self) -> Rank { fn rank(&self) -> Rank {
self.rank_out() self.rank_out()
} }
} }
impl<D: Data, B: Backend> GGLWEInfos for GLWEToLWESwitchingKeyPrepared<D, B> { impl<D: Data, B: Backend> GGLWEInfos for GLWEToLWEKeyPrepared<D, B> {
fn rank_in(&self) -> Rank { fn rank_in(&self) -> Rank {
self.0.rank_in() self.0.rank_in()
} }
@@ -51,65 +51,65 @@ impl<D: Data, B: Backend> GGLWEInfos for GLWEToLWESwitchingKeyPrepared<D, B> {
} }
} }
pub trait GLWEToLWESwitchingKeyPreparedFactory<B: Backend> pub trait GLWEToLWEKeyPreparedFactory<B: Backend>
where where
Self: GLWESwitchingKeyPreparedFactory<B>, Self: GLWESwitchingKeyPreparedFactory<B>,
{ {
fn alloc_glwe_to_lwe_switching_key_prepared( fn alloc_glwe_to_lwe_key_prepared(
&self, &self,
base2k: Base2K, base2k: Base2K,
k: TorusPrecision, k: TorusPrecision,
rank_in: Rank, rank_in: Rank,
dnum: Dnum, dnum: Dnum,
) -> GLWEToLWESwitchingKeyPrepared<Vec<u8>, B> { ) -> GLWEToLWEKeyPrepared<Vec<u8>, B> {
GLWEToLWESwitchingKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1))) GLWEToLWEKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1)))
} }
fn alloc_glwe_to_lwe_switching_key_prepared_from_infos<A>(&self, infos: &A) -> GLWEToLWESwitchingKeyPrepared<Vec<u8>, B> fn alloc_glwe_to_lwe_key_prepared_from_infos<A>(&self, infos: &A) -> GLWEToLWEKeyPrepared<Vec<u8>, B>
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
debug_assert_eq!( debug_assert_eq!(
infos.rank_out().0, infos.rank_out().0,
1, 1,
"rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" "rank_out > 1 is not supported for GLWEToLWEKeyPrepared"
); );
debug_assert_eq!( debug_assert_eq!(
infos.dsize().0, infos.dsize().0,
1, 1,
"dsize > 1 is not supported for GLWEToLWESwitchingKeyPrepared" "dsize > 1 is not supported for GLWEToLWEKeyPrepared"
); );
self.alloc_glwe_to_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) self.alloc_glwe_to_lwe_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum())
} }
fn bytes_of_glwe_to_lwe_switching_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize { fn bytes_of_glwe_to_lwe_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize {
self.bytes_of_glwe_switching_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1)) self.bytes_of_glwe_key_prepared(base2k, k, rank_in, Rank(1), dnum, Dsize(1))
} }
fn bytes_of_glwe_to_lwe_switching_key_prepared_from_infos<A>(&self, infos: &A) -> usize fn bytes_of_glwe_to_lwe_key_prepared_from_infos<A>(&self, infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
debug_assert_eq!( debug_assert_eq!(
infos.rank_out().0, infos.rank_out().0,
1, 1,
"rank_out > 1 is not supported for GLWEToLWESwitchingKeyPrepared" "rank_out > 1 is not supported for GLWEToLWEKeyPrepared"
); );
debug_assert_eq!( debug_assert_eq!(
infos.dsize().0, infos.dsize().0,
1, 1,
"dsize > 1 is not supported for GLWEToLWESwitchingKeyPrepared" "dsize > 1 is not supported for GLWEToLWEKeyPrepared"
); );
self.bytes_of_glwe_to_lwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum()) self.bytes_of_glwe_to_lwe_key_prepared(infos.base2k(), infos.k(), infos.rank_in(), infos.dnum())
} }
fn prepare_glwe_to_lwe_switching_key_tmp_bytes<A>(&self, infos: &A) -> usize fn prepare_glwe_to_lwe_key_tmp_bytes<A>(&self, infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
self.prepare_glwe_switching_key_tmp_bytes(infos) self.prepare_glwe_switching_key_tmp_bytes(infos)
} }
fn prepare_glwe_to_lwe_switching_key<R, O>(&self, res: &mut R, other: &O, scratch: &mut Scratch<B>) fn prepare_glwe_to_lwe_key<R, O>(&self, res: &mut R, other: &O, scratch: &mut Scratch<B>)
where where
R: GGLWEPreparedToMut<B> + GLWESwitchingKeyDegreesMut, R: GGLWEPreparedToMut<B> + GLWESwitchingKeyDegreesMut,
O: GGLWEToRef + GLWESwitchingKeyDegrees, O: GGLWEToRef + GLWESwitchingKeyDegrees,
@@ -118,61 +118,61 @@ where
} }
} }
impl<B: Backend> GLWEToLWESwitchingKeyPreparedFactory<B> for Module<B> where Self: GLWESwitchingKeyPreparedFactory<B> {} impl<B: Backend> GLWEToLWEKeyPreparedFactory<B> for Module<B> where Self: GLWESwitchingKeyPreparedFactory<B> {}
impl<B: Backend> GLWEToLWESwitchingKeyPrepared<Vec<u8>, B> { impl<B: Backend> GLWEToLWEKeyPrepared<Vec<u8>, B> {
pub fn alloc_from_infos<A, M>(module: &M, infos: &A) -> Self pub fn alloc_from_infos<A, M>(module: &M, infos: &A) -> Self
where where
A: GGLWEInfos, A: GGLWEInfos,
M: GLWEToLWESwitchingKeyPreparedFactory<B>, M: GLWEToLWEKeyPreparedFactory<B>,
{ {
module.alloc_glwe_to_lwe_switching_key_prepared_from_infos(infos) module.alloc_glwe_to_lwe_key_prepared_from_infos(infos)
} }
pub fn alloc<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self pub fn alloc<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> Self
where where
M: GLWEToLWESwitchingKeyPreparedFactory<B>, M: GLWEToLWEKeyPreparedFactory<B>,
{ {
module.alloc_glwe_to_lwe_switching_key_prepared(base2k, k, rank_in, dnum) module.alloc_glwe_to_lwe_key_prepared(base2k, k, rank_in, dnum)
} }
pub fn bytes_of_from_infos<A, M>(module: &M, infos: &A) -> usize pub fn bytes_of_from_infos<A, M>(module: &M, infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
M: GLWEToLWESwitchingKeyPreparedFactory<B>, M: GLWEToLWEKeyPreparedFactory<B>,
{ {
module.bytes_of_glwe_to_lwe_switching_key_prepared_from_infos(infos) module.bytes_of_glwe_to_lwe_key_prepared_from_infos(infos)
} }
pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum) -> usize
where where
M: GLWEToLWESwitchingKeyPreparedFactory<B>, M: GLWEToLWEKeyPreparedFactory<B>,
{ {
module.bytes_of_glwe_to_lwe_switching_key_prepared(base2k, k, rank_in, dnum) module.bytes_of_glwe_to_lwe_key_prepared(base2k, k, rank_in, dnum)
} }
} }
impl<B: Backend> GLWEToLWESwitchingKeyPrepared<Vec<u8>, B> { impl<B: Backend> GLWEToLWEKeyPrepared<Vec<u8>, B> {
pub fn prepare_tmp_bytes<A, M>(&self, module: &M, infos: &A) pub fn prepare_tmp_bytes<A, M>(&self, module: &M, infos: &A)
where where
A: GGLWEInfos, A: GGLWEInfos,
M: GLWEToLWESwitchingKeyPreparedFactory<B>, M: GLWEToLWEKeyPreparedFactory<B>,
{ {
module.prepare_glwe_to_lwe_switching_key_tmp_bytes(infos); module.prepare_glwe_to_lwe_key_tmp_bytes(infos);
} }
} }
impl<D: DataMut, B: Backend> GLWEToLWESwitchingKeyPrepared<D, B> { impl<D: DataMut, B: Backend> GLWEToLWEKeyPrepared<D, B> {
pub fn prepare<O, M>(&mut self, module: &M, other: &O, scratch: &mut Scratch<B>) pub fn prepare<O, M>(&mut self, module: &M, other: &O, scratch: &mut Scratch<B>)
where where
O: GGLWEToRef + GLWESwitchingKeyDegrees, O: GGLWEToRef + GLWESwitchingKeyDegrees,
M: GLWEToLWESwitchingKeyPreparedFactory<B>, M: GLWEToLWEKeyPreparedFactory<B>,
{ {
module.prepare_glwe_to_lwe_switching_key(self, other, scratch); module.prepare_glwe_to_lwe_key(self, other, scratch);
} }
} }
impl<D: DataRef, B: Backend> GGLWEPreparedToRef<B> for GLWEToLWESwitchingKeyPrepared<D, B> impl<D: DataRef, B: Backend> GGLWEPreparedToRef<B> for GLWEToLWEKeyPrepared<D, B>
where where
GLWESwitchingKeyPrepared<D, B>: GGLWEPreparedToRef<B>, GLWESwitchingKeyPrepared<D, B>: GGLWEPreparedToRef<B>,
{ {
@@ -181,7 +181,7 @@ where
} }
} }
impl<D: DataMut, B: Backend> GGLWEPreparedToMut<B> for GLWEToLWESwitchingKeyPrepared<D, B> impl<D: DataMut, B: Backend> GGLWEPreparedToMut<B> for GLWEToLWEKeyPrepared<D, B>
where where
GLWESwitchingKeyPrepared<D, B>: GGLWEPreparedToRef<B>, GLWESwitchingKeyPrepared<D, B>: GGLWEPreparedToRef<B>,
{ {
@@ -190,7 +190,7 @@ where
} }
} }
impl<D: DataMut, B: Backend> GLWESwitchingKeyDegreesMut for GLWEToLWESwitchingKeyPrepared<D, B> { impl<D: DataMut, B: Backend> GLWESwitchingKeyDegreesMut for GLWEToLWEKeyPrepared<D, B> {
fn input_degree(&mut self) -> &mut Degree { fn input_degree(&mut self) -> &mut Degree {
&mut self.0.input_degree &mut self.0.input_degree
} }
@@ -200,7 +200,7 @@ impl<D: DataMut, B: Backend> GLWESwitchingKeyDegreesMut for GLWEToLWESwitchingKe
} }
} }
impl<D: DataRef, B: Backend> GLWESwitchingKeyDegrees for GLWEToLWESwitchingKeyPrepared<D, B> { impl<D: DataRef, B: Backend> GLWESwitchingKeyDegrees for GLWEToLWEKeyPrepared<D, B> {
fn input_degree(&self) -> &Degree { fn input_degree(&self) -> &Degree {
&self.0.input_degree &self.0.input_degree
} }

View File

@@ -86,7 +86,7 @@ where
} }
fn bytes_of_lwe_switching_key_prepared(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize { fn bytes_of_lwe_switching_key_prepared(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum) -> usize {
self.bytes_of_glwe_switching_key_prepared(base2k, k, Rank(1), Rank(1), dnum, Dsize(1)) self.bytes_of_glwe_key_prepared(base2k, k, Rank(1), Rank(1), dnum, Dsize(1))
} }
fn bytes_of_lwe_switching_key_prepared_from_infos<A>(&self, infos: &A) -> usize fn bytes_of_lwe_switching_key_prepared_from_infos<A>(&self, infos: &A) -> usize

View File

@@ -8,9 +8,9 @@ use crate::layouts::{
/// A special [GLWESwitchingKey] required to for the conversion from [LWE] to [GLWE]. /// A special [GLWESwitchingKey] required to for the conversion from [LWE] to [GLWE].
#[derive(PartialEq, Eq)] #[derive(PartialEq, Eq)]
pub struct LWEToGLWESwitchingKeyPrepared<D: Data, B: Backend>(pub(crate) GLWESwitchingKeyPrepared<D, B>); pub struct LWEToGLWEKeyPrepared<D: Data, B: Backend>(pub(crate) GLWESwitchingKeyPrepared<D, B>);
impl<D: Data, B: Backend> LWEInfos for LWEToGLWESwitchingKeyPrepared<D, B> { impl<D: Data, B: Backend> LWEInfos for LWEToGLWEKeyPrepared<D, B> {
fn base2k(&self) -> Base2K { fn base2k(&self) -> Base2K {
self.0.base2k() self.0.base2k()
} }
@@ -28,13 +28,13 @@ impl<D: Data, B: Backend> LWEInfos for LWEToGLWESwitchingKeyPrepared<D, B> {
} }
} }
impl<D: Data, B: Backend> GLWEInfos for LWEToGLWESwitchingKeyPrepared<D, B> { impl<D: Data, B: Backend> GLWEInfos for LWEToGLWEKeyPrepared<D, B> {
fn rank(&self) -> Rank { fn rank(&self) -> Rank {
self.rank_out() self.rank_out()
} }
} }
impl<D: Data, B: Backend> GGLWEInfos for LWEToGLWESwitchingKeyPrepared<D, B> { impl<D: Data, B: Backend> GGLWEInfos for LWEToGLWEKeyPrepared<D, B> {
fn dsize(&self) -> Dsize { fn dsize(&self) -> Dsize {
self.0.dsize() self.0.dsize()
} }
@@ -52,71 +52,65 @@ impl<D: Data, B: Backend> GGLWEInfos for LWEToGLWESwitchingKeyPrepared<D, B> {
} }
} }
pub trait LWEToGLWESwitchingKeyPreparedFactory<B: Backend> pub trait LWEToGLWEKeyPreparedFactory<B: Backend>
where where
Self: GLWESwitchingKeyPreparedFactory<B>, Self: GLWESwitchingKeyPreparedFactory<B>,
{ {
fn alloc_lwe_to_glwe_switching_key_prepared( fn alloc_lwe_to_glwe_key_prepared(
&self, &self,
base2k: Base2K, base2k: Base2K,
k: TorusPrecision, k: TorusPrecision,
rank_out: Rank, rank_out: Rank,
dnum: Dnum, dnum: Dnum,
) -> LWEToGLWESwitchingKeyPrepared<Vec<u8>, B> { ) -> LWEToGLWEKeyPrepared<Vec<u8>, B> {
LWEToGLWESwitchingKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1))) LWEToGLWEKeyPrepared(self.alloc_glwe_switching_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1)))
} }
fn alloc_lwe_to_glwe_switching_key_prepared_from_infos<A>(&self, infos: &A) -> LWEToGLWESwitchingKeyPrepared<Vec<u8>, B> fn alloc_lwe_to_glwe_key_prepared_from_infos<A>(&self, infos: &A) -> LWEToGLWEKeyPrepared<Vec<u8>, B>
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
debug_assert_eq!( debug_assert_eq!(
infos.rank_in().0, infos.rank_in().0,
1, 1,
"rank_in > 1 is not supported for LWEToGLWESwitchingKey" "rank_in > 1 is not supported for LWEToGLWEKey"
); );
debug_assert_eq!( debug_assert_eq!(
infos.dsize().0, infos.dsize().0,
1, 1,
"dsize > 1 is not supported for LWEToGLWESwitchingKey" "dsize > 1 is not supported for LWEToGLWEKey"
); );
self.alloc_lwe_to_glwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) self.alloc_lwe_to_glwe_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum())
} }
fn bytes_of_lwe_to_glwe_switching_key_prepared( fn bytes_of_lwe_to_glwe_key_prepared(&self, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize {
&self, self.bytes_of_glwe_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1))
base2k: Base2K,
k: TorusPrecision,
rank_out: Rank,
dnum: Dnum,
) -> usize {
self.bytes_of_glwe_switching_key_prepared(base2k, k, Rank(1), rank_out, dnum, Dsize(1))
} }
fn bytes_of_lwe_to_glwe_switching_key_prepared_from_infos<A>(&self, infos: &A) -> usize fn bytes_of_lwe_to_glwe_key_prepared_from_infos<A>(&self, infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
debug_assert_eq!( debug_assert_eq!(
infos.rank_in().0, infos.rank_in().0,
1, 1,
"rank_in > 1 is not supported for LWEToGLWESwitchingKey" "rank_in > 1 is not supported for LWEToGLWEKey"
); );
debug_assert_eq!( debug_assert_eq!(
infos.dsize().0, infos.dsize().0,
1, 1,
"dsize > 1 is not supported for LWEToGLWESwitchingKey" "dsize > 1 is not supported for LWEToGLWEKey"
); );
self.bytes_of_lwe_to_glwe_switching_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum()) self.bytes_of_lwe_to_glwe_key_prepared(infos.base2k(), infos.k(), infos.rank_out(), infos.dnum())
} }
fn prepare_lwe_to_glwe_switching_key_tmp_bytes<A>(&self, infos: &A) fn prepare_lwe_to_glwe_key_tmp_bytes<A>(&self, infos: &A)
where where
A: GGLWEInfos, A: GGLWEInfos,
{ {
self.prepare_glwe_switching_key_tmp_bytes(infos); self.prepare_glwe_switching_key_tmp_bytes(infos);
} }
fn prepare_lwe_to_glwe_switching_key<R, O>(&self, res: &mut R, other: &O, scratch: &mut Scratch<B>) fn prepare_lwe_to_glwe_key<R, O>(&self, res: &mut R, other: &O, scratch: &mut Scratch<B>)
where where
R: GGLWEPreparedToMut<B> + GLWESwitchingKeyDegreesMut, R: GGLWEPreparedToMut<B> + GLWESwitchingKeyDegreesMut,
O: GGLWEToRef + GLWESwitchingKeyDegrees, O: GGLWEToRef + GLWESwitchingKeyDegrees,
@@ -125,61 +119,61 @@ where
} }
} }
impl<B: Backend> LWEToGLWESwitchingKeyPreparedFactory<B> for Module<B> where Self: GLWESwitchingKeyPreparedFactory<B> {} impl<B: Backend> LWEToGLWEKeyPreparedFactory<B> for Module<B> where Self: GLWESwitchingKeyPreparedFactory<B> {}
impl<B: Backend> LWEToGLWESwitchingKeyPrepared<Vec<u8>, B> { impl<B: Backend> LWEToGLWEKeyPrepared<Vec<u8>, B> {
pub fn alloc_from_infos<A, M>(module: &M, infos: &A) -> Self pub fn alloc_from_infos<A, M>(module: &M, infos: &A) -> Self
where where
A: GGLWEInfos, A: GGLWEInfos,
M: LWEToGLWESwitchingKeyPreparedFactory<B>, M: LWEToGLWEKeyPreparedFactory<B>,
{ {
module.alloc_lwe_to_glwe_switching_key_prepared_from_infos(infos) module.alloc_lwe_to_glwe_key_prepared_from_infos(infos)
} }
pub fn alloc<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self pub fn alloc<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> Self
where where
M: LWEToGLWESwitchingKeyPreparedFactory<B>, M: LWEToGLWEKeyPreparedFactory<B>,
{ {
module.alloc_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum) module.alloc_lwe_to_glwe_key_prepared(base2k, k, rank_out, dnum)
} }
pub fn bytes_of_from_infos<A, M>(module: &M, infos: &A) -> usize pub fn bytes_of_from_infos<A, M>(module: &M, infos: &A) -> usize
where where
A: GGLWEInfos, A: GGLWEInfos,
M: LWEToGLWESwitchingKeyPreparedFactory<B>, M: LWEToGLWEKeyPreparedFactory<B>,
{ {
module.bytes_of_lwe_to_glwe_switching_key_prepared_from_infos(infos) module.bytes_of_lwe_to_glwe_key_prepared_from_infos(infos)
} }
pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize pub fn bytes_of<M>(module: &M, base2k: Base2K, k: TorusPrecision, rank_out: Rank, dnum: Dnum) -> usize
where where
M: LWEToGLWESwitchingKeyPreparedFactory<B>, M: LWEToGLWEKeyPreparedFactory<B>,
{ {
module.bytes_of_lwe_to_glwe_switching_key_prepared(base2k, k, rank_out, dnum) module.bytes_of_lwe_to_glwe_key_prepared(base2k, k, rank_out, dnum)
} }
} }
impl<B: Backend> LWEToGLWESwitchingKeyPrepared<Vec<u8>, B> { impl<B: Backend> LWEToGLWEKeyPrepared<Vec<u8>, B> {
pub fn prepare_tmp_bytes<A, M>(&self, module: &M, infos: &A) pub fn prepare_tmp_bytes<A, M>(&self, module: &M, infos: &A)
where where
A: GGLWEInfos, A: GGLWEInfos,
M: LWEToGLWESwitchingKeyPreparedFactory<B>, M: LWEToGLWEKeyPreparedFactory<B>,
{ {
module.prepare_lwe_to_glwe_switching_key_tmp_bytes(infos); module.prepare_lwe_to_glwe_key_tmp_bytes(infos);
} }
} }
impl<D: DataMut, B: Backend> LWEToGLWESwitchingKeyPrepared<D, B> { impl<D: DataMut, B: Backend> LWEToGLWEKeyPrepared<D, B> {
pub fn prepare<O, M>(&mut self, module: &M, other: &O, scratch: &mut Scratch<B>) pub fn prepare<O, M>(&mut self, module: &M, other: &O, scratch: &mut Scratch<B>)
where where
O: GGLWEToRef + GLWESwitchingKeyDegrees, O: GGLWEToRef + GLWESwitchingKeyDegrees,
M: LWEToGLWESwitchingKeyPreparedFactory<B>, M: LWEToGLWEKeyPreparedFactory<B>,
{ {
module.prepare_lwe_to_glwe_switching_key(self, other, scratch); module.prepare_lwe_to_glwe_key(self, other, scratch);
} }
} }
impl<D: DataRef, B: Backend> GGLWEPreparedToRef<B> for LWEToGLWESwitchingKeyPrepared<D, B> impl<D: DataRef, B: Backend> GGLWEPreparedToRef<B> for LWEToGLWEKeyPrepared<D, B>
where where
GLWESwitchingKeyPrepared<D, B>: GGLWEPreparedToRef<B>, GLWESwitchingKeyPrepared<D, B>: GGLWEPreparedToRef<B>,
{ {
@@ -188,7 +182,7 @@ where
} }
} }
impl<D: DataMut, B: Backend> GGLWEPreparedToMut<B> for LWEToGLWESwitchingKeyPrepared<D, B> impl<D: DataMut, B: Backend> GGLWEPreparedToMut<B> for LWEToGLWEKeyPrepared<D, B>
where where
GLWESwitchingKeyPrepared<D, B>: GGLWEPreparedToMut<B>, GLWESwitchingKeyPrepared<D, B>: GGLWEPreparedToMut<B>,
{ {
@@ -197,7 +191,7 @@ where
} }
} }
impl<D: DataMut, B: Backend> GLWESwitchingKeyDegreesMut for LWEToGLWESwitchingKeyPrepared<D, B> { impl<D: DataMut, B: Backend> GLWESwitchingKeyDegreesMut for LWEToGLWEKeyPrepared<D, B> {
fn input_degree(&mut self) -> &mut Degree { fn input_degree(&mut self) -> &mut Degree {
&mut self.0.input_degree &mut self.0.input_degree
} }

View File

@@ -1,4 +1,5 @@
mod gglwe; mod gglwe;
mod gglwe_to_ggsw_key;
mod ggsw; mod ggsw;
mod glwe; mod glwe;
mod glwe_automorphism_key; mod glwe_automorphism_key;
@@ -6,11 +7,12 @@ mod glwe_public_key;
mod glwe_secret; mod glwe_secret;
mod glwe_switching_key; mod glwe_switching_key;
mod glwe_tensor_key; mod glwe_tensor_key;
mod glwe_to_lwe_switching_key; mod glwe_to_lwe_key;
mod lwe_switching_key; mod lwe_switching_key;
mod lwe_to_glwe_switching_key; mod lwe_to_glwe_key;
pub use gglwe::*; pub use gglwe::*;
pub use gglwe_to_ggsw_key::*;
pub use ggsw::*; pub use ggsw::*;
pub use glwe::*; pub use glwe::*;
pub use glwe_automorphism_key::*; pub use glwe_automorphism_key::*;
@@ -18,6 +20,6 @@ pub use glwe_public_key::*;
pub use glwe_secret::*; pub use glwe_secret::*;
pub use glwe_switching_key::*; pub use glwe_switching_key::*;
pub use glwe_tensor_key::*; pub use glwe_tensor_key::*;
pub use glwe_to_lwe_switching_key::*; pub use glwe_to_lwe_key::*;
pub use lwe_switching_key::*; pub use lwe_switching_key::*;
pub use lwe_to_glwe_switching_key::*; pub use lwe_to_glwe_key::*;

View File

@@ -4,6 +4,7 @@ mod decryption;
mod dist; mod dist;
mod encryption; mod encryption;
mod external_product; mod external_product;
mod glwe_packer;
mod glwe_packing; mod glwe_packing;
mod glwe_trace; mod glwe_trace;
mod keyswitching; mod keyswitching;
@@ -20,6 +21,7 @@ pub use decryption::*;
pub use dist::*; pub use dist::*;
pub use encryption::*; pub use encryption::*;
pub use external_product::*; pub use external_product::*;
pub use glwe_packer::*;
pub use glwe_packing::*; pub use glwe_packing::*;
pub use glwe_trace::*; pub use glwe_trace::*;
pub use keyswitching::*; pub use keyswitching::*;

View File

@@ -62,7 +62,7 @@ where
let noise_have: f64 = pt.data.std(base2k, 0).log2(); let noise_have: f64 = pt.data.std(base2k, 0).log2();
// println!("noise_have: {noise_have}"); println!("noise_have: {noise_have}");
assert!( assert!(
noise_have <= max_noise, noise_have <= max_noise,

View File

@@ -162,6 +162,7 @@ where
sk_prepared, sk_prepared,
scratch.borrow(), scratch.borrow(),
); );
self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0); self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0);
let std_pt: f64 = pt_have.data.std(base2k, 0).log2(); let std_pt: f64 = pt_have.data.std(base2k, 0).log2();

View File

@@ -0,0 +1,55 @@
use poulpy_hal::layouts::{Backend, Module, Scratch};
use crate::{
GLWERotate, ScratchTakeCore,
layouts::{GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWEInfos},
};
impl<BE: Backend> GGSWRotate<BE> for Module<BE> where Module<BE>: GLWERotate<BE> {}
pub trait GGSWRotate<BE: Backend>
where
Self: GLWERotate<BE>,
{
fn ggsw_rotate_tmp_bytes(&self) -> usize {
self.glwe_rotate_tmp_bytes()
}
fn ggsw_rotate<R, A>(&self, k: i64, res: &mut R, a: &A)
where
R: GGSWToMut,
A: GGSWToRef,
{
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let a: &GGSW<&[u8]> = &a.to_ref();
assert!(res.dnum() <= a.dnum());
assert_eq!(res.dsize(), a.dsize());
assert_eq!(res.rank(), a.rank());
let rows: usize = res.dnum().into();
let cols: usize = (res.rank() + 1).into();
for row in 0..rows {
for col in 0..cols {
self.glwe_rotate(k, &mut res.at_mut(row, col), &a.at(row, col));
}
}
}
fn ggsw_rotate_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
where
R: GGSWToMut,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let rows: usize = res.dnum().into();
let cols: usize = (res.rank() + 1).into();
for row in 0..rows {
for col in 0..cols {
self.glwe_rotate_inplace(k, &mut res.at_mut(row, col), scratch);
}
}
}
}

View File

@@ -1,20 +1,85 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace, BivariateTensoring, ModuleN, ScratchTakeBasic, VecZnxAdd, VecZnxAddInplace, VecZnxBigNormalize, VecZnxCopy,
VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxIdftApplyConsume, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNormalize,
VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub, VecZnxSubInplace,
VecZnxSubNegateInplace, VecZnxZero,
}, },
layouts::{Backend, Module, Scratch, VecZnx, ZnxZero}, layouts::{Backend, Module, Scratch, VecZnx, VecZnxBig, ZnxInfos},
reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes,
}; };
use crate::{ use crate::{
ScratchTakeCore, ScratchTakeCore,
layouts::{GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision}, layouts::{
GLWE, GLWEInfos, GLWEPrepared, GLWEPreparedToRef, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos,
TorusPrecision,
},
}; };
pub trait GLWETensoring<BE: Backend>
where
Self: BivariateTensoring<BE> + VecZnxIdftApplyConsume<BE> + VecZnxBigNormalize<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
/// res = (a (x) b) * 2^{k * a_base2k}
///
/// # Requires
/// * a.base2k() == b.base2k()
/// * res.cols() >= a.cols() + b.cols() - 1
///
/// # Behavior
/// * res precision is truncated to res.max_k().min(a.max_k() + b.max_k() + k * a_base2k)
fn glwe_tensor<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
where
R: GLWETensorToMut,
A: GLWEToRef,
B: GLWEPreparedToRef<BE>,
{
let res: &mut GLWETensor<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
let b: &GLWEPrepared<&[u8], BE> = &b.to_ref();
assert_eq!(a.base2k(), b.base2k());
assert_eq!(a.rank(), res.rank());
let res_cols: usize = res.data.cols();
// Get tmp buffer of min precision between a_prec * b_prec and res_prec
let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, res_cols, res.max_k().div_ceil(a.base2k()) as usize);
// DFT(res) = DFT(a) (x) DFT(b)
self.bivariate_tensoring(k, &mut res_dft, &a.data, &b.data, scratch_1);
// res = IDFT(res)
let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
// Normalize and switches basis if required
for res_col in 0..res_cols {
self.vec_znx_big_normalize(
res.base2k().into(),
&mut res.data,
res_col,
a.base2k().into(),
&res_big,
res_col,
scratch_1,
);
}
}
// fn glwe_relinearize<R, A, T>(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch<BE>)
// where
// R: GLWEToRef,
// A: GLWETensorToRef,
// T: GLWETensorKeyPreparedToRef<BE>,
// {
// }
}
pub trait GLWEAdd pub trait GLWEAdd
where where
Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace, Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero,
{ {
fn glwe_add<R, A, B>(&self, res: &mut R, a: &A, b: &B) fn glwe_add<R, A, B>(&self, res: &mut R, a: &A, b: &B)
where where
@@ -30,35 +95,38 @@ where
assert_eq!(b.n(), self.n() as u32); assert_eq!(b.n(), self.n() as u32);
assert_eq!(res.n(), self.n() as u32); assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.base2k(), b.base2k()); assert_eq!(a.base2k(), b.base2k());
assert!(res.rank() >= a.rank().max(b.rank())); assert_eq!(res.base2k(), b.base2k());
if a.rank() == 0 {
assert_eq!(res.rank(), b.rank());
} else if b.rank() == 0 {
assert_eq!(res.rank(), a.rank());
} else {
assert_eq!(res.rank(), a.rank());
assert_eq!(res.rank(), b.rank());
}
let min_col: usize = (a.rank().min(b.rank()) + 1).into(); let min_col: usize = (a.rank().min(b.rank()) + 1).into();
let max_col: usize = (a.rank().max(b.rank() + 1)).into(); let max_col: usize = (a.rank().max(b.rank() + 1)).into();
let self_col: usize = (res.rank() + 1).into(); let self_col: usize = (res.rank() + 1).into();
(0..min_col).for_each(|i| { for i in 0..min_col {
self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i); self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i);
});
if a.rank() > b.rank() {
(min_col..max_col).for_each(|i| {
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
});
} else {
(min_col..max_col).for_each(|i| {
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
});
} }
let size: usize = res.size(); if a.rank() > b.rank() {
(max_col..self_col).for_each(|i| { for i in min_col..max_col {
(0..size).for_each(|j| { self.vec_znx_copy(res.data_mut(), i, a.data(), i);
res.data.zero_at(i, j); }
}); } else {
}); for i in min_col..max_col {
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
}
}
res.set_base2k(a.base2k()); for i in max_col..self_col {
res.set_k(set_k_binary(res, a, b)); self.vec_znx_zero(res.data_mut(), i);
}
} }
fn glwe_add_inplace<R, A>(&self, res: &mut R, a: &A) fn glwe_add_inplace<R, A>(&self, res: &mut R, a: &A)
@@ -74,24 +142,22 @@ where
assert_eq!(res.base2k(), a.base2k()); assert_eq!(res.base2k(), a.base2k());
assert!(res.rank() >= a.rank()); assert!(res.rank() >= a.rank());
(0..(a.rank() + 1).into()).for_each(|i| { for i in 0..(a.rank() + 1).into() {
self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i); self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i);
}); }
res.set_k(set_k_unary(res, a))
} }
} }
impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace {} impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero {}
impl<BE: Backend> GLWESub for Module<BE> where impl<BE: Backend> GLWESub for Module<BE> where
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace
{ {
} }
pub trait GLWESub pub trait GLWESub
where where
Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace, Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace,
{ {
fn glwe_sub<R, A, B>(&self, res: &mut R, a: &A, b: &B) fn glwe_sub<R, A, B>(&self, res: &mut R, a: &A, b: &B)
where where
@@ -105,37 +171,40 @@ where
assert_eq!(a.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32);
assert_eq!(b.n(), self.n() as u32); assert_eq!(b.n(), self.n() as u32);
assert_eq!(a.base2k(), b.base2k()); assert_eq!(res.n(), self.n() as u32);
assert!(res.rank() >= a.rank().max(b.rank())); assert_eq!(a.base2k(), res.base2k());
assert_eq!(b.base2k(), res.base2k());
if a.rank() == 0 {
assert_eq!(res.rank(), b.rank());
} else if b.rank() == 0 {
assert_eq!(res.rank(), a.rank());
} else {
assert_eq!(res.rank(), a.rank());
assert_eq!(res.rank(), b.rank());
}
let min_col: usize = (a.rank().min(b.rank()) + 1).into(); let min_col: usize = (a.rank().min(b.rank()) + 1).into();
let max_col: usize = (a.rank().max(b.rank() + 1)).into(); let max_col: usize = (a.rank().max(b.rank() + 1)).into();
let self_col: usize = (res.rank() + 1).into(); let self_col: usize = (res.rank() + 1).into();
(0..min_col).for_each(|i| { for i in 0..min_col {
self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i); self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i);
});
if a.rank() > b.rank() {
(min_col..max_col).for_each(|i| {
self.vec_znx_copy(res.data_mut(), i, a.data(), i);
});
} else {
(min_col..max_col).for_each(|i| {
self.vec_znx_copy(res.data_mut(), i, b.data(), i);
self.vec_znx_negate_inplace(res.data_mut(), i);
});
} }
let size: usize = res.size(); if a.rank() > b.rank() {
(max_col..self_col).for_each(|i| { for i in min_col..max_col {
(0..size).for_each(|j| { self.vec_znx_copy(res.data_mut(), i, a.data(), i);
res.data.zero_at(i, j); }
}); } else {
}); for i in min_col..max_col {
self.vec_znx_negate(res.data_mut(), i, b.data(), i);
}
}
res.set_base2k(a.base2k()); for i in max_col..self_col {
res.set_k(set_k_binary(res, a, b)); self.vec_znx_zero(res.data_mut(), i);
}
} }
fn glwe_sub_inplace<R, A>(&self, res: &mut R, a: &A) fn glwe_sub_inplace<R, A>(&self, res: &mut R, a: &A)
@@ -149,13 +218,11 @@ where
assert_eq!(res.n(), self.n() as u32); assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.base2k(), a.base2k()); assert_eq!(res.base2k(), a.base2k());
assert!(res.rank() >= a.rank()); assert!(res.rank() == a.rank() || a.rank() == 0);
(0..(a.rank() + 1).into()).for_each(|i| { for i in 0..(a.rank() + 1).into() {
self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i); self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i);
}); }
res.set_k(set_k_unary(res, a))
} }
fn glwe_sub_negate_inplace<R, A>(&self, res: &mut R, a: &A) fn glwe_sub_negate_inplace<R, A>(&self, res: &mut R, a: &A)
@@ -169,22 +236,24 @@ where
assert_eq!(res.n(), self.n() as u32); assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.base2k(), a.base2k()); assert_eq!(res.base2k(), a.base2k());
assert!(res.rank() >= a.rank()); assert!(res.rank() == a.rank() || a.rank() == 0);
(0..(a.rank() + 1).into()).for_each(|i| { for i in 0..(a.rank() + 1).into() {
self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i); self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i);
}); }
res.set_k(set_k_unary(res, a))
} }
} }
impl<BE: Backend> GLWERotate<BE> for Module<BE> where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> {} impl<BE: Backend> GLWERotate<BE> for Module<BE> where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero {}
pub trait GLWERotate<BE: Backend> pub trait GLWERotate<BE: Backend>
where where
Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE>, Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero,
{ {
fn glwe_rotate_tmp_bytes(&self) -> usize {
vec_znx_rotate_inplace_tmp_bytes(self.n())
}
fn glwe_rotate<R, A>(&self, k: i64, res: &mut R, a: &A) fn glwe_rotate<R, A>(&self, k: i64, res: &mut R, a: &A)
where where
R: GLWEToMut, R: GLWEToMut,
@@ -194,14 +263,18 @@ where
let a: &GLWE<&[u8]> = &a.to_ref(); let a: &GLWE<&[u8]> = &a.to_ref();
assert_eq!(a.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.rank(), a.rank()); assert_eq!(res.n(), self.n() as u32);
assert!(res.rank() == a.rank() || a.rank() == 0);
(0..(a.rank() + 1).into()).for_each(|i| { let res_cols = (res.rank() + 1).into();
let a_cols = (a.rank() + 1).into();
for i in 0..a_cols {
self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i); self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i);
}); }
for i in a_cols..res_cols {
res.set_base2k(a.base2k()); self.vec_znx_zero(res.data_mut(), i);
res.set_k(set_k_unary(res, a)) }
} }
fn glwe_rotate_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>) fn glwe_rotate_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
@@ -211,9 +284,9 @@ where
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
(0..(res.rank() + 1).into()).for_each(|i| { for i in 0..(res.rank() + 1).into() {
self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch); self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch);
}); }
} }
} }
@@ -238,9 +311,6 @@ where
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i); self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i);
} }
res.set_base2k(a.base2k());
res.set_k(set_k_unary(res, a))
} }
fn glwe_mul_xp_minus_one_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>) fn glwe_mul_xp_minus_one_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
@@ -257,11 +327,11 @@ where
} }
} }
impl<BE: Backend> GLWECopy for Module<BE> where Self: ModuleN + VecZnxCopy {} impl<BE: Backend> GLWECopy for Module<BE> where Self: ModuleN + VecZnxCopy + VecZnxZero {}
pub trait GLWECopy pub trait GLWECopy
where where
Self: ModuleN + VecZnxCopy, Self: ModuleN + VecZnxCopy + VecZnxZero,
{ {
fn glwe_copy<R, A>(&self, res: &mut R, a: &A) fn glwe_copy<R, A>(&self, res: &mut R, a: &A)
where where
@@ -273,14 +343,17 @@ where
assert_eq!(res.n(), self.n() as u32); assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32);
assert_eq!(res.rank(), a.rank()); assert!(res.rank() == a.rank() || a.rank() == 0);
for i in 0..res.rank().as_usize() + 1 { let min_rank: usize = res.rank().min(a.rank()).as_usize() + 1;
for i in 0..min_rank {
self.vec_znx_copy(res.data_mut(), i, a.data(), i); self.vec_znx_copy(res.data_mut(), i, a.data(), i);
} }
res.set_k(a.k().min(res.max_k())); for i in min_rank..(res.rank() + 1).into() {
res.set_base2k(a.base2k()); self.vec_znx_zero(res.data_mut(), i);
}
} }
} }
@@ -346,8 +419,6 @@ where
scratch, scratch,
); );
} }
res.set_k(a.k().min(res.k()));
} }
fn glwe_normalize_inplace<R>(&self, res: &mut R, scratch: &mut Scratch<BE>) fn glwe_normalize_inplace<R>(&self, res: &mut R, scratch: &mut Scratch<BE>)
@@ -362,6 +433,7 @@ where
} }
} }
#[allow(dead_code)]
// c = op(a, b) // c = op(a, b)
fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision { fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
// If either operands is a ciphertext // If either operands is a ciphertext
@@ -383,6 +455,7 @@ fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> T
} }
} }
#[allow(dead_code)]
// a = op(a, b) // a = op(a, b)
fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision { fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
if a.rank() != 0 || b.rank() != 0 { if a.rank() != 0 || b.rank() != 0 {

View File

@@ -1,3 +1,5 @@
mod ggsw;
mod glwe; mod glwe;
pub use ggsw::*;
pub use glwe::*; pub use glwe::*;

View File

@@ -7,7 +7,7 @@ use crate::{
dist::Distribution, dist::Distribution,
layouts::{ layouts::{
Degree, GGLWE, GGLWEInfos, GGLWELayout, GGSW, GGSWInfos, GLWE, GLWEAutomorphismKey, GLWEInfos, GLWEPlaintext, Degree, GGLWE, GGLWEInfos, GGLWELayout, GGSW, GGSWInfos, GLWE, GLWEAutomorphismKey, GLWEInfos, GLWEPlaintext,
GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESwitchingKey, GLWETensorKey, Rank, GLWEPrepared, GLWEPublicKey, GLWESecret, GLWESecretTensor, GLWESwitchingKey, GLWETensorKey, Rank,
prepared::{ prepared::{
GGLWEPrepared, GGSWPrepared, GLWEAutomorphismKeyPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared, GGLWEPrepared, GGSWPrepared, GLWEAutomorphismKeyPrepared, GLWEPublicKeyPrepared, GLWESecretPrepared,
GLWESwitchingKeyPrepared, GLWETensorKeyPrepared, GLWESwitchingKeyPrepared, GLWETensorKeyPrepared,
@@ -232,6 +232,18 @@ where
) )
} }
fn take_glwe_secret_tensor(&mut self, n: Degree, rank: Rank) -> (GLWESecretTensor<&mut [u8]>, &mut Self) {
let (data, scratch) = self.take_scalar_znx(n.into(), GLWESecretTensor::pairs(rank.into()));
(
GLWESecretTensor {
data,
rank,
dist: Distribution::NONE,
},
scratch,
)
}
fn take_glwe_secret_prepared<M>(&mut self, module: &M, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self) fn take_glwe_secret_prepared<M>(&mut self, module: &M, rank: Rank) -> (GLWESecretPrepared<&mut [u8], B>, &mut Self)
where where
M: ModuleN + SvpPPolBytesOf, M: ModuleN + SvpPPolBytesOf,
@@ -313,25 +325,12 @@ where
infos.rank_out(), infos.rank_out(),
"rank_in != rank_out is not supported for GLWETensorKey" "rank_in != rank_out is not supported for GLWETensorKey"
); );
let mut keys: Vec<GGLWE<&mut [u8]>> = Vec::new();
let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize;
let mut scratch: &mut Self = self;
let pairs: u32 = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1);
let mut ksk_infos: GGLWELayout = infos.gglwe_layout(); let mut ksk_infos: GGLWELayout = infos.gglwe_layout();
ksk_infos.rank_in = Rank(1); ksk_infos.rank_in = Rank(pairs);
let (data, scratch) = self.take_gglwe(infos);
if pairs != 0 { (GLWETensorKey(data), scratch)
let (gglwe, s) = scratch.take_gglwe(&ksk_infos);
scratch = s;
keys.push(gglwe);
}
for _ in 1..pairs {
let (gglwe, s) = scratch.take_gglwe(&ksk_infos);
scratch = s;
keys.push(gglwe);
}
(GLWETensorKey { keys }, scratch)
} }
fn take_glwe_tensor_key_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GLWETensorKeyPrepared<&mut [u8], B>, &mut Self) fn take_glwe_tensor_key_prepared<A, M>(&mut self, module: &M, infos: &A) -> (GLWETensorKeyPrepared<&mut [u8], B>, &mut Self)
@@ -346,25 +345,11 @@ where
"rank_in != rank_out is not supported for GGLWETensorKeyPrepared" "rank_in != rank_out is not supported for GGLWETensorKeyPrepared"
); );
let mut keys: Vec<GGLWEPrepared<&mut [u8], B>> = Vec::new(); let pairs: u32 = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1);
let pairs: usize = (((infos.rank_out().0 + 1) * infos.rank_out().0) >> 1).max(1) as usize;
let mut scratch: &mut Self = self;
let mut ksk_infos: GGLWELayout = infos.gglwe_layout(); let mut ksk_infos: GGLWELayout = infos.gglwe_layout();
ksk_infos.rank_in = Rank(1); ksk_infos.rank_in = Rank(pairs);
let (data, scratch) = self.take_gglwe_prepared(module, infos);
if pairs != 0 { (GLWETensorKeyPrepared(data), scratch)
let (gglwe, s) = scratch.take_gglwe_prepared(module, &ksk_infos);
scratch = s;
keys.push(gglwe);
}
for _ in 1..pairs {
let (gglwe, s) = scratch.take_gglwe_prepared(module, &ksk_infos);
scratch = s;
keys.push(gglwe);
}
(GLWETensorKeyPrepared { keys }, scratch)
} }
} }

View File

@@ -36,6 +36,7 @@ gglwe_automorphism_key_encrypt_sk => crate::tests::test_suite::encryption::test_
gglwe_automorphism_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_compressed_encrypt_sk, gglwe_automorphism_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_compressed_encrypt_sk,
gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk,
gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk,
gglwe_to_ggsw_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_to_ggsw_key_encrypt_sk,
// GGLWE Keyswitching // GGLWE Keyswitching
gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch,
gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace,
@@ -93,6 +94,7 @@ gglwe_automorphism_key_encrypt_sk => crate::tests::test_suite::encryption::test_
gglwe_automorphism_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_compressed_encrypt_sk, gglwe_automorphism_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_automorphism_key_compressed_encrypt_sk,
gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk, gglwe_tensor_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_encrypt_sk,
gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk, gglwe_tensor_key_compressed_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_tensor_key_compressed_encrypt_sk,
gglwe_to_ggsw_key_encrypt_sk => crate::tests::test_suite::encryption::test_gglwe_to_ggsw_key_encrypt_sk,
// GGLWE Keyswitching // GGLWE Keyswitching
gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch, gglwe_switching_key_keyswitch => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch,
gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace, gglwe_switching_key_keyswitch_inplace => crate::tests::test_suite::keyswitch::test_gglwe_switching_key_keyswitch_inplace,

View File

@@ -1,12 +1,12 @@
use poulpy_hal::test_suite::serialization::test_reader_writer_interface; use poulpy_hal::test_suite::serialization::test_reader_writer_interface;
use crate::layouts::{ use crate::layouts::{
Base2K, Degree, Dnum, Dsize, GGLWE, GGSW, GLWE, GLWEAutomorphismKey, GLWESwitchingKey, GLWETensorKey, GLWEToLWESwitchingKey, Base2K, Degree, Dnum, Dsize, GGLWE, GGSW, GLWE, GLWEAutomorphismKey, GLWESwitchingKey, GLWETensorKey, GLWEToLWEKey, LWE,
LWE, LWESwitchingKey, LWEToGLWESwitchingKey, Rank, TorusPrecision, LWESwitchingKey, LWEToGLWEKey, Rank, TorusPrecision,
compressed::{ compressed::{
GGLWECompressed, GGSWCompressed, GLWEAutomorphismKeyCompressed, GLWECompressed, GLWESwitchingKeyCompressed, GGLWECompressed, GGSWCompressed, GLWEAutomorphismKeyCompressed, GLWECompressed, GLWESwitchingKeyCompressed,
GLWETensorKeyCompressed, GLWEToLWESwitchingKeyCompressed, LWECompressed, LWESwitchingKeyCompressed, GLWETensorKeyCompressed, GLWEToLWESwitchingKeyCompressed, LWECompressed, LWESwitchingKeyCompressed,
LWEToGLWESwitchingKeyCompressed, LWEToGLWEKeyCompressed,
}, },
}; };
@@ -93,28 +93,27 @@ fn test_tensor_key_compressed_serialization() {
} }
#[test] #[test]
fn glwe_to_lwe_switching_key_serialization() { fn glwe_to_lwe_key_serialization() {
let original: GLWEToLWESwitchingKey<Vec<u8>> = GLWEToLWESwitchingKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM); let original: GLWEToLWEKey<Vec<u8>> = GLWEToLWEKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM);
test_reader_writer_interface(original); test_reader_writer_interface(original);
} }
#[test] #[test]
fn glwe_to_lwe_switching_key_compressed_serialization() { fn glwe_to_lwe_key_compressed_serialization() {
let original: GLWEToLWESwitchingKeyCompressed<Vec<u8>> = let original: GLWEToLWESwitchingKeyCompressed<Vec<u8>> =
GLWEToLWESwitchingKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM); GLWEToLWESwitchingKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM);
test_reader_writer_interface(original); test_reader_writer_interface(original);
} }
#[test] #[test]
fn lwe_to_glwe_switching_key_serialization() { fn lwe_to_glwe_key_serialization() {
let original: LWEToGLWESwitchingKey<Vec<u8>> = LWEToGLWESwitchingKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM); let original: LWEToGLWEKey<Vec<u8>> = LWEToGLWEKey::alloc(N_GLWE, BASE2K, K, RANK, DNUM);
test_reader_writer_interface(original); test_reader_writer_interface(original);
} }
#[test] #[test]
fn lwe_to_glwe_switching_key_compressed_serialization() { fn lwe_to_glwe_key_compressed_serialization() {
let original: LWEToGLWESwitchingKeyCompressed<Vec<u8>> = let original: LWEToGLWEKeyCompressed<Vec<u8>> = LWEToGLWEKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM);
LWEToGLWESwitchingKeyCompressed::alloc(N_GLWE, BASE2K, K, RANK, DNUM);
test_reader_writer_interface(original); test_reader_writer_interface(original);
} }

View File

@@ -5,12 +5,12 @@ use poulpy_hal::{
}; };
use crate::{ use crate::{
GGSWAutomorphism, GGSWEncryptSk, GGSWNoise, GLWEAutomorphismKeyEncryptSk, GLWETensorKeyEncryptSk, ScratchTakeCore, GGLWEToGGSWKeyEncryptSk, GGSWAutomorphism, GGSWEncryptSk, GGSWNoise, GLWEAutomorphismKeyEncryptSk, ScratchTakeCore,
encryption::SIGMA, encryption::SIGMA,
layouts::{ layouts::{
GGSW, GGSWLayout, GLWEAutomorphismKey, GLWEAutomorphismKeyPreparedFactory, GLWESecret, GLWESecretPreparedFactory, GGLWEToGGSWKey, GGLWEToGGSWKeyLayout, GGLWEToGGSWKeyPreparedFactory, GGSW, GGSWLayout, GLWEAutomorphismKey,
GLWETensorKey, GLWETensorKeyLayout, GLWETensorKeyPreparedFactory, GLWEAutomorphismKeyPreparedFactory, GLWESecret, GLWESecretPreparedFactory,
prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared, GLWETensorKeyPrepared}, prepared::{GGLWEToGGSWKeyPrepared, GLWEAutomorphismKeyPrepared, GLWESecretPrepared},
}, },
noise::noise_ggsw_keyswitch, noise::noise_ggsw_keyswitch,
}; };
@@ -21,8 +21,8 @@ where
+ GLWEAutomorphismKeyEncryptSk<BE> + GLWEAutomorphismKeyEncryptSk<BE>
+ GLWEAutomorphismKeyPreparedFactory<BE> + GLWEAutomorphismKeyPreparedFactory<BE>
+ GGSWAutomorphism<BE> + GGSWAutomorphism<BE>
+ GLWETensorKeyPreparedFactory<BE> + GGLWEToGGSWKeyPreparedFactory<BE>
+ GLWETensorKeyEncryptSk<BE> + GGLWEToGGSWKeyEncryptSk<BE>
+ GLWESecretPreparedFactory<BE> + GLWESecretPreparedFactory<BE>
+ VecZnxAutomorphismInplace<BE> + VecZnxAutomorphismInplace<BE>
+ GGSWNoise<BE>, + GGSWNoise<BE>,
@@ -64,7 +64,7 @@ where
rank: rank.into(), rank: rank.into(),
}; };
let tensor_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k.into(),
k: k_tsk.into(), k: k_tsk.into(),
@@ -73,7 +73,7 @@ where
rank: rank.into(), rank: rank.into(),
}; };
let auto_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
@@ -84,7 +84,7 @@ where
let mut ct_in: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_in_layout); let mut ct_in: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_in_layout);
let mut ct_out: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_out_layout); let mut ct_out: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_out_layout);
let mut tensor_key: GLWETensorKey<Vec<u8>> = GLWETensorKey::alloc_from_infos(&tensor_key_layout); let mut tsk: GGLWEToGGSWKey<Vec<u8>> = GGLWEToGGSWKey::alloc_from_infos(&tsk_layout);
let mut auto_key: GLWEAutomorphismKey<Vec<u8>> = GLWEAutomorphismKey::alloc_from_infos(&auto_key_layout); let mut auto_key: GLWEAutomorphismKey<Vec<u8>> = GLWEAutomorphismKey::alloc_from_infos(&auto_key_layout);
let mut pt_scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, 1); let mut pt_scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, 1);
@@ -95,8 +95,8 @@ where
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GGSW::encrypt_sk_tmp_bytes(module, &ct_in) GGSW::encrypt_sk_tmp_bytes(module, &ct_in)
| GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key) | GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key)
| GLWETensorKey::encrypt_sk_tmp_bytes(module, &tensor_key) | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk)
| GGSW::automorphism_tmp_bytes(module, &ct_out, &ct_in, &auto_key, &tensor_key), | GGSW::automorphism_tmp_bytes(module, &ct_out, &ct_in, &auto_key, &tsk),
); );
let var_xs: f64 = 0.5; let var_xs: f64 = 0.5;
@@ -115,7 +115,7 @@ where
&mut source_xe, &mut source_xe,
scratch.borrow(), scratch.borrow(),
); );
tensor_key.encrypt_sk( tsk.encrypt_sk(
module, module,
&sk, &sk,
&mut source_xa, &mut source_xa,
@@ -138,9 +138,8 @@ where
GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_layout); GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_layout);
auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); auto_key_prepared.prepare(module, &auto_key, scratch.borrow());
let mut tsk_prepared: GLWETensorKeyPrepared<Vec<u8>, BE> = let mut tsk_prepared: GGLWEToGGSWKeyPrepared<Vec<u8>, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk);
GLWETensorKeyPrepared::alloc_from_infos(module, &tensor_key_layout); tsk_prepared.prepare(module, &tsk, scratch.borrow());
tsk_prepared.prepare(module, &tensor_key, scratch.borrow());
ct_out.automorphism( ct_out.automorphism(
module, module,
@@ -180,8 +179,8 @@ where
+ GLWEAutomorphismKeyEncryptSk<BE> + GLWEAutomorphismKeyEncryptSk<BE>
+ GLWEAutomorphismKeyPreparedFactory<BE> + GLWEAutomorphismKeyPreparedFactory<BE>
+ GGSWAutomorphism<BE> + GGSWAutomorphism<BE>
+ GLWETensorKeyPreparedFactory<BE> + GGLWEToGGSWKeyPreparedFactory<BE>
+ GLWETensorKeyEncryptSk<BE> + GGLWEToGGSWKeyEncryptSk<BE>
+ GLWESecretPreparedFactory<BE> + GLWESecretPreparedFactory<BE>
+ VecZnxAutomorphismInplace<BE> + VecZnxAutomorphismInplace<BE>
+ GGSWNoise<BE>, + GGSWNoise<BE>,
@@ -211,7 +210,7 @@ where
rank: rank.into(), rank: rank.into(),
}; };
let tensor_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k.into(),
k: k_tsk.into(), k: k_tsk.into(),
@@ -220,7 +219,7 @@ where
rank: rank.into(), rank: rank.into(),
}; };
let auto_key_layout: GLWETensorKeyLayout = GLWETensorKeyLayout { let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k.into(),
k: k_ksk.into(), k: k_ksk.into(),
@@ -230,7 +229,7 @@ where
}; };
let mut ct: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_out_layout); let mut ct: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_out_layout);
let mut tensor_key: GLWETensorKey<Vec<u8>> = GLWETensorKey::alloc_from_infos(&tensor_key_layout); let mut tsk: GGLWEToGGSWKey<Vec<u8>> = GGLWEToGGSWKey::alloc_from_infos(&tsk_layout);
let mut auto_key: GLWEAutomorphismKey<Vec<u8>> = GLWEAutomorphismKey::alloc_from_infos(&auto_key_layout); let mut auto_key: GLWEAutomorphismKey<Vec<u8>> = GLWEAutomorphismKey::alloc_from_infos(&auto_key_layout);
let mut pt_scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, 1); let mut pt_scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, 1);
@@ -241,8 +240,8 @@ where
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GGSW::encrypt_sk_tmp_bytes(module, &ct) GGSW::encrypt_sk_tmp_bytes(module, &ct)
| GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key) | GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key)
| GLWETensorKey::encrypt_sk_tmp_bytes(module, &tensor_key) | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk)
| GGSW::automorphism_tmp_bytes(module, &ct, &ct, &auto_key, &tensor_key), | GGSW::automorphism_tmp_bytes(module, &ct, &ct, &auto_key, &tsk),
); );
let var_xs: f64 = 0.5; let var_xs: f64 = 0.5;
@@ -261,7 +260,7 @@ where
&mut source_xe, &mut source_xe,
scratch.borrow(), scratch.borrow(),
); );
tensor_key.encrypt_sk( tsk.encrypt_sk(
module, module,
&sk, &sk,
&mut source_xa, &mut source_xa,
@@ -284,9 +283,8 @@ where
GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_layout); GLWEAutomorphismKeyPrepared::alloc_from_infos(module, &auto_key_layout);
auto_key_prepared.prepare(module, &auto_key, scratch.borrow()); auto_key_prepared.prepare(module, &auto_key, scratch.borrow());
let mut tsk_prepared: GLWETensorKeyPrepared<Vec<u8>, BE> = let mut tsk_prepared: GGLWEToGGSWKeyPrepared<Vec<u8>, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk);
GLWETensorKeyPrepared::alloc_from_infos(module, &tensor_key_layout); tsk_prepared.prepare(module, &tsk, scratch.borrow());
tsk_prepared.prepare(module, &tensor_key, scratch.borrow());
ct.automorphism_inplace(module, &auto_key_prepared, &tsk_prepared, scratch.borrow()); ct.automorphism_inplace(module, &auto_key_prepared, &tsk_prepared, scratch.borrow());

View File

@@ -8,10 +8,10 @@ use crate::{
GLWEDecrypt, GLWEEncryptSk, GLWEFromLWE, GLWEToLWESwitchingKeyEncryptSk, LWEDecrypt, LWEEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWEFromLWE, GLWEToLWESwitchingKeyEncryptSk, LWEDecrypt, LWEEncryptSk,
LWEToGLWESwitchingKeyEncryptSk, ScratchTakeCore, LWEToGLWESwitchingKeyEncryptSk, ScratchTakeCore,
layouts::{ layouts::{
Base2K, Degree, Dnum, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWEToLWEKeyLayout, Base2K, Degree, Dnum, GLWE, GLWELayout, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWEToLWEKey,
GLWEToLWESwitchingKey, GLWEToLWESwitchingKeyPreparedFactory, LWE, LWELayout, LWEPlaintext, LWESecret, GLWEToLWEKeyLayout, GLWEToLWEKeyPrepared, GLWEToLWEKeyPreparedFactory, LWE, LWELayout, LWEPlaintext, LWESecret,
LWEToGLWESwitchingKey, LWEToGLWESwitchingKeyLayout, LWEToGLWESwitchingKeyPreparedFactory, Rank, TorusPrecision, LWEToGLWEKey, LWEToGLWEKeyLayout, LWEToGLWEKeyPrepared, LWEToGLWEKeyPreparedFactory, Rank, TorusPrecision,
prepared::{GLWESecretPrepared, GLWEToLWESwitchingKeyPrepared, LWEToGLWESwitchingKeyPrepared}, prepared::GLWESecretPrepared,
}, },
}; };
@@ -22,7 +22,7 @@ where
+ GLWEDecrypt<BE> + GLWEDecrypt<BE>
+ GLWESecretPreparedFactory<BE> + GLWESecretPreparedFactory<BE>
+ LWEEncryptSk<BE> + LWEEncryptSk<BE>
+ LWEToGLWESwitchingKeyPreparedFactory<BE>, + LWEToGLWEKeyPreparedFactory<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
@@ -36,7 +36,7 @@ where
let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
let lwe_to_glwe_infos: LWEToGLWESwitchingKeyLayout = LWEToGLWESwitchingKeyLayout { let lwe_to_glwe_infos: LWEToGLWEKeyLayout = LWEToGLWEKeyLayout {
n: n_glwe, n: n_glwe,
base2k: Base2K(17), base2k: Base2K(17),
k: TorusPrecision(51), k: TorusPrecision(51),
@@ -58,7 +58,7 @@ where
}; };
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
LWEToGLWESwitchingKey::encrypt_sk_tmp_bytes(module, &lwe_to_glwe_infos) LWEToGLWEKey::encrypt_sk_tmp_bytes(module, &lwe_to_glwe_infos)
| GLWE::from_lwe_tmp_bytes(module, &glwe_infos, &lwe_infos, &lwe_to_glwe_infos) | GLWE::from_lwe_tmp_bytes(module, &glwe_infos, &lwe_infos, &lwe_to_glwe_infos)
| GLWE::decrypt_tmp_bytes(module, &glwe_infos), | GLWE::decrypt_tmp_bytes(module, &glwe_infos),
); );
@@ -80,7 +80,7 @@ where
let mut lwe_ct: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos); let mut lwe_ct: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos);
lwe_ct.encrypt_sk(module, &lwe_pt, &sk_lwe, &mut source_xa, &mut source_xe); lwe_ct.encrypt_sk(module, &lwe_pt, &sk_lwe, &mut source_xa, &mut source_xe);
let mut ksk: LWEToGLWESwitchingKey<Vec<u8>> = LWEToGLWESwitchingKey::alloc_from_infos(&lwe_to_glwe_infos); let mut ksk: LWEToGLWEKey<Vec<u8>> = LWEToGLWEKey::alloc_from_infos(&lwe_to_glwe_infos);
ksk.encrypt_sk( ksk.encrypt_sk(
module, module,
@@ -93,8 +93,7 @@ where
let mut glwe_ct: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_infos); let mut glwe_ct: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_infos);
let mut ksk_prepared: LWEToGLWESwitchingKeyPrepared<Vec<u8>, BE> = let mut ksk_prepared: LWEToGLWEKeyPrepared<Vec<u8>, BE> = LWEToGLWEKeyPrepared::alloc_from_infos(module, &ksk);
LWEToGLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk);
ksk_prepared.prepare(module, &ksk, scratch.borrow()); ksk_prepared.prepare(module, &ksk, scratch.borrow());
glwe_ct.from_lwe(module, &lwe_ct, &ksk_prepared, scratch.borrow()); glwe_ct.from_lwe(module, &lwe_ct, &ksk_prepared, scratch.borrow());
@@ -114,7 +113,7 @@ where
+ GLWEDecrypt<BE> + GLWEDecrypt<BE>
+ GLWESecretPreparedFactory<BE> + GLWESecretPreparedFactory<BE>
+ GLWEToLWESwitchingKeyEncryptSk<BE> + GLWEToLWESwitchingKeyEncryptSk<BE>
+ GLWEToLWESwitchingKeyPreparedFactory<BE>, + GLWEToLWEKeyPreparedFactory<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
@@ -150,7 +149,7 @@ where
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GLWEToLWESwitchingKey::encrypt_sk_tmp_bytes(module, &glwe_to_lwe_infos) GLWEToLWEKey::encrypt_sk_tmp_bytes(module, &glwe_to_lwe_infos)
| LWE::from_glwe_tmp_bytes(module, &lwe_infos, &glwe_infos, &glwe_to_lwe_infos) | LWE::from_glwe_tmp_bytes(module, &lwe_infos, &glwe_infos, &glwe_to_lwe_infos)
| GLWE::decrypt_tmp_bytes(module, &glwe_infos), | GLWE::decrypt_tmp_bytes(module, &glwe_infos),
); );
@@ -178,7 +177,7 @@ where
scratch.borrow(), scratch.borrow(),
); );
let mut ksk: GLWEToLWESwitchingKey<Vec<u8>> = GLWEToLWESwitchingKey::alloc_from_infos(&glwe_to_lwe_infos); let mut ksk: GLWEToLWEKey<Vec<u8>> = GLWEToLWEKey::alloc_from_infos(&glwe_to_lwe_infos);
ksk.encrypt_sk( ksk.encrypt_sk(
module, module,
@@ -191,8 +190,7 @@ where
let mut lwe_ct: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos); let mut lwe_ct: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos);
let mut ksk_prepared: GLWEToLWESwitchingKeyPrepared<Vec<u8>, BE> = let mut ksk_prepared: GLWEToLWEKeyPrepared<Vec<u8>, BE> = GLWEToLWEKeyPrepared::alloc_from_infos(module, &ksk);
GLWEToLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk);
ksk_prepared.prepare(module, &ksk, scratch.borrow()); ksk_prepared.prepare(module, &ksk, scratch.borrow());
lwe_ct.from_glwe(module, &glwe_ct, &ksk_prepared, scratch.borrow()); lwe_ct.from_glwe(module, &glwe_ct, &ksk_prepared, scratch.borrow());

View File

@@ -0,0 +1,144 @@
use poulpy_hal::{
api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxCopy},
layouts::{Backend, Module, ScalarZnx, Scratch, ScratchOwned},
source::Source,
};
use crate::{
GGLWENoise, GGLWEToGGSWKeyCompressedEncryptSk, GGLWEToGGSWKeyEncryptSk, ScratchTakeCore,
decryption::GLWEDecrypt,
encryption::SIGMA,
layouts::{
Dsize, GGLWEDecompress, GGLWEToGGSWKey, GGLWEToGGSWKeyCompressed, GGLWEToGGSWKeyDecompress, GGLWEToGGSWKeyLayout,
GLWESecret, GLWESecretPreparedFactory, GLWESecretTensor, GLWESecretTensorFactory, LWEInfos, prepared::GLWESecretPrepared,
},
};
pub fn test_gglwe_to_ggsw_key_encrypt_sk<BE: Backend>(module: &Module<BE>)
where
Module<BE>: GGLWEToGGSWKeyEncryptSk<BE>
+ GLWESecretTensorFactory<BE>
+ GLWESecretPreparedFactory<BE>
+ GLWEDecrypt<BE>
+ GGLWENoise<BE>
+ VecZnxCopy,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{
let base2k: usize = 8;
let k: usize = 54;
for rank in 2_usize..3 {
let n: usize = module.n();
let dnum: usize = k / base2k;
let key_infos: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(),
base2k: base2k.into(),
k: k.into(),
dnum: dnum.into(),
dsize: Dsize(1),
rank: rank.into(),
};
let mut key: GGLWEToGGSWKey<Vec<u8>> = GGLWEToGGSWKey::alloc_from_infos(&key_infos);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &key_infos));
let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&key_infos);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank.into());
sk_prepared.prepare(module, &sk);
key.encrypt_sk(
module,
&sk,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut sk_tensor: GLWESecretTensor<Vec<u8>> = GLWESecretTensor::alloc_from_infos(&sk);
sk_tensor.prepare(module, &sk, scratch.borrow());
let max_noise = SIGMA.log2() + 0.5 - (key.k().as_u32() as f64);
let mut pt_want: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(module.n(), rank);
for i in 0..rank {
for j in 0..rank {
module.vec_znx_copy(
&mut pt_want.as_vec_znx_mut(),
j,
&sk_tensor.at(i, j).as_vec_znx(),
0,
);
}
println!("pt_want: {}", pt_want.as_vec_znx());
module.gglwe_assert_noise(key.at(i), &sk_prepared, &pt_want, max_noise);
}
}
}
pub fn test_gglwe_to_ggsw_compressed_encrypt_sk<BE: Backend>(module: &Module<BE>)
where
Module<BE>: GGLWEToGGSWKeyCompressedEncryptSk<BE>
+ GLWESecretPreparedFactory<BE>
+ GLWEDecrypt<BE>
+ GLWESecretTensorFactory<BE>
+ GGLWENoise<BE>
+ GGLWEDecompress
+ GGLWEToGGSWKeyDecompress,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{
let base2k = 8;
let k = 54;
for rank in 1_usize..3 {
let n: usize = module.n();
let dnum: usize = k / base2k;
let key_infos: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(),
base2k: base2k.into(),
k: k.into(),
dnum: dnum.into(),
dsize: Dsize(1),
rank: rank.into(),
};
let mut key_compressed: GGLWEToGGSWKeyCompressed<Vec<u8>> = GGLWEToGGSWKeyCompressed::alloc_from_infos(&key_infos);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(GGLWEToGGSWKeyCompressed::encrypt_sk_tmp_bytes(
module, &key_infos,
));
let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&key_infos);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_prepared: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(module, rank.into());
sk_prepared.prepare(module, &sk);
let seed_xa: [u8; 32] = [1u8; 32];
key_compressed.encrypt_sk(module, &sk, seed_xa, &mut source_xe, scratch.borrow());
let mut key: GGLWEToGGSWKey<Vec<u8>> = GGLWEToGGSWKey::alloc_from_infos(&key_infos);
key.decompress(module, &key_compressed);
let mut sk_tensor: GLWESecretTensor<Vec<u8>> = GLWESecretTensor::alloc_from_infos(&sk);
sk_tensor.prepare(module, &sk, scratch.borrow());
for i in 0..rank {
module.gglwe_assert_noise(key.at(i), &sk_prepared, &sk_tensor.data, SIGMA + 0.5);
}
}
}

View File

@@ -1,20 +1,16 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow},
ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDft, VecZnxBigAlloc, VecZnxBigNormalize, layouts::{Backend, Module, Scratch, ScratchOwned},
VecZnxCopy, VecZnxDftAlloc, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyTmpA, VecZnxSubScalarInplace,
VecZnxSwitchRing,
},
layouts::{Backend, Module, Scratch, ScratchOwned, VecZnxBig, VecZnxDft},
source::Source, source::Source,
}; };
use crate::{ use crate::{
GLWETensorKeyCompressedEncryptSk, GLWETensorKeyEncryptSk, ScratchTakeCore, GGLWENoise, GLWETensorKeyCompressedEncryptSk, GLWETensorKeyEncryptSk, ScratchTakeCore,
decryption::GLWEDecrypt, decryption::GLWEDecrypt,
encryption::SIGMA, encryption::SIGMA,
layouts::{ layouts::{
Dsize, GLWEPlaintext, GLWESecret, GLWESecretPreparedFactory, GLWETensorKey, GLWETensorKeyCompressed, GLWETensorKeyLayout, Dsize, GGLWEDecompress, GLWESecret, GLWESecretPreparedFactory, GLWESecretTensor, GLWESecretTensorFactory, GLWETensorKey,
prepared::GLWESecretPrepared, GLWETensorKeyCompressed, GLWETensorKeyLayout, prepared::GLWESecretPrepared,
}, },
}; };
@@ -23,20 +19,15 @@ where
Module<BE>: GLWETensorKeyEncryptSk<BE> Module<BE>: GLWETensorKeyEncryptSk<BE>
+ GLWESecretPreparedFactory<BE> + GLWESecretPreparedFactory<BE>
+ GLWEDecrypt<BE> + GLWEDecrypt<BE>
+ VecZnxDftAlloc<BE> + GLWESecretTensorFactory<BE>
+ VecZnxBigAlloc<BE> + GGLWENoise<BE>,
+ VecZnxDftApply<BE>
+ SvpApplyDftToDft<BE>
+ VecZnxIdftApplyTmpA<BE>
+ VecZnxBigNormalize<BE>
+ VecZnxSubScalarInplace,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k: usize = 8; let base2k: usize = 8;
let k: usize = 54; let k: usize = 54;
for rank in 1_usize..3 { for rank in 2_usize..3 {
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k / base2k; let dnum: usize = k / base2k;
@@ -73,42 +64,10 @@ where
scratch.borrow(), scratch.borrow(),
); );
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&tensor_key_infos); let mut sk_tensor: GLWESecretTensor<Vec<u8>> = GLWESecretTensor::alloc_from_infos(&sk);
sk_tensor.prepare(module, &sk, scratch.borrow());
let mut sk_ij_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(1, 1); module.gglwe_assert_noise(&tensor_key, &sk_prepared, &sk_tensor.data, SIGMA + 0.5);
let mut sk_ij_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(1, 1);
let mut sk_ij: GLWESecret<Vec<u8>> = GLWESecret::alloc(n.into(), 1_u32.into());
let mut sk_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(rank, 1);
for i in 0..rank {
module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
}
for i in 0..rank {
for j in 0..rank {
module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i);
module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
module.vec_znx_big_normalize(
base2k,
&mut sk_ij.data.as_vec_znx_mut(),
0,
base2k,
&sk_ij_big,
0,
scratch.borrow(),
);
for row_i in 0..dnum {
let ct = tensor_key.at(i, j).at(row_i, 0);
ct.decrypt(module, &mut pt, &sk_prepared, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, 0);
let std_pt: f64 = pt.data.std(base2k, 0) * (k as f64).exp2();
assert!((SIGMA - std_pt).abs() <= 0.5, "{SIGMA} {std_pt}");
}
}
}
} }
} }
@@ -118,15 +77,9 @@ where
+ GLWESecretPreparedFactory<BE> + GLWESecretPreparedFactory<BE>
+ GLWETensorKeyCompressedEncryptSk<BE> + GLWETensorKeyCompressedEncryptSk<BE>
+ GLWEDecrypt<BE> + GLWEDecrypt<BE>
+ VecZnxDftAlloc<BE> + GLWESecretTensorFactory<BE>
+ VecZnxBigAlloc<BE> + GGLWENoise<BE>
+ VecZnxDftApply<BE> + GGLWEDecompress,
+ SvpApplyDftToDft<BE>
+ VecZnxIdftApplyTmpA<BE>
+ VecZnxSubScalarInplace
+ VecZnxFillUniform
+ VecZnxCopy
+ VecZnxSwitchRing,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
@@ -168,42 +121,9 @@ where
let mut tensor_key: GLWETensorKey<Vec<u8>> = GLWETensorKey::alloc_from_infos(&tensor_key_infos); let mut tensor_key: GLWETensorKey<Vec<u8>> = GLWETensorKey::alloc_from_infos(&tensor_key_infos);
tensor_key.decompress(module, &tensor_key_compressed); tensor_key.decompress(module, &tensor_key_compressed);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&tensor_key_infos); let mut sk_tensor: GLWESecretTensor<Vec<u8>> = GLWESecretTensor::alloc_from_infos(&sk);
sk_tensor.prepare(module, &sk, scratch.borrow());
let mut sk_ij_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(1, 1); module.gglwe_assert_noise(&tensor_key, &sk_prepared, &sk_tensor.data, SIGMA + 0.5);
let mut sk_ij_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(1, 1);
let mut sk_ij: GLWESecret<Vec<u8>> = GLWESecret::alloc(n.into(), 1_u32.into());
let mut sk_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(rank, 1);
for i in 0..rank {
module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
}
for i in 0..rank {
for j in 0..rank {
module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i);
module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
module.vec_znx_big_normalize(
base2k,
&mut sk_ij.data.as_vec_znx_mut(),
0,
base2k,
&sk_ij_big,
0,
scratch.borrow(),
);
for row_i in 0..dnum {
tensor_key
.at(i, j)
.at(row_i, 0)
.decrypt(module, &mut pt, &sk_prepared, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij.data, 0);
let std_pt: f64 = pt.data.std(base2k, 0) * (k as f64).exp2();
assert!((SIGMA - std_pt).abs() <= 0.5, "{SIGMA} {std_pt}");
}
}
}
} }
} }

View File

@@ -1,11 +1,13 @@
mod gglwe_atk; mod gglwe_atk;
mod gglwe_ct; mod gglwe_ct;
mod gglwe_to_ggsw_key;
mod ggsw_ct; mod ggsw_ct;
mod glwe_ct; mod glwe_ct;
mod glwe_tsk; mod glwe_tsk;
pub use gglwe_atk::*; pub use gglwe_atk::*;
pub use gglwe_ct::*; pub use gglwe_ct::*;
pub use gglwe_to_ggsw_key::*;
pub use ggsw_ct::*; pub use ggsw_ct::*;
pub use glwe_ct::*; pub use glwe_ct::*;
pub use glwe_tsk::*; pub use glwe_tsk::*;

View File

@@ -5,12 +5,13 @@ use poulpy_hal::{
}; };
use crate::{ use crate::{
GGSWEncryptSk, GGSWKeyswitch, GGSWNoise, GLWESwitchingKeyEncryptSk, GLWETensorKeyEncryptSk, ScratchTakeCore, GGLWEToGGSWKeyEncryptSk, GGSWEncryptSk, GGSWKeyswitch, GGSWNoise, GLWESwitchingKeyEncryptSk, ScratchTakeCore,
encryption::SIGMA, encryption::SIGMA,
layouts::{ layouts::{
GGSW, GGSWLayout, GLWESecret, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, GGLWEToGGSWKey, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedFactory, GGSW, GGSWLayout, GLWESecret,
GLWESwitchingKeyPreparedFactory, GLWETensorKey, GLWETensorKeyLayout, GLWETensorKeyPreparedFactory, GLWESecretPreparedFactory, GLWESwitchingKey, GLWESwitchingKeyLayout, GLWESwitchingKeyPreparedFactory,
prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared, GLWETensorKeyPrepared}, GLWETensorKeyLayout,
prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared},
}, },
noise::noise_ggsw_keyswitch, noise::noise_ggsw_keyswitch,
}; };
@@ -20,10 +21,10 @@ pub fn test_ggsw_keyswitch<BE: Backend>(module: &Module<BE>)
where where
Module<BE>: GGSWEncryptSk<BE> Module<BE>: GGSWEncryptSk<BE>
+ GLWESwitchingKeyEncryptSk<BE> + GLWESwitchingKeyEncryptSk<BE>
+ GLWETensorKeyEncryptSk<BE> + GGLWEToGGSWKeyEncryptSk<BE>
+ GGSWKeyswitch<BE> + GGSWKeyswitch<BE>
+ GLWESecretPreparedFactory<BE> + GLWESecretPreparedFactory<BE>
+ GLWETensorKeyPreparedFactory<BE> + GGLWEToGGSWKeyPreparedFactory<BE>
+ GLWESwitchingKeyPreparedFactory<BE> + GLWESwitchingKeyPreparedFactory<BE>
+ GGSWNoise<BE>, + GGSWNoise<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
@@ -82,7 +83,7 @@ where
let mut ggsw_in: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_in_infos); let mut ggsw_in: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_in_infos);
let mut ggsw_out: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_out_infos); let mut ggsw_out: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_out_infos);
let mut tsk: GLWETensorKey<Vec<u8>> = GLWETensorKey::alloc_from_infos(&tsk_infos); let mut tsk: GGLWEToGGSWKey<Vec<u8>> = GGLWEToGGSWKey::alloc_from_infos(&tsk_infos);
let mut ksk: GLWESwitchingKey<Vec<u8>> = GLWESwitchingKey::alloc_from_infos(&ksk_apply_infos); let mut ksk: GLWESwitchingKey<Vec<u8>> = GLWESwitchingKey::alloc_from_infos(&ksk_apply_infos);
let mut pt_scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, 1); let mut pt_scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, 1);
@@ -93,7 +94,7 @@ where
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GGSW::encrypt_sk_tmp_bytes(module, &ggsw_in_infos) GGSW::encrypt_sk_tmp_bytes(module, &ggsw_in_infos)
| GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos)
| GLWETensorKey::encrypt_sk_tmp_bytes(module, &tsk_infos) | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk_infos)
| GGSW::keyswitch_tmp_bytes( | GGSW::keyswitch_tmp_bytes(
module, module,
&ggsw_out_infos, &ggsw_out_infos,
@@ -148,7 +149,7 @@ where
GLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); GLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk);
ksk_prepared.prepare(module, &ksk, scratch.borrow()); ksk_prepared.prepare(module, &ksk, scratch.borrow());
let mut tsk_prepared: GLWETensorKeyPrepared<Vec<u8>, BE> = GLWETensorKeyPrepared::alloc_from_infos(module, &tsk); let mut tsk_prepared: GGLWEToGGSWKeyPrepared<Vec<u8>, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk);
tsk_prepared.prepare(module, &tsk, scratch.borrow()); tsk_prepared.prepare(module, &tsk, scratch.borrow());
ggsw_out.keyswitch( ggsw_out.keyswitch(
@@ -185,10 +186,10 @@ pub fn test_ggsw_keyswitch_inplace<BE: Backend>(module: &Module<BE>)
where where
Module<BE>: GGSWEncryptSk<BE> Module<BE>: GGSWEncryptSk<BE>
+ GLWESwitchingKeyEncryptSk<BE> + GLWESwitchingKeyEncryptSk<BE>
+ GLWETensorKeyEncryptSk<BE> + GGLWEToGGSWKeyEncryptSk<BE>
+ GGSWKeyswitch<BE> + GGSWKeyswitch<BE>
+ GLWESecretPreparedFactory<BE> + GLWESecretPreparedFactory<BE>
+ GLWETensorKeyPreparedFactory<BE> + GGLWEToGGSWKeyPreparedFactory<BE>
+ GLWESwitchingKeyPreparedFactory<BE> + GLWESwitchingKeyPreparedFactory<BE>
+ GGSWNoise<BE>, + GGSWNoise<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
@@ -236,7 +237,7 @@ where
}; };
let mut ggsw_out: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_out_infos); let mut ggsw_out: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_out_infos);
let mut tsk: GLWETensorKey<Vec<u8>> = GLWETensorKey::alloc_from_infos(&tsk_infos); let mut tsk: GGLWEToGGSWKey<Vec<u8>> = GGLWEToGGSWKey::alloc_from_infos(&tsk_infos);
let mut ksk: GLWESwitchingKey<Vec<u8>> = GLWESwitchingKey::alloc_from_infos(&ksk_apply_infos); let mut ksk: GLWESwitchingKey<Vec<u8>> = GLWESwitchingKey::alloc_from_infos(&ksk_apply_infos);
let mut pt_scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, 1); let mut pt_scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, 1);
@@ -247,7 +248,7 @@ where
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GGSW::encrypt_sk_tmp_bytes(module, &ggsw_out_infos) GGSW::encrypt_sk_tmp_bytes(module, &ggsw_out_infos)
| GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos) | GLWESwitchingKey::encrypt_sk_tmp_bytes(module, &ksk_apply_infos)
| GLWETensorKey::encrypt_sk_tmp_bytes(module, &tsk_infos) | GGLWEToGGSWKey::encrypt_sk_tmp_bytes(module, &tsk_infos)
| GGSW::keyswitch_tmp_bytes( | GGSW::keyswitch_tmp_bytes(
module, module,
&ggsw_out_infos, &ggsw_out_infos,
@@ -302,7 +303,7 @@ where
GLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk); GLWESwitchingKeyPrepared::alloc_from_infos(module, &ksk);
ksk_prepared.prepare(module, &ksk, scratch.borrow()); ksk_prepared.prepare(module, &ksk, scratch.borrow());
let mut tsk_prepared: GLWETensorKeyPrepared<Vec<u8>, BE> = GLWETensorKeyPrepared::alloc_from_infos(module, &tsk); let mut tsk_prepared: GGLWEToGGSWKeyPrepared<Vec<u8>, BE> = GGLWEToGGSWKeyPrepared::alloc_from_infos(module, &tsk);
tsk_prepared.prepare(module, &tsk, scratch.borrow()); tsk_prepared.prepare(module, &tsk, scratch.borrow());
ggsw_out.keyswitch_inplace(module, &ksk_prepared, &tsk_prepared, scratch.borrow()); ggsw_out.keyswitch_inplace(module, &ksk_prepared, &tsk_prepared, scratch.borrow());

View File

@@ -7,7 +7,7 @@ use poulpy_hal::{
}; };
use crate::{ use crate::{
GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWEPacker, GLWEPacking, GLWERotate, GLWESub, ScratchTakeCore, GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWEPacker, GLWEPackerOps, GLWERotate, GLWESub, ScratchTakeCore,
layouts::{ layouts::{
GLWE, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWELayout, GLWEPlaintext, GLWE, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWELayout, GLWEPlaintext,
GLWESecret, GLWESecretPreparedFactory, GLWESecret, GLWESecretPreparedFactory,
@@ -20,7 +20,7 @@ where
Module<BE>: GLWEEncryptSk<BE> Module<BE>: GLWEEncryptSk<BE>
+ GLWEAutomorphismKeyEncryptSk<BE> + GLWEAutomorphismKeyEncryptSk<BE>
+ GLWEAutomorphismKeyPreparedFactory<BE> + GLWEAutomorphismKeyPreparedFactory<BE>
+ GLWEPacking<BE> + GLWEPackerOps<BE>
+ GLWESecretPreparedFactory<BE> + GLWESecretPreparedFactory<BE>
+ GLWESub + GLWESub
+ GLWEDecrypt<BE> + GLWEDecrypt<BE>

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "poulpy-hal" name = "poulpy-hal"
version = "0.2.0" version = "0.3.1"
edition = "2024" edition = "2024"
license = "Apache-2.0" license = "Apache-2.0"
readme = "README.md" readme = "README.md"
@@ -19,7 +19,7 @@ rand_core = {workspace = true}
byteorder = {workspace = true} byteorder = {workspace = true}
once_cell = {workspace = true} once_cell = {workspace = true}
rand_chacha = "0.9.0" rand_chacha = "0.9.0"
bytemuck = "1.23.2" bytemuck = {workspace = true}
[build-dependencies] [build-dependencies]

View File

@@ -1,74 +1,24 @@
use crate::{ use crate::{
api::{ api::{
ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxDftAddScaledInplace, ModuleN, ScratchTakeBasic, SvpApplyDftToDft, SvpPPolAlloc, SvpPPolBytesOf, SvpPrepare, VecZnxDftAddScaledInplace,
VecZnxDftBytesOf, VecZnxDftBytesOf, VecZnxDftZero,
}, },
layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos, ZnxZero}, layouts::{Backend, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, ZnxInfos},
}; };
impl<BE: Backend> Convolution<BE> for Module<BE> impl<BE: Backend> BivariateTensoring<BE> for Module<BE>
where where
Self: Sized Self: BivariateConvolution<BE>,
+ ModuleN
+ SvpPPolAlloc<BE>
+ SvpApplyDftToDft<BE>
+ SvpPrepare<BE>
+ SvpPPolBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftAddScaledInplace<BE>,
Scratch<BE>: ScratchTakeBasic, Scratch<BE>: ScratchTakeBasic,
{ {
} }
pub trait Convolution<BE: Backend> pub trait BivariateTensoring<BE: Backend>
where where
Self: Sized Self: BivariateConvolution<BE>,
+ ModuleN
+ SvpPPolAlloc<BE>
+ SvpApplyDftToDft<BE>
+ SvpPrepare<BE>
+ SvpPPolBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftAddScaledInplace<BE>,
Scratch<BE>: ScratchTakeBasic, Scratch<BE>: ScratchTakeBasic,
{ {
fn convolution_tmp_bytes(&self, res_size: usize) -> usize { fn bivariate_tensoring<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, res_size)
}
/// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K
/// and scales the result by 2^{res_scale * K}
///
/// # Example
/// a = [a00, a10, a20, a30] = (a00 * 2^-K + a01 * 2^-2K) + (a10 * 2^-K + a11 * 2^-2K) * X ...
/// [a01, a11, a21, a31]
///
/// b = [b00, b10, b20, b30] = (b00 * 2^-K + b01 * 2^-2K) + (b10 * 2^-K + b11 * 2^-2K) * X ...
/// [b01, b11, b21, b31]
///
/// If res_scale = 0:
/// res = [ 0, 0, 0, 0] = (r01 * 2^-2K + r02 * 2^-3K + r03 * 2^-4K + r04 * 2^-5K) + ...
/// [r01, r11, r21, r31]
/// [r02, r12, r22, r32]
/// [r03, r13, r23, r33]
/// [r04, r14, r24, r34]
///
/// If res_scale = 1:
/// res = [r01, r11, r21, r31] = (r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K + r04 * 2^-4K + r05 * 2^-5K) + ...
/// [r02, r12, r22, r32]
/// [r03, r13, r23, r33]
/// [r04, r14, r24, r34]
/// [r05, r15, r25, r35]
///
/// If res_scale = -1:
/// res = [ 0, 0, 0, 0] = (r01 * 2^-3K + r02 * 2^-4K + r03 * 2^-5K) + ...
/// [ 0, 0, 0, 0]
/// [r01, r11, r21, r31]
/// [r02, r12, r22, r32]
/// [r03, r13, r23, r33]
///
/// If res.size() < a.size() + b.size() + 1 + res_scale, result is truncated accordingly in the Y dimension.
fn convolution<R, A, B>(&self, res: &mut R, res_scale: i64, a: &A, b: &B, scratch: &mut Scratch<BE>)
where where
R: VecZnxDftToMut<BE>, R: VecZnxDftToMut<BE>,
A: VecZnxToRef, A: VecZnxToRef,
@@ -78,32 +28,135 @@ where
let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref(); let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref();
let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref(); let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref();
assert!(res.cols() >= a.cols() + b.cols() - 1); let res_cols: usize = res.cols();
let a_cols: usize = a.cols();
let b_cols: usize = b.cols();
res.zero(); assert!(res_cols >= a_cols + b_cols - 1);
for res_col in 0..res_cols {
self.vec_znx_dft_zero(res, res_col);
}
for a_col in 0..a_cols {
for b_col in 0..b_cols {
self.bivariate_convolution_add(k, res, a_col + b_col, a, a_col, b, b_col, scratch);
}
}
}
}
impl<BE: Backend> BivariateConvolution<BE> for Module<BE>
where
Self: Sized
+ ModuleN
+ SvpPPolAlloc<BE>
+ SvpApplyDftToDft<BE>
+ SvpPrepare<BE>
+ SvpPPolBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftAddScaledInplace<BE>
+ VecZnxDftZero<BE>,
Scratch<BE>: ScratchTakeBasic,
{
}
pub trait BivariateConvolution<BE: Backend>
where
Self: Sized
+ ModuleN
+ SvpPPolAlloc<BE>
+ SvpApplyDftToDft<BE>
+ SvpPrepare<BE>
+ SvpPPolBytesOf
+ VecZnxDftBytesOf
+ VecZnxDftAddScaledInplace<BE>
+ VecZnxDftZero<BE>,
Scratch<BE>: ScratchTakeBasic,
{
fn convolution_tmp_bytes(&self, b_size: usize) -> usize {
self.bytes_of_svp_ppol(1) + self.bytes_of_vec_znx_dft(1, b_size)
}
#[allow(clippy::too_many_arguments)]
/// Evaluates a bivariate convolution over Z[X, Y] / (X^N + 1) where Y = 2^-K over the
/// selected columsn and stores the result on the selected column, scaled by 2^{k * Base2K}
///
/// # Example
/// a = [a00, a10, a20, a30] = (a00 * 2^-K + a01 * 2^-2K) + (a10 * 2^-K + a11 * 2^-2K) * X ...
/// [a01, a11, a21, a31]
///
/// b = [b00, b10, b20, b30] = (b00 * 2^-K + b01 * 2^-2K) + (b10 * 2^-K + b11 * 2^-2K) * X ...
/// [b01, b11, b21, b31]
///
/// If k = 0:
/// res = [ 0, 0, 0, 0] = (r01 * 2^-2K + r02 * 2^-3K + r03 * 2^-4K + r04 * 2^-5K) + ...
/// [r01, r11, r21, r31]
/// [r02, r12, r22, r32]
/// [r03, r13, r23, r33]
/// [r04, r14, r24, r34]
///
/// If k = 1:
/// res = [r01, r11, r21, r31] = (r01 * 2^-K + r02 * 2^-2K + r03 * 2^-3K + r04 * 2^-4K + r05 * 2^-5K) + ...
/// [r02, r12, r22, r32]
/// [r03, r13, r23, r33]
/// [r04, r14, r24, r34]
/// [r05, r15, r25, r35]
///
/// If k = -1:
/// res = [ 0, 0, 0, 0] = (r01 * 2^-3K + r02 * 2^-4K + r03 * 2^-5K) + ...
/// [ 0, 0, 0, 0]
/// [r01, r11, r21, r31]
/// [r02, r12, r22, r32]
/// [r03, r13, r23, r33]
///
/// If res.size() < a.size() + b.size() + 1 + k, result is truncated accordingly in the Y dimension.
fn bivariate_convolution_add<R, A, B>(
&self,
k: i64,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
b: &B,
b_col: usize,
scratch: &mut Scratch<BE>,
) where
R: VecZnxDftToMut<BE>,
A: VecZnxToRef,
B: VecZnxDftToRef<BE>,
{
let res: &mut crate::layouts::VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
let a: &crate::layouts::VecZnx<&[u8]> = &a.to_ref();
let b: &crate::layouts::VecZnxDft<&[u8], BE> = &b.to_ref();
let (mut ppol, scratch_1) = scratch.take_svp_ppol(self, 1); let (mut ppol, scratch_1) = scratch.take_svp_ppol(self, 1);
let (mut res_tmp, _) = scratch_1.take_vec_znx_dft(self, 1, res.size()); let (mut res_tmp, _) = scratch_1.take_vec_znx_dft(self, 1, b.size());
for a_col in 0..a.cols() {
for a_limb in 0..a.size() { for a_limb in 0..a.size() {
// Prepares the j-th limb of the i-th col of A
self.svp_prepare(&mut ppol, 0, &a.as_scalar_znx_ref(a_col, a_limb), 0); self.svp_prepare(&mut ppol, 0, &a.as_scalar_znx_ref(a_col, a_limb), 0);
for b_col in 0..b.cols() {
// Multiplies with the i-th col of B
self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col); self.svp_apply_dft_to_dft(&mut res_tmp, 0, &ppol, 0, b, b_col);
self.vec_znx_dft_add_scaled_inplace(res, res_col, &res_tmp, 0, -(1 + a_limb as i64) + k);
}
}
// Adds on the [a_col + b_col] of res, scaled by 2^{-(a_limb + 1) * Base2K} #[allow(clippy::too_many_arguments)]
self.vec_znx_dft_add_scaled_inplace( fn bivariate_convolution<R, A, B>(
res, &self,
a_col + b_col, k: i64,
&res_tmp, res: &mut R,
0, res_col: usize,
-(1 + a_limb as i64) + res_scale, a: &A,
); a_col: usize,
} b: &B,
} b_col: usize,
} scratch: &mut Scratch<BE>,
) where
R: VecZnxDftToMut<BE>,
A: VecZnxToRef,
B: VecZnxDftToRef<BE>,
{
self.vec_znx_dft_zero(res, res_col);
self.bivariate_convolution_add(k, res, res_col, a, a_col, b, b_col, scratch);
} }
} }

View File

@@ -8,6 +8,12 @@ pub trait VecZnxNormalizeTmpBytes {
fn vec_znx_normalize_tmp_bytes(&self) -> usize; fn vec_znx_normalize_tmp_bytes(&self) -> usize;
} }
pub trait VecZnxZero {
fn vec_znx_zero<R>(&self, res: &mut R, res_col: usize)
where
R: VecZnxToMut;
}
pub trait VecZnxNormalize<B: Backend> { pub trait VecZnxNormalize<B: Backend> {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`. /// Normalizes the selected column of `a` and stores the result into the selected column of `res`.

View File

@@ -97,7 +97,7 @@ pub trait VecZnxDftCopy<B: Backend> {
} }
pub trait VecZnxDftZero<B: Backend> { pub trait VecZnxDftZero<B: Backend> {
fn vec_znx_dft_zero<R>(&self, res: &mut R) fn vec_znx_dft_zero<R>(&self, res: &mut R, res_col: usize)
where where
R: VecZnxDftToMut<B>; R: VecZnxDftToMut<B>;
} }

View File

@@ -6,7 +6,7 @@ use crate::{
VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize, VecZnxMulXpMinusOneInplace, VecZnxMulXpMinusOneInplaceTmpBytes, VecZnxNegate, VecZnxNegateInplace, VecZnxNormalize,
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes,
VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing, VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxRsh, VecZnxRshInplace, VecZnxRshTmpBytes, VecZnxSplitRing, VecZnxSplitRingTmpBytes, VecZnxSub, VecZnxSubInplace,
VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing, VecZnxSubNegateInplace, VecZnxSubScalar, VecZnxSubScalarInplace, VecZnxSwitchRing, VecZnxZero,
}, },
layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef}, layouts::{Backend, Module, ScalarZnxToRef, Scratch, VecZnxToMut, VecZnxToRef},
oep::{ oep::{
@@ -18,11 +18,23 @@ use crate::{
VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl, VecZnxNormalizeInplaceImpl, VecZnxNormalizeTmpBytesImpl, VecZnxRotateImpl, VecZnxRotateInplaceImpl,
VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl, VecZnxRotateInplaceTmpBytesImpl, VecZnxRshImpl, VecZnxRshInplaceImpl, VecZnxRshTmpBytesImpl, VecZnxSplitRingImpl,
VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl, VecZnxSplitRingTmpBytesImpl, VecZnxSubImpl, VecZnxSubInplaceImpl, VecZnxSubNegateInplaceImpl, VecZnxSubScalarImpl,
VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxSubScalarInplaceImpl, VecZnxSwitchRingImpl, VecZnxZeroImpl,
}, },
source::Source, source::Source,
}; };
impl<B> VecZnxZero for Module<B>
where
B: Backend + VecZnxZeroImpl<B>,
{
fn vec_znx_zero<R>(&self, res: &mut R, res_col: usize)
where
R: VecZnxToMut,
{
B::vec_znx_zero_impl(self, res, res_col);
}
}
impl<B> VecZnxNormalizeTmpBytes for Module<B> impl<B> VecZnxNormalizeTmpBytes for Module<B>
where where
B: Backend + VecZnxNormalizeTmpBytesImpl<B>, B: Backend + VecZnxNormalizeTmpBytesImpl<B>,

View File

@@ -200,10 +200,10 @@ impl<B> VecZnxDftZero<B> for Module<B>
where where
B: Backend + VecZnxDftZeroImpl<B>, B: Backend + VecZnxDftZeroImpl<B>,
{ {
fn vec_znx_dft_zero<R>(&self, res: &mut R) fn vec_znx_dft_zero<R>(&self, res: &mut R, res_col: usize)
where where
R: VecZnxDftToMut<B>, R: VecZnxDftToMut<B>,
{ {
B::vec_znx_dft_zero_impl(self, res); B::vec_znx_dft_zero_impl(self, res, res_col);
} }
} }

View File

@@ -3,6 +3,16 @@ use crate::{
source::Source, source::Source,
}; };
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See TODO for reference implementation.
/// * See [crate::api::VecZnxZero] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxZeroImpl<B: Backend> {
fn vec_znx_zero_impl<R>(module: &Module<B>, res: &mut R, res_col: usize)
where
R: VecZnxToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe) /// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation. /// * See [poulpy-backend/src/cpu_fft64_ref/vec_znx.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/vec_znx.rs) for reference implementation.
/// * See [crate::api::VecZnxNormalizeTmpBytes] for corresponding public API. /// * See [crate::api::VecZnxNormalizeTmpBytes] for corresponding public API.

View File

@@ -188,7 +188,7 @@ pub unsafe trait VecZnxDftCopyImpl<B: Backend> {
/// * See [crate::api::VecZnxDftZero] for corresponding public API. /// * See [crate::api::VecZnxDftZero] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract. /// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait VecZnxDftZeroImpl<B: Backend> { pub unsafe trait VecZnxDftZeroImpl<B: Backend> {
fn vec_znx_dft_zero_impl<R>(module: &Module<B>, res: &mut R) fn vec_znx_dft_zero_impl<R>(module: &Module<B>, res: &mut R, res_col: usize)
where where
R: VecZnxDftToMut<B>; R: VecZnxDftToMut<B>;
} }

View File

@@ -118,7 +118,7 @@ where
} }
} else if a_scale < 0 { } else if a_scale < 0 {
let shift: usize = (a_scale.unsigned_abs() as usize).min(res_size); let shift: usize = (a_scale.unsigned_abs() as usize).min(res_size);
let sum_size: usize = a_size.min(res_size).saturating_sub(shift); let sum_size: usize = a_size.min(res_size.saturating_sub(shift));
for j in 0..sum_size { for j in 0..sum_size {
BE::reim_add_inplace(res.at_mut(res_col, j + shift), a.at(a_col, j)); BE::reim_add_inplace(res.at_mut(res_col, j + shift), a.at(a_col, j));
} }
@@ -398,10 +398,13 @@ where
} }
} }
pub fn vec_znx_dft_zero<R, BE>(res: &mut R) pub fn vec_znx_dft_zero<R, BE>(res: &mut R, res_col: usize)
where where
R: VecZnxDftToMut<BE>, R: VecZnxDftToMut<BE>,
BE: Backend<ScalarPrep = f64> + ReimZero, BE: Backend<ScalarPrep = f64> + ReimZero,
{ {
BE::reim_zero(res.to_mut().raw_mut()); let res: &mut VecZnxDft<&mut [u8], BE> = &mut res.to_mut();
for j in 0..res.size() {
BE::reim_zero(res.at_mut(res_col, j))
}
} }

View File

@@ -13,6 +13,7 @@ mod split_ring;
mod sub; mod sub;
mod sub_scalar; mod sub_scalar;
mod switch_ring; mod switch_ring;
mod zero;
pub use add::*; pub use add::*;
pub use add_scalar::*; pub use add_scalar::*;
@@ -29,3 +30,4 @@ pub use split_ring::*;
pub use sub::*; pub use sub::*;
pub use sub_scalar::*; pub use sub_scalar::*;
pub use switch_ring::*; pub use switch_ring::*;
pub use zero::*;

View File

@@ -0,0 +1,16 @@
use crate::{
layouts::{VecZnx, VecZnxToMut, ZnxInfos, ZnxViewMut},
reference::znx::ZnxZero,
};
pub fn vec_znx_zero<R, ZNXARI>(res: &mut R, res_col: usize)
where
R: VecZnxToMut,
ZNXARI: ZnxZero,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let res_size = res.size();
for j in 0..res_size {
ZNXARI::znx_zero(res.at_mut(res_col, j));
}
}

View File

@@ -1,7 +1,7 @@
use crate::{ use crate::{
api::{ api::{
Convolution, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigNormalize, BivariateTensoring, ModuleN, ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, TakeSlice, VecZnxBigAlloc,
VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalizeInplace, VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeInplace,
}, },
layouts::{ layouts::{
Backend, FillUniform, Scratch, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, Backend, FillUniform, Scratch, ScratchOwned, VecZnx, VecZnxBig, VecZnxDft, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView,
@@ -10,15 +10,16 @@ use crate::{
source::Source, source::Source,
}; };
pub fn test_convolution<M, BE: Backend>(module: &M) pub fn test_bivariate_tensoring<M, BE: Backend>(module: &M)
where where
M: ModuleN M: ModuleN
+ Convolution<BE> + BivariateTensoring<BE>
+ VecZnxDftAlloc<BE> + VecZnxDftAlloc<BE>
+ VecZnxDftApply<BE> + VecZnxDftApply<BE>
+ VecZnxIdftApplyConsume<BE> + VecZnxIdftApplyTmpA<BE>
+ VecZnxBigNormalize<BE> + VecZnxBigNormalize<BE>
+ VecZnxNormalizeInplace<BE>, + VecZnxNormalizeInplace<BE>
+ VecZnxBigAlloc<BE>,
Scratch<BE>: ScratchTakeBasic, Scratch<BE>: ScratchTakeBasic,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{ {
@@ -26,36 +27,41 @@ where
let base2k: usize = 12; let base2k: usize = 12;
for a_cols in 1..3 { let a_cols: usize = 3;
for b_cols in 1..3 { let b_cols: usize = 3;
for a_size in 1..5 { let a_size: usize = 3;
for b_size in 1..5 { let b_size: usize = 3;
let c_cols: usize = a_cols + b_cols - 1;
let c_size: usize = a_size + b_size;
let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), a_cols, a_size); let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), a_cols, a_size);
let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), b_cols, b_size); let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), b_cols, b_size);
let mut c_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), a_cols + b_cols - 1, b_size + a_size); let mut c_want: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), c_cols, c_size);
let mut c_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), c_want.cols(), c_want.size()); let mut c_have: VecZnx<Vec<u8>> = VecZnx::alloc(module.n(), c_cols, c_size);
let mut c_have_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(c_cols, c_size);
let mut c_have_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_big_alloc(c_cols, c_size);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.convolution_tmp_bytes(c_want.size())); let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(module.convolution_tmp_bytes(b_size));
a.fill_uniform(base2k, &mut source); a.fill_uniform(base2k, &mut source);
b.fill_uniform(base2k, &mut source); b.fill_uniform(base2k, &mut source);
let mut b_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(b.cols(), b.size()); let mut b_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(b_cols, b_size);
for i in 0..b.cols() { for i in 0..b.cols() {
module.vec_znx_dft_apply(1, 0, &mut b_dft, i, &b, i); module.vec_znx_dft_apply(1, 0, &mut b_dft, i, &b, i);
} }
for mut res_scale in 0..2 * c_want.size() as i64 + 1 { for mut k in 0..(2 * c_size + 1) as i64 {
res_scale -= c_want.size() as i64; k -= c_size as i64;
let mut c_have_dft: VecZnxDft<Vec<u8>, BE> = module.vec_znx_dft_alloc(c_have.cols(), c_have.size()); module.bivariate_tensoring(k, &mut c_have_dft, &a, &b_dft, scratch.borrow());
module.convolution(&mut c_have_dft, res_scale, &a, &b_dft, scratch.borrow());
let c_have_big: VecZnxBig<Vec<u8>, BE> = module.vec_znx_idft_apply_consume(c_have_dft); for i in 0..c_cols {
module.vec_znx_idft_apply_tmpa(&mut c_have_big, i, &mut c_have_dft, i);
}
for i in 0..c_have.cols() { for i in 0..c_cols {
module.vec_znx_big_normalize( module.vec_znx_big_normalize(
base2k, base2k,
&mut c_have, &mut c_have,
@@ -67,29 +73,17 @@ where
); );
} }
convolution_naive( bivariate_tensoring_naive(module, base2k, k, &mut c_want, &a, &b, scratch.borrow());
module,
base2k,
&mut c_want,
res_scale,
&a,
&b,
scratch.borrow(),
);
assert_eq!(c_want, c_have); assert_eq!(c_want, c_have);
} }
} }
}
}
}
}
fn convolution_naive<R, A, B, M, BE: Backend>( fn bivariate_tensoring_naive<R, A, B, M, BE: Backend>(
module: &M, module: &M,
base2k: usize, base2k: usize,
k: i64,
res: &mut R, res: &mut R,
res_scale: i64,
a: &A, a: &A,
b: &B, b: &B,
scratch: &mut Scratch<BE>, scratch: &mut Scratch<BE>,
@@ -112,11 +106,11 @@ fn convolution_naive<R, A, B, M, BE: Backend>(
for a_limb in 0..a.size() { for a_limb in 0..a.size() {
for b_col in 0..b.cols() { for b_col in 0..b.cols() {
for b_limb in 0..b.size() { for b_limb in 0..b.size() {
let res_scale_abs = res_scale.unsigned_abs() as usize; let res_scale_abs = k.unsigned_abs() as usize;
let mut res_limb: usize = a_limb + b_limb + 1; let mut res_limb: usize = a_limb + b_limb + 1;
if res_scale <= 0 { if k <= 0 {
res_limb += res_scale_abs; res_limb += res_scale_abs;
if res_limb < res.size() { if res_limb < res.size() {

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "poulpy-schemes" name = "poulpy-schemes"
version = "0.2.0" version = "0.3.0"
edition = "2024" edition = "2024"
license = "Apache-2.0" license = "Apache-2.0"
readme = "README.md" readme = "README.md"

View File

@@ -0,0 +1,204 @@
use poulpy_core::{
GLWECopy, GLWERotate, ScratchTakeCore,
layouts::{GGSW, GGSWInfos, GGSWToMut, GGSWToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos},
};
use poulpy_hal::{
api::{VecZnxAddScalarInplace, VecZnxNormalizeInplace},
layouts::{Backend, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxZero},
};
use crate::tfhe::bdd_arithmetic::{Cmux, GetGGSWBit, UnsignedInteger};
impl<T: UnsignedInteger, BE: Backend> GGSWBlindRotation<T, BE> for Module<BE>
where
Self: GLWEBlindRotation<T, BE> + VecZnxAddScalarInplace + VecZnxNormalizeInplace<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
}
pub trait GGSWBlindRotation<T: UnsignedInteger, BE: Backend>
where
Self: GLWEBlindRotation<T, BE> + VecZnxAddScalarInplace + VecZnxNormalizeInplace<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn ggsw_to_ggsw_blind_rotation_tmp_bytes<R, K>(&self, res_infos: &R, k_infos: &K) -> usize
where
R: GLWEInfos,
K: GGSWInfos,
{
self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos)
}
#[allow(clippy::too_many_arguments)]
/// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}.
fn ggsw_to_ggsw_blind_rotation<R, A, K>(
&self,
res: &mut R,
a: &A,
k: &K,
bit_start: usize,
bit_mask: usize,
bit_lsh: usize,
scratch: &mut Scratch<BE>,
) where
R: GGSWToMut,
A: GGSWToRef,
K: GetGGSWBit<T, BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let a: &GGSW<&[u8]> = &a.to_ref();
assert!(res.dnum() <= a.dnum());
assert_eq!(res.dsize(), a.dsize());
for col in 0..(res.rank() + 1).into() {
for row in 0..res.dnum().into() {
self.glwe_to_glwe_blind_rotation(
&mut res.at_mut(row, col),
&a.at(row, col),
k,
bit_start,
bit_mask,
bit_lsh,
scratch,
);
}
}
}
fn scalar_to_ggsw_blind_rotation_tmp_bytes<R, K>(&self, res_infos: &R, k_infos: &K) -> usize
where
R: GLWEInfos,
K: GGSWInfos,
{
self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos)
}
#[allow(clippy::too_many_arguments)]
fn scalar_to_ggsw_blind_rotation<R, S, K>(
&self,
res: &mut R,
test_vector: &S,
k: &K,
bit_start: usize,
bit_mask: usize,
bit_lsh: usize,
scratch: &mut Scratch<BE>,
) where
R: GGSWToMut,
S: ScalarZnxToRef,
K: GetGGSWBit<T, BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let test_vector: &ScalarZnx<&[u8]> = &test_vector.to_ref();
let base2k: usize = res.base2k().into();
let dsize: usize = res.dsize().into();
let (mut tmp_glwe, scratch_1) = scratch.take_glwe(res);
for col in 0..(res.rank() + 1).into() {
for row in 0..res.dnum().into() {
tmp_glwe.data_mut().zero();
self.vec_znx_add_scalar_inplace(
tmp_glwe.data_mut(),
col,
(dsize - 1) + row * dsize,
test_vector,
0,
);
self.vec_znx_normalize_inplace(base2k, tmp_glwe.data_mut(), col, scratch_1);
self.glwe_to_glwe_blind_rotation(
&mut res.at_mut(row, col),
&tmp_glwe,
k,
bit_start,
bit_mask,
bit_lsh,
scratch_1,
);
}
}
}
}
impl<T: UnsignedInteger, BE: Backend> GLWEBlindRotation<T, BE> for Module<BE>
where
Self: GLWECopy + GLWERotate<BE> + Cmux<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
}
pub trait GLWEBlindRotation<T: UnsignedInteger, BE: Backend>
where
Self: GLWECopy + GLWERotate<BE> + Cmux<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn glwe_to_glwe_blind_rotation_tmp_bytes<R, K>(&self, res_infos: &R, k_infos: &K) -> usize
where
R: GLWEInfos,
K: GGSWInfos,
{
self.cmux_tmp_bytes(res_infos, res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos)
}
#[allow(clippy::too_many_arguments)]
/// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}.
fn glwe_to_glwe_blind_rotation<R, A, K>(
&self,
res: &mut R,
a: &A,
k: &K,
bit_rsh: usize,
bit_mask: usize,
bit_lsh: usize,
scratch: &mut Scratch<BE>,
) where
R: GLWEToMut,
A: GLWEToRef,
K: GetGGSWBit<T, BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
assert!(bit_rsh + bit_mask <= T::WORD_SIZE);
let mut res: GLWE<&mut [u8]> = res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
let (mut tmp_res, scratch_1) = scratch.take_glwe(&res);
// a <- a ; b <- a * X^{-2^{i + bit_lsh}}
self.glwe_rotate(-1 << bit_lsh, &mut res, a);
// b <- (b - a) * GGSW(b[i]) + a
self.cmux_inplace(&mut res, a, &k.get_bit(bit_rsh), scratch_1);
// a_is_res = true => (a, b) = (&mut res, &mut tmp_res)
// a_is_res = false => (a, b) = (&mut tmp_res, &mut res)
let mut a_is_res: bool = true;
for i in 1..bit_mask {
let (a, b) = if a_is_res {
(&mut res, &mut tmp_res)
} else {
(&mut tmp_res, &mut res)
};
// a <- a ; b <- a * X^{-2^{i + bit_lsh}}
self.glwe_rotate(-1 << (i + bit_lsh), b, a);
// b <- (b - a) * GGSW(b[i]) + a
self.cmux_inplace(b, a, &k.get_bit(i + bit_rsh), scratch_1);
// ping-pong roles for next iter
a_is_res = !a_is_res;
}
// Ensure the final value ends up in `res`
if !a_is_res {
self.glwe_copy(&mut res, &tmp_res);
}
}
}

View File

@@ -39,6 +39,17 @@ impl<D: DataRef, T: UnsignedInteger> GLWEInfos for FheUintBlocks<D, T> {
} }
} }
impl<D: Data, T: UnsignedInteger> FheUintBlocks<D, T> {
pub fn new(blocks: Vec<GLWE<D>>) -> Self {
assert_eq!(blocks.len(), T::WORD_SIZE);
Self {
blocks,
_base: 1,
_phantom: PhantomData,
}
}
}
impl<T: UnsignedInteger> FheUintBlocks<Vec<u8>, T> { impl<T: UnsignedInteger> FheUintBlocks<Vec<u8>, T> {
pub fn alloc_from_infos<A, BE: Backend>(module: &Module<BE>, infos: &A) -> Self pub fn alloc_from_infos<A, BE: Backend>(module: &Module<BE>, infos: &A) -> Self
where where

View File

@@ -3,6 +3,7 @@ use std::marker::PhantomData;
use poulpy_core::layouts::{ use poulpy_core::layouts::{
Base2K, Dnum, Dsize, GGSWInfos, GGSWPreparedFactory, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWPrepared, Base2K, Dnum, Dsize, GGSWInfos, GGSWPreparedFactory, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWPrepared,
}; };
use poulpy_core::layouts::{GGSWPreparedToMut, GGSWPreparedToRef};
use poulpy_core::{GGSWEncryptSk, ScratchTakeCore, layouts::GLWESecretPreparedToRef}; use poulpy_core::{GGSWEncryptSk, ScratchTakeCore, layouts::GLWESecretPreparedToRef};
use poulpy_hal::layouts::{Backend, Data, DataRef, Module}; use poulpy_hal::layouts::{Backend, Data, DataRef, Module};
@@ -28,6 +29,28 @@ impl<T: UnsignedInteger, BE: Backend> FheUintBlocksPreparedFactory<T, BE> for Mo
{ {
} }
pub trait GetGGSWBit<T: UnsignedInteger, BE: Backend> {
fn get_bit(&self, bit: usize) -> GGSWPrepared<&[u8], BE>;
}
impl<D: DataRef, T: UnsignedInteger, BE: Backend> GetGGSWBit<T, BE> for FheUintBlocksPrepared<D, T, BE> {
fn get_bit(&self, bit: usize) -> GGSWPrepared<&[u8], BE> {
assert!(bit <= self.blocks.len());
self.blocks[bit].to_ref()
}
}
pub trait GetGGSWBitMut<T: UnsignedInteger, BE: Backend> {
fn get_bit(&mut self, bit: usize) -> GGSWPrepared<&mut [u8], BE>;
}
impl<D: DataMut, T: UnsignedInteger, BE: Backend> GetGGSWBitMut<T, BE> for FheUintBlocksPrepared<D, T, BE> {
fn get_bit(&mut self, bit: usize) -> GGSWPrepared<&mut [u8], BE> {
assert!(bit <= self.blocks.len());
self.blocks[bit].to_mut()
}
}
pub trait FheUintBlocksPreparedFactory<T: UnsignedInteger, BE: Backend> pub trait FheUintBlocksPreparedFactory<T: UnsignedInteger, BE: Backend>
where where
Self: Sized + GGSWPreparedFactory<BE>, Self: Sized + GGSWPreparedFactory<BE>,

View File

@@ -3,12 +3,9 @@ use core::panic;
use itertools::Itertools; use itertools::Itertools;
use poulpy_core::{ use poulpy_core::{
GLWEAdd, GLWECopy, GLWEExternalProduct, GLWESub, ScratchTakeCore, GLWEAdd, GLWECopy, GLWEExternalProduct, GLWESub, ScratchTakeCore,
layouts::{ layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef},
GLWE, LWEInfos,
prepared::{GGSWPrepared, GGSWPreparedToRef},
},
}; };
use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero}; use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero};
use crate::tfhe::bdd_arithmetic::UnsignedInteger; use crate::tfhe::bdd_arithmetic::UnsignedInteger;
@@ -146,30 +143,47 @@ pub enum Node {
None, None,
} }
pub trait Cmux<BE: Backend> { pub trait Cmux<BE: Backend>
fn cmux<O, T, F, S>(&self, out: &mut GLWE<O>, t: &GLWE<T>, f: &GLWE<F>, s: &GGSWPrepared<S, BE>, scratch: &mut Scratch<BE>)
where where
O: DataMut, Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd,
T: DataRef, Scratch<BE>: ScratchTakeCore<BE>,
F: DataRef, {
S: DataRef; fn cmux_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where
R: GLWEInfos,
A: GLWEInfos,
B: GGSWInfos,
{
self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos)
}
fn cmux<R, T, F, S>(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
T: GLWEToRef,
F: GLWEToRef,
S: GGSWPreparedToRef<BE>,
{
self.glwe_sub(res, t, f);
self.glwe_external_product_inplace(res, s, scratch);
self.glwe_add_inplace(res, f);
}
fn cmux_inplace<R, A, S>(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
S: GGSWPreparedToRef<BE>,
{
self.glwe_sub_inplace(res, a);
self.glwe_external_product_inplace(res, s, scratch);
self.glwe_add_inplace(res, a);
}
} }
impl<BE: Backend> Cmux<BE> for Module<BE> impl<BE: Backend> Cmux<BE> for Module<BE>
where where
Module<BE>: GLWEExternalProduct<BE> + GLWESub + GLWEAdd, Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn cmux<O, T, F, S>(&self, out: &mut GLWE<O>, t: &GLWE<T>, f: &GLWE<F>, s: &GGSWPrepared<S, BE>, scratch: &mut Scratch<BE>)
where
O: DataMut,
T: DataRef,
F: DataRef,
S: DataRef,
{
// let mut out: GLWECiphertext<&mut [u8]> = out.to_mut();
self.glwe_sub(out, t, f);
self.glwe_external_product_inplace(out, s, scratch);
self.glwe_add_inplace(out, f);
}
} }

View File

@@ -10,8 +10,8 @@ use crate::tfhe::{
use poulpy_core::{ use poulpy_core::{
GLWEToLWESwitchingKeyEncryptSk, GetDistribution, LWEFromGLWE, ScratchTakeCore, GLWEToLWESwitchingKeyEncryptSk, GetDistribution, LWEFromGLWE, ScratchTakeCore,
layouts::{ layouts::{
GGSWInfos, GGSWPreparedFactory, GLWEInfos, GLWESecretToRef, GLWEToLWEKeyLayout, GLWEToLWESwitchingKey, GGSWInfos, GGSWPreparedFactory, GLWEInfos, GLWESecretToRef, GLWEToLWEKey, GLWEToLWEKeyLayout,
GLWEToLWESwitchingKeyPreparedFactory, LWE, LWEInfos, LWESecretToRef, prepared::GLWEToLWESwitchingKeyPrepared, GLWEToLWEKeyPreparedFactory, LWE, LWEInfos, LWESecretToRef, prepared::GLWEToLWEKeyPrepared,
}, },
}; };
use poulpy_hal::{ use poulpy_hal::{
@@ -46,7 +46,7 @@ where
BRA: BlindRotationAlgo, BRA: BlindRotationAlgo,
{ {
cbt: CircuitBootstrappingKey<D, BRA>, cbt: CircuitBootstrappingKey<D, BRA>,
ks: GLWEToLWESwitchingKey<D>, ks: GLWEToLWEKey<D>,
} }
impl<BRA: BlindRotationAlgo> BDDKey<Vec<u8>, BRA> impl<BRA: BlindRotationAlgo> BDDKey<Vec<u8>, BRA>
@@ -59,7 +59,7 @@ where
{ {
Self { Self {
cbt: CircuitBootstrappingKey::alloc_from_infos(&infos.cbt_infos()), cbt: CircuitBootstrappingKey::alloc_from_infos(&infos.cbt_infos()),
ks: GLWEToLWESwitchingKey::alloc_from_infos(&infos.ks_infos()), ks: GLWEToLWEKey::alloc_from_infos(&infos.ks_infos()),
} }
} }
} }
@@ -130,12 +130,12 @@ where
BE: Backend, BE: Backend,
{ {
pub(crate) cbt: CircuitBootstrappingKeyPrepared<D, BRA, BE>, pub(crate) cbt: CircuitBootstrappingKeyPrepared<D, BRA, BE>,
pub(crate) ks: GLWEToLWESwitchingKeyPrepared<D, BE>, pub(crate) ks: GLWEToLWEKeyPrepared<D, BE>,
} }
pub trait BDDKeyPreparedFactory<BRA: BlindRotationAlgo, BE: Backend> pub trait BDDKeyPreparedFactory<BRA: BlindRotationAlgo, BE: Backend>
where where
Self: Sized + CircuitBootstrappingKeyPreparedFactory<BRA, BE> + GLWEToLWESwitchingKeyPreparedFactory<BE>, Self: Sized + CircuitBootstrappingKeyPreparedFactory<BRA, BE> + GLWEToLWEKeyPreparedFactory<BE>,
{ {
fn alloc_bdd_key_from_infos<A>(&self, infos: &A) -> BDDKeyPrepared<Vec<u8>, BRA, BE> fn alloc_bdd_key_from_infos<A>(&self, infos: &A) -> BDDKeyPrepared<Vec<u8>, BRA, BE>
where where
@@ -143,7 +143,7 @@ where
{ {
BDDKeyPrepared { BDDKeyPrepared {
cbt: CircuitBootstrappingKeyPrepared::alloc_from_infos(self, &infos.cbt_infos()), cbt: CircuitBootstrappingKeyPrepared::alloc_from_infos(self, &infos.cbt_infos()),
ks: GLWEToLWESwitchingKeyPrepared::alloc_from_infos(self, &infos.ks_infos()), ks: GLWEToLWEKeyPrepared::alloc_from_infos(self, &infos.ks_infos()),
} }
} }
@@ -152,7 +152,7 @@ where
A: BDDKeyInfos, A: BDDKeyInfos,
{ {
self.circuit_bootstrapping_key_prepare_tmp_bytes(&infos.cbt_infos()) self.circuit_bootstrapping_key_prepare_tmp_bytes(&infos.cbt_infos())
.max(self.prepare_glwe_to_lwe_switching_key_tmp_bytes(&infos.ks_infos())) .max(self.prepare_glwe_to_lwe_key_tmp_bytes(&infos.ks_infos()))
} }
fn prepare_bdd_key<DM, DR>(&self, res: &mut BDDKeyPrepared<DM, BRA, BE>, other: &BDDKey<DR, BRA>, scratch: &mut Scratch<BE>) fn prepare_bdd_key<DM, DR>(&self, res: &mut BDDKeyPrepared<DM, BRA, BE>, other: &BDDKey<DR, BRA>, scratch: &mut Scratch<BE>)
@@ -166,7 +166,7 @@ where
} }
} }
impl<BRA: BlindRotationAlgo, BE: Backend> BDDKeyPreparedFactory<BRA, BE> for Module<BE> where impl<BRA: BlindRotationAlgo, BE: Backend> BDDKeyPreparedFactory<BRA, BE> for Module<BE> where
Self: Sized + CircuitBootstrappingKeyPreparedFactory<BRA, BE> + GLWEToLWESwitchingKeyPreparedFactory<BE> Self: Sized + CircuitBootstrappingKeyPreparedFactory<BRA, BE> + GLWEToLWEKeyPreparedFactory<BE>
{ {
} }

View File

@@ -1,10 +1,12 @@
mod bdd_2w_to_1w; mod bdd_2w_to_1w;
mod blind_rotation;
mod ciphertexts; mod ciphertexts;
mod circuits; mod circuits;
mod eval; mod eval;
mod key; mod key;
pub use bdd_2w_to_1w::*; pub use bdd_2w_to_1w::*;
pub use blind_rotation::*;
pub use ciphertexts::*; pub use ciphertexts::*;
pub(crate) use circuits::*; pub(crate) use circuits::*;
pub(crate) use eval::*; pub(crate) use eval::*;

View File

@@ -3,11 +3,21 @@ use poulpy_backend::FFT64Ref;
use crate::tfhe::{ use crate::tfhe::{
bdd_arithmetic::tests::test_suite::{ bdd_arithmetic::tests::test_suite::{
test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, test_bdd_sra, test_bdd_add, test_bdd_and, test_bdd_or, test_bdd_prepare, test_bdd_sll, test_bdd_slt, test_bdd_sltu, test_bdd_sra,
test_bdd_srl, test_bdd_sub, test_bdd_xor, test_bdd_srl, test_bdd_sub, test_bdd_xor, test_glwe_to_glwe_blind_rotation, test_scalar_to_ggsw_blind_rotation,
}, },
blind_rotation::CGGI, blind_rotation::CGGI,
}; };
#[test]
fn test_glwe_to_glwe_blind_rotation_fft64_ref() {
test_glwe_to_glwe_blind_rotation::<FFT64Ref>()
}
#[test]
fn test_scalar_to_ggsw_blind_rotation_fft64_ref() {
test_scalar_to_ggsw_blind_rotation::<FFT64Ref>()
}
#[test] #[test]
fn test_bdd_prepare_fft64_ref() { fn test_bdd_prepare_fft64_ref() {
test_bdd_prepare::<CGGI, FFT64Ref>() test_bdd_prepare::<CGGI, FFT64Ref>()

View File

@@ -0,0 +1,149 @@
use poulpy_core::{
GGSWEncryptSk, GGSWNoise, GLWEDecrypt, GLWEEncryptSk, SIGMA, ScratchTakeCore,
layouts::{
Base2K, Degree, Dnum, Dsize, GGSW, GGSWLayout, GGSWPreparedFactory, GLWESecret, GLWESecretPrepared,
GLWESecretPreparedFactory, LWEInfos, Rank, TorusPrecision,
},
};
use poulpy_hal::{
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxRotateInplace},
layouts::{Backend, Module, ScalarZnx, Scratch, ScratchOwned, ZnxView, ZnxViewMut},
source::Source,
};
use rand::RngCore;
use crate::tfhe::bdd_arithmetic::{FheUintBlocksPrepared, GGSWBlindRotation};
pub fn test_scalar_to_ggsw_blind_rotation<BE: Backend>()
where
Module<BE>: ModuleNew<BE>
+ GLWESecretPreparedFactory<BE>
+ GGSWPreparedFactory<BE>
+ GGSWEncryptSk<BE>
+ GGSWBlindRotation<u32, BE>
+ GGSWNoise<BE>
+ GLWEDecrypt<BE>
+ GLWEEncryptSk<BE>
+ VecZnxRotateInplace<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let n: Degree = Degree(1 << 11);
let base2k: Base2K = Base2K(13);
let rank: Rank = Rank(1);
let k_ggsw_res: TorusPrecision = TorusPrecision(39);
let k_ggsw_apply: TorusPrecision = TorusPrecision(52);
let ggsw_res_infos: GGSWLayout = GGSWLayout {
n,
base2k,
k: k_ggsw_res,
rank,
dnum: Dnum(2),
dsize: Dsize(1),
};
let ggsw_k_infos: GGSWLayout = GGSWLayout {
n,
base2k,
k: k_ggsw_apply,
rank,
dnum: Dnum(3),
dsize: Dsize(1),
};
let n_glwe: usize = n.into();
let module: Module<BE> = Module::<BE>::new(n_glwe as u64);
let mut source: Source = Source::new([6u8; 32]);
let mut source_xs: Source = Source::new([1u8; 32]);
let mut source_xa: Source = Source::new([2u8; 32]);
let mut source_xe: Source = Source::new([3u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(1 << 22);
let mut sk_glwe: GLWESecret<Vec<u8>> = GLWESecret::alloc(n, rank);
sk_glwe.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_glwe_prep: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc(&module, rank);
sk_glwe_prep.prepare(&module, &sk_glwe);
let mut res: GGSW<Vec<u8>> = GGSW::alloc_from_infos(&ggsw_res_infos);
let mut scalar: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n_glwe, 1);
scalar
.raw_mut()
.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = i as i64);
let k: u32 = source.next_u32();
// println!("k: {k}");
let mut k_enc_prep: FheUintBlocksPrepared<Vec<u8>, u32, BE> =
FheUintBlocksPrepared::<Vec<u8>, u32, BE>::alloc(&module, &ggsw_k_infos);
k_enc_prep.encrypt_sk(
&module,
k,
&sk_glwe_prep,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let base: [usize; 2] = [6, 5];
assert_eq!(base.iter().sum::<usize>(), module.log_n());
// Starting bit
let mut bit_start: usize = 0;
let max_noise = |col_i: usize| {
let mut noise: f64 = -(ggsw_res_infos.size() as f64 * base2k.as_usize() as f64) + SIGMA.log2() + 2.0;
noise += 0.5 * ggsw_res_infos.log_n() as f64;
if col_i != 0 {
noise += 0.5 * ggsw_res_infos.log_n() as f64
}
noise
};
for _ in 0..32_usize.div_ceil(module.log_n()) {
// By how many bits to left shift
let mut bit_step: usize = 0;
for digit in base {
let mask: u32 = (1 << digit) - 1;
// How many bits to take
let bit_size: usize = (32 - bit_start).min(digit);
module.scalar_to_ggsw_blind_rotation(
&mut res,
&scalar,
&k_enc_prep,
bit_start,
bit_size,
bit_step,
scratch.borrow(),
);
let rot: i64 = (((k >> bit_start) & mask) << bit_step) as i64;
let mut scalar_want: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(module.n(), 1);
scalar_want.raw_mut().copy_from_slice(scalar.raw());
module.vec_znx_rotate_inplace(-rot, &mut scalar_want.as_vec_znx_mut(), 0, scratch.borrow());
// res.print_noise(&module, &sk_glwe_prep, &scalar_want);
res.assert_noise(&module, &sk_glwe_prep, &scalar_want, &max_noise);
bit_step += digit;
bit_start += digit;
if bit_start >= 32 {
break;
}
}
}
}

View File

@@ -0,0 +1,130 @@
use poulpy_core::{
GGSWEncryptSk, GLWEDecrypt, GLWEEncryptSk, ScratchTakeCore,
layouts::{
Base2K, Degree, Dnum, Dsize, GGSWLayout, GGSWPreparedFactory, GLWE, GLWELayout, GLWEPlaintext, GLWESecret,
GLWESecretPrepared, GLWESecretPreparedFactory, LWEInfos, Rank, TorusPrecision,
},
};
use poulpy_hal::{
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow},
layouts::{Backend, Module, Scratch, ScratchOwned},
source::Source,
};
use rand::RngCore;
use crate::tfhe::bdd_arithmetic::{FheUintBlocksPrepared, GLWEBlindRotation};
pub fn test_glwe_to_glwe_blind_rotation<BE: Backend>()
where
Module<BE>: ModuleNew<BE>
+ GLWESecretPreparedFactory<BE>
+ GGSWPreparedFactory<BE>
+ GGSWEncryptSk<BE>
+ GLWEBlindRotation<u32, BE>
+ GLWEDecrypt<BE>
+ GLWEEncryptSk<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let n: Degree = Degree(1 << 11);
let base2k: Base2K = Base2K(13);
let rank: Rank = Rank(1);
let k_glwe: TorusPrecision = TorusPrecision(26);
let k_ggsw: TorusPrecision = TorusPrecision(39);
let dnum: Dnum = Dnum(3);
let glwe_infos: GLWELayout = GLWELayout {
n,
base2k,
k: k_glwe,
rank,
};
let ggsw_infos: GGSWLayout = GGSWLayout {
n,
base2k,
k: k_ggsw,
rank,
dnum,
dsize: Dsize(1),
};
let n_glwe: usize = glwe_infos.n().into();
let module: Module<BE> = Module::<BE>::new(n_glwe as u64);
let mut source: Source = Source::new([6u8; 32]);
let mut source_xs: Source = Source::new([1u8; 32]);
let mut source_xa: Source = Source::new([2u8; 32]);
let mut source_xe: Source = Source::new([3u8; 32]);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(1 << 22);
let mut sk_glwe: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&glwe_infos);
sk_glwe.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_glwe_prep: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc_from_infos(&module, &glwe_infos);
sk_glwe_prep.prepare(&module, &sk_glwe);
let mut res: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&glwe_infos);
let mut test_glwe: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_infos);
let mut data: Vec<i64> = vec![0i64; module.n()];
data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
test_glwe.encode_vec_i64(&data, base2k.as_usize().into());
let k: u32 = source.next_u32();
let mut k_enc_prep: FheUintBlocksPrepared<Vec<u8>, u32, BE> =
FheUintBlocksPrepared::<Vec<u8>, u32, BE>::alloc(&module, &ggsw_infos);
k_enc_prep.encrypt_sk(
&module,
k,
&sk_glwe_prep,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let base: [usize; 2] = [6, 5];
assert_eq!(base.iter().sum::<usize>(), module.log_n());
// Starting bit
let mut bit_start: usize = 0;
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_infos);
for _ in 0..32_usize.div_ceil(module.log_n()) {
// By how many bits to left shift
let mut bit_step: usize = 0;
for digit in base {
let mask: u32 = (1 << digit) - 1;
// How many bits to take
let bit_size: usize = (32 - bit_start).min(digit);
module.glwe_to_glwe_blind_rotation(
&mut res,
&test_glwe,
&k_enc_prep,
bit_start,
bit_size,
bit_step,
scratch.borrow(),
);
res.decrypt(&module, &mut pt, &sk_glwe_prep, scratch.borrow());
assert_eq!(
(((k >> bit_start) & mask) << bit_step) as i64,
pt.decode_coeff_i64(base2k.as_usize().into(), 0)
);
bit_step += digit;
bit_start += digit;
if bit_start >= 32 {
break;
}
}
}
}

View File

@@ -1,5 +1,7 @@
mod add; mod add;
mod and; mod and;
mod ggsw_blind_rotations;
mod glwe_blind_rotation;
mod or; mod or;
mod prepare; mod prepare;
mod sll; mod sll;
@@ -12,6 +14,8 @@ mod xor;
pub use add::*; pub use add::*;
pub use and::*; pub use and::*;
pub use ggsw_blind_rotations::*;
pub use glwe_blind_rotation::*;
pub use or::*; pub use or::*;
pub use prepare::*; pub use prepare::*;
pub use sll::*; pub use sll::*;

View File

@@ -189,12 +189,12 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, M, BE: Backend>(
brk.data.chunks_exact(block_size) brk.data.chunks_exact(block_size)
) )
.for_each(|(ai, ski)| { .for_each(|(ai, ski)| {
(0..extension_factor).for_each(|i| { for i in 0..extension_factor {
(0..cols).for_each(|j| { for j in 0..cols {
module.vec_znx_dft_apply(1, 0, &mut acc_dft[i], j, &acc[i], j); module.vec_znx_dft_apply(1, 0, &mut acc_dft[i], j, &acc[i], j);
}); module.vec_znx_dft_zero(&mut acc_add_dft[i], j)
module.vec_znx_dft_zero(&mut acc_add_dft[i]) }
}); }
// TODO: first & last iterations can be optimized // TODO: first & last iterations can be optimized
izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| { izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| {
@@ -342,11 +342,10 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, M, BE: Backend>(
brk.data.chunks_exact(block_size) brk.data.chunks_exact(block_size)
) )
.for_each(|(ai, ski)| { .for_each(|(ai, ski)| {
(0..cols).for_each(|j| { for j in 0..cols {
module.vec_znx_dft_apply(1, 0, &mut acc_dft, j, out_mut.data_mut(), j); module.vec_znx_dft_apply(1, 0, &mut acc_dft, j, out_mut.data_mut(), j);
}); module.vec_znx_dft_zero(&mut acc_add_dft, j)
}
module.vec_znx_dft_zero(&mut acc_add_dft);
izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| { izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| {
let ai_pos: usize = ((aii + two_n as i64) & (two_n - 1) as i64) as usize; let ai_pos: usize = ((aii + two_n as i64) & (two_n - 1) as i64) as usize;

View File

@@ -6,7 +6,7 @@ use poulpy_hal::{
}; };
use poulpy_core::{ use poulpy_core::{
GGSWFromGGLWE, GLWEDecrypt, GLWEPacking, GLWETrace, ScratchTakeCore, GGSWFromGGLWE, GLWEDecrypt, GLWEPacking, GLWERotate, GLWETrace, ScratchTakeCore,
layouts::{ layouts::{
Dsize, GGLWELayout, GGSWInfos, GGSWToMut, GLWEInfos, GLWESecretPreparedFactory, GLWEToMut, GLWEToRef, LWEInfos, LWEToRef, Dsize, GGLWELayout, GGSWInfos, GGSWToMut, GLWEInfos, GLWESecretPreparedFactory, GLWEToMut, GLWEToRef, LWEInfos, LWEToRef,
}, },
@@ -115,7 +115,8 @@ where
+ GLWEPacking<BE> + GLWEPacking<BE>
+ GGSWFromGGLWE<BE> + GGSWFromGGLWE<BE>
+ GLWESecretPreparedFactory<BE> + GLWESecretPreparedFactory<BE>
+ GLWEDecrypt<BE>, + GLWEDecrypt<BE>
+ GLWERotate<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
@@ -216,7 +217,9 @@ pub fn circuit_bootstrap_core<R, L, D, M, BRA: BlindRotationAlgo, BE: Backend>(
+ GLWEPacking<BE> + GLWEPacking<BE>
+ GGSWFromGGLWE<BE> + GGSWFromGGLWE<BE>
+ GLWESecretPreparedFactory<BE> + GLWESecretPreparedFactory<BE>
+ GLWEDecrypt<BE>, + GLWEDecrypt<BE>
+ GLWERotate<BE>
+ ModuleLogN,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
@@ -332,7 +335,7 @@ fn post_process<R, A, M, BE: Backend>(
) where ) where
R: GLWEToMut, R: GLWEToMut,
A: GLWEToRef, A: GLWEToRef,
M: ModuleLogN + GLWETrace<BE> + GLWEPacking<BE>, M: ModuleLogN + GLWETrace<BE> + GLWEPacking<BE> + GLWERotate<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();

View File

@@ -1,8 +1,8 @@
use poulpy_core::{ use poulpy_core::{
Distribution, GLWEAutomorphismKeyEncryptSk, GLWETensorKeyEncryptSk, GetDistribution, ScratchTakeCore, Distribution, GGLWEToGGSWKeyEncryptSk, GLWEAutomorphismKeyEncryptSk, GetDistribution, ScratchTakeCore,
layouts::{ layouts::{
GGLWEInfos, GGSWInfos, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEInfos, GLWESecretPreparedFactory, GGLWEInfos, GGLWEToGGSWKey, GGSWInfos, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEInfos,
GLWESecretToRef, GLWETensorKey, GLWETensorKeyLayout, LWEInfos, LWESecretToRef, prepared::GLWESecretPrepared, GLWESecretPreparedFactory, GLWESecretToRef, GLWETensorKeyLayout, LWEInfos, LWESecretToRef, prepared::GLWESecretPrepared,
}, },
trace_galois_elements, trace_galois_elements,
}; };
@@ -81,14 +81,14 @@ impl<BRA: BlindRotationAlgo> CircuitBootstrappingKey<Vec<u8>, BRA> {
(gal_el, key) (gal_el, key)
}) })
.collect(), .collect(),
tsk: GLWETensorKey::alloc_from_infos(trk_infos), tsk: GGLWEToGGSWKey::alloc_from_infos(trk_infos),
} }
} }
} }
pub struct CircuitBootstrappingKey<D: Data, BRA: BlindRotationAlgo> { pub struct CircuitBootstrappingKey<D: Data, BRA: BlindRotationAlgo> {
pub(crate) brk: BlindRotationKey<D, BRA>, pub(crate) brk: BlindRotationKey<D, BRA>,
pub(crate) tsk: GLWETensorKey<Vec<u8>>, pub(crate) tsk: GGLWEToGGSWKey<Vec<u8>>,
pub(crate) atk: HashMap<i64, GLWEAutomorphismKey<Vec<u8>>>, pub(crate) atk: HashMap<i64, GLWEAutomorphismKey<Vec<u8>>>,
} }
@@ -112,7 +112,7 @@ impl<D: DataMut, BRA: BlindRotationAlgo> CircuitBootstrappingKey<D, BRA> {
impl<BRA: BlindRotationAlgo, BE: Backend> CircuitBootstrappingKeyEncryptSk<BRA, BE> for Module<BE> impl<BRA: BlindRotationAlgo, BE: Backend> CircuitBootstrappingKeyEncryptSk<BRA, BE> for Module<BE>
where where
Self: GLWETensorKeyEncryptSk<BE> Self: GGLWEToGGSWKeyEncryptSk<BE>
+ BlindRotationKeyEncryptSk<BRA, BE> + BlindRotationKeyEncryptSk<BRA, BE>
+ GLWEAutomorphismKeyEncryptSk<BE> + GLWEAutomorphismKeyEncryptSk<BE>
+ GLWESecretPreparedFactory<BE>, + GLWESecretPreparedFactory<BE>,

View File

@@ -1,8 +1,8 @@
use poulpy_core::{ use poulpy_core::{
layouts::{ layouts::{
GGLWEInfos, GGSWInfos, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWEInfos, GLWETensorKeyLayout, GGLWEInfos, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedFactory, GGSWInfos, GLWEAutomorphismKeyLayout,
GLWETensorKeyPreparedFactory, LWEInfos, GLWEAutomorphismKeyPreparedFactory, GLWEInfos, GLWETensorKeyLayout, GLWETensorKeyPreparedFactory, LWEInfos,
prepared::{GLWEAutomorphismKeyPrepared, GLWETensorKeyPrepared}, prepared::GLWEAutomorphismKeyPrepared,
}, },
trace_galois_elements, trace_galois_elements,
}; };
@@ -50,7 +50,7 @@ pub trait CircuitBootstrappingKeyPreparedFactory<BRA: BlindRotationAlgo, BE: Bac
where where
Self: Sized Self: Sized
+ BlindRotationKeyPreparedFactory<BRA, BE> + BlindRotationKeyPreparedFactory<BRA, BE>
+ GLWETensorKeyPreparedFactory<BE> + GGLWEToGGSWKeyPreparedFactory<BE>
+ GLWEAutomorphismKeyPreparedFactory<BE>, + GLWEAutomorphismKeyPreparedFactory<BE>,
{ {
fn circuit_bootstrapping_key_prepared_alloc_from_infos<A>( fn circuit_bootstrapping_key_prepared_alloc_from_infos<A>(
@@ -65,7 +65,7 @@ where
CircuitBootstrappingKeyPrepared { CircuitBootstrappingKeyPrepared {
brk: BlindRotationKeyPrepared::alloc(self, &infos.brk_infos()), brk: BlindRotationKeyPrepared::alloc(self, &infos.brk_infos()),
tsk: GLWETensorKeyPrepared::alloc_from_infos(self, &infos.tsk_infos()), tsk: GGLWEToGGSWKeyPrepared::alloc_from_infos(self, &infos.tsk_infos()),
atk: gal_els atk: gal_els
.iter() .iter()
.map(|&gal_el| { .map(|&gal_el| {
@@ -81,7 +81,7 @@ where
A: CircuitBootstrappingKeyInfos, A: CircuitBootstrappingKeyInfos,
{ {
self.blind_rotation_key_prepare_tmp_bytes(&infos.brk_infos()) self.blind_rotation_key_prepare_tmp_bytes(&infos.brk_infos())
.max(self.prepare_tensor_key_tmp_bytes(&infos.tsk_infos())) .max(self.prepare_gglwe_to_ggsw_key_tmp_bytes(&infos.tsk_infos()))
.max(self.prepare_glwe_automorphism_key_tmp_bytes(&infos.atk_infos())) .max(self.prepare_glwe_automorphism_key_tmp_bytes(&infos.atk_infos()))
} }
@@ -105,7 +105,7 @@ where
pub struct CircuitBootstrappingKeyPrepared<D: Data, BRA: BlindRotationAlgo, B: Backend> { pub struct CircuitBootstrappingKeyPrepared<D: Data, BRA: BlindRotationAlgo, B: Backend> {
pub(crate) brk: BlindRotationKeyPrepared<D, BRA, B>, pub(crate) brk: BlindRotationKeyPrepared<D, BRA, B>,
pub(crate) tsk: GLWETensorKeyPrepared<Vec<u8>, B>, pub(crate) tsk: GGLWEToGGSWKeyPrepared<Vec<u8>, B>,
pub(crate) atk: HashMap<i64, GLWEAutomorphismKeyPrepared<Vec<u8>, B>>, pub(crate) atk: HashMap<i64, GLWEAutomorphismKeyPrepared<Vec<u8>, B>>,
} }