This commit is contained in:
Jean-Philippe Bossuat
2025-05-27 17:49:43 +02:00
parent dec3481a6f
commit a295085724
32 changed files with 897 additions and 1375 deletions

View File

@@ -1,19 +1,15 @@
use backend::{FFT64, Module, Scratch, VecZnx, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxZero};
use backend::{FFT64, Module, Scratch, VecZnx, VecZnxOps, ZnxZero};
use crate::{
elem::{Infos, SetMetaData},
glwe_ciphertext::GLWECiphertext,
glwe_ciphertext::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef},
};
impl<DataSelf> GLWECiphertext<DataSelf>
where
Self: Infos,
VecZnx<DataSelf>: VecZnxToMut,
{
pub fn add<A, B>(&mut self, module: &Module<FFT64>, a: &A, b: &B)
pub trait GLWEOps: GLWECiphertextToMut + Infos + SetMetaData {
fn add<A, B>(&mut self, module: &Module<FFT64>, a: &A, b: &B)
where
A: VecZnxToRef + Infos,
B: VecZnxToRef + Infos,
A: GLWECiphertextToRef + Infos,
B: GLWECiphertextToRef + Infos,
{
#[cfg(debug_assertions)]
{
@@ -28,25 +24,28 @@ where
let max_col: usize = a.rank().max(b.rank() + 1);
let self_col: usize = self.rank() + 1;
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref();
(0..min_col).for_each(|i| {
module.vec_znx_add(self, i, a, i, b, i);
module.vec_znx_add(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i);
});
if a.rank() > b.rank() {
(min_col..max_col).for_each(|i| {
module.vec_znx_copy(self, i, a, i);
module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
});
} else {
(min_col..max_col).for_each(|i| {
module.vec_znx_copy(self, i, b, i);
module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i);
});
}
let size: usize = self.size();
let mut self_mut: VecZnx<&mut [u8]> = self.to_mut();
let size: usize = self_mut.size();
(max_col..self_col).for_each(|i| {
(0..size).for_each(|j| {
self_mut.zero_at(i, j);
self_mut.data.zero_at(i, j);
});
});
@@ -54,9 +53,9 @@ where
self.set_k(a.k().max(b.k()));
}
pub fn add_inplace<A>(&mut self, module: &Module<FFT64>, a: &A)
fn add_inplace<A>(&mut self, module: &Module<FFT64>, a: &A)
where
A: VecZnxToRef + Infos,
A: GLWECiphertextToRef + Infos,
{
#[cfg(debug_assertions)]
{
@@ -66,17 +65,20 @@ where
assert!(self.rank() >= a.rank())
}
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..a.rank() + 1).for_each(|i| {
module.vec_znx_add_inplace(self, i, a, i);
module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i);
});
self.set_k(a.k().max(self.k()));
}
pub fn sub<A, B>(&mut self, module: &Module<FFT64>, a: &A, b: &B)
fn sub<A, B>(&mut self, module: &Module<FFT64>, a: &A, b: &B)
where
A: VecZnxToRef + Infos,
B: VecZnxToRef + Infos,
A: GLWECiphertextToRef + Infos,
B: GLWECiphertextToRef + Infos,
{
#[cfg(debug_assertions)]
{
@@ -91,26 +93,29 @@ where
let max_col: usize = a.rank().max(b.rank() + 1);
let self_col: usize = self.rank() + 1;
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref();
(0..min_col).for_each(|i| {
module.vec_znx_sub(self, i, a, i, b, i);
module.vec_znx_sub(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i);
});
if a.rank() > b.rank() {
(min_col..max_col).for_each(|i| {
module.vec_znx_copy(self, i, a, i);
module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
});
} else {
(min_col..max_col).for_each(|i| {
module.vec_znx_copy(self, i, b, i);
module.vec_znx_negate_inplace(self, i);
module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i);
module.vec_znx_negate_inplace(&mut self_mut.data, i);
});
}
let size: usize = self.size();
let mut self_mut: VecZnx<&mut [u8]> = self.to_mut();
let size: usize = self_mut.size();
(max_col..self_col).for_each(|i| {
(0..size).for_each(|j| {
self_mut.zero_at(i, j);
self_mut.data.zero_at(i, j);
});
});
@@ -118,9 +123,9 @@ where
self.set_k(a.k().max(b.k()));
}
pub fn sub_inplace_ab<A>(&mut self, module: &Module<FFT64>, a: &A)
fn sub_inplace_ab<A>(&mut self, module: &Module<FFT64>, a: &A)
where
A: VecZnxToRef + Infos,
A: GLWECiphertextToRef + Infos,
{
#[cfg(debug_assertions)]
{
@@ -130,16 +135,19 @@ where
assert!(self.rank() >= a.rank())
}
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..a.rank() + 1).for_each(|i| {
module.vec_znx_sub_ab_inplace(self, i, a, i);
module.vec_znx_sub_ab_inplace(&mut self_mut.data, i, &a_ref.data, i);
});
self.set_k(a.k().max(self.k()));
}
pub fn sub_inplace_ba<A>(&mut self, module: &Module<FFT64>, a: &A)
fn sub_inplace_ba<A>(&mut self, module: &Module<FFT64>, a: &A)
where
A: VecZnxToRef + Infos,
A: GLWECiphertextToRef + Infos,
{
#[cfg(debug_assertions)]
{
@@ -149,16 +157,19 @@ where
assert!(self.rank() >= a.rank())
}
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..a.rank() + 1).for_each(|i| {
module.vec_znx_sub_ba_inplace(self, i, a, i);
module.vec_znx_sub_ba_inplace(&mut self_mut.data, i, &a_ref.data, i);
});
self.set_k(a.k().max(self.k()));
}
pub fn rotate<A>(&mut self, module: &Module<FFT64>, k: i64, a: &A)
fn rotate<A>(&mut self, module: &Module<FFT64>, k: i64, a: &A)
where
A: VecZnxToRef + Infos,
A: GLWECiphertextToRef + Infos,
{
#[cfg(debug_assertions)]
{
@@ -167,28 +178,33 @@ where
assert_eq!(self.rank(), a.rank())
}
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..a.rank() + 1).for_each(|i| {
module.vec_znx_rotate(k, self, i, a, i);
module.vec_znx_rotate(k, &mut self_mut.data, i, &a_ref.data, i);
});
self.set_basek(a.basek());
self.set_k(a.k());
}
pub fn rotate_inplace(&mut self, module: &Module<FFT64>, k: i64) {
fn rotate_inplace(&mut self, module: &Module<FFT64>, k: i64) {
#[cfg(debug_assertions)]
{
assert_eq!(self.n(), module.n());
}
(0..self.rank() + 1).for_each(|i| {
module.vec_znx_rotate_inplace(k, self, i);
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
(0..self_mut.rank() + 1).for_each(|i| {
module.vec_znx_rotate_inplace(k, &mut self_mut.data, i);
});
}
pub fn copy<A>(&mut self, module: &Module<FFT64>, a: &A)
fn copy<A>(&mut self, module: &Module<FFT64>, a: &A)
where
A: VecZnxToRef + Infos,
A: GLWECiphertextToRef + Infos,
{
#[cfg(debug_assertions)]
{
@@ -197,23 +213,26 @@ where
assert_eq!(self.rank(), a.rank());
}
(0..self.rank() + 1).for_each(|i| {
module.vec_znx_copy(self, i, a, i);
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..self_mut.rank() + 1).for_each(|i| {
module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
});
self.set_k(a.k());
self.set_basek(a.basek());
}
pub fn rsh(&mut self, k: usize, scratch: &mut Scratch) {
fn rsh(&mut self, k: usize, scratch: &mut Scratch) {
let basek: usize = self.basek();
let mut self_mut: VecZnx<&mut [u8]> = self.to_mut();
self_mut.rsh(basek, k, scratch);
let mut self_mut: GLWECiphertext<&mut [u8]> = self.to_mut();
self_mut.data.rsh(basek, k, scratch);
}
pub fn normalize<A>(&mut self, module: &Module<FFT64>, a: &A, scratch: &mut Scratch)
fn normalize<A>(&mut self, module: &Module<FFT64>, a: &A, scratch: &mut Scratch)
where
A: VecZnxToMut + Infos,
A: GLWECiphertextToRef + Infos,
{
#[cfg(debug_assertions)]
{
@@ -222,20 +241,24 @@ where
assert_eq!(self.rank(), a.rank());
}
(0..self.rank() + 1).for_each(|i| {
module.vec_znx_normalize(a.basek(), self, i, a, i, scratch);
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
(0..self_mut.rank() + 1).for_each(|i| {
module.vec_znx_normalize(a.basek(), &mut self_mut.data, i, &a_ref.data, i, scratch);
});
self.set_basek(a.basek());
self.set_k(a.k());
}
pub fn normalize_inplace(&mut self, module: &Module<FFT64>, scratch: &mut Scratch) {
fn normalize_inplace(&mut self, module: &Module<FFT64>, scratch: &mut Scratch) {
#[cfg(debug_assertions)]
{
assert_eq!(self.n(), module.n());
}
(0..self.rank() + 1).for_each(|i| {
module.vec_znx_normalize_inplace(self.basek(), self, i, scratch);
let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
(0..self_mut.rank() + 1).for_each(|i| {
module.vec_znx_normalize_inplace(self_mut.basek(), &mut self_mut.data, i, scratch);
});
}
}