This commit is contained in:
Jean-Philippe Bossuat
2025-05-27 17:49:43 +02:00
parent dec3481a6f
commit a295085724
32 changed files with 897 additions and 1375 deletions

View File

@@ -1,7 +1,6 @@
use backend::{
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps,
ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, ZnxZero,
Backend, FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDftOps, Scratch, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc,
VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxZero,
};
use sampling::source::Source;
@@ -48,24 +47,6 @@ impl<T, B: Backend> GLWECiphertextFourier<T, B> {
}
}
impl<C, B: Backend> VecZnxDftToMut<B> for GLWECiphertextFourier<C, B>
where
VecZnxDft<C, B>: VecZnxDftToMut<B>,
{
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> VecZnxDftToRef<B> for GLWECiphertextFourier<C, B>
where
VecZnxDft<C, B>: VecZnxDftToRef<B>,
{
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
self.data.to_ref()
}
}
impl GLWECiphertextFourier<Vec<u8>, FFT64> {
#[allow(dead_code)]
pub(crate) fn idft_scratch_space(module: &Module<FFT64>, size: usize) -> usize {
@@ -124,11 +105,8 @@ impl GLWECiphertextFourier<Vec<u8>, FFT64> {
}
}
impl<DataSelf> GLWECiphertextFourier<DataSelf, FFT64>
where
VecZnxDft<DataSelf, FFT64>: VecZnxDftToMut<FFT64>,
{
pub fn encrypt_zero_sk<DataSk>(
impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWECiphertextFourier<DataSelf, FFT64> {
pub fn encrypt_zero_sk<DataSk: AsRef<[u8]>>(
&mut self,
module: &Module<FFT64>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
@@ -136,9 +114,7 @@ where
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch,
) where
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
) {
let (vec_znx_tmp, scratch_1) = scratch.tmp_vec_znx(module, self.rank() + 1, self.size());
let mut ct_idft = GLWECiphertext {
data: vec_znx_tmp,
@@ -150,16 +126,13 @@ where
ct_idft.dft(module, self);
}
pub fn keyswitch<DataLhs, DataRhs>(
pub fn keyswitch<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self,
module: &Module<FFT64>,
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
) {
let cols_out: usize = rhs.rank_out() + 1;
// Space fr normalized VMP result outside of DFT domain
@@ -174,34 +147,29 @@ where
res_idft.keyswitch_from_fourier(module, lhs, rhs, scratch1);
(0..cols_out).for_each(|i| {
module.vec_znx_dft(self, i, &res_idft, i);
module.vec_znx_dft(&mut self.data, i, &res_idft.data, i);
});
}
pub fn keyswitch_inplace<DataRhs>(
pub fn keyswitch_inplace<DataRhs: AsRef<[u8]>>(
&mut self,
module: &Module<FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
) {
unsafe {
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
self.keyswitch(&module, &*self_ptr, rhs, scratch);
}
}
pub fn external_product<DataLhs, DataRhs>(
pub fn external_product<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self,
module: &Module<FFT64>,
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
) {
let basek: usize = self.basek();
#[cfg(debug_assertions)]
@@ -221,7 +189,7 @@ where
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size());
{
module.vmp_apply(&mut res_dft, lhs, rhs, scratch1);
module.vmp_apply(&mut res_dft, &lhs.data, &rhs.data, scratch1);
}
// VMP result in high precision
@@ -231,18 +199,16 @@ where
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size());
(0..cols).for_each(|i| {
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
module.vec_znx_dft(self, i, &res_small, i);
module.vec_znx_dft(&mut self.data, i, &res_small, i);
});
}
pub fn external_product_inplace<DataRhs>(
pub fn external_product_inplace<DataRhs: AsRef<[u8]>>(
&mut self,
module: &Module<FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
) {
unsafe {
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
self.external_product(&module, &*self_ptr, rhs, scratch);
@@ -250,20 +216,14 @@ where
}
}
impl<DataSelf> GLWECiphertextFourier<DataSelf, FFT64>
where
VecZnxDft<DataSelf, FFT64>: VecZnxDftToRef<FFT64>,
{
pub fn decrypt<DataPt, DataSk>(
impl<DataSelf: AsRef<[u8]>> GLWECiphertextFourier<DataSelf, FFT64> {
pub fn decrypt<DataPt: AsRef<[u8]> + AsMut<[u8]>, DataSk: AsRef<[u8]>>(
&self,
module: &Module<FFT64>,
pt: &mut GLWEPlaintext<DataPt>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
scratch: &mut Scratch,
) where
VecZnx<DataPt>: VecZnxToMut,
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
) {
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), sk_dft.rank());
@@ -280,7 +240,7 @@ where
{
(1..cols).for_each(|i| {
let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct
module.svp_apply(&mut ci_dft, 0, sk_dft, i - 1, self, i);
module.svp_apply(&mut ci_dft, 0, &sk_dft.data, i - 1, &self.data, i);
let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft);
module.vec_znx_big_add_inplace(&mut pt_big, 0, &ci_big, 0);
});
@@ -289,22 +249,24 @@ where
{
let (mut c0_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, self.size());
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
module.vec_znx_idft(&mut c0_big, 0, self, 0, scratch_2);
module.vec_znx_idft(&mut c0_big, 0, &self.data, 0, scratch_2);
module.vec_znx_big_add_inplace(&mut pt_big, 0, &c0_big, 0);
}
// pt = norm(BIG(m + e))
module.vec_znx_big_normalize(self.basek(), pt, 0, &mut pt_big, 0, scratch_1);
module.vec_znx_big_normalize(self.basek(), &mut pt.data, 0, &mut pt_big, 0, scratch_1);
pt.basek = self.basek();
pt.k = pt.k().min(self.k());
}
#[allow(dead_code)]
pub(crate) fn idft<DataRes>(&self, module: &Module<FFT64>, res: &mut GLWECiphertext<DataRes>, scratch: &mut Scratch)
where
GLWECiphertext<DataRes>: VecZnxToMut,
{
pub(crate) fn idft<DataRes: AsRef<[u8]> + AsMut<[u8]>>(
&self,
module: &Module<FFT64>,
res: &mut GLWECiphertext<DataRes>,
scratch: &mut Scratch,
) {
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), res.rank());
@@ -316,8 +278,8 @@ where
let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 1, min_size);
(0..self.rank() + 1).for_each(|i| {
module.vec_znx_idft(&mut res_big, 0, self, i, scratch1);
module.vec_znx_big_normalize(self.basek(), res, i, &res_big, 0, scratch1);
module.vec_znx_idft(&mut res_big, 0, &self.data, i, scratch1);
module.vec_znx_big_normalize(self.basek(), &mut res.data, i, &res_big, 0, scratch1);
});
}
}