diff --git a/base2k/src/scalar_znx.rs b/base2k/src/scalar_znx.rs index 8da145f..fa812a8 100644 --- a/base2k/src/scalar_znx.rs +++ b/base2k/src/scalar_znx.rs @@ -130,7 +130,7 @@ pub trait ScalarZnxOps { A: ScalarZnxToRef; /// Applies the automorphism X^i -> X^ik on the selected column of `a`. - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + fn scalar_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) where A: ScalarZnxToMut; } @@ -162,7 +162,7 @@ impl ScalarZnxOps for Module { } } - fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) + fn scalar_znx_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) where A: ScalarZnxToMut, { diff --git a/base2k/src/vec_znx_big.rs b/base2k/src/vec_znx_big.rs index 8b3223b..eba90e9 100644 --- a/base2k/src/vec_znx_big.rs +++ b/base2k/src/vec_znx_big.rs @@ -1,6 +1,6 @@ use crate::ffi::vec_znx_big; use crate::znx_base::{ZnxInfos, ZnxView}; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned}; +use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, FFT64}; use std::fmt; use std::marker::PhantomData; @@ -97,7 +97,18 @@ impl VecZnxBig { impl VecZnxBig where VecZnxBig: VecZnxBigToMut + ZnxInfos, -{ +{ + // Consumes the VecZnxBig to return a VecZnx. + // Useful when no normalization is needed. + pub fn to_vec_znx_small(self) -> VecZnx{ + VecZnx{ + data: self.data, + n: self.n, + cols: self.cols, + size: self.size, + } + } + /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. pub fn extract_column(&mut self, self_col: usize, a: &VecZnxBig, a_col: usize) where diff --git a/base2k/src/vec_znx_big_ops.rs b/base2k/src/vec_znx_big_ops.rs index 8208c97..f6dad7a 100644 --- a/base2k/src/vec_znx_big_ops.rs +++ b/base2k/src/vec_znx_big_ops.rs @@ -147,6 +147,7 @@ pub trait VecZnxBigOps { fn vec_znx_big_automorphism_inplace(&self, k: i64, a: &mut A, a_col: usize) where A: VecZnxBigToMut; + } pub trait VecZnxBigScratch { @@ -169,6 +170,7 @@ impl VecZnxBigAlloc for Module { } impl VecZnxBigOps for Module { + fn vec_znx_big_add(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) where R: VecZnxBigToMut, diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index ed6a954..8741bf9 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -1,6 +1,6 @@ use base2k::{ Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps, - ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, + ScalarZnxToRef, Scratch, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxZero, }; use sampling::source::Source; @@ -8,6 +8,7 @@ use crate::{ elem::{GetRow, Infos, SetRow}, gglwe_ciphertext::GGLWECiphertext, ggsw_ciphertext::GGSWCiphertext, + glwe_ciphertext::GLWECiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, keys::{SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, @@ -203,6 +204,103 @@ impl AutomorphismKey where MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, { + pub fn automorphism( + &mut self, + module: &Module, + lhs: &AutomorphismKey, + rhs: &AutomorphismKey, + scratch: &mut base2k::Scratch, + ) where + MatZnxDft: MatZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + #[cfg(debug_assertions)] + { + assert_eq!( + self.rank_in(), + lhs.rank_in(), + "ksk_out input rank: {} != ksk_in input rank: {}", + self.rank_in(), + lhs.rank_in() + ); + assert_eq!( + lhs.rank_out(), + rhs.rank_in(), + "ksk_in output rank: {} != ksk_apply input rank: {}", + self.rank_out(), + rhs.rank_in() + ); + assert_eq!( + self.rank_out(), + rhs.rank_out(), + "ksk_out output rank: {} != ksk_apply output rank: {}", + self.rank_out(), + rhs.rank_out() + ); + } + + let cols_out: usize = rhs.rank_out() + 1; + + let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, lhs.size()); + + let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { + data: tmp_dft_data, + basek: lhs.basek(), + k: lhs.k(), + }; + + (0..self.rank_in()).for_each(|col_i| { + (0..self.rows()).for_each(|row_j| { + // Extracts relevant row + lhs.get_row(module, row_j, col_i, &mut tmp_dft); + + // Get a VecZnxBig from scratch space + let (mut tmp_idft_data, scratch2) = scratch1.tmp_vec_znx_big(module, cols_out, self.size()); + + // Switches input outside of DFT + (0..cols_out).for_each(|i| { + module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2); + }); + + // Consumes to small vec znx + let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small(); + + // Reverts the automorphis key from (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a) + (0..cols_out).for_each(|i| { + module.vec_znx_automorphism_inplace(self.p(), &mut tmp_idft_small_data, i); + }); + + // Wraps into ciphertext + let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: tmp_idft_small_data, + basek: self.basek(), + k: self.k(), + }; + + // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) + tmp_idft.keyswitch_inplace(module, &rhs.key, scratch2); + + // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a) + // and switches back to DFT domain + (0..self.rank_out() + 1).for_each(|i| { + module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i); + module.vec_znx_dft(&mut tmp_dft, i, &tmp_idft, i); + }); + + // Sets back the relevant row + self.set_row(module, row_j, col_i, &tmp_dft); + }); + }); + + tmp_dft.data.zero(); + + (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { + (0..self.rank_in()).for_each(|col_j| { + self.set_row(module, row_i, col_j, &tmp_dft); + }); + }); + } + pub fn keyswitch( &mut self, module: &Module, diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 422f2cc..245ee26 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -299,6 +299,70 @@ where }) } + pub(crate) fn keyswitch_from_fourier( + &mut self, + module: &Module, + lhs: &GLWECiphertextFourier, + rhs: &GLWESwitchingKey, + scratch: &mut Scratch, + ) where + VecZnxDft: VecZnxDftToRef, + MatZnxDft: MatZnxDftToRef, + { + let basek: usize = self.basek(); + + #[cfg(debug_assertions)] + { + assert_eq!(lhs.rank(), rhs.rank_in()); + assert_eq!(self.rank(), rhs.rank_out()); + assert_eq!(self.basek(), basek); + assert_eq!(lhs.basek(), basek); + assert_eq!(rhs.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(lhs.n(), module.n()); + assert!( + scratch.available() + >= GLWECiphertextFourier::keyswitch_scratch_space( + module, + self.size(), + self.rank(), + lhs.size(), + lhs.rank(), + rhs.size(), + ) + ); + } + + let cols_in: usize = rhs.rank_in(); + let cols_out: usize = rhs.rank_out() + 1; + + // Buffer of the result of VMP in DFT + let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise + + { + // Applies VMP + let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size()); + (0..cols_in).for_each(|col_i| { + module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1); + }); + module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2); + } + + // Switches result of VMP outside of DFT + let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); + + { + // Switches lhs 0-th outside of DFT domain and adds on + let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size()); + module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2); + module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0); + } + + (0..cols_out).for_each(|i| { + module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1); + }); + } + pub fn keyswitch( &mut self, module: &Module, diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index ebbe9cf..4c22507 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -174,60 +174,21 @@ where VecZnxDft: VecZnxDftToRef, MatZnxDft: MatZnxDftToRef, { - let basek: usize = self.basek(); - - #[cfg(debug_assertions)] - { - assert_eq!(lhs.rank(), rhs.rank_in()); - assert_eq!(self.rank(), rhs.rank_out()); - assert_eq!(self.basek(), basek); - assert_eq!(lhs.basek(), basek); - assert_eq!(rhs.n(), module.n()); - assert_eq!(self.n(), module.n()); - assert_eq!(lhs.n(), module.n()); - assert!( - scratch.available() - >= GLWECiphertextFourier::keyswitch_scratch_space( - module, - self.size(), - self.rank(), - lhs.size(), - lhs.rank(), - rhs.size(), - ) - ); - } - - let cols_in: usize = rhs.rank_in(); let cols_out: usize = rhs.rank_out() + 1; - // Buffer of the result of VMP in DFT - let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise - - { - // Applies VMP - let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size()); - (0..cols_in).for_each(|col_i| { - module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1); - }); - module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2); - } - - // Switches result of VMP outside of DFT - let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); - - { - // Switches lhs 0-th outside of DFT domain and adds on - let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size()); - module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2); - module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0); - } - // Space fr normalized VMP result outside of DFT domain - let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols_out, lhs.size()); + let (res_idft_data, scratch1) = scratch.tmp_vec_znx(module, cols_out, lhs.size()); + + let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { + data: res_idft_data, + basek: self.basek, + k: self.k, + }; + + res_idft.keyswitch_from_fourier(module, self, rhs, scratch1); + (0..cols_out).for_each(|i| { - module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2); - module.vec_znx_dft(self, i, &res_small, i); + module.vec_znx_dft(self, i, &res_idft, i); }); }