Added basic GLWE ops

This commit is contained in:
Jean-Philippe Bossuat
2025-05-22 16:08:44 +02:00
parent dbbbe2bd92
commit 3084978976
22 changed files with 535 additions and 294 deletions

View File

@@ -90,7 +90,7 @@ fn main() {
// ct[0] <- ct[0] + e // ct[0] <- ct[0] + e
ct.add_normal( ct.add_normal(
basek, basek,
0, // Selects the first column of ct (ct[0]) 0, // Selects the first column of ct (ct[0])
basek * ct_size, // Scaling of the noise: 2^{-basek * limbs} basek * ct_size, // Scaling of the noise: 2^{-basek * limbs}
&mut source, &mut source,
3.2, // Standard deviation 3.2, // Standard deviation

View File

@@ -1,232 +1,232 @@
use crate::znx_base::ZnxInfos; use crate::znx_base::ZnxInfos;
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned};
use std::marker::PhantomData; use std::marker::PhantomData;
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], /// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
/// stored as a 3D matrix in the DFT domain in a single contiguous array. /// stored as a 3D matrix in the DFT domain in a single contiguous array.
/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft]. /// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft].
/// ///
/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft]. /// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft].
/// See the trait [MatZnxDftOps] for additional information. /// See the trait [MatZnxDftOps] for additional information.
pub struct MatZnxDft<D, B: Backend> { pub struct MatZnxDft<D, B: Backend> {
data: D, data: D,
n: usize, n: usize,
size: usize, size: usize,
rows: usize, rows: usize,
cols_in: usize, cols_in: usize,
cols_out: usize, cols_out: usize,
_phantom: PhantomData<B>, _phantom: PhantomData<B>,
} }
impl<D, B: Backend> ZnxInfos for MatZnxDft<D, B> { impl<D, B: Backend> ZnxInfos for MatZnxDft<D, B> {
fn cols(&self) -> usize { fn cols(&self) -> usize {
self.cols_in self.cols_in
} }
fn rows(&self) -> usize { fn rows(&self) -> usize {
self.rows self.rows
} }
fn n(&self) -> usize { fn n(&self) -> usize {
self.n self.n
} }
fn size(&self) -> usize { fn size(&self) -> usize {
self.size self.size
} }
} }
impl<D> ZnxSliceSize for MatZnxDft<D, FFT64> { impl<D> ZnxSliceSize for MatZnxDft<D, FFT64> {
fn sl(&self) -> usize { fn sl(&self) -> usize {
self.n() * self.cols_out() self.n() * self.cols_out()
} }
} }
impl<D, B: Backend> DataView for MatZnxDft<D, B> { impl<D, B: Backend> DataView for MatZnxDft<D, B> {
type D = D; type D = D;
fn data(&self) -> &Self::D { fn data(&self) -> &Self::D {
&self.data &self.data
} }
} }
impl<D, B: Backend> DataViewMut for MatZnxDft<D, B> { impl<D, B: Backend> DataViewMut for MatZnxDft<D, B> {
fn data_mut(&mut self) -> &mut Self::D { fn data_mut(&mut self) -> &mut Self::D {
&mut self.data &mut self.data
} }
} }
impl<D: AsRef<[u8]>> ZnxView for MatZnxDft<D, FFT64> { impl<D: AsRef<[u8]>> ZnxView for MatZnxDft<D, FFT64> {
type Scalar = f64; type Scalar = f64;
} }
impl<D, B: Backend> MatZnxDft<D, B> { impl<D, B: Backend> MatZnxDft<D, B> {
pub fn cols_in(&self) -> usize { pub fn cols_in(&self) -> usize {
self.cols_in self.cols_in
} }
pub fn cols_out(&self) -> usize { pub fn cols_out(&self) -> usize {
self.cols_out self.cols_out
} }
} }
impl<D: From<Vec<u8>>, B: Backend> MatZnxDft<D, B> { impl<D: From<Vec<u8>>, B: Backend> MatZnxDft<D, B> {
pub(crate) fn bytes_of(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { pub(crate) fn bytes_of(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
unsafe { unsafe {
crate::ffi::vmp::bytes_of_vmp_pmat( crate::ffi::vmp::bytes_of_vmp_pmat(
module.ptr, module.ptr,
(rows * cols_in) as u64, (rows * cols_in) as u64,
(size * cols_out) as u64, (size * cols_out) as u64,
) as usize ) as usize
} }
} }
pub(crate) fn new(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { pub(crate) fn new(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size)); let data: Vec<u8> = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size));
Self { Self {
data: data.into(), data: data.into(),
n: module.n(), n: module.n(),
size, size,
rows, rows,
cols_in, cols_in,
cols_out, cols_out,
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
pub(crate) fn new_from_bytes( pub(crate) fn new_from_bytes(
module: &Module<B>, module: &Module<B>,
rows: usize, rows: usize,
cols_in: usize, cols_in: usize,
cols_out: usize, cols_out: usize,
size: usize, size: usize,
bytes: impl Into<Vec<u8>>, bytes: impl Into<Vec<u8>>,
) -> Self { ) -> Self {
let data: Vec<u8> = bytes.into(); let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size)); assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size));
Self { Self {
data: data.into(), data: data.into(),
n: module.n(), n: module.n(),
size, size,
rows, rows,
cols_in, cols_in,
cols_out, cols_out,
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
} }
impl<D: AsRef<[u8]>> MatZnxDft<D, FFT64> { impl<D: AsRef<[u8]>> MatZnxDft<D, FFT64> {
/// Returns a copy of the backend array at index (i, j) of the [MatZnxDft]. /// Returns a copy of the backend array at index (i, j) of the [MatZnxDft].
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `row`: row index (i). /// * `row`: row index (i).
/// * `col`: col index (j). /// * `col`: col index (j).
#[allow(dead_code)] #[allow(dead_code)]
fn at(&self, row: usize, col: usize) -> Vec<f64> { fn at(&self, row: usize, col: usize) -> Vec<f64> {
let n: usize = self.n(); let n: usize = self.n();
let mut res: Vec<f64> = alloc_aligned(n); let mut res: Vec<f64> = alloc_aligned(n);
if n < 8 { if n < 8 {
res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]); res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]);
} else { } else {
(0..n >> 3).for_each(|blk| { (0..n >> 3).for_each(|blk| {
res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]);
}); });
} }
res res
} }
#[allow(dead_code)] #[allow(dead_code)]
fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] {
let nrows: usize = self.rows(); let nrows: usize = self.rows();
let nsize: usize = self.size(); let nsize: usize = self.size();
if col == (nsize - 1) && (nsize & 1 == 1) { if col == (nsize - 1) && (nsize & 1 == 1) {
&self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..] &self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..]
} else { } else {
&self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] &self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..]
} }
} }
} }
pub type MatZnxDftOwned<B> = MatZnxDft<Vec<u8>, B>; pub type MatZnxDftOwned<B> = MatZnxDft<Vec<u8>, B>;
pub trait MatZnxDftToRef<B: Backend> { pub trait MatZnxDftToRef<B: Backend> {
fn to_ref(&self) -> MatZnxDft<&[u8], B>; fn to_ref(&self) -> MatZnxDft<&[u8], B>;
} }
pub trait MatZnxDftToMut<B: Backend> { pub trait MatZnxDftToMut<B: Backend>: MatZnxDftToRef<B> {
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>; fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>;
} }
impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<Vec<u8>, B> { impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<Vec<u8>, B> {
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
MatZnxDft { MatZnxDft {
data: self.data.as_mut_slice(), data: self.data.as_mut_slice(),
n: self.n, n: self.n,
rows: self.rows, rows: self.rows,
cols_in: self.cols_in, cols_in: self.cols_in,
cols_out: self.cols_out, cols_out: self.cols_out,
size: self.size, size: self.size,
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
} }
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<Vec<u8>, B> { impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> MatZnxDft<&[u8], B> { fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft { MatZnxDft {
data: self.data.as_slice(), data: self.data.as_slice(),
n: self.n, n: self.n,
rows: self.rows, rows: self.rows,
cols_in: self.cols_in, cols_in: self.cols_in,
cols_out: self.cols_out, cols_out: self.cols_out,
size: self.size, size: self.size,
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
} }
impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<&mut [u8], B> { impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
MatZnxDft { MatZnxDft {
data: self.data, data: self.data,
n: self.n, n: self.n,
rows: self.rows, rows: self.rows,
cols_in: self.cols_in, cols_in: self.cols_in,
cols_out: self.cols_out, cols_out: self.cols_out,
size: self.size, size: self.size,
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
} }
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&mut [u8], B> { impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&mut [u8], B> {
fn to_ref(&self) -> MatZnxDft<&[u8], B> { fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft { MatZnxDft {
data: self.data, data: self.data,
n: self.n, n: self.n,
rows: self.rows, rows: self.rows,
cols_in: self.cols_in, cols_in: self.cols_in,
cols_out: self.cols_out, cols_out: self.cols_out,
size: self.size, size: self.size,
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
} }
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&[u8], B> { impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&[u8], B> {
fn to_ref(&self) -> MatZnxDft<&[u8], B> { fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft { MatZnxDft {
data: self.data, data: self.data,
n: self.n, n: self.n,
rows: self.rows, rows: self.rows,
cols_in: self.cols_in, cols_in: self.cols_in,
cols_out: self.cols_out, cols_out: self.cols_out,
size: self.size, size: self.size,
_phantom: PhantomData, _phantom: PhantomData,
} }
} }
} }

View File

@@ -313,7 +313,7 @@ pub trait VecZnxToRef {
fn to_ref(&self) -> VecZnx<&[u8]>; fn to_ref(&self) -> VecZnx<&[u8]>;
} }
pub trait VecZnxToMut { pub trait VecZnxToMut: VecZnxToRef {
fn to_mut(&mut self) -> VecZnx<&mut [u8]>; fn to_mut(&mut self) -> VecZnx<&mut [u8]>;
} }

View File

@@ -125,15 +125,8 @@ pub trait VecZnxBigOps<BACKEND: Backend> {
/// ///
/// * `basek`: normalization basis. /// * `basek`: normalization basis.
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize]. /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize].
fn vec_znx_big_normalize<R, A>( fn vec_znx_big_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
&self, where
basek: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch,
) where
R: VecZnxToMut, R: VecZnxToMut,
A: VecZnxBigToRef<BACKEND>; A: VecZnxBigToRef<BACKEND>;
@@ -530,15 +523,8 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
} }
} }
fn vec_znx_big_normalize<R, A>( fn vec_znx_big_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
&self, where
basek: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch,
) where
R: VecZnxToMut, R: VecZnxToMut,
A: VecZnxBigToRef<FFT64>, A: VecZnxBigToRef<FFT64>,
{ {

View File

@@ -142,7 +142,7 @@ pub trait VecZnxDftToRef<B: Backend> {
fn to_ref(&self) -> VecZnxDft<&[u8], B>; fn to_ref(&self) -> VecZnxDft<&[u8], B>;
} }
pub trait VecZnxDftToMut<B: Backend> { pub trait VecZnxDftToMut<B: Backend>: VecZnxDftToRef<B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>; fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>;
} }

View File

@@ -152,6 +152,11 @@ pub trait VecZnxOps {
where where
R: VecZnxToMut, R: VecZnxToMut,
A: VecZnxToRef; A: VecZnxToRef;
fn vec_znx_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
} }
pub trait VecZnxScratch { pub trait VecZnxScratch {
@@ -174,6 +179,26 @@ impl<B: Backend> VecZnxAlloc for Module<B> {
} }
impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> { impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> {
fn vec_znx_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let mut res_mut: VecZnx<&mut [u8]> = res.to_mut();
let a_ref: VecZnx<&[u8]> = a.to_ref();
let min_size: usize = min(res_mut.size(), a_ref.size());
(0..min_size).for_each(|j| {
res_mut
.at_mut(res_col, j)
.copy_from_slice(a_ref.at(a_col, j));
});
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
})
}
fn vec_znx_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) fn vec_znx_normalize<R, A>(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where where
R: VecZnxToMut, R: VecZnxToMut,

View File

@@ -1,11 +1,11 @@
use backend::{Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, FFT64}; use backend::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned};
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
use core::{ use core::{
elem::Infos, elem::Infos,
ggsw_ciphertext::GGSWCiphertext, ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext::GLWECiphertext, glwe_ciphertext::GLWECiphertext,
keys::{SecretKey, SecretKeyFourier}, keys::{SecretKey, SecretKeyFourier},
}; };
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
use sampling::source::Source; use sampling::source::Source;
fn bench_external_product_glwe_fft64(c: &mut Criterion) { fn bench_external_product_glwe_fft64(c: &mut Criterion) {

View File

@@ -1,11 +1,11 @@
use backend::{FFT64, Module, ScratchOwned}; use backend::{FFT64, Module, ScratchOwned};
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
use core::{ use core::{
elem::Infos, elem::Infos,
glwe_ciphertext::GLWECiphertext, glwe_ciphertext::GLWECiphertext,
keys::{SecretKey, SecretKeyFourier}, keys::{SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey, keyswitch_key::GLWESwitchingKey,
}; };
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
use sampling::source::Source; use sampling::source::Source;
fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { fn bench_keyswitch_glwe_fft64(c: &mut Criterion) {

View File

@@ -168,7 +168,7 @@ impl AutomorphismKey<Vec<u8>, FFT64> {
impl<DataSelf> AutomorphismKey<DataSelf, FFT64> impl<DataSelf> AutomorphismKey<DataSelf, FFT64>
where where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>, MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64>,
{ {
pub fn generate_from_sk<DataSk>( pub fn generate_from_sk<DataSk>(
&mut self, &mut self,
@@ -221,7 +221,7 @@ where
impl<DataSelf> AutomorphismKey<DataSelf, FFT64> impl<DataSelf> AutomorphismKey<DataSelf, FFT64>
where where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>, MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64>,
{ {
pub fn automorphism<DataLhs, DataRhs>( pub fn automorphism<DataLhs, DataRhs>(
&mut self, &mut self,

View File

@@ -27,6 +27,10 @@ pub trait Infos {
self.inner().cols() self.inner().cols()
} }
fn rank(&self) -> usize {
self.cols() - 1
}
/// Returns the number of size per polynomial. /// Returns the number of size per polynomial.
fn size(&self) -> usize { fn size(&self) -> usize {
let size: usize = self.inner().size(); let size: usize = self.inner().size();
@@ -46,6 +50,11 @@ pub trait Infos {
fn k(&self) -> usize; fn k(&self) -> usize;
} }
pub trait SetMetaData {
fn set_basek(&mut self, basek: usize);
fn set_k(&mut self, k: usize);
}
pub trait GetRow<B: Backend> { pub trait GetRow<B: Backend> {
fn get_row<R>(&self, module: &Module<B>, row_i: usize, col_j: usize, res: &mut R) fn get_row<R>(&self, module: &Module<B>, row_i: usize, col_j: usize, res: &mut R)
where where

View File

@@ -1,8 +1,8 @@
use backend::{ use backend::{
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps,
VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, ZnxInfos,
VecZnxToRef, ZnxInfos, ZnxZero, ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -196,7 +196,7 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
impl<DataSelf> GGSWCiphertext<DataSelf, FFT64> impl<DataSelf> GGSWCiphertext<DataSelf, FFT64>
where where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>, MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64>,
{ {
pub fn encrypt_sk<DataPt, DataSk>( pub fn encrypt_sk<DataPt, DataSk>(
&mut self, &mut self,
@@ -639,7 +639,7 @@ where
ksk: &GLWESwitchingKey<DataKsk, FFT64>, ksk: &GLWESwitchingKey<DataKsk, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) where
VecZnx<DataRes>: VecZnxToMut + VecZnxToRef, VecZnx<DataRes>: VecZnxToMut,
MatZnxDft<DataKsk, FFT64>: MatZnxDftToRef<FFT64>, MatZnxDft<DataKsk, FFT64>: MatZnxDftToRef<FFT64>,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]

View File

@@ -2,16 +2,17 @@ use backend::{
AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxAlloc, AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxAlloc,
ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc,
VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps,
VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, copy_vec_znx_from, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
use crate::{ use crate::{
SIX_SIGMA, SIX_SIGMA,
automorphism::AutomorphismKey, automorphism::AutomorphismKey,
elem::Infos, elem::{Infos, SetMetaData},
ggsw_ciphertext::GGSWCiphertext, ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier, glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_ops::GLWEOps,
glwe_plaintext::GLWEPlaintext, glwe_plaintext::GLWEPlaintext,
keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier}, keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey, keyswitch_key::GLWESwitchingKey,
@@ -201,9 +202,24 @@ impl GLWECiphertext<Vec<u8>> {
} }
} }
impl<DataSelf> SetMetaData for GLWECiphertext<DataSelf>
where
VecZnx<DataSelf>: VecZnxToMut,
{
fn set_k(&mut self, k: usize) {
self.k = k
}
fn set_basek(&mut self, basek: usize) {
self.basek = basek
}
}
impl<DataSelf> GLWEOps<FFT64> for GLWECiphertext<DataSelf> where VecZnx<DataSelf>: VecZnxToMut {}
impl<DataSelf> GLWECiphertext<DataSelf> impl<DataSelf> GLWECiphertext<DataSelf>
where where
VecZnx<DataSelf>: VecZnxToMut + VecZnxToRef, VecZnx<DataSelf>: VecZnxToMut,
{ {
pub fn encrypt_sk<DataPt, DataSk>( pub fn encrypt_sk<DataPt, DataSk>(
&mut self, &mut self,
@@ -281,21 +297,6 @@ where
self.encrypt_pk_private(module, None, pk, source_xu, source_xe, sigma, scratch); self.encrypt_pk_private(module, None, pk, source_xu, source_xe, sigma, scratch);
} }
pub fn copy<DataOther>(&mut self, other: &GLWECiphertext<DataOther>)
where
VecZnx<DataOther>: VecZnxToRef,
{
copy_vec_znx_from(&mut self.data.to_mut(), &other.to_ref());
self.k = other.k;
self.basek = other.basek;
}
pub fn rsh(&mut self, k: usize, scratch: &mut Scratch) {
let basek: usize = self.basek();
let mut self_mut: VecZnx<&mut [u8]> = self.data.to_mut();
self_mut.rsh(basek, k, scratch);
}
pub fn automorphism<DataLhs, DataRhs>( pub fn automorphism<DataLhs, DataRhs>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,

View File

@@ -1,7 +1,7 @@
use backend::{ use backend::{
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps, Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps,
ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -126,7 +126,7 @@ impl GLWECiphertextFourier<Vec<u8>, FFT64> {
impl<DataSelf> GLWECiphertextFourier<DataSelf, FFT64> impl<DataSelf> GLWECiphertextFourier<DataSelf, FFT64>
where where
VecZnxDft<DataSelf, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>, VecZnxDft<DataSelf, FFT64>: VecZnxDftToMut<FFT64>,
{ {
pub fn encrypt_zero_sk<DataSk>( pub fn encrypt_zero_sk<DataSk>(
&mut self, &mut self,
@@ -261,7 +261,7 @@ where
sk_dft: &SecretKeyFourier<DataSk, FFT64>, sk_dft: &SecretKeyFourier<DataSk, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) where ) where
VecZnx<DataPt>: VecZnxToMut + VecZnxToRef, VecZnx<DataPt>: VecZnxToMut,
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>, ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]

213
core/src/glwe_ops.rs Normal file
View File

@@ -0,0 +1,213 @@
use backend::{Backend, Module, Scratch, VecZnx, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero};
use crate::elem::{Infos, SetMetaData};
pub trait GLWEOps<BACKEND: Backend>
where
Self: Sized + VecZnxToMut + SetMetaData + Infos,
{
fn add<A, B>(&mut self, module: &Module<BACKEND>, a: &A, b: &B)
where
A: VecZnxToRef + Infos,
B: VecZnxToRef + Infos,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(a.basek(), b.basek());
assert!(self.rank() >= a.rank().max(b.rank()));
}
let min_col: usize = a.rank().min(b.rank()) + 1;
let max_col: usize = a.rank().max(b.rank() + 1);
let self_col: usize = self.rank() + 1;
(0..min_col).for_each(|i| {
module.vec_znx_add(self, i, a, i, b, i);
});
if a.rank() > b.rank() {
(min_col..max_col).for_each(|i| {
module.vec_znx_copy(self, i, a, i);
});
} else {
(min_col..max_col).for_each(|i| {
module.vec_znx_copy(self, i, b, i);
});
}
let size: usize = self.size();
let mut self_mut: VecZnx<&mut [u8]> = self.to_mut();
(max_col..self_col).for_each(|i| {
(0..size).for_each(|j| {
self_mut.zero_at(i, j);
});
});
self.set_basek(a.basek());
self.set_k(a.k().max(b.k()));
}
fn add_inplace<A>(&mut self, module: &Module<BACKEND>, a: &A)
where
A: VecZnxToRef + Infos,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(self.basek(), a.basek());
assert!(self.rank() >= a.rank())
}
(0..a.rank() + 1).for_each(|i| {
module.vec_znx_add_inplace(self, i, a, i);
});
self.set_k(a.k().max(self.k()));
}
fn sub<A, B>(&mut self, module: &Module<BACKEND>, a: &A, b: &B)
where
A: VecZnxToRef + Infos,
B: VecZnxToRef + Infos,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(b.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(a.basek(), b.basek());
assert!(self.rank() >= a.rank().max(b.rank()));
}
let min_col: usize = a.rank().min(b.rank()) + 1;
let max_col: usize = a.rank().max(b.rank() + 1);
let self_col: usize = self.rank() + 1;
(0..min_col).for_each(|i| {
module.vec_znx_sub(self, i, a, i, b, i);
});
if a.rank() > b.rank() {
(min_col..max_col).for_each(|i| {
module.vec_znx_copy(self, i, a, i);
});
} else {
(min_col..max_col).for_each(|i| {
module.vec_znx_copy(self, i, b, i);
module.vec_znx_negate_inplace(self, i);
});
}
let size: usize = self.size();
let mut self_mut: VecZnx<&mut [u8]> = self.to_mut();
(max_col..self_col).for_each(|i| {
(0..size).for_each(|j| {
self_mut.zero_at(i, j);
});
});
self.set_basek(a.basek());
self.set_k(a.k().max(b.k()));
}
fn sub_inplace_ab<A>(&mut self, module: &Module<BACKEND>, a: &A)
where
A: VecZnxToRef + Infos,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(self.basek(), a.basek());
assert!(self.rank() >= a.rank())
}
(0..a.rank() + 1).for_each(|i| {
module.vec_znx_sub_ab_inplace(self, i, a, i);
});
self.set_k(a.k().max(self.k()));
}
fn sub_inplace_ba<A>(&mut self, module: &Module<BACKEND>, a: &A)
where
A: VecZnxToRef + Infos,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(self.basek(), a.basek());
assert!(self.rank() >= a.rank())
}
(0..a.rank() + 1).for_each(|i| {
module.vec_znx_sub_ba_inplace(self, i, a, i);
});
self.set_k(a.k().max(self.k()));
}
fn rotate<A>(&mut self, module: &Module<BACKEND>, k: i64, a: &A)
where
A: VecZnxToRef + Infos,
{
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(self.basek(), a.basek());
assert_eq!(self.rank(), a.rank())
}
(0..a.rank() + 1).for_each(|i| {
module.vec_znx_rotate(k, self, i, a, i);
});
self.set_k(a.k());
}
fn rotate_inplace<A>(&mut self, module: &Module<BACKEND>, k: i64)
where
A: VecZnxToRef + Infos,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.n(), module.n());
}
(0..self.rank() + 1).for_each(|i| {
module.vec_znx_rotate_inplace(k, self, i);
});
}
fn copy<A>(&mut self, module: &Module<BACKEND>, a: &A)
where
A: VecZnxToRef + Infos,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.n(), module.n());
assert_eq!(a.n(), module.n());
}
let cols: usize = self.rank().min(a.rank()) + 1;
(0..cols).for_each(|i| {
module.vec_znx_copy(self, i, a, i);
});
self.set_k(a.k());
self.set_basek(a.basek());
}
fn rsh(&mut self, k: usize, scratch: &mut Scratch) {
let basek: usize = self.basek();
let mut self_mut: VecZnx<&mut [u8]> = self.to_mut();
self_mut.rsh(basek, k, scratch);
}
}

View File

@@ -217,7 +217,7 @@ impl<C> GLWEPublicKey<C, FFT64> {
source_xe: &mut Source, source_xe: &mut Source,
sigma: f64, sigma: f64,
) where ) where
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>, VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64>,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64> + ZnxInfos, ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64> + ZnxInfos,
{ {
#[cfg(debug_assertions)] #[cfg(debug_assertions)]

View File

@@ -149,7 +149,7 @@ impl GLWESwitchingKey<Vec<u8>, FFT64> {
} }
impl<DataSelf> GLWESwitchingKey<DataSelf, FFT64> impl<DataSelf> GLWESwitchingKey<DataSelf, FFT64>
where where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>, MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64>,
{ {
pub fn encrypt_sk<DataSkIn, DataSkOut>( pub fn encrypt_sk<DataSkIn, DataSkOut>(
&mut self, &mut self,

View File

@@ -4,6 +4,7 @@ pub mod gglwe_ciphertext;
pub mod ggsw_ciphertext; pub mod ggsw_ciphertext;
pub mod glwe_ciphertext; pub mod glwe_ciphertext;
pub mod glwe_ciphertext_fourier; pub mod glwe_ciphertext_fourier;
pub mod glwe_ops;
pub mod glwe_plaintext; pub mod glwe_plaintext;
pub mod keys; pub mod keys;
pub mod keyswitch_key; pub mod keyswitch_key;

View File

@@ -63,7 +63,7 @@ impl TensorKey<Vec<u8>, FFT64> {
impl<DataSelf> TensorKey<DataSelf, FFT64> impl<DataSelf> TensorKey<DataSelf, FFT64>
where where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>, MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64>,
{ {
pub fn encrypt_sk<DataSk>( pub fn encrypt_sk<DataSk>(
&mut self, &mut self,

View File

@@ -110,7 +110,8 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in
scratch.borrow(), scratch.borrow(),
); );
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out); let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> =
GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out);
(0..ksk.rank_in()).for_each(|col_i| { (0..ksk.rank_in()).for_each(|col_i| {
(0..ksk.rows()).for_each(|row_i| { (0..ksk.rows()).for_each(|row_i| {
@@ -202,7 +203,8 @@ fn test_key_switch(
// gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0)
ct_gglwe_s0s2.keyswitch(&module, &ct_gglwe_s0s1, &ct_gglwe_s1s2, scratch.borrow()); ct_gglwe_s0s2.keyswitch(&module, &ct_gglwe_s0s1, &ct_gglwe_s1s2, scratch.borrow());
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out_s1s2); let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> =
GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out_s1s2);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_ksk); let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_ksk);
(0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| {
@@ -304,7 +306,8 @@ fn test_key_switch_inplace(log_n: usize, basek: usize, k_ksk: usize, sigma: f64,
let ct_gglwe_s0s2: GLWESwitchingKey<Vec<u8>, FFT64> = ct_gglwe_s0s1; let ct_gglwe_s0s2: GLWESwitchingKey<Vec<u8>, FFT64> = ct_gglwe_s0s1;
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out_s0s1); let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> =
GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out_s0s1);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_ksk); let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_ksk);
(0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| {

View File

@@ -61,7 +61,8 @@ fn test_keyswitch(
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, rank_in, rank_out); let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, rank_in, rank_out);
let mut ct_glwe_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_ct_in, rank_in); let mut ct_glwe_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_ct_in, rank_in);
let mut ct_glwe_dft_in: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct_in, rank_in); let mut ct_glwe_dft_in: GLWECiphertextFourier<Vec<u8>, FFT64> =
GLWECiphertextFourier::alloc(&module, basek, k_ct_in, rank_in);
let mut ct_glwe_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_ct_out, rank_out); let mut ct_glwe_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_ct_out, rank_out);
let mut ct_glwe_dft_out: GLWECiphertextFourier<Vec<u8>, FFT64> = let mut ct_glwe_dft_out: GLWECiphertextFourier<Vec<u8>, FFT64> =
GLWECiphertextFourier::alloc(&module, basek, k_ct_out, rank_out); GLWECiphertextFourier::alloc(&module, basek, k_ct_out, rank_out);

View File

@@ -1,4 +1,6 @@
use backend::{Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxDftOps, VecZnxOps, FFT64}; use backend::{
FFT64, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxDftOps, VecZnxOps,
};
use sampling::source::Source; use sampling::source::Source;
use crate::{ use crate::{

View File

@@ -2,7 +2,7 @@ use std::collections::HashMap;
use backend::{FFT64, MatZnxDft, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxToMut, VecZnxToRef}; use backend::{FFT64, MatZnxDft, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxToMut, VecZnxToRef};
use crate::{automorphism::AutomorphismKey, glwe_ciphertext::GLWECiphertext}; use crate::{automorphism::AutomorphismKey, glwe_ciphertext::GLWECiphertext, glwe_ops::GLWEOps};
impl GLWECiphertext<Vec<u8>> { impl GLWECiphertext<Vec<u8>> {
pub fn trace_galois_elements(module: &Module<FFT64>) -> Vec<i64> { pub fn trace_galois_elements(module: &Module<FFT64>) -> Vec<i64> {
@@ -34,7 +34,7 @@ impl GLWECiphertext<Vec<u8>> {
impl<DataSelf> GLWECiphertext<DataSelf> impl<DataSelf> GLWECiphertext<DataSelf>
where where
VecZnx<DataSelf>: VecZnxToMut + VecZnxToRef, VecZnx<DataSelf>: VecZnxToMut,
{ {
pub fn trace<DataLhs, DataAK>( pub fn trace<DataLhs, DataAK>(
&mut self, &mut self,
@@ -48,7 +48,7 @@ where
VecZnx<DataLhs>: VecZnxToRef, VecZnx<DataLhs>: VecZnxToRef,
MatZnxDft<DataAK, FFT64>: MatZnxDftToRef<FFT64>, MatZnxDft<DataAK, FFT64>: MatZnxDftToRef<FFT64>,
{ {
self.copy(lhs); self.copy(module, lhs);
self.trace_inplace(module, start, end, auto_keys, scratch); self.trace_inplace(module, start, end, auto_keys, scratch);
} }