From 30849789764e2ee5f96d97e5ef349f056ec84f13 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Thu, 22 May 2025 16:08:44 +0200 Subject: [PATCH] Added basic GLWE ops --- backend/examples/rlwe_encrypt.rs | 2 +- backend/src/mat_znx_dft.rs | 464 ++++++++++---------- backend/src/vec_znx.rs | 2 +- backend/src/vec_znx_big_ops.rs | 22 +- backend/src/vec_znx_dft.rs | 2 +- backend/src/vec_znx_ops.rs | 25 ++ core/benches/external_product_glwe_fft64.rs | 4 +- core/benches/keyswitch_glwe_fft64.rs | 2 +- core/src/automorphism.rs | 4 +- core/src/elem.rs | 9 + core/src/ggsw_ciphertext.rs | 8 +- core/src/glwe_ciphertext.rs | 37 +- core/src/glwe_ciphertext_fourier.rs | 6 +- core/src/glwe_ops.rs | 213 +++++++++ core/src/keys.rs | 2 +- core/src/keyswitch_key.rs | 2 +- core/src/lib.rs | 1 + core/src/tensor_key.rs | 2 +- core/src/test_fft64/gglwe.rs | 9 +- core/src/test_fft64/glwe_fourier.rs | 3 +- core/src/test_fft64/tensor_key.rs | 4 +- core/src/trace.rs | 6 +- 22 files changed, 535 insertions(+), 294 deletions(-) create mode 100644 core/src/glwe_ops.rs diff --git a/backend/examples/rlwe_encrypt.rs b/backend/examples/rlwe_encrypt.rs index a16437f..84f85ad 100644 --- a/backend/examples/rlwe_encrypt.rs +++ b/backend/examples/rlwe_encrypt.rs @@ -90,7 +90,7 @@ fn main() { // ct[0] <- ct[0] + e ct.add_normal( basek, - 0, // Selects the first column of ct (ct[0]) + 0, // Selects the first column of ct (ct[0]) basek * ct_size, // Scaling of the noise: 2^{-basek * limbs} &mut source, 3.2, // Standard deviation diff --git a/backend/src/mat_znx_dft.rs b/backend/src/mat_znx_dft.rs index 209c696..1fe67eb 100644 --- a/backend/src/mat_znx_dft.rs +++ b/backend/src/mat_znx_dft.rs @@ -1,232 +1,232 @@ -use crate::znx_base::ZnxInfos; -use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; -use std::marker::PhantomData; - -/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], -/// stored as a 3D matrix in the DFT domain in a single contiguous array. -/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft]. -/// -/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft]. -/// See the trait [MatZnxDftOps] for additional information. -pub struct MatZnxDft { - data: D, - n: usize, - size: usize, - rows: usize, - cols_in: usize, - cols_out: usize, - _phantom: PhantomData, -} - -impl ZnxInfos for MatZnxDft { - fn cols(&self) -> usize { - self.cols_in - } - - fn rows(&self) -> usize { - self.rows - } - - fn n(&self) -> usize { - self.n - } - - fn size(&self) -> usize { - self.size - } -} - -impl ZnxSliceSize for MatZnxDft { - fn sl(&self) -> usize { - self.n() * self.cols_out() - } -} - -impl DataView for MatZnxDft { - type D = D; - fn data(&self) -> &Self::D { - &self.data - } -} - -impl DataViewMut for MatZnxDft { - fn data_mut(&mut self) -> &mut Self::D { - &mut self.data - } -} - -impl> ZnxView for MatZnxDft { - type Scalar = f64; -} - -impl MatZnxDft { - pub fn cols_in(&self) -> usize { - self.cols_in - } - - pub fn cols_out(&self) -> usize { - self.cols_out - } -} - -impl>, B: Backend> MatZnxDft { - pub(crate) fn bytes_of(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { - unsafe { - crate::ffi::vmp::bytes_of_vmp_pmat( - module.ptr, - (rows * cols_in) as u64, - (size * cols_out) as u64, - ) as usize - } - } - - pub(crate) fn new(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { - let data: Vec = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size)); - Self { - data: data.into(), - n: module.n(), - size, - rows, - cols_in, - cols_out, - _phantom: PhantomData, - } - } - - pub(crate) fn new_from_bytes( - module: &Module, - rows: usize, - cols_in: usize, - cols_out: usize, - size: usize, - bytes: impl Into>, - ) -> Self { - let data: Vec = bytes.into(); - assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size)); - Self { - data: data.into(), - n: module.n(), - size, - rows, - cols_in, - cols_out, - _phantom: PhantomData, - } - } -} - -impl> MatZnxDft { - /// Returns a copy of the backend array at index (i, j) of the [MatZnxDft]. - /// - /// # Arguments - /// - /// * `row`: row index (i). - /// * `col`: col index (j). - #[allow(dead_code)] - fn at(&self, row: usize, col: usize) -> Vec { - let n: usize = self.n(); - - let mut res: Vec = alloc_aligned(n); - - if n < 8 { - res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]); - } else { - (0..n >> 3).for_each(|blk| { - res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); - }); - } - - res - } - - #[allow(dead_code)] - fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { - let nrows: usize = self.rows(); - let nsize: usize = self.size(); - if col == (nsize - 1) && (nsize & 1 == 1) { - &self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..] - } else { - &self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] - } - } -} - -pub type MatZnxDftOwned = MatZnxDft, B>; - -pub trait MatZnxDftToRef { - fn to_ref(&self) -> MatZnxDft<&[u8], B>; -} - -pub trait MatZnxDftToMut { - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>; -} - -impl MatZnxDftToMut for MatZnxDft, B> { - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - MatZnxDft { - data: self.data.as_mut_slice(), - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: PhantomData, - } - } -} - -impl MatZnxDftToRef for MatZnxDft, B> { - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - MatZnxDft { - data: self.data.as_slice(), - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: PhantomData, - } - } -} - -impl MatZnxDftToMut for MatZnxDft<&mut [u8], B> { - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - MatZnxDft { - data: self.data, - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: PhantomData, - } - } -} - -impl MatZnxDftToRef for MatZnxDft<&mut [u8], B> { - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - MatZnxDft { - data: self.data, - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: PhantomData, - } - } -} - -impl MatZnxDftToRef for MatZnxDft<&[u8], B> { - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - MatZnxDft { - data: self.data, - n: self.n, - rows: self.rows, - cols_in: self.cols_in, - cols_out: self.cols_out, - size: self.size, - _phantom: PhantomData, - } - } -} +use crate::znx_base::ZnxInfos; +use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned}; +use std::marker::PhantomData; + +/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx], +/// stored as a 3D matrix in the DFT domain in a single contiguous array. +/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft]. +/// +/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft]. +/// See the trait [MatZnxDftOps] for additional information. +pub struct MatZnxDft { + data: D, + n: usize, + size: usize, + rows: usize, + cols_in: usize, + cols_out: usize, + _phantom: PhantomData, +} + +impl ZnxInfos for MatZnxDft { + fn cols(&self) -> usize { + self.cols_in + } + + fn rows(&self) -> usize { + self.rows + } + + fn n(&self) -> usize { + self.n + } + + fn size(&self) -> usize { + self.size + } +} + +impl ZnxSliceSize for MatZnxDft { + fn sl(&self) -> usize { + self.n() * self.cols_out() + } +} + +impl DataView for MatZnxDft { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for MatZnxDft { + fn data_mut(&mut self) -> &mut Self::D { + &mut self.data + } +} + +impl> ZnxView for MatZnxDft { + type Scalar = f64; +} + +impl MatZnxDft { + pub fn cols_in(&self) -> usize { + self.cols_in + } + + pub fn cols_out(&self) -> usize { + self.cols_out + } +} + +impl>, B: Backend> MatZnxDft { + pub(crate) fn bytes_of(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize { + unsafe { + crate::ffi::vmp::bytes_of_vmp_pmat( + module.ptr, + (rows * cols_in) as u64, + (size * cols_out) as u64, + ) as usize + } + } + + pub(crate) fn new(module: &Module, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self { + let data: Vec = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n: module.n(), + size, + rows, + cols_in, + cols_out, + _phantom: PhantomData, + } + } + + pub(crate) fn new_from_bytes( + module: &Module, + rows: usize, + cols_in: usize, + cols_out: usize, + size: usize, + bytes: impl Into>, + ) -> Self { + let data: Vec = bytes.into(); + assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size)); + Self { + data: data.into(), + n: module.n(), + size, + rows, + cols_in, + cols_out, + _phantom: PhantomData, + } + } +} + +impl> MatZnxDft { + /// Returns a copy of the backend array at index (i, j) of the [MatZnxDft]. + /// + /// # Arguments + /// + /// * `row`: row index (i). + /// * `col`: col index (j). + #[allow(dead_code)] + fn at(&self, row: usize, col: usize) -> Vec { + let n: usize = self.n(); + + let mut res: Vec = alloc_aligned(n); + + if n < 8 { + res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]); + } else { + (0..n >> 3).for_each(|blk| { + res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]); + }); + } + + res + } + + #[allow(dead_code)] + fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] { + let nrows: usize = self.rows(); + let nsize: usize = self.size(); + if col == (nsize - 1) && (nsize & 1 == 1) { + &self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..] + } else { + &self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..] + } + } +} + +pub type MatZnxDftOwned = MatZnxDft, B>; + +pub trait MatZnxDftToRef { + fn to_ref(&self) -> MatZnxDft<&[u8], B>; +} + +pub trait MatZnxDftToMut: MatZnxDftToRef { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>; +} + +impl MatZnxDftToMut for MatZnxDft, B> { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + MatZnxDft { + data: self.data.as_mut_slice(), + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft, B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data.as_slice(), + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToMut for MatZnxDft<&mut [u8], B> { + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft<&mut [u8], B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} + +impl MatZnxDftToRef for MatZnxDft<&[u8], B> { + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + MatZnxDft { + data: self.data, + n: self.n, + rows: self.rows, + cols_in: self.cols_in, + cols_out: self.cols_out, + size: self.size, + _phantom: PhantomData, + } + } +} diff --git a/backend/src/vec_znx.rs b/backend/src/vec_znx.rs index 950fae9..feba7ce 100644 --- a/backend/src/vec_znx.rs +++ b/backend/src/vec_znx.rs @@ -313,7 +313,7 @@ pub trait VecZnxToRef { fn to_ref(&self) -> VecZnx<&[u8]>; } -pub trait VecZnxToMut { +pub trait VecZnxToMut: VecZnxToRef { fn to_mut(&mut self) -> VecZnx<&mut [u8]>; } diff --git a/backend/src/vec_znx_big_ops.rs b/backend/src/vec_znx_big_ops.rs index d23dc22..a88dd27 100644 --- a/backend/src/vec_znx_big_ops.rs +++ b/backend/src/vec_znx_big_ops.rs @@ -125,15 +125,8 @@ pub trait VecZnxBigOps { /// /// * `basek`: normalization basis. /// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize]. - fn vec_znx_big_normalize( - &self, - basek: usize, - res: &mut R, - res_col: usize, - a: &A, - a_col: usize, - scratch: &mut Scratch, - ) where + fn vec_znx_big_normalize(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where R: VecZnxToMut, A: VecZnxBigToRef; @@ -530,15 +523,8 @@ impl VecZnxBigOps for Module { } } - fn vec_znx_big_normalize( - &self, - basek: usize, - res: &mut R, - res_col: usize, - a: &A, - a_col: usize, - scratch: &mut Scratch, - ) where + fn vec_znx_big_normalize(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) + where R: VecZnxToMut, A: VecZnxBigToRef, { diff --git a/backend/src/vec_znx_dft.rs b/backend/src/vec_znx_dft.rs index 7b4ec29..c304089 100644 --- a/backend/src/vec_znx_dft.rs +++ b/backend/src/vec_znx_dft.rs @@ -142,7 +142,7 @@ pub trait VecZnxDftToRef { fn to_ref(&self) -> VecZnxDft<&[u8], B>; } -pub trait VecZnxDftToMut { +pub trait VecZnxDftToMut: VecZnxDftToRef { fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>; } diff --git a/backend/src/vec_znx_ops.rs b/backend/src/vec_znx_ops.rs index 90321a5..106f777 100644 --- a/backend/src/vec_znx_ops.rs +++ b/backend/src/vec_znx_ops.rs @@ -152,6 +152,11 @@ pub trait VecZnxOps { where R: VecZnxToMut, A: VecZnxToRef; + + fn vec_znx_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef; } pub trait VecZnxScratch { @@ -174,6 +179,26 @@ impl VecZnxAlloc for Module { } impl VecZnxOps for Module { + fn vec_znx_copy(&self, res: &mut R, res_col: usize, a: &A, a_col: usize) + where + R: VecZnxToMut, + A: VecZnxToRef, + { + let mut res_mut: VecZnx<&mut [u8]> = res.to_mut(); + let a_ref: VecZnx<&[u8]> = a.to_ref(); + + let min_size: usize = min(res_mut.size(), a_ref.size()); + + (0..min_size).for_each(|j| { + res_mut + .at_mut(res_col, j) + .copy_from_slice(a_ref.at(a_col, j)); + }); + (min_size..res_mut.size()).for_each(|j| { + res_mut.zero_at(res_col, j); + }) + } + fn vec_znx_normalize(&self, basek: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch) where R: VecZnxToMut, diff --git a/core/benches/external_product_glwe_fft64.rs b/core/benches/external_product_glwe_fft64.rs index b99e875..7c57a2f 100644 --- a/core/benches/external_product_glwe_fft64.rs +++ b/core/benches/external_product_glwe_fft64.rs @@ -1,11 +1,11 @@ -use backend::{Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, FFT64}; -use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; +use backend::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned}; use core::{ elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, keys::{SecretKey, SecretKeyFourier}, }; +use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; use sampling::source::Source; fn bench_external_product_glwe_fft64(c: &mut Criterion) { diff --git a/core/benches/keyswitch_glwe_fft64.rs b/core/benches/keyswitch_glwe_fft64.rs index 16ea862..0d30b80 100644 --- a/core/benches/keyswitch_glwe_fft64.rs +++ b/core/benches/keyswitch_glwe_fft64.rs @@ -1,11 +1,11 @@ use backend::{FFT64, Module, ScratchOwned}; -use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; use core::{ elem::Infos, glwe_ciphertext::GLWECiphertext, keys::{SecretKey, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, }; +use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main}; use sampling::source::Source; fn bench_keyswitch_glwe_fft64(c: &mut Criterion) { diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index a120594..91f13e5 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -168,7 +168,7 @@ impl AutomorphismKey, FFT64> { impl AutomorphismKey where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToMut, { pub fn generate_from_sk( &mut self, @@ -221,7 +221,7 @@ where impl AutomorphismKey where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToMut, { pub fn automorphism( &mut self, diff --git a/core/src/elem.rs b/core/src/elem.rs index 554b743..426e23d 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -27,6 +27,10 @@ pub trait Infos { self.inner().cols() } + fn rank(&self) -> usize { + self.cols() - 1 + } + /// Returns the number of size per polynomial. fn size(&self) -> usize { let size: usize = self.inner().size(); @@ -46,6 +50,11 @@ pub trait Infos { fn k(&self) -> usize; } +pub trait SetMetaData { + fn set_basek(&mut self, basek: usize); + fn set_k(&mut self, k: usize); +} + pub trait GetRow { fn get_row(&self, module: &Module, row_i: usize, col_j: usize, res: &mut R) where diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index fdbc225..ca22faf 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -1,8 +1,8 @@ use backend::{ Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, - VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, - VecZnxToRef, ZnxInfos, ZnxZero, + VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, ZnxInfos, + ZnxZero, }; use sampling::source::Source; @@ -196,7 +196,7 @@ impl GGSWCiphertext, FFT64> { impl GGSWCiphertext where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToMut, { pub fn encrypt_sk( &mut self, @@ -639,7 +639,7 @@ where ksk: &GLWESwitchingKey, scratch: &mut Scratch, ) where - VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToMut, MatZnxDft: MatZnxDftToRef, { #[cfg(debug_assertions)] diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index 155eca4..18c2ce0 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -2,16 +2,17 @@ use backend::{ AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, - VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, copy_vec_znx_from, + VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero, }; use sampling::source::Source; use crate::{ SIX_SIGMA, automorphism::AutomorphismKey, - elem::Infos, + elem::{Infos, SetMetaData}, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, + glwe_ops::GLWEOps, glwe_plaintext::GLWEPlaintext, keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier}, keyswitch_key::GLWESwitchingKey, @@ -201,9 +202,24 @@ impl GLWECiphertext> { } } +impl SetMetaData for GLWECiphertext +where + VecZnx: VecZnxToMut, +{ + fn set_k(&mut self, k: usize) { + self.k = k + } + + fn set_basek(&mut self, basek: usize) { + self.basek = basek + } +} + +impl GLWEOps for GLWECiphertext where VecZnx: VecZnxToMut {} + impl GLWECiphertext where - VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToMut, { pub fn encrypt_sk( &mut self, @@ -281,21 +297,6 @@ where self.encrypt_pk_private(module, None, pk, source_xu, source_xe, sigma, scratch); } - pub fn copy(&mut self, other: &GLWECiphertext) - where - VecZnx: VecZnxToRef, - { - copy_vec_znx_from(&mut self.data.to_mut(), &other.to_ref()); - self.k = other.k; - self.basek = other.basek; - } - - pub fn rsh(&mut self, k: usize, scratch: &mut Scratch) { - let basek: usize = self.basek(); - let mut self_mut: VecZnx<&mut [u8]> = self.data.to_mut(); - self_mut.rsh(basek, k, scratch); - } - pub fn automorphism( &mut self, module: &Module, diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index 20b4f72..921fd55 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -1,7 +1,7 @@ use backend::{ Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, - VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero, + VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, ZnxZero, }; use sampling::source::Source; @@ -126,7 +126,7 @@ impl GLWECiphertextFourier, FFT64> { impl GLWECiphertextFourier where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + VecZnxDft: VecZnxDftToMut, { pub fn encrypt_zero_sk( &mut self, @@ -261,7 +261,7 @@ where sk_dft: &SecretKeyFourier, scratch: &mut Scratch, ) where - VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToMut, ScalarZnxDft: ScalarZnxDftToRef, { #[cfg(debug_assertions)] diff --git a/core/src/glwe_ops.rs b/core/src/glwe_ops.rs new file mode 100644 index 0000000..691e34d --- /dev/null +++ b/core/src/glwe_ops.rs @@ -0,0 +1,213 @@ +use backend::{Backend, Module, Scratch, VecZnx, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero}; + +use crate::elem::{Infos, SetMetaData}; + +pub trait GLWEOps +where + Self: Sized + VecZnxToMut + SetMetaData + Infos, +{ + fn add(&mut self, module: &Module, a: &A, b: &B) + where + A: VecZnxToRef + Infos, + B: VecZnxToRef + Infos, + { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(a.basek(), b.basek()); + assert!(self.rank() >= a.rank().max(b.rank())); + } + + let min_col: usize = a.rank().min(b.rank()) + 1; + let max_col: usize = a.rank().max(b.rank() + 1); + let self_col: usize = self.rank() + 1; + + (0..min_col).for_each(|i| { + module.vec_znx_add(self, i, a, i, b, i); + }); + + if a.rank() > b.rank() { + (min_col..max_col).for_each(|i| { + module.vec_znx_copy(self, i, a, i); + }); + } else { + (min_col..max_col).for_each(|i| { + module.vec_znx_copy(self, i, b, i); + }); + } + + let size: usize = self.size(); + let mut self_mut: VecZnx<&mut [u8]> = self.to_mut(); + (max_col..self_col).for_each(|i| { + (0..size).for_each(|j| { + self_mut.zero_at(i, j); + }); + }); + + self.set_basek(a.basek()); + self.set_k(a.k().max(b.k())); + } + + fn add_inplace(&mut self, module: &Module, a: &A) + where + A: VecZnxToRef + Infos, + { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(self.basek(), a.basek()); + assert!(self.rank() >= a.rank()) + } + + (0..a.rank() + 1).for_each(|i| { + module.vec_znx_add_inplace(self, i, a, i); + }); + + self.set_k(a.k().max(self.k())); + } + + fn sub(&mut self, module: &Module, a: &A, b: &B) + where + A: VecZnxToRef + Infos, + B: VecZnxToRef + Infos, + { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(b.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(a.basek(), b.basek()); + assert!(self.rank() >= a.rank().max(b.rank())); + } + + let min_col: usize = a.rank().min(b.rank()) + 1; + let max_col: usize = a.rank().max(b.rank() + 1); + let self_col: usize = self.rank() + 1; + + (0..min_col).for_each(|i| { + module.vec_znx_sub(self, i, a, i, b, i); + }); + + if a.rank() > b.rank() { + (min_col..max_col).for_each(|i| { + module.vec_znx_copy(self, i, a, i); + }); + } else { + (min_col..max_col).for_each(|i| { + module.vec_znx_copy(self, i, b, i); + module.vec_znx_negate_inplace(self, i); + }); + } + + let size: usize = self.size(); + let mut self_mut: VecZnx<&mut [u8]> = self.to_mut(); + (max_col..self_col).for_each(|i| { + (0..size).for_each(|j| { + self_mut.zero_at(i, j); + }); + }); + + self.set_basek(a.basek()); + self.set_k(a.k().max(b.k())); + } + + fn sub_inplace_ab(&mut self, module: &Module, a: &A) + where + A: VecZnxToRef + Infos, + { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(self.basek(), a.basek()); + assert!(self.rank() >= a.rank()) + } + + (0..a.rank() + 1).for_each(|i| { + module.vec_znx_sub_ab_inplace(self, i, a, i); + }); + + self.set_k(a.k().max(self.k())); + } + + fn sub_inplace_ba(&mut self, module: &Module, a: &A) + where + A: VecZnxToRef + Infos, + { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(self.basek(), a.basek()); + assert!(self.rank() >= a.rank()) + } + + (0..a.rank() + 1).for_each(|i| { + module.vec_znx_sub_ba_inplace(self, i, a, i); + }); + + self.set_k(a.k().max(self.k())); + } + + fn rotate(&mut self, module: &Module, k: i64, a: &A) + where + A: VecZnxToRef + Infos, + { + #[cfg(debug_assertions)] + { + assert_eq!(a.n(), module.n()); + assert_eq!(self.n(), module.n()); + assert_eq!(self.basek(), a.basek()); + assert_eq!(self.rank(), a.rank()) + } + + (0..a.rank() + 1).for_each(|i| { + module.vec_znx_rotate(k, self, i, a, i); + }); + + self.set_k(a.k()); + } + + fn rotate_inplace(&mut self, module: &Module, k: i64) + where + A: VecZnxToRef + Infos, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.n(), module.n()); + } + + (0..self.rank() + 1).for_each(|i| { + module.vec_znx_rotate_inplace(k, self, i); + }); + } + + fn copy(&mut self, module: &Module, a: &A) + where + A: VecZnxToRef + Infos, + { + #[cfg(debug_assertions)] + { + assert_eq!(self.n(), module.n()); + assert_eq!(a.n(), module.n()); + } + + let cols: usize = self.rank().min(a.rank()) + 1; + + (0..cols).for_each(|i| { + module.vec_znx_copy(self, i, a, i); + }); + + self.set_k(a.k()); + self.set_basek(a.basek()); + } + + fn rsh(&mut self, k: usize, scratch: &mut Scratch) { + let basek: usize = self.basek(); + let mut self_mut: VecZnx<&mut [u8]> = self.to_mut(); + self_mut.rsh(basek, k, scratch); + } +} diff --git a/core/src/keys.rs b/core/src/keys.rs index 3b0af56..e8af9b1 100644 --- a/core/src/keys.rs +++ b/core/src/keys.rs @@ -217,7 +217,7 @@ impl GLWEPublicKey { source_xe: &mut Source, sigma: f64, ) where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + VecZnxDft: VecZnxDftToMut, ScalarZnxDft: ScalarZnxDftToRef + ZnxInfos, { #[cfg(debug_assertions)] diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index a4ace4a..2ee24ed 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -149,7 +149,7 @@ impl GLWESwitchingKey, FFT64> { } impl GLWESwitchingKey where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToMut, { pub fn encrypt_sk( &mut self, diff --git a/core/src/lib.rs b/core/src/lib.rs index 249ad94..82b3c4b 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -4,6 +4,7 @@ pub mod gglwe_ciphertext; pub mod ggsw_ciphertext; pub mod glwe_ciphertext; pub mod glwe_ciphertext_fourier; +pub mod glwe_ops; pub mod glwe_plaintext; pub mod keys; pub mod keyswitch_key; diff --git a/core/src/tensor_key.rs b/core/src/tensor_key.rs index ee31c8d..985f90d 100644 --- a/core/src/tensor_key.rs +++ b/core/src/tensor_key.rs @@ -63,7 +63,7 @@ impl TensorKey, FFT64> { impl TensorKey where - MatZnxDft: MatZnxDftToMut + MatZnxDftToRef, + MatZnxDft: MatZnxDftToMut, { pub fn encrypt_sk( &mut self, diff --git a/core/src/test_fft64/gglwe.rs b/core/src/test_fft64/gglwe.rs index 5324a02..52339f5 100644 --- a/core/src/test_fft64/gglwe.rs +++ b/core/src/test_fft64/gglwe.rs @@ -110,7 +110,8 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in scratch.borrow(), ); - let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out); + let mut ct_glwe_fourier: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out); (0..ksk.rank_in()).for_each(|col_i| { (0..ksk.rows()).for_each(|row_i| { @@ -202,7 +203,8 @@ fn test_key_switch( // gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0) ct_gglwe_s0s2.keyswitch(&module, &ct_gglwe_s0s1, &ct_gglwe_s1s2, scratch.borrow()); - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out_s1s2); + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out_s1s2); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ksk); (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { @@ -304,7 +306,8 @@ fn test_key_switch_inplace(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, let ct_gglwe_s0s2: GLWESwitchingKey, FFT64> = ct_gglwe_s0s1; - let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out_s0s1); + let mut ct_glwe_dft: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::alloc(&module, basek, k_ksk, rank_out_s0s1); let mut pt: GLWEPlaintext> = GLWEPlaintext::alloc(&module, basek, k_ksk); (0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| { diff --git a/core/src/test_fft64/glwe_fourier.rs b/core/src/test_fft64/glwe_fourier.rs index 4a8c7c1..4532fb8 100644 --- a/core/src/test_fft64/glwe_fourier.rs +++ b/core/src/test_fft64/glwe_fourier.rs @@ -61,7 +61,8 @@ fn test_keyswitch( let mut ksk: GLWESwitchingKey, FFT64> = GLWESwitchingKey::alloc(&module, basek, k_ksk, rows, rank_in, rank_out); let mut ct_glwe_in: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct_in, rank_in); - let mut ct_glwe_dft_in: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct_in, rank_in); + let mut ct_glwe_dft_in: GLWECiphertextFourier, FFT64> = + GLWECiphertextFourier::alloc(&module, basek, k_ct_in, rank_in); let mut ct_glwe_out: GLWECiphertext> = GLWECiphertext::alloc(&module, basek, k_ct_out, rank_out); let mut ct_glwe_dft_out: GLWECiphertextFourier, FFT64> = GLWECiphertextFourier::alloc(&module, basek, k_ct_out, rank_out); diff --git a/core/src/test_fft64/tensor_key.rs b/core/src/test_fft64/tensor_key.rs index ea90413..a897253 100644 --- a/core/src/test_fft64/tensor_key.rs +++ b/core/src/test_fft64/tensor_key.rs @@ -1,4 +1,6 @@ -use backend::{Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxDftOps, VecZnxOps, FFT64}; +use backend::{ + FFT64, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxDftOps, VecZnxOps, +}; use sampling::source::Source; use crate::{ diff --git a/core/src/trace.rs b/core/src/trace.rs index 1169795..07ce1b3 100644 --- a/core/src/trace.rs +++ b/core/src/trace.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use backend::{FFT64, MatZnxDft, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxToMut, VecZnxToRef}; -use crate::{automorphism::AutomorphismKey, glwe_ciphertext::GLWECiphertext}; +use crate::{automorphism::AutomorphismKey, glwe_ciphertext::GLWECiphertext, glwe_ops::GLWEOps}; impl GLWECiphertext> { pub fn trace_galois_elements(module: &Module) -> Vec { @@ -34,7 +34,7 @@ impl GLWECiphertext> { impl GLWECiphertext where - VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToMut, { pub fn trace( &mut self, @@ -48,7 +48,7 @@ where VecZnx: VecZnxToRef, MatZnxDft: MatZnxDftToRef, { - self.copy(lhs); + self.copy(module, lhs); self.trace_inplace(module, start, end, auto_keys, scratch); }