mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
wip CGGI BR for extended LUT
This commit is contained in:
@@ -1,13 +1,13 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use backend::{
|
||||
FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDftOps, Scratch, VecZnxDftOps, VecZnxOps, ZnxView, ZnxViewMut,
|
||||
ZnxZero,
|
||||
FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxDftOps,
|
||||
VecZnxOps, ZnxView, ZnxViewMut, ZnxZero,
|
||||
};
|
||||
use itertools::izip;
|
||||
|
||||
use crate::{
|
||||
GGSWCiphertext, GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, ScratchCore,
|
||||
FourierGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext, ScratchCore,
|
||||
blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable},
|
||||
lwe::ciphertext::LWECiphertextToRef,
|
||||
};
|
||||
@@ -37,7 +37,232 @@ pub fn cggi_blind_rotate<DataRes, DataIn>(
|
||||
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
||||
DataIn: AsRef<[u8]>,
|
||||
{
|
||||
let basek = res.basek();
|
||||
if lut.data.len() > 1 {
|
||||
cggi_blind_rotate_block_binary_exnteded(module, res, lwe, lut, brk, scratch);
|
||||
} else if brk.block_size() > 1 {
|
||||
cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch);
|
||||
} else {
|
||||
todo!("implement this case")
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn cggi_blind_rotate_block_binary_exnteded<DataRes, DataIn>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut GLWECiphertext<DataRes>,
|
||||
lwe: &LWECiphertext<DataIn>,
|
||||
lut: &LookUpTable,
|
||||
brk: &BlindRotationKeyCGGI<FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
||||
DataIn: AsRef<[u8]>,
|
||||
{
|
||||
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
||||
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
|
||||
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
|
||||
let basek: usize = out_mut.basek();
|
||||
|
||||
let cols: usize = out_mut.rank() + 1;
|
||||
|
||||
mod_switch_2n(
|
||||
2 * module.n() * lut.extension_factor(),
|
||||
&mut lwe_2n,
|
||||
&lwe_ref,
|
||||
);
|
||||
|
||||
let extension_factor: i64 = lut.extension_factor() as i64;
|
||||
|
||||
let mut acc: Vec<GLWECiphertext<Vec<u8>>> = Vec::with_capacity(lut.extension_factor());
|
||||
|
||||
for _ in 0..extension_factor {
|
||||
acc.push(GLWECiphertext::alloc(
|
||||
module,
|
||||
basek,
|
||||
out_mut.k(),
|
||||
out_mut.rank(),
|
||||
));
|
||||
}
|
||||
|
||||
let a: &[i64] = &lwe_2n[1..];
|
||||
let b: i64 = lwe_2n[0];
|
||||
|
||||
let b_inner: i64 = b / extension_factor;
|
||||
let b_outer: i64 = b % extension_factor;
|
||||
|
||||
for (i, j) in (0..b_outer).zip(extension_factor - b_outer..extension_factor) {
|
||||
module.vec_znx_rotate(
|
||||
b_inner + 1,
|
||||
&mut acc[j as usize].data,
|
||||
0,
|
||||
&lut.data[i as usize],
|
||||
0,
|
||||
);
|
||||
}
|
||||
for (i, j) in (b_outer..extension_factor).zip(0..extension_factor - b_outer) {
|
||||
module.vec_znx_rotate(
|
||||
b_inner,
|
||||
&mut acc[j as usize].data,
|
||||
0,
|
||||
&lut.data[i as usize],
|
||||
0,
|
||||
);
|
||||
}
|
||||
|
||||
let block_size: usize = brk.block_size();
|
||||
|
||||
let mut acc_dft: Vec<FourierGLWECiphertext<Vec<u8>, FFT64>> = Vec::with_capacity(lut.extension_factor());
|
||||
|
||||
for _ in 0..extension_factor {
|
||||
acc_dft.push(FourierGLWECiphertext::alloc(
|
||||
module,
|
||||
basek,
|
||||
out_mut.k(),
|
||||
out_mut.rank(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut vmp_res: Vec<FourierGLWECiphertext<Vec<u8>, FFT64>> = Vec::with_capacity(lut.extension_factor());
|
||||
|
||||
for _ in 0..extension_factor {
|
||||
vmp_res.push(FourierGLWECiphertext::alloc(
|
||||
module,
|
||||
basek,
|
||||
out_mut.k(),
|
||||
out_mut.rank(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut acc_add_dft: Vec<FourierGLWECiphertext<Vec<u8>, FFT64>> = Vec::with_capacity(lut.extension_factor());
|
||||
|
||||
for _ in 0..extension_factor {
|
||||
acc_add_dft.push(FourierGLWECiphertext::alloc(
|
||||
module,
|
||||
basek,
|
||||
out_mut.k(),
|
||||
out_mut.rank(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut xai_minus_one: backend::ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
let mut xai_minus_one_dft: backend::ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(1);
|
||||
|
||||
izip!(
|
||||
a.chunks_exact(block_size),
|
||||
brk.data.chunks_exact(block_size)
|
||||
)
|
||||
.enumerate()
|
||||
.for_each(|(i, (ai, ski))| {
|
||||
(0..lut.extension_factor()).for_each(|i| {
|
||||
acc[i].dft(module, &mut acc_dft[i]);
|
||||
acc_add_dft[i].data.zero();
|
||||
});
|
||||
|
||||
izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| {
|
||||
let aii_inner: i64 = aii / extension_factor;
|
||||
let aii_outer: i64 = aii % extension_factor;
|
||||
|
||||
// vmp_res = DFT(acc) * BRK[i]
|
||||
(0..lut.extension_factor()).for_each(|i| {
|
||||
module.vmp_apply(&mut vmp_res[i].data, &acc_dft[i].data, &skii.data, scratch);
|
||||
});
|
||||
|
||||
if aii_outer == 0 {
|
||||
xai_minus_one.zero();
|
||||
xai_minus_one.at_mut(0, 0)[0] = 1;
|
||||
module.vec_znx_rotate_inplace(aii_inner, &mut xai_minus_one, 0);
|
||||
xai_minus_one.at_mut(0, 0)[0] -= 1;
|
||||
module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0);
|
||||
|
||||
(0..lut.extension_factor()).for_each(|j| {
|
||||
(0..cols).for_each(|i| {
|
||||
module.svp_apply_inplace(&mut vmp_res[j].data, i, &xai_minus_one_dft, 0);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[j].data, i, &vmp_res[j].data, i);
|
||||
});
|
||||
})
|
||||
} else {
|
||||
xai_minus_one.zero();
|
||||
xai_minus_one.at_mut(0, 0)[0] = 1;
|
||||
module.vec_znx_rotate_inplace(aii_inner + 1, &mut xai_minus_one, 0);
|
||||
xai_minus_one.at_mut(0, 0)[0] -= 1;
|
||||
module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0);
|
||||
|
||||
for (i, j) in (0..aii_outer).zip(extension_factor - aii_outer..extension_factor) {
|
||||
module.vec_znx_rotate(
|
||||
b_inner + 1,
|
||||
&mut acc[j as usize].data,
|
||||
0,
|
||||
&lut.data[i as usize],
|
||||
0,
|
||||
);
|
||||
|
||||
(0..cols).for_each(|k| {
|
||||
module.svp_apply_inplace(&mut vmp_res[i as usize].data, k, &xai_minus_one_dft, 0);
|
||||
module.vec_znx_dft_add_inplace(
|
||||
&mut acc_add_dft[j as usize].data,
|
||||
k,
|
||||
&vmp_res[i as usize].data,
|
||||
k,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
xai_minus_one.zero();
|
||||
xai_minus_one.at_mut(0, 0)[0] = 1;
|
||||
module.vec_znx_rotate_inplace(aii_inner, &mut xai_minus_one, 0);
|
||||
xai_minus_one.at_mut(0, 0)[0] -= 1;
|
||||
module.svp_prepare(&mut xai_minus_one_dft, 0, &xai_minus_one, 0);
|
||||
|
||||
for (i, j) in (aii_outer..extension_factor).zip(0..extension_factor - aii_outer) {
|
||||
module.vec_znx_rotate(
|
||||
b_inner,
|
||||
&mut acc[j as usize].data,
|
||||
0,
|
||||
&lut.data[i as usize],
|
||||
0,
|
||||
);
|
||||
|
||||
(0..cols).for_each(|k| {
|
||||
module.svp_apply_inplace(&mut vmp_res[i as usize].data, k, &xai_minus_one_dft, 0);
|
||||
module.vec_znx_dft_add_inplace(
|
||||
&mut acc_add_dft[j as usize].data,
|
||||
k,
|
||||
&vmp_res[i as usize].data,
|
||||
k,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if i == lwe.n() - block_size {
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_dft_add_inplace(&mut acc_dft[0].data, i, &acc_add_dft[0].data, i);
|
||||
});
|
||||
acc_dft[0].idft(module, &mut out_mut, scratch);
|
||||
} else {
|
||||
(0..lut.extension_factor()).for_each(|j| {
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_dft_add_inplace(&mut acc_dft[j].data, i, &acc_add_dft[j].data, i);
|
||||
});
|
||||
|
||||
acc_dft[j].idft(module, &mut acc[j], scratch);
|
||||
})
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut GLWECiphertext<DataRes>,
|
||||
lwe: &LWECiphertext<DataIn>,
|
||||
lut: &LookUpTable,
|
||||
brk: &BlindRotationKeyCGGI<FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
||||
DataIn: AsRef<[u8]>,
|
||||
{
|
||||
let basek: usize = res.basek();
|
||||
|
||||
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
||||
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
|
||||
@@ -45,7 +270,7 @@ pub fn cggi_blind_rotate<DataRes, DataIn>(
|
||||
|
||||
let cols: usize = out_mut.rank() + 1;
|
||||
|
||||
mod_switch_2n(module, &mut lwe_2n, &lwe_ref);
|
||||
mod_switch_2n(2 * module.n(), &mut lwe_2n, &lwe_ref);
|
||||
|
||||
let a: &[i64] = &lwe_2n[1..];
|
||||
let b: i64 = lwe_2n[0];
|
||||
@@ -102,10 +327,10 @@ pub fn cggi_blind_rotate<DataRes, DataIn>(
|
||||
println!("external products: {} us", duration.as_micros());
|
||||
}
|
||||
|
||||
pub(crate) fn mod_switch_2n(module: &Module<FFT64>, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) {
|
||||
pub(crate) fn mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) {
|
||||
let basek: usize = lwe.basek();
|
||||
|
||||
let log2n: usize = module.log_n() + 1;
|
||||
let log2n: usize = usize::BITS as usize - (n - 1).leading_zeros() as usize + 1;
|
||||
|
||||
res.copy_from_slice(&lwe.data.at(0, 0));
|
||||
|
||||
|
||||
@@ -16,6 +16,10 @@ impl LookUpTable {
|
||||
Self { data, basek, k }
|
||||
}
|
||||
|
||||
pub fn extension_factor(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
pub fn set(&mut self, module: &Module<FFT64>, f: fn(i64) -> i64, message_modulus: usize) {
|
||||
let basek: usize = self.basek;
|
||||
|
||||
@@ -29,7 +33,7 @@ impl LookUpTable {
|
||||
let f_scaled = |x: i64| (f(x) % message_modulus as i64) * scale;
|
||||
|
||||
// If LUT size > module.n()
|
||||
let domain_size: usize = self.data[0].n() * self.data.len();
|
||||
let domain_size: usize = self.data[0].n() * self.extension_factor();
|
||||
|
||||
let size: usize = self.k.div_ceil(self.basek);
|
||||
|
||||
@@ -63,7 +67,7 @@ impl LookUpTable {
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned(lut_full.n() * size_of::<i64>());
|
||||
lut_full.normalize(self.basek, 0, &mut tmp_bytes);
|
||||
|
||||
if self.data.len() > 1 {
|
||||
if self.extension_factor() > 1 {
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(module.bytes_of_vec_znx(1, size));
|
||||
module.vec_znx_split(&mut self.data, 0, &lut_full, 0, scratch.borrow());
|
||||
} else {
|
||||
|
||||
@@ -103,7 +103,7 @@ fn blind_rotation() {
|
||||
|
||||
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
||||
|
||||
mod_switch_2n(&module, &mut lwe_2n, &lwe.to_ref());
|
||||
mod_switch_2n(module.n() * 2, &mut lwe_2n, &lwe.to_ref());
|
||||
|
||||
let pt_want: i64 = (lwe_2n[0]
|
||||
+ lwe_2n[1..]
|
||||
|
||||
Reference in New Issue
Block a user