Fixed lut & added test for lut

This commit is contained in:
Jean-Philippe Bossuat
2025-07-02 12:25:22 +02:00
parent 52154d6f8a
commit c98bf75b61
4 changed files with 121 additions and 11 deletions

View File

@@ -111,6 +111,16 @@ impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D> {
}
}
pub fn rotate(&mut self, k: i64){
unsafe{
(0..self.cols()).for_each(|i|{
(0..self.size()).for_each(|j|{
znx::znx_rotate_inplace_i64(self.n() as u64, k, self.at_mut_ptr(i, j));
});
})
}
}
pub fn rsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) {
let n: usize = self.n();
let cols: usize = self.cols();

View File

@@ -1,4 +1,4 @@
use backend::{FFT64, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned};
use backend::{FFT64, Module, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned};
pub struct LookUpTable {
pub(crate) data: Vec<VecZnx<Vec<u8>>>,
@@ -7,10 +7,10 @@ pub struct LookUpTable {
}
impl LookUpTable {
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, extend_factor: usize) -> Self {
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, extension_factor: usize) -> Self {
let size: usize = k.div_ceil(basek);
let mut data: Vec<VecZnx<Vec<u8>>> = Vec::with_capacity(extend_factor);
(0..extend_factor).for_each(|_| {
let mut data: Vec<VecZnx<Vec<u8>>> = Vec::with_capacity(extension_factor);
(0..extension_factor).for_each(|_| {
data.push(module.new_vec_znx(1, size));
});
Self { data, basek, k }
@@ -20,6 +20,10 @@ impl LookUpTable {
self.data.len()
}
pub fn domain_size(&self) -> usize {
self.data.len() * self.data[0].n()
}
pub fn set(&mut self, module: &Module<FFT64>, f: fn(i64) -> i64, message_modulus: usize) {
let basek: usize = self.basek;
@@ -33,11 +37,11 @@ impl LookUpTable {
let f_scaled = |x: i64| (f(x) % message_modulus as i64) * scale;
// If LUT size > module.n()
let domain_size: usize = self.data[0].n() * self.extension_factor();
let domain_size: usize = self.domain_size();
let size: usize = self.k.div_ceil(self.basek);
// Equivalent to AUTO([f(0), f(1), ..., f(n-1)], -1)
// Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1)
let mut lut_full: VecZnx<Vec<u8>> = VecZnx::new::<i64>(domain_size, 1, size);
{
let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1);
@@ -55,28 +59,54 @@ impl LookUpTable {
let end: usize = ((x + 1) * domain_size).div_round(message_modulus);
let y: i64 = f_scaled(x as i64);
(start..end).for_each(|i| {
lut_at[domain_size - i] = -y;
lut_at[i] = y;
})
});
}
// Rotates half the step to the left
let half_step: usize = domain_size.div_round(message_modulus << 1);
module.vec_znx_rotate_inplace(-(half_step as i64), &mut lut_full, 0);
lut_full.rotate(-(half_step as i64));
let mut tmp_bytes: Vec<u8> = alloc_aligned(lut_full.n() * size_of::<i64>());
lut_full.normalize(self.basek, 0, &mut tmp_bytes);
if self.extension_factor() > 1 {
let mut scratch: ScratchOwned = ScratchOwned::new(module.bytes_of_vec_znx(1, size));
module.vec_znx_split(&mut self.data, 0, &lut_full, 0, scratch.borrow());
(0..self.extension_factor()).for_each(|i| {
module.switch_degree(&mut self.data[i], 0, &lut_full, 0);
if i < self.extension_factor() {
lut_full.rotate(-1);
}
});
} else {
module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0);
}
}
pub(crate) fn rotate(&mut self, k: i64) {
let extension_factor: usize = self.extension_factor();
let two_n: usize = 2 * self.data[0].n();
let two_n_ext: usize = two_n * extension_factor;
let k_pos: usize = ((k + two_n_ext as i64) % two_n_ext as i64) as usize;
let k_hi: usize = k_pos / extension_factor;
let k_lo: usize = k_pos % extension_factor;
(0..extension_factor - k_lo).for_each(|i| {
self.data[i].rotate(k_hi as i64);
});
(extension_factor - k_lo..extension_factor).for_each(|i| {
self.data[i].rotate(k_hi as i64 + 1);
});
self.data.rotate_right(k_lo as usize);
}
}
pub trait DivRound {
pub(crate) trait DivRound {
fn div_round(self, rhs: Self) -> Self;
}

View File

@@ -0,0 +1,69 @@
use backend::{FFT64, Module, ZnxView};
use crate::blind_rotation::lut::{DivRound, LookUpTable};
#[test]
fn standard() {
let module: Module<FFT64> = Module::<FFT64>::new(32);
let basek: usize = 20;
let k_lut: usize = 40;
let message_modulus: usize = 16;
let extension_factor: usize = 1;
let scale: usize = (1 << (basek - 1)) / message_modulus;
fn lut_fn(x: i64) -> i64 {
x - 8
}
let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor);
lut.set(&module, lut_fn, message_modulus);
let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64;
lut.rotate(half_step);
let step: usize = lut.domain_size().div_round(message_modulus);
(0..lut.domain_size()).step_by(step).for_each(|i| {
(0..step).for_each(|j| {
assert_eq!(
lut_fn((i / step) as i64) % message_modulus as i64,
lut.data[0].raw()[0] / scale as i64
);
lut.rotate(-1);
});
});
}
#[test]
fn extended() {
let module: Module<FFT64> = Module::<FFT64>::new(32);
let basek: usize = 20;
let k_lut: usize = 40;
let message_modulus: usize = 16;
let extension_factor: usize = 4;
let scale: usize = (1 << (basek - 1)) / message_modulus;
fn lut_fn(x: i64) -> i64 {
x - 8
}
let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor);
lut.set(&module, lut_fn, message_modulus);
let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64;
lut.rotate(half_step);
let step: usize = lut.domain_size().div_round(message_modulus);
(0..lut.domain_size()).step_by(step).for_each(|i| {
(0..step).for_each(|j| {
assert_eq!(
lut_fn((i / step) as i64) % message_modulus as i64,
lut.data[0].raw()[0] / scale as i64
);
lut.rotate(-1);
});
});
}

View File

@@ -1 +1,2 @@
pub mod cggi;
pub mod lut;