mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
wip adding automorphism on AutomorphismKey
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps,
|
||||
ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef,
|
||||
ScalarZnxToRef, Scratch, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
@@ -8,6 +8,7 @@ use crate::{
|
||||
elem::{GetRow, Infos, SetRow},
|
||||
gglwe_ciphertext::GGLWECiphertext,
|
||||
ggsw_ciphertext::GGSWCiphertext,
|
||||
glwe_ciphertext::GLWECiphertext,
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
keys::{SecretKey, SecretKeyFourier},
|
||||
keyswitch_key::GLWESwitchingKey,
|
||||
@@ -203,6 +204,103 @@ impl<DataSelf> AutomorphismKey<DataSelf, FFT64>
|
||||
where
|
||||
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
pub fn automorphism<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &AutomorphismKey<DataLhs, FFT64>,
|
||||
rhs: &AutomorphismKey<DataRhs, FFT64>,
|
||||
scratch: &mut base2k::Scratch,
|
||||
) where
|
||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank_in(),
|
||||
lhs.rank_in(),
|
||||
"ksk_out input rank: {} != ksk_in input rank: {}",
|
||||
self.rank_in(),
|
||||
lhs.rank_in()
|
||||
);
|
||||
assert_eq!(
|
||||
lhs.rank_out(),
|
||||
rhs.rank_in(),
|
||||
"ksk_in output rank: {} != ksk_apply input rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank_in()
|
||||
);
|
||||
assert_eq!(
|
||||
self.rank_out(),
|
||||
rhs.rank_out(),
|
||||
"ksk_out output rank: {} != ksk_apply output rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank_out()
|
||||
);
|
||||
}
|
||||
|
||||
let cols_out: usize = rhs.rank_out() + 1;
|
||||
|
||||
let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, lhs.size());
|
||||
|
||||
let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_dft_data,
|
||||
basek: lhs.basek(),
|
||||
k: lhs.k(),
|
||||
};
|
||||
|
||||
(0..self.rank_in()).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
// Extracts relevant row
|
||||
lhs.get_row(module, row_j, col_i, &mut tmp_dft);
|
||||
|
||||
// Get a VecZnxBig from scratch space
|
||||
let (mut tmp_idft_data, scratch2) = scratch1.tmp_vec_znx_big(module, cols_out, self.size());
|
||||
|
||||
// Switches input outside of DFT
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2);
|
||||
});
|
||||
|
||||
// Consumes to small vec znx
|
||||
let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small();
|
||||
|
||||
// Reverts the automorphis key from (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_automorphism_inplace(self.p(), &mut tmp_idft_small_data, i);
|
||||
});
|
||||
|
||||
// Wraps into ciphertext
|
||||
let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
|
||||
data: tmp_idft_small_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
|
||||
tmp_idft.keyswitch_inplace(module, &rhs.key, scratch2);
|
||||
|
||||
// Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a)
|
||||
// and switches back to DFT domain
|
||||
(0..self.rank_out() + 1).for_each(|i| {
|
||||
module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i);
|
||||
module.vec_znx_dft(&mut tmp_dft, i, &tmp_idft, i);
|
||||
});
|
||||
|
||||
// Sets back the relevant row
|
||||
self.set_row(module, row_j, col_i, &tmp_dft);
|
||||
});
|
||||
});
|
||||
|
||||
tmp_dft.data.zero();
|
||||
|
||||
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
|
||||
(0..self.rank_in()).for_each(|col_j| {
|
||||
self.set_row(module, row_i, col_j, &tmp_dft);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub fn keyswitch<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
|
||||
@@ -299,6 +299,70 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn keyswitch_from_fourier<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
let basek: usize = self.basek();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(lhs.rank(), rhs.rank_in());
|
||||
assert_eq!(self.rank(), rhs.rank_out());
|
||||
assert_eq!(self.basek(), basek);
|
||||
assert_eq!(lhs.basek(), basek);
|
||||
assert_eq!(rhs.n(), module.n());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(lhs.n(), module.n());
|
||||
assert!(
|
||||
scratch.available()
|
||||
>= GLWECiphertextFourier::keyswitch_scratch_space(
|
||||
module,
|
||||
self.size(),
|
||||
self.rank(),
|
||||
lhs.size(),
|
||||
lhs.rank(),
|
||||
rhs.size(),
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let cols_in: usize = rhs.rank_in();
|
||||
let cols_out: usize = rhs.rank_out() + 1;
|
||||
|
||||
// Buffer of the result of VMP in DFT
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
|
||||
|
||||
{
|
||||
// Applies VMP
|
||||
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
|
||||
(0..cols_in).for_each(|col_i| {
|
||||
module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
|
||||
});
|
||||
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
|
||||
}
|
||||
|
||||
// Switches result of VMP outside of DFT
|
||||
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
|
||||
|
||||
{
|
||||
// Switches lhs 0-th outside of DFT domain and adds on
|
||||
let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size());
|
||||
module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2);
|
||||
module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0);
|
||||
}
|
||||
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn keyswitch<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
|
||||
@@ -174,60 +174,21 @@ where
|
||||
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
let basek: usize = self.basek();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(lhs.rank(), rhs.rank_in());
|
||||
assert_eq!(self.rank(), rhs.rank_out());
|
||||
assert_eq!(self.basek(), basek);
|
||||
assert_eq!(lhs.basek(), basek);
|
||||
assert_eq!(rhs.n(), module.n());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(lhs.n(), module.n());
|
||||
assert!(
|
||||
scratch.available()
|
||||
>= GLWECiphertextFourier::keyswitch_scratch_space(
|
||||
module,
|
||||
self.size(),
|
||||
self.rank(),
|
||||
lhs.size(),
|
||||
lhs.rank(),
|
||||
rhs.size(),
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let cols_in: usize = rhs.rank_in();
|
||||
let cols_out: usize = rhs.rank_out() + 1;
|
||||
|
||||
// Buffer of the result of VMP in DFT
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
|
||||
|
||||
{
|
||||
// Applies VMP
|
||||
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
|
||||
(0..cols_in).for_each(|col_i| {
|
||||
module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
|
||||
});
|
||||
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
|
||||
}
|
||||
|
||||
// Switches result of VMP outside of DFT
|
||||
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
|
||||
|
||||
{
|
||||
// Switches lhs 0-th outside of DFT domain and adds on
|
||||
let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size());
|
||||
module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2);
|
||||
module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0);
|
||||
}
|
||||
|
||||
// Space fr normalized VMP result outside of DFT domain
|
||||
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols_out, lhs.size());
|
||||
let (res_idft_data, scratch1) = scratch.tmp_vec_znx(module, cols_out, lhs.size());
|
||||
|
||||
let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
|
||||
data: res_idft_data,
|
||||
basek: self.basek,
|
||||
k: self.k,
|
||||
};
|
||||
|
||||
res_idft.keyswitch_from_fourier(module, self, rhs, scratch1);
|
||||
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
|
||||
module.vec_znx_dft(self, i, &res_small, i);
|
||||
module.vec_znx_dft(self, i, &res_idft, i);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user