mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
updated key-switch for rank switching & updated glwe key-switching test
This commit is contained in:
@@ -52,6 +52,14 @@ impl<T, B: Backend> GGLWECiphertext<T, B> {
|
||||
pub fn rank(&self) -> usize {
|
||||
self.data.cols_out() - 1
|
||||
}
|
||||
|
||||
pub fn rank_in(&self) -> usize {
|
||||
self.data.cols_in()
|
||||
}
|
||||
|
||||
pub fn rank_out(&self) -> usize {
|
||||
self.data.cols_out() - 1
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> MatZnxDftToMut<B> for GGLWECiphertext<C, B>
|
||||
@@ -104,7 +112,8 @@ where
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank(), sk_dft.rank());
|
||||
assert_eq!(self.rank_in(), pt.cols());
|
||||
assert_eq!(self.rank_out(), sk_dft.rank());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(sk_dft.n(), module.n());
|
||||
assert_eq!(pt.n(), module.n());
|
||||
@@ -115,11 +124,12 @@ where
|
||||
let basek: usize = self.basek();
|
||||
let k: usize = self.k();
|
||||
|
||||
let cols: usize = self.rank() + 1;
|
||||
let cols_in: usize = self.rank_in();
|
||||
let cols_out: usize = self.rank_out() + 1;
|
||||
|
||||
let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size);
|
||||
let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, cols, size);
|
||||
let (tmp_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, cols, size);
|
||||
let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, cols_out, size);
|
||||
let (tmp_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, cols_out, size);
|
||||
|
||||
let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext {
|
||||
data: tmp_znx_pt,
|
||||
@@ -139,29 +149,42 @@ where
|
||||
k,
|
||||
};
|
||||
|
||||
(0..rows).for_each(|row_i| {
|
||||
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
||||
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3);
|
||||
// For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns
|
||||
//
|
||||
// Example for ksk rank 2 to rank 3:
|
||||
//
|
||||
// (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2)
|
||||
// (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2)
|
||||
//
|
||||
// Example ksk rank 2 to rank 1
|
||||
//
|
||||
// (-(a*s) + s0, a)
|
||||
// (-(b*s) + s1, b)
|
||||
(0..cols_in).for_each(|col_i| {
|
||||
(0..rows).for_each(|row_i| {
|
||||
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
||||
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, col_i); // Selects the i-th
|
||||
module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3);
|
||||
|
||||
// rlwe encrypt of vec_znx_pt into vec_znx_ct
|
||||
vec_znx_ct.encrypt_sk(
|
||||
module,
|
||||
&vec_znx_pt,
|
||||
sk_dft,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
scratch_3,
|
||||
);
|
||||
// rlwe encrypt of vec_znx_pt into vec_znx_ct
|
||||
vec_znx_ct.encrypt_sk(
|
||||
module,
|
||||
&vec_znx_pt,
|
||||
sk_dft,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
scratch_3,
|
||||
);
|
||||
|
||||
vec_znx_pt.data.zero(); // zeroes for next iteration
|
||||
vec_znx_pt.data.zero(); // zeroes for next iteration
|
||||
|
||||
// Switch vec_znx_ct into DFT domain
|
||||
vec_znx_ct.dft(module, &mut vec_znx_ct_dft);
|
||||
// Switch vec_znx_ct into DFT domain
|
||||
vec_znx_ct.dft(module, &mut vec_znx_ct_dft);
|
||||
|
||||
// Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft
|
||||
module.vmp_prepare_row(self, row_i, 0, &vec_znx_ct_dft);
|
||||
// Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft
|
||||
module.vmp_prepare_row(self, row_i, col_i, &vec_znx_ct_dft);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -174,10 +197,6 @@ where
|
||||
where
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(col_j, 0);
|
||||
}
|
||||
module.vmp_extract_row(res, self, row_i, col_j);
|
||||
}
|
||||
}
|
||||
@@ -190,20 +209,23 @@ where
|
||||
where
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(col_j, 0);
|
||||
}
|
||||
module.vmp_prepare_row(self, row_i, col_j, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl VecGLWEProductScratchSpace for GGLWECiphertext<Vec<u8>, FFT64> {
|
||||
fn prod_with_glwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
|
||||
module.bytes_of_vec_znx_dft(2, grlwe_size)
|
||||
fn prod_with_glwe_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
grlwe_size: usize,
|
||||
rank_in: usize,
|
||||
rank_out: usize,
|
||||
) -> usize {
|
||||
module.bytes_of_vec_znx_dft(rank_out + 1, grlwe_size)
|
||||
+ (module.vec_znx_big_normalize_tmp_bytes()
|
||||
| (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size)
|
||||
+ module.bytes_of_vec_znx_dft(1, a_size)))
|
||||
| (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, rank_in, rank_out + 1, grlwe_size)
|
||||
+ module.bytes_of_vec_znx_dft(rank_in, a_size)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,30 +244,38 @@ where
|
||||
VecZnx<R>: VecZnxToMut,
|
||||
VecZnx<A>: VecZnxToRef,
|
||||
{
|
||||
let log_base2k: usize = self.basek();
|
||||
let basek: usize = self.basek();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.basek(), log_base2k);
|
||||
assert_eq!(a.basek(), log_base2k);
|
||||
assert_eq!(a.rank(), self.rank_in());
|
||||
assert_eq!(res.rank(), self.rank_out());
|
||||
assert_eq!(res.basek(), basek);
|
||||
assert_eq!(a.basek(), basek);
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), module.n());
|
||||
}
|
||||
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise
|
||||
let cols_in: usize = self.rank_in();
|
||||
let cols_out: usize = self.rank_out() + 1;
|
||||
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, self.size()); // Todo optimise
|
||||
|
||||
{
|
||||
let (mut a1_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, a.size());
|
||||
module.vec_znx_dft(&mut a1_dft, 0, a, 1);
|
||||
module.vmp_apply(&mut res_dft, &a1_dft, self, scratch2);
|
||||
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, a.size());
|
||||
(0..cols_in).for_each(|col_i| {
|
||||
module.vec_znx_dft(&mut ai_dft, col_i, a, col_i + 1);
|
||||
});
|
||||
module.vmp_apply(&mut res_dft, &ai_dft, self, scratch2);
|
||||
}
|
||||
|
||||
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
|
||||
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
|
||||
|
||||
module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1);
|
||||
module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1);
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_big_normalize(basek, res, i, &res_big, i, scratch1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user