updated key-switch for rank switching & updated glwe key-switching test

This commit is contained in:
Jean-Philippe Bossuat
2025-05-14 16:34:52 +02:00
parent cb1928802a
commit f517a730a3
10 changed files with 1806 additions and 1676 deletions

View File

@@ -52,6 +52,14 @@ impl<T, B: Backend> GGLWECiphertext<T, B> {
pub fn rank(&self) -> usize {
self.data.cols_out() - 1
}
pub fn rank_in(&self) -> usize {
self.data.cols_in()
}
pub fn rank_out(&self) -> usize {
self.data.cols_out() - 1
}
}
impl<C, B: Backend> MatZnxDftToMut<B> for GGLWECiphertext<C, B>
@@ -104,7 +112,8 @@ where
{
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), sk_dft.rank());
assert_eq!(self.rank_in(), pt.cols());
assert_eq!(self.rank_out(), sk_dft.rank());
assert_eq!(self.n(), module.n());
assert_eq!(sk_dft.n(), module.n());
assert_eq!(pt.n(), module.n());
@@ -115,11 +124,12 @@ where
let basek: usize = self.basek();
let k: usize = self.k();
let cols: usize = self.rank() + 1;
let cols_in: usize = self.rank_in();
let cols_out: usize = self.rank_out() + 1;
let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size);
let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, cols, size);
let (tmp_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, cols, size);
let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, cols_out, size);
let (tmp_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, cols_out, size);
let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext {
data: tmp_znx_pt,
@@ -139,29 +149,42 @@ where
k,
};
(0..rows).for_each(|row_i| {
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0);
module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3);
// For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns
//
// Example for ksk rank 2 to rank 3:
//
// (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2)
// (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2)
//
// Example ksk rank 2 to rank 1
//
// (-(a*s) + s0, a)
// (-(b*s) + s1, b)
(0..cols_in).for_each(|col_i| {
(0..rows).for_each(|row_i| {
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, col_i); // Selects the i-th
module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3);
// rlwe encrypt of vec_znx_pt into vec_znx_ct
vec_znx_ct.encrypt_sk(
module,
&vec_znx_pt,
sk_dft,
source_xa,
source_xe,
sigma,
scratch_3,
);
// rlwe encrypt of vec_znx_pt into vec_znx_ct
vec_znx_ct.encrypt_sk(
module,
&vec_znx_pt,
sk_dft,
source_xa,
source_xe,
sigma,
scratch_3,
);
vec_znx_pt.data.zero(); // zeroes for next iteration
vec_znx_pt.data.zero(); // zeroes for next iteration
// Switch vec_znx_ct into DFT domain
vec_znx_ct.dft(module, &mut vec_znx_ct_dft);
// Switch vec_znx_ct into DFT domain
vec_znx_ct.dft(module, &mut vec_znx_ct_dft);
// Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft
module.vmp_prepare_row(self, row_i, 0, &vec_znx_ct_dft);
// Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft
module.vmp_prepare_row(self, row_i, col_i, &vec_znx_ct_dft);
});
});
}
}
@@ -174,10 +197,6 @@ where
where
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(col_j, 0);
}
module.vmp_extract_row(res, self, row_i, col_j);
}
}
@@ -190,20 +209,23 @@ where
where
VecZnxDft<R, FFT64>: VecZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(col_j, 0);
}
module.vmp_prepare_row(self, row_i, col_j, a);
}
}
impl VecGLWEProductScratchSpace for GGLWECiphertext<Vec<u8>, FFT64> {
fn prod_with_glwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
module.bytes_of_vec_znx_dft(2, grlwe_size)
fn prod_with_glwe_scratch_space(
module: &Module<FFT64>,
res_size: usize,
a_size: usize,
grlwe_size: usize,
rank_in: usize,
rank_out: usize,
) -> usize {
module.bytes_of_vec_znx_dft(rank_out + 1, grlwe_size)
+ (module.vec_znx_big_normalize_tmp_bytes()
| (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size)
+ module.bytes_of_vec_znx_dft(1, a_size)))
| (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, rank_in, rank_out + 1, grlwe_size)
+ module.bytes_of_vec_znx_dft(rank_in, a_size)))
}
}
@@ -222,30 +244,38 @@ where
VecZnx<R>: VecZnxToMut,
VecZnx<A>: VecZnxToRef,
{
let log_base2k: usize = self.basek();
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(res.basek(), log_base2k);
assert_eq!(a.basek(), log_base2k);
assert_eq!(a.rank(), self.rank_in());
assert_eq!(res.rank(), self.rank_out());
assert_eq!(res.basek(), basek);
assert_eq!(a.basek(), basek);
assert_eq!(self.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), module.n());
}
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise
let cols_in: usize = self.rank_in();
let cols_out: usize = self.rank_out() + 1;
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, self.size()); // Todo optimise
{
let (mut a1_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, a.size());
module.vec_znx_dft(&mut a1_dft, 0, a, 1);
module.vmp_apply(&mut res_dft, &a1_dft, self, scratch2);
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, a.size());
(0..cols_in).for_each(|col_i| {
module.vec_znx_dft(&mut ai_dft, col_i, a, col_i + 1);
});
module.vmp_apply(&mut res_dft, &ai_dft, self, scratch2);
}
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1);
module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1);
(0..cols_out).for_each(|i| {
module.vec_znx_big_normalize(basek, res, i, &res_big, i, scratch1);
});
}
}