wip adding automorphism on AutomorphismKey

This commit is contained in:
Jean-Philippe Bossuat
2025-05-16 16:27:49 +02:00
parent 7434f289fe
commit b71e526260
6 changed files with 191 additions and 55 deletions

View File

@@ -130,7 +130,7 @@ pub trait ScalarZnxOps {
A: ScalarZnxToRef; A: ScalarZnxToRef;
/// Applies the automorphism X^i -> X^ik on the selected column of `a`. /// Applies the automorphism X^i -> X^ik on the selected column of `a`.
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize) fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where where
A: ScalarZnxToMut; A: ScalarZnxToMut;
} }
@@ -162,7 +162,7 @@ impl<B: Backend> ScalarZnxOps for Module<B> {
} }
} }
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize) fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where where
A: ScalarZnxToMut, A: ScalarZnxToMut,
{ {

View File

@@ -1,6 +1,6 @@
use crate::ffi::vec_znx_big; use crate::ffi::vec_znx_big;
use crate::znx_base::{ZnxInfos, ZnxView}; use crate::znx_base::{ZnxInfos, ZnxView};
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned}; use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, FFT64};
use std::fmt; use std::fmt;
use std::marker::PhantomData; use std::marker::PhantomData;
@@ -98,6 +98,17 @@ impl<D> VecZnxBig<D, FFT64>
where where
VecZnxBig<D, FFT64>: VecZnxBigToMut<FFT64> + ZnxInfos, VecZnxBig<D, FFT64>: VecZnxBigToMut<FFT64> + ZnxInfos,
{ {
// Consumes the VecZnxBig to return a VecZnx.
// Useful when no normalization is needed.
pub fn to_vec_znx_small(self) -> VecZnx<D>{
VecZnx{
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
}
}
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxBig<C, FFT64>, a_col: usize) pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxBig<C, FFT64>, a_col: usize)
where where

View File

@@ -147,6 +147,7 @@ pub trait VecZnxBigOps<BACKEND: Backend> {
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize) fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where where
A: VecZnxBigToMut<BACKEND>; A: VecZnxBigToMut<BACKEND>;
} }
pub trait VecZnxBigScratch { pub trait VecZnxBigScratch {
@@ -169,6 +170,7 @@ impl<B: Backend> VecZnxBigAlloc<B> for Module<B> {
} }
impl VecZnxBigOps<FFT64> for Module<FFT64> { impl VecZnxBigOps<FFT64> for Module<FFT64> {
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where where
R: VecZnxBigToMut<FFT64>, R: VecZnxBigToMut<FFT64>,

View File

@@ -1,6 +1,6 @@
use base2k::{ use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps, Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps,
ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -8,6 +8,7 @@ use crate::{
elem::{GetRow, Infos, SetRow}, elem::{GetRow, Infos, SetRow},
gglwe_ciphertext::GGLWECiphertext, gglwe_ciphertext::GGLWECiphertext,
ggsw_ciphertext::GGSWCiphertext, ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext::GLWECiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier, glwe_ciphertext_fourier::GLWECiphertextFourier,
keys::{SecretKey, SecretKeyFourier}, keys::{SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey, keyswitch_key::GLWESwitchingKey,
@@ -203,6 +204,103 @@ impl<DataSelf> AutomorphismKey<DataSelf, FFT64>
where where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>, MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
{ {
pub fn automorphism<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &AutomorphismKey<DataLhs, FFT64>,
rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut base2k::Scratch,
) where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank_in(),
lhs.rank_in(),
"ksk_out input rank: {} != ksk_in input rank: {}",
self.rank_in(),
lhs.rank_in()
);
assert_eq!(
lhs.rank_out(),
rhs.rank_in(),
"ksk_in output rank: {} != ksk_apply input rank: {}",
self.rank_out(),
rhs.rank_in()
);
assert_eq!(
self.rank_out(),
rhs.rank_out(),
"ksk_out output rank: {} != ksk_apply output rank: {}",
self.rank_out(),
rhs.rank_out()
);
}
let cols_out: usize = rhs.rank_out() + 1;
let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, lhs.size());
let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_dft_data,
basek: lhs.basek(),
k: lhs.k(),
};
(0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
// Extracts relevant row
lhs.get_row(module, row_j, col_i, &mut tmp_dft);
// Get a VecZnxBig from scratch space
let (mut tmp_idft_data, scratch2) = scratch1.tmp_vec_znx_big(module, cols_out, self.size());
// Switches input outside of DFT
(0..cols_out).for_each(|i| {
module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2);
});
// Consumes to small vec znx
let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small();
// Reverts the automorphis key from (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
(0..cols_out).for_each(|i| {
module.vec_znx_automorphism_inplace(self.p(), &mut tmp_idft_small_data, i);
});
// Wraps into ciphertext
let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
data: tmp_idft_small_data,
basek: self.basek(),
k: self.k(),
};
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
tmp_idft.keyswitch_inplace(module, &rhs.key, scratch2);
// Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a)
// and switches back to DFT domain
(0..self.rank_out() + 1).for_each(|i| {
module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i);
module.vec_znx_dft(&mut tmp_dft, i, &tmp_idft, i);
});
// Sets back the relevant row
self.set_row(module, row_j, col_i, &tmp_dft);
});
});
tmp_dft.data.zero();
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
(0..self.rank_in()).for_each(|col_j| {
self.set_row(module, row_i, col_j, &tmp_dft);
});
});
}
pub fn keyswitch<DataLhs, DataRhs>( pub fn keyswitch<DataLhs, DataRhs>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,

View File

@@ -299,6 +299,70 @@ where
}) })
} }
pub(crate) fn keyswitch_from_fourier<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(lhs.rank(), rhs.rank_in());
assert_eq!(self.rank(), rhs.rank_out());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
assert!(
scratch.available()
>= GLWECiphertextFourier::keyswitch_scratch_space(
module,
self.size(),
self.rank(),
lhs.size(),
lhs.rank(),
rhs.size(),
)
);
}
let cols_in: usize = rhs.rank_in();
let cols_out: usize = rhs.rank_out() + 1;
// Buffer of the result of VMP in DFT
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
{
// Applies VMP
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
(0..cols_in).for_each(|col_i| {
module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
});
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
}
// Switches result of VMP outside of DFT
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
{
// Switches lhs 0-th outside of DFT domain and adds on
let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size());
module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2);
module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0);
}
(0..cols_out).for_each(|i| {
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
});
}
pub fn keyswitch<DataLhs, DataRhs>( pub fn keyswitch<DataLhs, DataRhs>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,

View File

@@ -174,60 +174,21 @@ where
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>, VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>, MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{ {
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(lhs.rank(), rhs.rank_in());
assert_eq!(self.rank(), rhs.rank_out());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
assert!(
scratch.available()
>= GLWECiphertextFourier::keyswitch_scratch_space(
module,
self.size(),
self.rank(),
lhs.size(),
lhs.rank(),
rhs.size(),
)
);
}
let cols_in: usize = rhs.rank_in();
let cols_out: usize = rhs.rank_out() + 1; let cols_out: usize = rhs.rank_out() + 1;
// Buffer of the result of VMP in DFT
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
{
// Applies VMP
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
(0..cols_in).for_each(|col_i| {
module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
});
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
}
// Switches result of VMP outside of DFT
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
{
// Switches lhs 0-th outside of DFT domain and adds on
let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size());
module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2);
module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0);
}
// Space fr normalized VMP result outside of DFT domain // Space fr normalized VMP result outside of DFT domain
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols_out, lhs.size()); let (res_idft_data, scratch1) = scratch.tmp_vec_znx(module, cols_out, lhs.size());
let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
data: res_idft_data,
basek: self.basek,
k: self.k,
};
res_idft.keyswitch_from_fourier(module, self, rhs, scratch1);
(0..cols_out).for_each(|i| { (0..cols_out).for_each(|i| {
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2); module.vec_znx_dft(self, i, &res_idft, i);
module.vec_znx_dft(self, i, &res_small, i);
}); });
} }