Added grlwe ops + tests

This commit is contained in:
Jean-Philippe Bossuat
2025-05-09 10:39:00 +02:00
parent de3b34477d
commit 9913040aa1
16 changed files with 1435 additions and 385 deletions

View File

@@ -1,5 +1,7 @@
use crate::znx_base::ZnxInfos;
use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut};
use crate::{
Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned,
};
use rand::seq::SliceRandom;
use rand_core::RngCore;
use rand_distr::{Distribution, weighted::WeightedIndex};
@@ -144,7 +146,7 @@ impl ScalarZnxToMut for ScalarZnx<Vec<u8>> {
}
}
impl VecZnxToMut for ScalarZnx<Vec<u8>>{
impl VecZnxToMut for ScalarZnx<Vec<u8>> {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data.as_mut_slice(),
@@ -165,7 +167,7 @@ impl ScalarZnxToRef for ScalarZnx<Vec<u8>> {
}
}
impl VecZnxToRef for ScalarZnx<Vec<u8>>{
impl VecZnxToRef for ScalarZnx<Vec<u8>> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data.as_slice(),

View File

@@ -1,6 +1,7 @@
use crate::DataView;
use crate::DataViewMut;
use crate::ZnxSliceSize;
use crate::ZnxZero;
use crate::alloc_aligned;
use crate::assert_alignement;
use crate::cast_mut;
@@ -182,6 +183,39 @@ fn normalize<D: AsMut<[u8]> + AsRef<[u8]>>(log_base2k: usize, a: &mut VecZnx<D>,
}
}
impl<D> VecZnx<D>
where
VecZnx<D>: VecZnxToMut + ZnxInfos,
{
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnx<C>, a_col: usize)
where
VecZnx<C>: VecZnxToRef + ZnxInfos,
{
#[cfg(debug_assertions)]
{
assert!(self_col < self.cols());
assert!(a_col < a.cols());
}
let min_size: usize = self.size.min(a.size());
let max_size: usize = self.size;
let mut self_mut: VecZnx<&mut [u8]> = self.to_mut();
let a_ref: VecZnx<&[u8]> = a.to_ref();
(0..min_size).for_each(|i: usize| {
self_mut
.at_mut(self_col, i)
.copy_from_slice(a_ref.at(a_col, i));
});
(min_size..max_size).for_each(|i| {
self_mut.zero_at(self_col, i);
});
}
}
impl<D: AsRef<[u8]>> fmt::Display for VecZnx<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(

View File

@@ -1,6 +1,6 @@
use crate::ffi::vec_znx_big;
use crate::znx_base::{ZnxInfos, ZnxView};
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, alloc_aligned};
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned};
use std::fmt;
use std::marker::PhantomData;
@@ -94,6 +94,39 @@ impl<D, B: Backend> VecZnxBig<D, B> {
}
}
impl<D> VecZnxBig<D, FFT64>
where
VecZnxBig<D, FFT64>: VecZnxBigToMut<FFT64> + ZnxInfos,
{
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxBig<C, FFT64>, a_col: usize)
where
VecZnxBig<C, FFT64>: VecZnxBigToRef<FFT64> + ZnxInfos,
{
#[cfg(debug_assertions)]
{
assert!(self_col < self.cols());
assert!(a_col < a.cols());
}
let min_size: usize = self.size.min(a.size());
let max_size: usize = self.size;
let mut self_mut: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
let a_ref: VecZnxBig<&[u8], FFT64> = a.to_ref();
(0..min_size).for_each(|i: usize| {
self_mut
.at_mut(self_col, i)
.copy_from_slice(a_ref.at(a_col, i));
});
(min_size..max_size).for_each(|i| {
self_mut.zero_at(self_col, i);
});
}
}
pub type VecZnxBigOwned<B> = VecZnxBig<Vec<u8>, B>;
pub trait VecZnxBigToRef<B: Backend> {

View File

@@ -115,7 +115,9 @@ pub trait VecZnxBigOps<BACKEND: Backend> {
A: VecZnxToRef;
/// Negates `a` inplace.
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize) where A: VecZnxBigToMut<BACKEND>;
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<BACKEND>;
/// Normalizes `a` and stores the result on `b`.
///
@@ -506,7 +508,10 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
}
}
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, res_col: usize) where A: VecZnxBigToMut<FFT64> {
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, res_col: usize)
where
A: VecZnxBigToMut<FFT64>,
{
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
#[cfg(debug_assertions)]
{

View File

@@ -2,7 +2,9 @@ use std::marker::PhantomData;
use crate::ffi::vec_znx_dft;
use crate::znx_base::ZnxInfos;
use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, alloc_aligned};
use crate::{
Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, alloc_aligned,
};
use std::fmt;
pub struct VecZnxDft<D, B: Backend> {
@@ -89,6 +91,39 @@ impl<D: From<Vec<u8>>, B: Backend> VecZnxDft<D, B> {
}
}
impl<D> VecZnxDft<D, FFT64>
where
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64> + ZnxInfos,
{
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxDft<C, FFT64>, a_col: usize)
where
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos,
{
#[cfg(debug_assertions)]
{
assert!(self_col < self.cols());
assert!(a_col < a.cols());
}
let min_size: usize = self.size.min(a.size());
let max_size: usize = self.size;
let mut self_mut: VecZnxDft<&mut [u8], FFT64> = self.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
(0..min_size).for_each(|i: usize| {
self_mut
.at_mut(self_col, i)
.copy_from_slice(a_ref.at(a_col, i));
});
(min_size..max_size).for_each(|i| {
self_mut.zero_at(self_col, i);
});
}
}
pub type VecZnxDftOwned<B> = VecZnxDft<Vec<u8>, B>;
impl<D, B: Backend> VecZnxDft<D, B> {

View File

@@ -47,7 +47,9 @@ pub trait VecZnxDftOps<B: Backend> {
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToMut<B>;
fn vec_znx_idft_consume<D>(&self, a: VecZnxDft<D, B>, a_cols: usize) -> VecZnxBig<D, FFT64>
/// Consumes a to return IDFT(a) in big coeff space.
fn vec_znx_idft_consume<D>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, FFT64>
where
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>;
@@ -103,25 +105,28 @@ impl VecZnxDftOps<FFT64> for Module<FFT64> {
}
}
fn vec_znx_idft_consume<D>(&self, mut a: VecZnxDft<D, FFT64>, a_col: usize) -> VecZnxBig<D, FFT64>
fn vec_znx_idft_consume<D>(&self, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
where
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
{
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
unsafe {
// Rev col and rows because ZnxDft.sl() >= ZnxBig.sl()
(0..a_mut.size()).for_each(|j| {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_big::vec_znx_big_t,
1 as u64,
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1 as u64,
)
(0..a_mut.cols()).for_each(|i| {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
a_mut.at_mut_ptr(i, j) as *mut vec_znx_big::vec_znx_big_t,
1 as u64,
a_mut.at_mut_ptr(i, j) as *mut vec_znx_dft::vec_znx_dft_t,
1 as u64,
)
});
});
a.into_big()
}
a.into_big()
}
fn vec_znx_idft_tmp_bytes(&self) -> usize {

View File

@@ -101,25 +101,25 @@ pub trait ZnxViewMut: ZnxView + DataViewMut<D: AsMut<[u8]>> {
//(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: AsMut<[u8]>> {}
pub trait ZnxZero: ZnxViewMut
pub trait ZnxZero: ZnxViewMut + ZnxSliceSize
where
Self: Sized,
{
fn zero(&mut self) {
unsafe {
std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * self.poly_count());
std::ptr::write_bytes(self.as_mut_ptr(), 0, self.sl() * self.poly_count());
}
}
fn zero_at(&mut self, i: usize, j: usize) {
unsafe {
std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.n());
std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.sl());
}
}
}
// Blanket implementations
impl<T> ZnxZero for T where T: ZnxViewMut {}
impl<T> ZnxZero for T where T: ZnxViewMut + ZnxSliceSize {} // WARNING should not work for mat_znx_dft but it does
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};