mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 21:26:41 +01:00
refactor
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
use backend::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps,
|
||||
ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
|
||||
VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, ZnxZero,
|
||||
Backend, FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDftOps, Scratch, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc,
|
||||
VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
@@ -48,24 +47,6 @@ impl<T, B: Backend> GLWECiphertextFourier<T, B> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> VecZnxDftToMut<B> for GLWECiphertextFourier<C, B>
|
||||
where
|
||||
VecZnxDft<C, B>: VecZnxDftToMut<B>,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> VecZnxDftToRef<B> for GLWECiphertextFourier<C, B>
|
||||
where
|
||||
VecZnxDft<C, B>: VecZnxDftToRef<B>,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl GLWECiphertextFourier<Vec<u8>, FFT64> {
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn idft_scratch_space(module: &Module<FFT64>, size: usize) -> usize {
|
||||
@@ -124,11 +105,8 @@ impl GLWECiphertextFourier<Vec<u8>, FFT64> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> GLWECiphertextFourier<DataSelf, FFT64>
|
||||
where
|
||||
VecZnxDft<DataSelf, FFT64>: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
pub fn encrypt_zero_sk<DataSk>(
|
||||
impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWECiphertextFourier<DataSelf, FFT64> {
|
||||
pub fn encrypt_zero_sk<DataSk: AsRef<[u8]>>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
|
||||
@@ -136,9 +114,7 @@ where
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
) {
|
||||
let (vec_znx_tmp, scratch_1) = scratch.tmp_vec_znx(module, self.rank() + 1, self.size());
|
||||
let mut ct_idft = GLWECiphertext {
|
||||
data: vec_znx_tmp,
|
||||
@@ -150,16 +126,13 @@ where
|
||||
ct_idft.dft(module, self);
|
||||
}
|
||||
|
||||
pub fn keyswitch<DataLhs, DataRhs>(
|
||||
pub fn keyswitch<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
) {
|
||||
let cols_out: usize = rhs.rank_out() + 1;
|
||||
|
||||
// Space fr normalized VMP result outside of DFT domain
|
||||
@@ -174,34 +147,29 @@ where
|
||||
res_idft.keyswitch_from_fourier(module, lhs, rhs, scratch1);
|
||||
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_dft(self, i, &res_idft, i);
|
||||
module.vec_znx_dft(&mut self.data, i, &res_idft.data, i);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace<DataRhs>(
|
||||
pub fn keyswitch_inplace<DataRhs: AsRef<[u8]>>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
) {
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
|
||||
self.keyswitch(&module, &*self_ptr, rhs, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn external_product<DataLhs, DataRhs>(
|
||||
pub fn external_product<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
|
||||
rhs: &GGSWCiphertext<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
) {
|
||||
let basek: usize = self.basek();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -221,7 +189,7 @@ where
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size());
|
||||
|
||||
{
|
||||
module.vmp_apply(&mut res_dft, lhs, rhs, scratch1);
|
||||
module.vmp_apply(&mut res_dft, &lhs.data, &rhs.data, scratch1);
|
||||
}
|
||||
|
||||
// VMP result in high precision
|
||||
@@ -231,18 +199,16 @@ where
|
||||
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size());
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
|
||||
module.vec_znx_dft(self, i, &res_small, i);
|
||||
module.vec_znx_dft(&mut self.data, i, &res_small, i);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn external_product_inplace<DataRhs>(
|
||||
pub fn external_product_inplace<DataRhs: AsRef<[u8]>>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &GGSWCiphertext<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
) {
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
|
||||
self.external_product(&module, &*self_ptr, rhs, scratch);
|
||||
@@ -250,20 +216,14 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> GLWECiphertextFourier<DataSelf, FFT64>
|
||||
where
|
||||
VecZnxDft<DataSelf, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
pub fn decrypt<DataPt, DataSk>(
|
||||
impl<DataSelf: AsRef<[u8]>> GLWECiphertextFourier<DataSelf, FFT64> {
|
||||
pub fn decrypt<DataPt: AsRef<[u8]> + AsMut<[u8]>, DataSk: AsRef<[u8]>>(
|
||||
&self,
|
||||
module: &Module<FFT64>,
|
||||
pt: &mut GLWEPlaintext<DataPt>,
|
||||
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<DataPt>: VecZnxToMut,
|
||||
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank(), sk_dft.rank());
|
||||
@@ -280,7 +240,7 @@ where
|
||||
{
|
||||
(1..cols).for_each(|i| {
|
||||
let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct
|
||||
module.svp_apply(&mut ci_dft, 0, sk_dft, i - 1, self, i);
|
||||
module.svp_apply(&mut ci_dft, 0, &sk_dft.data, i - 1, &self.data, i);
|
||||
let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft);
|
||||
module.vec_znx_big_add_inplace(&mut pt_big, 0, &ci_big, 0);
|
||||
});
|
||||
@@ -289,22 +249,24 @@ where
|
||||
{
|
||||
let (mut c0_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, self.size());
|
||||
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
|
||||
module.vec_znx_idft(&mut c0_big, 0, self, 0, scratch_2);
|
||||
module.vec_znx_idft(&mut c0_big, 0, &self.data, 0, scratch_2);
|
||||
module.vec_znx_big_add_inplace(&mut pt_big, 0, &c0_big, 0);
|
||||
}
|
||||
|
||||
// pt = norm(BIG(m + e))
|
||||
module.vec_znx_big_normalize(self.basek(), pt, 0, &mut pt_big, 0, scratch_1);
|
||||
module.vec_znx_big_normalize(self.basek(), &mut pt.data, 0, &mut pt_big, 0, scratch_1);
|
||||
|
||||
pt.basek = self.basek();
|
||||
pt.k = pt.k().min(self.k());
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn idft<DataRes>(&self, module: &Module<FFT64>, res: &mut GLWECiphertext<DataRes>, scratch: &mut Scratch)
|
||||
where
|
||||
GLWECiphertext<DataRes>: VecZnxToMut,
|
||||
{
|
||||
pub(crate) fn idft<DataRes: AsRef<[u8]> + AsMut<[u8]>>(
|
||||
&self,
|
||||
module: &Module<FFT64>,
|
||||
res: &mut GLWECiphertext<DataRes>,
|
||||
scratch: &mut Scratch,
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank(), res.rank());
|
||||
@@ -316,8 +278,8 @@ where
|
||||
let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 1, min_size);
|
||||
|
||||
(0..self.rank() + 1).for_each(|i| {
|
||||
module.vec_znx_idft(&mut res_big, 0, self, i, scratch1);
|
||||
module.vec_znx_big_normalize(self.basek(), res, i, &res_big, 0, scratch1);
|
||||
module.vec_znx_idft(&mut res_big, 0, &self.data, i, scratch1);
|
||||
module.vec_znx_big_normalize(self.basek(), &mut res.data, i, &res_big, 0, scratch1);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user