Fixes after meeting

This commit is contained in:
Jean-Philippe Bossuat
2025-07-11 12:29:49 +02:00
parent 38df06f7ab
commit 52a6a130a5
6 changed files with 188 additions and 151 deletions

View File

@@ -1,5 +1,5 @@
use backend::{
FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps,
FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps,
Scratch, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxView,
ZnxViewMut, ZnxZero,
};
@@ -30,7 +30,7 @@ pub fn cggi_blind_rotate_scratch_space(
let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size);
let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor;
let acc_dft_add: usize = vmp_res;
let xai_plus_y: usize = module.bytes_of_scalar_znx(1);
let xai_plus_y: usize = module.bytes_of_scalar_znx_dft(1);
let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1);
let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
@@ -54,16 +54,17 @@ pub fn cggi_blind_rotate_scratch_space(
}
}
pub fn cggi_blind_rotate<DataRes, DataIn>(
pub fn cggi_blind_rotate<DataRes, DataIn, DataBrk>(
module: &Module<FFT64>,
res: &mut GLWECiphertext<DataRes>,
lwe: &LWECiphertext<DataIn>,
lut: &LookUpTable,
brk: &BlindRotationKeyCGGI<FFT64>,
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
scratch: &mut Scratch,
) where
DataRes: AsRef<[u8]> + AsMut<[u8]>,
DataIn: AsRef<[u8]>,
DataBrk: AsRef<[u8]>,
{
match brk.dist {
Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => {
@@ -82,16 +83,17 @@ pub fn cggi_blind_rotate<DataRes, DataIn>(
}
}
pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
module: &Module<FFT64>,
res: &mut GLWECiphertext<DataRes>,
lwe: &LWECiphertext<DataIn>,
lut: &LookUpTable,
brk: &BlindRotationKeyCGGI<FFT64>,
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
scratch: &mut Scratch,
) where
DataRes: AsRef<[u8]> + AsMut<[u8]>,
DataIn: AsRef<[u8]>,
DataBrk: AsRef<[u8]>,
{
let extension_factor: usize = lut.extension_factor();
let basek: usize = res.basek();
@@ -102,25 +104,35 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
let (mut acc_dft, scratch2) = scratch1.tmp_slice_vec_znx_dft(extension_factor, module, cols, rows);
let (mut vmp_res, scratch3) = scratch2.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size());
let (mut acc_add_dft, scratch4) = scratch3.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size());
let (mut xai_plus_y, scratch5) = scratch4.tmp_scalar_znx(module, 1);
let (mut minus_one, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1);
let (mut xai_plus_y_dft, scratch6) = scratch5.tmp_scalar_znx_dft(module, 1);
minus_one.raw_mut()[..module.n() >> 1].fill(-1.0);
(0..extension_factor).for_each(|i| {
acc[i].zero();
});
let x_pow_a: &Vec<ScalarZnxDft<Vec<u8>, FFT64>>;
if let Some(b) = &brk.x_pow_a {
x_pow_a = b
} else {
panic!("invalid key: x_pow_a has not been initialized")
}
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
let two_n: usize = 2 * module.n();
let two_n_ext: usize = 2 * lut.domain_size();
negate_and_mod_switch_2n(two_n_ext, &mut lwe_2n, &lwe_ref);
let a: &[i64] = &lwe_2n[1..];
let b_pos: usize = ((lwe_2n[0] + two_n_ext as i64) % two_n_ext as i64) as usize;
let b_pos: usize = ((lwe_2n[0] + two_n_ext as i64) & (two_n_ext - 1) as i64) as usize;
let b_hi: usize = b_pos / extension_factor;
let b_lo: usize = b_pos % extension_factor;
let b_lo: usize = b_pos & (extension_factor - 1);
for (i, j) in (0..b_lo).zip(extension_factor - b_lo..extension_factor) {
module.vec_znx_rotate(b_hi as i64 + 1, &mut acc[i], 0, &lut.data[j], 0);
@@ -145,9 +157,9 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
// TODO: first & last iterations can be optimized
izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| {
let ai_pos: usize = ((aii + two_n_ext as i64) % two_n_ext as i64) as usize;
let ai_pos: usize = ((aii + two_n_ext as i64) & (two_n_ext - 1) as i64) as usize;
let ai_hi: usize = ai_pos / extension_factor;
let ai_lo: usize = ai_pos % extension_factor;
let ai_lo: usize = ai_pos & (extension_factor - 1);
// vmp_res = DFT(acc) * BRK[i]
(0..extension_factor).for_each(|i| {
@@ -156,48 +168,62 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
// Trivial case: no rotation between polynomials, we can directly multiply with (X^{-ai} - 1)
if ai_lo == 0 {
// DFT X^{-ai}
set_xai_plus_y(module, ai_hi, -1, &mut xai_plus_y_dft, &mut xai_plus_y);
// Sets acc_add_dft[i] = (acc[i] * sk) * (X^{-ai} - 1)
(0..extension_factor).for_each(|j| {
(0..cols).for_each(|i| {
module.svp_apply_inplace(&mut vmp_res[j], i, &xai_plus_y_dft, 0);
module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i);
if ai_hi != 0 {
// DFT X^{-ai}
module.vec_znx_dft_add(&mut xai_plus_y_dft, 0, &x_pow_a[ai_hi], 0, &minus_one, 0);
(0..extension_factor).for_each(|j| {
(0..cols).for_each(|i| {
module.svp_apply_inplace(&mut vmp_res[j], i, &xai_plus_y_dft, 0);
module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i);
});
});
});
}
// Non trivial case: rotation between polynomials
// In this case we can't directly multiply with (X^{-ai} - 1) because of the
// ring homomorphism R^{N} -> prod R^{N/extension_factor}, so we split the
// computation in two steps: acc_add_dft = (acc * sk) * (-1) + (acc * sk) * X^{-ai}
} else {
// Sets acc_add_dft[i] = acc[i] * sk
(0..extension_factor).for_each(|i| {
(0..cols).for_each(|k| {
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
})
});
// DFT X^{-ai}
set_xai_plus_y(module, ai_hi + 1, 0, &mut xai_plus_y_dft, &mut xai_plus_y);
// Sets acc_add_dft[0..ai_lo] += (acc[extension_factor - ai_lo..extension_factor] * sk) * X^{-ai+1}
for (i, j) in (0..ai_lo).zip(extension_factor - ai_lo..extension_factor) {
(0..cols).for_each(|k| {
module.svp_apply_inplace(&mut vmp_res[j], k, &xai_plus_y_dft, 0);
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k);
});
// Sets acc_add_dft[0..ai_lo] -= acc[..ai_lo] * sk
if (ai_hi + 1) & (two_n - 1) != 0 {
for i in 0..ai_lo {
(0..cols).for_each(|k| {
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
});
}
}
// DFT X^{-ai}
set_xai_plus_y(module, ai_hi, 0, &mut xai_plus_y_dft, &mut xai_plus_y);
// Sets acc_add_dft[ai_lo..extension_factor] -= acc[ai_lo..extension_factor] * sk
if ai_hi != 0 {
for i in ai_lo..extension_factor {
(0..cols).for_each(|k: usize| {
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
});
}
}
// Sets acc_add_dft[0..ai_lo] += (acc[extension_factor - ai_lo..extension_factor] * sk) * X^{-ai+1}
if (ai_hi + 1) & (two_n - 1) != 0 {
for (i, j) in (0..ai_lo).zip(extension_factor - ai_lo..extension_factor) {
(0..cols).for_each(|k| {
module.svp_apply_inplace(&mut vmp_res[j], k, &x_pow_a[ai_hi + 1], 0);
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k);
});
}
}
// Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai}
for (i, j) in (ai_lo..extension_factor).zip(0..extension_factor - ai_lo) {
(0..cols).for_each(|k| {
module.svp_apply_inplace(&mut vmp_res[j], k, &xai_plus_y_dft, 0);
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k);
});
if ai_hi != 0 {
// Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai}
for (i, j) in (ai_lo..extension_factor).zip(0..extension_factor - ai_lo) {
(0..cols).for_each(|k| {
module.svp_apply_inplace(&mut vmp_res[j], k, &x_pow_a[ai_hi], 0);
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k);
});
}
}
}
});
@@ -220,49 +246,17 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
});
}
fn set_xai_plus_y(
module: &Module<FFT64>,
ai: usize,
y: i64,
res: &mut ScalarZnxDft<&mut [u8], FFT64>,
buf: &mut ScalarZnx<&mut [u8]>,
) {
let n: usize = module.n();
{
let raw: &mut [i64] = buf.at_mut(0, 0);
if ai < n {
raw[ai] = 1;
} else {
raw[(ai - n) & (n - 1)] = -1;
}
raw[0] += y;
}
module.svp_prepare(res, 0, buf, 0);
{
let raw: &mut [i64] = buf.at_mut(0, 0);
if ai < n {
raw[ai] = 0;
} else {
raw[(ai - n) & (n - 1)] = 0;
}
raw[0] = 0;
}
}
pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn, DataBrk>(
module: &Module<FFT64>,
res: &mut GLWECiphertext<DataRes>,
lwe: &LWECiphertext<DataIn>,
lut: &LookUpTable,
brk: &BlindRotationKeyCGGI<FFT64>,
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
scratch: &mut Scratch,
) where
DataRes: AsRef<[u8]> + AsMut<[u8]>,
DataIn: AsRef<[u8]>,
DataBrk: AsRef<[u8]>,
{
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
@@ -290,9 +284,18 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
let (mut acc_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rows);
let (mut vmp_res, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, brk.size());
let (mut acc_add_dft, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, brk.size());
let (mut xai_plus_y, scratch4) = scratch3.tmp_scalar_znx(module, 1);
let (mut minus_one, scratch4) = scratch3.tmp_scalar_znx_dft(module, 1);
let (mut xai_plus_y_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1);
minus_one.raw_mut()[..module.n() >> 1].fill(-1.0);
let x_pow_a: &Vec<ScalarZnxDft<Vec<u8>, FFT64>>;
if let Some(b) = &brk.x_pow_a {
x_pow_a = b
} else {
panic!("invalid key: x_pow_a has not been initialized")
}
izip!(
a.chunks_exact(block_size),
brk.data.chunks_exact(block_size)
@@ -305,13 +308,13 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
acc_add_dft.zero();
izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| {
let ai_pos: usize = ((aii + two_n as i64) % two_n as i64) as usize;
let ai_pos: usize = ((aii + two_n as i64) & (two_n - 1) as i64) as usize;
// vmp_res = DFT(acc) * BRK[i]
module.vmp_apply(&mut vmp_res, &acc_dft, &skii.data, scratch5);
// DFT(X^ai -1)
set_xai_plus_y(module, ai_pos, -1, &mut xai_plus_y_dft, &mut xai_plus_y);
module.vec_znx_dft_add(&mut xai_plus_y_dft, 0, &x_pow_a[ai_pos], 0, &minus_one, 0);
// DFT(X^ai -1) * (DFT(acc) * BRK[i])
(0..cols).for_each(|i| {
@@ -320,10 +323,6 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
});
});
(0..cols).for_each(|i| {
module.vec_znx_dft_add_inplace(&mut acc_dft, i, &acc_add_dft, i);
});
{
let (mut acc_add_big, scratch6) = scratch5.tmp_vec_znx_big(module, 1, brk.size());
@@ -336,16 +335,17 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
});
}
pub(crate) fn cggi_blind_rotate_binary_standard<DataRes, DataIn>(
pub(crate) fn cggi_blind_rotate_binary_standard<DataRes, DataIn, DataBrk>(
module: &Module<FFT64>,
res: &mut GLWECiphertext<DataRes>,
lwe: &LWECiphertext<DataIn>,
lut: &LookUpTable,
brk: &BlindRotationKeyCGGI<FFT64>,
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
scratch: &mut Scratch,
) where
DataRes: AsRef<[u8]> + AsMut<[u8]>,
DataIn: AsRef<[u8]>,
DataBrk: AsRef<[u8]>,
{
#[cfg(debug_assertions)]
{