From 52154d6f8a85ffbf6ad08b7b327df7edff9d7d53 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Fri, 20 Jun 2025 11:00:04 +0200 Subject: [PATCH] wip CGGI BR for extended LUT --- core/src/blind_rotation/ccgi.rs | 239 ++++++++++++++++++++- core/src/blind_rotation/lut.rs | 8 +- core/src/blind_rotation/test_fft64/cggi.rs | 2 +- 3 files changed, 239 insertions(+), 10 deletions(-) diff --git a/core/src/blind_rotation/ccgi.rs b/core/src/blind_rotation/ccgi.rs index f560d09..38c8bf2 100644 --- a/core/src/blind_rotation/ccgi.rs +++ b/core/src/blind_rotation/ccgi.rs @@ -1,13 +1,13 @@ use std::time::Instant; use backend::{ - FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDftOps, Scratch, VecZnxDftOps, VecZnxOps, ZnxView, ZnxViewMut, - ZnxZero, + FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxDftOps, + VecZnxOps, ZnxView, ZnxViewMut, ZnxZero, }; use itertools::izip; use crate::{ - GGSWCiphertext, GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, ScratchCore, + FourierGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, ScratchCore, blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable}, lwe::ciphertext::LWECiphertextToRef, }; @@ -37,7 +37,232 @@ pub fn cggi_blind_rotate( DataRes: AsRef<[u8]> + AsMut<[u8]>, DataIn: AsRef<[u8]>, { - let basek = res.basek(); + if lut.data.len() > 1 { + cggi_blind_rotate_block_binary_exnteded(module, res, lwe, lut, brk, scratch); + } else if brk.block_size() > 1 { + cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch); + } else { + todo!("implement this case") + } +} + +pub(crate) fn cggi_blind_rotate_block_binary_exnteded( + module: &Module, + res: &mut GLWECiphertext, + lwe: &LWECiphertext, + lut: &LookUpTable, + brk: &BlindRotationKeyCGGI, + scratch: &mut Scratch, +) where + DataRes: AsRef<[u8]> + AsMut<[u8]>, + DataIn: AsRef<[u8]>, +{ + let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space + let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); + let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref(); + let basek: usize = out_mut.basek(); + + let cols: usize = out_mut.rank() + 1; + + mod_switch_2n( + 2 * module.n() * lut.extension_factor(), + &mut lwe_2n, + &lwe_ref, + ); + + let extension_factor: i64 = lut.extension_factor() as i64; + + let mut acc: Vec>> = Vec::with_capacity(lut.extension_factor()); + + for _ in 0..extension_factor { + acc.push(GLWECiphertext::alloc( + module, + basek, + out_mut.k(), + out_mut.rank(), + )); + } + + let a: &[i64] = &lwe_2n[1..]; + let b: i64 = lwe_2n[0]; + + let b_inner: i64 = b / extension_factor; + let b_outer: i64 = b % extension_factor; + + for (i, j) in (0..b_outer).zip(extension_factor - b_outer..extension_factor) { + module.vec_znx_rotate( + b_inner + 1, + &mut acc[j as usize].data, + 0, + &lut.data[i as usize], + 0, + ); + } + for (i, j) in (b_outer..extension_factor).zip(0..extension_factor - b_outer) { + module.vec_znx_rotate( + b_inner, + &mut acc[j as usize].data, + 0, + &lut.data[i as usize], + 0, + ); + } + + let block_size: usize = brk.block_size(); + + let mut acc_dft: Vec, FFT64>> = Vec::with_capacity(lut.extension_factor()); + + for _ in 0..extension_factor { + acc_dft.push(FourierGLWECiphertext::alloc( + module, + basek, + out_mut.k(), + out_mut.rank(), + )); + } + + let mut vmp_res: Vec, FFT64>> = Vec::with_capacity(lut.extension_factor()); + + for _ in 0..extension_factor { + vmp_res.push(FourierGLWECiphertext::alloc( + module, + basek, + out_mut.k(), + out_mut.rank(), + )); + } + + let mut acc_add_dft: Vec, FFT64>> = Vec::with_capacity(lut.extension_factor()); + + for _ in 0..extension_factor { + acc_add_dft.push(FourierGLWECiphertext::alloc( + module, + basek, + out_mut.k(), + out_mut.rank(), + )); + } + + let mut xai_minus_one: backend::ScalarZnx> = module.new_scalar_znx(1); + let mut xai_minus_one_dft: backend::ScalarZnxDft, FFT64> = module.new_scalar_znx_dft(1); + + izip!( + a.chunks_exact(block_size), + brk.data.chunks_exact(block_size) + ) + .enumerate() + .for_each(|(i, (ai, ski))| { + (0..lut.extension_factor()).for_each(|i| { + acc[i].dft(module, &mut acc_dft[i]); + acc_add_dft[i].data.zero(); + }); + + izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| { + let aii_inner: i64 = aii / extension_factor; + let aii_outer: i64 = aii % extension_factor; + + // vmp_res = DFT(acc) * BRK[i] + (0..lut.extension_factor()).for_each(|i| { + module.vmp_apply(&mut vmp_res[i].data, &acc_dft[i].data, &skii.data, scratch); + }); + + if aii_outer == 0 { + xai_minus_one.zero(); + xai_minus_one.at_mut(0, 0)[0] = 1; + module.vec_znx_rotate_inplace(aii_inner, &mut xai_minus_one, 0); + xai_minus_one.at_mut(0, 0)[0] -= 1; + module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); + + (0..lut.extension_factor()).for_each(|j| { + (0..cols).for_each(|i| { + module.svp_apply_inplace(&mut vmp_res[j].data, i, &xai_minus_one_dft, 0); + module.vec_znx_dft_add_inplace(&mut acc_add_dft[j].data, i, &vmp_res[j].data, i); + }); + }) + } else { + xai_minus_one.zero(); + xai_minus_one.at_mut(0, 0)[0] = 1; + module.vec_znx_rotate_inplace(aii_inner + 1, &mut xai_minus_one, 0); + xai_minus_one.at_mut(0, 0)[0] -= 1; + module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); + + for (i, j) in (0..aii_outer).zip(extension_factor - aii_outer..extension_factor) { + module.vec_znx_rotate( + b_inner + 1, + &mut acc[j as usize].data, + 0, + &lut.data[i as usize], + 0, + ); + + (0..cols).for_each(|k| { + module.svp_apply_inplace(&mut vmp_res[i as usize].data, k, &xai_minus_one_dft, 0); + module.vec_znx_dft_add_inplace( + &mut acc_add_dft[j as usize].data, + k, + &vmp_res[i as usize].data, + k, + ); + }); + } + + xai_minus_one.zero(); + xai_minus_one.at_mut(0, 0)[0] = 1; + module.vec_znx_rotate_inplace(aii_inner, &mut xai_minus_one, 0); + xai_minus_one.at_mut(0, 0)[0] -= 1; + module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0); + + for (i, j) in (aii_outer..extension_factor).zip(0..extension_factor - aii_outer) { + module.vec_znx_rotate( + b_inner, + &mut acc[j as usize].data, + 0, + &lut.data[i as usize], + 0, + ); + + (0..cols).for_each(|k| { + module.svp_apply_inplace(&mut vmp_res[i as usize].data, k, &xai_minus_one_dft, 0); + module.vec_znx_dft_add_inplace( + &mut acc_add_dft[j as usize].data, + k, + &vmp_res[i as usize].data, + k, + ); + }); + } + } + }); + + if i == lwe.n() - block_size { + (0..cols).for_each(|i| { + module.vec_znx_dft_add_inplace(&mut acc_dft[0].data, i, &acc_add_dft[0].data, i); + }); + acc_dft[0].idft(module, &mut out_mut, scratch); + } else { + (0..lut.extension_factor()).for_each(|j| { + (0..cols).for_each(|i| { + module.vec_znx_dft_add_inplace(&mut acc_dft[j].data, i, &acc_add_dft[j].data, i); + }); + + acc_dft[j].idft(module, &mut acc[j], scratch); + }) + } + }); +} + +pub(crate) fn cggi_blind_rotate_block_binary( + module: &Module, + res: &mut GLWECiphertext, + lwe: &LWECiphertext, + lut: &LookUpTable, + brk: &BlindRotationKeyCGGI, + scratch: &mut Scratch, +) where + DataRes: AsRef<[u8]> + AsMut<[u8]>, + DataIn: AsRef<[u8]>, +{ + let basek: usize = res.basek(); let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut(); @@ -45,7 +270,7 @@ pub fn cggi_blind_rotate( let cols: usize = out_mut.rank() + 1; - mod_switch_2n(module, &mut lwe_2n, &lwe_ref); + mod_switch_2n(2 * module.n(), &mut lwe_2n, &lwe_ref); let a: &[i64] = &lwe_2n[1..]; let b: i64 = lwe_2n[0]; @@ -102,10 +327,10 @@ pub fn cggi_blind_rotate( println!("external products: {} us", duration.as_micros()); } -pub(crate) fn mod_switch_2n(module: &Module, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { +pub(crate) fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) { let basek: usize = lwe.basek(); - let log2n: usize = module.log_n() + 1; + let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1; res.copy_from_slice(&lwe.data.at(0, 0)); diff --git a/core/src/blind_rotation/lut.rs b/core/src/blind_rotation/lut.rs index 56b65a0..743036b 100644 --- a/core/src/blind_rotation/lut.rs +++ b/core/src/blind_rotation/lut.rs @@ -16,6 +16,10 @@ impl LookUpTable { Self { data, basek, k } } + pub fn extension_factor(&self) -> usize { + self.data.len() + } + pub fn set(&mut self, module: &Module, f: fn(i64) -> i64, message_modulus: usize) { let basek: usize = self.basek; @@ -29,7 +33,7 @@ impl LookUpTable { let f_scaled = |x: i64| (f(x) % message_modulus as i64) * scale; // If LUT size > module.n() - let domain_size: usize = self.data[0].n() * self.data.len(); + let domain_size: usize = self.data[0].n() * self.extension_factor(); let size: usize = self.k.div_ceil(self.basek); @@ -63,7 +67,7 @@ impl LookUpTable { let mut tmp_bytes: Vec = alloc_aligned(lut_full.n() * size_of::()); lut_full.normalize(self.basek, 0, &mut tmp_bytes); - if self.data.len() > 1 { + if self.extension_factor() > 1 { let mut scratch: ScratchOwned = ScratchOwned::new(module.bytes_of_vec_znx(1, size)); module.vec_znx_split(&mut self.data, 0, &lut_full, 0, scratch.borrow()); } else { diff --git a/core/src/blind_rotation/test_fft64/cggi.rs b/core/src/blind_rotation/test_fft64/cggi.rs index bb98cea..6322c75 100644 --- a/core/src/blind_rotation/test_fft64/cggi.rs +++ b/core/src/blind_rotation/test_fft64/cggi.rs @@ -103,7 +103,7 @@ fn blind_rotation() { let mut lwe_2n: Vec = vec![0i64; lwe.n() + 1]; // TODO: from scratch space - mod_switch_2n(&module, &mut lwe_2n, &lwe.to_ref()); + mod_switch_2n(module.n() * 2, &mut lwe_2n, &lwe.to_ref()); let pt_want: i64 = (lwe_2n[0] + lwe_2n[1..]