This commit is contained in:
Jean-Philippe Bossuat
2025-02-04 17:13:46 +01:00
parent e4a976ec9e
commit a790ff37cc
14 changed files with 1097 additions and 683 deletions

View File

@@ -1,10 +1,65 @@
use crate::ffi::svp::{delete_svp_ppol, new_svp_ppol, svp_apply_dft, svp_ppol_t, svp_prepare};
use crate::scalar::Scalar;
use crate::ffi::svp;
use crate::{Free, Module, VecZnx, VecZnxDft};
pub struct SvpPPol(pub *mut svp_ppol_t, pub usize);
use crate::Infos;
use rand::seq::SliceRandom;
use rand_core::RngCore;
use rand_distr::{Distribution, WeightedIndex};
use sampling::source::Source;
/// A prepared [crate::Scalar] for [ScalarVectorProduct::svp_apply_dft].
pub struct Scalar(pub Vec<i64>);
impl Module {
pub fn new_scalar(&self) -> Scalar {
Scalar::new(self.n())
}
}
impl Scalar {
pub fn new(n: usize) -> Self {
Self(vec![i64::default(); Self::buffer_size(n)])
}
pub fn buffer_size(n: usize) -> usize {
n
}
pub fn from_buffer(&mut self, n: usize, buf: &[i64]) {
let size: usize = Self::buffer_size(n);
assert!(
buf.len() >= size,
"invalid buffer: buf.len()={} < self.buffer_size(n={})={}",
buf.len(),
n,
size
);
self.0 = Vec::from(&buf[..size])
}
pub fn as_ptr(&self) -> *const i64 {
self.0.as_ptr()
}
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
let choices: [i64; 3] = [-1, 0, 1];
let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0];
let dist: WeightedIndex<f64> = WeightedIndex::new(&weights).unwrap();
self.0
.iter_mut()
.for_each(|x: &mut i64| *x = choices[dist.sample(source)]);
}
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
self.0[..hw]
.iter_mut()
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
self.0.shuffle(source);
}
}
pub struct SvpPPol(pub *mut svp::svp_ppol_t, pub usize);
/// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft].
/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb.
/// The backend array of an [SvpPPol] is allocated in C and must be freed manually.
impl SvpPPol {
@@ -19,15 +74,8 @@ impl SvpPPol {
}
}
impl Free for SvpPPol {
fn free(self) {
unsafe { delete_svp_ppol(self.0) };
let _ = drop(self);
}
}
pub trait ScalarVectorProduct {
/// Prepares a [crate::Scalar] for a [ScalarVectorProduct::svp_apply_dft].
pub trait SvpPPolOps {
/// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft].
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar);
/// Allocates a new [SvpPPol].
@@ -38,16 +86,16 @@ pub trait ScalarVectorProduct {
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx);
}
impl Module {
pub fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) {
unsafe { svp_prepare(self.0, svp_ppol.0, a.as_ptr()) }
impl SvpPPolOps for Module {
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) {
unsafe { svp::svp_prepare(self.0, svp_ppol.0, a.as_ptr()) }
}
pub fn svp_new_ppol(&self) -> SvpPPol {
unsafe { SvpPPol(new_svp_ppol(self.0), self.n()) }
fn svp_new_ppol(&self) -> SvpPPol {
unsafe { SvpPPol(svp::new_svp_ppol(self.0), self.n()) }
}
pub fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) {
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) {
let limbs: u64 = b.limbs() as u64;
assert!(
c.limbs() as u64 >= limbs,
@@ -55,6 +103,6 @@ impl Module {
c.limbs(),
limbs
);
unsafe { svp_apply_dft(self.0, c.0, limbs, a.0, b.as_ptr(), limbs, b.n() as u64) }
unsafe { svp::svp_apply_dft(self.0, c.0, limbs, a.0, b.as_ptr(), limbs, b.n() as u64) }
}
}