merge VecZnxApi + Infos into VecZnxCommon + updated VecZnxApi generics

This commit is contained in:
Jean-Philippe Bossuat
2025-02-24 16:39:47 +01:00
parent cac4b3549d
commit 1a583ea0db
3 changed files with 98 additions and 50 deletions

View File

@@ -12,7 +12,7 @@ pub trait VecZnxVec {
fn dblptr_mut(&mut self) -> Vec<&mut [i64]>;
}
impl<T: VecZnxApi + Infos> VecZnxVec for Vec<T> {
impl<T: VecZnxCommon> VecZnxVec for Vec<T> {
fn dblptr(&self) -> Vec<&[i64]> {
self.iter().map(|v| v.raw()).collect()
}
@@ -23,7 +23,7 @@ impl<T: VecZnxApi + Infos> VecZnxVec for Vec<T> {
}
pub trait VecZnxApi: AsRef<Self> + AsMut<Self> {
type Owned: VecZnxApi + Infos;
type Owned: VecZnxCommon;
fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned;
@@ -32,9 +32,9 @@ pub trait VecZnxApi: AsRef<Self> + AsMut<Self> {
fn bytes_of(n: usize, cols: usize) -> usize;
/// Copy the data of a onto self.
fn copy_from<T: VecZnxApi + Infos>(&mut self, a: &T)
fn copy_from<A: VecZnxCommon, B: VecZnxCommon>(&mut self, a: &A)
where
Self: AsMut<T>;
Self: AsMut<B>;
/// Returns the backing array.
fn raw(&self) -> &[i64];
@@ -86,9 +86,9 @@ pub trait VecZnxApi: AsRef<Self> + AsMut<Self> {
/// # Arguments
///
/// * `a`: the receiver polynomial in which the extracted coefficients are stored.
fn switch_degree<T: VecZnxApi + Infos>(&self, a: &mut T)
fn switch_degree<A: VecZnxCommon, B: VecZnxCommon>(&self, a: &mut A)
where
Self: AsRef<T>;
Self: AsRef<B>;
fn print(&self, cols: usize, n: usize);
}
@@ -115,6 +115,8 @@ impl AsRef<VecZnxBorrow> for VecZnxBorrow {
}
}
impl VecZnxCommon for VecZnxBorrow {}
impl VecZnxApi for VecZnxBorrow {
type Owned = VecZnxBorrow;
@@ -145,11 +147,11 @@ impl VecZnxApi for VecZnxBorrow {
bytes_of_vec_znx(n, cols)
}
fn copy_from<T: VecZnxApi + Infos>(&mut self, a: &T)
fn copy_from<A: VecZnxCommon, B: VecZnxCommon>(&mut self, a: &A)
where
Self: AsMut<T>,
Self: AsMut<B>,
{
copy_vec_znx_from::<T>(self.as_mut(), a);
copy_vec_znx_from::<A, B>(self.as_mut(), a);
}
fn as_ptr(&self) -> *const i64 {
@@ -200,9 +202,9 @@ impl VecZnxApi for VecZnxBorrow {
rsh(log_base2k, self, k, carry)
}
fn switch_degree<T: VecZnxApi + Infos>(&self, a: &mut T)
fn switch_degree<A: VecZnxCommon, B: VecZnxCommon>(&self, a: &mut A)
where
Self: AsRef<T>,
Self: AsRef<B>,
{
switch_degree(a, self.as_ref());
}
@@ -212,6 +214,8 @@ impl VecZnxApi for VecZnxBorrow {
}
}
impl VecZnxCommon for VecZnx {}
impl VecZnxApi for VecZnx {
type Owned = VecZnx;
@@ -242,9 +246,9 @@ impl VecZnxApi for VecZnx {
bytes_of_vec_znx(n, cols)
}
fn copy_from<T: VecZnxApi + Infos>(&mut self, a: &T)
fn copy_from<A: VecZnxCommon, B: VecZnxCommon>(&mut self, a: &A)
where
Self: AsMut<T>,
Self: AsMut<B>,
{
copy_vec_znx_from(self.as_mut(), a);
}
@@ -295,9 +299,9 @@ impl VecZnxApi for VecZnx {
rsh(log_base2k, self, k, carry)
}
fn switch_degree<T: VecZnxApi + Infos>(&self, a: &mut T)
fn switch_degree<A: VecZnxCommon, B: VecZnxCommon>(&self, a: &mut A)
where
Self: AsRef<T>,
Self: AsRef<B>,
{
switch_degree(a, self.as_ref())
}
@@ -332,7 +336,7 @@ impl AsRef<VecZnx> for VecZnx {
/// Copies the coefficients of `a` on the receiver.
/// Copy is done with the minimum size matching both backing arrays.
pub fn copy_vec_znx_from<T: VecZnxApi + Infos>(b: &mut T, a: &T) {
pub fn copy_vec_znx_from<A: VecZnxCommon, B: VecZnxCommon>(b: &mut B, a: &A) {
let data_a: &[i64] = a.raw();
let data_b: &mut [i64] = b.raw_mut();
let size = min(data_b.len(), data_a.len());
@@ -373,7 +377,7 @@ impl VecZnx {
}
}
pub fn switch_degree<T: VecZnxApi + Infos>(b: &mut T, a: &T) {
pub fn switch_degree<A: VecZnxCommon, B: VecZnxCommon>(b: &mut B, a: &A) {
let (n_in, n_out) = (a.n(), b.n());
let (gap_in, gap_out): (usize, usize);
@@ -395,7 +399,7 @@ pub fn switch_degree<T: VecZnxApi + Infos>(b: &mut T, a: &T) {
});
}
fn normalize<T: VecZnxApi + Infos>(log_base2k: usize, a: &mut T, carry: &mut [u8]) {
fn normalize<T: VecZnxCommon>(log_base2k: usize, a: &mut T, carry: &mut [u8]) {
let n: usize = a.n();
assert!(
@@ -422,7 +426,7 @@ fn normalize<T: VecZnxApi + Infos>(log_base2k: usize, a: &mut T, carry: &mut [u8
}
}
pub fn rsh<T: VecZnxApi + Infos>(log_base2k: usize, a: &mut T, k: usize, carry: &mut [u8]) {
pub fn rsh<T: VecZnxCommon>(log_base2k: usize, a: &mut T, k: usize, carry: &mut [u8]) {
let n: usize = a.n();
assert!(
@@ -462,6 +466,8 @@ pub fn rsh<T: VecZnxApi + Infos>(log_base2k: usize, a: &mut T, k: usize, carry:
}
}
pub trait VecZnxCommon: VecZnxApi + Infos {}
pub trait VecZnxOps {
/// Allocates a new [VecZnx].
///
@@ -475,34 +481,50 @@ pub trait VecZnxOps {
fn bytes_of_vec_znx(&self, cols: usize) -> usize;
/// c <- a + b.
fn vec_znx_add<T: VecZnxApi + Infos>(&self, c: &mut T, a: &T, b: &T);
fn vec_znx_add<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>(
&self,
c: &mut C,
a: &A,
b: &B,
);
/// b <- b + a.
fn vec_znx_add_inplace<T: VecZnxApi + Infos>(&self, b: &mut T, a: &T);
fn vec_znx_add_inplace<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A);
/// c <- a - b.
fn vec_znx_sub<T: VecZnxApi + Infos>(&self, c: &mut T, a: &T, b: &T);
fn vec_znx_sub<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>(
&self,
c: &mut C,
a: &A,
b: &B,
);
/// b <- b - a.
fn vec_znx_sub_inplace<T: VecZnxApi + Infos>(&self, b: &mut T, a: &T);
fn vec_znx_sub_inplace<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A);
/// b <- -a.
fn vec_znx_negate<T: VecZnxApi + Infos>(&self, b: &mut T, a: &T);
fn vec_znx_negate<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A);
/// b <- -b.
fn vec_znx_negate_inplace<T: VecZnxApi + Infos>(&self, a: &mut T);
fn vec_znx_negate_inplace<A: VecZnxCommon>(&self, a: &mut A);
/// b <- a * X^k (mod X^{n} + 1)
fn vec_znx_rotate<T: VecZnxApi + Infos>(&self, k: i64, b: &mut T, a: &T);
fn vec_znx_rotate<A: VecZnxCommon, B: VecZnxCommon>(&self, k: i64, b: &mut B, a: &A);
/// a <- a * X^k (mod X^{n} + 1)
fn vec_znx_rotate_inplace<T: VecZnxApi + Infos>(&self, k: i64, a: &mut T);
fn vec_znx_rotate_inplace<A: VecZnxCommon>(&self, k: i64, a: &mut A);
/// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
fn vec_znx_automorphism<T: VecZnxApi + Infos>(&self, k: i64, b: &mut T, a: &T, a_cols: usize);
fn vec_znx_automorphism<A: VecZnxCommon, B: VecZnxCommon>(
&self,
k: i64,
b: &mut B,
a: &A,
a_cols: usize,
);
/// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
fn vec_znx_automorphism_inplace<T: VecZnxApi + Infos>(&self, k: i64, a: &mut T, a_cols: usize);
fn vec_znx_automorphism_inplace<A: VecZnxCommon>(&self, k: i64, a: &mut A, a_cols: usize);
/// Splits b into subrings and copies them them into a.
///
@@ -510,7 +532,12 @@ pub trait VecZnxOps {
///
/// This method requires that all [VecZnx] of b have the same ring degree
/// and that b.n() * b.len() <= a.n()
fn vec_znx_split<T: VecZnxApi + Infos>(&self, b: &mut Vec<T>, a: &T, buf: &mut T);
fn vec_znx_split<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>(
&self,
b: &mut Vec<B>,
a: &A,
buf: &mut C,
);
/// Merges the subrings a into b.
///
@@ -518,7 +545,7 @@ pub trait VecZnxOps {
///
/// This method requires that all [VecZnx] of a have the same ring degree
/// and that a.n() * a.len() <= b.n()
fn vec_znx_merge<T: VecZnxApi + Infos>(&self, b: &mut T, a: &Vec<T>);
fn vec_znx_merge<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &Vec<A>);
}
impl VecZnxOps for Module {
@@ -531,7 +558,12 @@ impl VecZnxOps for Module {
}
// c <- a + b
fn vec_znx_add<T: VecZnxApi + Infos>(&self, c: &mut T, a: &T, b: &T) {
fn vec_znx_add<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>(
&self,
c: &mut C,
a: &A,
b: &B,
) {
unsafe {
vec_znx::vec_znx_add(
self.0,
@@ -549,7 +581,7 @@ impl VecZnxOps for Module {
}
// b <- a + b
fn vec_znx_add_inplace<T: VecZnxApi + Infos>(&self, b: &mut T, a: &T) {
fn vec_znx_add_inplace<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A) {
unsafe {
vec_znx::vec_znx_add(
self.0,
@@ -567,7 +599,12 @@ impl VecZnxOps for Module {
}
// c <- a + b
fn vec_znx_sub<T: VecZnxApi + Infos>(&self, c: &mut T, a: &T, b: &T) {
fn vec_znx_sub<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>(
&self,
c: &mut C,
a: &A,
b: &B,
) {
unsafe {
vec_znx::vec_znx_sub(
self.0,
@@ -585,7 +622,7 @@ impl VecZnxOps for Module {
}
// b <- a + b
fn vec_znx_sub_inplace<T: VecZnxApi + Infos>(&self, b: &mut T, a: &T) {
fn vec_znx_sub_inplace<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A) {
unsafe {
vec_znx::vec_znx_sub(
self.0,
@@ -602,7 +639,7 @@ impl VecZnxOps for Module {
}
}
fn vec_znx_negate<T: VecZnxApi + Infos>(&self, b: &mut T, a: &T) {
fn vec_znx_negate<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A) {
unsafe {
vec_znx::vec_znx_negate(
self.0,
@@ -616,7 +653,7 @@ impl VecZnxOps for Module {
}
}
fn vec_znx_negate_inplace<T: VecZnxApi + Infos>(&self, a: &mut T) {
fn vec_znx_negate_inplace<A: VecZnxCommon>(&self, a: &mut A) {
unsafe {
vec_znx::vec_znx_negate(
self.0,
@@ -630,22 +667,22 @@ impl VecZnxOps for Module {
}
}
fn vec_znx_rotate<T: VecZnxApi + Infos>(&self, k: i64, a: &mut T, b: &T) {
fn vec_znx_rotate<A: VecZnxCommon, B: VecZnxCommon>(&self, k: i64, b: &mut B, a: &A) {
unsafe {
vec_znx::vec_znx_rotate(
self.0,
k,
a.as_mut_ptr(),
a.cols() as u64,
a.n() as u64,
b.as_ptr(),
b.as_mut_ptr(),
b.cols() as u64,
b.n() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
)
}
}
fn vec_znx_rotate_inplace<T: VecZnxApi + Infos>(&self, k: i64, a: &mut T) {
fn vec_znx_rotate_inplace<A: VecZnxCommon>(&self, k: i64, a: &mut A) {
unsafe {
vec_znx::vec_znx_rotate(
self.0,
@@ -697,7 +734,13 @@ impl VecZnxOps for Module {
/// });
/// izip!(b.data.iter(), c.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
/// ```
fn vec_znx_automorphism<T: VecZnxApi + Infos>(&self, k: i64, b: &mut T, a: &T, a_cols: usize) {
fn vec_znx_automorphism<A: VecZnxCommon, B: VecZnxCommon>(
&self,
k: i64,
b: &mut B,
a: &A,
a_cols: usize,
) {
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert!(a.cols() >= a_cols);
@@ -750,7 +793,7 @@ impl VecZnxOps for Module {
/// });
/// izip!(a.data.iter(), b.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
/// ```
fn vec_znx_automorphism_inplace<T: VecZnxApi + Infos>(&self, k: i64, a: &mut T, a_cols: usize) {
fn vec_znx_automorphism_inplace<A: VecZnxCommon>(&self, k: i64, a: &mut A, a_cols: usize) {
assert_eq!(a.n(), self.n());
assert!(a.cols() >= a_cols);
unsafe {
@@ -767,7 +810,12 @@ impl VecZnxOps for Module {
}
}
fn vec_znx_split<T: VecZnxApi + Infos>(&self, b: &mut Vec<T>, a: &T, buf: &mut T) {
fn vec_znx_split<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>(
&self,
b: &mut Vec<B>,
a: &A,
buf: &mut C,
) {
let (n_in, n_out) = (a.n(), b[0].n());
assert!(
@@ -793,7 +841,7 @@ impl VecZnxOps for Module {
})
}
fn vec_znx_merge<T: VecZnxApi + Infos>(&self, b: &mut T, a: &Vec<T>) {
fn vec_znx_merge<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &Vec<A>) {
let (n_in, n_out) = (b.n(), a[0].n());
assert!(