diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs
index 8da145f..fa812a8 100644
--- a/base2k/src/scalar_znx.rs
+++ b/base2k/src/scalar_znx.rs
@@ -130,7 +130,7 @@ pub trait ScalarZnxOps {
A: ScalarZnxToRef;
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
- fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize)
+ fn scalar_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize)
where
A: ScalarZnxToMut;
}
@@ -162,7 +162,7 @@ impl ScalarZnxOps for Module {
}
}
- fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize)
+ fn scalar_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize)
where
A: ScalarZnxToMut,
{
diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs
index 8b3223b..eba90e9 100644
--- a/base2k/src/vec_znx_big.rs
+++ b/base2k/src/vec_znx_big.rs
@@ -1,6 +1,6 @@
use crate::ffi::vec_znx_big;
use crate::znx_base::{ZnxInfos, ZnxView};
-use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned};
+use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, FFT64};
use std::fmt;
use std::marker::PhantomData;
@@ -97,7 +97,18 @@ impl VecZnxBig {
impl VecZnxBig
where
VecZnxBig: VecZnxBigToMut + ZnxInfos,
-{
+{
+ // Consumes the VecZnxBig to return a VecZnx.
+ // Useful when no normalization is needed.
+ pub fn to_vec_znx_small(self) -> VecZnx{
+ VecZnx{
+ data: self.data,
+ n: self.n,
+ cols: self.cols,
+ size: self.size,
+ }
+ }
+
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column(&mut self, self_col: usize, a: &VecZnxBig, a_col: usize)
where
diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs
index 8208c97..f6dad7a 100644
--- a/base2k/src/vec_znx_big_ops.rs
+++ b/base2k/src/vec_znx_big_ops.rs
@@ -147,6 +147,7 @@ pub trait VecZnxBigOps {
fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut;
+
}
pub trait VecZnxBigScratch {
@@ -169,6 +170,7 @@ impl VecZnxBigAlloc for Module {
}
impl VecZnxBigOps for Module {
+
fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut,
diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs
index ed6a954..8741bf9 100644
--- a/core/src/automorphism.rs
+++ b/core/src/automorphism.rs
@@ -1,6 +1,6 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps,
- ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef,
+ ScalarZnxToRef, Scratch, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxZero,
};
use sampling::source::Source;
@@ -8,6 +8,7 @@ use crate::{
elem::{GetRow, Infos, SetRow},
gglwe_ciphertext::GGLWECiphertext,
ggsw_ciphertext::GGSWCiphertext,
+ glwe_ciphertext::GLWECiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
keys::{SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
@@ -203,6 +204,103 @@ impl AutomorphismKey
where
MatZnxDft: MatZnxDftToMut + MatZnxDftToRef,
{
+ pub fn automorphism(
+ &mut self,
+ module: &Module,
+ lhs: &AutomorphismKey,
+ rhs: &AutomorphismKey,
+ scratch: &mut base2k::Scratch,
+ ) where
+ MatZnxDft: MatZnxDftToRef,
+ MatZnxDft: MatZnxDftToRef,
+ {
+ #[cfg(debug_assertions)]
+ {
+ assert_eq!(
+ self.rank_in(),
+ lhs.rank_in(),
+ "ksk_out input rank: {} != ksk_in input rank: {}",
+ self.rank_in(),
+ lhs.rank_in()
+ );
+ assert_eq!(
+ lhs.rank_out(),
+ rhs.rank_in(),
+ "ksk_in output rank: {} != ksk_apply input rank: {}",
+ self.rank_out(),
+ rhs.rank_in()
+ );
+ assert_eq!(
+ self.rank_out(),
+ rhs.rank_out(),
+ "ksk_out output rank: {} != ksk_apply output rank: {}",
+ self.rank_out(),
+ rhs.rank_out()
+ );
+ }
+
+ let cols_out: usize = rhs.rank_out() + 1;
+
+ let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, lhs.size());
+
+ let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
+ data: tmp_dft_data,
+ basek: lhs.basek(),
+ k: lhs.k(),
+ };
+
+ (0..self.rank_in()).for_each(|col_i| {
+ (0..self.rows()).for_each(|row_j| {
+ // Extracts relevant row
+ lhs.get_row(module, row_j, col_i, &mut tmp_dft);
+
+ // Get a VecZnxBig from scratch space
+ let (mut tmp_idft_data, scratch2) = scratch1.tmp_vec_znx_big(module, cols_out, self.size());
+
+ // Switches input outside of DFT
+ (0..cols_out).for_each(|i| {
+ module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2);
+ });
+
+ // Consumes to small vec znx
+ let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small();
+
+ // Reverts the automorphis key from (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
+ (0..cols_out).for_each(|i| {
+ module.vec_znx_automorphism_inplace(self.p(), &mut tmp_idft_small_data, i);
+ });
+
+ // Wraps into ciphertext
+ let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
+ data: tmp_idft_small_data,
+ basek: self.basek(),
+ k: self.k(),
+ };
+
+ // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
+ tmp_idft.keyswitch_inplace(module, &rhs.key, scratch2);
+
+ // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a)
+ // and switches back to DFT domain
+ (0..self.rank_out() + 1).for_each(|i| {
+ module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i);
+ module.vec_znx_dft(&mut tmp_dft, i, &tmp_idft, i);
+ });
+
+ // Sets back the relevant row
+ self.set_row(module, row_j, col_i, &tmp_dft);
+ });
+ });
+
+ tmp_dft.data.zero();
+
+ (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
+ (0..self.rank_in()).for_each(|col_j| {
+ self.set_row(module, row_i, col_j, &tmp_dft);
+ });
+ });
+ }
+
pub fn keyswitch(
&mut self,
module: &Module,
diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs
index 422f2cc..245ee26 100644
--- a/core/src/glwe_ciphertext.rs
+++ b/core/src/glwe_ciphertext.rs
@@ -299,6 +299,70 @@ where
})
}
+ pub(crate) fn keyswitch_from_fourier(
+ &mut self,
+ module: &Module,
+ lhs: &GLWECiphertextFourier,
+ rhs: &GLWESwitchingKey,
+ scratch: &mut Scratch,
+ ) where
+ VecZnxDft: VecZnxDftToRef,
+ MatZnxDft: MatZnxDftToRef,
+ {
+ let basek: usize = self.basek();
+
+ #[cfg(debug_assertions)]
+ {
+ assert_eq!(lhs.rank(), rhs.rank_in());
+ assert_eq!(self.rank(), rhs.rank_out());
+ assert_eq!(self.basek(), basek);
+ assert_eq!(lhs.basek(), basek);
+ assert_eq!(rhs.n(), module.n());
+ assert_eq!(self.n(), module.n());
+ assert_eq!(lhs.n(), module.n());
+ assert!(
+ scratch.available()
+ >= GLWECiphertextFourier::keyswitch_scratch_space(
+ module,
+ self.size(),
+ self.rank(),
+ lhs.size(),
+ lhs.rank(),
+ rhs.size(),
+ )
+ );
+ }
+
+ let cols_in: usize = rhs.rank_in();
+ let cols_out: usize = rhs.rank_out() + 1;
+
+ // Buffer of the result of VMP in DFT
+ let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
+
+ {
+ // Applies VMP
+ let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
+ (0..cols_in).for_each(|col_i| {
+ module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
+ });
+ module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
+ }
+
+ // Switches result of VMP outside of DFT
+ let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
+
+ {
+ // Switches lhs 0-th outside of DFT domain and adds on
+ let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size());
+ module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2);
+ module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0);
+ }
+
+ (0..cols_out).for_each(|i| {
+ module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
+ });
+ }
+
pub fn keyswitch(
&mut self,
module: &Module,
diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs
index ebbe9cf..4c22507 100644
--- a/core/src/glwe_ciphertext_fourier.rs
+++ b/core/src/glwe_ciphertext_fourier.rs
@@ -174,60 +174,21 @@ where
VecZnxDft: VecZnxDftToRef,
MatZnxDft: MatZnxDftToRef,
{
- let basek: usize = self.basek();
-
- #[cfg(debug_assertions)]
- {
- assert_eq!(lhs.rank(), rhs.rank_in());
- assert_eq!(self.rank(), rhs.rank_out());
- assert_eq!(self.basek(), basek);
- assert_eq!(lhs.basek(), basek);
- assert_eq!(rhs.n(), module.n());
- assert_eq!(self.n(), module.n());
- assert_eq!(lhs.n(), module.n());
- assert!(
- scratch.available()
- >= GLWECiphertextFourier::keyswitch_scratch_space(
- module,
- self.size(),
- self.rank(),
- lhs.size(),
- lhs.rank(),
- rhs.size(),
- )
- );
- }
-
- let cols_in: usize = rhs.rank_in();
let cols_out: usize = rhs.rank_out() + 1;
- // Buffer of the result of VMP in DFT
- let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
-
- {
- // Applies VMP
- let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
- (0..cols_in).for_each(|col_i| {
- module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
- });
- module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
- }
-
- // Switches result of VMP outside of DFT
- let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
-
- {
- // Switches lhs 0-th outside of DFT domain and adds on
- let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size());
- module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2);
- module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0);
- }
-
// Space fr normalized VMP result outside of DFT domain
- let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols_out, lhs.size());
+ let (res_idft_data, scratch1) = scratch.tmp_vec_znx(module, cols_out, lhs.size());
+
+ let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
+ data: res_idft_data,
+ basek: self.basek,
+ k: self.k,
+ };
+
+ res_idft.keyswitch_from_fourier(module, self, rhs, scratch1);
+
(0..cols_out).for_each(|i| {
- module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
- module.vec_znx_dft(self, i, &res_small, i);
+ module.vec_znx_dft(self, i, &res_idft, i);
});
}