Remove Zn (replaced by VecZnx), add more cross-base2k ops & tests

This commit is contained in:
Pro7ech
2025-11-18 01:08:20 +01:00
parent a3264b8851
commit f39e3e2865
52 changed files with 952 additions and 1550 deletions

View File

@@ -1,5 +1,5 @@
use itertools::izip;
use poulpy_backend::cpu_spqlios::FFT64Spqlios;
use poulpy_backend::cpu_fft64_ref::FFT64Ref;
use poulpy_hal::{
api::{
ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPrepare, VecZnxAddNormal,
@@ -16,9 +16,9 @@ fn main() {
let ct_size: usize = 3;
let msg_size: usize = 2;
let log_scale: usize = msg_size * base2k - 5;
let module: Module<FFT64Spqlios> = Module::<FFT64Spqlios>::new(n as u64);
let module: Module<FFT64Ref> = Module::<FFT64Ref>::new(n as u64);
let mut scratch: ScratchOwned<FFT64Spqlios> = ScratchOwned::<FFT64Spqlios>::alloc(module.vec_znx_big_normalize_tmp_bytes());
let mut scratch: ScratchOwned<FFT64Ref> = ScratchOwned::<FFT64Ref>::alloc(module.vec_znx_big_normalize_tmp_bytes());
let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed);
@@ -28,7 +28,7 @@ fn main() {
s.fill_ternary_prob(0, 0.5, &mut source);
// Buffer to store s in the DFT domain
let mut s_dft: SvpPPol<Vec<u8>, FFT64Spqlios> = module.svp_ppol_alloc(s.cols());
let mut s_dft: SvpPPol<Vec<u8>, FFT64Ref> = module.svp_ppol_alloc(s.cols());
// s_dft <- DFT(s)
module.svp_prepare(&mut s_dft, 0, &s, 0);
@@ -43,7 +43,7 @@ fn main() {
// Fill the second column with random values: ct = (0, a)
module.vec_znx_fill_uniform(base2k, &mut ct, 1, &mut source);
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64Spqlios> = module.vec_znx_dft_alloc(1, ct_size);
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64Ref> = module.vec_znx_dft_alloc(1, ct_size);
module.vec_znx_dft_apply(1, 0, &mut buf_dft, 0, &ct, 1);
@@ -58,7 +58,7 @@ fn main() {
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
let mut buf_big: VecZnxBig<Vec<u8>, FFT64Spqlios> = module.vec_znx_big_alloc(1, ct_size);
let mut buf_big: VecZnxBig<Vec<u8>, FFT64Ref> = module.vec_znx_big_alloc(1, ct_size);
module.vec_znx_idft_apply_tmpa(&mut buf_big, 0, &mut buf_dft, 0);
// Creates a plaintext: VecZnx with 1 column

View File

@@ -7,7 +7,6 @@ mod vec_znx;
mod vec_znx_big;
mod vec_znx_dft;
mod vmp;
mod zn;
mod znx_avx;
pub struct FFT64Avx {}

View File

@@ -1,73 +0,0 @@
use poulpy_hal::{
api::TakeSlice,
layouts::{Scratch, ZnToMut},
oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl},
reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform, zn_normalize_inplace, zn_normalize_tmp_bytes},
source::Source,
};
use crate::cpu_fft64_avx::FFT64Avx;
unsafe impl ZnNormalizeTmpBytesImpl<Self> for FFT64Avx {
fn zn_normalize_tmp_bytes_impl(n: usize) -> usize {
zn_normalize_tmp_bytes(n)
}
}
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Avx
where
Self: TakeSliceImpl<Self>,
{
fn zn_normalize_inplace_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
where
R: ZnToMut,
{
let (carry, _) = scratch.take_slice(n);
zn_normalize_inplace::<R, FFT64Avx>(n, base2k, res, res_col, carry);
}
}
unsafe impl ZnFillUniformImpl<Self> for FFT64Avx {
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
zn_fill_uniform(n, base2k, res, res_col, source);
}
}
unsafe impl ZnFillNormalImpl<Self> for FFT64Avx {
#[allow(clippy::too_many_arguments)]
fn zn_fill_normal_impl<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}
unsafe impl ZnAddNormalImpl<Self> for FFT64Avx {
#[allow(clippy::too_many_arguments)]
fn zn_add_normal_impl<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}

View File

@@ -6,7 +6,6 @@ mod vec_znx;
mod vec_znx_big;
mod vec_znx_dft;
mod vmp;
mod zn;
mod znx;
#[cfg(test)]

View File

@@ -1,73 +0,0 @@
use poulpy_hal::{
api::TakeSlice,
layouts::{Scratch, ZnToMut},
oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl, ZnNormalizeTmpBytesImpl},
reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform, zn_normalize_inplace, zn_normalize_tmp_bytes},
source::Source,
};
use crate::cpu_fft64_ref::FFT64Ref;
unsafe impl ZnNormalizeTmpBytesImpl<Self> for FFT64Ref {
fn zn_normalize_tmp_bytes_impl(n: usize) -> usize {
zn_normalize_tmp_bytes(n)
}
}
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Ref
where
Self: TakeSliceImpl<Self>,
{
fn zn_normalize_inplace_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, scratch: &mut Scratch<Self>)
where
R: ZnToMut,
{
let (carry, _) = scratch.take_slice(n);
zn_normalize_inplace::<R, FFT64Ref>(n, base2k, res, res_col, carry);
}
}
unsafe impl ZnFillUniformImpl<Self> for FFT64Ref {
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
zn_fill_uniform(n, base2k, res, res_col, source);
}
}
unsafe impl ZnFillNormalImpl<Self> for FFT64Ref {
#[allow(clippy::too_many_arguments)]
fn zn_fill_normal_impl<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}
unsafe impl ZnAddNormalImpl<Self> for FFT64Ref {
#[allow(clippy::too_many_arguments)]
fn zn_add_normal_impl<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}

View File

@@ -5,7 +5,6 @@ mod vec_znx;
mod vec_znx_big;
mod vec_znx_dft;
mod vmp_pmat;
mod zn;
mod znx;
pub struct FFT64Spqlios;

View File

@@ -1,82 +0,0 @@
use poulpy_hal::{
api::TakeSlice,
layouts::{Scratch, Zn, ZnToMut, ZnxInfos, ZnxSliceSize, ZnxView, ZnxViewMut},
oep::{TakeSliceImpl, ZnAddNormalImpl, ZnFillNormalImpl, ZnFillUniformImpl, ZnNormalizeInplaceImpl},
reference::zn::{zn_add_normal, zn_fill_normal, zn_fill_uniform},
source::Source,
};
use crate::cpu_spqlios::{FFT64Spqlios, ffi::zn64};
unsafe impl ZnNormalizeInplaceImpl<Self> for FFT64Spqlios
where
Self: TakeSliceImpl<Self>,
{
fn zn_normalize_inplace_impl<A>(n: usize, base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch<Self>)
where
A: ZnToMut,
{
let mut a: Zn<&mut [u8]> = a.to_mut();
let (tmp_bytes, _) = scratch.take_slice(n * size_of::<i64>());
unsafe {
zn64::zn64_normalize_base2k_ref(
n as u64,
base2k as u64,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
}
unsafe impl ZnFillUniformImpl<Self> for FFT64Spqlios {
fn zn_fill_uniform_impl<R>(n: usize, base2k: usize, res: &mut R, res_col: usize, source: &mut Source)
where
R: ZnToMut,
{
zn_fill_uniform(n, base2k, res, res_col, source);
}
}
unsafe impl ZnFillNormalImpl<Self> for FFT64Spqlios {
#[allow(clippy::too_many_arguments)]
fn zn_fill_normal_impl<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
zn_fill_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}
unsafe impl ZnAddNormalImpl<Self> for FFT64Spqlios {
#[allow(clippy::too_many_arguments)]
fn zn_add_normal_impl<R>(
n: usize,
base2k: usize,
res: &mut R,
res_col: usize,
k: usize,
source: &mut Source,
sigma: f64,
bound: f64,
) where
R: ZnToMut,
{
zn_add_normal(n, base2k, res, res_col, k, source, sigma, bound);
}
}