Fixes after meeting

This commit is contained in:
Jean-Philippe Bossuat
2025-07-11 12:29:49 +02:00
parent 38df06f7ab
commit 52a6a130a5
6 changed files with 188 additions and 151 deletions

View File

@@ -1,11 +1,15 @@
use backend::{Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut};
use backend::{
Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxToRef, Scratch,
ZnxView, ZnxViewMut,
};
use sampling::source::Source;
use crate::{Distribution, FourierGLWESecret, GGSWCiphertext, Infos, LWESecret};
pub struct BlindRotationKeyCGGI<B: Backend> {
pub(crate) data: Vec<GGSWCiphertext<Vec<u8>, B>>,
pub struct BlindRotationKeyCGGI<D, B: Backend> {
pub(crate) data: Vec<GGSWCiphertext<D, B>>,
pub(crate) dist: Distribution,
pub(crate) x_pow_a: Option<Vec<ScalarZnxDft<Vec<u8>, B>>>,
}
// pub struct BlindRotationKeyFHEW<B: Backend> {
@@ -13,20 +17,61 @@ pub struct BlindRotationKeyCGGI<B: Backend> {
// pub(crate) auto: Vec<GLWEAutomorphismKey<Vec<u8>, B>>,
//}
impl BlindRotationKeyCGGI<FFT64> {
impl BlindRotationKeyCGGI<Vec<u8>, FFT64> {
pub fn allocate(module: &Module<FFT64>, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
let mut data: Vec<GGSWCiphertext<Vec<u8>, FFT64>> = Vec::with_capacity(n_lwe);
(0..n_lwe).for_each(|_| data.push(GGSWCiphertext::alloc(module, basek, k, rows, 1, rank)));
Self {
data,
dist: Distribution::NONE,
x_pow_a: None::<Vec<ScalarZnxDft<Vec<u8>, FFT64>>>,
}
}
pub fn generate_from_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank)
}
}
impl<D: AsRef<[u8]>> BlindRotationKeyCGGI<D, FFT64> {
#[allow(dead_code)]
pub(crate) fn n(&self) -> usize {
self.data[0].n()
}
#[allow(dead_code)]
pub(crate) fn rows(&self) -> usize {
self.data[0].rows()
}
#[allow(dead_code)]
pub(crate) fn k(&self) -> usize {
self.data[0].k()
}
#[allow(dead_code)]
pub(crate) fn size(&self) -> usize {
self.data[0].size()
}
#[allow(dead_code)]
pub(crate) fn rank(&self) -> usize {
self.data[0].rank()
}
pub(crate) fn basek(&self) -> usize {
self.data[0].basek()
}
pub(crate) fn block_size(&self) -> usize {
match self.dist {
Distribution::BinaryBlock(value) => value,
_ => 1,
}
}
}
impl<D: AsRef<[u8]> + AsMut<[u8]>> BlindRotationKeyCGGI<D, FFT64> {
pub fn generate_from_sk<DataSkGLWE, DataSkLWE>(
&mut self,
module: &Module<FFT64>,
@@ -64,42 +109,51 @@ impl BlindRotationKeyCGGI<FFT64> {
self.data.iter_mut().enumerate().for_each(|(i, ggsw)| {
pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i];
ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, sigma, scratch);
})
}
});
pub(crate) fn block_size(&self) -> usize {
match self.dist {
Distribution::BinaryBlock(value) => value,
_ => 1,
match sk_lwe.dist {
Distribution::BinaryBlock(_) => {
let mut x_pow_a: Vec<ScalarZnxDft<Vec<u8>, FFT64>> = Vec::with_capacity(module.n() << 1);
let mut buf: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
(0..module.n() << 1).for_each(|i| {
let mut res: ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(1);
set_xai_plus_y(module, i, 0, &mut res, &mut buf);
x_pow_a.push(res);
});
self.x_pow_a = Some(x_pow_a);
}
_ => {}
}
}
}
#[allow(dead_code)]
pub(crate) fn n(&self) -> usize {
self.data[0].n()
pub fn set_xai_plus_y<A, B>(module: &Module<FFT64>, ai: usize, y: i64, res: &mut ScalarZnxDft<A, FFT64>, buf: &mut ScalarZnx<B>)
where
A: AsRef<[u8]> + AsMut<[u8]>,
B: AsRef<[u8]> + AsMut<[u8]>,
{
let n: usize = module.n();
{
let raw: &mut [i64] = buf.at_mut(0, 0);
if ai < n {
raw[ai] = 1;
} else {
raw[(ai - n) & (n - 1)] = -1;
}
raw[0] += y;
}
#[allow(dead_code)]
pub(crate) fn rows(&self) -> usize {
self.data[0].rows()
}
module.svp_prepare(res, 0, buf, 0);
#[allow(dead_code)]
pub(crate) fn k(&self) -> usize {
self.data[0].k()
}
{
let raw: &mut [i64] = buf.at_mut(0, 0);
#[allow(dead_code)]
pub(crate) fn size(&self) -> usize {
self.data[0].size()
}
#[allow(dead_code)]
pub(crate) fn rank(&self) -> usize {
self.data[0].rank()
}
pub(crate) fn basek(&self) -> usize {
self.data[0].basek()
if ai < n {
raw[ai] = 0;
} else {
raw[(ai - n) & (n - 1)] = 0;
}
raw[0] = 0;
}
}