This commit is contained in:
Jean-Philippe Bossuat
2025-05-27 17:49:43 +02:00
parent dec3481a6f
commit a295085724
32 changed files with 897 additions and 1375 deletions

View File

@@ -152,81 +152,49 @@ impl<D: AsRef<[u8]>> MatZnxDft<D, FFT64> {
} }
pub type MatZnxDftOwned<B> = MatZnxDft<Vec<u8>, B>; pub type MatZnxDftOwned<B> = MatZnxDft<Vec<u8>, B>;
pub type MatZnxDftMut<'a, B> = MatZnxDft<&'a mut [u8], B>;
pub type MatZnxDftRef<'a, B> = MatZnxDft<&'a [u8], B>;
pub trait MatZnxDftToRef<B: Backend> { pub trait MatZnxToRef<B: Backend> {
fn to_ref(&self) -> MatZnxDft<&[u8], B>; fn to_ref(&self) -> MatZnxDft<&[u8], B>;
} }
pub trait MatZnxDftToMut<B: Backend>: MatZnxDftToRef<B> { impl<D, B: Backend> MatZnxToRef<B> for MatZnxDft<D, B>
where
D: AsRef<[u8]>,
B: Backend,
{
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft {
data: self.data.as_ref(),
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: std::marker::PhantomData,
}
}
}
pub trait MatZnxToMut<B: Backend> {
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>; fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>;
} }
impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<Vec<u8>, B> { impl<D, B: Backend> MatZnxToMut<B> for MatZnxDft<D, B>
where
D: AsRef<[u8]> + AsMut<[u8]>,
B: Backend,
{
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
MatZnxDft { MatZnxDft {
data: self.data.as_mut_slice(), data: self.data.as_mut(),
n: self.n, n: self.n,
rows: self.rows, rows: self.rows,
cols_in: self.cols_in, cols_in: self.cols_in,
cols_out: self.cols_out, cols_out: self.cols_out,
size: self.size, size: self.size,
_phantom: PhantomData, _phantom: std::marker::PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft {
data: self.data.as_slice(),
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
MatZnxDft {
data: self.data,
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&mut [u8], B> {
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft {
data: self.data,
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&[u8], B> {
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft {
data: self.data,
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
} }
} }
} }

View File

@@ -2,7 +2,7 @@ use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::ffi::vmp; use crate::ffi::vmp;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{ use crate::{
Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut, Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxToMut, MatZnxToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut,
VecZnxDftToRef, VecZnxDftToRef,
}; };
@@ -47,27 +47,27 @@ pub trait MatZnxDftOps<BACKEND: Backend> {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `b`: [MatZnxDft] on which the values are encoded. /// * `res`: [MatZnxDft] on which the values are encoded.
/// * `a`: the [VecZnxDft] to encode on the [MatZnxDft]. /// * `a`: the [VecZnxDft] to encode on the [MatZnxDft].
/// * `row_i`: the index of the row to prepare. /// * `row_i`: the index of the row to prepare.
/// ///
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes]. /// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
where where
R: MatZnxDftToMut<BACKEND>, R: MatZnxToMut<FFT64>,
A: VecZnxDftToRef<BACKEND>; A: VecZnxDftToRef<FFT64>;
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft]. /// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft].
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft]. /// * `res`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
/// * `a`: [MatZnxDft] on which the values are encoded. /// * `a`: [MatZnxDft] on which the values are encoded.
/// * `row_i`: the index of the row to extract. /// * `row_i`: the index of the row to extract.
fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
where where
R: VecZnxDftToMut<BACKEND>, R: VecZnxDftToMut<FFT64>,
A: MatZnxDftToRef<BACKEND>; A: MatZnxToRef<FFT64>;
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft]. /// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. /// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
@@ -96,16 +96,16 @@ pub trait MatZnxDftOps<BACKEND: Backend> {
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes]. /// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) fn vmp_apply<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
where where
R: VecZnxDftToMut<BACKEND>, R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<BACKEND>, A: VecZnxDftToRef<FFT64>,
B: MatZnxDftToRef<BACKEND>; B: MatZnxToRef<FFT64>;
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R. // Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
where where
R: VecZnxDftToMut<BACKEND>, R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<BACKEND>, A: VecZnxDftToRef<FFT64>,
B: MatZnxDftToRef<BACKEND>; B: MatZnxToRef<FFT64>;
} }
impl<B: Backend> MatZnxDftAlloc<B> for Module<B> { impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
@@ -154,10 +154,10 @@ impl<BACKEND: Backend> MatZnxDftScratch for Module<BACKEND> {
impl MatZnxDftOps<FFT64> for Module<FFT64> { impl MatZnxDftOps<FFT64> for Module<FFT64> {
fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A) fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
where where
R: MatZnxDftToMut<FFT64>, R: MatZnxToMut<FFT64>,
A: VecZnxDftToRef<FFT64>, A: VecZnxDftToRef<FFT64>,
{ {
let mut res: MatZnxDft<&mut [u8], _> = res.to_mut(); let mut res: MatZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: VecZnxDft<&[u8], _> = a.to_ref(); let a: VecZnxDft<&[u8], _> = a.to_ref();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@@ -207,9 +207,9 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize) fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
where where
R: VecZnxDftToMut<FFT64>, R: VecZnxDftToMut<FFT64>,
A: MatZnxDftToRef<FFT64>, A: MatZnxToRef<FFT64>,
{ {
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: MatZnxDft<&[u8], _> = a.to_ref(); let a: MatZnxDft<&[u8], _> = a.to_ref();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@@ -259,7 +259,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
where where
R: VecZnxDftToMut<FFT64>, R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>, A: VecZnxDftToRef<FFT64>,
B: MatZnxDftToRef<FFT64>, B: MatZnxToRef<FFT64>,
{ {
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
let a: VecZnxDft<&[u8], _> = a.to_ref(); let a: VecZnxDft<&[u8], _> = a.to_ref();
@@ -313,7 +313,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
where where
R: VecZnxDftToMut<FFT64>, R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>, A: VecZnxDftToRef<FFT64>,
B: MatZnxDftToRef<FFT64>, B: MatZnxToRef<FFT64>,
{ {
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut(); let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
let a: VecZnxDft<&[u8], _> = a.to_ref(); let a: VecZnxDft<&[u8], _> = a.to_ref();

View File

@@ -42,7 +42,7 @@ pub trait AddNormal {
fn add_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64); fn add_normal(&mut self, basek: usize, col_i: usize, k: usize, source: &mut Source, sigma: f64, bound: f64);
} }
impl<T> FillUniform for VecZnx<T> impl<T: AsMut<[u8]> + AsRef<[u8]>> FillUniform for VecZnx<T>
where where
VecZnx<T>: VecZnxToMut, VecZnx<T>: VecZnxToMut,
{ {
@@ -59,7 +59,7 @@ where
} }
} }
impl<T> FillDistF64 for VecZnx<T> impl<T: AsMut<[u8]> + AsRef<[u8]>> FillDistF64 for VecZnx<T>
where where
VecZnx<T>: VecZnxToMut, VecZnx<T>: VecZnxToMut,
{ {
@@ -102,7 +102,7 @@ where
} }
} }
impl<T> AddDistF64 for VecZnx<T> impl<T: AsMut<[u8]> + AsRef<[u8]>> AddDistF64 for VecZnx<T>
where where
VecZnx<T>: VecZnxToMut, VecZnx<T>: VecZnxToMut,
{ {
@@ -145,7 +145,7 @@ where
} }
} }
impl<T> FillNormal for VecZnx<T> impl<T: AsMut<[u8]> + AsRef<[u8]>> FillNormal for VecZnx<T>
where where
VecZnx<T>: VecZnxToMut, VecZnx<T>: VecZnxToMut,
{ {
@@ -161,7 +161,7 @@ where
} }
} }
impl<T> AddNormal for VecZnx<T> impl<T: AsMut<[u8]> + AsRef<[u8]>> AddNormal for VecZnx<T>
where where
VecZnx<T>: VecZnxToMut, VecZnx<T>: VecZnxToMut,
{ {
@@ -177,7 +177,7 @@ where
} }
} }
impl<T> FillDistF64 for VecZnxBig<T, FFT64> impl<T: AsMut<[u8]> + AsRef<[u8]>> FillDistF64 for VecZnxBig<T, FFT64>
where where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>, VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{ {
@@ -220,7 +220,7 @@ where
} }
} }
impl<T> AddDistF64 for VecZnxBig<T, FFT64> impl<T: AsMut<[u8]> + AsRef<[u8]>> AddDistF64 for VecZnxBig<T, FFT64>
where where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>, VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{ {
@@ -263,7 +263,7 @@ where
} }
} }
impl<T> FillNormal for VecZnxBig<T, FFT64> impl<T: AsMut<[u8]> + AsRef<[u8]>> FillNormal for VecZnxBig<T, FFT64>
where where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>, VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{ {
@@ -279,7 +279,7 @@ where
} }
} }
impl<T> AddNormal for VecZnxBig<T, FFT64> impl<T: AsMut<[u8]> + AsRef<[u8]>> AddNormal for VecZnxBig<T, FFT64>
where where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>, VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{ {

View File

@@ -196,108 +196,57 @@ pub trait ScalarZnxToRef {
fn to_ref(&self) -> ScalarZnx<&[u8]>; fn to_ref(&self) -> ScalarZnx<&[u8]>;
} }
impl<D> ScalarZnxToRef for ScalarZnx<D>
where
D: AsRef<[u8]>,
{
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
}
}
}
pub trait ScalarZnxToMut { pub trait ScalarZnxToMut {
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]>; fn to_mut(&mut self) -> ScalarZnx<&mut [u8]>;
} }
impl ScalarZnxToMut for ScalarZnx<Vec<u8>> { impl<D> ScalarZnxToMut for ScalarZnx<D>
where
D: AsRef<[u8]> + AsMut<[u8]>,
{
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> { fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
ScalarZnx { ScalarZnx {
data: self.data.as_mut_slice(), data: self.data.as_mut(),
n: self.n, n: self.n,
cols: self.cols, cols: self.cols,
} }
} }
} }
impl VecZnxToMut for ScalarZnx<Vec<u8>> { impl<D> VecZnxToRef for ScalarZnx<D>
where
D: AsRef<[u8]>,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl<D> VecZnxToMut for ScalarZnx<D>
where
D: AsRef<[u8]> + AsMut<[u8]>,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> { fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx { VecZnx {
data: self.data.as_mut_slice(), data: self.data.as_mut(),
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToRef for ScalarZnx<Vec<u8>> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToRef for ScalarZnx<Vec<u8>> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToMut for ScalarZnx<&mut [u8]> {
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
ScalarZnx {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToMut for ScalarZnx<&mut [u8]> {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToRef for ScalarZnx<&mut [u8]> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToRef for ScalarZnx<&mut [u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToRef for ScalarZnx<&[u8]> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToRef for ScalarZnx<&[u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
n: self.n, n: self.n,
cols: self.cols, cols: self.cols,
size: 1, size: 1,

View File

@@ -113,14 +113,33 @@ pub trait ScalarZnxDftToRef<B: Backend> {
fn to_ref(&self) -> ScalarZnxDft<&[u8], B>; fn to_ref(&self) -> ScalarZnxDft<&[u8], B>;
} }
impl<D, B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<D, B>
where
D: AsRef<[u8]>,
B: Backend,
{
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
ScalarZnxDft {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
pub trait ScalarZnxDftToMut<B: Backend> { pub trait ScalarZnxDftToMut<B: Backend> {
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B>; fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B>;
} }
impl<B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> { impl<D, B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<D, B>
where
D: AsMut<[u8]> + AsRef<[u8]>,
B: Backend,
{
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> { fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
ScalarZnxDft { ScalarZnxDft {
data: self.data.as_mut_slice(), data: self.data.as_mut(),
n: self.n, n: self.n,
cols: self.cols, cols: self.cols,
_phantom: PhantomData, _phantom: PhantomData,
@@ -128,106 +147,34 @@ impl<B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> {
} }
} }
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<Vec<u8>, B> { impl<D, B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<D, B>
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { where
ScalarZnxDft { D: AsRef<[u8]>,
data: self.data.as_slice(), B: Backend,
{
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data.as_ref(),
n: self.n, n: self.n,
cols: self.cols, cols: self.cols,
_phantom: PhantomData, size: 1,
_phantom: std::marker::PhantomData,
} }
} }
} }
impl<B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<&mut [u8], B> { impl<D, B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<D, B>
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> { where
ScalarZnxDft { D: AsRef<[u8]> + AsMut<[u8]>,
data: self.data, B: Backend,
n: self.n, {
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<&mut [u8], B> {
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
ScalarZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
ScalarZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft { VecZnxDft {
data: self.data.as_mut_slice(), data: self.data.as_mut(),
n: self.n, n: self.n,
cols: self.cols, cols: self.cols,
size: 1, size: 1,
_phantom: PhantomData, _phantom: std::marker::PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&mut [u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
} }
} }
} }

View File

@@ -1,103 +1,105 @@
use crate::ffi::svp; use crate::ffi::svp;
use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{ use crate::{
Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft, Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft,
VecZnxDftToMut, VecZnxDftToRef, VecZnxDftToMut, VecZnxDftToRef,
}; };
pub trait ScalarZnxDftAlloc<B: Backend> { pub trait ScalarZnxDftAlloc<B: Backend> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>; fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>;
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize; fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>; fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>;
} }
pub trait ScalarZnxDftOps<BACKEND: Backend> { pub trait ScalarZnxDftOps<BACKEND: Backend> {
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where where
R: ScalarZnxDftToMut<BACKEND>, R: ScalarZnxDftToMut<BACKEND>,
A: ScalarZnxToRef; A: ScalarZnxToRef;
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
R: VecZnxDftToMut<BACKEND>, where
A: ScalarZnxDftToRef<BACKEND>, R: VecZnxDftToMut<BACKEND>,
B: VecZnxDftToRef<FFT64>; A: ScalarZnxDftToRef<BACKEND>,
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) B: VecZnxDftToRef<BACKEND>;
where
R: VecZnxDftToMut<BACKEND>, fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
A: ScalarZnxDftToRef<BACKEND>; where
} R: VecZnxDftToMut<BACKEND>,
A: ScalarZnxDftToRef<BACKEND>;
impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> { }
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B> {
ScalarZnxDftOwned::new(self, cols) impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
} fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B> {
ScalarZnxDftOwned::new(self, cols)
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize { }
ScalarZnxDftOwned::bytes_of(self, cols)
} fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize {
ScalarZnxDftOwned::bytes_of(self, cols)
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B> { }
ScalarZnxDftOwned::new_from_bytes(self, cols, bytes)
} fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B> {
} ScalarZnxDftOwned::new_from_bytes(self, cols, bytes)
}
impl ScalarZnxDftOps<FFT64> for Module<FFT64> { }
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
R: ScalarZnxDftToMut<FFT64>, fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
A: ScalarZnxToRef, where
{ R: ScalarZnxDftToMut<FFT64>,
unsafe { A: ScalarZnxToRef,
svp::svp_prepare( {
self.ptr, unsafe {
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t, svp::svp_prepare(
a.to_ref().at_ptr(a_col, 0), self.ptr,
) res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
} a.to_ref().at_ptr(a_col, 0),
} )
}
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize) }
where
R: VecZnxDftToMut<FFT64>, fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
A: ScalarZnxDftToRef<FFT64>, where
B: VecZnxDftToRef<FFT64>, R: VecZnxDftToMut<FFT64>,
{ A: ScalarZnxDftToRef<FFT64>,
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); B: VecZnxDftToRef<FFT64>,
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); {
let b: VecZnxDft<&[u8], FFT64> = b.to_ref(); let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
unsafe { let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
svp::svp_apply_dft_to_dft( let b: VecZnxDft<&[u8], FFT64> = b.to_ref();
self.ptr, unsafe {
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, svp::svp_apply_dft_to_dft(
res.size() as u64, self.ptr,
res.cols() as u64, res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, res.size() as u64,
b.at_ptr(b_col, 0) as *const vec_znx_dft_t, res.cols() as u64,
b.size() as u64, a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
b.cols() as u64, b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
) b.size() as u64,
} b.cols() as u64,
} )
}
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) }
where
R: VecZnxDftToMut<FFT64>, fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
A: ScalarZnxDftToRef<FFT64>, where
{ R: VecZnxDftToMut<FFT64>,
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut(); A: ScalarZnxDftToRef<FFT64>,
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref(); {
unsafe { let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
svp::svp_apply_dft_to_dft( let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
self.ptr, unsafe {
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t, svp::svp_apply_dft_to_dft(
res.size() as u64, self.ptr,
res.cols() as u64, res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t, res.size() as u64,
res.at_ptr(res_col, 0) as *const vec_znx_dft_t, res.cols() as u64,
res.size() as u64, a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
res.cols() as u64, res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
) res.size() as u64,
} res.cols() as u64,
} )
} }
}
}

View File

@@ -237,14 +237,15 @@ fn normalize<D: AsMut<[u8]> + AsRef<[u8]>>(basek: usize, a: &mut VecZnx<D>, a_co
} }
} }
impl<D> VecZnx<D> impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnx<D>
where where
VecZnx<D>: VecZnxToMut + ZnxInfos, VecZnx<D>: VecZnxToMut + ZnxInfos,
{ {
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<R>(&mut self, self_col: usize, a: &R, a_col: usize) pub fn extract_column<R>(&mut self, self_col: usize, a: &VecZnx<R>, a_col: usize)
where where
R: VecZnxToRef + ZnxInfos, R: AsRef<[u8]>,
VecZnx<R>: VecZnxToRef + ZnxInfos,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -313,72 +314,41 @@ pub trait VecZnxToRef {
fn to_ref(&self) -> VecZnx<&[u8]>; fn to_ref(&self) -> VecZnx<&[u8]>;
} }
pub trait VecZnxToMut: VecZnxToRef { impl<D> VecZnxToRef for VecZnx<D>
where
D: AsRef<[u8]>,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
pub trait VecZnxToMut {
fn to_mut(&mut self) -> VecZnx<&mut [u8]>; fn to_mut(&mut self) -> VecZnx<&mut [u8]>;
} }
impl VecZnxToMut for VecZnx<Vec<u8>> { impl<D> VecZnxToMut for VecZnx<D>
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl VecZnxToRef for VecZnx<Vec<u8>> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl VecZnxToMut for VecZnx<&mut [u8]> {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl VecZnxToRef for VecZnx<&mut [u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl VecZnxToRef for VecZnx<&[u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl<DataSelf> VecZnx<DataSelf>
where where
VecZnx<DataSelf>: VecZnxToRef, D: AsRef<[u8]> + AsMut<[u8]>,
{ {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data.as_mut(),
n: self.n,
cols: self.cols,
size: self.size,
}
}
}
impl<DataSelf: AsRef<[u8]>> VecZnx<DataSelf> {
pub fn clone(&self) -> VecZnx<Vec<u8>> { pub fn clone(&self) -> VecZnx<Vec<u8>> {
let self_ref: VecZnx<&[u8]> = self.to_ref(); let self_ref: VecZnx<&[u8]> = self.to_ref();
VecZnx { VecZnx {
data: self_ref.data.to_vec(), data: self_ref.data.to_vec(),
n: self_ref.n, n: self_ref.n,

View File

@@ -94,7 +94,7 @@ impl<D, B: Backend> VecZnxBig<D, B> {
} }
} }
impl<D> VecZnxBig<D, FFT64> impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnxBig<D, FFT64>
where where
VecZnxBig<D, FFT64>: VecZnxBigToMut<FFT64> + ZnxInfos, VecZnxBig<D, FFT64>: VecZnxBigToMut<FFT64> + ZnxInfos,
{ {
@@ -110,9 +110,9 @@ where
} }
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxBig<C, FFT64>, a_col: usize) pub fn extract_column<C>(&mut self, self_col: usize, a: &C, a_col: usize)
where where
VecZnxBig<C, FFT64>: VecZnxBigToRef<FFT64> + ZnxInfos, C: VecZnxBigToRef<FFT64> + ZnxInfos,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -144,66 +144,38 @@ pub trait VecZnxBigToRef<B: Backend> {
fn to_ref(&self) -> VecZnxBig<&[u8], B>; fn to_ref(&self) -> VecZnxBig<&[u8], B>;
} }
impl<D, B: Backend> VecZnxBigToRef<B> for VecZnxBig<D, B>
where
D: AsRef<[u8]>,
B: Backend,
{
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
VecZnxBig {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: std::marker::PhantomData,
}
}
}
pub trait VecZnxBigToMut<B: Backend> { pub trait VecZnxBigToMut<B: Backend> {
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>; fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>;
} }
impl<B: Backend> VecZnxBigToMut<B> for VecZnxBig<Vec<u8>, B> { impl<D, B: Backend> VecZnxBigToMut<B> for VecZnxBig<D, B>
where
D: AsRef<[u8]> + AsMut<[u8]>,
B: Backend,
{
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> { fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
VecZnxBig { VecZnxBig {
data: self.data.as_mut_slice(), data: self.data.as_mut(),
n: self.n, n: self.n,
cols: self.cols, cols: self.cols,
size: self.size, size: self.size,
_phantom: PhantomData, _phantom: std::marker::PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<Vec<u8>, B> {
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
VecZnxBig {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToMut<B> for VecZnxBig<&mut [u8], B> {
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
VecZnxBig {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<&mut [u8], B> {
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
VecZnxBig {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<&[u8], B> {
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
VecZnxBig {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
} }
} }
} }

View File

@@ -128,7 +128,7 @@ pub trait VecZnxBigOps<BACKEND: Backend> {
fn vec_znx_big_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) fn vec_znx_big_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where where
R: VecZnxToMut, R: VecZnxToMut,
A: VecZnxBigToRef<BACKEND>; A: VecZnxBigToRef<FFT64>;
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`. /// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize) fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
@@ -501,7 +501,7 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
} }
} }
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, res_col: usize) fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
where where
A: VecZnxBigToMut<FFT64>, A: VecZnxBigToMut<FFT64>,
{ {
@@ -513,10 +513,10 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
unsafe { unsafe {
vec_znx::vec_znx_negate( vec_znx::vec_znx_negate(
self.ptr, self.ptr,
a.at_mut_ptr(res_col, 0), a.at_mut_ptr(a_col, 0),
a.size() as u64, a.size() as u64,
a.sl() as u64, a.sl() as u64,
a.at_ptr(res_col, 0), a.at_ptr(a_col, 0),
a.size() as u64, a.size() as u64,
a.sl() as u64, a.sl() as u64,
) )

View File

@@ -91,14 +91,14 @@ impl<D: From<Vec<u8>>, B: Backend> VecZnxDft<D, B> {
} }
} }
impl<D> VecZnxDft<D, FFT64> impl<D: AsMut<[u8]> + AsRef<[u8]>> VecZnxDft<D, FFT64>
where where
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64> + ZnxInfos, VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
{ {
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self]. /// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxDft<C, FFT64>, a_col: usize) pub fn extract_column<C: AsRef<[u8]>>(&mut self, self_col: usize, a: &VecZnxDft<C, FFT64>, a_col: usize)
where where
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos, VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64>,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -142,66 +142,38 @@ pub trait VecZnxDftToRef<B: Backend> {
fn to_ref(&self) -> VecZnxDft<&[u8], B>; fn to_ref(&self) -> VecZnxDft<&[u8], B>;
} }
pub trait VecZnxDftToMut<B: Backend>: VecZnxDftToRef<B> { impl<D, B: Backend> VecZnxDftToRef<B> for VecZnxDft<D, B>
where
D: AsRef<[u8]>,
B: Backend,
{
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data.as_ref(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: std::marker::PhantomData,
}
}
}
pub trait VecZnxDftToMut<B: Backend> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>; fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>;
} }
impl<B: Backend> VecZnxDftToMut<B> for VecZnxDft<Vec<u8>, B> { impl<D, B: Backend> VecZnxDftToMut<B> for VecZnxDft<D, B>
where
D: AsRef<[u8]> + AsMut<[u8]>,
B: Backend,
{
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft { VecZnxDft {
data: self.data.as_mut_slice(), data: self.data.as_mut(),
n: self.n, n: self.n,
cols: self.cols, cols: self.cols,
size: self.size, size: self.size,
_phantom: PhantomData, _phantom: std::marker::PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToMut<B> for VecZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<&mut [u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<&[u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
} }
} }
} }

View File

@@ -59,7 +59,7 @@ pub trait VecZnxOps {
A: VecZnxToRef; A: VecZnxToRef;
/// Adds the selected column of `a` on the selected column and limb of `res`. /// Adds the selected column of `a` on the selected column and limb of `res`.
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize) fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where where
R: VecZnxToMut, R: VecZnxToMut,
A: ScalarZnxToRef; A: ScalarZnxToRef;

View File

@@ -5,8 +5,9 @@ use core::{
glwe_ciphertext::GLWECiphertext, glwe_ciphertext::GLWECiphertext,
keys::{SecretKey, SecretKeyFourier}, keys::{SecretKey, SecretKeyFourier},
}; };
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use sampling::source::Source; use sampling::source::Source;
use std::hint::black_box;
fn bench_external_product_glwe_fft64(c: &mut Criterion) { fn bench_external_product_glwe_fft64(c: &mut Criterion) {
let mut group = c.benchmark_group("external_product_glwe_fft64"); let mut group = c.benchmark_group("external_product_glwe_fft64");

View File

@@ -5,8 +5,9 @@ use core::{
keys::{SecretKey, SecretKeyFourier}, keys::{SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey, keyswitch_key::GLWESwitchingKey,
}; };
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use sampling::source::Source; use sampling::source::Source;
use std::hint::black_box;
fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { fn bench_keyswitch_glwe_fft64(c: &mut Criterion) {
let mut group = c.benchmark_group("keyswitch_glwe_fft64"); let mut group = c.benchmark_group("keyswitch_glwe_fft64");
@@ -65,7 +66,7 @@ fn bench_keyswitch_glwe_fft64(c: &mut Criterion) {
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank_out); let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank_out);
sk_out_dft.dft(&module, &sk_out); sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk( ksk.generate_from_sk(
&module, &module,
&sk_in, &sk_in,
&sk_out_dft, &sk_out_dft,
@@ -158,7 +159,7 @@ fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) {
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank); let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank);
sk_out_dft.dft(&module, &sk_out); sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk( ksk.generate_from_sk(
&module, &module,
&sk_in, &sk_in,
&sk_out_dft, &sk_out_dft,

View File

@@ -1,7 +1,6 @@
use backend::{ use backend::{
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftAlloc, Backend, FFT64, MatZnxDft, MatZnxDftOps, Module, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxOps, Scratch, VecZnx,
ScalarZnxDftOps, ScalarZnxOps, ScalarZnxToRef, Scratch, VecZnx, VecZnxBigAlloc, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxBigAlloc, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero,
VecZnxDftToRef, VecZnxOps, ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -63,45 +62,27 @@ impl<T, B: Backend> AutomorphismKey<T, B> {
} }
} }
impl<DataSelf, B: Backend> MatZnxDftToMut<B> for AutomorphismKey<DataSelf, B> impl<C: AsRef<[u8]>> GetRow<FFT64> for AutomorphismKey<C, FFT64> {
where fn get_row<R: AsMut<[u8]> + AsRef<[u8]>>(
MatZnxDft<DataSelf, B>: MatZnxDftToMut<B>, &self,
{ module: &Module<FFT64>,
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { row_i: usize,
self.key.to_mut() col_j: usize,
res: &mut GLWECiphertextFourier<R, FFT64>,
) {
module.vmp_extract_row(&mut res.data, &self.key.0.data, row_i, col_j);
} }
} }
impl<DataSelf, B: Backend> MatZnxDftToRef<B> for AutomorphismKey<DataSelf, B> impl<C: AsMut<[u8]> + AsRef<[u8]>> SetRow<FFT64> for AutomorphismKey<C, FFT64> {
where fn set_row<R: AsRef<[u8]>>(
MatZnxDft<DataSelf, B>: MatZnxDftToRef<B>, &mut self,
{ module: &Module<FFT64>,
fn to_ref(&self) -> MatZnxDft<&[u8], B> { row_i: usize,
self.key.to_ref() col_j: usize,
} a: &GLWECiphertextFourier<R, FFT64>,
} ) {
module.vmp_prepare_row(&mut self.key.0.data, row_i, col_j, &a.data);
impl<C> GetRow<FFT64> for AutomorphismKey<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
{
fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut R)
where
R: VecZnxDftToMut<FFT64>,
{
module.vmp_extract_row(res, self, row_i, col_j);
}
}
impl<C> SetRow<FFT64> for AutomorphismKey<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
{
fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, col_j: usize, a: &R)
where
R: VecZnxDftToRef<FFT64>,
{
module.vmp_prepare_row(self, row_i, col_j, a);
} }
} }
@@ -166,11 +147,8 @@ impl AutomorphismKey<Vec<u8>, FFT64> {
} }
} }
impl<DataSelf> AutomorphismKey<DataSelf, FFT64> impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> AutomorphismKey<DataSelf, FFT64> {
where pub fn generate_from_sk<DataSk: AsRef<[u8]>>(
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64>,
{
pub fn generate_from_sk<DataSk>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
p: i64, p: i64,
@@ -179,9 +157,7 @@ where
source_xe: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
ScalarZnx<DataSk>: ScalarZnxToRef,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(self.n(), module.n()); assert_eq!(self.n(), module.n());
@@ -209,12 +185,18 @@ where
{ {
(0..self.rank()).for_each(|i| { (0..self.rank()).for_each(|i| {
let (mut sk_inv_auto, _) = scratch_1.tmp_scalar_znx(module, 1); let (mut sk_inv_auto, _) = scratch_1.tmp_scalar_znx(module, 1);
module.scalar_znx_automorphism(module.galois_element_inv(p), &mut sk_inv_auto, 0, sk, i); module.scalar_znx_automorphism(
module.svp_prepare(&mut sk_out_dft, i, &sk_inv_auto, 0); module.galois_element_inv(p),
&mut sk_inv_auto,
0,
&sk.data,
i,
);
module.svp_prepare(&mut sk_out_dft.data, i, &sk_inv_auto, 0);
}); });
} }
self.key.encrypt_sk( self.key.generate_from_sk(
module, module,
&sk, &sk,
&sk_out_dft, &sk_out_dft,
@@ -228,20 +210,14 @@ where
} }
} }
impl<DataSelf> AutomorphismKey<DataSelf, FFT64> impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> AutomorphismKey<DataSelf, FFT64> {
where pub fn automorphism<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64>,
{
pub fn automorphism<DataLhs, DataRhs>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &AutomorphismKey<DataLhs, FFT64>, lhs: &AutomorphismKey<DataLhs, FFT64>,
rhs: &AutomorphismKey<DataRhs, FFT64>, rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!( assert_eq!(
@@ -311,8 +287,8 @@ where
// Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a) // Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a)
// and switches back to DFT domain // and switches back to DFT domain
(0..self.rank_out() + 1).for_each(|i| { (0..self.rank_out() + 1).for_each(|i| {
module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft, i); module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft.data, i);
module.vec_znx_dft(&mut tmp_dft, i, &tmp_idft, i); module.vec_znx_dft(&mut tmp_dft.data, i, &tmp_idft.data, i);
}); });
// Sets back the relevant row // Sets back the relevant row
@@ -331,65 +307,53 @@ where
self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64); self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64);
} }
pub fn automorphism_inplace<DataRhs>( pub fn automorphism_inplace<DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
rhs: &AutomorphismKey<DataRhs, FFT64>, rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe { unsafe {
let self_ptr: *mut AutomorphismKey<DataSelf, FFT64> = self as *mut AutomorphismKey<DataSelf, FFT64>; let self_ptr: *mut AutomorphismKey<DataSelf, FFT64> = self as *mut AutomorphismKey<DataSelf, FFT64>;
self.automorphism(&module, &*self_ptr, rhs, scratch); self.automorphism(&module, &*self_ptr, rhs, scratch);
} }
} }
pub fn keyswitch<DataLhs, DataRhs>( pub fn keyswitch<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &AutomorphismKey<DataLhs, FFT64>, lhs: &AutomorphismKey<DataLhs, FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>, rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
self.key.keyswitch(module, &lhs.key, rhs, scratch); self.key.keyswitch(module, &lhs.key, rhs, scratch);
} }
pub fn keyswitch_inplace<DataRhs>( pub fn keyswitch_inplace<DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
rhs: &AutomorphismKey<DataRhs, FFT64>, rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
self.key.keyswitch_inplace(module, &rhs.key, scratch); self.key.keyswitch_inplace(module, &rhs.key, scratch);
} }
pub fn external_product<DataLhs, DataRhs>( pub fn external_product<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &AutomorphismKey<DataLhs, FFT64>, lhs: &AutomorphismKey<DataLhs, FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>, rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
self.key.external_product(module, &lhs.key, rhs, scratch); self.key.external_product(module, &lhs.key, rhs, scratch);
} }
pub fn external_product_inplace<DataRhs>( pub fn external_product_inplace<DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>, rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
self.key.external_product_inplace(module, rhs, scratch); self.key.external_product_inplace(module, rhs, scratch);
} }
} }

View File

@@ -1,6 +1,6 @@
use backend::{Backend, Module, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos}; use backend::{Backend, Module, ZnxInfos};
use crate::utils::derive_size; use crate::{glwe_ciphertext_fourier::GLWECiphertextFourier, utils::derive_size};
pub trait Infos { pub trait Infos {
type Inner: ZnxInfos; type Inner: ZnxInfos;
@@ -56,13 +56,13 @@ pub trait SetMetaData {
} }
pub trait GetRow<B: Backend> { pub trait GetRow<B: Backend> {
fn get_row<R>(&self, module: &Module<B>, row_i: usize, col_j: usize, res: &mut R) fn get_row<R>(&self, module: &Module<B>, row_i: usize, col_j: usize, res: &mut GLWECiphertextFourier<R, B>)
where where
R: VecZnxDftToMut<B>; R: AsMut<[u8]> + AsRef<[u8]>;
} }
pub trait SetRow<B: Backend> { pub trait SetRow<B: Backend> {
fn set_row<R>(&mut self, module: &Module<B>, row_i: usize, col_j: usize, a: &R) fn set_row<R>(&mut self, module: &Module<B>, row_i: usize, col_j: usize, a: &GLWECiphertextFourier<R, B>)
where where
R: VecZnxDftToRef<B>; R: AsRef<[u8]>;
} }

View File

@@ -1,7 +1,6 @@
use backend::{ use backend::{
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, Module, ScalarZnx, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxOps,
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxInfos, ZnxZero,
ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -60,24 +59,6 @@ impl<T, B: Backend> GGLWECiphertext<T, B> {
} }
} }
impl<C, B: Backend> MatZnxDftToMut<B> for GGLWECiphertext<C, B>
where
MatZnxDft<C, B>: MatZnxDftToMut<B>,
{
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> MatZnxDftToRef<B> for GGLWECiphertext<C, B>
where
MatZnxDft<C, B>: MatZnxDftToRef<B>,
{
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
self.data.to_ref()
}
}
impl GGLWECiphertext<Vec<u8>, FFT64> { impl GGLWECiphertext<Vec<u8>, FFT64> {
pub fn generate_from_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize { pub fn generate_from_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
GLWECiphertext::encrypt_sk_scratch_space(module, size) GLWECiphertext::encrypt_sk_scratch_space(module, size)
@@ -91,11 +72,8 @@ impl GGLWECiphertext<Vec<u8>, FFT64> {
} }
} }
impl<DataSelf> GGLWECiphertext<DataSelf, FFT64> impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGLWECiphertext<DataSelf, FFT64> {
where pub fn encrypt_sk<DataPt: AsRef<[u8]>, DataSk: AsRef<[u8]>>(
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + ZnxInfos,
{
pub fn generate_from_sk<DataPt, DataSk>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
pt: &ScalarZnx<DataPt>, pt: &ScalarZnx<DataPt>,
@@ -104,10 +82,7 @@ where
source_xe: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
ScalarZnx<DataPt>: ScalarZnxToRef,
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(self.rank_in(), pt.cols()); assert_eq!(self.rank_in(), pt.cols());
@@ -171,8 +146,8 @@ where
(0..rows).for_each(|row_i| { (0..rows).for_each(|row_i| {
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
vec_znx_pt.data.zero(); // zeroes for next iteration vec_znx_pt.data.zero(); // zeroes for next iteration
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, col_i); // Selects the i-th module.vec_znx_add_scalar_inplace(&mut vec_znx_pt.data, 0, row_i, pt, col_i); // Selects the i-th
module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3); module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt.data, 0, scratch_3);
// rlwe encrypt of vec_znx_pt into vec_znx_ct // rlwe encrypt of vec_znx_pt into vec_znx_ct
vec_znx_ct.encrypt_sk( vec_znx_ct.encrypt_sk(
@@ -189,32 +164,32 @@ where
vec_znx_ct.dft(module, &mut vec_znx_ct_dft); vec_znx_ct.dft(module, &mut vec_znx_ct_dft);
// Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft // Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft
module.vmp_prepare_row(self, row_i, col_i, &vec_znx_ct_dft); module.vmp_prepare_row(&mut self.data, row_i, col_i, &vec_znx_ct_dft.data);
}); });
}); });
} }
} }
impl<C> GetRow<FFT64> for GGLWECiphertext<C, FFT64> impl<C: AsRef<[u8]>> GetRow<FFT64> for GGLWECiphertext<C, FFT64> {
where fn get_row<R: AsMut<[u8]> + AsRef<[u8]>>(
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>, &self,
{ module: &Module<FFT64>,
fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut R) row_i: usize,
where col_j: usize,
R: VecZnxDftToMut<FFT64>, res: &mut GLWECiphertextFourier<R, FFT64>,
{ ) {
module.vmp_extract_row(res, self, row_i, col_j); module.vmp_extract_row(&mut res.data, &self.data, row_i, col_j);
} }
} }
impl<C> SetRow<FFT64> for GGLWECiphertext<C, FFT64> impl<C: AsMut<[u8]> + AsRef<[u8]>> SetRow<FFT64> for GGLWECiphertext<C, FFT64> {
where fn set_row<R: AsRef<[u8]>>(
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>, &mut self,
{ module: &Module<FFT64>,
fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, col_j: usize, a: &R) row_i: usize,
where col_j: usize,
R: VecZnxDftToRef<FFT64>, a: &GLWECiphertextFourier<R, FFT64>,
{ ) {
module.vmp_prepare_row(self, row_i, col_j, a); module.vmp_prepare_row(&mut self.data, row_i, col_j, &a.data);
} }
} }

View File

@@ -1,7 +1,6 @@
use backend::{ use backend::{
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnx, Scratch, VecZnxAlloc,
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxInfos,
VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, ZnxInfos,
ZnxZero, ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -56,24 +55,6 @@ impl<T, B: Backend> GGSWCiphertext<T, B> {
} }
} }
impl<C, B: Backend> MatZnxDftToMut<B> for GGSWCiphertext<C, B>
where
MatZnxDft<C, B>: MatZnxDftToMut<B>,
{
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> MatZnxDftToRef<B> for GGSWCiphertext<C, B>
where
MatZnxDft<C, B>: MatZnxDftToRef<B>,
{
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
self.data.to_ref()
}
}
impl GGSWCiphertext<Vec<u8>, FFT64> { impl GGSWCiphertext<Vec<u8>, FFT64> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize { pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
GLWECiphertext::encrypt_sk_scratch_space(module, size) GLWECiphertext::encrypt_sk_scratch_space(module, size)
@@ -146,7 +127,8 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
let res: usize = module.bytes_of_vec_znx(cols, out_size); let res: usize = module.bytes_of_vec_znx(cols, out_size);
let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size); let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size); let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
let ks_internal: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, out_size, in_size, auto_key_size, rank); let ks_internal: usize =
GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, out_size, in_size, auto_key_size, rank);
let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tensor_key_size, rank); let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tensor_key_size, rank);
res + ci_dft + (ks_internal | expand | res_dft) res + ci_dft + (ks_internal | expand | res_dft)
} }
@@ -193,11 +175,8 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
} }
} }
impl<DataSelf> GGSWCiphertext<DataSelf, FFT64> impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
where pub fn encrypt_sk<DataPt: AsRef<[u8]>, DataSk: AsRef<[u8]>>(
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64>,
{
pub fn encrypt_sk<DataPt, DataSk>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
pt: &ScalarZnx<DataPt>, pt: &ScalarZnx<DataPt>,
@@ -206,10 +185,7 @@ where
source_xe: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
ScalarZnx<DataPt>: ScalarZnxToRef,
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(self.rank(), sk_dft.rank()); assert_eq!(self.rank(), sk_dft.rank());
@@ -242,8 +218,8 @@ where
vec_znx_pt.data.zero(); vec_znx_pt.data.zero();
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0); module.vec_znx_add_scalar_inplace(&mut vec_znx_pt.data, 0, row_i, pt, 0);
module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scrach_2); module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt.data, 0, scrach_2);
(0..cols).for_each(|col_j| { (0..cols).for_each(|col_j| {
// rlwe encrypt of vec_znx_pt into vec_znx_ct // rlwe encrypt of vec_znx_pt into vec_znx_ct
@@ -263,16 +239,16 @@ where
let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, cols, size); let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, cols, size);
(0..cols).for_each(|i| { (0..cols).for_each(|i| {
module.vec_znx_dft(&mut vec_znx_dft_ct, i, &vec_znx_ct, i); module.vec_znx_dft(&mut vec_znx_dft_ct, i, &vec_znx_ct.data, i);
}); });
self.set_row(module, row_i, col_j, &vec_znx_dft_ct); module.vmp_prepare_row(&mut self.data, row_i, col_j, &vec_znx_dft_ct);
} }
}); });
}); });
} }
pub(crate) fn expand_row<R, DataCi, DataTsk>( pub(crate) fn expand_row<R, DataCi: AsRef<[u8]>, DataTsk: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
col_j: usize, col_j: usize,
@@ -282,8 +258,6 @@ where
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) where
R: VecZnxToMut, R: VecZnxToMut,
VecZnxDft<DataCi, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataTsk, FFT64>: MatZnxDftToRef<FFT64>,
{ {
let cols: usize = self.rank() + 1; let cols: usize = self.rank() + 1;
@@ -332,14 +306,14 @@ where
module.vmp_apply( module.vmp_apply(
&mut tmp_dft_i, &mut tmp_dft_i,
&tmp_dft_col_data, &tmp_dft_col_data,
tsk.at(col_i - 1, col_j - 1), // Selects Enc(s[i]s[j]) &tsk.at(col_i - 1, col_j - 1).0.data, // Selects Enc(s[i]s[j])
scratch2, scratch2,
); );
} else { } else {
module.vmp_apply_add( module.vmp_apply_add(
&mut tmp_dft_i, &mut tmp_dft_i,
&tmp_dft_col_data, &tmp_dft_col_data,
tsk.at(col_i - 1, col_j - 1), // Selects Enc(s[i]s[j]) &tsk.at(col_i - 1, col_j - 1).0.data, // Selects Enc(s[i]s[j])
scratch2, scratch2,
); );
} }
@@ -363,18 +337,14 @@ where
}); });
} }
pub fn keyswitch<DataLhs, DataKsk, DataTsk>( pub fn keyswitch<DataLhs: AsRef<[u8]>, DataKsk: AsRef<[u8]>, DataTsk: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &GGSWCiphertext<DataLhs, FFT64>, lhs: &GGSWCiphertext<DataLhs, FFT64>,
ksk: &GLWESwitchingKey<DataKsk, FFT64>, ksk: &GLWESwitchingKey<DataKsk, FFT64>,
tsk: &TensorKey<DataTsk, FFT64>, tsk: &TensorKey<DataTsk, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataKsk, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataTsk, FFT64>: MatZnxDftToRef<FFT64>,
{
let cols: usize = self.rank() + 1; let cols: usize = self.rank() + 1;
let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size()); let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size());
@@ -394,10 +364,10 @@ where
// Isolates DFT(a[i]) // Isolates DFT(a[i])
(0..cols).for_each(|col_i| { (0..cols).for_each(|col_i| {
module.vec_znx_dft(&mut ci_dft, col_i, &res, col_i); module.vec_znx_dft(&mut ci_dft, col_i, &res.data, col_i);
}); });
self.set_row(module, row_i, 0, &ci_dft); module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft);
// Generates // Generates
// //
@@ -405,46 +375,39 @@ where
// col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 )
// col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i])
(1..cols).for_each(|col_j| { (1..cols).for_each(|col_j| {
self.expand_row(module, col_j, &mut res, &ci_dft, tsk, scratch2); self.expand_row(module, col_j, &mut res.data, &ci_dft, tsk, scratch2);
let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size()); let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size());
(0..cols).for_each(|i| { (0..cols).for_each(|i| {
module.vec_znx_dft(&mut res_dft, i, &res, i); module.vec_znx_dft(&mut res_dft, i, &res.data, i);
}); });
self.set_row(module, row_i, col_j, &res_dft); module.vmp_prepare_row(&mut self.data, row_i, col_j, &res_dft);
}) });
}) })
} }
pub fn keyswitch_inplace<DataKsk, DataTsk>( pub fn keyswitch_inplace<DataKsk: AsRef<[u8]>, DataTsk: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
ksk: &GLWESwitchingKey<DataKsk, FFT64>, ksk: &GLWESwitchingKey<DataKsk, FFT64>,
tsk: &TensorKey<DataTsk, FFT64>, tsk: &TensorKey<DataTsk, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataKsk, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataTsk, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe { unsafe {
let self_ptr: *mut GGSWCiphertext<DataSelf, FFT64> = self as *mut GGSWCiphertext<DataSelf, FFT64>; let self_ptr: *mut GGSWCiphertext<DataSelf, FFT64> = self as *mut GGSWCiphertext<DataSelf, FFT64>;
self.keyswitch(module, &*self_ptr, ksk, tsk, scratch); self.keyswitch(module, &*self_ptr, ksk, tsk, scratch);
} }
} }
pub fn automorphism<DataLhs, DataAk, DataTsk>( pub fn automorphism<DataLhs: AsRef<[u8]>, DataAk: AsRef<[u8]>, DataTsk: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &GGSWCiphertext<DataLhs, FFT64>, lhs: &GGSWCiphertext<DataLhs, FFT64>,
auto_key: &AutomorphismKey<DataAk, FFT64>, auto_key: &AutomorphismKey<DataAk, FFT64>,
tensor_key: &TensorKey<DataTsk, FFT64>, tensor_key: &TensorKey<DataTsk, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataAk, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataTsk, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!( assert_eq!(
@@ -468,7 +431,17 @@ where
self.rank(), self.rank(),
tensor_key.rank() tensor_key.rank()
); );
assert!(scratch.available() >= GGSWCiphertext::automorphism_scratch_space(module, self.size(), lhs.size(), auto_key.size(), tensor_key.size(), self.rank())) assert!(
scratch.available()
>= GGSWCiphertext::automorphism_scratch_space(
module,
self.size(),
lhs.size(),
auto_key.size(),
tensor_key.size(),
self.rank()
)
)
}; };
let cols: usize = self.rank() + 1; let cols: usize = self.rank() + 1;
@@ -491,11 +464,11 @@ where
// Isolates DFT(AUTO(a[i])) // Isolates DFT(AUTO(a[i]))
(0..cols).for_each(|col_i| { (0..cols).for_each(|col_i| {
// (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) -> (-(a0s0 + a1s1 + a2s2) + pi(M[i]), a0, a1, a2) // (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) -> (-(a0s0 + a1s1 + a2s2) + pi(M[i]), a0, a1, a2)
module.vec_znx_automorphism_inplace(auto_key.p(), &mut res, col_i); module.vec_znx_automorphism_inplace(auto_key.p(), &mut res.data, col_i);
module.vec_znx_dft(&mut ci_dft, col_i, &res, col_i); module.vec_znx_dft(&mut ci_dft, col_i, &res.data, col_i);
}); });
self.set_row(module, row_i, 0, &ci_dft); module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft);
// Generates // Generates
// //
@@ -503,44 +476,38 @@ where
// col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + pi(M[i]), c2 ) // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + pi(M[i]), c2 )
// col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + pi(M[i])) // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + pi(M[i]))
(1..cols).for_each(|col_j| { (1..cols).for_each(|col_j| {
self.expand_row(module, col_j, &mut res, &ci_dft, tensor_key, scratch2); self.expand_row(module, col_j, &mut res.data, &ci_dft, tensor_key, scratch2);
let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size()); let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size());
(0..cols).for_each(|i| { (0..cols).for_each(|i| {
module.vec_znx_dft(&mut res_dft, i, &res, i); module.vec_znx_dft(&mut res_dft, i, &res.data, i);
}); });
self.set_row(module, row_i, col_j, &res_dft); module.vmp_prepare_row(&mut self.data, row_i, col_j, &res_dft);
}) });
}) })
} }
pub fn automorphism_inplace<DataKsk, DataTsk>( pub fn automorphism_inplace<DataKsk: AsRef<[u8]>, DataTsk: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
auto_key: &AutomorphismKey<DataKsk, FFT64>, auto_key: &AutomorphismKey<DataKsk, FFT64>,
tensor_key: &TensorKey<DataTsk, FFT64>, tensor_key: &TensorKey<DataTsk, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataKsk, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataTsk, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe { unsafe {
let self_ptr: *mut GGSWCiphertext<DataSelf, FFT64> = self as *mut GGSWCiphertext<DataSelf, FFT64>; let self_ptr: *mut GGSWCiphertext<DataSelf, FFT64> = self as *mut GGSWCiphertext<DataSelf, FFT64>;
self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch); self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch);
} }
} }
pub fn external_product<DataLhs, DataRhs>( pub fn external_product<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &GGSWCiphertext<DataLhs, FFT64>, lhs: &GGSWCiphertext<DataLhs, FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>, rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!( assert_eq!(
@@ -592,14 +559,12 @@ where
}); });
} }
pub fn external_product_inplace<DataRhs>( pub fn external_product_inplace<DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>, rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!( assert_eq!(
@@ -629,26 +594,29 @@ where
} }
} }
impl<DataSelf> GGSWCiphertext<DataSelf, FFT64> impl<DataSelf: AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
where pub(crate) fn keyswitch_internal_col0<DataRes: AsMut<[u8]> + AsRef<[u8]>, DataKsk: AsRef<[u8]>>(
MatZnxDft<DataSelf, FFT64>: MatZnxDftToRef<FFT64>,
{
pub(crate) fn keyswitch_internal_col0<DataRes, DataKsk>(
&self, &self,
module: &Module<FFT64>, module: &Module<FFT64>,
row_i: usize, row_i: usize,
res: &mut GLWECiphertext<DataRes>, res: &mut GLWECiphertext<DataRes>,
ksk: &GLWESwitchingKey<DataKsk, FFT64>, ksk: &GLWESwitchingKey<DataKsk, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnx<DataRes>: VecZnxToMut,
MatZnxDft<DataKsk, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(self.rank(), ksk.rank()); assert_eq!(self.rank(), ksk.rank());
assert_eq!(res.rank(), ksk.rank()); assert_eq!(res.rank(), ksk.rank());
assert!(scratch.available() >= GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, res.size(), self.size(), ksk.size(), ksk.rank())) assert!(
scratch.available()
>= GGSWCiphertext::keyswitch_internal_col0_scratch_space(
module,
res.size(),
self.size(),
ksk.size(),
ksk.rank()
)
)
} }
let (tmp_dft_in_data, scratch2) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); let (tmp_dft_in_data, scratch2) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
@@ -662,26 +630,26 @@ where
} }
} }
impl<DataSelf> GetRow<FFT64> for GGSWCiphertext<DataSelf, FFT64> impl<DataSelf: AsRef<[u8]>> GetRow<FFT64> for GGSWCiphertext<DataSelf, FFT64> {
where fn get_row<R: AsMut<[u8]> + AsRef<[u8]>>(
MatZnxDft<DataSelf, FFT64>: MatZnxDftToRef<FFT64>, &self,
{ module: &Module<FFT64>,
fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut R) row_i: usize,
where col_j: usize,
R: VecZnxDftToMut<FFT64>, res: &mut GLWECiphertextFourier<R, FFT64>,
{ ) {
module.vmp_extract_row(res, self, row_i, col_j); module.vmp_extract_row(&mut res.data, &self.data, row_i, col_j);
} }
} }
impl<DataSelf> SetRow<FFT64> for GGSWCiphertext<DataSelf, FFT64> impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> SetRow<FFT64> for GGSWCiphertext<DataSelf, FFT64> {
where fn set_row<R: AsRef<[u8]>>(
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64>, &mut self,
{ module: &Module<FFT64>,
fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, col_j: usize, a: &R) row_i: usize,
where col_j: usize,
R: VecZnxDftToRef<FFT64>, a: &GLWECiphertextFourier<R, FFT64>,
{ ) {
module.vmp_prepare_row(self, row_i, col_j, a); module.vmp_prepare_row(&mut self.data, row_i, col_j, &a.data);
} }
} }

View File

@@ -1,8 +1,7 @@
use backend::{ use backend::{
AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxAlloc, AddNormal, Backend, FFT64, FillUniform, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxAlloc, ScalarZnxDftAlloc,
ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, ScalarZnxDftOps, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc,
VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxDftOps, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero,
VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -12,6 +11,7 @@ use crate::{
elem::{Infos, SetMetaData}, elem::{Infos, SetMetaData},
ggsw_ciphertext::GGSWCiphertext, ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier, glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_ops::GLWEOps,
glwe_plaintext::GLWEPlaintext, glwe_plaintext::GLWEPlaintext,
keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier}, keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey, keyswitch_key::GLWESwitchingKey,
@@ -56,33 +56,9 @@ impl<T> GLWECiphertext<T> {
} }
} }
impl<C> VecZnxToMut for GLWECiphertext<C> impl<C: AsRef<[u8]>> GLWECiphertext<C> {
where
VecZnx<C>: VecZnxToMut,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
self.data.to_mut()
}
}
impl<C> VecZnxToRef for GLWECiphertext<C>
where
VecZnx<C>: VecZnxToRef,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
self.data.to_ref()
}
}
impl<C> GLWECiphertext<C>
where
VecZnx<C>: VecZnxToRef,
{
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) fn dft<R>(&self, module: &Module<FFT64>, res: &mut GLWECiphertextFourier<R, FFT64>) pub(crate) fn dft<R: AsMut<[u8]> + AsRef<[u8]>>(&self, module: &Module<FFT64>, res: &mut GLWECiphertextFourier<R, FFT64>) {
where
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64> + ZnxInfos,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(self.rank(), res.rank()); assert_eq!(self.rank(), res.rank());
@@ -90,7 +66,7 @@ where
} }
(0..self.rank() + 1).for_each(|i| { (0..self.rank() + 1).for_each(|i| {
module.vec_znx_dft(res, i, self, i); module.vec_znx_dft(&mut res.data, i, &self.data, i);
}) })
} }
} }
@@ -199,10 +175,7 @@ impl GLWECiphertext<Vec<u8>> {
} }
} }
impl<DataSelf> SetMetaData for GLWECiphertext<DataSelf> impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> SetMetaData for GLWECiphertext<DataSelf> {
where
VecZnx<DataSelf>: VecZnxToMut,
{
fn set_k(&mut self, k: usize) { fn set_k(&mut self, k: usize) {
self.k = k self.k = k
} }
@@ -212,11 +185,8 @@ where
} }
} }
impl<DataSelf> GLWECiphertext<DataSelf> impl<DataSelf: AsRef<[u8]> + AsMut<[u8]>> GLWECiphertext<DataSelf> {
where pub fn encrypt_sk<DataPt: AsRef<[u8]>, DataSk: AsRef<[u8]>>(
VecZnx<DataSelf>: VecZnxToMut,
{
pub fn encrypt_sk<DataPt, DataSk>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
pt: &GLWEPlaintext<DataPt>, pt: &GLWEPlaintext<DataPt>,
@@ -225,10 +195,7 @@ where
source_xe: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnx<DataPt>: VecZnxToRef,
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
self.encrypt_sk_private( self.encrypt_sk_private(
module, module,
Some((pt, 0)), Some((pt, 0)),
@@ -240,7 +207,7 @@ where
); );
} }
pub fn encrypt_zero_sk<DataSk>( pub fn encrypt_zero_sk<DataSk: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>, sk_dft: &SecretKeyFourier<DataSk, FFT64>,
@@ -248,13 +215,19 @@ where
source_xe: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>, self.encrypt_sk_private(
{ module,
self.encrypt_sk_private(module, None, sk_dft, source_xa, source_xe, sigma, scratch); None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
sk_dft,
source_xa,
source_xe,
sigma,
scratch,
);
} }
pub fn encrypt_pk<DataPt, DataPk>( pub fn encrypt_pk<DataPt: AsRef<[u8]>, DataPk: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
pt: &GLWEPlaintext<DataPt>, pt: &GLWEPlaintext<DataPt>,
@@ -263,10 +236,7 @@ where
source_xe: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnx<DataPt>: VecZnxToRef,
VecZnxDft<DataPk, FFT64>: VecZnxDftToRef<FFT64>,
{
self.encrypt_pk_private( self.encrypt_pk_private(
module, module,
Some((pt, 0)), Some((pt, 0)),
@@ -278,7 +248,7 @@ where
); );
} }
pub fn encrypt_zero_pk<DataPk>( pub fn encrypt_zero_pk<DataPk: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
pk: &GLWEPublicKey<DataPk, FFT64>, pk: &GLWEPublicKey<DataPk, FFT64>,
@@ -286,133 +256,116 @@ where
source_xe: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnxDft<DataPk, FFT64>: VecZnxDftToRef<FFT64>, self.encrypt_pk_private(
{ module,
self.encrypt_pk_private(module, None, pk, source_xu, source_xe, sigma, scratch); None::<(&GLWEPlaintext<Vec<u8>>, usize)>,
pk,
source_xu,
source_xe,
sigma,
scratch,
);
} }
pub fn automorphism<DataLhs, DataRhs>( pub fn automorphism<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &GLWECiphertext<DataLhs>, lhs: &GLWECiphertext<DataLhs>,
rhs: &AutomorphismKey<DataRhs, FFT64>, rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnx<DataLhs>: VecZnxToRef,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
self.keyswitch(module, lhs, &rhs.key, scratch); self.keyswitch(module, lhs, &rhs.key, scratch);
(0..self.rank() + 1).for_each(|i| { (0..self.rank() + 1).for_each(|i| {
module.vec_znx_automorphism_inplace(rhs.p(), self, i); module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i);
}) })
} }
pub fn automorphism_inplace<DataRhs>( pub fn automorphism_inplace<DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
rhs: &AutomorphismKey<DataRhs, FFT64>, rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
self.keyswitch_inplace(module, &rhs.key, scratch); self.keyswitch_inplace(module, &rhs.key, scratch);
(0..self.rank() + 1).for_each(|i| { (0..self.rank() + 1).for_each(|i| {
module.vec_znx_automorphism_inplace(rhs.p(), self, i); module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i);
}) })
} }
pub fn automorphism_add<DataLhs, DataRhs>( pub fn automorphism_add<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &GLWECiphertext<DataLhs>, lhs: &GLWECiphertext<DataLhs>,
rhs: &AutomorphismKey<DataRhs, FFT64>, rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnx<DataLhs>: VecZnxToRef,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, lhs, &rhs.key, scratch); Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, lhs, &rhs.key, scratch);
} }
pub fn automorphism_add_inplace<DataRhs>( pub fn automorphism_add_inplace<DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
rhs: &AutomorphismKey<DataRhs, FFT64>, rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe { unsafe {
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>; let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); Self::keyswitch_private::<_, _, 1>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch);
} }
} }
pub fn automorphism_sub_ab<DataLhs, DataRhs>( pub fn automorphism_sub_ab<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &GLWECiphertext<DataLhs>, lhs: &GLWECiphertext<DataLhs>,
rhs: &AutomorphismKey<DataRhs, FFT64>, rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnx<DataLhs>: VecZnxToRef,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, lhs, &rhs.key, scratch); Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, lhs, &rhs.key, scratch);
} }
pub fn automorphism_sub_ab_inplace<DataRhs>( pub fn automorphism_sub_ab_inplace<DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
rhs: &AutomorphismKey<DataRhs, FFT64>, rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe { unsafe {
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>; let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); Self::keyswitch_private::<_, _, 2>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch);
} }
} }
pub fn automorphism_sub_ba<DataLhs, DataRhs>( pub fn automorphism_sub_ba<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &GLWECiphertext<DataLhs>, lhs: &GLWECiphertext<DataLhs>,
rhs: &AutomorphismKey<DataRhs, FFT64>, rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnx<DataLhs>: VecZnxToRef,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, lhs, &rhs.key, scratch); Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, lhs, &rhs.key, scratch);
} }
pub fn automorphism_sub_ba_inplace<DataRhs>( pub fn automorphism_sub_ba_inplace<DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
rhs: &AutomorphismKey<DataRhs, FFT64>, rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe { unsafe {
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>; let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch); Self::keyswitch_private::<_, _, 3>(self, rhs.p(), module, &*self_ptr, &rhs.key, scratch);
} }
} }
pub(crate) fn keyswitch_from_fourier<DataLhs, DataRhs>( pub(crate) fn keyswitch_from_fourier<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &GLWECiphertextFourier<DataLhs, FFT64>, lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>, rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
let basek: usize = self.basek(); let basek: usize = self.basek();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@@ -447,45 +400,39 @@ where
// Applies VMP // Applies VMP
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size()); let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
(0..cols_in).for_each(|col_i| { (0..cols_in).for_each(|col_i| {
module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1); module.vec_znx_dft_copy(&mut ai_dft, col_i, &lhs.data, col_i + 1);
}); });
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2); module.vmp_apply(&mut res_dft, &ai_dft, &rhs.0.data, scratch2);
} }
module.vec_znx_dft_add_inplace(&mut res_dft, 0, lhs, 0); module.vec_znx_dft_add_inplace(&mut res_dft, 0, &lhs.data, 0);
// Switches result of VMP outside of DFT // Switches result of VMP outside of DFT
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft); let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
(0..cols_out).for_each(|i| { (0..cols_out).for_each(|i| {
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1); module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1);
}); });
} }
pub fn keyswitch<DataLhs, DataRhs>( pub fn keyswitch<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &GLWECiphertext<DataLhs>, lhs: &GLWECiphertext<DataLhs>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>, rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnx<DataLhs>: VecZnxToRef,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
Self::keyswitch_private::<_, _, 0>(self, 0, module, lhs, rhs, scratch); Self::keyswitch_private::<_, _, 0>(self, 0, module, lhs, rhs, scratch);
} }
pub(crate) fn keyswitch_private<DataLhs, DataRhs, const OP: u8>( pub(crate) fn keyswitch_private<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>, const OP: u8>(
&mut self, &mut self,
apply_auto: i64, apply_auto: i64,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &GLWECiphertext<DataLhs>, lhs: &GLWECiphertext<DataLhs>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>, rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnx<DataLhs>: VecZnxToRef,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
let basek: usize = self.basek(); let basek: usize = self.basek();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@@ -518,14 +465,14 @@ where
{ {
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size()); let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
(0..cols_in).for_each(|col_i| { (0..cols_in).for_each(|col_i| {
module.vec_znx_dft(&mut ai_dft, col_i, lhs, col_i + 1); module.vec_znx_dft(&mut ai_dft, col_i, &lhs.data, col_i + 1);
}); });
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2); module.vmp_apply(&mut res_dft, &ai_dft, &rhs.0.data, scratch2);
} }
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, 0, lhs, 0); module.vec_znx_big_add_small_inplace(&mut res_big, 0, &lhs.data, 0);
(0..cols_out).for_each(|i| { (0..cols_out).for_each(|i| {
if apply_auto != 0 { if apply_auto != 0 {
@@ -533,39 +480,34 @@ where
} }
match OP { match OP {
1 => module.vec_znx_big_add_small_inplace(&mut res_big, i, lhs, i), 1 => module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i),
2 => module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, lhs, i), 2 => module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &lhs.data, i),
3 => module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, lhs, i), 3 => module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &lhs.data, i),
_ => {} _ => {}
} }
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1); module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1);
}); });
} }
pub fn keyswitch_inplace<DataRhs>( pub fn keyswitch_inplace<DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>, rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe { unsafe {
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>; let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
self.keyswitch(&module, &*self_ptr, rhs, scratch); self.keyswitch(&module, &*self_ptr, rhs, scratch);
} }
} }
pub fn external_product<DataLhs, DataRhs>( pub fn external_product<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &GLWECiphertext<DataLhs>, lhs: &GLWECiphertext<DataLhs>,
rhs: &GGSWCiphertext<DataRhs, FFT64>, rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnx<DataLhs>: VecZnxToRef,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
let basek: usize = self.basek(); let basek: usize = self.basek();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@@ -586,33 +528,31 @@ where
{ {
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size()); let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
(0..cols).for_each(|col_i| { (0..cols).for_each(|col_i| {
module.vec_znx_dft(&mut a_dft, col_i, lhs, col_i); module.vec_znx_dft(&mut a_dft, col_i, &lhs.data, col_i);
}); });
module.vmp_apply(&mut res_dft, &a_dft, rhs, scratch2); module.vmp_apply(&mut res_dft, &a_dft, &rhs.data, scratch2);
} }
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft); let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
(0..cols).for_each(|i| { (0..cols).for_each(|i| {
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1); module.vec_znx_big_normalize(basek, &mut self.data, i, &res_big, i, scratch1);
}); });
} }
pub fn external_product_inplace<DataRhs>( pub fn external_product_inplace<DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>, rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe { unsafe {
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>; let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
self.external_product(&module, &*self_ptr, rhs, scratch); self.external_product(&module, &*self_ptr, rhs, scratch);
} }
} }
pub(crate) fn encrypt_sk_private<DataPt, DataSk>( pub(crate) fn encrypt_sk_private<DataPt: AsRef<[u8]>, DataSk: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
pt: Option<(&GLWEPlaintext<DataPt>, usize)>, pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
@@ -621,10 +561,7 @@ where
source_xe: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnx<DataPt>: VecZnxToRef,
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(self.rank(), sk_dft.rank()); assert_eq!(self.rank(), sk_dft.rank());
@@ -660,21 +597,21 @@ where
self.data.fill_uniform(basek, i, size, source_xa); self.data.fill_uniform(basek, i, size, source_xa);
// c[i] = norm(IDFT(DFT(c[i]) * DFT(s[i]))) // c[i] = norm(IDFT(DFT(c[i]) * DFT(s[i])))
module.vec_znx_dft(&mut ci_dft, 0, self, i); module.vec_znx_dft(&mut ci_dft, 0, &self.data, i);
module.svp_apply_inplace(&mut ci_dft, 0, sk_dft, i - 1); module.svp_apply_inplace(&mut ci_dft, 0, &sk_dft.data, i - 1);
let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft);
// use c[0] as buffer, which is overwritten later by the normalization step // use c[0] as buffer, which is overwritten later by the normalization step
module.vec_znx_big_normalize(basek, self, 0, &ci_big, 0, scratch_2); module.vec_znx_big_normalize(basek, &mut self.data, 0, &ci_big, 0, scratch_2);
// c0_tmp = -c[i] * s[i] (use c[0] as buffer) // c0_tmp = -c[i] * s[i] (use c[0] as buffer)
module.vec_znx_sub_ab_inplace(&mut c0_big, 0, self, 0); module.vec_znx_sub_ab_inplace(&mut c0_big, 0, &self.data, 0);
// c[i] += m if col = i // c[i] += m if col = i
if let Some((pt, col)) = pt { if let Some((pt, col)) = pt {
if i == col { if i == col {
module.vec_znx_add_inplace(self, i, pt, 0); module.vec_znx_add_inplace(&mut self.data, i, &pt.data, 0);
module.vec_znx_normalize_inplace(basek, self, i, scratch_2); module.vec_znx_normalize_inplace(basek, &mut self.data, i, scratch_2);
} }
} }
}); });
@@ -686,15 +623,15 @@ where
// c[0] += m if col = 0 // c[0] += m if col = 0
if let Some((pt, col)) = pt { if let Some((pt, col)) = pt {
if col == 0 { if col == 0 {
module.vec_znx_add_inplace(&mut c0_big, 0, pt, 0); module.vec_znx_add_inplace(&mut c0_big, 0, &pt.data, 0);
} }
} }
// c[0] = norm(c[0]) // c[0] = norm(c[0])
module.vec_znx_normalize(basek, self, 0, &c0_big, 0, scratch_1); module.vec_znx_normalize(basek, &mut self.data, 0, &c0_big, 0, scratch_1);
} }
pub(crate) fn encrypt_pk_private<DataPt, DataPk>( pub(crate) fn encrypt_pk_private<DataPt: AsRef<[u8]>, DataPk: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
pt: Option<(&GLWEPlaintext<DataPt>, usize)>, pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
@@ -703,10 +640,7 @@ where
source_xe: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnx<DataPt>: VecZnxToRef,
VecZnxDft<DataPk, FFT64>: VecZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(self.basek(), pk.basek()); assert_eq!(self.basek(), pk.basek());
@@ -745,7 +679,7 @@ where
(0..cols).for_each(|i| { (0..cols).for_each(|i| {
let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size_pk); let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size_pk);
// ci_dft = DFT(u) * DFT(pk[i]) // ci_dft = DFT(u) * DFT(pk[i])
module.svp_apply(&mut ci_dft, 0, &u_dft, 0, pk, i); module.svp_apply(&mut ci_dft, 0, &u_dft, 0, &pk.data.data, i);
// ci_big = u * p[i] // ci_big = u * p[i]
let mut ci_big = module.vec_znx_idft_consume(ci_dft); let mut ci_big = module.vec_znx_idft_consume(ci_dft);
@@ -756,20 +690,17 @@ where
// ci_big = u * pk[i] + e + m (if col = i) // ci_big = u * pk[i] + e + m (if col = i)
if let Some((pt, col)) = pt { if let Some((pt, col)) = pt {
if col == i { if col == i {
module.vec_znx_big_add_small_inplace(&mut ci_big, 0, pt, 0); module.vec_znx_big_add_small_inplace(&mut ci_big, 0, &pt.data, 0);
} }
} }
// ct[i] = norm(ci_big) // ct[i] = norm(ci_big)
module.vec_znx_big_normalize(basek, self, i, &ci_big, 0, scratch_2); module.vec_znx_big_normalize(basek, &mut self.data, i, &ci_big, 0, scratch_2);
}); });
} }
} }
impl<DataSelf> GLWECiphertext<DataSelf> impl<DataSelf: AsRef<[u8]>> GLWECiphertext<DataSelf> {
where
VecZnx<DataSelf>: VecZnxToRef,
{
pub fn clone(&self) -> GLWECiphertext<Vec<u8>> { pub fn clone(&self) -> GLWECiphertext<Vec<u8>> {
GLWECiphertext { GLWECiphertext {
data: self.data.clone(), data: self.data.clone(),
@@ -778,16 +709,13 @@ where
} }
} }
pub fn decrypt<DataPt, DataSk>( pub fn decrypt<DataPt: AsMut<[u8]> + AsRef<[u8]>, DataSk: AsRef<[u8]>>(
&self, &self,
module: &Module<FFT64>, module: &Module<FFT64>,
pt: &mut GLWEPlaintext<DataPt>, pt: &mut GLWEPlaintext<DataPt>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>, sk_dft: &SecretKeyFourier<DataSk, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnx<DataPt>: VecZnxToMut,
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(self.rank(), sk_dft.rank()); assert_eq!(self.rank(), sk_dft.rank());
@@ -805,8 +733,8 @@ where
(1..cols).for_each(|i| { (1..cols).for_each(|i| {
// ci_dft = DFT(a[i]) * DFT(s[i]) // ci_dft = DFT(a[i]) * DFT(s[i])
let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct
module.vec_znx_dft(&mut ci_dft, 0, self, i); module.vec_znx_dft(&mut ci_dft, 0, &self.data, i);
module.svp_apply_inplace(&mut ci_dft, 0, sk_dft, i - 1); module.svp_apply_inplace(&mut ci_dft, 0, &sk_dft.data, i - 1);
let ci_big = module.vec_znx_idft_consume(ci_dft); let ci_big = module.vec_znx_idft_consume(ci_dft);
// c0_big += a[i] * s[i] // c0_big += a[i] * s[i]
@@ -815,12 +743,47 @@ where
} }
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
module.vec_znx_big_add_small_inplace(&mut c0_big, 0, self, 0); module.vec_znx_big_add_small_inplace(&mut c0_big, 0, &self.data, 0);
// pt = norm(BIG(m + e)) // pt = norm(BIG(m + e))
module.vec_znx_big_normalize(self.basek(), pt, 0, &mut c0_big, 0, scratch_1); module.vec_znx_big_normalize(self.basek(), &mut pt.data, 0, &mut c0_big, 0, scratch_1);
pt.basek = self.basek(); pt.basek = self.basek();
pt.k = pt.k().min(self.k()); pt.k = pt.k().min(self.k());
} }
} }
pub trait GLWECiphertextToRef {
fn to_ref(&self) -> GLWECiphertext<&[u8]>;
}
impl<D: AsRef<[u8]>> GLWECiphertextToRef for GLWECiphertext<D> {
fn to_ref(&self) -> GLWECiphertext<&[u8]> {
GLWECiphertext {
data: self.data.to_ref(),
basek: self.basek,
k: self.k,
}
}
}
pub trait GLWECiphertextToMut {
fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]>;
}
impl<D: AsMut<[u8]> + AsRef<[u8]>> GLWECiphertextToMut for GLWECiphertext<D> {
fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]> {
GLWECiphertext {
data: self.data.to_mut(),
basek: self.basek,
k: self.k,
}
}
}
impl<D> GLWEOps for GLWECiphertext<D>
where
D: AsRef<[u8]> + AsMut<[u8]>,
GLWECiphertext<D>: GLWECiphertextToMut + Infos + SetMetaData,
{
}

View File

@@ -1,7 +1,6 @@
use backend::{ use backend::{
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps, Backend, FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDftOps, Scratch, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc,
ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxZero,
VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -48,24 +47,6 @@ impl<T, B: Backend> GLWECiphertextFourier<T, B> {
} }
} }
impl<C, B: Backend> VecZnxDftToMut<B> for GLWECiphertextFourier<C, B>
where
VecZnxDft<C, B>: VecZnxDftToMut<B>,
{
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> VecZnxDftToRef<B> for GLWECiphertextFourier<C, B>
where
VecZnxDft<C, B>: VecZnxDftToRef<B>,
{
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
self.data.to_ref()
}
}
impl GLWECiphertextFourier<Vec<u8>, FFT64> { impl GLWECiphertextFourier<Vec<u8>, FFT64> {
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) fn idft_scratch_space(module: &Module<FFT64>, size: usize) -> usize { pub(crate) fn idft_scratch_space(module: &Module<FFT64>, size: usize) -> usize {
@@ -124,11 +105,8 @@ impl GLWECiphertextFourier<Vec<u8>, FFT64> {
} }
} }
impl<DataSelf> GLWECiphertextFourier<DataSelf, FFT64> impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWECiphertextFourier<DataSelf, FFT64> {
where pub fn encrypt_zero_sk<DataSk: AsRef<[u8]>>(
VecZnxDft<DataSelf, FFT64>: VecZnxDftToMut<FFT64>,
{
pub fn encrypt_zero_sk<DataSk>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>, sk_dft: &SecretKeyFourier<DataSk, FFT64>,
@@ -136,9 +114,7 @@ where
source_xe: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
let (vec_znx_tmp, scratch_1) = scratch.tmp_vec_znx(module, self.rank() + 1, self.size()); let (vec_znx_tmp, scratch_1) = scratch.tmp_vec_znx(module, self.rank() + 1, self.size());
let mut ct_idft = GLWECiphertext { let mut ct_idft = GLWECiphertext {
data: vec_znx_tmp, data: vec_znx_tmp,
@@ -150,16 +126,13 @@ where
ct_idft.dft(module, self); ct_idft.dft(module, self);
} }
pub fn keyswitch<DataLhs, DataRhs>( pub fn keyswitch<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &GLWECiphertextFourier<DataLhs, FFT64>, lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>, rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
let cols_out: usize = rhs.rank_out() + 1; let cols_out: usize = rhs.rank_out() + 1;
// Space fr normalized VMP result outside of DFT domain // Space fr normalized VMP result outside of DFT domain
@@ -174,34 +147,29 @@ where
res_idft.keyswitch_from_fourier(module, lhs, rhs, scratch1); res_idft.keyswitch_from_fourier(module, lhs, rhs, scratch1);
(0..cols_out).for_each(|i| { (0..cols_out).for_each(|i| {
module.vec_znx_dft(self, i, &res_idft, i); module.vec_znx_dft(&mut self.data, i, &res_idft.data, i);
}); });
} }
pub fn keyswitch_inplace<DataRhs>( pub fn keyswitch_inplace<DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>, rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe { unsafe {
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>; let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
self.keyswitch(&module, &*self_ptr, rhs, scratch); self.keyswitch(&module, &*self_ptr, rhs, scratch);
} }
} }
pub fn external_product<DataLhs, DataRhs>( pub fn external_product<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &GLWECiphertextFourier<DataLhs, FFT64>, lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>, rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
let basek: usize = self.basek(); let basek: usize = self.basek();
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
@@ -221,7 +189,7 @@ where
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size());
{ {
module.vmp_apply(&mut res_dft, lhs, rhs, scratch1); module.vmp_apply(&mut res_dft, &lhs.data, &rhs.data, scratch1);
} }
// VMP result in high precision // VMP result in high precision
@@ -231,18 +199,16 @@ where
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size()); let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size());
(0..cols).for_each(|i| { (0..cols).for_each(|i| {
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2); module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
module.vec_znx_dft(self, i, &res_small, i); module.vec_znx_dft(&mut self.data, i, &res_small, i);
}); });
} }
pub fn external_product_inplace<DataRhs>( pub fn external_product_inplace<DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>, rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe { unsafe {
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>; let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
self.external_product(&module, &*self_ptr, rhs, scratch); self.external_product(&module, &*self_ptr, rhs, scratch);
@@ -250,20 +216,14 @@ where
} }
} }
impl<DataSelf> GLWECiphertextFourier<DataSelf, FFT64> impl<DataSelf: AsRef<[u8]>> GLWECiphertextFourier<DataSelf, FFT64> {
where pub fn decrypt<DataPt: AsRef<[u8]> + AsMut<[u8]>, DataSk: AsRef<[u8]>>(
VecZnxDft<DataSelf, FFT64>: VecZnxDftToRef<FFT64>,
{
pub fn decrypt<DataPt, DataSk>(
&self, &self,
module: &Module<FFT64>, module: &Module<FFT64>,
pt: &mut GLWEPlaintext<DataPt>, pt: &mut GLWEPlaintext<DataPt>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>, sk_dft: &SecretKeyFourier<DataSk, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
VecZnx<DataPt>: VecZnxToMut,
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(self.rank(), sk_dft.rank()); assert_eq!(self.rank(), sk_dft.rank());
@@ -280,7 +240,7 @@ where
{ {
(1..cols).for_each(|i| { (1..cols).for_each(|i| {
let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct
module.svp_apply(&mut ci_dft, 0, sk_dft, i - 1, self, i); module.svp_apply(&mut ci_dft, 0, &sk_dft.data, i - 1, &self.data, i);
let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft); let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft);
module.vec_znx_big_add_inplace(&mut pt_big, 0, &ci_big, 0); module.vec_znx_big_add_inplace(&mut pt_big, 0, &ci_big, 0);
}); });
@@ -289,22 +249,24 @@ where
{ {
let (mut c0_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, self.size()); let (mut c0_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, self.size());
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
module.vec_znx_idft(&mut c0_big, 0, self, 0, scratch_2); module.vec_znx_idft(&mut c0_big, 0, &self.data, 0, scratch_2);
module.vec_znx_big_add_inplace(&mut pt_big, 0, &c0_big, 0); module.vec_znx_big_add_inplace(&mut pt_big, 0, &c0_big, 0);
} }
// pt = norm(BIG(m + e)) // pt = norm(BIG(m + e))
module.vec_znx_big_normalize(self.basek(), pt, 0, &mut pt_big, 0, scratch_1); module.vec_znx_big_normalize(self.basek(), &mut pt.data, 0, &mut pt_big, 0, scratch_1);
pt.basek = self.basek(); pt.basek = self.basek();
pt.k = pt.k().min(self.k()); pt.k = pt.k().min(self.k());
} }
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) fn idft<DataRes>(&self, module: &Module<FFT64>, res: &mut GLWECiphertext<DataRes>, scratch: &mut Scratch) pub(crate) fn idft<DataRes: AsRef<[u8]> + AsMut<[u8]>>(
where &self,
GLWECiphertext<DataRes>: VecZnxToMut, module: &Module<FFT64>,
{ res: &mut GLWECiphertext<DataRes>,
scratch: &mut Scratch,
) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(self.rank(), res.rank()); assert_eq!(self.rank(), res.rank());
@@ -316,8 +278,8 @@ where
let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 1, min_size); let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 1, min_size);
(0..self.rank() + 1).for_each(|i| { (0..self.rank() + 1).for_each(|i| {
module.vec_znx_idft(&mut res_big, 0, self, i, scratch1); module.vec_znx_idft(&mut res_big, 0, &self.data, i, scratch1);
module.vec_znx_big_normalize(self.basek(), res, i, &res_big, 0, scratch1); module.vec_znx_big_normalize(self.basek(), &mut res.data, i, &res_big, 0, scratch1);
}); });
} }
} }

View File

@@ -1,19 +1,15 @@
use backend::{FFT64, Module, Scratch, VecZnx, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero}; use backend::{FFT64, Module, Scratch, VecZnx, VecZnxOps, ZnxZero};
use crate::{ use crate::{
elem::{Infos, SetMetaData}, elem::{Infos, SetMetaData},
glwe_ciphertext::GLWECiphertext, glwe_ciphertext::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef},
}; };
impl<DataSelf> GLWECiphertext<DataSelf> pub trait GLWEOps: GLWECiphertextToMut + Infos + SetMetaData {
where fn add<A, B>(&mut self, module: &Module<FFT64>, a: &A, b: &B)
Self: Infos,
VecZnx<DataSelf>: VecZnxToMut,
{
pub fn add<A, B>(&mut self, module: &Module<FFT64>, a: &A, b: &B)
where where
A: VecZnxToRef + Infos, A: GLWECiphertextToRef + Infos,
B: VecZnxToRef + Infos, B: GLWECiphertextToRef + Infos,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -28,25 +24,28 @@ where
let max_col: usize = a.rank().max(b.rank() + 1); let max_col: usize = a.rank().max(b.rank() + 1);
let self_col: usize = self.rank() + 1; let self_col: usize = self.rank() + 1;
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref();
(0..min_col).for_each(|i| { (0..min_col).for_each(|i| {
module.vec_znx_add(self, i, a, i, b, i); module.vec_znx_add(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i);
}); });
if a.rank() > b.rank() { if a.rank() > b.rank() {
(min_col..max_col).for_each(|i| { (min_col..max_col).for_each(|i| {
module.vec_znx_copy(self, i, a, i); module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
}); });
} else { } else {
(min_col..max_col).for_each(|i| { (min_col..max_col).for_each(|i| {
module.vec_znx_copy(self, i, b, i); module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i);
}); });
} }
let size: usize = self.size(); let size: usize = self_mut.size();
let mut self_mut: VecZnx<&mut [u8]> = self.to_mut();
(max_col..self_col).for_each(|i| { (max_col..self_col).for_each(|i| {
(0..size).for_each(|j| { (0..size).for_each(|j| {
self_mut.zero_at(i, j); self_mut.data.zero_at(i, j);
}); });
}); });
@@ -54,9 +53,9 @@ where
self.set_k(a.k().max(b.k())); self.set_k(a.k().max(b.k()));
} }
pub fn add_inplace<A>(&mut self, module: &Module<FFT64>, a: &A) fn add_inplace<A>(&mut self, module: &Module<FFT64>, a: &A)
where where
A: VecZnxToRef + Infos, A: GLWECiphertextToRef + Infos,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -66,17 +65,20 @@ where
assert!(self.rank() >= a.rank()) assert!(self.rank() >= a.rank())
} }
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..a.rank() + 1).for_each(|i| { (0..a.rank() + 1).for_each(|i| {
module.vec_znx_add_inplace(self, i, a, i); module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i);
}); });
self.set_k(a.k().max(self.k())); self.set_k(a.k().max(self.k()));
} }
pub fn sub<A, B>(&mut self, module: &Module<FFT64>, a: &A, b: &B) fn sub<A, B>(&mut self, module: &Module<FFT64>, a: &A, b: &B)
where where
A: VecZnxToRef + Infos, A: GLWECiphertextToRef + Infos,
B: VecZnxToRef + Infos, B: GLWECiphertextToRef + Infos,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -91,26 +93,29 @@ where
let max_col: usize = a.rank().max(b.rank() + 1); let max_col: usize = a.rank().max(b.rank() + 1);
let self_col: usize = self.rank() + 1; let self_col: usize = self.rank() + 1;
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref();
(0..min_col).for_each(|i| { (0..min_col).for_each(|i| {
module.vec_znx_sub(self, i, a, i, b, i); module.vec_znx_sub(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i);
}); });
if a.rank() > b.rank() { if a.rank() > b.rank() {
(min_col..max_col).for_each(|i| { (min_col..max_col).for_each(|i| {
module.vec_znx_copy(self, i, a, i); module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
}); });
} else { } else {
(min_col..max_col).for_each(|i| { (min_col..max_col).for_each(|i| {
module.vec_znx_copy(self, i, b, i); module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i);
module.vec_znx_negate_inplace(self, i); module.vec_znx_negate_inplace(&mut self_mut.data, i);
}); });
} }
let size: usize = self.size(); let size: usize = self_mut.size();
let mut self_mut: VecZnx<&mut [u8]> = self.to_mut();
(max_col..self_col).for_each(|i| { (max_col..self_col).for_each(|i| {
(0..size).for_each(|j| { (0..size).for_each(|j| {
self_mut.zero_at(i, j); self_mut.data.zero_at(i, j);
}); });
}); });
@@ -118,9 +123,9 @@ where
self.set_k(a.k().max(b.k())); self.set_k(a.k().max(b.k()));
} }
pub fn sub_inplace_ab<A>(&mut self, module: &Module<FFT64>, a: &A) fn sub_inplace_ab<A>(&mut self, module: &Module<FFT64>, a: &A)
where where
A: VecZnxToRef + Infos, A: GLWECiphertextToRef + Infos,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -130,16 +135,19 @@ where
assert!(self.rank() >= a.rank()) assert!(self.rank() >= a.rank())
} }
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..a.rank() + 1).for_each(|i| { (0..a.rank() + 1).for_each(|i| {
module.vec_znx_sub_ab_inplace(self, i, a, i); module.vec_znx_sub_ab_inplace(&mut self_mut.data, i, &a_ref.data, i);
}); });
self.set_k(a.k().max(self.k())); self.set_k(a.k().max(self.k()));
} }
pub fn sub_inplace_ba<A>(&mut self, module: &Module<FFT64>, a: &A) fn sub_inplace_ba<A>(&mut self, module: &Module<FFT64>, a: &A)
where where
A: VecZnxToRef + Infos, A: GLWECiphertextToRef + Infos,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -149,16 +157,19 @@ where
assert!(self.rank() >= a.rank()) assert!(self.rank() >= a.rank())
} }
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..a.rank() + 1).for_each(|i| { (0..a.rank() + 1).for_each(|i| {
module.vec_znx_sub_ba_inplace(self, i, a, i); module.vec_znx_sub_ba_inplace(&mut self_mut.data, i, &a_ref.data, i);
}); });
self.set_k(a.k().max(self.k())); self.set_k(a.k().max(self.k()));
} }
pub fn rotate<A>(&mut self, module: &Module<FFT64>, k: i64, a: &A) fn rotate<A>(&mut self, module: &Module<FFT64>, k: i64, a: &A)
where where
A: VecZnxToRef + Infos, A: GLWECiphertextToRef + Infos,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -167,28 +178,33 @@ where
assert_eq!(self.rank(), a.rank()) assert_eq!(self.rank(), a.rank())
} }
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..a.rank() + 1).for_each(|i| { (0..a.rank() + 1).for_each(|i| {
module.vec_znx_rotate(k, self, i, a, i); module.vec_znx_rotate(k, &mut self_mut.data, i, &a_ref.data, i);
}); });
self.set_basek(a.basek()); self.set_basek(a.basek());
self.set_k(a.k()); self.set_k(a.k());
} }
pub fn rotate_inplace(&mut self, module: &Module<FFT64>, k: i64) { fn rotate_inplace(&mut self, module: &Module<FFT64>, k: i64) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(self.n(), module.n()); assert_eq!(self.n(), module.n());
} }
(0..self.rank() + 1).for_each(|i| { let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
module.vec_znx_rotate_inplace(k, self, i);
(0..self_mut.rank() + 1).for_each(|i| {
module.vec_znx_rotate_inplace(k, &mut self_mut.data, i);
}); });
} }
pub fn copy<A>(&mut self, module: &Module<FFT64>, a: &A) fn copy<A>(&mut self, module: &Module<FFT64>, a: &A)
where where
A: VecZnxToRef + Infos, A: GLWECiphertextToRef + Infos,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -197,23 +213,26 @@ where
assert_eq!(self.rank(), a.rank()); assert_eq!(self.rank(), a.rank());
} }
(0..self.rank() + 1).for_each(|i| { let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
module.vec_znx_copy(self, i, a, i); let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..self_mut.rank() + 1).for_each(|i| {
module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
}); });
self.set_k(a.k()); self.set_k(a.k());
self.set_basek(a.basek()); self.set_basek(a.basek());
} }
pub fn rsh(&mut self, k: usize, scratch: &mut Scratch) { fn rsh(&mut self, k: usize, scratch: &mut Scratch) {
let basek: usize = self.basek(); let basek: usize = self.basek();
let mut self_mut: VecZnx<&mut [u8]> = self.to_mut(); let mut self_mut: GLWECiphertext<&mut [u8]> = self.to_mut();
self_mut.rsh(basek, k, scratch); self_mut.data.rsh(basek, k, scratch);
} }
pub fn normalize<A>(&mut self, module: &Module<FFT64>, a: &A, scratch: &mut Scratch) fn normalize<A>(&mut self, module: &Module<FFT64>, a: &A, scratch: &mut Scratch)
where where
A: VecZnxToMut + Infos, A: GLWECiphertextToRef + Infos,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
@@ -222,20 +241,24 @@ where
assert_eq!(self.rank(), a.rank()); assert_eq!(self.rank(), a.rank());
} }
(0..self.rank() + 1).for_each(|i| { let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
module.vec_znx_normalize(a.basek(), self, i, a, i, scratch); let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..self_mut.rank() + 1).for_each(|i| {
module.vec_znx_normalize(a.basek(), &mut self_mut.data, i, &a_ref.data, i, scratch);
}); });
self.set_basek(a.basek()); self.set_basek(a.basek());
self.set_k(a.k()); self.set_k(a.k());
} }
pub fn normalize_inplace(&mut self, module: &Module<FFT64>, scratch: &mut Scratch) { fn normalize_inplace(&mut self, module: &Module<FFT64>, scratch: &mut Scratch) {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(self.n(), module.n()); assert_eq!(self.n(), module.n());
} }
(0..self.rank() + 1).for_each(|i| { let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
module.vec_znx_normalize_inplace(self.basek(), self, i, scratch); (0..self_mut.rank() + 1).for_each(|i| {
module.vec_znx_normalize_inplace(self_mut.basek(), &mut self_mut.data, i, scratch);
}); });
} }
} }

View File

@@ -1,4 +1,4 @@
use backend::{Backend, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef}; use backend::{Backend, Module, VecZnx, VecZnxAlloc};
use crate::{elem::Infos, utils::derive_size}; use crate::{elem::Infos, utils::derive_size};
@@ -24,24 +24,6 @@ impl<T> Infos for GLWEPlaintext<T> {
} }
} }
impl<C> VecZnxToMut for GLWEPlaintext<C>
where
VecZnx<C>: VecZnxToMut,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
self.data.to_mut()
}
}
impl<C> VecZnxToRef for GLWEPlaintext<C>
where
VecZnx<C>: VecZnxToRef,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
self.data.to_ref()
}
}
impl GLWEPlaintext<Vec<u8>> { impl GLWEPlaintext<Vec<u8>> {
pub fn alloc<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> Self { pub fn alloc<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> Self {
Self { Self {

View File

@@ -1,7 +1,6 @@
use backend::{ use backend::{
Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut, Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, VecZnxDft,
ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, ScratchOwned, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos, ZnxInfos, ZnxZero,
ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -43,10 +42,7 @@ impl<DataSelf> SecretKey<DataSelf> {
} }
} }
impl<S> SecretKey<S> impl<S: AsMut<[u8]> + AsRef<[u8]>> SecretKey<S> {
where
S: AsMut<[u8]> + AsRef<[u8]>,
{
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
(0..self.rank()).for_each(|i| { (0..self.rank()).for_each(|i| {
self.data.fill_ternary_prob(i, prob, source); self.data.fill_ternary_prob(i, prob, source);
@@ -67,24 +63,6 @@ where
} }
} }
impl<C> ScalarZnxToMut for SecretKey<C>
where
ScalarZnx<C>: ScalarZnxToMut,
{
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
self.data.to_mut()
}
}
impl<C> ScalarZnxToRef for SecretKey<C>
where
ScalarZnx<C>: ScalarZnxToRef,
{
fn to_ref(&self) -> ScalarZnx<&[u8]> {
self.data.to_ref()
}
}
pub struct SecretKeyFourier<T, B: Backend> { pub struct SecretKeyFourier<T, B: Backend> {
pub data: ScalarZnxDft<T, B>, pub data: ScalarZnxDft<T, B>,
pub dist: SecretDistribution, pub dist: SecretDistribution,
@@ -111,12 +89,10 @@ impl<B: Backend> SecretKeyFourier<Vec<u8>, B> {
dist: SecretDistribution::NONE, dist: SecretDistribution::NONE,
} }
} }
}
pub fn dft<S>(&mut self, module: &Module<FFT64>, sk: &SecretKey<S>) impl<D: AsRef<[u8]> + AsMut<[u8]>> SecretKeyFourier<D, FFT64> {
where pub fn dft<S: AsRef<[u8]>>(&mut self, module: &Module<FFT64>, sk: &SecretKey<S>) {
SecretKeyFourier<Vec<u8>, B>: ScalarZnxDftToMut<FFT64>,
SecretKey<S>: ScalarZnxToRef,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
match sk.dist { match sk.dist {
@@ -130,30 +106,12 @@ impl<B: Backend> SecretKeyFourier<Vec<u8>, B> {
} }
(0..self.rank()).for_each(|i| { (0..self.rank()).for_each(|i| {
module.svp_prepare(self, i, sk, i); module.svp_prepare(&mut self.data, i, &sk.data, i);
}); });
self.dist = sk.dist; self.dist = sk.dist;
} }
} }
impl<C, B: Backend> ScalarZnxDftToMut<B> for SecretKeyFourier<C, B>
where
ScalarZnxDft<C, B>: ScalarZnxDftToMut<B>,
{
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> ScalarZnxDftToRef<B> for SecretKeyFourier<C, B>
where
ScalarZnxDft<C, B>: ScalarZnxDftToRef<B>,
{
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
self.data.to_ref()
}
}
pub struct GLWEPublicKey<D, B: Backend> { pub struct GLWEPublicKey<D, B: Backend> {
pub data: GLWECiphertextFourier<D, B>, pub data: GLWECiphertextFourier<D, B>,
pub dist: SecretDistribution, pub dist: SecretDistribution,
@@ -190,36 +148,15 @@ impl<T, B: Backend> GLWEPublicKey<T, B> {
} }
} }
impl<C, B: Backend> VecZnxDftToMut<B> for GLWEPublicKey<C, B> impl<C: AsRef<[u8]> + AsMut<[u8]>> GLWEPublicKey<C, FFT64> {
where pub fn generate_from_sk<S: AsRef<[u8]>>(
VecZnxDft<C, B>: VecZnxDftToMut<B>,
{
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> VecZnxDftToRef<B> for GLWEPublicKey<C, B>
where
VecZnxDft<C, B>: VecZnxDftToRef<B>,
{
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
self.data.to_ref()
}
}
impl<C> GLWEPublicKey<C, FFT64> {
pub fn generate_from_sk<S>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
sk_dft: &SecretKeyFourier<S, FFT64>, sk_dft: &SecretKeyFourier<S, FFT64>,
source_xa: &mut Source, source_xa: &mut Source,
source_xe: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
) where ) {
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64>,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64> + ZnxInfos,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
match sk_dft.dist { match sk_dft.dist {

View File

@@ -1,7 +1,4 @@
use backend::{ use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module, Scratch, VecZnxDftAlloc, ZnxZero};
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef,
ScalarZnxToRef, Scratch, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, ZnxZero,
};
use sampling::source::Source; use sampling::source::Source;
use crate::{ use crate::{
@@ -52,45 +49,27 @@ impl<T, B: Backend> GLWESwitchingKey<T, B> {
} }
} }
impl<DataSelf, B: Backend> MatZnxDftToMut<B> for GLWESwitchingKey<DataSelf, B> impl<C: AsRef<[u8]>> GetRow<FFT64> for GLWESwitchingKey<C, FFT64> {
where fn get_row<R: AsMut<[u8]> + AsRef<[u8]>>(
MatZnxDft<DataSelf, B>: MatZnxDftToMut<B>, &self,
{ module: &Module<FFT64>,
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { row_i: usize,
self.0.data.to_mut() col_j: usize,
res: &mut GLWECiphertextFourier<R, FFT64>,
) {
module.vmp_extract_row(&mut res.data, &self.0.data, row_i, col_j);
} }
} }
impl<DataSelf, B: Backend> MatZnxDftToRef<B> for GLWESwitchingKey<DataSelf, B> impl<C: AsMut<[u8]> + AsRef<[u8]>> SetRow<FFT64> for GLWESwitchingKey<C, FFT64> {
where fn set_row<R: AsRef<[u8]>>(
MatZnxDft<DataSelf, B>: MatZnxDftToRef<B>, &mut self,
{ module: &Module<FFT64>,
fn to_ref(&self) -> MatZnxDft<&[u8], B> { row_i: usize,
self.0.data.to_ref() col_j: usize,
} a: &GLWECiphertextFourier<R, FFT64>,
} ) {
module.vmp_prepare_row(&mut self.0.data, row_i, col_j, &a.data);
impl<C> GetRow<FFT64> for GLWESwitchingKey<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
{
fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut R)
where
R: VecZnxDftToMut<FFT64>,
{
module.vmp_extract_row(res, self, row_i, col_j);
}
}
impl<C> SetRow<FFT64> for GLWESwitchingKey<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
{
fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, col_j: usize, a: &R)
where
R: VecZnxDftToRef<FFT64>,
{
module.vmp_prepare_row(self, row_i, col_j, a);
} }
} }
@@ -147,11 +126,8 @@ impl GLWESwitchingKey<Vec<u8>, FFT64> {
tmp + ggsw tmp + ggsw
} }
} }
impl<DataSelf> GLWESwitchingKey<DataSelf, FFT64> impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWESwitchingKey<DataSelf, FFT64> {
where pub fn generate_from_sk<DataSkIn: AsRef<[u8]>, DataSkOut: AsRef<[u8]>>(
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64>,
{
pub fn encrypt_sk<DataSkIn, DataSkOut>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
sk_in: &SecretKey<DataSkIn>, sk_in: &SecretKey<DataSkIn>,
@@ -160,11 +136,8 @@ where
source_xe: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
ScalarZnx<DataSkIn>: ScalarZnxToRef, self.0.encrypt_sk(
ScalarZnxDft<DataSkOut, FFT64>: ScalarZnxDftToRef<FFT64>,
{
self.0.generate_from_sk(
module, module,
&sk_in.data, &sk_in.data,
sk_out_dft, sk_out_dft,
@@ -175,16 +148,13 @@ where
); );
} }
pub fn keyswitch<DataLhs, DataRhs>( pub fn keyswitch<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &GLWESwitchingKey<DataLhs, FFT64>, lhs: &GLWESwitchingKey<DataLhs, FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>, rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!( assert_eq!(
@@ -243,14 +213,12 @@ where
}); });
} }
pub fn keyswitch_inplace<DataRhs>( pub fn keyswitch_inplace<DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>, rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!( assert_eq!(
@@ -279,16 +247,13 @@ where
}); });
} }
pub fn external_product<DataLhs, DataRhs>( pub fn external_product<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
lhs: &GLWESwitchingKey<DataLhs, FFT64>, lhs: &GLWESwitchingKey<DataLhs, FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>, rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!( assert_eq!(
@@ -347,14 +312,12 @@ where
}); });
} }
pub fn external_product_inplace<DataRhs>( pub fn external_product_inplace<DataRhs: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>, rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!( assert_eq!(

View File

@@ -1,7 +1,4 @@
use backend::{ use backend::{Backend, FFT64, MatZnxDft, Module, ScalarZnx, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxDftOps};
Backend, FFT64, MatZnxDft, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftAlloc,
ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnxDftOps, VecZnxDftToRef,
};
use sampling::source::Source; use sampling::source::Source;
use crate::{ use crate::{
@@ -61,11 +58,8 @@ impl TensorKey<Vec<u8>, FFT64> {
} }
} }
impl<DataSelf> TensorKey<DataSelf, FFT64> impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> TensorKey<DataSelf, FFT64> {
where pub fn generate_from_sk<DataSk: AsRef<[u8]>>(
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64>,
{
pub fn generate_from_sk<DataSk>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>, sk_dft: &SecretKeyFourier<DataSk, FFT64>,
@@ -73,9 +67,7 @@ where
source_xe: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
ScalarZnxDft<DataSk, FFT64>: VecZnxDftToRef<FFT64> + ScalarZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
{ {
assert_eq!(self.rank(), sk_dft.rank()); assert_eq!(self.rank(), sk_dft.rank());
@@ -98,7 +90,7 @@ where
dist: sk_dft.dist, dist: sk_dft.dist,
}; };
self.at_mut(i, j).encrypt_sk( self.at_mut(i, j).generate_from_sk(
module, &sk_ij, sk_dft, source_xa, source_xe, sigma, scratch1, module, &sk_ij, sk_dft, source_xa, source_xe, sigma, scratch1,
); );
}); });
@@ -115,10 +107,7 @@ where
} }
} }
impl<DataSelf> TensorKey<DataSelf, FFT64> impl<DataSelf: AsRef<[u8]>> TensorKey<DataSelf, FFT64> {
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToRef<FFT64>,
{
// Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j]) // Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j])
pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey<DataSelf, FFT64> { pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey<DataSelf, FFT64> {
if i > j { if i > j {

View File

@@ -87,7 +87,13 @@ fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize,
let mut sk_auto: SecretKey<Vec<u8>> = SecretKey::alloc(&module, rank); let mut sk_auto: SecretKey<Vec<u8>> = SecretKey::alloc(&module, rank);
sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk
(0..rank).for_each(|i| { (0..rank).for_each(|i| {
module.scalar_znx_automorphism(module.galois_element_inv(p0 * p1), &mut sk_auto, i, &sk, i); module.scalar_znx_automorphism(
module.galois_element_inv(p0 * p1),
&mut sk_auto.data,
i,
&sk.data,
i,
);
}); });
let mut sk_auto_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank); let mut sk_auto_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank);
@@ -98,7 +104,7 @@ fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize,
auto_key_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft); auto_key_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i); module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk.data, col_i);
let noise_have: f64 = pt.data.std(0, basek).log2(); let noise_have: f64 = pt.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product( let noise_want: f64 = log2_std_noise_gglwe_product(
@@ -178,7 +184,13 @@ fn test_automorphism_inplace(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk
let mut sk_auto: SecretKey<Vec<u8>> = SecretKey::alloc(&module, rank); let mut sk_auto: SecretKey<Vec<u8>> = SecretKey::alloc(&module, rank);
sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk
(0..rank).for_each(|i| { (0..rank).for_each(|i| {
module.scalar_znx_automorphism(module.galois_element_inv(p0 * p1), &mut sk_auto, i, &sk, i); module.scalar_znx_automorphism(
module.galois_element_inv(p0 * p1),
&mut sk_auto.data,
i,
&sk.data,
i,
);
}); });
let mut sk_auto_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank); let mut sk_auto_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank);
@@ -189,7 +201,7 @@ fn test_automorphism_inplace(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk
auto_key.get_row(&module, row_i, col_i, &mut ct_glwe_dft); auto_key.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow()); ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i); module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk.data, col_i);
let noise_have: f64 = pt.data.std(0, basek).log2(); let noise_have: f64 = pt.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product( let noise_want: f64 = log2_std_noise_gglwe_product(

View File

@@ -100,7 +100,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank_out); let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank_out);
sk_out_dft.dft(&module, &sk_out); sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk( ksk.generate_from_sk(
&module, &module,
&sk_in, &sk_in,
&sk_out_dft, &sk_out_dft,
@@ -117,7 +117,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in
(0..ksk.rows()).for_each(|row_i| { (0..ksk.rows()).for_each(|row_i| {
ksk.get_row(&module, row_i, col_i, &mut ct_glwe_fourier); ksk.get_row(&module, row_i, col_i, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); ct_glwe_fourier.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i); module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_in.data, col_i);
let std_pt: f64 = pt.data.std(0, basek) * (k_ksk as f64).exp2(); let std_pt: f64 = pt.data.std(0, basek) * (k_ksk as f64).exp2();
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
}); });
@@ -179,7 +179,7 @@ fn test_key_switch(
sk2_dft.dft(&module, &sk2); sk2_dft.dft(&module, &sk2);
// gglwe_{s1}(s0) = s0 -> s1 // gglwe_{s1}(s0) = s0 -> s1
ct_gglwe_s0s1.encrypt_sk( ct_gglwe_s0s1.generate_from_sk(
&module, &module,
&sk0, &sk0,
&sk1_dft, &sk1_dft,
@@ -190,7 +190,7 @@ fn test_key_switch(
); );
// gglwe_{s2}(s1) -> s1 -> s2 // gglwe_{s2}(s1) -> s1 -> s2
ct_gglwe_s1s2.encrypt_sk( ct_gglwe_s1s2.generate_from_sk(
&module, &module,
&sk1, &sk1,
&sk2_dft, &sk2_dft,
@@ -211,7 +211,7 @@ fn test_key_switch(
(0..ct_gglwe_s0s2.rows()).for_each(|row_i| { (0..ct_gglwe_s0s2.rows()).for_each(|row_i| {
ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft); ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, col_i); module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk0.data, col_i);
let noise_have: f64 = pt.data.std(0, basek).log2(); let noise_have: f64 = pt.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product( let noise_want: f64 = log2_std_noise_gglwe_product(
@@ -280,7 +280,7 @@ fn test_key_switch_inplace(log_n: usize, basek: usize, k_ksk: usize, sigma: f64,
sk2_dft.dft(&module, &sk2); sk2_dft.dft(&module, &sk2);
// gglwe_{s1}(s0) = s0 -> s1 // gglwe_{s1}(s0) = s0 -> s1
ct_gglwe_s0s1.encrypt_sk( ct_gglwe_s0s1.generate_from_sk(
&module, &module,
&sk0, &sk0,
&sk1_dft, &sk1_dft,
@@ -291,7 +291,7 @@ fn test_key_switch_inplace(log_n: usize, basek: usize, k_ksk: usize, sigma: f64,
); );
// gglwe_{s2}(s1) -> s1 -> s2 // gglwe_{s2}(s1) -> s1 -> s2
ct_gglwe_s1s2.encrypt_sk( ct_gglwe_s1s2.generate_from_sk(
&module, &module,
&sk1, &sk1,
&sk2_dft, &sk2_dft,
@@ -314,7 +314,7 @@ fn test_key_switch_inplace(log_n: usize, basek: usize, k_ksk: usize, sigma: f64,
(0..ct_gglwe_s0s2.rows()).for_each(|row_i| { (0..ct_gglwe_s0s2.rows()).for_each(|row_i| {
ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft); ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow()); ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, col_i); module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk0.data, col_i);
let noise_have: f64 = pt.data.std(0, basek).log2(); let noise_have: f64 = pt.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product( let noise_want: f64 = log2_std_noise_gglwe_product(
@@ -385,7 +385,7 @@ fn test_external_product(log_n: usize, basek: usize, k: usize, sigma: f64, rank_
sk_out_dft.dft(&module, &sk_out); sk_out_dft.dft(&module, &sk_out);
// gglwe_{s1}(s0) = s0 -> s1 // gglwe_{s1}(s0) = s0 -> s1
ct_gglwe_in.encrypt_sk( ct_gglwe_in.generate_from_sk(
&module, &module,
&sk_in, &sk_in,
&sk_out_dft, &sk_out_dft,
@@ -432,7 +432,7 @@ fn test_external_product(log_n: usize, basek: usize, k: usize, sigma: f64, rank_
(0..ct_gglwe_out.rows()).for_each(|row_i| { (0..ct_gglwe_out.rows()).for_each(|row_i| {
ct_gglwe_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft); ct_gglwe_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i); module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_in.data, col_i);
let noise_have: f64 = pt.data.std(0, basek).log2(); let noise_have: f64 = pt.data.std(0, basek).log2();
@@ -505,7 +505,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k: usize, sigma: f6
sk_out_dft.dft(&module, &sk_out); sk_out_dft.dft(&module, &sk_out);
// gglwe_{s1}(s0) = s0 -> s1 // gglwe_{s1}(s0) = s0 -> s1
ct_gglwe.encrypt_sk( ct_gglwe.generate_from_sk(
&module, &module,
&sk_in, &sk_in,
&sk_out_dft, &sk_out_dft,
@@ -539,7 +539,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k: usize, sigma: f6
(0..ct_gglwe.rows()).for_each(|row_i| { (0..ct_gglwe.rows()).for_each(|row_i| {
ct_gglwe.get_row(&module, row_i, col_i, &mut ct_glwe_dft); ct_gglwe.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow()); ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i); module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_in.data, col_i);
let noise_have: f64 = pt.data.std(0, basek).log2(); let noise_have: f64 = pt.data.std(0, basek).log2();

View File

@@ -116,21 +116,21 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank:
(0..ct.rank() + 1).for_each(|col_j| { (0..ct.rank() + 1).for_each(|col_j| {
(0..ct.rows()).for_each(|row_i| { (0..ct.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); module.vec_znx_add_scalar_inplace(&mut pt_want.data, 0, row_i, &pt_scalar, 0);
// mul with sk[col_j-1] // mul with sk[col_j-1]
if col_j > 0 { if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); module.vec_znx_dft(&mut pt_dft, 0, &pt_want.data, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow());
} }
ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
let std_pt: f64 = pt_have.data.std(0, basek) * (k_ggsw as f64).exp2(); let std_pt: f64 = pt_have.data.std(0, basek) * (k_ggsw as f64).exp2();
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
@@ -185,7 +185,7 @@ fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64)
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank); let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank);
sk_out_dft.dft(&module, &sk_out); sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk( ksk.generate_from_sk(
&module, &module,
&sk_in, &sk_in,
&sk_out_dft, &sk_out_dft,
@@ -223,21 +223,21 @@ fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64)
(0..ct_out.rank() + 1).for_each(|col_j| { (0..ct_out.rank() + 1).for_each(|col_j| {
(0..ct_out.rows()).for_each(|row_i| { (0..ct_out.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); module.vec_znx_add_scalar_inplace(&mut pt_want.data, 0, row_i, &pt_scalar, 0);
// mul with sk[col_j-1] // mul with sk[col_j-1]
if col_j > 0 { if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); module.vec_znx_dft(&mut pt_dft, 0, &pt_want.data, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft, col_j - 1); module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft.data, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow());
} }
ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = noise_ggsw_keyswitch( let noise_want: f64 = noise_ggsw_keyswitch(
@@ -304,7 +304,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k: usize, rank: usize, sig
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank); let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank);
sk_out_dft.dft(&module, &sk_out); sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk( ksk.generate_from_sk(
&module, &module,
&sk_in, &sk_in,
&sk_out_dft, &sk_out_dft,
@@ -342,21 +342,21 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k: usize, rank: usize, sig
(0..ct.rank() + 1).for_each(|col_j| { (0..ct.rank() + 1).for_each(|col_j| {
(0..ct.rows()).for_each(|row_i| { (0..ct.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); module.vec_znx_add_scalar_inplace(&mut pt_want.data, 0, row_i, &pt_scalar, 0);
// mul with sk[col_j-1] // mul with sk[col_j-1]
if col_j > 0 { if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); module.vec_znx_dft(&mut pt_dft, 0, &pt_want.data, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft, col_j - 1); module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft.data, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow());
} }
ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = noise_ggsw_keyswitch( let noise_want: f64 = noise_ggsw_keyswitch(
@@ -514,21 +514,21 @@ fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize,
(0..ct_out.rank() + 1).for_each(|col_j| { (0..ct_out.rank() + 1).for_each(|col_j| {
(0..ct_out.rows()).for_each(|row_i| { (0..ct_out.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); module.vec_znx_add_scalar_inplace(&mut pt_want.data, 0, row_i, &pt_scalar, 0);
// mul with sk[col_j-1] // mul with sk[col_j-1]
if col_j > 0 { if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); module.vec_znx_dft(&mut pt_dft, 0, &pt_want.data, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow());
} }
ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = noise_ggsw_keyswitch( let noise_want: f64 = noise_ggsw_keyswitch(
@@ -627,21 +627,21 @@ fn test_automorphism_inplace(p: i64, log_n: usize, basek: usize, k: usize, rank:
(0..ct.rank() + 1).for_each(|col_j| { (0..ct.rank() + 1).for_each(|col_j| {
(0..ct.rows()).for_each(|row_i| { (0..ct.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0); module.vec_znx_add_scalar_inplace(&mut pt_want.data, 0, row_i, &pt_scalar, 0);
// mul with sk[col_j-1] // mul with sk[col_j-1]
if col_j > 0 { if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); module.vec_znx_dft(&mut pt_dft, 0, &pt_want.data, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow());
} }
ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = noise_ggsw_keyswitch( let noise_want: f64 = noise_ggsw_keyswitch(
@@ -740,19 +740,19 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize,
(0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| { (0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| {
(0..ct_ggsw_lhs_out.rows()).for_each(|row_i| { (0..ct_ggsw_lhs_out.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0); module.vec_znx_add_scalar_inplace(&mut pt_want.data, 0, row_i, &pt_ggsw_lhs, 0);
if col_j > 0 { if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); module.vec_znx_dft(&mut pt_dft, 0, &pt_want.data, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow());
} }
ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt.data, 0, &pt_want.data, 0);
let noise_have: f64 = pt.data.std(0, basek).log2(); let noise_have: f64 = pt.data.std(0, basek).log2();
@@ -853,19 +853,19 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, rank
(0..ct_ggsw_lhs.rank() + 1).for_each(|col_j| { (0..ct_ggsw_lhs.rank() + 1).for_each(|col_j| {
(0..ct_ggsw_lhs.rows()).for_each(|row_i| { (0..ct_ggsw_lhs.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0); module.vec_znx_add_scalar_inplace(&mut pt_want.data, 0, row_i, &pt_ggsw_lhs, 0);
if col_j > 0 { if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0); module.vec_znx_dft(&mut pt_dft, 0, &pt_want.data, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1); module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft.data, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0); module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow()); module.vec_znx_big_normalize(basek, &mut pt_want.data, 0, &pt_big, 0, scratch.borrow());
} }
ct_ggsw_lhs.get_row(&module, row_i, col_j, &mut ct_glwe_fourier); ct_ggsw_lhs.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt.data, 0, &pt_want.data, 0);
let noise_have: f64 = pt.data.std(0, basek).log2(); let noise_have: f64 = pt.data.std(0, basek).log2();

View File

@@ -1,6 +1,6 @@
use backend::{ use backend::{
Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut,
ZnxViewMut, ZnxZero, ZnxZero,
}; };
use itertools::izip; use itertools::izip;
use sampling::source::Source; use sampling::source::Source;
@@ -232,7 +232,7 @@ fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma:
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); module.vec_znx_sub_ab_inplace(&mut pt_want.data, 0, &pt_have.data, 0);
let noise_have: f64 = pt_want.data.std(0, basek).log2(); let noise_have: f64 = pt_want.data.std(0, basek).log2();
let noise_want: f64 = ((((rank as f64) + 1.0) * module.n() as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (k_ct as f64); let noise_want: f64 = ((((rank as f64) + 1.0) * module.n() as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (k_ct as f64);
@@ -299,7 +299,7 @@ fn test_keyswitch(
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank_out); let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank_out);
sk_out_dft.dft(&module, &sk_out); sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk( ksk.generate_from_sk(
&module, &module,
&sk_in, &sk_in,
&sk_out_dft, &sk_out_dft,
@@ -323,7 +323,7 @@ fn test_keyswitch(
ct_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); ct_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product( let noise_want: f64 = log2_std_noise_gglwe_product(
@@ -384,7 +384,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize,
let mut sk1_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank); let mut sk1_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank);
sk1_dft.dft(&module, &sk1); sk1_dft.dft(&module, &sk1);
ct_grlwe.encrypt_sk( ct_grlwe.generate_from_sk(
&module, &module,
&sk0, &sk0,
&sk1_dft, &sk1_dft,
@@ -408,7 +408,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize,
ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow()); ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product( let noise_want: f64 = log2_std_noise_gglwe_product(
@@ -494,9 +494,9 @@ fn test_automorphism(
ct_out.automorphism(&module, &ct_in, &autokey, scratch.borrow()); ct_out.automorphism(&module, &ct_in, &autokey, scratch.borrow());
ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_automorphism_inplace(p, &mut pt_want, 0); module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
module.vec_znx_normalize_inplace(basek, &mut pt_have, 0, scratch.borrow()); module.vec_znx_normalize_inplace(basek, &mut pt_have.data, 0, scratch.borrow());
let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_have: f64 = pt_have.data.std(0, basek).log2();
@@ -576,9 +576,9 @@ fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usiz
ct.automorphism_inplace(&module, &autokey, scratch.borrow()); ct.automorphism_inplace(&module, &autokey, scratch.borrow());
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_automorphism_inplace(p, &mut pt_want, 0); module.vec_znx_automorphism_inplace(p, &mut pt_want.data, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
module.vec_znx_normalize_inplace(basek, &mut pt_have, 0, scratch.borrow()); module.vec_znx_normalize_inplace(basek, &mut pt_have.data, 0, scratch.borrow());
let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product( let noise_want: f64 = log2_std_noise_gglwe_product(
@@ -623,7 +623,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi
.data .data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa); .fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
pt_want.to_mut().at_mut(0, 0)[1] = 1; pt_want.data.at_mut(0, 0)[1] = 1;
let k: usize = 1; let k: usize = 1;
@@ -672,9 +672,9 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi
ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_have: f64 = pt_have.data.std(0, basek).log2();
@@ -726,7 +726,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct
.data .data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa); .fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
pt_want.to_mut().at_mut(0, 0)[1] = 1; pt_want.data.at_mut(0, 0)[1] = 1;
let k: usize = 1; let k: usize = 1;
@@ -769,9 +769,9 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct
ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_have: f64 = pt_have.data.std(0, basek).log2();

View File

@@ -8,7 +8,7 @@ use crate::{
keyswitch_key::GLWESwitchingKey, keyswitch_key::GLWESwitchingKey,
test_fft64::{gglwe::log2_std_noise_gglwe_product, ggsw::noise_ggsw_product}, test_fft64::{gglwe::log2_std_noise_gglwe_product, ggsw::noise_ggsw_product},
}; };
use backend::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut}; use backend::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, ZnxViewMut};
use sampling::source::Source; use sampling::source::Source;
#[test] #[test]
@@ -104,7 +104,7 @@ fn test_keyswitch(
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank_out); let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank_out);
sk_out_dft.dft(&module, &sk_out); sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk( ksk.generate_from_sk(
&module, &module,
&sk_in, &sk_in,
&sk_out_dft, &sk_out_dft,
@@ -130,7 +130,7 @@ fn test_keyswitch(
ct_glwe_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); ct_glwe_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product( let noise_want: f64 = log2_std_noise_gglwe_product(
@@ -192,7 +192,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize,
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank); let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::alloc(&module, rank);
sk_out_dft.dft(&module, &sk_out); sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk( ksk.generate_from_sk(
&module, &module,
&sk_in, &sk_in,
&sk_out_dft, &sk_out_dft,
@@ -218,7 +218,7 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize,
ct_glwe.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow()); ct_glwe.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product( let noise_want: f64 = log2_std_noise_gglwe_product(
@@ -265,7 +265,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi
.data .data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa); .fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
pt_want.to_mut().at_mut(0, 0)[1] = 1; pt_want.data.at_mut(0, 0)[1] = 1;
let k: usize = 1; let k: usize = 1;
@@ -310,9 +310,9 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usi
ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_have: f64 = pt_have.data.std(0, basek).log2();
@@ -365,7 +365,7 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct
.data .data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa); .fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
pt_want.to_mut().at_mut(0, 0)[1] = 1; pt_want.data.at_mut(0, 0)[1] = 1;
let k: usize = 1; let k: usize = 1;
@@ -410,9 +410,9 @@ fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0); module.vec_znx_rotate_inplace(k as i64, &mut pt_want.data, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0); module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt_want.data, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2(); let noise_have: f64 = pt_have.data.std(0, basek).log2();

View File

@@ -69,7 +69,7 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize
.at(i, j) .at(i, j)
.get_row(&module, row_i, col_i, &mut ct_glwe_fourier); .get_row(&module, row_i, col_i, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_ij, col_i); module.vec_znx_sub_scalar_inplace(&mut pt.data, 0, row_i, &sk_ij, col_i);
let std_pt: f64 = pt.data.std(0, basek) * (k as f64).exp2(); let std_pt: f64 = pt.data.std(0, basek) * (k as f64).exp2();
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt); assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
}); });

View File

@@ -91,8 +91,8 @@ fn test_trace_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank: us
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); module.vec_znx_sub_ab_inplace(&mut pt_want.data, 0, &pt_have.data, 0);
module.vec_znx_normalize_inplace(basek, &mut pt_want, 0, scratch.borrow()); module.vec_znx_normalize_inplace(basek, &mut pt_want.data, 0, scratch.borrow());
let noise_have = pt_want.data.std(0, basek).log2(); let noise_have = pt_want.data.std(0, basek).log2();

View File

@@ -1,8 +1,13 @@
use std::collections::HashMap; use std::collections::HashMap;
use backend::{FFT64, MatZnxDft, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxToMut, VecZnxToRef}; use backend::{FFT64, Module, Scratch};
use crate::{automorphism::AutomorphismKey, glwe_ciphertext::GLWECiphertext}; use crate::{
automorphism::AutomorphismKey,
elem::{Infos, SetMetaData},
glwe_ciphertext::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef},
glwe_ops::GLWEOps,
};
impl GLWECiphertext<Vec<u8>> { impl GLWECiphertext<Vec<u8>> {
pub fn trace_galois_elements(module: &Module<FFT64>) -> Vec<i64> { pub fn trace_galois_elements(module: &Module<FFT64>) -> Vec<i64> {
@@ -32,11 +37,11 @@ impl GLWECiphertext<Vec<u8>> {
} }
} }
impl<DataSelf> GLWECiphertext<DataSelf> impl<DataSelf: AsRef<[u8]> + AsMut<[u8]>> GLWECiphertext<DataSelf>
where where
VecZnx<DataSelf>: VecZnxToMut, GLWECiphertext<DataSelf>: GLWECiphertextToMut + Infos + SetMetaData,
{ {
pub fn trace<DataLhs, DataAK>( pub fn trace<DataLhs: AsRef<[u8]>, DataAK: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
start: usize, start: usize,
@@ -45,23 +50,20 @@ where
auto_keys: &HashMap<i64, AutomorphismKey<DataAK, FFT64>>, auto_keys: &HashMap<i64, AutomorphismKey<DataAK, FFT64>>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) where
VecZnx<DataLhs>: VecZnxToRef, GLWECiphertext<DataLhs>: GLWECiphertextToRef + Infos,
MatZnxDft<DataAK, FFT64>: MatZnxDftToRef<FFT64>,
{ {
self.copy(module, lhs); self.copy(module, lhs);
self.trace_inplace(module, start, end, auto_keys, scratch); self.trace_inplace(module, start, end, auto_keys, scratch);
} }
pub fn trace_inplace<DataAK>( pub fn trace_inplace<DataAK: AsRef<[u8]>>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
start: usize, start: usize,
end: usize, end: usize,
auto_keys: &HashMap<i64, AutomorphismKey<DataAK, FFT64>>, auto_keys: &HashMap<i64, AutomorphismKey<DataAK, FFT64>>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) {
MatZnxDft<DataAK, FFT64>: MatZnxDftToRef<FFT64>,
{
(start..end).for_each(|i| { (start..end).for_each(|i| {
self.rsh(1, scratch); self.rsh(1, scratch);