mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Add cross-basek normalization (#90)
* added cross_basek_normalization * updated method signatures to take layouts * fixed cross-base normalization fix #91 fix #93
This commit is contained in:
committed by
GitHub
parent
4da790ea6a
commit
37e13b965c
@@ -4,34 +4,33 @@ use poulpy_hal::{
|
||||
ScratchAvailable, SvpApplyDftToDft, SvpPPolAllocBytes, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice,
|
||||
TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply,
|
||||
VecZnxDftSubABInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes,
|
||||
VecZnxDftSubInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes,
|
||||
VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate,
|
||||
VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol, VecZnx, ZnxView, ZnxZero},
|
||||
};
|
||||
|
||||
use poulpy_core::{
|
||||
Distribution, GLWEOperations, TakeGLWECt,
|
||||
layouts::{GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, LWECiphertextToRef},
|
||||
layouts::{GGSWInfos, GLWECiphertext, GLWECiphertextToMut, GLWEInfos, LWECiphertext, LWECiphertextToRef, LWEInfos},
|
||||
};
|
||||
|
||||
use crate::tfhe::blind_rotation::{
|
||||
BlincRotationExecute, BlindRotationKeyPrepared, CGGI, LookUpTable, LookUpTableRotationDirection,
|
||||
BlincRotationExecute, BlindRotationKeyInfos, BlindRotationKeyPrepared, CGGI, LookUpTable, LookUpTableRotationDirection,
|
||||
};
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn cggi_blind_rotate_scratch_space<B: Backend>(
|
||||
pub fn cggi_blind_rotate_scratch_space<B: Backend, OUT, GGSW>(
|
||||
module: &Module<B>,
|
||||
block_size: usize,
|
||||
extension_factor: usize,
|
||||
basek: usize,
|
||||
k_res: usize,
|
||||
k_brk: usize,
|
||||
rows: usize,
|
||||
rank: usize,
|
||||
glwe_infos: &OUT,
|
||||
brk_infos: &GGSW,
|
||||
) -> usize
|
||||
where
|
||||
OUT: GLWEInfos,
|
||||
GGSW: GGSWInfos,
|
||||
Module<B>: VecZnxDftAllocBytes
|
||||
+ VmpApplyDftToDftTmpBytes
|
||||
+ VecZnxNormalizeTmpBytes
|
||||
@@ -39,10 +38,11 @@ where
|
||||
+ VecZnxIdftApplyTmpBytes
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
{
|
||||
let brk_size: usize = k_brk.div_ceil(basek);
|
||||
let brk_size: usize = brk_infos.size();
|
||||
|
||||
if block_size > 1 {
|
||||
let cols: usize = rank + 1;
|
||||
let cols: usize = (brk_infos.rank() + 1).into();
|
||||
let rows: usize = brk_infos.rows().into();
|
||||
let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, rows) * extension_factor;
|
||||
let acc_big: usize = module.vec_znx_big_alloc_bytes(1, brk_size);
|
||||
let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor;
|
||||
@@ -50,7 +50,7 @@ where
|
||||
let acc_dft_add: usize = vmp_res;
|
||||
let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
|
||||
let acc: usize = if extension_factor > 1 {
|
||||
VecZnx::alloc_bytes(module.n(), cols, k_res.div_ceil(basek)) * extension_factor
|
||||
VecZnx::alloc_bytes(module.n(), cols, glwe_infos.size()) * extension_factor
|
||||
} else {
|
||||
0
|
||||
};
|
||||
@@ -61,8 +61,8 @@ where
|
||||
+ vmp_xai
|
||||
+ (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_apply_tmp_bytes())))
|
||||
} else {
|
||||
GLWECiphertext::bytes_of(module.n(), basek, k_res, rank)
|
||||
+ GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank)
|
||||
GLWECiphertext::alloc_bytes(glwe_infos)
|
||||
+ GLWECiphertext::external_product_inplace_scratch_space(module, glwe_infos, brk_infos)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,11 +80,11 @@ where
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxDftZero<B>
|
||||
+ SvpApplyDftToDft<B>
|
||||
+ VecZnxDftSubABInplace<B>
|
||||
+ VecZnxDftSubInplace<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxRotate
|
||||
+ VecZnxAddInplace
|
||||
+ VecZnxSubABInplace
|
||||
+ VecZnxSubInplace
|
||||
+ VecZnxNormalize<B>
|
||||
+ VecZnxNormalizeInplace<B>
|
||||
+ VecZnxCopy
|
||||
@@ -142,11 +142,11 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxDftZero<B>
|
||||
+ SvpApplyDftToDft<B>
|
||||
+ VecZnxDftSubABInplace<B>
|
||||
+ VecZnxDftSubInplace<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxRotate
|
||||
+ VecZnxAddInplace
|
||||
+ VecZnxSubABInplace
|
||||
+ VecZnxSubInplace
|
||||
+ VecZnxNormalize<B>
|
||||
+ VecZnxNormalizeInplace<B>
|
||||
+ VecZnxCopy
|
||||
@@ -155,11 +155,11 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
+ VmpApplyDftToDft<B>,
|
||||
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx,
|
||||
{
|
||||
let n_glwe: usize = brk.n();
|
||||
let n_glwe: usize = brk.n_glwe().into();
|
||||
let extension_factor: usize = lut.extension_factor();
|
||||
let basek: usize = res.basek();
|
||||
let rows: usize = brk.rows();
|
||||
let cols: usize = res.rank() + 1;
|
||||
let base2k: usize = res.base2k().into();
|
||||
let rows: usize = brk.rows().into();
|
||||
let cols: usize = (res.rank() + 1).into();
|
||||
|
||||
let (mut acc, scratch_1) = scratch.take_vec_znx_slice(extension_factor, n_glwe, cols, res.size());
|
||||
let (mut acc_dft, scratch_2) = scratch_1.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, rows);
|
||||
@@ -178,7 +178,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
panic!("invalid key: x_pow_a has not been initialized")
|
||||
}
|
||||
|
||||
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
||||
let mut lwe_2n: Vec<i64> = vec![0i64; (lwe.n() + 1).as_usize()]; // TODO: from scratch space
|
||||
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
|
||||
|
||||
let two_n: usize = 2 * n_glwe;
|
||||
@@ -233,7 +233,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
(0..cols).for_each(|i| {
|
||||
module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], i);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_xai, 0);
|
||||
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i);
|
||||
module.vec_znx_dft_sub_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -249,7 +249,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
(0..cols).for_each(|k| {
|
||||
module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_hi + 1], 0, &vmp_res[j], k);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_xai, 0);
|
||||
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
|
||||
module.vec_znx_dft_sub_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -261,7 +261,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
(0..cols).for_each(|k| {
|
||||
module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], k);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_xai, 0);
|
||||
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
|
||||
module.vec_znx_dft_sub_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -275,14 +275,14 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_idft_apply(&mut acc_add_big, 0, &acc_add_dft[j], i, scratch7);
|
||||
module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &acc[j], i);
|
||||
module.vec_znx_big_normalize(basek, &mut acc[j], i, &acc_add_big, 0, scratch7);
|
||||
module.vec_znx_big_normalize(base2k, &mut acc[j], i, base2k, &acc_add_big, 0, scratch7);
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_copy(&mut res.data, i, &acc[0], i);
|
||||
module.vec_znx_copy(res.data_mut(), i, &acc[0], i);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -309,11 +309,11 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxDftZero<B>
|
||||
+ SvpApplyDftToDft<B>
|
||||
+ VecZnxDftSubABInplace<B>
|
||||
+ VecZnxDftSubInplace<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxRotate
|
||||
+ VecZnxAddInplace
|
||||
+ VecZnxSubABInplace
|
||||
+ VecZnxSubInplace
|
||||
+ VecZnxNormalize<B>
|
||||
+ VecZnxNormalizeInplace<B>
|
||||
+ VecZnxCopy
|
||||
@@ -322,15 +322,15 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
+ VecZnxBigNormalize<B>,
|
||||
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx,
|
||||
{
|
||||
let n_glwe: usize = brk.n();
|
||||
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
||||
let n_glwe: usize = brk.n_glwe().into();
|
||||
let mut lwe_2n: Vec<i64> = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space
|
||||
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
|
||||
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
|
||||
let two_n: usize = n_glwe << 1;
|
||||
let basek: usize = brk.basek();
|
||||
let rows: usize = brk.rows();
|
||||
let base2k: usize = brk.base2k().into();
|
||||
let rows: usize = brk.rows().into();
|
||||
|
||||
let cols: usize = out_mut.rank() + 1;
|
||||
let cols: usize = (out_mut.rank() + 1).into();
|
||||
|
||||
mod_switch_2n(
|
||||
2 * lut.domain_size(),
|
||||
@@ -342,10 +342,10 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
let a: &[i64] = &lwe_2n[1..];
|
||||
let b: i64 = lwe_2n[0];
|
||||
|
||||
out_mut.data.zero();
|
||||
out_mut.data_mut().zero();
|
||||
|
||||
// Initialize out to X^{b} * LUT(X)
|
||||
module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0);
|
||||
module.vec_znx_rotate(b, out_mut.data_mut(), 0, &lut.data[0], 0);
|
||||
|
||||
let block_size: usize = brk.block_size();
|
||||
|
||||
@@ -369,7 +369,7 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
)
|
||||
.for_each(|(ai, ski)| {
|
||||
(0..cols).for_each(|j| {
|
||||
module.vec_znx_dft_apply(1, 0, &mut acc_dft, j, &out_mut.data, j);
|
||||
module.vec_znx_dft_apply(1, 0, &mut acc_dft, j, out_mut.data_mut(), j);
|
||||
});
|
||||
|
||||
module.vec_znx_dft_zero(&mut acc_add_dft);
|
||||
@@ -384,7 +384,7 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
(0..cols).for_each(|i| {
|
||||
module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_pos], 0, &vmp_res, i);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft, i, &vmp_xai, 0);
|
||||
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft, i, &vmp_res, i);
|
||||
module.vec_znx_dft_sub_inplace(&mut acc_add_dft, i, &vmp_res, i);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -393,8 +393,16 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_idft_apply(&mut acc_add_big, 0, &acc_add_dft, i, scratch_5);
|
||||
module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &out_mut.data, i);
|
||||
module.vec_znx_big_normalize(basek, &mut out_mut.data, i, &acc_add_big, 0, scratch_5);
|
||||
module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, out_mut.data_mut(), i);
|
||||
module.vec_znx_big_normalize(
|
||||
base2k,
|
||||
out_mut.data_mut(),
|
||||
i,
|
||||
base2k,
|
||||
&acc_add_big,
|
||||
0,
|
||||
scratch_5,
|
||||
);
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -423,11 +431,11 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
+ VecZnxDftApply<B>
|
||||
+ VecZnxDftZero<B>
|
||||
+ SvpApplyDftToDft<B>
|
||||
+ VecZnxDftSubABInplace<B>
|
||||
+ VecZnxDftSubInplace<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ VecZnxRotate
|
||||
+ VecZnxAddInplace
|
||||
+ VecZnxSubABInplace
|
||||
+ VecZnxSubInplace
|
||||
+ VecZnxNormalize<B>
|
||||
+ VecZnxNormalizeInplace<B>
|
||||
+ VecZnxCopy
|
||||
@@ -450,10 +458,10 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
);
|
||||
assert_eq!(
|
||||
lut.domain_size(),
|
||||
brk.n(),
|
||||
brk.n_glwe().as_usize(),
|
||||
"lut.n(): {} != brk.n(): {}",
|
||||
lut.domain_size(),
|
||||
brk.n()
|
||||
brk.n_glwe().as_usize()
|
||||
);
|
||||
assert_eq!(
|
||||
res.rank(),
|
||||
@@ -464,17 +472,16 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
);
|
||||
assert_eq!(
|
||||
lwe.n(),
|
||||
brk.data.len(),
|
||||
brk.n_lwe(),
|
||||
"lwe.n(): {} != brk.data.len(): {}",
|
||||
lwe.n(),
|
||||
brk.data.len()
|
||||
brk.n_lwe()
|
||||
);
|
||||
}
|
||||
|
||||
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
||||
let mut lwe_2n: Vec<i64> = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space
|
||||
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
|
||||
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
|
||||
let basek: usize = brk.basek();
|
||||
|
||||
mod_switch_2n(
|
||||
2 * lut.domain_size(),
|
||||
@@ -486,13 +493,13 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
let a: &[i64] = &lwe_2n[1..];
|
||||
let b: i64 = lwe_2n[0];
|
||||
|
||||
out_mut.data.zero();
|
||||
out_mut.data_mut().zero();
|
||||
|
||||
// Initialize out to X^{b} * LUT(X)
|
||||
module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0);
|
||||
module.vec_znx_rotate(b, out_mut.data_mut(), 0, &lut.data[0], 0);
|
||||
|
||||
// ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)]
|
||||
let (mut acc_tmp, scratch_1) = scratch.take_glwe_ct(out_mut.n(), basek, out_mut.k(), out_mut.rank());
|
||||
let (mut acc_tmp, scratch_1) = scratch.take_glwe_ct(&out_mut);
|
||||
|
||||
// TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs
|
||||
// TODO: first iteration can be optimized to be a gglwe product
|
||||
@@ -507,13 +514,13 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
out_mut.add_inplace(module, &acc_tmp);
|
||||
});
|
||||
|
||||
// We can normalize only at the end because we add normalized values in [-2^{basek-1}, 2^{basek-1}]
|
||||
// on top of each others, thus ~ 2^{63-basek} additions are supported before overflow.
|
||||
// We can normalize only at the end because we add normalized values in [-2^{base2k-1}, 2^{base2k-1}]
|
||||
// on top of each others, thus ~ 2^{63-base2k} additions are supported before overflow.
|
||||
out_mut.normalize_inplace(module, scratch_1);
|
||||
}
|
||||
|
||||
pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>, rot_dir: LookUpTableRotationDirection) {
|
||||
let basek: usize = lwe.basek();
|
||||
let base2k: usize = lwe.base2k().into();
|
||||
|
||||
let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1;
|
||||
|
||||
@@ -526,23 +533,23 @@ pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>, rot_
|
||||
LookUpTableRotationDirection::Right => {}
|
||||
}
|
||||
|
||||
if basek > log2n {
|
||||
let diff: usize = basek - (log2n - 1); // additional -1 because we map to [-N/2, N/2) instead of [0, N)
|
||||
if base2k > log2n {
|
||||
let diff: usize = base2k - (log2n - 1); // additional -1 because we map to [-N/2, N/2) instead of [0, N)
|
||||
res.iter_mut().for_each(|x| {
|
||||
*x = div_round_by_pow2(x, diff);
|
||||
})
|
||||
} else {
|
||||
let rem: usize = basek - (log2n % basek);
|
||||
let size: usize = log2n.div_ceil(basek);
|
||||
let rem: usize = base2k - (log2n % base2k);
|
||||
let size: usize = log2n.div_ceil(base2k);
|
||||
(1..size).for_each(|i| {
|
||||
if i == size - 1 && rem != basek {
|
||||
let k_rem: usize = basek - rem;
|
||||
if i == size - 1 && rem != base2k {
|
||||
let k_rem: usize = base2k - rem;
|
||||
izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| {
|
||||
*y = (*y << k_rem) + (x >> rem);
|
||||
});
|
||||
} else {
|
||||
izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| {
|
||||
*y = (*y << basek) + x;
|
||||
*y = (*y << base2k) + x;
|
||||
});
|
||||
}
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user