Remove Zn (replaced by VecZnx), add more cross-base2k ops & tests

This commit is contained in:
Pro7ech
2025-11-18 01:08:20 +01:00
parent a3264b8851
commit f39e3e2865
52 changed files with 952 additions and 1550 deletions

View File

@@ -1,5 +1,5 @@
use itertools::izip; use itertools::izip;
use poulpy_backend::cpu_spqlios::FFT64Spqlios; use poulpy_backend::cpu_fft64_ref::FFT64Ref;
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal, ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal,
@@ -16,9 +16,9 @@ fn main() {
let ct_size: usize = 3; let ct_size: usize = 3;
let msg_size: usize = 2; let msg_size: usize = 2;
let log_scale: usize = msg_size * base2k - 5; let log_scale: usize = msg_size * base2k - 5;
let module: Module<FFT64Spqlios> = Module::<FFT64Spqlios>::new(n as u64); let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(n as u64);
let mut scratch: ScratchOwned<FFT64Spqlios> = ScratchOwned::<FFT64Spqlios>::alloc(module.vec_znx_big_normalize_tmp_bytes()); let mut scratch: ScratchOwned<FFT64Ref> = ScratchOwned::<FFT64Ref>::alloc(module.vec_znx_big_normalize_tmp_bytes());
let seed: [u8; 32] = [0; 32]; let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed); let mut source: Source = Source::new(seed);
@@ -28,7 +28,7 @@ fn main() {
s.fill_ternary_prob(0, 0.5, &mut source); s.fill_ternary_prob(0, 0.5, &mut source);
// Buffer to store s in the DFT domain // Buffer to store s in the DFT domain
let mut s_dft: SvpPPol<Vec<u8>, FFT64Spqlios> = module.svp_ppol_alloc(s.cols()); let mut s_dft: SvpPPol<Vec<u8>, FFT64Ref> = module.svp_ppol_alloc(s.cols());
// s_dft <- DFT(s) // s_dft <- DFT(s)
module.svp_prepare(&mut s_dft, 0, &s, 0); module.svp_prepare(&mut s_dft, 0, &s, 0);
@@ -43,7 +43,7 @@ fn main() {
// Fill the second column with random values: ct = (0, a) // Fill the second column with random values: ct = (0, a)
module.vec_znx_fill_uniform(base2k, &mut ct, 1, &mut source); module.vec_znx_fill_uniform(base2k, &mut ct, 1, &mut source);
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64Spqlios> = module.vec_znx_dft_alloc(1, ct_size); let mut buf_dft: VecZnxDft<Vec<u8>, FFT64Ref> = module.vec_znx_dft_alloc(1, ct_size);
module.vec_znx_dft_apply(1, 0, &mut buf_dft, 0, &ct, 1); module.vec_znx_dft_apply(1, 0, &mut buf_dft, 0, &ct, 1);
@@ -58,7 +58,7 @@ fn main() {
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>) // Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized) // BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
let mut buf_big: VecZnxBig<Vec<u8>, FFT64Spqlios> = module.vec_znx_big_alloc(1, ct_size); let mut buf_big: VecZnxBig<Vec<u8>, FFT64Ref> = module.vec_znx_big_alloc(1, ct_size);
module.vec_znx_idft_apply_tmpa(&mut buf_big, 0, &mut buf_dft, 0); module.vec_znx_idft_apply_tmpa(&mut buf_big, 0, &mut buf_dft, 0);
// Creates a plaintext: VecZnx with 1 column // Creates a plaintext: VecZnx with 1 column

View File

@@ -7,7 +7,6 @@ mod vec_znx;
mod vec_znx_big; mod vec_znx_big;
mod vec_znx_dft; mod vec_znx_dft;
mod vmp; mod vmp;
mod zn;
mod znx_avx; mod znx_avx;
pub struct FFT64Avx {} pub struct FFT64Avx {}

View File

@@ -1,73 +0,0 @@
use poulpy_hal::{
api::TakeSlice,
layouts::{Scratch, ZnToMut},
oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl},
reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform, zn_normalize_inplace, zn_normalize_tmp_bytes},
source::Source,
};
use crate::cpu_fft64_avx::FFT64Avx;
unsafe impl ZnNormalizeTmpBytesImpl<Self> for FFT64Avx {
fn zn_normalize_tmp_bytes_impl(n: usize) -> usize {
zn_normalize_tmp_bytes(n)
}
}
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Avx
where
Self: TakeSliceImpl<Self>,
{
fn zn_normalize_inplace_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
where
R: ZnToMut,
{
let (carry, _) = scratch.take_slice(n);
zn_normalize_inplace::<R, FFT64Avx>(n, base2k, res, res_col, carry);
}
}
unsafe impl ZnFillUniformImpl<Self> for FFT64Avx {
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
zn_fill_uniform(n, base2k, res, res_col, source);
}
}
unsafe impl ZnFillNormalImpl<Self> for FFT64Avx {
#[allow(clippy::too_many_arguments)]
fn zn_fill_normal_impl<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}
unsafe impl ZnAddNormalImpl<Self> for FFT64Avx {
#[allow(clippy::too_many_arguments)]
fn zn_add_normal_impl<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}

View File

@@ -6,7 +6,6 @@ mod vec_znx;
mod vec_znx_big; mod vec_znx_big;
mod vec_znx_dft; mod vec_znx_dft;
mod vmp; mod vmp;
mod zn;
mod znx; mod znx;
#[cfg(test)] #[cfg(test)]

View File

@@ -1,73 +0,0 @@
use poulpy_hal::{
api::TakeSlice,
layouts::{Scratch, ZnToMut},
oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl},
reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform, zn_normalize_inplace, zn_normalize_tmp_bytes},
source::Source,
};
use crate::cpu_fft64_ref::FFT64Ref;
unsafe impl ZnNormalizeTmpBytesImpl<Self> for FFT64Ref {
fn zn_normalize_tmp_bytes_impl(n: usize) -> usize {
zn_normalize_tmp_bytes(n)
}
}
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Ref
where
Self: TakeSliceImpl<Self>,
{
fn zn_normalize_inplace_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
where
R: ZnToMut,
{
let (carry, _) = scratch.take_slice(n);
zn_normalize_inplace::<R, FFT64Ref>(n, base2k, res, res_col, carry);
}
}
unsafe impl ZnFillUniformImpl<Self> for FFT64Ref {
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
zn_fill_uniform(n, base2k, res, res_col, source);
}
}
unsafe impl ZnFillNormalImpl<Self> for FFT64Ref {
#[allow(clippy::too_many_arguments)]
fn zn_fill_normal_impl<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}
unsafe impl ZnAddNormalImpl<Self> for FFT64Ref {
#[allow(clippy::too_many_arguments)]
fn zn_add_normal_impl<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}

View File

@@ -5,7 +5,6 @@ mod vec_znx;
mod vec_znx_big; mod vec_znx_big;
mod vec_znx_dft; mod vec_znx_dft;
mod vmp_pmat; mod vmp_pmat;
mod zn;
mod znx; mod znx;
pub struct FFT64Spqlios; pub struct FFT64Spqlios;

View File

@@ -1,82 +0,0 @@
use poulpy_hal::{
api::TakeSlice,
layouts::{Scratch, Zn, ZnToMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut},
oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl},
reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform},
source::Source,
};
use crate::cpu_spqlios::{FFT64Spqlios, ffi::zn64};
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Spqlios
where
Self: TakeSliceImpl<Self>,
{
fn zn_normalize_inplace_impl<A>(n: usize, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<Self>)
where
A: ZnToMut,
{
let mut a: Zn<&mut [u8]> = a.to_mut();
let (tmp_bytes, _) = scratch.take_slice(n * size_of::<i64>());
unsafe {
zn64::zn64_normalize_base2k_ref(
n as u64,
base2k as u64,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
}
unsafe impl ZnFillUniformImpl<Self> for FFT64Spqlios {
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
zn_fill_uniform(n, base2k, res, res_col, source);
}
}
unsafe impl ZnFillNormalImpl<Self> for FFT64Spqlios {
#[allow(clippy::too_many_arguments)]
fn zn_fill_normal_impl<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}
unsafe impl ZnAddNormalImpl<Self> for FFT64Spqlios {
#[allow(clippy::too_many_arguments)]
fn zn_add_normal_impl<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}

View File

@@ -55,7 +55,11 @@ where
A: GGLWEInfos, A: GGLWEInfos,
K: GGLWEInfos, K: GGLWEInfos,
{ {
if res_infos.glwe_layout() == a_infos.glwe_layout() {
self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos)
} else {
self.glwe_keyswitch_tmp_bytes(res_infos, a_infos, key_infos) + GLWE::bytes_of_from_infos(a_infos)
}
} }
fn glwe_automorphism_key_automorphism<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>) fn glwe_automorphism_key_automorphism<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
@@ -79,12 +83,16 @@ where
a.dsize() a.dsize()
); );
assert_eq!(res.base2k(), a.base2k());
let cols_out: usize = (key.rank_out() + 1).into(); let cols_out: usize = (key.rank_out() + 1).into();
let cols_in: usize = key.rank_in().into(); let cols_in: usize = key.rank_in().into();
let p: i64 = a.p(); let p: i64 = a.p();
let p_inv: i64 = self.galois_element_inv(p); let p_inv: i64 = self.galois_element_inv(p);
let same_layout: bool = res.glwe_layout() == a.glwe_layout();
{ {
let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut();
let a: &GGLWE<&[u8]> = &a.to_ref(); let a: &GGLWE<&[u8]> = &a.to_ref();
@@ -94,6 +102,7 @@ where
let mut res_tmp: GLWE<&mut [u8]> = res.at_mut(row, col); let mut res_tmp: GLWE<&mut [u8]> = res.at_mut(row, col);
let a_ct: GLWE<&[u8]> = a.at(row, col); let a_ct: GLWE<&[u8]> = a.at(row, col);
if same_layout {
// Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
for i in 0..cols_out { for i in 0..cols_out {
self.vec_znx_automorphism(p, res_tmp.data_mut(), i, &a_ct.data, i); self.vec_znx_automorphism(p, res_tmp.data_mut(), i, &a_ct.data, i);
@@ -101,11 +110,22 @@ where
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
self.glwe_keyswitch_inplace(&mut res_tmp, key, scratch); self.glwe_keyswitch_inplace(&mut res_tmp, key, scratch);
} else {
let (mut tmp_glwe, scratch_1) = scratch.take_glwe(a);
// Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
for i in 0..cols_out {
self.vec_znx_automorphism(p, tmp_glwe.data_mut(), i, &a_ct.data, i);
}
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
self.glwe_keyswitch(&mut res_tmp, &tmp_glwe, key, scratch_1);
}
// Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a) // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a)
(0..cols_out).for_each(|i| { for i in 0..cols_out {
self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch); self.vec_znx_automorphism_inplace(p_inv, res_tmp.data_mut(), i, scratch);
}); }
} }
} }
} }

View File

@@ -34,9 +34,9 @@ impl GGSW<Vec<u8>> {
impl<D: DataMut> GGSW<D> { impl<D: DataMut> GGSW<D> {
pub fn automorphism<A, K, T, M, BE: Backend>(&mut self, module: &M, a: &A, key: &K, tsk: &T, scratch: &mut Scratch<BE>) pub fn automorphism<A, K, T, M, BE: Backend>(&mut self, module: &M, a: &A, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
where where
A: GGSWToRef, A: GGSWToRef + GGSWInfos,
K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos,
T: GGLWEToGGSWKeyPreparedToRef<BE>, T: GGLWEToGGSWKeyPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGSWAutomorphism<BE>, M: GGSWAutomorphism<BE>,
{ {
@@ -73,20 +73,21 @@ where
fn ggsw_automorphism<R, A, K, T>(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch<BE>) fn ggsw_automorphism<R, A, K, T>(&self, res: &mut R, a: &A, key: &K, tsk: &T, scratch: &mut Scratch<BE>)
where where
R: GGSWToMut, R: GGSWToMut + GGSWInfos,
A: GGSWToRef, A: GGSWToRef + GGSWInfos,
K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos, K: GetGaloisElement + GGLWEPreparedToRef<BE> + GGLWEInfos,
T: GGLWEToGGSWKeyPreparedToRef<BE>, T: GGLWEToGGSWKeyPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
assert_eq!(res.dsize(), a.dsize());
assert_eq!(res.base2k(), a.base2k());
assert!(res.dnum() <= a.dnum());
assert!(scratch.available() >= self.ggsw_automorphism_tmp_bytes(res, a, key, tsk));
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let a: &GGSW<&[u8]> = &a.to_ref(); let a: &GGSW<&[u8]> = &a.to_ref();
let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
assert_eq!(res.dsize(), a.dsize());
assert!(res.dnum() <= a.dnum());
assert!(scratch.available() >= self.ggsw_automorphism_tmp_bytes(res, a, key, tsk));
// Keyswitch the j-th row of the col 0 // Keyswitch the j-th row of the col 0
for row in 0..res.dnum().as_usize() { for row in 0..res.dnum().as_usize() {
// Key-switch column 0, i.e. // Key-switch column 0, i.e.

View File

@@ -7,8 +7,8 @@ use poulpy_hal::{
}; };
use crate::{ use crate::{
GLWEKeySwitchInternal, GLWEKeyswitch, ScratchTakeCore, GLWEKeySwitchInternal, GLWEKeyswitch, GLWENormalize, ScratchTakeCore,
layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos}, layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos},
}; };
impl GLWE<Vec<u8>> { impl GLWE<Vec<u8>> {
@@ -164,7 +164,8 @@ where
+ VecZnxBigSubSmallInplace<BE> + VecZnxBigSubSmallInplace<BE>
+ VecZnxBigSubSmallNegateInplace<BE> + VecZnxBigSubSmallNegateInplace<BE>
+ VecZnxBigAddSmallInplace<BE> + VecZnxBigAddSmallInplace<BE>
+ VecZnxBigNormalize<BE>, + VecZnxBigNormalize<BE>
+ GLWENormalize<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
fn glwe_automorphism_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize fn glwe_automorphism_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
@@ -217,22 +218,50 @@ where
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref(); let a: &GLWE<&[u8]> = &a.to_ref();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let base2k_a: usize = a.base2k().into();
let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); let base2k_key: usize = key.base2k().into();
let base2k_res: usize = res.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
if base2k_a != base2k_key {
let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: a.n(),
base2k: key.base2k(),
k: a.k(),
rank: a.rank(),
});
self.glwe_normalize(&mut a_conv, a, scratch_2);
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &a_conv, key, scratch_2);
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2);
self.vec_znx_big_add_small_inplace(&mut res_big, i, a_conv.data(), i);
self.vec_znx_big_normalize(
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_2,
);
}
} else {
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
self.vec_znx_big_add_small_inplace(&mut res_big, i, a.data(), i); self.vec_znx_big_add_small_inplace(&mut res_big, i, a.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(
res.base2k().into(), base2k_res,
res.data_mut(), res.data_mut(),
i, i,
key.base2k().into(), base2k_key,
&res_big, &res_big,
i, i,
scratch_1, scratch_1,
); );
} }
};
} }
fn glwe_automorphism_add_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>) fn glwe_automorphism_add_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
@@ -243,22 +272,49 @@ where
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let base2k_key: usize = key.base2k().into();
let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); let base2k_res: usize = res.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
if base2k_res != base2k_key {
let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: res.n(),
base2k: key.base2k(),
k: res.k(),
rank: res.rank(),
});
self.glwe_normalize(&mut res_conv, res, scratch_2);
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &res_conv, key, scratch_2);
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2);
self.vec_znx_big_add_small_inplace(&mut res_big, i, res_conv.data(), i);
self.vec_znx_big_normalize(
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_2,
);
}
} else {
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
self.vec_znx_big_add_small_inplace(&mut res_big, i, res.data(), i); self.vec_znx_big_add_small_inplace(&mut res_big, i, res.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(
res.base2k().into(), base2k_res,
res.data_mut(), res.data_mut(),
i, i,
key.base2k().into(), base2k_key,
&res_big, &res_big,
i, i,
scratch_1, scratch_1,
); );
} }
};
} }
fn glwe_automorphism_sub<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>) fn glwe_automorphism_sub<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
@@ -271,22 +327,50 @@ where
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref(); let a: &GLWE<&[u8]> = &a.to_ref();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let base2k_a: usize = a.base2k().into();
let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); let base2k_key: usize = key.base2k().into();
let base2k_res: usize = res.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
if base2k_a != base2k_key {
let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: a.n(),
base2k: key.base2k(),
k: a.k(),
rank: a.rank(),
});
self.glwe_normalize(&mut a_conv, a, scratch_2);
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &a_conv, key, scratch_2);
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2);
self.vec_znx_big_sub_small_inplace(&mut res_big, i, a_conv.data(), i);
self.vec_znx_big_normalize(
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_2,
);
}
} else {
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
self.vec_znx_big_sub_small_inplace(&mut res_big, i, a.data(), i); self.vec_znx_big_sub_small_inplace(&mut res_big, i, a.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(
res.base2k().into(), base2k_res,
res.data_mut(), res.data_mut(),
i, i,
key.base2k().into(), base2k_key,
&res_big, &res_big,
i, i,
scratch_1, scratch_1,
); );
} }
};
} }
fn glwe_automorphism_sub_negate<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>) fn glwe_automorphism_sub_negate<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
@@ -299,22 +383,50 @@ where
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref(); let a: &GLWE<&[u8]> = &a.to_ref();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let base2k_a: usize = a.base2k().into();
let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1); let base2k_key: usize = key.base2k().into();
let base2k_res: usize = res.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
if base2k_a != base2k_key {
let (mut a_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: a.n(),
base2k: key.base2k(),
k: a.k(),
rank: a.rank(),
});
self.glwe_normalize(&mut a_conv, a, scratch_2);
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &a_conv, key, scratch_2);
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2);
self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a_conv.data(), i);
self.vec_znx_big_normalize(
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_2,
);
}
} else {
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, a, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a.data(), i); self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, a.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(
res.base2k().into(), base2k_res,
res.data_mut(), res.data_mut(),
i, i,
key.base2k().into(), base2k_key,
&res_big, &res_big,
i, i,
scratch_1, scratch_1,
); );
} }
};
} }
fn glwe_automorphism_sub_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>) fn glwe_automorphism_sub_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
@@ -325,22 +437,49 @@ where
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let base2k_key: usize = key.base2k().into();
let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); let base2k_res: usize = res.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
if base2k_res != base2k_key {
let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: res.n(),
base2k: key.base2k(),
k: res.k(),
rank: res.rank(),
});
self.glwe_normalize(&mut res_conv, res, scratch_2);
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &res_conv, key, scratch_2);
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2);
self.vec_znx_big_sub_small_inplace(&mut res_big, i, res_conv.data(), i);
self.vec_znx_big_normalize(
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_2,
);
}
} else {
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
self.vec_znx_big_sub_small_inplace(&mut res_big, i, res.data(), i); self.vec_znx_big_sub_small_inplace(&mut res_big, i, res.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(
res.base2k().into(), base2k_res,
res.data_mut(), res.data_mut(),
i, i,
key.base2k().into(), base2k_key,
&res_big, &res_big,
i, i,
scratch_1, scratch_1,
); );
} }
};
} }
fn glwe_automorphism_sub_negate_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>) fn glwe_automorphism_sub_negate_inplace<R, K>(&self, res: &mut R, key: &K, scratch: &mut Scratch<BE>)
@@ -351,21 +490,48 @@ where
{ {
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut(); let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size let base2k_key: usize = key.base2k().into();
let mut res_big: VecZnxBig<_, BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1); let base2k_res: usize = res.base2k().into();
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), key.size()); // TODO: optimise size
if base2k_res != base2k_key {
let (mut res_conv, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: res.n(),
base2k: key.base2k(),
k: res.k(),
rank: res.rank(),
});
self.glwe_normalize(&mut res_conv, res, scratch_2);
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, &res_conv, key, scratch_2);
for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_2);
self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res_conv.data(), i);
self.vec_znx_big_normalize(
base2k_res,
res.data_mut(),
i,
base2k_key,
&res_big,
i,
scratch_2,
);
}
} else {
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_keyswitch_internal(res_dft, res, key, scratch_1);
for i in 0..res.rank().as_usize() + 1 { for i in 0..res.rank().as_usize() + 1 {
self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1); self.vec_znx_big_automorphism_inplace(key.p(), &mut res_big, i, scratch_1);
self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res.data(), i); self.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, res.data(), i);
self.vec_znx_big_normalize( self.vec_znx_big_normalize(
res.base2k().into(), base2k_res,
res.data_mut(), res.data_mut(),
i, i,
key.base2k().into(), base2k_key,
&res_big, &res_big,
i, i,
scratch_1, scratch_1,
); );
} }
};
} }
} }

View File

@@ -1,9 +1,9 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ api::{
ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize, ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
}, },
layouts::{Backend, DataMut, Module, Scratch, VecZnxBig}, layouts::{Backend, DataMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VecZnxDftToRef, VecZnxToRef},
}; };
use crate::{ use crate::{
@@ -65,6 +65,7 @@ where
assert_eq!(res.n(), self.n() as u32); assert_eq!(res.n(), self.n() as u32);
assert_eq!(a.n(), self.n() as u32); assert_eq!(a.n(), self.n() as u32);
assert_eq!(tsk.n(), self.n() as u32); assert_eq!(tsk.n(), self.n() as u32);
assert_eq!(res.base2k(), a.base2k());
for row in 0..res.dnum().into() { for row in 0..res.dnum().into() {
self.glwe_copy(&mut res.at_mut(row, 0), &a.at(row, 0)); self.glwe_copy(&mut res.at_mut(row, 0), &a.at(row, 0));
@@ -111,28 +112,29 @@ where
+ VecZnxDftApply<BE> + VecZnxDftApply<BE>
+ VecZnxNormalize<BE> + VecZnxNormalize<BE>
+ VecZnxBigAddSmallInplace<BE> + VecZnxBigAddSmallInplace<BE>
+ VecZnxIdftApplyConsume<BE>, + VecZnxIdftApplyConsume<BE>
+ VecZnxCopy,
{ {
fn ggsw_expand_rows_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize fn ggsw_expand_rows_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize
where where
R: GGSWInfos, R: GGSWInfos,
A: GGLWEInfos, A: GGLWEInfos,
{ {
let base2k_in: usize = res_infos.base2k().into();
let base2k_tsk: usize = tsk_infos.base2k().into(); let base2k_tsk: usize = tsk_infos.base2k().into();
let rank: usize = res_infos.rank().into(); let rank: usize = res_infos.rank().into();
let cols: usize = rank + 1; let cols: usize = rank + 1;
let res_size = res_infos.size(); let res_size: usize = res_infos.size();
let a_size: usize = (res_infos.size() * base2k_in).div_ceil(base2k_tsk); let a_size: usize = res_infos.max_k().as_usize().div_ceil(base2k_tsk);
let a_dft = self.bytes_of_vec_znx_dft(cols - 1, a_size); let a_0: usize = VecZnx::bytes_of(self.n(), 1, a_size);
let res_dft = self.bytes_of_vec_znx_dft(cols, a_size); let a_dft: usize = self.bytes_of_vec_znx_dft(cols - 1, a_size);
let res_dft: usize = self.bytes_of_vec_znx_dft(cols, a_size);
let gglwe_prod: usize = self.gglwe_product_dft_tmp_bytes(res_size, a_size, tsk_infos); let gglwe_prod: usize = self.gglwe_product_dft_tmp_bytes(res_size, a_size, tsk_infos);
let normalize = self.vec_znx_big_normalize_tmp_bytes(); let normalize: usize = self.vec_znx_big_normalize_tmp_bytes();
(a_dft + res_dft + gglwe_prod).max(normalize) (a_0 + a_dft + res_dft + gglwe_prod).max(normalize)
} }
fn ggsw_expand_row<R, T>(&self, res: &mut R, tsk: &T, scratch: &mut Scratch<BE>) fn ggsw_expand_row<R, T>(&self, res: &mut R, tsk: &T, scratch: &mut Scratch<BE>)
@@ -144,7 +146,7 @@ where
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut(); let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref(); let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
let base2k_in: usize = res.base2k().into(); let base2k_res: usize = res.base2k().into();
let base2k_tsk: usize = tsk.base2k().into(); let base2k_tsk: usize = tsk.base2k().into();
assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk)); assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk));
@@ -152,35 +154,70 @@ where
let rank: usize = res.rank().into(); let rank: usize = res.rank().into();
let cols: usize = rank + 1; let cols: usize = rank + 1;
let a_size: usize = (res.size() * base2k_in).div_ceil(base2k_tsk); let res_conv_size: usize = res.max_k().as_usize().div_ceil(base2k_tsk);
let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, res_conv_size);
let (mut a_0, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, res_conv_size);
// Keyswitch the j-th row of the col 0 // Keyswitch the j-th row of the col 0
for row in 0..res.dnum().as_usize() { for row in 0..res.dnum().as_usize() {
let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, a_size);
{
let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0); let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0);
if base2k_in == base2k_tsk { if base2k_res == base2k_tsk {
for col_i in 0..cols - 1 { for col_i in 0..cols - 1 {
self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1); self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1);
} }
self.vec_znx_copy(&mut a_0, 0, glwe_mi_1.data(), 0);
} else { } else {
let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, a_size);
for i in 0..cols - 1 { for i in 0..cols - 1 {
self.vec_znx_normalize( self.vec_znx_normalize(
base2k_tsk, base2k_tsk,
&mut a_conv, &mut a_0,
0, 0,
base2k_in, base2k_res,
glwe_mi_1.data(), glwe_mi_1.data(),
i + 1, i + 1,
scratch_2, scratch_2,
); );
self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_conv, 0); self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_0, 0);
} }
self.vec_znx_normalize(
base2k_tsk,
&mut a_0,
0,
base2k_res,
glwe_mi_1.data(),
0,
scratch_2,
);
}
ggsw_expand_rows_internal(self, row, res, &a_0, &a_dft, tsk, scratch_2)
} }
} }
}
fn ggsw_expand_rows_internal<M, R, C, A, T, BE: Backend>(
module: &M,
row: usize,
res: &mut R,
a_0: &C,
a_dft: &A,
tsk: &T,
scratch: &mut Scratch<BE>,
) where
R: GGSWToMut,
C: VecZnxToRef,
A: VecZnxDftToRef<BE>,
M: GGLWEProduct<BE> + VecZnxIdftApplyConsume<BE> + VecZnxBigAddSmallInplace<BE> + VecZnxBigNormalize<BE>,
T: GGLWEToGGSWKeyPreparedToRef<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
let a_0: &VecZnx<&[u8]> = &a_0.to_ref();
let a_dft: &VecZnxDft<&[u8], BE> = &a_dft.to_ref();
let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
let cols: usize = res.rank().as_usize() + 1;
// Example for rank 3: // Example for rank 3:
// //
@@ -202,7 +239,7 @@ where
// col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 ) // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 )
// col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i]) // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i])
for col in 1..cols { for col in 1..cols {
let (mut res_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, tsk.size()); // Todo optimise let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols, tsk.size()); // Todo optimise
// Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2 // Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
// //
@@ -215,9 +252,9 @@ where
// a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2) // a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2)
// = // =
// (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2) // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2)
self.gglwe_product_dft(&mut res_dft, &a_dft, tsk.at(col - 1), scratch_2); module.gglwe_product_dft(&mut res_dft, a_dft, tsk.at(col - 1), scratch_1);
let mut res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft); let mut res_big: VecZnxBig<&mut [u8], BE> = module.vec_znx_idft_apply_consume(res_dft);
// Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i // Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i
// //
@@ -228,20 +265,18 @@ where
// (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2) // (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2)
// = // =
// (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2) // (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2)
self.vec_znx_big_add_small_inplace(&mut res_big, col, res.at(row, 0).data(), 0); module.vec_znx_big_add_small_inplace(&mut res_big, col, a_0, 0);
for j in 0..cols { for j in 0..cols {
self.vec_znx_big_normalize( module.vec_znx_big_normalize(
res.base2k().as_usize(), res.base2k().as_usize(),
res.at_mut(row, col).data_mut(), res.at_mut(row, col).data_mut(),
j, j,
tsk.base2k().as_usize(), tsk.base2k().as_usize(),
&res_big, &res_big,
j, j,
scratch_2, scratch_1,
); );
} }
} }
}
}
} }

View File

@@ -1,39 +1,44 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace}, api::VecZnxNormalizeInplace,
layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, ZnxView, ZnxViewMut}, layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut},
}; };
use crate::layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToMut, LWESecret, LWESecretToRef, LWEToMut}; use crate::{
ScratchTakeCore,
layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToMut, LWESecret, LWESecretToRef, LWEToMut},
};
impl<DataSelf: DataRef + DataMut> LWE<DataSelf> { impl<DataSelf: DataRef + DataMut> LWE<DataSelf> {
pub fn decrypt<P, S, M, B: Backend>(&mut self, module: &M, pt: &mut P, sk: &S) pub fn decrypt<P, S, M, BE: Backend>(&mut self, module: &M, pt: &mut P, sk: &S, scratch: &mut Scratch<BE>)
where where
P: LWEPlaintextToMut, P: LWEPlaintextToMut,
S: LWESecretToRef, S: LWESecretToRef,
M: LWEDecrypt<B>, M: LWEDecrypt<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
module.lwe_decrypt(self, pt, sk); module.lwe_decrypt(self, pt, sk, scratch);
} }
} }
pub trait LWEDecrypt<BE: Backend> { pub trait LWEDecrypt<BE: Backend> {
fn lwe_decrypt<R, P, S>(&self, res: &mut R, pt: &mut P, sk: &S) fn lwe_decrypt<R, P, S>(&self, res: &mut R, pt: &mut P, sk: &S, scratch: &mut Scratch<BE>)
where
R: LWEToMut,
P: LWEPlaintextToMut,
S: LWESecretToRef;
}
impl<BE: Backend> LWEDecrypt<BE> for Module<BE>
where
Self: Sized + ZnNormalizeInplace<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{
fn lwe_decrypt<R, P, S>(&self, res: &mut R, pt: &mut P, sk: &S)
where where
R: LWEToMut, R: LWEToMut,
P: LWEPlaintextToMut, P: LWEPlaintextToMut,
S: LWESecretToRef, S: LWESecretToRef,
Scratch<BE>: ScratchTakeCore<BE>;
}
impl<BE: Backend> LWEDecrypt<BE> for Module<BE>
where
Self: Sized + VecZnxNormalizeInplace<BE>,
{
fn lwe_decrypt<R, P, S>(&self, res: &mut R, pt: &mut P, sk: &S, scratch: &mut Scratch<BE>)
where
R: LWEToMut,
P: LWEPlaintextToMut,
S: LWESecretToRef,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); let res: &mut LWE<&mut [u8]> = &mut res.to_mut();
let pt: &mut LWEPlaintext<&mut [u8]> = &mut pt.to_mut(); let pt: &mut LWEPlaintext<&mut [u8]> = &mut pt.to_mut();
@@ -52,13 +57,7 @@ where
.map(|(x, y)| x * y) .map(|(x, y)| x * y)
.sum::<i64>(); .sum::<i64>();
}); });
self.zn_normalize_inplace( self.vec_znx_normalize_inplace(res.base2k().into(), &mut pt.data, 0, scratch);
1,
res.base2k().into(),
&mut pt.data,
0,
ScratchOwned::alloc(size_of::<i64>()).borrow(),
);
pt.base2k = res.base2k(); pt.base2k = res.base2k();
pt.k = crate::layouts::TorusPrecision(res.k().0.min(pt.size() as u32 * res.base2k().0)); pt.k = crate::layouts::TorusPrecision(res.k().0.min(pt.size() as u32 * res.base2k().0));
} }

View File

@@ -1,43 +1,67 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace}, api::{VecZnxAddNormal, VecZnxFillUniform, VecZnxNormalizeInplace},
layouts::{Backend, DataMut, Module, ScratchOwned, Zn, ZnxView, ZnxViewMut}, layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxView, ZnxViewMut},
source::Source, source::Source,
}; };
use crate::{ use crate::{
ScratchTakeCore,
encryption::{SIGMA, SIGMA_BOUND}, encryption::{SIGMA, SIGMA_BOUND},
layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToRef, LWESecret, LWESecretToRef, LWEToMut}, layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToRef, LWESecret, LWESecretToRef, LWEToMut},
}; };
impl<DataSelf: DataMut> LWE<DataSelf> { impl<DataSelf: DataMut> LWE<DataSelf> {
pub fn encrypt_sk<P, S, M, BE: Backend>(&mut self, module: &M, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source) pub fn encrypt_sk<P, S, M, BE: Backend>(
where &mut self,
module: &M,
pt: &P,
sk: &S,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
P: LWEPlaintextToRef, P: LWEPlaintextToRef,
S: LWESecretToRef, S: LWESecretToRef,
M: LWEEncryptSk<BE>, M: LWEEncryptSk<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
module.lwe_encrypt_sk(self, pt, sk, source_xa, source_xe); module.lwe_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch);
} }
} }
pub trait LWEEncryptSk<BE: Backend> { pub trait LWEEncryptSk<BE: Backend> {
fn lwe_encrypt_sk<R, P, S>(&self, res: &mut R, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source) fn lwe_encrypt_sk<R, P, S>(
where &self,
res: &mut R,
pt: &P,
sk: &S,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
R: LWEToMut, R: LWEToMut,
P: LWEPlaintextToRef, P: LWEPlaintextToRef,
S: LWESecretToRef; S: LWESecretToRef,
Scratch<BE>: ScratchTakeCore<BE>;
} }
impl<BE: Backend> LWEEncryptSk<BE> for Module<BE> impl<BE: Backend> LWEEncryptSk<BE> for Module<BE>
where where
Self: Sized + ZnFillUniform + ZnAddNormal + ZnNormalizeInplace<BE>, Self: Sized + VecZnxFillUniform + VecZnxAddNormal + VecZnxNormalizeInplace<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
{ {
fn lwe_encrypt_sk<R, P, S>(&self, res: &mut R, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source) fn lwe_encrypt_sk<R, P, S>(
where &self,
res: &mut R,
pt: &P,
sk: &S,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch<BE>,
) where
R: LWEToMut, R: LWEToMut,
P: LWEPlaintextToRef, P: LWEPlaintextToRef,
S: LWESecretToRef, S: LWESecretToRef,
Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut LWE<&mut [u8]> = &mut res.to_mut(); let res: &mut LWE<&mut [u8]> = &mut res.to_mut();
let pt: &LWEPlaintext<&[u8]> = &pt.to_ref(); let pt: &LWEPlaintext<&[u8]> = &pt.to_ref();
@@ -51,11 +75,11 @@ where
let base2k: usize = res.base2k().into(); let base2k: usize = res.base2k().into();
let k: usize = res.k().into(); let k: usize = res.k().into();
self.zn_fill_uniform((res.n() + 1).into(), base2k, &mut res.data, 0, source_xa); self.vec_znx_fill_uniform(base2k, &mut res.data, 0, source_xa);
let mut tmp_znx: Zn<Vec<u8>> = Zn::alloc(1, 1, res.size()); let mut tmp_znx: VecZnx<Vec<u8>> = VecZnx::alloc(1, 1, res.size());
let min_size = res.size().min(pt.size()); let min_size: usize = res.size().min(pt.size());
(0..min_size).for_each(|i| { (0..min_size).for_each(|i| {
tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0] tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0]
@@ -74,24 +98,9 @@ where
.sum::<i64>(); .sum::<i64>();
}); });
self.zn_add_normal( self.vec_znx_add_normal(base2k, &mut tmp_znx, 0, k, source_xe, SIGMA, SIGMA_BOUND);
1,
base2k,
&mut res.data,
0,
k,
source_xe,
SIGMA,
SIGMA_BOUND,
);
self.zn_normalize_inplace( self.vec_znx_normalize_inplace(base2k, &mut tmp_znx, 0, scratch);
1,
base2k,
&mut tmp_znx,
0,
ScratchOwned::alloc(size_of::<i64>()).borrow(),
);
(0..res.size()).for_each(|i| { (0..res.size()).for_each(|i| {
res.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0]; res.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0];

View File

@@ -30,8 +30,8 @@ impl<DataSelf: DataMut> GLWEAutomorphismKey<DataSelf> {
pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>) pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where where
M: GGLWEExternalProduct<BE>, M: GGLWEExternalProduct<BE>,
A: GGLWEToRef, A: GGLWEToRef + GGLWEInfos,
B: GGSWPreparedToRef<BE>, B: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
module.gglwe_external_product(self, a, b, scratch); module.gglwe_external_product(self, a, b, scratch);
@@ -62,15 +62,11 @@ where
fn gglwe_external_product<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>) fn gglwe_external_product<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
where where
R: GGLWEToMut, R: GGLWEToMut + GGLWEInfos,
A: GGLWEToRef, A: GGLWEToRef + GGLWEInfos,
B: GGSWPreparedToRef<BE>, B: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut();
let a: &GGLWE<&[u8]> = &a.to_ref();
let b: &GGSWPrepared<&[u8], BE> = &b.to_ref();
assert_eq!( assert_eq!(
res.rank_in(), res.rank_in(),
a.rank_in(), a.rank_in(),
@@ -92,6 +88,11 @@ where
res.rank_out(), res.rank_out(),
b.rank() b.rank()
); );
assert_eq!(res.base2k(), a.base2k());
let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut();
let a: &GGLWE<&[u8]> = &a.to_ref();
let b: &GGSWPrepared<&[u8], BE> = &b.to_ref();
for row in 0..res.dnum().into() { for row in 0..res.dnum().into() {
for col in 0..res.rank_in().into() { for col in 0..res.rank_in().into() {
@@ -149,8 +150,8 @@ impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>) pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where where
M: GGLWEExternalProduct<BE>, M: GGLWEExternalProduct<BE>,
A: GGLWEToRef, A: GGLWEToRef + GGLWEInfos,
B: GGSWPreparedToRef<BE>, B: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
module.gglwe_external_product(self, a, b, scratch); module.gglwe_external_product(self, a, b, scratch);

View File

@@ -50,6 +50,8 @@ where
b.rank() b.rank()
); );
assert_eq!(res.base2k(), a.base2k());
assert!(scratch.available() >= self.ggsw_external_product_tmp_bytes(res, a, b)); assert!(scratch.available() >= self.ggsw_external_product_tmp_bytes(res, a, b));
let min_dnum: usize = res.dnum().min(a.dnum()).into(); let min_dnum: usize = res.dnum().min(a.dnum()).into();

View File

@@ -21,7 +21,7 @@ impl GLWEAutomorphismKey<Vec<u8>> {
impl<DataSelf: DataMut> GLWEAutomorphismKey<DataSelf> { impl<DataSelf: DataMut> GLWEAutomorphismKey<DataSelf> {
pub fn keyswitch<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>) pub fn keyswitch<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where where
A: GGLWEToRef + GGLWEToRef, A: GGLWEToRef + GGLWEInfos,
B: GGLWEPreparedToRef<BE> + GGLWEInfos, B: GGLWEPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeyswitch<BE>, M: GGLWEKeyswitch<BE>,
@@ -54,7 +54,7 @@ impl GLWESwitchingKey<Vec<u8>> {
impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> { impl<DataSelf: DataMut> GLWESwitchingKey<DataSelf> {
pub fn keyswitch<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>) pub fn keyswitch<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where where
A: GGLWEToRef, A: GGLWEToRef + GGLWEInfos,
B: GGLWEPreparedToRef<BE> + GGLWEInfos, B: GGLWEPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeyswitch<BE>, M: GGLWEKeyswitch<BE>,
@@ -87,7 +87,7 @@ impl GGLWE<Vec<u8>> {
impl<DataSelf: DataMut> GGLWE<DataSelf> { impl<DataSelf: DataMut> GGLWE<DataSelf> {
pub fn keyswitch<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>) pub fn keyswitch<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
where where
A: GGLWEToRef, A: GGLWEToRef + GGLWEInfos,
B: GGLWEPreparedToRef<BE> + GGLWEInfos, B: GGLWEPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
M: GGLWEKeyswitch<BE>, M: GGLWEKeyswitch<BE>,
@@ -122,14 +122,11 @@ where
fn gglwe_keyswitch<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>) fn gglwe_keyswitch<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
where where
R: GGLWEToMut, R: GGLWEToMut + GGLWEInfos,
A: GGLWEToRef, A: GGLWEToRef + GGLWEInfos,
B: GGLWEPreparedToRef<BE> + GGLWEInfos, B: GGLWEPreparedToRef<BE> + GGLWEInfos,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut();
let a: &GGLWE<&[u8]> = &a.to_ref();
assert_eq!( assert_eq!(
res.rank_in(), res.rank_in(),
a.rank_in(), a.rank_in(),
@@ -164,6 +161,10 @@ where
res.dsize(), res.dsize(),
a.dsize() a.dsize()
); );
assert_eq!(res.base2k(), a.base2k());
let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut();
let a: &GGLWE<&[u8]> = &a.to_ref();
for row in 0..res.dnum().into() { for row in 0..res.dnum().into() {
for col in 0..res.rank_in().into() { for col in 0..res.rank_in().into() {

View File

@@ -3,7 +3,7 @@ use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch};
use crate::{ use crate::{
GGSWExpandRows, ScratchTakeCore, GGSWExpandRows, ScratchTakeCore,
keyswitching::GLWEKeyswitch, keyswitching::GLWEKeyswitch,
layouts::{GGLWEInfos, GGLWEPreparedToRef, GGLWEToGGSWKeyPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef}, layouts::{GGLWEInfos, GGLWEPreparedToRef, GGLWEToGGSWKeyPreparedToRef, GGSW, GGSWInfos, GGSWToMut, GGSWToRef, LWEInfos},
}; };
impl GGSW<Vec<u8>> { impl GGSW<Vec<u8>> {
@@ -98,6 +98,7 @@ where
assert!(res.dnum() <= a.dnum()); assert!(res.dnum() <= a.dnum());
assert_eq!(res.dsize(), a.dsize()); assert_eq!(res.dsize(), a.dsize());
assert_eq!(res.base2k(), a.base2k());
for row in 0..a.dnum().into() { for row in 0..a.dnum().into() {
// Key-switch column 0, i.e. // Key-switch column 0, i.e.

View File

@@ -57,21 +57,19 @@ where
B: GGLWEInfos, B: GGLWEInfos,
{ {
let cols: usize = res_infos.rank().as_usize() + 1; let cols: usize = res_infos.rank().as_usize() + 1;
let size: usize = self let size: usize = if a_infos.base2k() != key_infos.base2k() {
.glwe_keyswitch_internal_tmp_bytes(res_infos, a_infos, key_infos) let a_conv_infos = &GLWELayout {
.max(self.vec_znx_big_normalize_tmp_bytes())
+ self.bytes_of_vec_znx_dft(cols, key_infos.size());
if a_infos.base2k() != key_infos.base2k() {
size + GLWE::bytes_of_from_infos(&GLWELayout {
n: a_infos.n(), n: a_infos.n(),
base2k: key_infos.base2k(), base2k: key_infos.base2k(),
k: a_infos.k(), k: a_infos.k(),
rank: a_infos.rank(), rank: a_infos.rank(),
}) };
self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_conv_infos, key_infos) + GLWE::bytes_of_from_infos(a_conv_infos)
} else { } else {
size self.glwe_keyswitch_internal_tmp_bytes(res_infos, a_infos, key_infos)
} };
size.max(self.vec_znx_big_normalize_tmp_bytes()) + self.bytes_of_vec_znx_dft(cols, key_infos.size())
} }
fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>) fn glwe_keyswitch<R, A, K>(&self, res: &mut R, a: &A, key: &K, scratch: &mut Scratch<BE>)
@@ -256,7 +254,7 @@ where
{ {
let cols: usize = (a_infos.rank() + 1).into(); let cols: usize = (a_infos.rank() + 1).into();
let a_size: usize = a_infos.size(); let a_size: usize = a_infos.size();
self.gglwe_product_dft_tmp_bytes(res_infos.size(), a_size, key_infos) + self.bytes_of_vec_znx_dft(cols, a_size) self.gglwe_product_dft_tmp_bytes(res_infos.size(), a_size, key_infos) + self.bytes_of_vec_znx_dft(cols - 1, a_size)
} }
fn glwe_keyswitch_internal<DR, A, K>( fn glwe_keyswitch_internal<DR, A, K>(

View File

@@ -83,34 +83,30 @@ where
assert_eq!(ksk.n(), self.n() as u32); assert_eq!(ksk.n(), self.n() as u32);
assert!(scratch.available() >= self.lwe_keyswitch_tmp_bytes(res, a, ksk)); assert!(scratch.available() >= self.lwe_keyswitch_tmp_bytes(res, a, ksk));
let max_k: TorusPrecision = res.k().max(a.k());
let a_size: usize = a.k().div_ceil(ksk.base2k()) as usize;
let (mut glwe_in, scratch_1) = scratch.take_glwe(&GLWELayout { let (mut glwe_in, scratch_1) = scratch.take_glwe(&GLWELayout {
n: ksk.n(), n: ksk.n(),
base2k: a.base2k(), base2k: a.base2k(),
k: max_k, k: a.k(),
rank: Rank(1), rank: Rank(1),
}); });
glwe_in.data.zero(); glwe_in.data.zero();
let (mut glwe_out, scratch_1) = scratch_1.take_glwe(&GLWELayout {
n: ksk.n(),
base2k: res.base2k(),
k: max_k,
rank: Rank(1),
});
let n_lwe: usize = a.n().into(); let n_lwe: usize = a.n().into();
for i in 0..a_size { for i in 0..a.size() {
let data_lwe: &[i64] = a.data.at(0, i); let data_lwe: &[i64] = a.data.at(0, i);
glwe_in.data.at_mut(0, i)[0] = data_lwe[0]; glwe_in.data.at_mut(0, i)[0] = data_lwe[0];
glwe_in.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]); glwe_in.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
} }
self.glwe_keyswitch(&mut glwe_out, &glwe_in, ksk, scratch_1); let (mut glwe_out, scratch_2) = scratch_1.take_glwe(&GLWELayout {
n: ksk.n(),
base2k: res.base2k(),
k: res.k(),
rank: Rank(1),
});
self.glwe_keyswitch(&mut glwe_out, &glwe_in, ksk, scratch_2);
self.lwe_sample_extract(res, &glwe_out); self.lwe_sample_extract(res, &glwe_out);
} }
} }

View File

@@ -1,10 +1,10 @@
use std::fmt; use std::fmt;
use poulpy_hal::{ use poulpy_hal::{
api::ZnFillUniform, api::VecZnxFillUniform,
layouts::{ layouts::{
Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, WriterTo, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, Backend, Data, DataMut, DataRef, FillUniform, Module, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo, ZnxInfos,
ZnxViewMut, ZnxView, ZnxViewMut,
}, },
source::Source, source::Source,
}; };
@@ -13,7 +13,7 @@ use crate::layouts::{Base2K, Degree, LWE, LWEInfos, LWEToMut, TorusPrecision};
#[derive(PartialEq, Eq, Clone)] #[derive(PartialEq, Eq, Clone)]
pub struct LWECompressed<D: Data> { pub struct LWECompressed<D: Data> {
pub(crate) data: Zn<D>, pub(crate) data: VecZnx<D>,
pub(crate) k: TorusPrecision, pub(crate) k: TorusPrecision,
pub(crate) base2k: Base2K, pub(crate) base2k: Base2K,
pub(crate) seed: [u8; 32], pub(crate) seed: [u8; 32],
@@ -72,7 +72,7 @@ impl LWECompressed<Vec<u8>> {
pub fn alloc(base2k: Base2K, k: TorusPrecision) -> Self { pub fn alloc(base2k: Base2K, k: TorusPrecision) -> Self {
LWECompressed { LWECompressed {
data: Zn::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), data: VecZnx::alloc(1, 1, k.0.div_ceil(base2k.0) as usize),
k, k,
base2k, base2k,
seed: [0u8; 32], seed: [0u8; 32],
@@ -87,7 +87,7 @@ impl LWECompressed<Vec<u8>> {
} }
pub fn bytes_of(base2k: Base2K, k: TorusPrecision) -> usize { pub fn bytes_of(base2k: Base2K, k: TorusPrecision) -> usize {
Zn::bytes_of(1, 1, k.0.div_ceil(base2k.0) as usize) VecZnx::bytes_of(1, 1, k.0.div_ceil(base2k.0) as usize)
} }
} }
@@ -113,7 +113,7 @@ impl<D: DataRef> WriterTo for LWECompressed<D> {
pub trait LWEDecompress pub trait LWEDecompress
where where
Self: ZnFillUniform, Self: VecZnxFillUniform,
{ {
fn decompress_lwe<R, O>(&self, res: &mut R, other: &O) fn decompress_lwe<R, O>(&self, res: &mut R, other: &O)
where where
@@ -126,20 +126,14 @@ where
assert_eq!(res.lwe_layout(), other.lwe_layout()); assert_eq!(res.lwe_layout(), other.lwe_layout());
let mut source: Source = Source::new(other.seed); let mut source: Source = Source::new(other.seed);
self.zn_fill_uniform( self.vec_znx_fill_uniform(other.base2k().into(), &mut res.data, 0, &mut source);
res.n().into(),
other.base2k().into(),
&mut res.data,
0,
&mut source,
);
for i in 0..res.size() { for i in 0..res.size() {
res.data.at_mut(0, i)[0] = other.data.at(0, i)[0]; res.data.at_mut(0, i)[0] = other.data.at(0, i)[0];
} }
} }
} }
impl<B: Backend> LWEDecompress for Module<B> where Self: ZnFillUniform {} impl<B: Backend> LWEDecompress for Module<B> where Self: VecZnxFillUniform {}
impl<D: DataMut> LWE<D> { impl<D: DataMut> LWE<D> {
pub fn decompress<O, M>(&mut self, module: &M, other: &O) pub fn decompress<O, M>(&mut self, module: &M, other: &O)

View File

@@ -158,3 +158,15 @@ impl<D: DataMut> GLWEPlaintextToMut for GLWEPlaintext<D> {
} }
} }
} }
impl<D: DataMut> GLWEPlaintext<D> {
pub fn data_mut(&mut self) -> &mut VecZnx<D> {
&mut self.data
}
}
impl<D: DataRef> GLWEPlaintext<D> {
pub fn data(&self) -> &VecZnx<D> {
&self.data
}
}

View File

@@ -1,7 +1,7 @@
use std::fmt; use std::fmt;
use poulpy_hal::{ use poulpy_hal::{
layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, WriterTo, Zn, ZnToMut, ZnToRef, ZnxInfos}, layouts::{Data, DataMut, DataRef, FillUniform, ReaderFrom, VecZnx, VecZnxToMut, VecZnxToRef, WriterTo, ZnxInfos},
source::Source, source::Source,
}; };
@@ -57,7 +57,7 @@ impl LWEInfos for LWELayout {
} }
#[derive(PartialEq, Eq, Clone)] #[derive(PartialEq, Eq, Clone)]
pub struct LWE<D: Data> { pub struct LWE<D: Data> {
pub(crate) data: Zn<D>, pub(crate) data: VecZnx<D>,
pub(crate) k: TorusPrecision, pub(crate) k: TorusPrecision,
pub(crate) base2k: Base2K, pub(crate) base2k: Base2K,
} }
@@ -90,13 +90,13 @@ impl<D: Data> SetLWEInfos for LWE<D> {
} }
impl<D: DataRef> LWE<D> { impl<D: DataRef> LWE<D> {
pub fn data(&self) -> &Zn<D> { pub fn data(&self) -> &VecZnx<D> {
&self.data &self.data
} }
} }
impl<D: DataMut> LWE<D> { impl<D: DataMut> LWE<D> {
pub fn data_mut(&mut self) -> &Zn<D> { pub fn data_mut(&mut self) -> &VecZnx<D> {
&mut self.data &mut self.data
} }
} }
@@ -121,7 +121,7 @@ impl<D: DataRef> fmt::Display for LWE<D> {
impl<D: DataMut> FillUniform for LWE<D> impl<D: DataMut> FillUniform for LWE<D>
where where
Zn<D>: FillUniform, VecZnx<D>: FillUniform,
{ {
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) { fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
self.data.fill_uniform(log_bound, source); self.data.fill_uniform(log_bound, source);
@@ -138,7 +138,7 @@ impl LWE<Vec<u8>> {
pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision) -> Self { pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision) -> Self {
LWE { LWE {
data: Zn::alloc((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize), data: VecZnx::alloc((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize),
k, k,
base2k, base2k,
} }
@@ -152,7 +152,7 @@ impl LWE<Vec<u8>> {
} }
pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision) -> usize { pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision) -> usize {
Zn::bytes_of((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize) VecZnx::bytes_of((n + 1).into(), 1, k.0.div_ceil(base2k.0) as usize)
} }
} }

View File

@@ -1,6 +1,6 @@
use std::fmt; use std::fmt;
use poulpy_hal::layouts::{Data, DataMut, DataRef, Zn, ZnToMut, ZnToRef, ZnxInfos}; use poulpy_hal::layouts::{Data, DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos};
use crate::layouts::{Base2K, Degree, LWEInfos, TorusPrecision}; use crate::layouts::{Base2K, Degree, LWEInfos, TorusPrecision};
@@ -29,7 +29,7 @@ impl LWEInfos for LWEPlaintextLayout {
} }
pub struct LWEPlaintext<D: Data> { pub struct LWEPlaintext<D: Data> {
pub(crate) data: Zn<D>, pub(crate) data: VecZnx<D>,
pub(crate) k: TorusPrecision, pub(crate) k: TorusPrecision,
pub(crate) base2k: Base2K, pub(crate) base2k: Base2K,
} }
@@ -62,7 +62,7 @@ impl LWEPlaintext<Vec<u8>> {
pub fn alloc(base2k: Base2K, k: TorusPrecision) -> Self { pub fn alloc(base2k: Base2K, k: TorusPrecision) -> Self {
LWEPlaintext { LWEPlaintext {
data: Zn::alloc(1, 1, k.0.div_ceil(base2k.0) as usize), data: VecZnx::alloc(1, 1, k.0.div_ceil(base2k.0) as usize),
k, k,
base2k, base2k,
} }
@@ -111,8 +111,14 @@ impl<D: DataMut> LWEPlaintextToMut for LWEPlaintext<D> {
} }
} }
impl<D: DataRef> LWEPlaintext<D> {
pub fn data(&self) -> &VecZnx<D> {
&self.data
}
}
impl<D: DataMut> LWEPlaintext<D> { impl<D: DataMut> LWEPlaintext<D> {
pub fn data_mut(&mut self) -> &mut Zn<D> { pub fn data_mut(&mut self) -> &mut VecZnx<D> {
&mut self.data &mut self.data
} }
} }

View File

@@ -42,7 +42,7 @@ pub(crate) fn var_noise_gglwe_product(
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) fn var_noise_gglwe_product_v2( pub(crate) fn var_noise_gglwe_product_v2(
n: f64, n: f64,
logq: usize, k_ksk: usize,
dnum: usize, dnum: usize,
dsize: usize, dsize: usize,
base2k: usize, base2k: usize,
@@ -55,7 +55,7 @@ pub(crate) fn var_noise_gglwe_product_v2(
) -> f64 { ) -> f64 {
let base: f64 = ((dsize * base2k) as f64).exp2(); let base: f64 = ((dsize * base2k) as f64).exp2();
let var_base: f64 = base * base / 12f64; let var_base: f64 = base * base / 12f64;
let scale: f64 = (logq as f64).exp2(); let scale: f64 = (k_ksk as f64).exp2();
let mut noise: f64 = (dnum as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs); let mut noise: f64 = (dnum as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
noise += var_msg * var_a_err * var_base * n; noise += var_msg * var_a_err * var_base * n;

View File

@@ -23,7 +23,7 @@ where
where where
A: LWEInfos, A: LWEInfos,
{ {
let (data, scratch) = self.take_zn(infos.n().into(), 1, infos.size()); let (data, scratch) = self.take_vec_znx(infos.n().into(), 1, infos.size());
( (
LWE { LWE {
k: infos.k(), k: infos.k(),

View File

@@ -9,10 +9,10 @@ use crate::{
encryption::SIGMA, encryption::SIGMA,
layouts::{ layouts::{
GGLWEInfos, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWEPlaintext, GGLWEInfos, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWEPlaintext,
GLWESecret, GLWESecretPreparedFactory, GLWESecret, GLWESecretPreparedFactory, LWEInfos,
prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared},
}, },
noise::log2_std_noise_gglwe_product, var_noise_gglwe_product_v2,
}; };
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
@@ -29,26 +29,27 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k: usize = 12; let base2k_in: usize = 17;
let k_in: usize = 60; let base2k_key: usize = 13;
let k_out: usize = 40; let base2k_out: usize = base2k_in; // MUST BE SAME
let dsize: usize = k_in.div_ceil(base2k); let k_in: usize = 102;
let max_dsize: usize = k_in.div_ceil(base2k_key);
let p0: i64 = -1; let p0: i64 = -1;
let p1: i64 = -5; let p1: i64 = -5;
for rank in 1_usize..3 { for rank in 1_usize..3 {
for di in 1..dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_apply: usize = (dsize + di) * base2k; let k_ksk: usize = k_in + base2k_key * dsize;
let k_out: usize = k_ksk; // Better capture noise.
let n: usize = module.n(); let n: usize = module.n();
let dsize_in: usize = 1; let dsize_in: usize = 1;
let dnum_in: usize = k_in / (base2k * di); let dnum_in: usize = k_in / base2k_in;
let dnum_out: usize = k_out / (base2k * di); let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize);
let dnum_apply: usize = k_in.div_ceil(base2k * di);
let auto_key_in_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let auto_key_in_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_in.into(),
k: k_in.into(), k: k_in.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -57,19 +58,19 @@ where
let auto_key_out_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let auto_key_out_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_out.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_out.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
rank: rank.into(), rank: rank.into(),
}; };
let auto_key_apply_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let auto_key_apply_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_key.into(),
k: k_apply.into(), k: k_ksk.into(),
dnum: dnum_apply.into(), dnum: dnum_ksk.into(),
dsize: di.into(), dsize: dsize.into(),
rank: rank.into(), rank: rank.into(),
}; };
@@ -83,13 +84,16 @@ where
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_in_infos) GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_in_infos)
| GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &auto_key_apply_infos) .max(GLWEAutomorphismKey::encrypt_sk_tmp_bytes(
| GLWEAutomorphismKey::automorphism_tmp_bytes( module,
&auto_key_apply_infos,
))
.max(GLWEAutomorphismKey::automorphism_tmp_bytes(
module, module,
&auto_key_out_infos, &auto_key_out_infos,
&auto_key_in_infos, &auto_key_in_infos,
&auto_key_apply_infos, &auto_key_apply_infos,
), )),
); );
let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&auto_key_in); let mut sk: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&auto_key_in);
@@ -128,7 +132,7 @@ where
scratch.borrow(), scratch.borrow(),
); );
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&auto_key_out_infos); let mut pt_out: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&auto_key_out_infos);
let mut sk_auto: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&auto_key_out_infos); let mut sk_auto: GLWESecret<Vec<u8>> = GLWESecret::alloc_from_infos(&auto_key_out_infos);
sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk
@@ -145,41 +149,44 @@ where
let mut sk_auto_dft: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk_auto); let mut sk_auto_dft: GLWESecretPrepared<Vec<u8>, BE> = GLWESecretPrepared::alloc_from_infos(module, &sk_auto);
sk_auto_dft.prepare(module, &sk_auto); sk_auto_dft.prepare(module, &sk_auto);
(0..auto_key_out.rank_in().into()).for_each(|col_i| { for col_i in 0..auto_key_out.rank_in().into() {
(0..auto_key_out.dnum().into()).for_each(|row_i| { for row_i in 0..auto_key_out.dnum().into() {
auto_key_out auto_key_out
.at(row_i, col_i) .at(row_i, col_i)
.decrypt(module, &mut pt, &sk_auto_dft, scratch.borrow()); .decrypt(module, &mut pt_out, &sk_auto_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace( module.vec_znx_sub_scalar_inplace(
&mut pt.data, &mut pt_out.data,
0, 0,
(dsize_in - 1) + row_i * dsize_in, (dsize_in - 1) + row_i * dsize_in,
&sk.data, &sk.data,
col_i, col_i,
); );
let noise_have: f64 = pt.data.stats(base2k, 0).std().log2(); let noise_have: f64 = pt_out.data.stats(pt_out.base2k().into(), 0).std().log2();
let noise_want: f64 = log2_std_noise_gglwe_product( let max_noise: f64 = var_noise_gglwe_product_v2(
n as f64, module.n() as f64,
base2k * di, k_ksk,
dnum_ksk,
dsize,
base2k_key,
0.5, 0.5,
0.5, 0.5,
0f64, 0f64,
SIGMA * SIGMA, SIGMA * SIGMA,
0f64, 0f64,
rank as f64, rank as f64,
k_out, )
k_apply, .sqrt()
); .log2();
assert!( assert!(
noise_have < noise_want + 0.5, noise_have < max_noise + 0.5,
"{noise_have} {}", "{noise_have} {}",
noise_want + 0.5 max_noise + 0.5
); );
}); }
}); }
} }
} }
} }
@@ -198,25 +205,27 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k: usize = 12; let base2k_out: usize = 17;
let k_in: usize = 60; let base2k_key: usize = 13;
let dsize: usize = k_in.div_ceil(base2k); let k_out: usize = 102;
let max_dsize: usize = k_out.div_ceil(base2k_key);
let p0: i64 = -1; let p0: i64 = -1;
let p1: i64 = -5; let p1: i64 = -5;
for rank in 1_usize..3 { for rank in 1_usize..3 {
for di in 1..dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_apply: usize = (dsize + di) * base2k; let k_ksk: usize = k_out + base2k_key * dsize;
let n: usize = module.n(); let n: usize = module.n();
let dsize_in: usize = 1; let dsize_in: usize = 1;
let dnum_in: usize = k_in / (base2k * di); let dnum_in: usize = k_out / base2k_out;
let dnum_apply: usize = k_in.div_ceil(base2k * di); let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize);
let auto_key_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let auto_key_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_out.into(),
k: k_in.into(), k: k_out.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
rank: rank.into(), rank: rank.into(),
@@ -224,10 +233,10 @@ where
let auto_key_apply_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let auto_key_apply_layout: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_key.into(),
k: k_apply.into(), k: k_ksk.into(),
dnum: dnum_apply.into(), dnum: dnum_ksk.into(),
dsize: di.into(), dsize: dsize.into(),
rank: rank.into(), rank: rank.into(),
}; };
@@ -306,24 +315,27 @@ where
col_i, col_i,
); );
let noise_have: f64 = pt.data.stats(base2k, 0).std().log2(); let noise_have: f64 = pt.data.stats(pt.base2k().into(), 0).std().log2();
let noise_want: f64 = log2_std_noise_gglwe_product( let max_noise: f64 = var_noise_gglwe_product_v2(
n as f64, module.n() as f64,
base2k * di, k_ksk,
dnum_ksk,
dsize,
base2k_key,
0.5, 0.5,
0.5, 0.5,
0f64, 0f64,
SIGMA * SIGMA, SIGMA * SIGMA,
0f64, 0f64,
rank as f64, rank as f64,
k_in, )
k_apply, .sqrt()
); .log2();
assert!( assert!(
noise_have < noise_want + 0.5, noise_have < max_noise + 0.5,
"{noise_have} {}", "{noise_have} {}",
noise_want + 0.5 max_noise + 0.5
); );
}); });
}); });

View File

@@ -29,26 +29,28 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k: usize = 12; let base2k_in: usize = 17;
let k_in: usize = 54; let base2k_key: usize = 13;
let dsize: usize = k_in.div_ceil(base2k); let base2k_out: usize = base2k_in; // MUST BE SAME
let p: i64 = -5; let k_in: usize = 102;
let max_dsize: usize = k_in.div_ceil(base2k_key);
let p: i64 = -5;
for rank in 1_usize..3 { for rank in 1_usize..3 {
for di in 1..dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ksk: usize = k_in + base2k * di; let k_ksk: usize = k_in + base2k_key * dsize;
let k_tsk: usize = k_ksk; let k_tsk: usize = k_ksk;
let k_out: usize = k_ksk; // Better capture noise. let k_out: usize = k_ksk; // Better capture noise.
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k_in.div_ceil(base2k * di); let dnum_in: usize = k_in / base2k_in;
let dnum_in: usize = k_in.div_euclid(base2k * di); let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize);
let dsize_in: usize = 1; let dsize_in: usize = 1;
let ggsw_in_layout: GGSWLayout = GGSWLayout { let ggsw_in_layout: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_in.into(),
k: k_in.into(), k: k_in.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -57,7 +59,7 @@ where
let ggsw_out_layout: GGSWLayout = GGSWLayout { let ggsw_out_layout: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_out.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -66,19 +68,19 @@ where
let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_key.into(),
k: k_tsk.into(), k: k_tsk.into(),
dnum: dnum.into(), dnum: dnum_ksk.into(),
dsize: di.into(), dsize: dsize.into(),
rank: rank.into(), rank: rank.into(),
}; };
let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_key.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum.into(), dnum: dnum_ksk.into(),
dsize: di.into(), dsize: dsize.into(),
rank: rank.into(), rank: rank.into(),
}; };
@@ -154,7 +156,7 @@ where
let max_noise = |col_j: usize| -> f64 { let max_noise = |col_j: usize| -> f64 {
noise_ggsw_keyswitch( noise_ggsw_keyswitch(
n as f64, n as f64,
base2k * di, base2k_key * dsize,
col_j, col_j,
var_xs, var_xs,
0f64, 0f64,
@@ -187,23 +189,25 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k: usize = 12; let base2k_out: usize = 17;
let k_out: usize = 54; let base2k_key: usize = 13;
let dsize: usize = k_out.div_ceil(base2k); let k_out: usize = 102;
let max_dsize: usize = k_out.div_ceil(base2k_key);
let p: i64 = -1; let p: i64 = -1;
for rank in 1_usize..3 { for rank in 1_usize..3 {
for di in 1..dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ksk: usize = k_out + base2k * di; let k_ksk: usize = k_out + base2k_key * dsize;
let k_tsk: usize = k_ksk; let k_tsk: usize = k_ksk;
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k_out.div_ceil(di * base2k); let dnum_in: usize = k_out / base2k_out;
let dnum_in: usize = k_out.div_euclid(base2k * di); let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize);
let dsize_in: usize = 1; let dsize_in: usize = 1;
let ggsw_out_layout: GGSWLayout = GGSWLayout { let ggsw_out_layout: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_out.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_in.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
@@ -212,19 +216,19 @@ where
let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { let tsk_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_key.into(),
k: k_tsk.into(), k: k_tsk.into(),
dnum: dnum.into(), dnum: dnum_ksk.into(),
dsize: di.into(), dsize: dsize.into(),
rank: rank.into(), rank: rank.into(),
}; };
let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout { let auto_key_layout: GGLWEToGGSWKeyLayout = GGLWEToGGSWKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_key.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum.into(), dnum: dnum_ksk.into(),
dsize: di.into(), dsize: dsize.into(),
rank: rank.into(), rank: rank.into(),
}; };
@@ -293,7 +297,7 @@ where
let max_noise = |col_j: usize| -> f64 { let max_noise = |col_j: usize| -> f64 {
noise_ggsw_keyswitch( noise_ggsw_keyswitch(
n as f64, n as f64,
base2k * di, base2k_key * dsize,
col_j, col_j,
var_xs, var_xs,
0f64, 0f64,

View File

@@ -5,14 +5,14 @@ use poulpy_hal::{
}; };
use crate::{ use crate::{
GLWEAutomorphism, GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWENoise, ScratchTakeCore, GLWEAutomorphism, GLWEAutomorphismKeyEncryptSk, GLWEDecrypt, GLWEEncryptSk, GLWENoise, GLWENormalize, ScratchTakeCore,
encryption::SIGMA, encryption::SIGMA,
layouts::{ layouts::{
GLWE, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWELayout, GLWEPlaintext, GLWE, GLWEAutomorphismKey, GLWEAutomorphismKeyLayout, GLWEAutomorphismKeyPreparedFactory, GLWELayout, GLWEPlaintext,
GLWESecret, GLWESecretPreparedFactory, GLWESecret, GLWESecretPreparedFactory,
prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared}, prepared::{GLWEAutomorphismKeyPrepared, GLWESecretPrepared},
}, },
noise::log2_std_noise_gglwe_product, var_noise_gglwe_product_v2,
}; };
pub fn test_glwe_automorphism<BE: Backend>(module: &Module<BE>) pub fn test_glwe_automorphism<BE: Backend>(module: &Module<BE>)
@@ -25,55 +25,59 @@ where
+ GLWEAutomorphismKeyEncryptSk<BE> + GLWEAutomorphismKeyEncryptSk<BE>
+ GLWEAutomorphismKeyPreparedFactory<BE> + GLWEAutomorphismKeyPreparedFactory<BE>
+ GLWENoise<BE> + GLWENoise<BE>
+ VecZnxAutomorphismInplace<BE>, + VecZnxAutomorphismInplace<BE>
+ GLWENormalize<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k: usize = 12; let base2k_in: usize = 17;
let k_in: usize = 60; let base2k_key: usize = 13;
let dsize: usize = k_in.div_ceil(base2k); let base2k_out: usize = 15;
let k_in: usize = 102;
let max_dsize: usize = k_in.div_ceil(base2k_key);
let p: i64 = -5; let p: i64 = -5;
for rank in 1_usize..3 { for rank in 1_usize..3 {
for di in 1..dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ksk: usize = k_in + base2k * di; let k_ksk: usize = k_in + base2k_key * dsize;
let k_out: usize = k_ksk; // Better capture noise. let k_out: usize = k_ksk; // Better capture noise.
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k_in.div_ceil(base2k * dsize); let dnum: usize = k_in.div_ceil(base2k_key * dsize);
let ct_in_infos: GLWELayout = GLWELayout { let ct_in_infos: GLWELayout = GLWELayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_in.into(),
k: k_in.into(), k: k_in.into(),
rank: rank.into(), rank: rank.into(),
}; };
let ct_out_infos: GLWELayout = GLWELayout { let ct_out_infos: GLWELayout = GLWELayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_out.into(),
k: k_out.into(), k: k_out.into(),
rank: rank.into(), rank: rank.into(),
}; };
let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_key.into(),
k: k_out.into(), k: k_out.into(),
rank: rank.into(), rank: rank.into(),
dnum: dnum.into(), dnum: dnum.into(),
dsize: di.into(), dsize: dsize.into(),
}; };
let mut autokey: GLWEAutomorphismKey<Vec<u8>> = GLWEAutomorphismKey::alloc_from_infos(&autokey_infos); let mut autokey: GLWEAutomorphismKey<Vec<u8>> = GLWEAutomorphismKey::alloc_from_infos(&autokey_infos);
let mut ct_in: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&ct_in_infos); let mut ct_in: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&ct_in_infos);
let mut ct_out: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&ct_out_infos); let mut ct_out: GLWE<Vec<u8>> = GLWE::alloc_from_infos(&ct_out_infos);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&ct_out_infos); let mut pt_in: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&ct_in_infos);
let mut pt_out: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&ct_out_infos);
let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]);
module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); module.vec_znx_fill_uniform(base2k_in, &mut pt_in.data, 0, &mut source_xa);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey)
@@ -99,7 +103,7 @@ where
ct_in.encrypt_sk( ct_in.encrypt_sk(
module, module,
&pt_want, &pt_in,
&sk_prepared, &sk_prepared,
&mut source_xa, &mut source_xa,
&mut source_xe, &mut source_xe,
@@ -112,22 +116,26 @@ where
ct_out.automorphism(module, &ct_in, &autokey_prepared, scratch.borrow()); ct_out.automorphism(module, &ct_in, &autokey_prepared, scratch.borrow());
let max_noise: f64 = log2_std_noise_gglwe_product( let max_noise: f64 = var_noise_gglwe_product_v2(
module.n() as f64, module.n() as f64,
base2k * dsize, k_ksk,
dnum,
max_dsize,
base2k_key,
0.5, 0.5,
0.5, 0.5,
0f64, 0f64,
SIGMA * SIGMA, SIGMA * SIGMA,
0f64, 0f64,
rank as f64, rank as f64,
k_in, )
k_ksk, .sqrt()
); .log2();
module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0, scratch.borrow()); module.glwe_normalize(&mut pt_out, &pt_in, scratch.borrow());
module.vec_znx_automorphism_inplace(p, &mut pt_out.data, 0, scratch.borrow());
ct_out.assert_noise(module, &sk_prepared, &pt_want, max_noise + 1.0); ct_out.assert_noise(module, &sk_prepared, &pt_out, max_noise + 1.0);
} }
} }
} }
@@ -147,31 +155,33 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k: usize = 12; let base2k_out: usize = 17;
let k_out: usize = 60; let base2k_key: usize = 13;
let dsize: usize = k_out.div_ceil(base2k); let k_out: usize = 102;
let max_dsize: usize = k_out.div_ceil(base2k_key);
let p = -5; let p = -5;
for rank in 1_usize..3 { for rank in 1_usize..3 {
for di in 1..dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ksk: usize = k_out + base2k * di; let k_ksk: usize = k_out + base2k_key * dsize;
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k_out.div_ceil(base2k * dsize); let dnum: usize = k_out.div_ceil(base2k_key * dsize);
let ct_out_infos: GLWELayout = GLWELayout { let ct_out_infos: GLWELayout = GLWELayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_out.into(),
k: k_out.into(), k: k_out.into(),
rank: rank.into(), rank: rank.into(),
}; };
let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout { let autokey_infos: GLWEAutomorphismKeyLayout = GLWEAutomorphismKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_key.into(),
k: k_ksk.into(), k: k_ksk.into(),
rank: rank.into(), rank: rank.into(),
dnum: dnum.into(), dnum: dnum.into(),
dsize: di.into(), dsize: dsize.into(),
}; };
let mut autokey: GLWEAutomorphismKey<Vec<u8>> = GLWEAutomorphismKey::alloc_from_infos(&autokey_infos); let mut autokey: GLWEAutomorphismKey<Vec<u8>> = GLWEAutomorphismKey::alloc_from_infos(&autokey_infos);
@@ -182,7 +192,7 @@ where
let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]);
module.vec_znx_fill_uniform(base2k, &mut pt_want.data, 0, &mut source_xa); module.vec_znx_fill_uniform(base2k_out, &mut pt_want.data, 0, &mut source_xa);
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey) GLWEAutomorphismKey::encrypt_sk_tmp_bytes(module, &autokey)
@@ -221,18 +231,21 @@ where
ct.automorphism_inplace(module, &autokey_prepared, scratch.borrow()); ct.automorphism_inplace(module, &autokey_prepared, scratch.borrow());
let max_noise: f64 = log2_std_noise_gglwe_product( let max_noise: f64 = var_noise_gglwe_product_v2(
module.n() as f64, module.n() as f64,
base2k * dsize, k_ksk,
dnum,
dsize,
base2k_key,
0.5, 0.5,
0.5, 0.5,
0f64, 0f64,
SIGMA * SIGMA, SIGMA * SIGMA,
0f64, 0f64,
rank as f64, rank as f64,
k_out, )
k_ksk, .sqrt()
); .log2();
module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0, scratch.borrow()); module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0, scratch.borrow());

View File

@@ -1,5 +1,5 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxFillUniform}, api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxFillUniform, VecZnxNormalize},
layouts::{Backend, FillUniform, Module, Scratch, ScratchOwned, ZnxView}, layouts::{Backend, FillUniform, Module, Scratch, ScratchOwned, ZnxView},
source::Source, source::Source,
}; };
@@ -104,7 +104,8 @@ where
+ GLWEDecrypt<BE> + GLWEDecrypt<BE>
+ GLWESecretPreparedFactory<BE> + GLWESecretPreparedFactory<BE>
+ LWEEncryptSk<BE> + LWEEncryptSk<BE>
+ LWEToGLWEKeyPreparedFactory<BE>, + LWEToGLWEKeyPreparedFactory<BE>
+ VecZnxNormalize<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
@@ -120,23 +121,23 @@ where
let lwe_to_glwe_infos: LWEToGLWEKeyLayout = LWEToGLWEKeyLayout { let lwe_to_glwe_infos: LWEToGLWEKeyLayout = LWEToGLWEKeyLayout {
n: n_glwe, n: n_glwe,
base2k: Base2K(17), base2k: Base2K(13),
k: TorusPrecision(51), k: TorusPrecision(92),
dnum: Dnum(2), dnum: Dnum(2),
rank_out: rank, rank_out: rank,
}; };
let glwe_infos: GLWELayout = GLWELayout { let glwe_infos: GLWELayout = GLWELayout {
n: n_glwe, n: n_glwe,
base2k: Base2K(17), base2k: Base2K(15),
k: TorusPrecision(34), k: TorusPrecision(75),
rank, rank,
}; };
let lwe_infos: LWELayout = LWELayout { let lwe_infos: LWELayout = LWELayout {
n: n_lwe, n: n_lwe,
base2k: Base2K(17), base2k: Base2K(17),
k: TorusPrecision(34), k: TorusPrecision(75),
}; };
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
@@ -160,7 +161,14 @@ where
lwe_pt.encode_i64(data, k_lwe_pt); lwe_pt.encode_i64(data, k_lwe_pt);
let mut lwe_ct: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos); let mut lwe_ct: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos);
lwe_ct.encrypt_sk(module, &lwe_pt, &sk_lwe, &mut source_xa, &mut source_xe); lwe_ct.encrypt_sk(
module,
&lwe_pt,
&sk_lwe,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let mut ksk: LWEToGLWEKey<Vec<u8>> = LWEToGLWEKey::alloc_from_infos(&lwe_to_glwe_infos); let mut ksk: LWEToGLWEKey<Vec<u8>> = LWEToGLWEKey::alloc_from_infos(&lwe_to_glwe_infos);
@@ -183,7 +191,19 @@ where
let mut glwe_pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_infos); let mut glwe_pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(&glwe_infos);
glwe_ct.decrypt(module, &mut glwe_pt, &sk_glwe_prepared, scratch.borrow()); glwe_ct.decrypt(module, &mut glwe_pt, &sk_glwe_prepared, scratch.borrow());
assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt.data.at(0, 0)[0]); let mut lwe_pt_conv = LWEPlaintext::alloc(glwe_pt.base2k(), lwe_pt.k());
module.vec_znx_normalize(
glwe_pt.base2k().as_usize(),
lwe_pt_conv.data_mut(),
0,
lwe_pt.base2k().as_usize(),
lwe_pt.data(),
0,
scratch.borrow(),
);
assert_eq!(glwe_pt.data.at(0, 0)[0], lwe_pt_conv.data.at(0, 0)[0]);
} }
pub fn test_glwe_to_lwe<BE: Backend>(module: &Module<BE>) pub fn test_glwe_to_lwe<BE: Backend>(module: &Module<BE>)
@@ -196,7 +216,8 @@ where
+ GLWEDecrypt<BE> + GLWEDecrypt<BE>
+ GLWESecretPreparedFactory<BE> + GLWESecretPreparedFactory<BE>
+ GLWEToLWESwitchingKeyEncryptSk<BE> + GLWEToLWESwitchingKeyEncryptSk<BE>
+ GLWEToLWEKeyPreparedFactory<BE>, + GLWEToLWEKeyPreparedFactory<BE>
+ VecZnxNormalize<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
@@ -208,8 +229,8 @@ where
let glwe_to_lwe_infos: GLWEToLWEKeyLayout = GLWEToLWEKeyLayout { let glwe_to_lwe_infos: GLWEToLWEKeyLayout = GLWEToLWEKeyLayout {
n: n_glwe, n: n_glwe,
base2k: Base2K(17), base2k: Base2K(13),
k: TorusPrecision(51), k: TorusPrecision(91),
dnum: Dnum(2), dnum: Dnum(2),
rank_in: rank, rank_in: rank,
}; };
@@ -217,14 +238,14 @@ where
let glwe_infos: GLWELayout = GLWELayout { let glwe_infos: GLWELayout = GLWELayout {
n: n_glwe, n: n_glwe,
base2k: Base2K(17), base2k: Base2K(17),
k: TorusPrecision(34), k: TorusPrecision(72),
rank, rank,
}; };
let lwe_infos: LWELayout = LWELayout { let lwe_infos: LWELayout = LWELayout {
n: n_lwe, n: n_lwe,
base2k: Base2K(17), base2k: Base2K(15),
k: TorusPrecision(34), k: TorusPrecision(72),
}; };
let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xs: Source = Source::new([0u8; 32]);
@@ -284,7 +305,19 @@ where
lwe_ct.from_glwe(module, &glwe_ct, a_idx, &ksk_prepared, scratch.borrow()); lwe_ct.from_glwe(module, &glwe_ct, a_idx, &ksk_prepared, scratch.borrow());
let mut lwe_pt: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc_from_infos(&lwe_infos); let mut lwe_pt: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc_from_infos(&lwe_infos);
lwe_ct.decrypt(module, &mut lwe_pt, &sk_lwe); lwe_ct.decrypt(module, &mut lwe_pt, &sk_lwe, scratch.borrow());
assert_eq!(glwe_pt.data.at(0, 0)[a_idx], lwe_pt.data.at(0, 0)[0]); let mut glwe_pt_conv = GLWEPlaintext::alloc(glwe_ct.n(), lwe_pt.base2k(), lwe_pt.k());
module.vec_znx_normalize(
lwe_pt.base2k().as_usize(),
glwe_pt_conv.data_mut(),
0,
glwe_ct.base2k().as_usize(),
glwe_pt.data(),
0,
scratch.borrow(),
);
assert_eq!(glwe_pt_conv.data.at(0, 0)[a_idx], lwe_pt.data.at(0, 0)[0]);
} }

View File

@@ -12,6 +12,7 @@ use crate::{
prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared}, prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared},
}, },
noise::log2_std_noise_gglwe_product, noise::log2_std_noise_gglwe_product,
var_noise_gglwe_product_v2,
}; };
pub fn test_gglwe_switching_key_keyswitch<BE: Backend>(module: &Module<BE>) pub fn test_gglwe_switching_key_keyswitch<BE: Backend>(module: &Module<BE>)
@@ -24,27 +25,29 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k: usize = 12; let base2k_in: usize = 17;
let k_in: usize = 60; let base2k_key: usize = 13;
let dsize: usize = k_in.div_ceil(base2k); let base2k_out: usize = base2k_in; // MUST BE SAME
let k_in: usize = 102;
let max_dsize: usize = k_in.div_ceil(base2k_key);
for rank_in_s0s1 in 1_usize..3 { for rank_in_s0s1 in 1_usize..2 {
for rank_out_s0s1 in 1_usize..3 { for rank_out_s0s1 in 1_usize..3 {
for rank_out_s1s2 in 1_usize..3 { for rank_out_s1s2 in 1_usize..3 {
for di in 1_usize..dsize + 1 { for dsize in 1_usize..max_dsize + 1 {
let k_ksk: usize = k_in + base2k * di; let k_ksk: usize = k_in + base2k_key * dsize;
let k_out: usize = k_ksk; // Better capture noise. let k_out: usize = k_ksk; // Better capture noise.
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k_in / base2k;
let dnum_apply: usize = k_in.div_ceil(base2k * di);
let dsize_in: usize = 1; let dsize_in: usize = 1;
let dnum_in: usize = k_in / base2k_in;
let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize);
let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_in.into(),
k: k_in.into(), k: k_in.into(),
dnum: dnum.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
rank_in: rank_in_s0s1.into(), rank_in: rank_in_s0s1.into(),
rank_out: rank_out_s0s1.into(), rank_out: rank_out_s0s1.into(),
@@ -52,19 +55,19 @@ where
let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_key.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum_apply.into(), dnum: dnum_ksk.into(),
dsize: di.into(), dsize: dsize.into(),
rank_in: rank_out_s0s1.into(), rank_in: rank_out_s0s1.into(),
rank_out: rank_out_s1s2.into(), rank_out: rank_out_s1s2.into(),
}; };
let gglwe_s0s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let gglwe_s0s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_out.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum_apply.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
rank_in: rank_in_s0s1.into(), rank_in: rank_in_s0s1.into(),
rank_out: rank_out_s1s2.into(), rank_out: rank_out_s1s2.into(),
@@ -85,8 +88,8 @@ where
); );
let mut scratch_apply: ScratchOwned<BE> = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_tmp_bytes( let mut scratch_apply: ScratchOwned<BE> = ScratchOwned::alloc(GLWESwitchingKey::keyswitch_tmp_bytes(
module, module,
&gglwe_s0s1_infos,
&gglwe_s0s2_infos, &gglwe_s0s2_infos,
&gglwe_s0s1_infos,
&gglwe_s1s2_infos, &gglwe_s1s2_infos,
)); ));
@@ -135,18 +138,21 @@ where
scratch_apply.borrow(), scratch_apply.borrow(),
); );
let max_noise: f64 = log2_std_noise_gglwe_product( let max_noise: f64 = var_noise_gglwe_product_v2(
n as f64, module.n() as f64,
base2k * di, k_ksk,
dnum_ksk,
dsize,
base2k_key,
0.5, 0.5,
0.5, 0.5,
0f64, 0f64,
SIGMA * SIGMA, SIGMA * SIGMA,
0f64, 0f64,
rank_out_s0s1 as f64, rank_out_s0s1 as f64,
k_in, )
k_ksk, .sqrt()
); .log2();
gglwe_s0s2 gglwe_s0s2
.key .key
@@ -168,23 +174,27 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k: usize = 12; let base2k_out: usize = 17;
let k_out: usize = 60; let base2k_key: usize = 13;
let dsize: usize = k_out.div_ceil(base2k); let k_out: usize = 102;
let max_dsize: usize = k_out.div_ceil(base2k_key);
for rank_in in 1_usize..3 { for rank_in in 1_usize..3 {
for rank_out in 1_usize..3 { for rank_out in 1_usize..3 {
for di in 1_usize..dsize + 1 { for dsize in 1_usize..max_dsize + 1 {
let k_ksk: usize = k_out + base2k * di; let k_ksk: usize = k_out + base2k_key * dsize;
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k_out.div_ceil(base2k * di);
let dsize_in: usize = 1; let dsize_in: usize = 1;
let dnum_in: usize = k_out / base2k_out;
let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize);
let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let gglwe_s0s1_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_out.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
rank_in: rank_in.into(), rank_in: rank_in.into(),
rank_out: rank_out.into(), rank_out: rank_out.into(),
@@ -192,10 +202,10 @@ where
let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let gglwe_s1s2_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_key.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum.into(), dnum: dnum_ksk.into(),
dsize: di.into(), dsize: dsize.into(),
rank_in: rank_out.into(), rank_in: rank_out.into(),
rank_out: rank_out.into(), rank_out: rank_out.into(),
}; };
@@ -263,7 +273,7 @@ where
let max_noise: f64 = log2_std_noise_gglwe_product( let max_noise: f64 = log2_std_noise_gglwe_product(
n as f64, n as f64,
base2k * di, base2k_key * dsize,
var_xs, var_xs,
var_xs, var_xs,
0f64, 0f64,

View File

@@ -30,53 +30,57 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k: usize = 12; let base2k_in: usize = 17;
let k_in: usize = 54; let base2k_key: usize = 13;
let dsize: usize = k_in.div_ceil(base2k); let base2k_out: usize = base2k_in; // MUST BE SAME
let k_in: usize = 102;
let max_dsize: usize = k_in.div_ceil(base2k_key);
for rank in 1_usize..3 { for rank in 1_usize..3 {
for di in 1..dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ksk: usize = k_in + base2k * di; let k_ksk: usize = k_in + base2k_key * dsize;
let k_tsk: usize = k_ksk; let k_tsk: usize = k_ksk;
let k_out: usize = k_ksk; // Better capture noise. let k_out: usize = k_ksk; // Better capture noise.
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k_in.div_ceil(di * base2k); let dnum_in: usize = k_in / base2k_in;
let dnum_ksk: usize = k_in.div_ceil(base2k_key * dsize);
let dsize_in: usize = 1; let dsize_in: usize = 1;
let ggsw_in_infos: GGSWLayout = GGSWLayout { let ggsw_in_infos: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_in.into(),
k: k_in.into(), k: k_in.into(),
dnum: dnum.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
rank: rank.into(), rank: rank.into(),
}; };
let ggsw_out_infos: GGSWLayout = GGSWLayout { let ggsw_out_infos: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_out.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
rank: rank.into(), rank: rank.into(),
}; };
let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout { let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_key.into(),
k: k_tsk.into(), k: k_tsk.into(),
dnum: dnum.into(), dnum: dnum_ksk.into(),
dsize: di.into(), dsize: dsize.into(),
rank: rank.into(), rank: rank.into(),
}; };
let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_key.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum.into(), dnum: dnum_ksk.into(),
dsize: di.into(), dsize: dsize.into(),
rank_in: rank.into(), rank_in: rank.into(),
rank_out: rank.into(), rank_out: rank.into(),
}; };
@@ -163,7 +167,7 @@ where
let max_noise = |col_j: usize| -> f64 { let max_noise = |col_j: usize| -> f64 {
noise_ggsw_keyswitch( noise_ggsw_keyswitch(
n as f64, n as f64,
base2k * di, base2k_key * dsize,
col_j, col_j,
var_xs, var_xs,
0f64, 0f64,
@@ -195,43 +199,45 @@ where
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let base2k: usize = 12; let base2k_out: usize = 17;
let k_out: usize = 54; let base2k_key: usize = 13;
let dsize: usize = k_out.div_ceil(base2k); let k_out: usize = 102;
let max_dsize: usize = k_out.div_ceil(base2k_key);
for rank in 1_usize..3 { for rank in 1_usize..3 {
for di in 1..dsize + 1 { for dsize in 1..max_dsize + 1 {
let k_ksk: usize = k_out + base2k * di; let k_ksk: usize = k_out + base2k_key * dsize;
let k_tsk: usize = k_ksk; let k_tsk: usize = k_ksk;
let n: usize = module.n(); let n: usize = module.n();
let dnum: usize = k_out.div_ceil(di * base2k); let dnum_in: usize = k_out / base2k_out;
let dnum_ksk: usize = k_out.div_ceil(base2k_key * dsize);
let dsize_in: usize = 1; let dsize_in: usize = 1;
let ggsw_out_infos: GGSWLayout = GGSWLayout { let ggsw_out_infos: GGSWLayout = GGSWLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_out.into(),
k: k_out.into(), k: k_out.into(),
dnum: dnum.into(), dnum: dnum_in.into(),
dsize: dsize_in.into(), dsize: dsize_in.into(),
rank: rank.into(), rank: rank.into(),
}; };
let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout { let tsk_infos: GLWETensorKeyLayout = GLWETensorKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_key.into(),
k: k_tsk.into(), k: k_tsk.into(),
dnum: dnum.into(), dnum: dnum_ksk.into(),
dsize: di.into(), dsize: dsize.into(),
rank: rank.into(), rank: rank.into(),
}; };
let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout { let ksk_apply_infos: GLWESwitchingKeyLayout = GLWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_key.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum.into(), dnum: dnum_ksk.into(),
dsize: di.into(), dsize: dsize.into(),
rank_in: rank.into(), rank_in: rank.into(),
rank_out: rank.into(), rank_out: rank.into(),
}; };
@@ -311,7 +317,7 @@ where
let max_noise = |col_j: usize| -> f64 { let max_noise = |col_j: usize| -> f64 {
noise_ggsw_keyswitch( noise_ggsw_keyswitch(
n as f64, n as f64,
base2k * di, base2k_key * dsize,
col_j, col_j,
var_xs, var_xs,
0f64, 0f64,

View File

@@ -12,7 +12,6 @@ use crate::{
GLWESwitchingKeyPreparedFactory, LWEInfos, GLWESwitchingKeyPreparedFactory, LWEInfos,
prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared}, prepared::{GLWESecretPrepared, GLWESwitchingKeyPrepared},
}, },
noise::log2_std_noise_gglwe_product,
var_noise_gglwe_product_v2, var_noise_gglwe_product_v2,
}; };

View File

@@ -1,5 +1,5 @@
use poulpy_hal::{ use poulpy_hal::{
api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow}, api::{ScratchAvailable, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize},
layouts::{Backend, Module, Scratch, ScratchOwned, ZnxView}, layouts::{Backend, Module, Scratch, ScratchOwned, ZnxView},
source::Source, source::Source,
}; };
@@ -14,21 +14,27 @@ use crate::{
pub fn test_lwe_keyswitch<BE: Backend>(module: &Module<BE>) pub fn test_lwe_keyswitch<BE: Backend>(module: &Module<BE>)
where where
Module<BE>: Module<BE>: LWEKeySwitch<BE>
LWEKeySwitch<BE> + LWESwitchingKeyEncrypt<BE> + LWEEncryptSk<BE> + LWESwitchingKeyPreparedFactory<BE> + LWEDecrypt<BE>, + LWESwitchingKeyEncrypt<BE>
+ LWEEncryptSk<BE>
+ LWESwitchingKeyPreparedFactory<BE>
+ LWEDecrypt<BE>
+ VecZnxNormalize<BE>,
ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>, ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>, Scratch<BE>: ScratchAvailable + ScratchTakeCore<BE>,
{ {
let n: usize = module.n(); let n: usize = module.n();
let base2k: usize = 17; let base2k_in: usize = 17;
let base2k_out: usize = 15;
let base2k_key: usize = 13;
let n_lwe_in: usize = 22; let n_lwe_in: usize = module.n() >> 1;
let n_lwe_out: usize = 30; let n_lwe_out: usize = module.n() >> 1;
let k_lwe_ct: usize = 2 * base2k; let k_lwe_ct: usize = 102;
let k_lwe_pt: usize = 8; let k_lwe_pt: usize = 8;
let k_ksk: usize = k_lwe_ct + base2k; let k_ksk: usize = k_lwe_ct + base2k_key;
let dnum: usize = k_lwe_ct.div_ceil(base2k); let dnum: usize = k_lwe_ct.div_ceil(base2k_key);
let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]);
@@ -36,21 +42,21 @@ where
let key_apply_infos: LWESwitchingKeyLayout = LWESwitchingKeyLayout { let key_apply_infos: LWESwitchingKeyLayout = LWESwitchingKeyLayout {
n: n.into(), n: n.into(),
base2k: base2k.into(), base2k: base2k_key.into(),
k: k_ksk.into(), k: k_ksk.into(),
dnum: dnum.into(), dnum: dnum.into(),
}; };
let lwe_in_infos: LWELayout = LWELayout { let lwe_in_infos: LWELayout = LWELayout {
n: n_lwe_in.into(), n: n_lwe_in.into(),
base2k: base2k.into(), base2k: base2k_in.into(),
k: k_lwe_ct.into(), k: k_lwe_ct.into(),
}; };
let lwe_out_infos: LWELayout = LWELayout { let lwe_out_infos: LWELayout = LWELayout {
n: n_lwe_out.into(), n: n_lwe_out.into(),
k: k_lwe_ct.into(), k: k_lwe_ct.into(),
base2k: base2k.into(), base2k: base2k_out.into(),
}; };
let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc( let mut scratch: ScratchOwned<BE> = ScratchOwned::alloc(
@@ -66,7 +72,7 @@ where
let data: i64 = 17; let data: i64 = 17;
let mut lwe_pt_in: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc(base2k.into(), k_lwe_pt.into()); let mut lwe_pt_in: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc(base2k_in.into(), k_lwe_pt.into());
lwe_pt_in.encode_i64(data, k_lwe_pt.into()); lwe_pt_in.encode_i64(data, k_lwe_pt.into());
let mut lwe_ct_in: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_in_infos); let mut lwe_ct_in: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_in_infos);
@@ -76,6 +82,7 @@ where
&sk_lwe_in, &sk_lwe_in,
&mut source_xa, &mut source_xa,
&mut source_xe, &mut source_xe,
scratch.borrow(),
); );
let mut ksk: LWESwitchingKey<Vec<u8>> = LWESwitchingKey::alloc_from_infos(&key_apply_infos); let mut ksk: LWESwitchingKey<Vec<u8>> = LWESwitchingKey::alloc_from_infos(&key_apply_infos);
@@ -97,7 +104,18 @@ where
lwe_ct_out.keyswitch(module, &lwe_ct_in, &ksk_prepared, scratch.borrow()); lwe_ct_out.keyswitch(module, &lwe_ct_in, &ksk_prepared, scratch.borrow());
let mut lwe_pt_out: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc_from_infos(&lwe_out_infos); let mut lwe_pt_out: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc_from_infos(&lwe_out_infos);
lwe_ct_out.decrypt(module, &mut lwe_pt_out, &sk_lwe_out); lwe_ct_out.decrypt(module, &mut lwe_pt_out, &sk_lwe_out, scratch.borrow());
assert_eq!(lwe_pt_in.data.at(0, 0)[0], lwe_pt_out.data.at(0, 0)[0]); let mut lwe_pt_want: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc_from_infos(&lwe_out_infos);
module.vec_znx_normalize(
base2k_out,
lwe_pt_want.data_mut(),
0,
base2k_in,
lwe_pt_in.data(),
0,
scratch.borrow(),
);
assert_eq!(lwe_pt_want.data.at(0, 0)[0], lwe_pt_out.data.at(0, 0)[0]);
} }

View File

@@ -37,16 +37,20 @@ impl<D: DataRef> GLWEPlaintext<D> {
impl<D: DataMut> LWEPlaintext<D> { impl<D: DataMut> LWEPlaintext<D> {
pub fn encode_i64(&mut self, data: i64, k: TorusPrecision) { pub fn encode_i64(&mut self, data: i64, k: TorusPrecision) {
let base2k: usize = self.base2k().into(); let base2k: usize = self.base2k().into();
self.data.encode_i64(base2k, k.into(), data); self.data.encode_coeff_i64(base2k, 0, k.into(), 0, data);
} }
} }
impl<D: DataRef> LWEPlaintext<D> { impl<D: DataRef> LWEPlaintext<D> {
pub fn decode_i64(&self, k: TorusPrecision) -> i64 { pub fn decode_i64(&self, k: TorusPrecision) -> i64 {
self.data.decode_i64(self.base2k().into(), k.into()) self.data
.decode_coeff_i64(self.base2k().into(), 0, k.into(), 0)
} }
pub fn decode_float(&self) -> Float { pub fn decode_float(&self) -> Float {
self.data.decode_float(self.base2k().into()) let mut out: [Float; 1] = [Float::new(self.k().as_u32())];
self.data
.decode_vec_float(self.base2k().into(), 0, &mut out);
out[0].clone()
} }
} }

View File

@@ -6,7 +6,6 @@ mod vec_znx;
mod vec_znx_big; mod vec_znx_big;
mod vec_znx_dft; mod vec_znx_dft;
mod vmp_pmat; mod vmp_pmat;
mod zn;
pub use convolution::*; pub use convolution::*;
pub use module::*; pub use module::*;
@@ -16,4 +15,3 @@ pub use vec_znx::*;
pub use vec_znx_big::*; pub use vec_znx_big::*;
pub use vec_znx_dft::*; pub use vec_znx_dft::*;
pub use vmp_pmat::*; pub use vmp_pmat::*;
pub use zn::*;

View File

@@ -1,6 +1,6 @@
use crate::{ use crate::{
api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf}, api::{ModuleN, SvpPPolBytesOf, VecZnxBigBytesOf, VecZnxDftBytesOf, VmpPMatBytesOf},
layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, Zn}, layouts::{Backend, MatZnx, ScalarZnx, Scratch, SvpPPol, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
}; };
/// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes. /// Allocates a new [crate::layouts::ScratchOwned] of `size` aligned bytes.
@@ -69,11 +69,6 @@ where
(SvpPPol::from_data(take_slice, module.n(), cols), rem_slice) (SvpPPol::from_data(take_slice, module.n(), cols), rem_slice)
} }
fn take_zn(&mut self, n: usize, cols: usize, size: usize) -> (Zn<&mut [u8]>, &mut Self) {
let (take_slice, rem_slice) = self.take_slice(Zn::bytes_of(n, cols, size));
(Zn::from_data(take_slice, n, cols, size), rem_slice)
}
fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) { fn take_vec_znx(&mut self, n: usize, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size)); let (take_slice, rem_slice) = self.take_slice(VecZnx::bytes_of(n, cols, size));
(VecZnx::from_data(take_slice, n, cols, size), rem_slice) (VecZnx::from_data(take_slice, n, cols, size), rem_slice)

View File

@@ -1,58 +0,0 @@
use crate::{
layouts::{Backend, Scratch, ZnToMut},
reference::zn::zn_normalize_tmp_bytes,
source::Source,
};
pub trait ZnNormalizeTmpBytes {
fn zn_normalize_tmp_bytes(&self, n: usize) -> usize {
zn_normalize_tmp_bytes(n)
}
}
pub trait ZnNormalizeInplace<B: Backend> {
/// Normalizes the selected column of `a`.
fn zn_normalize_inplace<R>(&self, n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
where
R: ZnToMut;
}
pub trait ZnFillUniform {
/// Fills the first `size` size with uniform values in \[-2^{base2k-1}, 2^{base2k-1}\]
fn zn_fill_uniform<R>(&self, n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait ZnFillNormal {
fn zn_fill_normal<R>(
&self,
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
pub trait ZnAddNormal {
/// Adds a discrete normal vector scaled by 2^{-k} with the provided standard deviation and bounded to \[-bound, bound\].
fn zn_add_normal<R>(
&self,
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut;
}

View File

@@ -5,4 +5,3 @@ mod vec_znx;
mod vec_znx_big; mod vec_znx_big;
mod vec_znx_dft; mod vec_znx_dft;
mod vmp_pmat; mod vmp_pmat;
mod zn;

View File

@@ -1,81 +0,0 @@
use crate::{
api::{ZnAddNormal, ZnFillNormal, ZnFillUniform, ZnNormalizeInplace, ZnNormalizeTmpBytes},
layouts::{Backend, Module, Scratch, ZnToMut},
oep::{ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl},
source::Source,
};
impl<B> ZnNormalizeTmpBytes for Module<B>
where
B: Backend + ZnNormalizeTmpBytesImpl<B>,
{
fn zn_normalize_tmp_bytes(&self, n: usize) -> usize {
B::zn_normalize_tmp_bytes_impl(n)
}
}
impl<B> ZnNormalizeInplace<B> for Module<B>
where
B: Backend + ZnNormalizeInplaceImpl<B>,
{
fn zn_normalize_inplace<A>(&self, n: usize, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<B>)
where
A: ZnToMut,
{
B::zn_normalize_inplace_impl(n, base2k, a, a_col, scratch)
}
}
impl<B> ZnFillUniform for Module<B>
where
B: Backend + ZnFillUniformImpl<B>,
{
fn zn_fill_uniform<R>(&self, n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
B::zn_fill_uniform_impl(n, base2k, res, res_col, source);
}
}
impl<B> ZnFillNormal for Module<B>
where
B: Backend + ZnFillNormalImpl<B>,
{
fn zn_fill_normal<R>(
&self,
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
B::zn_fill_normal_impl(n, base2k, res, res_col, k, source, sigma, bound);
}
}
impl<B> ZnAddNormal for Module<B>
where
B: Backend + ZnAddNormalImpl<B>,
{
fn zn_add_normal<R>(
&self,
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
B::zn_add_normal_impl(n, base2k, res, res_col, k, source, sigma, bound);
}
}

View File

@@ -2,7 +2,7 @@ use itertools::izip;
use rug::{Assign, Float}; use rug::{Assign, Float};
use crate::{ use crate::{
layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, Zn, ZnToMut, ZnToRef, ZnxInfos, ZnxView, ZnxViewMut}, layouts::{DataMut, DataRef, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ reference::znx::{
ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStepInplace, ZnxRef, ZnxZero, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStepInplace, ZnxRef, ZnxZero,
get_carry_i128, get_digit_i128, znx_zero_ref, get_carry_i128, get_digit_i128, znx_zero_ref,
@@ -245,90 +245,6 @@ impl<D: DataRef> VecZnx<D> {
} }
} }
impl<D: DataMut> Zn<D> {
pub fn encode_i64(&mut self, base2k: usize, k: usize, data: i64) {
let size: usize = k.div_ceil(base2k);
#[cfg(debug_assertions)]
{
let a: Zn<&mut [u8]> = self.to_mut();
assert!(
size <= a.size(),
"invalid argument k.div_ceil(base2k)={} > a.size()={}",
size,
a.size()
);
}
let mut a: Zn<&mut [u8]> = self.to_mut();
let a_size = a.size();
for j in 0..a_size {
a.at_mut(0, j)[0] = 0
}
a.at_mut(0, size - 1)[0] = data;
let mut carry: Vec<i64> = vec![0i64; 1];
let k_rem: usize = (base2k - (k % base2k)) % base2k;
for j in (0..size).rev() {
let slice = &mut a.at_mut(0, j)[..1];
if j == size - 1 {
ZnxRef::znx_normalize_first_step_inplace(base2k, k_rem, slice, &mut carry);
} else if j == 0 {
ZnxRef::znx_normalize_final_step_inplace(base2k, k_rem, slice, &mut carry);
} else {
ZnxRef::znx_normalize_middle_step_inplace(base2k, k_rem, slice, &mut carry);
}
}
}
}
impl<D: DataRef> Zn<D> {
pub fn decode_i64(&self, base2k: usize, k: usize) -> i64 {
let a: Zn<&[u8]> = self.to_ref();
let size: usize = k.div_ceil(base2k);
let mut res: i64 = 0;
let rem: usize = base2k - (k % base2k);
(0..size).for_each(|j| {
let x: i64 = a.at(0, j)[0];
if j == size - 1 && rem != base2k {
let k_rem: usize = (base2k - rem) % base2k;
let scale: i64 = 1 << rem as i64;
res = (res << k_rem) + div_round(x, scale);
} else {
res = (res << base2k) + x;
}
});
res
}
pub fn decode_float(&self, base2k: usize) -> Float {
let a: Zn<&[u8]> = self.to_ref();
let size: usize = a.size();
let prec: u32 = (base2k * size) as u32;
// 2^{base2k}
let base: Float = Float::with_val(prec, (1 << base2k) as f64);
let mut res: Float = Float::with_val(prec, (1 << base2k) as f64);
// y[i] = sum x[j][i] * 2^{-base2k*j}
(0..size).for_each(|i| {
if i == 0 {
res.assign(a.at(0, size - i - 1)[0]);
res /= &base;
} else {
res += Float::with_val(prec, a.at(0, size - i - 1)[0]);
res /= &base;
}
});
res
}
}
#[inline] #[inline]
pub fn div_round(a: i64, b: i64) -> i64 { pub fn div_round(a: i64, b: i64) -> i64 {
assert!(b != 0, "division by zero"); assert!(b != 0, "division by zero");

View File

@@ -10,7 +10,6 @@ mod vec_znx;
mod vec_znx_big; mod vec_znx_big;
mod vec_znx_dft; mod vec_znx_dft;
mod vmp_pmat; mod vmp_pmat;
mod zn;
mod znx_base; mod znx_base;
pub use mat_znx::*; pub use mat_znx::*;
@@ -24,7 +23,6 @@ pub use vec_znx::*;
pub use vec_znx_big::*; pub use vec_znx_big::*;
pub use vec_znx_dft::*; pub use vec_znx_dft::*;
pub use vmp_pmat::*; pub use vmp_pmat::*;
pub use zn::*;
pub use znx_base::*; pub use znx_base::*;
pub trait Data = PartialEq + Eq + Sized + Default; pub trait Data = PartialEq + Eq + Sized + Default;

View File

@@ -1,273 +0,0 @@
use std::{
fmt,
hash::{DefaultHasher, Hasher},
};
use crate::{
alloc_aligned,
layouts::{
Data, DataMut, DataRef, DataView, DataViewMut, DigestU64, FillUniform, ReaderFrom, ToOwnedDeep, WriterTo, ZnxInfos,
ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero,
},
source::Source,
};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use rand::RngCore;
#[repr(C)]
#[derive(PartialEq, Eq, Clone, Copy, Hash)]
pub struct Zn<D: Data> {
pub data: D,
pub n: usize,
pub cols: usize,
pub size: usize,
pub max_size: usize,
}
impl<D: DataRef> DigestU64 for Zn<D> {
fn digest_u64(&self) -> u64 {
let mut h: DefaultHasher = DefaultHasher::new();
h.write(self.data.as_ref());
h.write_usize(self.n);
h.write_usize(self.cols);
h.write_usize(self.size);
h.write_usize(self.max_size);
h.finish()
}
}
impl<D: DataRef> ToOwnedDeep for Zn<D> {
type Owned = Zn<Vec<u8>>;
fn to_owned_deep(&self) -> Self::Owned {
Zn {
data: self.data.as_ref().to_vec(),
n: self.n,
cols: self.cols,
size: self.size,
max_size: self.max_size,
}
}
}
impl<D: DataRef> fmt::Debug for Zn<D> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{self}")
}
}
impl<D: Data> ZnxInfos for Zn<D> {
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
}
impl<D: Data> ZnxSliceSize for Zn<D> {
fn sl(&self) -> usize {
self.n() * self.cols()
}
}
impl<D: Data> DataView for Zn<D> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D: Data> DataViewMut for Zn<D> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: DataRef> ZnxView for Zn<D> {
type Scalar = i64;
}
impl Zn<Vec<u8>> {
pub fn rsh_tmp_bytes(n: usize) -> usize {
n * std::mem::size_of::<i64>()
}
}
impl<D: DataMut> ZnxZero for Zn<D> {
fn zero(&mut self) {
self.raw_mut().fill(0)
}
fn zero_at(&mut self, i: usize, j: usize) {
self.at_mut(i, j).fill(0);
}
}
impl Zn<Vec<u8>> {
pub fn bytes_of(n: usize, cols: usize, size: usize) -> usize {
n * cols * size * size_of::<i64>()
}
pub fn alloc(n: usize, cols: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned::<u8>(Self::bytes_of(n, cols, size));
Self {
data,
n,
cols,
size,
max_size: size,
}
}
pub fn from_bytes<Scalar: Sized>(n: usize, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of(n, cols, size));
Self {
data,
n,
cols,
size,
max_size: size,
}
}
}
impl<D: Data> Zn<D> {
pub fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
Self {
data,
n,
cols,
size,
max_size: size,
}
}
}
impl<D: DataRef> fmt::Display for Zn<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"Zn(n={}, cols={}, size={})",
self.n, self.cols, self.size
)?;
for col in 0..self.cols {
writeln!(f, "Column {col}:")?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {size}: [")?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{coeff}")?;
}
if coeffs.len() > max_show {
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
}
writeln!(f, "]")?;
}
}
Ok(())
}
}
impl<D: DataMut> FillUniform for Zn<D> {
fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
match log_bound {
64 => source.fill_bytes(self.data.as_mut()),
0 => panic!("invalid log_bound, cannot be zero"),
_ => {
let mask: u64 = (1u64 << log_bound) - 1;
for x in self.raw_mut().iter_mut() {
let r = source.next_u64() & mask;
*x = ((r << (64 - log_bound)) as i64) >> (64 - log_bound);
}
}
}
}
}
pub type ZnOwned = Zn<Vec<u8>>;
pub type ZnMut<'a> = Zn<&'a mut [u8]>;
pub type ZnRef<'a> = Zn<&'a [u8]>;
pub trait ZnToRef {
fn to_ref(&self) -> Zn<&[u8]>;
}
impl<D: DataRef> ZnToRef for Zn<D> {
fn to_ref(&self) -> Zn<&[u8]> {
Zn {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
size: self.size,
max_size: self.max_size,
}
}
}
pub trait ZnToMut {
fn to_mut(&mut self) -> Zn<&mut [u8]>;
}
impl<D: DataMut> ZnToMut for Zn<D> {
fn to_mut(&mut self) -> Zn<&mut [u8]> {
Zn {
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
size: self.size,
max_size: self.max_size,
}
}
}
impl<D: DataMut> ReaderFrom for Zn<D> {
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
self.n = reader.read_u64::<LittleEndian>()? as usize;
self.cols = reader.read_u64::<LittleEndian>()? as usize;
self.size = reader.read_u64::<LittleEndian>()? as usize;
self.max_size = reader.read_u64::<LittleEndian>()? as usize;
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
let buf: &mut [u8] = self.data.as_mut();
if buf.len() != len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!("self.data.len()={} != read len={}", buf.len(), len),
));
}
reader.read_exact(&mut buf[..len])?;
Ok(())
}
}
impl<D: DataRef> WriterTo for Zn<D> {
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
writer.write_u64::<LittleEndian>(self.n as u64)?;
writer.write_u64::<LittleEndian>(self.cols as u64)?;
writer.write_u64::<LittleEndian>(self.size as u64)?;
writer.write_u64::<LittleEndian>(self.max_size as u64)?;
let buf: &[u8] = self.data.as_ref();
writer.write_u64::<LittleEndian>(buf.len() as u64)?;
writer.write_all(buf)?;
Ok(())
}
}

View File

@@ -5,7 +5,6 @@ mod vec_znx;
mod vec_znx_big; mod vec_znx_big;
mod vec_znx_dft; mod vec_znx_dft;
mod vmp_pmat; mod vmp_pmat;
mod zn;
pub use module::*; pub use module::*;
pub use scratch::*; pub use scratch::*;
@@ -14,4 +13,3 @@ pub use vec_znx::*;
pub use vec_znx_big::*; pub use vec_znx_big::*;
pub use vec_znx_dft::*; pub use vec_znx_dft::*;
pub use vmp_pmat::*; pub use vmp_pmat::*;
pub use zn::*;

View File

@@ -1,70 +0,0 @@
use crate::{
layouts::{Backend, Scratch, ZnToMut},
source::Source,
};
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation.
/// * See [crate::api::ZnNormalizeTmpBytes] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnNormalizeTmpBytesImpl<B: Backend> {
fn zn_normalize_tmp_bytes_impl(n: usize) -> usize;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation.
/// * See [crate::api::ZnNormalizeInplace] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnNormalizeInplaceImpl<B: Backend> {
fn zn_normalize_inplace_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<B>)
where
R: ZnToMut;
}
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation.
/// * See [crate::api::ZnFillUniform] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnFillUniformImpl<B: Backend> {
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation.
/// * See [crate::api::ZnFillNormal] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnFillNormalImpl<B: Backend> {
fn zn_fill_normal_impl<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut;
}
#[allow(clippy::too_many_arguments)]
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
/// * See [poulpy-backend/src/cpu_fft64_ref/zn.rs](https://github.com/phantomzone-org/poulpy/blob/main/poulpy-backend/src/cpu_fft64_ref/zn.rs) for reference implementation.
/// * See [crate::api::ZnAddNormal] for corresponding public API.
/// # Safety [crate::doc::backend_safety] for safety contract.
pub unsafe trait ZnAddNormalImpl<B: Backend> {
fn zn_add_normal_impl<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut;
}

View File

@@ -1,4 +1,3 @@
pub mod fft64; pub mod fft64;
pub mod vec_znx; pub mod vec_znx;
pub mod zn;
pub mod znx; pub mod znx;

View File

@@ -53,6 +53,8 @@ pub fn vec_znx_normalize<R, A, ZNXARI>(
let res_size: usize = res.size(); let res_size: usize = res.size();
let a_size: usize = a.size(); let a_size: usize = a.size();
let carry = &mut carry[..2 * n];
if res_base2k == a_base2k { if res_base2k == a_base2k {
if a_size > res_size { if a_size > res_size {
for j in (res_size..a_size).rev() { for j in (res_size..a_size).rev() {

View File

@@ -1,5 +0,0 @@
mod normalization;
mod sampling;
pub use normalization::*;
pub use sampling::*;

View File

@@ -1,72 +0,0 @@
use crate::{
api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace, ZnNormalizeTmpBytes},
layouts::{Backend, Module, ScratchOwned, Zn, ZnToMut, ZnxInfos, ZnxView, ZnxViewMut},
reference::znx::{ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStepInplace, ZnxRef},
source::Source,
};
pub fn zn_normalize_tmp_bytes(n: usize) -> usize {
n * size_of::<i64>()
}
pub fn zn_normalize_inplace<R, ARI>(n: usize, base2k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
where
R: ZnToMut,
ARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeFinalStepInplace + ZnxNormalizeMiddleStepInplace,
{
let mut res: Zn<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(carry.len(), res.n());
}
let res_size: usize = res.size();
for j in (0..res_size).rev() {
let out = &mut res.at_mut(res_col, j)[..n];
if j == res_size - 1 {
ARI::znx_normalize_first_step_inplace(base2k, 0, out, carry);
} else if j == 0 {
ARI::znx_normalize_final_step_inplace(base2k, 0, out, carry);
} else {
ARI::znx_normalize_middle_step_inplace(base2k, 0, out, carry);
}
}
}
pub fn test_zn_normalize_inplace<B: Backend>(module: &Module<B>)
where
Module<B>: ZnNormalizeInplace<B> + ZnNormalizeTmpBytes,
ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
{
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let base2k: usize = 12;
let n = 33;
let mut carry: Vec<i64> = vec![0i64; zn_normalize_tmp_bytes(n)];
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.zn_normalize_tmp_bytes(module.n()));
for res_size in [1, 2, 6, 11] {
let mut res_0: Zn<Vec<u8>> = Zn::alloc(n, cols, res_size);
let mut res_1: Zn<Vec<u8>> = Zn::alloc(n, cols, res_size);
res_0
.raw_mut()
.iter_mut()
.for_each(|x| *x = source.next_i32() as i64);
res_1.raw_mut().copy_from_slice(res_0.raw());
// Reference
for i in 0..cols {
zn_normalize_inplace::<_, ZnxRef>(n, base2k, &mut res_0, i, &mut carry);
module.zn_normalize_inplace(n, base2k, &mut res_1, i, scratch.borrow());
}
assert_eq!(res_0.raw(), res_1.raw());
}
}

View File

@@ -1,75 +0,0 @@
use crate::{
layouts::{Zn, ZnToMut, ZnxInfos, ZnxViewMut},
reference::znx::{znx_add_normal_f64_ref, znx_fill_normal_f64_ref, znx_fill_uniform_ref},
source::Source,
};
pub fn zn_fill_uniform<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
let mut res: Zn<&mut [u8]> = res.to_mut();
for j in 0..res.size() {
znx_fill_uniform_ref(base2k, &mut res.at_mut(res_col, j)[..n], source)
}
}
#[allow(clippy::too_many_arguments)]
pub fn zn_fill_normal<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
let mut res: Zn<&mut [u8]> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_fill_normal_f64_ref(
&mut res.at_mut(res_col, limb)[..n],
sigma * scale,
bound * scale,
source,
)
}
#[allow(clippy::too_many_arguments)]
pub fn zn_add_normal<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
let mut res: Zn<&mut [u8]> = res.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = k.div_ceil(base2k) - 1;
let scale: f64 = (1 << ((limb + 1) * base2k - k)) as f64;
znx_add_normal_f64_ref(
&mut res.at_mut(res_col, limb)[..n],
sigma * scale,
bound * scale,
source,
)
}

View File

@@ -2,7 +2,7 @@ use poulpy_core::{
GLWENormalize, GLWENormalize,
layouts::{ layouts::{
GGLWEToGGSWKeyLayout, GGSW, GGSWLayout, GLWE, GLWEAutomorphismKeyLayout, GLWELayout, GLWEPlaintext, GLWESecret, LWE, GGLWEToGGSWKeyLayout, GGSW, GGSWLayout, GLWE, GLWEAutomorphismKeyLayout, GLWELayout, GLWEPlaintext, GLWESecret, LWE,
LWEInfos, LWELayout, LWEPlaintext, LWESecret, LWELayout, LWEPlaintext, LWESecret,
prepared::{GGSWPrepared, GLWESecretPrepared}, prepared::{GGSWPrepared, GLWESecretPrepared},
}, },
}; };
@@ -15,7 +15,7 @@ use poulpy_backend::FFT64Avx as BackendImpl;
use poulpy_backend::FFT64Ref as BackendImpl; use poulpy_backend::FFT64Ref as BackendImpl;
use poulpy_hal::{ use poulpy_hal::{
api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, ZnNormalizeInplace}, api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalizeInplace},
layouts::{Module, ScalarZnx, ScratchOwned, ZnxView, ZnxViewMut}, layouts::{Module, ScalarZnx, ScratchOwned, ZnxView, ZnxViewMut},
source::Source, source::Source,
}; };
@@ -155,20 +155,21 @@ fn main() {
pt_lwe.encode_i64(data, (k_lwe_pt + 1).into()); // +1 for padding bit pt_lwe.encode_i64(data, (k_lwe_pt + 1).into()); // +1 for padding bit
// Normalize plaintext to nicely print coefficients // Normalize plaintext to nicely print coefficients
module.zn_normalize_inplace( module.vec_znx_normalize_inplace(base2k, pt_lwe.data_mut(), 0, scratch.borrow());
pt_lwe.n().into(),
base2k,
pt_lwe.data_mut(),
0,
scratch.borrow(),
);
println!("pt_lwe: {pt_lwe}"); println!("pt_lwe: {pt_lwe}");
// LWE ciphertext // LWE ciphertext
let mut ct_lwe: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos); let mut ct_lwe: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos);
// Encrypt LWE Plaintext // Encrypt LWE Plaintext
ct_lwe.encrypt_sk(&module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); ct_lwe.encrypt_sk(
&module,
&pt_lwe,
&sk_lwe,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let now: Instant = Instant::now(); let now: Instant = Instant::now();

View File

@@ -111,7 +111,14 @@ pub fn test_blind_rotation<BRA: BlindRotationAlgo, M, BE: Backend>(
pt_lwe.encode_i64(x, (log_message_modulus + 1).into()); pt_lwe.encode_i64(x, (log_message_modulus + 1).into());
lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); lwe.encrypt_sk(
module,
&pt_lwe,
&sk_lwe,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let f = |x: i64| -> i64 { 2 * x + 1 }; let f = |x: i64| -> i64 { 2 * x + 1 };

View File

@@ -132,7 +132,14 @@ where
println!("pt_lwe: {pt_lwe}"); println!("pt_lwe: {pt_lwe}");
let mut ct_lwe: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos); let mut ct_lwe: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos);
ct_lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); ct_lwe.encrypt_sk(
module,
&pt_lwe,
&sk_lwe,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let now: Instant = Instant::now(); let now: Instant = Instant::now();
let mut cbt_key: CircuitBootstrappingKey<Vec<u8>, BRA> = CircuitBootstrappingKey::alloc_from_infos(&cbt_infos); let mut cbt_key: CircuitBootstrappingKey<Vec<u8>, BRA> = CircuitBootstrappingKey::alloc_from_infos(&cbt_infos);
@@ -313,7 +320,14 @@ where
println!("pt_lwe: {pt_lwe}"); println!("pt_lwe: {pt_lwe}");
let mut ct_lwe: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos); let mut ct_lwe: LWE<Vec<u8>> = LWE::alloc_from_infos(&lwe_infos);
ct_lwe.encrypt_sk(module, &pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe); ct_lwe.encrypt_sk(
module,
&pt_lwe,
&sk_lwe,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
);
let now: Instant = Instant::now(); let now: Instant = Instant::now();
let mut cbt_key: CircuitBootstrappingKey<Vec<u8>, BRA> = CircuitBootstrappingKey::alloc_from_infos(&cbt_infos); let mut cbt_key: CircuitBootstrappingKey<Vec<u8>, BRA> = CircuitBootstrappingKey::alloc_from_infos(&cbt_infos);