Add cross-basek normalization (#90)

* added cross_basek_normalization

* updated method signatures to take layouts

* fixed cross-base normalization

fix #91
fix #93
This commit is contained in:
Jean-Philippe Bossuat
2025-09-30 14:40:10 +02:00
committed by GitHub
parent 4da790ea6a
commit 37e13b965c
216 changed files with 12481 additions and 7745 deletions

View File

@@ -4,34 +4,33 @@ use poulpy_hal::{
ScratchAvailable, SvpApplyDftToDft, SvpPPolAllocBytes, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice,
TakeVecZnxSlice, VecZnxAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize,
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftApply,
VecZnxDftSubABInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes,
VecZnxDftSubInplace, VecZnxDftZero, VecZnxIdftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpBytes,
VecZnxMulXpMinusOneInplace, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate,
VecZnxSubABInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
VecZnxSubInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol, VecZnx, ZnxView, ZnxZero},
};
use poulpy_core::{
Distribution, GLWEOperations, TakeGLWECt,
layouts::{GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, LWECiphertextToRef},
layouts::{GGSWInfos, GLWECiphertext, GLWECiphertextToMut, GLWEInfos, LWECiphertext, LWECiphertextToRef, LWEInfos},
};
use crate::tfhe::blind_rotation::{
BlincRotationExecute, BlindRotationKeyPrepared, CGGI, LookUpTable, LookUpTableRotationDirection,
BlincRotationExecute, BlindRotationKeyInfos, BlindRotationKeyPrepared, CGGI, LookUpTable, LookUpTableRotationDirection,
};
#[allow(clippy::too_many_arguments)]
pub fn cggi_blind_rotate_scratch_space<B: Backend>(
pub fn cggi_blind_rotate_scratch_space<B: Backend, OUT, GGSW>(
module: &Module<B>,
block_size: usize,
extension_factor: usize,
basek: usize,
k_res: usize,
k_brk: usize,
rows: usize,
rank: usize,
glwe_infos: &OUT,
brk_infos: &GGSW,
) -> usize
where
OUT: GLWEInfos,
GGSW: GGSWInfos,
Module<B>: VecZnxDftAllocBytes
+ VmpApplyDftToDftTmpBytes
+ VecZnxNormalizeTmpBytes
@@ -39,10 +38,11 @@ where
+ VecZnxIdftApplyTmpBytes
+ VecZnxBigNormalizeTmpBytes,
{
let brk_size: usize = k_brk.div_ceil(basek);
let brk_size: usize = brk_infos.size();
if block_size > 1 {
let cols: usize = rank + 1;
let cols: usize = (brk_infos.rank() + 1).into();
let rows: usize = brk_infos.rows().into();
let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, rows) * extension_factor;
let acc_big: usize = module.vec_znx_big_alloc_bytes(1, brk_size);
let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor;
@@ -50,7 +50,7 @@ where
let acc_dft_add: usize = vmp_res;
let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
let acc: usize = if extension_factor > 1 {
VecZnx::alloc_bytes(module.n(), cols, k_res.div_ceil(basek)) * extension_factor
VecZnx::alloc_bytes(module.n(), cols, glwe_infos.size()) * extension_factor
} else {
0
};
@@ -61,8 +61,8 @@ where
+ vmp_xai
+ (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_apply_tmp_bytes())))
} else {
GLWECiphertext::bytes_of(module.n(), basek, k_res, rank)
+ GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank)
GLWECiphertext::alloc_bytes(glwe_infos)
+ GLWECiphertext::external_product_inplace_scratch_space(module, glwe_infos, brk_infos)
}
}
@@ -80,11 +80,11 @@ where
+ VecZnxDftApply<B>
+ VecZnxDftZero<B>
+ SvpApplyDftToDft<B>
+ VecZnxDftSubABInplace<B>
+ VecZnxDftSubInplace<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxRotate
+ VecZnxAddInplace
+ VecZnxSubABInplace
+ VecZnxSubInplace
+ VecZnxNormalize<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
@@ -142,11 +142,11 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
+ VecZnxDftApply<B>
+ VecZnxDftZero<B>
+ SvpApplyDftToDft<B>
+ VecZnxDftSubABInplace<B>
+ VecZnxDftSubInplace<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxRotate
+ VecZnxAddInplace
+ VecZnxSubABInplace
+ VecZnxSubInplace
+ VecZnxNormalize<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
@@ -155,11 +155,11 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
+ VmpApplyDftToDft<B>,
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx,
{
let n_glwe: usize = brk.n();
let n_glwe: usize = brk.n_glwe().into();
let extension_factor: usize = lut.extension_factor();
let basek: usize = res.basek();
let rows: usize = brk.rows();
let cols: usize = res.rank() + 1;
let base2k: usize = res.base2k().into();
let rows: usize = brk.rows().into();
let cols: usize = (res.rank() + 1).into();
let (mut acc, scratch_1) = scratch.take_vec_znx_slice(extension_factor, n_glwe, cols, res.size());
let (mut acc_dft, scratch_2) = scratch_1.take_vec_znx_dft_slice(extension_factor, n_glwe, cols, rows);
@@ -178,7 +178,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
panic!("invalid key: x_pow_a has not been initialized")
}
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
let mut lwe_2n: Vec<i64> = vec![0i64; (lwe.n() + 1).as_usize()]; // TODO: from scratch space
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
let two_n: usize = 2 * n_glwe;
@@ -233,7 +233,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
(0..cols).for_each(|i| {
module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], i);
module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_xai, 0);
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i);
module.vec_znx_dft_sub_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i);
});
});
}
@@ -249,7 +249,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
(0..cols).for_each(|k| {
module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_hi + 1], 0, &vmp_res[j], k);
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_xai, 0);
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
module.vec_znx_dft_sub_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
});
}
}
@@ -261,7 +261,7 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
(0..cols).for_each(|k| {
module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], k);
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_xai, 0);
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
module.vec_znx_dft_sub_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
});
}
}
@@ -275,14 +275,14 @@ fn execute_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
(0..cols).for_each(|i| {
module.vec_znx_idft_apply(&mut acc_add_big, 0, &acc_add_dft[j], i, scratch7);
module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &acc[j], i);
module.vec_znx_big_normalize(basek, &mut acc[j], i, &acc_add_big, 0, scratch7);
module.vec_znx_big_normalize(base2k, &mut acc[j], i, base2k, &acc_add_big, 0, scratch7);
});
});
}
});
(0..cols).for_each(|i| {
module.vec_znx_copy(&mut res.data, i, &acc[0], i);
module.vec_znx_copy(res.data_mut(), i, &acc[0], i);
});
}
@@ -309,11 +309,11 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
+ VecZnxDftApply<B>
+ VecZnxDftZero<B>
+ SvpApplyDftToDft<B>
+ VecZnxDftSubABInplace<B>
+ VecZnxDftSubInplace<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxRotate
+ VecZnxAddInplace
+ VecZnxSubABInplace
+ VecZnxSubInplace
+ VecZnxNormalize<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
@@ -322,15 +322,15 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
+ VecZnxBigNormalize<B>,
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice + ScratchAvailable + TakeVecZnx,
{
let n_glwe: usize = brk.n();
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
let n_glwe: usize = brk.n_glwe().into();
let mut lwe_2n: Vec<i64> = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
let two_n: usize = n_glwe << 1;
let basek: usize = brk.basek();
let rows: usize = brk.rows();
let base2k: usize = brk.base2k().into();
let rows: usize = brk.rows().into();
let cols: usize = out_mut.rank() + 1;
let cols: usize = (out_mut.rank() + 1).into();
mod_switch_2n(
2 * lut.domain_size(),
@@ -342,10 +342,10 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
let a: &[i64] = &lwe_2n[1..];
let b: i64 = lwe_2n[0];
out_mut.data.zero();
out_mut.data_mut().zero();
// Initialize out to X^{b} * LUT(X)
module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0);
module.vec_znx_rotate(b, out_mut.data_mut(), 0, &lut.data[0], 0);
let block_size: usize = brk.block_size();
@@ -369,7 +369,7 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
)
.for_each(|(ai, ski)| {
(0..cols).for_each(|j| {
module.vec_znx_dft_apply(1, 0, &mut acc_dft, j, &out_mut.data, j);
module.vec_znx_dft_apply(1, 0, &mut acc_dft, j, out_mut.data_mut(), j);
});
module.vec_znx_dft_zero(&mut acc_add_dft);
@@ -384,7 +384,7 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
(0..cols).for_each(|i| {
module.svp_apply_dft_to_dft(&mut vmp_xai, 0, &x_pow_a[ai_pos], 0, &vmp_res, i);
module.vec_znx_dft_add_inplace(&mut acc_add_dft, i, &vmp_xai, 0);
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft, i, &vmp_res, i);
module.vec_znx_dft_sub_inplace(&mut acc_add_dft, i, &vmp_res, i);
});
});
@@ -393,8 +393,16 @@ fn execute_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
(0..cols).for_each(|i| {
module.vec_znx_idft_apply(&mut acc_add_big, 0, &acc_add_dft, i, scratch_5);
module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &out_mut.data, i);
module.vec_znx_big_normalize(basek, &mut out_mut.data, i, &acc_add_big, 0, scratch_5);
module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, out_mut.data_mut(), i);
module.vec_znx_big_normalize(
base2k,
out_mut.data_mut(),
i,
base2k,
&acc_add_big,
0,
scratch_5,
);
});
}
});
@@ -423,11 +431,11 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
+ VecZnxDftApply<B>
+ VecZnxDftZero<B>
+ SvpApplyDftToDft<B>
+ VecZnxDftSubABInplace<B>
+ VecZnxDftSubInplace<B>
+ VecZnxBigAddSmallInplace<B>
+ VecZnxRotate
+ VecZnxAddInplace
+ VecZnxSubABInplace
+ VecZnxSubInplace
+ VecZnxNormalize<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
@@ -450,10 +458,10 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
);
assert_eq!(
lut.domain_size(),
brk.n(),
brk.n_glwe().as_usize(),
"lut.n(): {} != brk.n(): {}",
lut.domain_size(),
brk.n()
brk.n_glwe().as_usize()
);
assert_eq!(
res.rank(),
@@ -464,17 +472,16 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
);
assert_eq!(
lwe.n(),
brk.data.len(),
brk.n_lwe(),
"lwe.n(): {} != brk.data.len(): {}",
lwe.n(),
brk.data.len()
brk.n_lwe()
);
}
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
let mut lwe_2n: Vec<i64> = vec![0i64; (lwe.n() + 1).into()]; // TODO: from scratch space
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
let basek: usize = brk.basek();
mod_switch_2n(
2 * lut.domain_size(),
@@ -486,13 +493,13 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
let a: &[i64] = &lwe_2n[1..];
let b: i64 = lwe_2n[0];
out_mut.data.zero();
out_mut.data_mut().zero();
// Initialize out to X^{b} * LUT(X)
module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0);
module.vec_znx_rotate(b, out_mut.data_mut(), 0, &lut.data[0], 0);
// ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)]
let (mut acc_tmp, scratch_1) = scratch.take_glwe_ct(out_mut.n(), basek, out_mut.k(), out_mut.rank());
let (mut acc_tmp, scratch_1) = scratch.take_glwe_ct(&out_mut);
// TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs
// TODO: first iteration can be optimized to be a gglwe product
@@ -507,13 +514,13 @@ fn execute_standard<DataRes, DataIn, DataBrk, B: Backend>(
out_mut.add_inplace(module, &acc_tmp);
});
// We can normalize only at the end because we add normalized values in [-2^{basek-1}, 2^{basek-1}]
// on top of each others, thus ~ 2^{63-basek} additions are supported before overflow.
// We can normalize only at the end because we add normalized values in [-2^{base2k-1}, 2^{base2k-1}]
// on top of each others, thus ~ 2^{63-base2k} additions are supported before overflow.
out_mut.normalize_inplace(module, scratch_1);
}
pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>, rot_dir: LookUpTableRotationDirection) {
let basek: usize = lwe.basek();
let base2k: usize = lwe.base2k().into();
let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1;
@@ -526,23 +533,23 @@ pub fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>, rot_
LookUpTableRotationDirection::Right => {}
}
if basek > log2n {
let diff: usize = basek - (log2n - 1); // additional -1 because we map to [-N/2, N/2) instead of [0, N)
if base2k > log2n {
let diff: usize = base2k - (log2n - 1); // additional -1 because we map to [-N/2, N/2) instead of [0, N)
res.iter_mut().for_each(|x| {
*x = div_round_by_pow2(x, diff);
})
} else {
let rem: usize = basek - (log2n % basek);
let size: usize = log2n.div_ceil(basek);
let rem: usize = base2k - (log2n % base2k);
let size: usize = log2n.div_ceil(base2k);
(1..size).for_each(|i| {
if i == size - 1 && rem != basek {
let k_rem: usize = basek - rem;
if i == size - 1 && rem != base2k {
let k_rem: usize = base2k - rem;
izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| {
*y = (*y << k_rem) + (x >> rem);
});
} else {
izip!(lwe.data().at(0, i).iter(), res.iter_mut()).for_each(|(x, y)| {
*y = (*y << basek) + x;
*y = (*y << base2k) + x;
});
}
})