mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Add cross-basek normalization (#90)
* added cross_basek_normalization * updated method signatures to take layouts * fixed cross-base normalization fix #91 fix #93
This commit is contained in:
committed by
GitHub
parent
4da790ea6a
commit
37e13b965c
@@ -15,29 +15,28 @@ pub enum LookUpTableRotationDirection {
|
||||
pub struct LookUpTable {
|
||||
pub(crate) data: Vec<VecZnx<Vec<u8>>>,
|
||||
pub(crate) rot_dir: LookUpTableRotationDirection,
|
||||
pub(crate) basek: usize,
|
||||
pub(crate) base2k: usize,
|
||||
pub(crate) k: usize,
|
||||
pub(crate) drift: usize,
|
||||
}
|
||||
|
||||
impl LookUpTable {
|
||||
pub fn alloc<B: Backend>(module: &Module<B>, basek: usize, k: usize, extension_factor: usize) -> Self {
|
||||
pub fn alloc<B: Backend>(module: &Module<B>, base2k: usize, k: usize, extension_factor: usize) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
extension_factor & (extension_factor - 1) == 0,
|
||||
"extension_factor must be a power of two but is: {}",
|
||||
extension_factor
|
||||
"extension_factor must be a power of two but is: {extension_factor}"
|
||||
);
|
||||
}
|
||||
let size: usize = k.div_ceil(basek);
|
||||
let size: usize = k.div_ceil(base2k);
|
||||
let mut data: Vec<VecZnx<Vec<u8>>> = Vec::with_capacity(extension_factor);
|
||||
(0..extension_factor).for_each(|_| {
|
||||
data.push(VecZnx::alloc(module.n(), 1, size));
|
||||
});
|
||||
Self {
|
||||
data,
|
||||
basek,
|
||||
base2k,
|
||||
k,
|
||||
drift: 0,
|
||||
rot_dir: LookUpTableRotationDirection::Left,
|
||||
@@ -80,27 +79,27 @@ impl LookUpTable {
|
||||
{
|
||||
assert!(f.len() <= module.n());
|
||||
|
||||
let basek: usize = self.basek;
|
||||
let base2k: usize = self.base2k;
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
|
||||
|
||||
// Get the number minimum limb to store the message modulus
|
||||
let limbs: usize = k.div_ceil(basek);
|
||||
let limbs: usize = k.div_ceil(base2k);
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(f.len() <= module.n());
|
||||
assert!(
|
||||
(max_bit_size(f) + (k % basek) as u32) < i64::BITS,
|
||||
"overflow: max(|f|) << (k%basek) > i64::BITS"
|
||||
(max_bit_size(f) + (k % base2k) as u32) < i64::BITS,
|
||||
"overflow: max(|f|) << (k%base2k) > i64::BITS"
|
||||
);
|
||||
assert!(limbs <= self.data[0].size());
|
||||
}
|
||||
|
||||
// Scaling factor
|
||||
let mut scale = 1;
|
||||
if !k.is_multiple_of(basek) {
|
||||
scale <<= basek - (k % basek);
|
||||
if !k.is_multiple_of(base2k) {
|
||||
scale <<= base2k - (k % base2k);
|
||||
}
|
||||
|
||||
// #elements in lookup table
|
||||
@@ -109,7 +108,7 @@ impl LookUpTable {
|
||||
// If LUT size > TakeScalarZnx
|
||||
let domain_size: usize = self.domain_size();
|
||||
|
||||
let size: usize = self.k.div_ceil(self.basek);
|
||||
let size: usize = self.k.div_ceil(self.base2k);
|
||||
|
||||
// Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1)
|
||||
let mut lut_full: VecZnx<Vec<u8>> = VecZnx::alloc(domain_size, 1, size);
|
||||
@@ -140,7 +139,7 @@ impl LookUpTable {
|
||||
}
|
||||
|
||||
self.data.iter_mut().for_each(|a| {
|
||||
module.vec_znx_normalize_inplace(self.basek, a, 0, scratch.borrow());
|
||||
module.vec_znx_normalize_inplace(self.base2k, a, 0, scratch.borrow());
|
||||
});
|
||||
|
||||
self.rotate(module, -(drift as i64));
|
||||
|
||||
Reference in New Issue
Block a user