diff --git a/backend/src/vec_znx.rs b/backend/src/vec_znx.rs index 6189bad..00568dd 100644 --- a/backend/src/vec_znx.rs +++ b/backend/src/vec_znx.rs @@ -111,6 +111,16 @@ impl + AsRef<[u8]>> VecZnx { } } + pub fn rotate(&mut self, k: i64){ + unsafe{ + (0..self.cols()).for_each(|i|{ + (0..self.size()).for_each(|j|{ + znx::znx_rotate_inplace_i64(self.n() as u64, k, self.at_mut_ptr(i, j)); + }); + }) + } + } + pub fn rsh(&mut self, basek: usize, k: usize, scratch: &mut Scratch) { let n: usize = self.n(); let cols: usize = self.cols(); diff --git a/core/src/blind_rotation/lut.rs b/core/src/blind_rotation/lut.rs index 743036b..b6e9a7f 100644 --- a/core/src/blind_rotation/lut.rs +++ b/core/src/blind_rotation/lut.rs @@ -1,4 +1,4 @@ -use backend::{FFT64, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned}; +use backend::{FFT64, Module, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned}; pub struct LookUpTable { pub(crate) data: Vec>>, @@ -7,10 +7,10 @@ pub struct LookUpTable { } impl LookUpTable { - pub fn alloc(module: &Module, basek: usize, k: usize, extend_factor: usize) -> Self { + pub fn alloc(module: &Module, basek: usize, k: usize, extension_factor: usize) -> Self { let size: usize = k.div_ceil(basek); - let mut data: Vec>> = Vec::with_capacity(extend_factor); - (0..extend_factor).for_each(|_| { + let mut data: Vec>> = Vec::with_capacity(extension_factor); + (0..extension_factor).for_each(|_| { data.push(module.new_vec_znx(1, size)); }); Self { data, basek, k } @@ -20,6 +20,10 @@ impl LookUpTable { self.data.len() } + pub fn domain_size(&self) -> usize { + self.data.len() * self.data[0].n() + } + pub fn set(&mut self, module: &Module, f: fn(i64) -> i64, message_modulus: usize) { let basek: usize = self.basek; @@ -33,11 +37,11 @@ impl LookUpTable { let f_scaled = |x: i64| (f(x) % message_modulus as i64) * scale; // If LUT size > module.n() - let domain_size: usize = self.data[0].n() * self.extension_factor(); + let domain_size: usize = self.domain_size(); let size: usize = self.k.div_ceil(self.basek); - // Equivalent to AUTO([f(0), f(1), ..., f(n-1)], -1) + // Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1) let mut lut_full: VecZnx> = VecZnx::new::(domain_size, 1, size); { let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1); @@ -55,28 +59,54 @@ impl LookUpTable { let end: usize = ((x + 1) * domain_size).div_round(message_modulus); let y: i64 = f_scaled(x as i64); (start..end).for_each(|i| { - lut_at[domain_size - i] = -y; + lut_at[i] = y; }) }); } // Rotates half the step to the left let half_step: usize = domain_size.div_round(message_modulus << 1); - module.vec_znx_rotate_inplace(-(half_step as i64), &mut lut_full, 0); + + lut_full.rotate(-(half_step as i64)); let mut tmp_bytes: Vec = alloc_aligned(lut_full.n() * size_of::()); lut_full.normalize(self.basek, 0, &mut tmp_bytes); if self.extension_factor() > 1 { - let mut scratch: ScratchOwned = ScratchOwned::new(module.bytes_of_vec_znx(1, size)); - module.vec_znx_split(&mut self.data, 0, &lut_full, 0, scratch.borrow()); + (0..self.extension_factor()).for_each(|i| { + module.switch_degree(&mut self.data[i], 0, &lut_full, 0); + if i < self.extension_factor() { + lut_full.rotate(-1); + } + }); } else { module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0); } } + + pub(crate) fn rotate(&mut self, k: i64) { + let extension_factor: usize = self.extension_factor(); + let two_n: usize = 2 * self.data[0].n(); + let two_n_ext: usize = two_n * extension_factor; + + let k_pos: usize = ((k + two_n_ext as i64) % two_n_ext as i64) as usize; + + let k_hi: usize = k_pos / extension_factor; + let k_lo: usize = k_pos % extension_factor; + + (0..extension_factor - k_lo).for_each(|i| { + self.data[i].rotate(k_hi as i64); + }); + + (extension_factor - k_lo..extension_factor).for_each(|i| { + self.data[i].rotate(k_hi as i64 + 1); + }); + + self.data.rotate_right(k_lo as usize); + } } -pub trait DivRound { +pub(crate) trait DivRound { fn div_round(self, rhs: Self) -> Self; } diff --git a/core/src/blind_rotation/test_fft64/lut.rs b/core/src/blind_rotation/test_fft64/lut.rs new file mode 100644 index 0000000..9377d76 --- /dev/null +++ b/core/src/blind_rotation/test_fft64/lut.rs @@ -0,0 +1,69 @@ +use backend::{FFT64, Module, ZnxView}; + +use crate::blind_rotation::lut::{DivRound, LookUpTable}; + +#[test] +fn standard() { + let module: Module = Module::::new(32); + let basek: usize = 20; + let k_lut: usize = 40; + let message_modulus: usize = 16; + let extension_factor: usize = 1; + + let scale: usize = (1 << (basek - 1)) / message_modulus; + + fn lut_fn(x: i64) -> i64 { + x - 8 + } + + let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); + lut.set(&module, lut_fn, message_modulus); + + let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; + lut.rotate(half_step); + + let step: usize = lut.domain_size().div_round(message_modulus); + + (0..lut.domain_size()).step_by(step).for_each(|i| { + (0..step).for_each(|j| { + assert_eq!( + lut_fn((i / step) as i64) % message_modulus as i64, + lut.data[0].raw()[0] / scale as i64 + ); + lut.rotate(-1); + }); + }); +} + +#[test] +fn extended() { + let module: Module = Module::::new(32); + let basek: usize = 20; + let k_lut: usize = 40; + let message_modulus: usize = 16; + let extension_factor: usize = 4; + + let scale: usize = (1 << (basek - 1)) / message_modulus; + + fn lut_fn(x: i64) -> i64 { + x - 8 + } + + let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor); + lut.set(&module, lut_fn, message_modulus); + + let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64; + lut.rotate(half_step); + + let step: usize = lut.domain_size().div_round(message_modulus); + + (0..lut.domain_size()).step_by(step).for_each(|i| { + (0..step).for_each(|j| { + assert_eq!( + lut_fn((i / step) as i64) % message_modulus as i64, + lut.data[0].raw()[0] / scale as i64 + ); + lut.rotate(-1); + }); + }); +} diff --git a/core/src/blind_rotation/test_fft64/mod.rs b/core/src/blind_rotation/test_fft64/mod.rs index 1a23dff..18ac93c 100644 --- a/core/src/blind_rotation/test_fft64/mod.rs +++ b/core/src/blind_rotation/test_fft64/mod.rs @@ -1 +1,2 @@ pub mod cggi; +pub mod lut;