merge VecZnxApi + Infos into VecZnxCommon + updated VecZnxApi generics

This commit is contained in:
Jean-Philippe Bossuat
2025-02-24 16:39:47 +01:00
parent cac4b3549d
commit 1a583ea0db
3 changed files with 98 additions and 50 deletions

View File

@@ -12,7 +12,7 @@ pub trait VecZnxVec {
fn dblptr_mut(&mut self) -> Vec<&mut [i64]>; fn dblptr_mut(&mut self) -> Vec<&mut [i64]>;
} }
impl<T: VecZnxApi + Infos> VecZnxVec for Vec<T> { impl<T: VecZnxCommon> VecZnxVec for Vec<T> {
fn dblptr(&self) -> Vec<&[i64]> { fn dblptr(&self) -> Vec<&[i64]> {
self.iter().map(|v| v.raw()).collect() self.iter().map(|v| v.raw()).collect()
} }
@@ -23,7 +23,7 @@ impl<T: VecZnxApi + Infos> VecZnxVec for Vec<T> {
} }
pub trait VecZnxApi: AsRef<Self> + AsMut<Self> { pub trait VecZnxApi: AsRef<Self> + AsMut<Self> {
type Owned: VecZnxApi + Infos; type Owned: VecZnxCommon;
fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned; fn from_bytes(n: usize, cols: usize, bytes: &mut [u8]) -> Self::Owned;
@@ -32,9 +32,9 @@ pub trait VecZnxApi: AsRef<Self> + AsMut<Self> {
fn bytes_of(n: usize, cols: usize) -> usize; fn bytes_of(n: usize, cols: usize) -> usize;
/// Copy the data of a onto self. /// Copy the data of a onto self.
fn copy_from<T: VecZnxApi + Infos>(&mut self, a: &T) fn copy_from<A: VecZnxCommon, B: VecZnxCommon>(&mut self, a: &A)
where where
Self: AsMut<T>; Self: AsMut<B>;
/// Returns the backing array. /// Returns the backing array.
fn raw(&self) -> &[i64]; fn raw(&self) -> &[i64];
@@ -86,9 +86,9 @@ pub trait VecZnxApi: AsRef<Self> + AsMut<Self> {
/// # Arguments /// # Arguments
/// ///
/// * `a`: the receiver polynomial in which the extracted coefficients are stored. /// * `a`: the receiver polynomial in which the extracted coefficients are stored.
fn switch_degree<T: VecZnxApi + Infos>(&self, a: &mut T) fn switch_degree<A: VecZnxCommon, B: VecZnxCommon>(&self, a: &mut A)
where where
Self: AsRef<T>; Self: AsRef<B>;
fn print(&self, cols: usize, n: usize); fn print(&self, cols: usize, n: usize);
} }
@@ -115,6 +115,8 @@ impl AsRef<VecZnxBorrow> for VecZnxBorrow {
} }
} }
impl VecZnxCommon for VecZnxBorrow {}
impl VecZnxApi for VecZnxBorrow { impl VecZnxApi for VecZnxBorrow {
type Owned = VecZnxBorrow; type Owned = VecZnxBorrow;
@@ -145,11 +147,11 @@ impl VecZnxApi for VecZnxBorrow {
bytes_of_vec_znx(n, cols) bytes_of_vec_znx(n, cols)
} }
fn copy_from<T: VecZnxApi + Infos>(&mut self, a: &T) fn copy_from<A: VecZnxCommon, B: VecZnxCommon>(&mut self, a: &A)
where where
Self: AsMut<T>, Self: AsMut<B>,
{ {
copy_vec_znx_from::<T>(self.as_mut(), a); copy_vec_znx_from::<A, B>(self.as_mut(), a);
} }
fn as_ptr(&self) -> *const i64 { fn as_ptr(&self) -> *const i64 {
@@ -200,9 +202,9 @@ impl VecZnxApi for VecZnxBorrow {
rsh(log_base2k, self, k, carry) rsh(log_base2k, self, k, carry)
} }
fn switch_degree<T: VecZnxApi + Infos>(&self, a: &mut T) fn switch_degree<A: VecZnxCommon, B: VecZnxCommon>(&self, a: &mut A)
where where
Self: AsRef<T>, Self: AsRef<B>,
{ {
switch_degree(a, self.as_ref()); switch_degree(a, self.as_ref());
} }
@@ -212,6 +214,8 @@ impl VecZnxApi for VecZnxBorrow {
} }
} }
impl VecZnxCommon for VecZnx {}
impl VecZnxApi for VecZnx { impl VecZnxApi for VecZnx {
type Owned = VecZnx; type Owned = VecZnx;
@@ -242,9 +246,9 @@ impl VecZnxApi for VecZnx {
bytes_of_vec_znx(n, cols) bytes_of_vec_znx(n, cols)
} }
fn copy_from<T: VecZnxApi + Infos>(&mut self, a: &T) fn copy_from<A: VecZnxCommon, B: VecZnxCommon>(&mut self, a: &A)
where where
Self: AsMut<T>, Self: AsMut<B>,
{ {
copy_vec_znx_from(self.as_mut(), a); copy_vec_znx_from(self.as_mut(), a);
} }
@@ -295,9 +299,9 @@ impl VecZnxApi for VecZnx {
rsh(log_base2k, self, k, carry) rsh(log_base2k, self, k, carry)
} }
fn switch_degree<T: VecZnxApi + Infos>(&self, a: &mut T) fn switch_degree<A: VecZnxCommon, B: VecZnxCommon>(&self, a: &mut A)
where where
Self: AsRef<T>, Self: AsRef<B>,
{ {
switch_degree(a, self.as_ref()) switch_degree(a, self.as_ref())
} }
@@ -332,7 +336,7 @@ impl AsRef<VecZnx> for VecZnx {
/// Copies the coefficients of `a` on the receiver. /// Copies the coefficients of `a` on the receiver.
/// Copy is done with the minimum size matching both backing arrays. /// Copy is done with the minimum size matching both backing arrays.
pub fn copy_vec_znx_from<T: VecZnxApi + Infos>(b: &mut T, a: &T) { pub fn copy_vec_znx_from<A: VecZnxCommon, B: VecZnxCommon>(b: &mut B, a: &A) {
let data_a: &[i64] = a.raw(); let data_a: &[i64] = a.raw();
let data_b: &mut [i64] = b.raw_mut(); let data_b: &mut [i64] = b.raw_mut();
let size = min(data_b.len(), data_a.len()); let size = min(data_b.len(), data_a.len());
@@ -373,7 +377,7 @@ impl VecZnx {
} }
} }
pub fn switch_degree<T: VecZnxApi + Infos>(b: &mut T, a: &T) { pub fn switch_degree<A: VecZnxCommon, B: VecZnxCommon>(b: &mut B, a: &A) {
let (n_in, n_out) = (a.n(), b.n()); let (n_in, n_out) = (a.n(), b.n());
let (gap_in, gap_out): (usize, usize); let (gap_in, gap_out): (usize, usize);
@@ -395,7 +399,7 @@ pub fn switch_degree<T: VecZnxApi + Infos>(b: &mut T, a: &T) {
}); });
} }
fn normalize<T: VecZnxApi + Infos>(log_base2k: usize, a: &mut T, carry: &mut [u8]) { fn normalize<T: VecZnxCommon>(log_base2k: usize, a: &mut T, carry: &mut [u8]) {
let n: usize = a.n(); let n: usize = a.n();
assert!( assert!(
@@ -422,7 +426,7 @@ fn normalize<T: VecZnxApi + Infos>(log_base2k: usize, a: &mut T, carry: &mut [u8
} }
} }
pub fn rsh<T: VecZnxApi + Infos>(log_base2k: usize, a: &mut T, k: usize, carry: &mut [u8]) { pub fn rsh<T: VecZnxCommon>(log_base2k: usize, a: &mut T, k: usize, carry: &mut [u8]) {
let n: usize = a.n(); let n: usize = a.n();
assert!( assert!(
@@ -462,6 +466,8 @@ pub fn rsh<T: VecZnxApi + Infos>(log_base2k: usize, a: &mut T, k: usize, carry:
} }
} }
pub trait VecZnxCommon: VecZnxApi + Infos {}
pub trait VecZnxOps { pub trait VecZnxOps {
/// Allocates a new [VecZnx]. /// Allocates a new [VecZnx].
/// ///
@@ -475,34 +481,50 @@ pub trait VecZnxOps {
fn bytes_of_vec_znx(&self, cols: usize) -> usize; fn bytes_of_vec_znx(&self, cols: usize) -> usize;
/// c <- a + b. /// c <- a + b.
fn vec_znx_add<T: VecZnxApi + Infos>(&self, c: &mut T, a: &T, b: &T); fn vec_znx_add<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>(
&self,
c: &mut C,
a: &A,
b: &B,
);
/// b <- b + a. /// b <- b + a.
fn vec_znx_add_inplace<T: VecZnxApi + Infos>(&self, b: &mut T, a: &T); fn vec_znx_add_inplace<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A);
/// c <- a - b. /// c <- a - b.
fn vec_znx_sub<T: VecZnxApi + Infos>(&self, c: &mut T, a: &T, b: &T); fn vec_znx_sub<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>(
&self,
c: &mut C,
a: &A,
b: &B,
);
/// b <- b - a. /// b <- b - a.
fn vec_znx_sub_inplace<T: VecZnxApi + Infos>(&self, b: &mut T, a: &T); fn vec_znx_sub_inplace<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A);
/// b <- -a. /// b <- -a.
fn vec_znx_negate<T: VecZnxApi + Infos>(&self, b: &mut T, a: &T); fn vec_znx_negate<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A);
/// b <- -b. /// b <- -b.
fn vec_znx_negate_inplace<T: VecZnxApi + Infos>(&self, a: &mut T); fn vec_znx_negate_inplace<A: VecZnxCommon>(&self, a: &mut A);
/// b <- a * X^k (mod X^{n} + 1) /// b <- a * X^k (mod X^{n} + 1)
fn vec_znx_rotate<T: VecZnxApi + Infos>(&self, k: i64, b: &mut T, a: &T); fn vec_znx_rotate<A: VecZnxCommon, B: VecZnxCommon>(&self, k: i64, b: &mut B, a: &A);
/// a <- a * X^k (mod X^{n} + 1) /// a <- a * X^k (mod X^{n} + 1)
fn vec_znx_rotate_inplace<T: VecZnxApi + Infos>(&self, k: i64, a: &mut T); fn vec_znx_rotate_inplace<A: VecZnxCommon>(&self, k: i64, a: &mut A);
/// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) /// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
fn vec_znx_automorphism<T: VecZnxApi + Infos>(&self, k: i64, b: &mut T, a: &T, a_cols: usize); fn vec_znx_automorphism<A: VecZnxCommon, B: VecZnxCommon>(
&self,
k: i64,
b: &mut B,
a: &A,
a_cols: usize,
);
/// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1)) /// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
fn vec_znx_automorphism_inplace<T: VecZnxApi + Infos>(&self, k: i64, a: &mut T, a_cols: usize); fn vec_znx_automorphism_inplace<A: VecZnxCommon>(&self, k: i64, a: &mut A, a_cols: usize);
/// Splits b into subrings and copies them them into a. /// Splits b into subrings and copies them them into a.
/// ///
@@ -510,7 +532,12 @@ pub trait VecZnxOps {
/// ///
/// This method requires that all [VecZnx] of b have the same ring degree /// This method requires that all [VecZnx] of b have the same ring degree
/// and that b.n() * b.len() <= a.n() /// and that b.n() * b.len() <= a.n()
fn vec_znx_split<T: VecZnxApi + Infos>(&self, b: &mut Vec<T>, a: &T, buf: &mut T); fn vec_znx_split<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>(
&self,
b: &mut Vec<B>,
a: &A,
buf: &mut C,
);
/// Merges the subrings a into b. /// Merges the subrings a into b.
/// ///
@@ -518,7 +545,7 @@ pub trait VecZnxOps {
/// ///
/// This method requires that all [VecZnx] of a have the same ring degree /// This method requires that all [VecZnx] of a have the same ring degree
/// and that a.n() * a.len() <= b.n() /// and that a.n() * a.len() <= b.n()
fn vec_znx_merge<T: VecZnxApi + Infos>(&self, b: &mut T, a: &Vec<T>); fn vec_znx_merge<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &Vec<A>);
} }
impl VecZnxOps for Module { impl VecZnxOps for Module {
@@ -531,7 +558,12 @@ impl VecZnxOps for Module {
} }
// c <- a + b // c <- a + b
fn vec_znx_add<T: VecZnxApi + Infos>(&self, c: &mut T, a: &T, b: &T) { fn vec_znx_add<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>(
&self,
c: &mut C,
a: &A,
b: &B,
) {
unsafe { unsafe {
vec_znx::vec_znx_add( vec_znx::vec_znx_add(
self.0, self.0,
@@ -549,7 +581,7 @@ impl VecZnxOps for Module {
} }
// b <- a + b // b <- a + b
fn vec_znx_add_inplace<T: VecZnxApi + Infos>(&self, b: &mut T, a: &T) { fn vec_znx_add_inplace<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A) {
unsafe { unsafe {
vec_znx::vec_znx_add( vec_znx::vec_znx_add(
self.0, self.0,
@@ -567,7 +599,12 @@ impl VecZnxOps for Module {
} }
// c <- a + b // c <- a + b
fn vec_znx_sub<T: VecZnxApi + Infos>(&self, c: &mut T, a: &T, b: &T) { fn vec_znx_sub<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>(
&self,
c: &mut C,
a: &A,
b: &B,
) {
unsafe { unsafe {
vec_znx::vec_znx_sub( vec_znx::vec_znx_sub(
self.0, self.0,
@@ -585,7 +622,7 @@ impl VecZnxOps for Module {
} }
// b <- a + b // b <- a + b
fn vec_znx_sub_inplace<T: VecZnxApi + Infos>(&self, b: &mut T, a: &T) { fn vec_znx_sub_inplace<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A) {
unsafe { unsafe {
vec_znx::vec_znx_sub( vec_znx::vec_znx_sub(
self.0, self.0,
@@ -602,7 +639,7 @@ impl VecZnxOps for Module {
} }
} }
fn vec_znx_negate<T: VecZnxApi + Infos>(&self, b: &mut T, a: &T) { fn vec_znx_negate<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &A) {
unsafe { unsafe {
vec_znx::vec_znx_negate( vec_znx::vec_znx_negate(
self.0, self.0,
@@ -616,7 +653,7 @@ impl VecZnxOps for Module {
} }
} }
fn vec_znx_negate_inplace<T: VecZnxApi + Infos>(&self, a: &mut T) { fn vec_znx_negate_inplace<A: VecZnxCommon>(&self, a: &mut A) {
unsafe { unsafe {
vec_znx::vec_znx_negate( vec_znx::vec_znx_negate(
self.0, self.0,
@@ -630,22 +667,22 @@ impl VecZnxOps for Module {
} }
} }
fn vec_znx_rotate<T: VecZnxApi + Infos>(&self, k: i64, a: &mut T, b: &T) { fn vec_znx_rotate<A: VecZnxCommon, B: VecZnxCommon>(&self, k: i64, b: &mut B, a: &A) {
unsafe { unsafe {
vec_znx::vec_znx_rotate( vec_znx::vec_znx_rotate(
self.0, self.0,
k, k,
a.as_mut_ptr(), b.as_mut_ptr(),
a.cols() as u64,
a.n() as u64,
b.as_ptr(),
b.cols() as u64, b.cols() as u64,
b.n() as u64, b.n() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
) )
} }
} }
fn vec_znx_rotate_inplace<T: VecZnxApi + Infos>(&self, k: i64, a: &mut T) { fn vec_znx_rotate_inplace<A: VecZnxCommon>(&self, k: i64, a: &mut A) {
unsafe { unsafe {
vec_znx::vec_znx_rotate( vec_znx::vec_znx_rotate(
self.0, self.0,
@@ -697,7 +734,13 @@ impl VecZnxOps for Module {
/// }); /// });
/// izip!(b.data.iter(), c.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); /// izip!(b.data.iter(), c.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
/// ``` /// ```
fn vec_znx_automorphism<T: VecZnxApi + Infos>(&self, k: i64, b: &mut T, a: &T, a_cols: usize) { fn vec_znx_automorphism<A: VecZnxCommon, B: VecZnxCommon>(
&self,
k: i64,
b: &mut B,
a: &A,
a_cols: usize,
) {
assert_eq!(a.n(), self.n()); assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n()); assert_eq!(b.n(), self.n());
assert!(a.cols() >= a_cols); assert!(a.cols() >= a_cols);
@@ -750,7 +793,7 @@ impl VecZnxOps for Module {
/// }); /// });
/// izip!(a.data.iter(), b.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); /// izip!(a.data.iter(), b.data.iter()).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
/// ``` /// ```
fn vec_znx_automorphism_inplace<T: VecZnxApi + Infos>(&self, k: i64, a: &mut T, a_cols: usize) { fn vec_znx_automorphism_inplace<A: VecZnxCommon>(&self, k: i64, a: &mut A, a_cols: usize) {
assert_eq!(a.n(), self.n()); assert_eq!(a.n(), self.n());
assert!(a.cols() >= a_cols); assert!(a.cols() >= a_cols);
unsafe { unsafe {
@@ -767,7 +810,12 @@ impl VecZnxOps for Module {
} }
} }
fn vec_znx_split<T: VecZnxApi + Infos>(&self, b: &mut Vec<T>, a: &T, buf: &mut T) { fn vec_znx_split<A: VecZnxCommon, B: VecZnxCommon, C: VecZnxCommon>(
&self,
b: &mut Vec<B>,
a: &A,
buf: &mut C,
) {
let (n_in, n_out) = (a.n(), b[0].n()); let (n_in, n_out) = (a.n(), b[0].n());
assert!( assert!(
@@ -793,7 +841,7 @@ impl VecZnxOps for Module {
}) })
} }
fn vec_znx_merge<T: VecZnxApi + Infos>(&self, b: &mut T, a: &Vec<T>) { fn vec_znx_merge<A: VecZnxCommon, B: VecZnxCommon>(&self, b: &mut B, a: &Vec<A>) {
let (n_in, n_out) = (b.n(), a[0].n()); let (n_in, n_out) = (b.n(), a[0].n());
assert!( assert!(

View File

@@ -1,4 +1,4 @@
use base2k::{Infos, Module, VecZnx, VecZnxApi, VecZnxBorrow, VecZnxOps, VmpPMat, VmpPMatOps}; use base2k::{Infos, Module, VecZnx, VecZnxBorrow, VecZnxOps, VmpPMat, VmpPMatOps};
use crate::parameters::Parameters; use crate::parameters::Parameters;
@@ -19,7 +19,7 @@ pub struct Elem<T> {
pub log_scale: usize, pub log_scale: usize,
} }
pub trait VecZnxCommon: VecZnxApi + Infos {} pub trait VecZnxCommon: base2k::VecZnxCommon {}
impl VecZnxCommon for VecZnx {} impl VecZnxCommon for VecZnx {}
impl VecZnxCommon for VecZnxBorrow {} impl VecZnxCommon for VecZnxBorrow {}

View File

@@ -6,7 +6,7 @@ use crate::plaintext::Plaintext;
use base2k::sampling::Sampling; use base2k::sampling::Sampling;
use base2k::{ use base2k::{
Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxBorrow, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxBorrow,
VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, cast_mut,
}; };
use sampling::source::{Source, new_seed}; use sampling::source::{Source, new_seed};