mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
wip adding automorphism on AutomorphismKey
This commit is contained in:
@@ -130,7 +130,7 @@ pub trait ScalarZnxOps {
|
|||||||
A: ScalarZnxToRef;
|
A: ScalarZnxToRef;
|
||||||
|
|
||||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
|
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
|
||||||
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||||
where
|
where
|
||||||
A: ScalarZnxToMut;
|
A: ScalarZnxToMut;
|
||||||
}
|
}
|
||||||
@@ -162,7 +162,7 @@ impl<B: Backend> ScalarZnxOps for Module<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||||
where
|
where
|
||||||
A: ScalarZnxToMut,
|
A: ScalarZnxToMut,
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use crate::ffi::vec_znx_big;
|
use crate::ffi::vec_znx_big;
|
||||||
use crate::znx_base::{ZnxInfos, ZnxView};
|
use crate::znx_base::{ZnxInfos, ZnxView};
|
||||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned};
|
use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, FFT64};
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
@@ -98,6 +98,17 @@ impl<D> VecZnxBig<D, FFT64>
|
|||||||
where
|
where
|
||||||
VecZnxBig<D, FFT64>: VecZnxBigToMut<FFT64> + ZnxInfos,
|
VecZnxBig<D, FFT64>: VecZnxBigToMut<FFT64> + ZnxInfos,
|
||||||
{
|
{
|
||||||
|
// Consumes the VecZnxBig to return a VecZnx.
|
||||||
|
// Useful when no normalization is needed.
|
||||||
|
pub fn to_vec_znx_small(self) -> VecZnx<D>{
|
||||||
|
VecZnx{
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: self.size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
|
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
|
||||||
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxBig<C, FFT64>, a_col: usize)
|
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxBig<C, FFT64>, a_col: usize)
|
||||||
where
|
where
|
||||||
|
|||||||
@@ -147,6 +147,7 @@ pub trait VecZnxBigOps<BACKEND: Backend> {
|
|||||||
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||||
where
|
where
|
||||||
A: VecZnxBigToMut<BACKEND>;
|
A: VecZnxBigToMut<BACKEND>;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait VecZnxBigScratch {
|
pub trait VecZnxBigScratch {
|
||||||
@@ -169,6 +170,7 @@ impl<B: Backend> VecZnxBigAlloc<B> for Module<B> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||||
|
|
||||||
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||||
where
|
where
|
||||||
R: VecZnxBigToMut<FFT64>,
|
R: VecZnxBigToMut<FFT64>,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use base2k::{
|
use base2k::{
|
||||||
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps,
|
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps,
|
||||||
ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef,
|
ScalarZnxToRef, Scratch, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxZero,
|
||||||
};
|
};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
@@ -8,6 +8,7 @@ use crate::{
|
|||||||
elem::{GetRow, Infos, SetRow},
|
elem::{GetRow, Infos, SetRow},
|
||||||
gglwe_ciphertext::GGLWECiphertext,
|
gglwe_ciphertext::GGLWECiphertext,
|
||||||
ggsw_ciphertext::GGSWCiphertext,
|
ggsw_ciphertext::GGSWCiphertext,
|
||||||
|
glwe_ciphertext::GLWECiphertext,
|
||||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||||
keys::{SecretKey, SecretKeyFourier},
|
keys::{SecretKey, SecretKeyFourier},
|
||||||
keyswitch_key::GLWESwitchingKey,
|
keyswitch_key::GLWESwitchingKey,
|
||||||
@@ -203,6 +204,103 @@ impl<DataSelf> AutomorphismKey<DataSelf, FFT64>
|
|||||||
where
|
where
|
||||||
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
|
pub fn automorphism<DataLhs, DataRhs>(
|
||||||
|
&mut self,
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
lhs: &AutomorphismKey<DataLhs, FFT64>,
|
||||||
|
rhs: &AutomorphismKey<DataRhs, FFT64>,
|
||||||
|
scratch: &mut base2k::Scratch,
|
||||||
|
) where
|
||||||
|
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(
|
||||||
|
self.rank_in(),
|
||||||
|
lhs.rank_in(),
|
||||||
|
"ksk_out input rank: {} != ksk_in input rank: {}",
|
||||||
|
self.rank_in(),
|
||||||
|
lhs.rank_in()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
lhs.rank_out(),
|
||||||
|
rhs.rank_in(),
|
||||||
|
"ksk_in output rank: {} != ksk_apply input rank: {}",
|
||||||
|
self.rank_out(),
|
||||||
|
rhs.rank_in()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
self.rank_out(),
|
||||||
|
rhs.rank_out(),
|
||||||
|
"ksk_out output rank: {} != ksk_apply output rank: {}",
|
||||||
|
self.rank_out(),
|
||||||
|
rhs.rank_out()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let cols_out: usize = rhs.rank_out() + 1;
|
||||||
|
|
||||||
|
let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, lhs.size());
|
||||||
|
|
||||||
|
let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||||
|
data: tmp_dft_data,
|
||||||
|
basek: lhs.basek(),
|
||||||
|
k: lhs.k(),
|
||||||
|
};
|
||||||
|
|
||||||
|
(0..self.rank_in()).for_each(|col_i| {
|
||||||
|
(0..self.rows()).for_each(|row_j| {
|
||||||
|
// Extracts relevant row
|
||||||
|
lhs.get_row(module, row_j, col_i, &mut tmp_dft);
|
||||||
|
|
||||||
|
// Get a VecZnxBig from scratch space
|
||||||
|
let (mut tmp_idft_data, scratch2) = scratch1.tmp_vec_znx_big(module, cols_out, self.size());
|
||||||
|
|
||||||
|
// Switches input outside of DFT
|
||||||
|
(0..cols_out).for_each(|i| {
|
||||||
|
module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Consumes to small vec znx
|
||||||
|
let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small();
|
||||||
|
|
||||||
|
// Reverts the automorphis key from (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
|
||||||
|
(0..cols_out).for_each(|i| {
|
||||||
|
module.vec_znx_automorphism_inplace(self.p(), &mut tmp_idft_small_data, i);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wraps into ciphertext
|
||||||
|
let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
|
||||||
|
data: tmp_idft_small_data,
|
||||||
|
basek: self.basek(),
|
||||||
|
k: self.k(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
|
||||||
|
tmp_idft.keyswitch_inplace(module, &rhs.key, scratch2);
|
||||||
|
|
||||||
|
// Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a)
|
||||||
|
// and switches back to DFT domain
|
||||||
|
(0..self.rank_out() + 1).for_each(|i| {
|
||||||
|
module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i);
|
||||||
|
module.vec_znx_dft(&mut tmp_dft, i, &tmp_idft, i);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Sets back the relevant row
|
||||||
|
self.set_row(module, row_j, col_i, &tmp_dft);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
tmp_dft.data.zero();
|
||||||
|
|
||||||
|
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
|
||||||
|
(0..self.rank_in()).for_each(|col_j| {
|
||||||
|
self.set_row(module, row_i, col_j, &tmp_dft);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
pub fn keyswitch<DataLhs, DataRhs>(
|
pub fn keyswitch<DataLhs, DataRhs>(
|
||||||
&mut self,
|
&mut self,
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
|
|||||||
@@ -299,6 +299,70 @@ where
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn keyswitch_from_fourier<DataLhs, DataRhs>(
|
||||||
|
&mut self,
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
|
||||||
|
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||||
|
scratch: &mut Scratch,
|
||||||
|
) where
|
||||||
|
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
||||||
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
|
let basek: usize = self.basek();
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(lhs.rank(), rhs.rank_in());
|
||||||
|
assert_eq!(self.rank(), rhs.rank_out());
|
||||||
|
assert_eq!(self.basek(), basek);
|
||||||
|
assert_eq!(lhs.basek(), basek);
|
||||||
|
assert_eq!(rhs.n(), module.n());
|
||||||
|
assert_eq!(self.n(), module.n());
|
||||||
|
assert_eq!(lhs.n(), module.n());
|
||||||
|
assert!(
|
||||||
|
scratch.available()
|
||||||
|
>= GLWECiphertextFourier::keyswitch_scratch_space(
|
||||||
|
module,
|
||||||
|
self.size(),
|
||||||
|
self.rank(),
|
||||||
|
lhs.size(),
|
||||||
|
lhs.rank(),
|
||||||
|
rhs.size(),
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let cols_in: usize = rhs.rank_in();
|
||||||
|
let cols_out: usize = rhs.rank_out() + 1;
|
||||||
|
|
||||||
|
// Buffer of the result of VMP in DFT
|
||||||
|
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
|
||||||
|
|
||||||
|
{
|
||||||
|
// Applies VMP
|
||||||
|
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
|
||||||
|
(0..cols_in).for_each(|col_i| {
|
||||||
|
module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
|
||||||
|
});
|
||||||
|
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switches result of VMP outside of DFT
|
||||||
|
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
|
||||||
|
|
||||||
|
{
|
||||||
|
// Switches lhs 0-th outside of DFT domain and adds on
|
||||||
|
let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size());
|
||||||
|
module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2);
|
||||||
|
module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
(0..cols_out).for_each(|i| {
|
||||||
|
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
pub fn keyswitch<DataLhs, DataRhs>(
|
pub fn keyswitch<DataLhs, DataRhs>(
|
||||||
&mut self,
|
&mut self,
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
|
|||||||
@@ -174,60 +174,21 @@ where
|
|||||||
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
||||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
let basek: usize = self.basek();
|
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_eq!(lhs.rank(), rhs.rank_in());
|
|
||||||
assert_eq!(self.rank(), rhs.rank_out());
|
|
||||||
assert_eq!(self.basek(), basek);
|
|
||||||
assert_eq!(lhs.basek(), basek);
|
|
||||||
assert_eq!(rhs.n(), module.n());
|
|
||||||
assert_eq!(self.n(), module.n());
|
|
||||||
assert_eq!(lhs.n(), module.n());
|
|
||||||
assert!(
|
|
||||||
scratch.available()
|
|
||||||
>= GLWECiphertextFourier::keyswitch_scratch_space(
|
|
||||||
module,
|
|
||||||
self.size(),
|
|
||||||
self.rank(),
|
|
||||||
lhs.size(),
|
|
||||||
lhs.rank(),
|
|
||||||
rhs.size(),
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let cols_in: usize = rhs.rank_in();
|
|
||||||
let cols_out: usize = rhs.rank_out() + 1;
|
let cols_out: usize = rhs.rank_out() + 1;
|
||||||
|
|
||||||
// Buffer of the result of VMP in DFT
|
|
||||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
|
|
||||||
|
|
||||||
{
|
|
||||||
// Applies VMP
|
|
||||||
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
|
|
||||||
(0..cols_in).for_each(|col_i| {
|
|
||||||
module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
|
|
||||||
});
|
|
||||||
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Switches result of VMP outside of DFT
|
|
||||||
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
|
|
||||||
|
|
||||||
{
|
|
||||||
// Switches lhs 0-th outside of DFT domain and adds on
|
|
||||||
let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size());
|
|
||||||
module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2);
|
|
||||||
module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Space fr normalized VMP result outside of DFT domain
|
// Space fr normalized VMP result outside of DFT domain
|
||||||
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols_out, lhs.size());
|
let (res_idft_data, scratch1) = scratch.tmp_vec_znx(module, cols_out, lhs.size());
|
||||||
|
|
||||||
|
let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
|
||||||
|
data: res_idft_data,
|
||||||
|
basek: self.basek,
|
||||||
|
k: self.k,
|
||||||
|
};
|
||||||
|
|
||||||
|
res_idft.keyswitch_from_fourier(module, self, rhs, scratch1);
|
||||||
|
|
||||||
(0..cols_out).for_each(|i| {
|
(0..cols_out).for_each(|i| {
|
||||||
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
|
module.vec_znx_dft(self, i, &res_idft, i);
|
||||||
module.vec_znx_dft(self, i, &res_small, i);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user