use backend::{FFT64, Module, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned}; pub struct LookUpTable { pub(crate) data: Vec>>, pub(crate) basek: usize, pub(crate) k: usize, } impl LookUpTable { pub fn alloc(module: &Module, basek: usize, k: usize, extension_factor: usize) -> Self { let size: usize = k.div_ceil(basek); let mut data: Vec>> = Vec::with_capacity(extension_factor); (0..extension_factor).for_each(|_| { data.push(module.new_vec_znx(1, size)); }); Self { data, basek, k } } pub fn extension_factor(&self) -> usize { self.data.len() } pub fn domain_size(&self) -> usize { self.data.len() * self.data[0].n() } pub fn set(&mut self, module: &Module, f: &Vec, k: usize) { assert!(f.len() <= module.n()); let basek: usize = self.basek; // Get the number minimum limb to store the message modulus let limbs: usize = k.div_ceil(1 << basek); // Scaling factor let scale: i64 = (1 << (basek * limbs - 1)).div_round(k) as i64; // #elements in lookup table let f_len: usize = f.len(); // If LUT size > module.n() let domain_size: usize = self.domain_size(); let size: usize = self.k.div_ceil(self.basek); // Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1) let mut lut_full: VecZnx> = VecZnx::new::(domain_size, 1, size); let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1); f.iter().enumerate().for_each(|(i, fi)| { let start: usize = (i * domain_size).div_round(f_len); let end: usize = ((i + 1) * domain_size).div_round(f_len); lut_at[start..end].fill(fi * scale); }); // Rotates half the step to the left let half_step: usize = domain_size.div_round(f_len << 1); lut_full.rotate(-(half_step as i64)); let mut tmp_bytes: Vec = alloc_aligned(lut_full.n() * size_of::()); lut_full.normalize(self.basek, 0, &mut tmp_bytes); if self.extension_factor() > 1 { (0..self.extension_factor()).for_each(|i| { module.switch_degree(&mut self.data[i], 0, &lut_full, 0); if i < self.extension_factor() { lut_full.rotate(-1); } }); } else { module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0); } } #[allow(dead_code)] pub(crate) fn rotate(&mut self, k: i64) { let extension_factor: usize = self.extension_factor(); let two_n: usize = 2 * self.data[0].n(); let two_n_ext: usize = two_n * extension_factor; let k_pos: usize = ((k + two_n_ext as i64) % two_n_ext as i64) as usize; let k_hi: usize = k_pos / extension_factor; let k_lo: usize = k_pos % extension_factor; (0..extension_factor - k_lo).for_each(|i| { self.data[i].rotate(k_hi as i64); }); (extension_factor - k_lo..extension_factor).for_each(|i| { self.data[i].rotate(k_hi as i64 + 1); }); self.data.rotate_right(k_lo as usize); } } pub(crate) trait DivRound { fn div_round(self, rhs: Self) -> Self; } impl DivRound for usize { #[inline] fn div_round(self, rhs: Self) -> Self { (self + rhs / 2) / rhs } }