wip adding automorphism on AutomorphismKey

This commit is contained in:
Jean-Philippe Bossuat
2025-05-16 16:27:49 +02:00
parent 7434f289fe
commit b71e526260
6 changed files with 191 additions and 55 deletions

View File

@@ -130,7 +130,7 @@ pub trait ScalarZnxOps {
A: ScalarZnxToRef;
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: ScalarZnxToMut;
}
@@ -162,7 +162,7 @@ impl<B: Backend> ScalarZnxOps for Module<B> {
}
}
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: ScalarZnxToMut,
{

View File

@@ -1,6 +1,6 @@
use crate::ffi::vec_znx_big;
use crate::znx_base::{ZnxInfos, ZnxView};
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned};
use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, FFT64};
use std::fmt;
use std::marker::PhantomData;
@@ -97,7 +97,18 @@ impl<D, B: Backend> VecZnxBig<D, B> {
impl<D> VecZnxBig<D, FFT64>
where
VecZnxBig<D, FFT64>: VecZnxBigToMut<FFT64> + ZnxInfos,
{
{
// Consumes the VecZnxBig to return a VecZnx.
// Useful when no normalization is needed.
pub fn to_vec_znx_small(self) -> VecZnx<D>{
VecZnx{
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
}
}
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxBig<C, FFT64>, a_col: usize)
where

View File

@@ -147,6 +147,7 @@ pub trait VecZnxBigOps<BACKEND: Backend> {
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<BACKEND>;
}
pub trait VecZnxBigScratch {
@@ -169,6 +170,7 @@ impl<B: Backend> VecZnxBigAlloc<B> for Module<B> {
}
impl VecZnxBigOps<FFT64> for Module<FFT64> {
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<FFT64>,

View File

@@ -1,6 +1,6 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps,
ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef,
ScalarZnxToRef, Scratch, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxZero,
};
use sampling::source::Source;
@@ -8,6 +8,7 @@ use crate::{
elem::{GetRow, Infos, SetRow},
gglwe_ciphertext::GGLWECiphertext,
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext::GLWECiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
keys::{SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
@@ -203,6 +204,103 @@ impl<DataSelf> AutomorphismKey<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
{
pub fn automorphism<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &AutomorphismKey<DataLhs, FFT64>,
rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut base2k::Scratch,
) where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank_in(),
lhs.rank_in(),
"ksk_out input rank: {} != ksk_in input rank: {}",
self.rank_in(),
lhs.rank_in()
);
assert_eq!(
lhs.rank_out(),
rhs.rank_in(),
"ksk_in output rank: {} != ksk_apply input rank: {}",
self.rank_out(),
rhs.rank_in()
);
assert_eq!(
self.rank_out(),
rhs.rank_out(),
"ksk_out output rank: {} != ksk_apply output rank: {}",
self.rank_out(),
rhs.rank_out()
);
}
let cols_out: usize = rhs.rank_out() + 1;
let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, lhs.size());
let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_dft_data,
basek: lhs.basek(),
k: lhs.k(),
};
(0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
// Extracts relevant row
lhs.get_row(module, row_j, col_i, &mut tmp_dft);
// Get a VecZnxBig from scratch space
let (mut tmp_idft_data, scratch2) = scratch1.tmp_vec_znx_big(module, cols_out, self.size());
// Switches input outside of DFT
(0..cols_out).for_each(|i| {
module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2);
});
// Consumes to small vec znx
let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small();
// Reverts the automorphis key from (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
(0..cols_out).for_each(|i| {
module.vec_znx_automorphism_inplace(self.p(), &mut tmp_idft_small_data, i);
});
// Wraps into ciphertext
let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
data: tmp_idft_small_data,
basek: self.basek(),
k: self.k(),
};
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
tmp_idft.keyswitch_inplace(module, &rhs.key, scratch2);
// Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a)
// and switches back to DFT domain
(0..self.rank_out() + 1).for_each(|i| {
module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i);
module.vec_znx_dft(&mut tmp_dft, i, &tmp_idft, i);
});
// Sets back the relevant row
self.set_row(module, row_j, col_i, &tmp_dft);
});
});
tmp_dft.data.zero();
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
(0..self.rank_in()).for_each(|col_j| {
self.set_row(module, row_i, col_j, &tmp_dft);
});
});
}
pub fn keyswitch<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,

View File

@@ -299,6 +299,70 @@ where
})
}
pub(crate) fn keyswitch_from_fourier<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(lhs.rank(), rhs.rank_in());
assert_eq!(self.rank(), rhs.rank_out());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
assert!(
scratch.available()
>= GLWECiphertextFourier::keyswitch_scratch_space(
module,
self.size(),
self.rank(),
lhs.size(),
lhs.rank(),
rhs.size(),
)
);
}
let cols_in: usize = rhs.rank_in();
let cols_out: usize = rhs.rank_out() + 1;
// Buffer of the result of VMP in DFT
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
{
// Applies VMP
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
(0..cols_in).for_each(|col_i| {
module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
});
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
}
// Switches result of VMP outside of DFT
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
{
// Switches lhs 0-th outside of DFT domain and adds on
let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size());
module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2);
module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0);
}
(0..cols_out).for_each(|i| {
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
});
}
pub fn keyswitch<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,

View File

@@ -174,60 +174,21 @@ where
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(lhs.rank(), rhs.rank_in());
assert_eq!(self.rank(), rhs.rank_out());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
assert!(
scratch.available()
>= GLWECiphertextFourier::keyswitch_scratch_space(
module,
self.size(),
self.rank(),
lhs.size(),
lhs.rank(),
rhs.size(),
)
);
}
let cols_in: usize = rhs.rank_in();
let cols_out: usize = rhs.rank_out() + 1;
// Buffer of the result of VMP in DFT
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
{
// Applies VMP
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
(0..cols_in).for_each(|col_i| {
module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
});
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
}
// Switches result of VMP outside of DFT
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
{
// Switches lhs 0-th outside of DFT domain and adds on
let (mut a0_big, scratch2) = scratch1.tmp_vec_znx_big(module, 1, lhs.size());
module.vec_znx_idft(&mut a0_big, 0, lhs, 0, scratch2);
module.vec_znx_big_add_inplace(&mut res_big, 0, &a0_big, 0);
}
// Space fr normalized VMP result outside of DFT domain
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols_out, lhs.size());
let (res_idft_data, scratch1) = scratch.tmp_vec_znx(module, cols_out, lhs.size());
let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
data: res_idft_data,
basek: self.basek,
k: self.k,
};
res_idft.keyswitch_from_fourier(module, self, rhs, scratch1);
(0..cols_out).for_each(|i| {
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
module.vec_znx_dft(self, i, &res_small, i);
module.vec_znx_dft(self, i, &res_idft, i);
});
}