This commit is contained in:
Jean-Philippe Bossuat
2025-04-22 18:50:51 +02:00
parent d3e3594ae8
commit fbdb4436b2
18 changed files with 908 additions and 403 deletions

View File

@@ -36,6 +36,21 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_add(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a: *const i64,
a_size: u64,
a_sl: u64,
pmat: *const VMP_PMAT,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_tmp_bytes(
module: *const MODULE,
@@ -60,6 +75,20 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_to_dft_add(
module: *const MODULE,
res: *mut VEC_ZNX_DFT,
res_size: u64,
a_dft: *const VEC_ZNX_DFT,
a_size: u64,
pmat: *const VMP_PMAT,
nrows: u64,
ncols: u64,
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub unsafe fn vmp_apply_dft_to_dft_tmp_bytes(
module: *const MODULE,

View File

@@ -51,27 +51,23 @@ impl Module {
(self.n() << 1) as _
}
// GALOISGENERATOR^|gen| * sign(gen)
// Returns GALOISGENERATOR^|gen| * sign(gen)
pub fn galois_element(&self, gen: i64) -> i64 {
if gen == 0 {
return 1;
}
((mod_exp_u64(GALOISGENERATOR, gen.abs() as usize) & (self.cyclotomic_order() - 1)) as i64)
* gen.signum()
}
let mut gal_el: u64 = 1;
let mut gen_1_pow: u64 = GALOISGENERATOR;
let mut e: usize = gen.abs() as usize;
while e > 0 {
if e & 1 == 1 {
gal_el = gal_el.wrapping_mul(gen_1_pow);
}
gen_1_pow = gen_1_pow.wrapping_mul(gen_1_pow);
e >>= 1;
// Returns gen^-1
pub fn galois_element_inv(&self, gen: i64) -> i64 {
if gen == 0 {
panic!("cannot invert 0")
}
gal_el &= self.cyclotomic_order() - 1;
(gal_el as i64) * gen.signum()
((mod_exp_u64(gen.abs() as u64, (self.cyclotomic_order() - 1) as usize)
& (self.cyclotomic_order() - 1)) as i64)
* gen.signum()
}
pub fn free(self) {
@@ -79,3 +75,17 @@ impl Module {
drop(self);
}
}
fn mod_exp_u64(x: u64, e: usize) -> u64 {
let mut y: u64 = 1;
let mut x_pow: u64 = x;
let mut exp = e;
while exp > 0 {
if exp & 1 == 1 {
y = y.wrapping_mul(x_pow);
}
x_pow = x_pow.wrapping_mul(x_pow);
exp >>= 1;
}
y
}

View File

@@ -1,6 +1,6 @@
use crate::ffi::svp;
use crate::ffi::svp::{self, svp_ppol_t};
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::{assert_alignement, Module, VecZnx, VecZnxDft};
use crate::{assert_alignement, Module, VecZnx, VecZnxDft, BACKEND};
use crate::{alloc_aligned, cast_mut, Infos};
use rand::seq::SliceRandom;
@@ -35,15 +35,15 @@ impl Scalar {
self.n
}
pub fn buffer_size(n: usize) -> usize {
n
pub fn bytes_of(n: usize) -> usize {
n * std::mem::size_of::<i64>()
}
pub fn from_buffer(&mut self, n: usize, bytes: &mut [u8]) -> Self {
let size: usize = Self::buffer_size(n);
pub fn from_bytes(n: usize, bytes: &mut [u8]) -> Self {
let size: usize = Self::bytes_of(n);
debug_assert!(
bytes.len() == size,
"invalid buffer: bytes.len()={} < self.buffer_size(n={})={}",
"invalid buffer: bytes.len()={} < self.bytes_of(n={})={}",
bytes.len(),
n,
size
@@ -63,6 +63,28 @@ impl Scalar {
}
}
pub fn from_bytes_borrow(n: usize, bytes: &mut [u8]) -> Self {
let size: usize = Self::bytes_of(n);
debug_assert!(
bytes.len() == size,
"invalid buffer: bytes.len()={} < self.bytes_of(n={})={}",
bytes.len(),
n,
size
);
#[cfg(debug_assertions)]
{
assert_alignement(bytes.as_ptr())
}
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
Self {
n: n,
data: Vec::new(),
ptr: ptr,
}
}
pub fn as_ptr(&self) -> *const i64 {
self.ptr
}
@@ -87,26 +109,89 @@ impl Scalar {
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
self.data.shuffle(source);
}
pub fn as_vec_znx(&self) -> VecZnx {
VecZnx {
n: self.n,
cols: 1,
data: Vec::new(),
ptr: self.ptr,
}
}
}
pub struct SvpPPol(pub *mut svp::svp_ppol_t, pub usize);
pub trait ScalarOps {
fn bytes_of_scalar(&self) -> usize;
fn new_scalar(&self) -> Scalar;
fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar;
fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar;
}
impl ScalarOps for Module {
fn bytes_of_scalar(&self) -> usize {
Scalar::bytes_of(self.n())
}
fn new_scalar(&self) -> Scalar {
Scalar::new(self.n())
}
fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar {
Scalar::from_bytes(self.n(), bytes)
}
fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar {
Scalar::from_bytes_borrow(self.n(), tmp_bytes)
}
}
pub struct SvpPPol {
pub n: usize,
pub data: Vec<u8>,
pub ptr: *mut u8,
pub backend: BACKEND,
}
/// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft].
/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb.
/// The backend array of an [SvpPPol] is allocated in C and must be freed manually.
impl SvpPPol {
/// Returns the ring degree of the [SvpPPol].
pub fn n(&self) -> usize {
self.1
pub fn new(module: &Module) -> Self {
module.new_svp_ppol()
}
pub fn from_bytes(size: usize, bytes: &mut [u8]) -> SvpPPol {
/// Returns the ring degree of the [SvpPPol].
pub fn n(&self) -> usize {
self.n
}
pub fn bytes_of(module: &Module) -> usize {
module.bytes_of_svp_ppol()
}
pub fn from_bytes(module: &Module, bytes: &mut [u8]) -> SvpPPol {
#[cfg(debug_assertions)]
{
assert_alignement(bytes.as_ptr())
assert_alignement(bytes.as_ptr());
assert_eq!(bytes.len(), module.bytes_of_svp_ppol());
}
unsafe {
Self {
n: module.n(),
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
ptr: bytes.as_mut_ptr(),
backend: module.backend(),
}
}
}
pub fn from_bytes_borrow(module: &Module, tmp_bytes: &mut [u8]) -> SvpPPol {
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
assert_eq!(tmp_bytes.len(), module.bytes_of_svp_ppol());
}
Self {
n: module.n(),
data: Vec::new(),
ptr: tmp_bytes.as_mut_ptr(),
backend: module.backend(),
}
debug_assert!(bytes.len() << 3 >= size);
SvpPPol(bytes.as_mut_ptr() as *mut svp::svp_ppol_t, size)
}
/// Returns the number of cols of the [SvpPPol], which is always 1.
@@ -120,45 +205,64 @@ pub trait SvpPPolOps {
fn new_svp_ppol(&self) -> SvpPPol;
/// Returns the minimum number of bytes necessary to allocate
/// a new [SvpPPol] through [SvpPPol::from_bytes].
/// a new [SvpPPol] through [SvpPPol::from_bytes] ro.
fn bytes_of_svp_ppol(&self) -> usize;
/// Allocates a new [SvpPPol] from an array of bytes.
/// The array of bytes is owned by the [SvpPPol].
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol;
/// Allocates a new [SvpPPol] from an array of bytes.
/// The array of bytes is borrowed by the [SvpPPol].
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol;
/// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft].
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar);
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
/// the [VecZnxDft] is multiplied with [SvpPPol].
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_cols: usize);
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx);
}
impl SvpPPolOps for Module {
fn new_svp_ppol(&self) -> SvpPPol {
unsafe { SvpPPol(svp::new_svp_ppol(self.ptr), self.n()) }
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_svp_ppol());
let ptr: *mut u8 = data.as_mut_ptr();
SvpPPol {
data: data,
ptr: ptr,
n: self.n(),
backend: self.backend(),
}
}
fn bytes_of_svp_ppol(&self) -> usize {
unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize }
}
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) {
unsafe { svp::svp_prepare(self.ptr, svp_ppol.0, a.as_ptr()) }
fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol {
SvpPPol::from_bytes(self, bytes)
}
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx, b_cols: usize) {
debug_assert!(
c.cols() >= b_cols,
"invalid c_vector: c_vector.cols()={} < b.cols()={}",
c.cols(),
b_cols
);
fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol {
SvpPPol::from_bytes_borrow(self, tmp_bytes)
}
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) {
unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) }
}
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) {
unsafe {
svp::svp_apply_dft(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
b_cols as u64,
a.0,
c.cols() as u64,
a.ptr as *const svp_ppol_t,
b.as_ptr(),
b_cols as u64,
b.cols() as u64,
b.n() as u64,
)
}

View File

@@ -12,16 +12,16 @@ use std::cmp::min;
#[derive(Clone)]
pub struct VecZnx {
/// Polynomial degree.
n: usize,
pub n: usize,
/// Number of columns.
cols: usize,
pub cols: usize,
/// Polynomial coefficients, as a contiguous array. Each col is equally spaced by n.
data: Vec<i64>,
pub data: Vec<i64>,
/// Pointer to data (data can be enpty if [VecZnx] borrows space instead of owning it).
ptr: *mut i64,
pub ptr: *mut i64,
}
pub trait VecZnxVec {
@@ -363,10 +363,10 @@ pub trait VecZnxOps {
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx);
/// b <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx, a_cols: usize);
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx);
/// a <- phi_k(a) where phi_k: X^i -> X^{i*k} (mod (X^{n} + 1))
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_cols: usize);
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx);
/// Splits b into subrings and copies them them into a.
///
@@ -540,10 +540,9 @@ impl VecZnxOps for Module {
/// # Panics
///
/// The method will panic if the argument `a` is greater than `a.cols()`.
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx, a_cols: usize) {
fn vec_znx_automorphism(&self, k: i64, b: &mut VecZnx, a: &VecZnx) {
debug_assert_eq!(a.n(), self.n());
debug_assert_eq!(b.n(), self.n());
debug_assert!(a.cols() >= a_cols);
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
@@ -552,7 +551,7 @@ impl VecZnxOps for Module {
b.cols() as u64,
b.n() as u64,
a.as_ptr(),
a_cols as u64,
a.cols() as u64,
a.n() as u64,
);
}
@@ -569,9 +568,8 @@ impl VecZnxOps for Module {
/// # Panics
///
/// The method will panic if the argument `cols` is greater than `self.cols()`.
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_cols: usize) {
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx) {
debug_assert_eq!(a.n(), self.n());
debug_assert!(a.cols() >= a_cols);
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
@@ -580,7 +578,7 @@ impl VecZnxOps for Module {
a.cols() as u64,
a.n() as u64,
a.as_ptr(),
a_cols as u64,
a.cols() as u64,
a.n() as u64,
);
}

View File

@@ -16,6 +16,7 @@ impl VecZnxBig {
pub fn from_bytes(module: &Module, cols: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(cols));
assert_alignement(bytes.as_ptr())
};
unsafe {
@@ -54,14 +55,6 @@ impl VecZnxBig {
}
}
pub fn n(&self) -> usize {
self.n
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn backend(&self) -> BACKEND {
self.backend
}
@@ -77,12 +70,36 @@ impl VecZnxBig {
}
}
impl Infos for VecZnxBig {
/// Returns the base 2 logarithm of the [VecZnx] degree.
fn log_n(&self) -> usize {
(usize::BITS - (self.n - 1).leading_zeros()) as _
}
/// Returns the [VecZnx] degree.
fn n(&self) -> usize {
self.n
}
/// Returns the number of cols of the [VecZnx].
fn cols(&self) -> usize {
self.cols
}
/// Returns the number of rows of the [VecZnx].
fn rows(&self) -> usize {
1
}
}
pub trait VecZnxBigOps {
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
fn new_vec_znx_big(&self, cols: usize) -> VecZnxBig;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxBig].
@@ -92,6 +109,19 @@ pub trait VecZnxBigOps {
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
///
/// Behavior: the backing array is only borrowed.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxBig].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig;
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
fn bytes_of_vec_znx_big(&self, cols: usize) -> usize;
@@ -151,19 +181,13 @@ impl VecZnxBigOps for Module {
}
fn new_vec_znx_big_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxBig {
debug_assert!(
bytes.len() >= <Module as VecZnxBigOps>::bytes_of_vec_znx_big(self, cols),
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
bytes.len(),
<Module as VecZnxBigOps>::bytes_of_vec_znx_big(self, cols)
);
#[cfg(debug_assertions)]
{
assert_alignement(bytes.as_ptr())
}
VecZnxBig::from_bytes(self, cols, bytes)
}
fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig {
VecZnxBig::from_bytes_borrow(self, cols, tmp_bytes)
}
fn bytes_of_vec_znx_big(&self, cols: usize) -> usize {
unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, cols as u64) as usize }
}

View File

@@ -61,14 +61,6 @@ impl VecZnxDft {
}
}
pub fn n(&self) -> usize {
self.n
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn backend(&self) -> BACKEND {
self.backend
}
@@ -102,12 +94,36 @@ impl VecZnxDft {
}
}
impl Infos for VecZnxDft {
/// Returns the base 2 logarithm of the [VecZnx] degree.
fn log_n(&self) -> usize {
(usize::BITS - (self.n - 1).leading_zeros()) as _
}
/// Returns the [VecZnx] degree.
fn n(&self) -> usize {
self.n
}
/// Returns the number of cols of the [VecZnx].
fn cols(&self) -> usize {
self.cols
}
/// Returns the number of rows of the [VecZnx].
fn rows(&self) -> usize {
1
}
}
pub trait VecZnxDftOps {
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
fn new_vec_znx_dft(&self, cols: usize) -> VecZnxDft;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
@@ -117,6 +133,19 @@ pub trait VecZnxDftOps {
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes(&self, cols: usize, bytes: &mut [u8]) -> VecZnxDft;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: the backing array is only borrowed.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> VecZnxDft;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// # Arguments
@@ -133,28 +162,15 @@ pub trait VecZnxDftOps {
fn vec_znx_idft_tmp_bytes(&self) -> usize;
/// b <- IDFT(a), uses a as scratch space.
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize);
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft);
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, a_cols: usize, tmp_bytes: &mut [u8]);
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]);
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize);
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx);
fn vec_znx_dft_automorphism(
&self,
k: i64,
b: &mut VecZnxDft,
b_cols: usize,
a: &VecZnxDft,
a_cols: usize,
);
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft);
fn vec_znx_dft_automorphism_inplace(
&self,
k: i64,
a: &mut VecZnxDft,
a_cols: usize,
tmp_bytes: &mut [u8],
);
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]);
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize;
}
@@ -173,37 +189,25 @@ impl VecZnxDftOps for Module {
}
fn new_vec_znx_dft_from_bytes(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
debug_assert!(
tmp_bytes.len() >= Self::bytes_of_vec_znx_dft(self, cols),
"invalid bytes: bytes.len()={} < bytes_of_vec_znx_dft={}",
tmp_bytes.len(),
Self::bytes_of_vec_znx_dft(self, cols)
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr())
}
VecZnxDft::from_bytes(self, cols, tmp_bytes)
}
fn new_vec_znx_dft_from_bytes_borrow(&self, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
VecZnxDft::from_bytes_borrow(self, cols, tmp_bytes)
}
fn bytes_of_vec_znx_dft(&self, cols: usize) -> usize {
unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize }
}
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft, a_cols: usize) {
debug_assert!(
b.cols() >= a_cols,
"invalid c_vector: b_vector.cols()={} < a_cols={}",
b.cols(),
a_cols
);
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) {
unsafe {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
a.ptr as *mut vec_znx_dft_t,
a_cols as u64,
a.cols() as u64,
)
}
}
@@ -216,41 +220,23 @@ impl VecZnxDftOps for Module {
///
/// # Panics
/// If b.cols < a_cols
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx, a_cols: usize) {
debug_assert!(
b.cols() >= a_cols,
"invalid a_cols: b.cols()={} < a_cols={}",
b.cols(),
a_cols
);
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx) {
unsafe {
vec_znx_dft::vec_znx_dft(
self.ptr,
b.ptr as *mut vec_znx_dft_t,
b.cols() as u64,
a.as_ptr(),
a_cols as u64,
a.cols() as u64,
a.n() as u64,
)
}
}
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, a_cols: usize, tmp_bytes: &mut [u8]) {
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
assert!(
b.cols() >= a_cols,
"invalid c_vector: b.cols()={} < a_cols={}",
b.cols(),
a_cols
);
assert!(
a.cols() >= a_cols,
"invalid c_vector: a.cols()={} < a_cols={}",
a.cols(),
a_cols
);
assert!(
tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
@@ -263,65 +249,31 @@ impl VecZnxDftOps for Module {
vec_znx_dft::vec_znx_idft(
self.ptr,
b.ptr as *mut vec_znx_big_t,
a.cols() as u64,
b.cols() as u64,
a.ptr as *const vec_znx_dft_t,
a_cols as u64,
a.cols() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vec_znx_dft_automorphism(
&self,
k: i64,
b: &mut VecZnxDft,
b_cols: usize,
a: &VecZnxDft,
a_cols: usize,
) {
#[cfg(debug_assertions)]
{
assert!(
b.cols() >= a_cols,
"invalid c_vector: b.cols()={} < a_cols={}",
b.cols(),
a_cols
);
assert!(
a.cols() >= a_cols,
"invalid c_vector: a.cols()={} < a_cols={}",
a.cols(),
a_cols
);
}
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft) {
unsafe {
vec_znx_dft::vec_znx_dft_automorphism(
self.ptr,
k,
b.ptr as *mut vec_znx_dft_t,
b_cols as u64,
b.cols() as u64,
a.ptr as *const vec_znx_dft_t,
a_cols as u64,
a.cols() as u64,
[0u8; 0].as_mut_ptr(),
);
}
}
fn vec_znx_dft_automorphism_inplace(
&self,
k: i64,
a: &mut VecZnxDft,
a_cols: usize,
tmp_bytes: &mut [u8],
) {
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
assert!(
a.cols() >= a_cols,
"invalid c_vector: a.cols()={} < a_cols={}",
a.cols(),
a_cols
);
assert!(
tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}",
@@ -335,9 +287,9 @@ impl VecZnxDftOps for Module {
self.ptr,
k,
a.ptr as *mut vec_znx_dft_t,
a_cols as u64,
a.cols() as u64,
a.ptr as *const vec_znx_dft_t,
a_cols as u64,
a.cols() as u64,
tmp_bytes.as_mut_ptr(),
);
}
@@ -379,16 +331,16 @@ mod tests {
let p: i64 = -5;
// a_dft <- DFT(a)
module.vec_znx_dft(&mut a_dft, &a, cols);
module.vec_znx_dft(&mut a_dft, &a);
// a_dft <- AUTO(a_dft)
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, cols, &mut tmp_bytes);
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes);
// a <- AUTO(a)
module.vec_znx_automorphism_inplace(p, &mut a, cols);
module.vec_znx_automorphism_inplace(p, &mut a);
// b_dft <- DFT(AUTO(a))
module.vec_znx_dft(&mut b_dft, &a, cols);
module.vec_znx_dft(&mut b_dft, &a);
let a_f64: &[f64] = a_dft.raw(&module);
let b_f64: &[f64] = b_dft.raw(&module);

View File

@@ -253,6 +253,32 @@ pub trait VmpPMatOps {
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on the receiver.
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
///
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
/// `j` cols, the output is a [VecZnx] of `j` cols.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
/// ```text
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
/// |h i j|
/// |k l m|
/// ```
/// where each element is a [VecZnxDft].
///
/// # Arguments
///
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
/// * `a`: the left operand [VecZnx] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft].
///
/// # Arguments
@@ -296,6 +322,39 @@ pub trait VmpPMatOps {
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on top of the receiver instead of overwritting it.
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
///
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
/// `j` cols, the output is a [VecZnx] of `j` cols.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
/// ```text
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
/// |h i j|
/// |k l m|
/// ```
/// where each element is a [VecZnxDft].
///
/// # Arguments
///
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft_add(
&self,
c: &mut VecZnxDft,
a: &VecZnxDft,
b: &VmpPMat,
buf: &mut [u8],
);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place.
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
///
@@ -503,6 +562,30 @@ impl VmpPMatOps for Module {
}
}
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!(
tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_apply_dft_add(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.cols() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
b.cols() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_dft_to_dft_tmp_bytes(
&self,
res_cols: usize,
@@ -551,6 +634,36 @@ impl VmpPMatOps for Module {
}
}
fn vmp_apply_dft_to_dft_add(
&self,
c: &mut VecZnxDft,
a: &VecZnxDft,
b: &VmpPMat,
tmp_bytes: &mut [u8],
) {
debug_assert!(
tmp_bytes.len()
>= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols())
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_apply_dft_to_dft_add(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.cols() as u64,
a.ptr as *const vec_znx_dft_t,
a.cols() as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
b.cols() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!(
tmp_bytes.len()
@@ -604,7 +717,7 @@ mod tests {
for row_i in 0..vpmat_rows {
let mut source: Source = Source::new([0u8; 32]);
module.fill_uniform(log_base2k, &mut a, vpmat_cols, &mut source);
module.vec_znx_dft(&mut a_dft, &a, vpmat_cols);
module.vec_znx_dft(&mut a_dft, &a);
module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes);
// Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft)
@@ -617,7 +730,7 @@ mod tests {
// Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big)
module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i);
module.vec_znx_idft(&mut a_big, &a_dft, vpmat_cols, &mut tmp_bytes);
module.vec_znx_idft(&mut a_big, &a_dft, &mut tmp_bytes);
assert_eq!(a_big.raw::<i64>(&module), b_big.raw::<i64>(&module));
}