mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
reworked scalar
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
use base2k::{
|
use base2k::{
|
||||||
Encoding, FFT64, Module, Sampling, Scalar, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
|
Encoding, FFT64, Module, Sampling, Scalar, ScalarOps, ScalarZnxDft, ScalarZnxDftOps, VecZnx, VecZnxBig, VecZnxBigOps,
|
||||||
VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned,
|
VecZnxDft, VecZnxDftOps, VecZnxOps, ZnxInfos, alloc_aligned,
|
||||||
};
|
};
|
||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
@@ -19,14 +19,14 @@ fn main() {
|
|||||||
let mut source: Source = Source::new(seed);
|
let mut source: Source = Source::new(seed);
|
||||||
|
|
||||||
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
|
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
|
||||||
let mut s: Scalar = Scalar::new(n);
|
let mut s: Scalar = module.new_scalar(1);
|
||||||
s.fill_ternary_prob(0.5, &mut source);
|
s.fill_ternary_prob(0, 0.5, &mut source);
|
||||||
|
|
||||||
// Buffer to store s in the DFT domain
|
// Buffer to store s in the DFT domain
|
||||||
let mut s_dft: ScalarZnxDft<FFT64> = module.new_scalar_znx_dft();
|
let mut s_dft: ScalarZnxDft<FFT64> = module.new_scalar_znx_dft(s.cols());
|
||||||
|
|
||||||
// s_dft <- DFT(s)
|
// s_dft <- DFT(s)
|
||||||
module.svp_prepare(&mut s_dft, &s);
|
module.svp_prepare(&mut s_dft, 0, &s, 0);
|
||||||
|
|
||||||
// Allocates a VecZnx with two columns: ct=(0, 0)
|
// Allocates a VecZnx with two columns: ct=(0, 0)
|
||||||
let mut ct: VecZnx = module.new_vec_znx(
|
let mut ct: VecZnx = module.new_vec_znx(
|
||||||
@@ -48,6 +48,7 @@ fn main() {
|
|||||||
&mut buf_dft, // DFT(ct[1] * s)
|
&mut buf_dft, // DFT(ct[1] * s)
|
||||||
0, // Selects the first column of res
|
0, // Selects the first column of res
|
||||||
&s_dft, // DFT(s)
|
&s_dft, // DFT(s)
|
||||||
|
0, // Selects the first column of s_dft
|
||||||
&ct,
|
&ct,
|
||||||
1, // Selects the second column of ct
|
1, // Selects the second column of ct
|
||||||
);
|
);
|
||||||
@@ -106,6 +107,7 @@ fn main() {
|
|||||||
&mut buf_dft,
|
&mut buf_dft,
|
||||||
0, // Selects the first column of res.
|
0, // Selects the first column of res.
|
||||||
&s_dft,
|
&s_dft,
|
||||||
|
0,
|
||||||
&ct,
|
&ct,
|
||||||
1, // Selects the second column of ct (ct[1])
|
1, // Selects the second column of ct (ct[1])
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ pub mod ffi;
|
|||||||
pub mod mat_znx_dft;
|
pub mod mat_znx_dft;
|
||||||
pub mod module;
|
pub mod module;
|
||||||
pub mod sampling;
|
pub mod sampling;
|
||||||
|
pub mod scalar_znx;
|
||||||
pub mod scalar_znx_dft;
|
pub mod scalar_znx_dft;
|
||||||
|
pub mod scalar_znx_dft_ops;
|
||||||
pub mod stats;
|
pub mod stats;
|
||||||
pub mod vec_znx;
|
pub mod vec_znx;
|
||||||
pub mod vec_znx_big;
|
pub mod vec_znx_big;
|
||||||
@@ -19,8 +21,11 @@ pub use encoding::*;
|
|||||||
pub use mat_znx_dft::*;
|
pub use mat_znx_dft::*;
|
||||||
pub use module::*;
|
pub use module::*;
|
||||||
pub use sampling::*;
|
pub use sampling::*;
|
||||||
|
#[allow(unused_imports)]
|
||||||
|
pub use scalar_znx::*;
|
||||||
pub use scalar_znx_dft::*;
|
pub use scalar_znx_dft::*;
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
|
pub use scalar_znx_dft_ops::*;
|
||||||
pub use stats::*;
|
pub use stats::*;
|
||||||
pub use vec_znx::*;
|
pub use vec_znx::*;
|
||||||
pub use vec_znx_big::*;
|
pub use vec_znx_big::*;
|
||||||
@@ -50,13 +55,13 @@ pub fn assert_alignement<T>(ptr: *const T) {
|
|||||||
|
|
||||||
pub fn cast<T, V>(data: &[T]) -> &[V] {
|
pub fn cast<T, V>(data: &[T]) -> &[V] {
|
||||||
let ptr: *const V = data.as_ptr() as *const V;
|
let ptr: *const V = data.as_ptr() as *const V;
|
||||||
let len: usize = data.len() / std::mem::size_of::<V>();
|
let len: usize = data.len() / size_of::<V>();
|
||||||
unsafe { std::slice::from_raw_parts(ptr, len) }
|
unsafe { std::slice::from_raw_parts(ptr, len) }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cast_mut<T, V>(data: &[T]) -> &mut [V] {
|
pub fn cast_mut<T, V>(data: &[T]) -> &mut [V] {
|
||||||
let ptr: *mut V = data.as_ptr() as *mut V;
|
let ptr: *mut V = data.as_ptr() as *mut V;
|
||||||
let len: usize = data.len() / std::mem::size_of::<V>();
|
let len: usize = data.len() / size_of::<V>();
|
||||||
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
|
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -70,7 +75,7 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
|
|||||||
align
|
align
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
(size * std::mem::size_of::<u8>()) % align,
|
(size * size_of::<u8>()) % align,
|
||||||
0,
|
0,
|
||||||
"size={} must be a multiple of align={}",
|
"size={} must be a multiple of align={}",
|
||||||
size,
|
size,
|
||||||
@@ -98,22 +103,25 @@ fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
|
|||||||
/// Size of T * size msut be a multiple of [DEFAULTALIGN].
|
/// Size of T * size msut be a multiple of [DEFAULTALIGN].
|
||||||
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
|
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
(size * std::mem::size_of::<T>()) % align,
|
(size * size_of::<T>()) % align,
|
||||||
0,
|
0,
|
||||||
"size={} must be a multiple of align={}",
|
"size={} must be a multiple of align={}",
|
||||||
size,
|
size,
|
||||||
align
|
align
|
||||||
);
|
);
|
||||||
let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(std::mem::size_of::<T>() * size, align);
|
let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(size_of::<T>() * size, align);
|
||||||
let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T;
|
let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T;
|
||||||
let len: usize = vec_u8.len() / std::mem::size_of::<T>();
|
let len: usize = vec_u8.len() / size_of::<T>();
|
||||||
let cap: usize = vec_u8.capacity() / std::mem::size_of::<T>();
|
let cap: usize = vec_u8.capacity() / size_of::<T>();
|
||||||
std::mem::forget(vec_u8);
|
std::mem::forget(vec_u8);
|
||||||
unsafe { Vec::from_raw_parts(ptr, len, cap) }
|
unsafe { Vec::from_raw_parts(ptr, len, cap) }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Allocates an aligned of size equal to the smallest multiple
|
/// Allocates an aligned vector of size equal to the smallest multiple
|
||||||
/// of [DEFAULTALIGN] that is equal or greater to `size`.
|
/// of [DEFAULTALIGN]/size_of::<T>() that is equal or greater to `size`.
|
||||||
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
|
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
|
||||||
alloc_aligned_custom::<T>(size + (size % DEFAULTALIGN), DEFAULTALIGN)
|
alloc_aligned_custom::<T>(
|
||||||
|
size + (size % (DEFAULTALIGN / size_of::<T>())),
|
||||||
|
DEFAULTALIGN,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ pub trait MatZnxDftOps<B: Backend> {
|
|||||||
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
|
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
|
||||||
/// * `a`: [MatZnxDft] on which the values are encoded.
|
/// * `a`: [MatZnxDft] on which the values are encoded.
|
||||||
/// * `row_i`: the index of the row to extract.
|
/// * `row_i`: the index of the row to extract.
|
||||||
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<B>, a: &MatZnxDft<B>, row_i: usize);
|
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<B>, row_i: usize, a: &MatZnxDft<B>);
|
||||||
|
|
||||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
|
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft].
|
||||||
///
|
///
|
||||||
@@ -170,7 +170,7 @@ pub trait MatZnxDftOps<B: Backend> {
|
|||||||
/// * `a_size`: number of size of the input [VecZnx].
|
/// * `a_size`: number of size of the input [VecZnx].
|
||||||
/// * `rows`: number of rows of the input [MatZnxDft].
|
/// * `rows`: number of rows of the input [MatZnxDft].
|
||||||
/// * `size`: number of size of the input [MatZnxDft].
|
/// * `size`: number of size of the input [MatZnxDft].
|
||||||
fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, rows: usize, size: usize) -> usize;
|
fn vmp_apply_dft_tmp_bytes(&self, c_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize;
|
||||||
|
|
||||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
||||||
///
|
///
|
||||||
@@ -404,7 +404,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<FFT64>, a: &MatZnxDft<FFT64>, row_i: usize) {
|
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft<FFT64>, row_i: usize, a: &MatZnxDft<FFT64>) {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), b.n());
|
assert_eq!(a.n(), b.n());
|
||||||
@@ -422,14 +422,14 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, gct_rows: usize, gct_size: usize) -> usize {
|
fn vmp_apply_dft_tmp_bytes(&self, res_size: usize, a_size: usize, b_rows: usize, b_size: usize) -> usize {
|
||||||
unsafe {
|
unsafe {
|
||||||
vmp::vmp_apply_dft_tmp_bytes(
|
vmp::vmp_apply_dft_tmp_bytes(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res_size as u64,
|
res_size as u64,
|
||||||
a_size as u64,
|
a_size as u64,
|
||||||
gct_rows as u64,
|
b_rows as u64,
|
||||||
gct_size as u64,
|
b_size as u64,
|
||||||
) as usize
|
) as usize
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -595,7 +595,7 @@ mod tests {
|
|||||||
assert_eq!(vmpmat_0.raw(), vmpmat_1.raw());
|
assert_eq!(vmpmat_0.raw(), vmpmat_1.raw());
|
||||||
|
|
||||||
// Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft)
|
// Checks that a_dft = extract_dft(prepare(mat_znx_dft, a), b_dft)
|
||||||
module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i);
|
module.vmp_extract_row_dft(&mut b_dft, row_i, &vmpmat_0);
|
||||||
assert_eq!(a_dft.raw(), b_dft.raw());
|
assert_eq!(a_dft.raw(), b_dft.raw());
|
||||||
|
|
||||||
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)
|
// Checks that a_big = extract(prepare_dft(mat_znx_dft, a_dft), b_big)
|
||||||
|
|||||||
113
base2k/src/scalar_znx.rs
Normal file
113
base2k/src/scalar_znx.rs
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
use crate::znx_base::{ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
||||||
|
use crate::{Backend, GetZnxBase, Module, VecZnx};
|
||||||
|
use rand::seq::SliceRandom;
|
||||||
|
use rand_core::RngCore;
|
||||||
|
use rand_distr::{Distribution, weighted::WeightedIndex};
|
||||||
|
use sampling::source::Source;
|
||||||
|
|
||||||
|
pub const SCALAR_ZNX_ROWS: usize = 1;
|
||||||
|
pub const SCALAR_ZNX_SIZE: usize = 1;
|
||||||
|
|
||||||
|
pub struct Scalar {
|
||||||
|
pub inner: ZnxBase,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GetZnxBase for Scalar {
|
||||||
|
fn znx(&self) -> &ZnxBase {
|
||||||
|
&self.inner
|
||||||
|
}
|
||||||
|
|
||||||
|
fn znx_mut(&mut self) -> &mut ZnxBase {
|
||||||
|
&mut self.inner
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ZnxInfos for Scalar {}
|
||||||
|
|
||||||
|
impl<B: Backend> ZnxAlloc<B> for Scalar {
|
||||||
|
type Scalar = i64;
|
||||||
|
|
||||||
|
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self {
|
||||||
|
Self {
|
||||||
|
inner: ZnxBase::from_bytes_borrow(module.n(), SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, _size: usize) -> usize {
|
||||||
|
debug_assert_eq!(
|
||||||
|
_rows, SCALAR_ZNX_ROWS,
|
||||||
|
"rows != {} not supported for Scalar",
|
||||||
|
SCALAR_ZNX_ROWS
|
||||||
|
);
|
||||||
|
debug_assert_eq!(
|
||||||
|
_size, SCALAR_ZNX_SIZE,
|
||||||
|
"rows != {} not supported for Scalar",
|
||||||
|
SCALAR_ZNX_SIZE
|
||||||
|
);
|
||||||
|
module.n() * cols * std::mem::size_of::<self::Scalar>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ZnxLayout for Scalar {
|
||||||
|
type Scalar = i64;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ZnxSliceSize for Scalar {
|
||||||
|
fn sl(&self) -> usize {
|
||||||
|
self.n()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Scalar {
|
||||||
|
pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) {
|
||||||
|
let choices: [i64; 3] = [-1, 0, 1];
|
||||||
|
let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0];
|
||||||
|
let dist: WeightedIndex<f64> = WeightedIndex::new(&weights).unwrap();
|
||||||
|
self.at_mut(col, 0)
|
||||||
|
.iter_mut()
|
||||||
|
.for_each(|x: &mut i64| *x = choices[dist.sample(source)]);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fill_ternary_hw(&mut self, col: usize, hw: usize, source: &mut Source) {
|
||||||
|
assert!(hw <= self.n());
|
||||||
|
self.at_mut(col, 0)[..hw]
|
||||||
|
.iter_mut()
|
||||||
|
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
|
||||||
|
self.at_mut(col, 0).shuffle(source);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn alias_as_vec_znx(&self) -> VecZnx {
|
||||||
|
VecZnx {
|
||||||
|
inner: ZnxBase {
|
||||||
|
n: self.n(),
|
||||||
|
rows: 1,
|
||||||
|
cols: 1,
|
||||||
|
size: 1,
|
||||||
|
data: Vec::new(),
|
||||||
|
ptr: self.ptr() as *mut u8,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait ScalarOps {
|
||||||
|
fn bytes_of_scalar(&self, cols: usize) -> usize;
|
||||||
|
fn new_scalar(&self, cols: usize) -> Scalar;
|
||||||
|
fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> Scalar;
|
||||||
|
fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> ScalarOps for Module<B> {
|
||||||
|
fn bytes_of_scalar(&self, cols: usize) -> usize {
|
||||||
|
Scalar::bytes_of(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE)
|
||||||
|
}
|
||||||
|
fn new_scalar(&self, cols: usize) -> Scalar {
|
||||||
|
Scalar::new(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE)
|
||||||
|
}
|
||||||
|
fn new_scalar_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> Scalar {
|
||||||
|
Scalar::from_bytes(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes)
|
||||||
|
}
|
||||||
|
fn new_scalar_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> Scalar {
|
||||||
|
Scalar::from_bytes_borrow(self, SCALAR_ZNX_ROWS, cols, SCALAR_ZNX_SIZE, bytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,279 +1,66 @@
|
|||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
use crate::ffi::svp::{self, svp_ppol_t};
|
use crate::ffi::svp;
|
||||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
use crate::znx_base::{ZnxAlloc, ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
||||||
use crate::znx_base::{ZnxBase, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
use crate::{Backend, FFT64, GetZnxBase, Module};
|
||||||
use crate::{Backend, FFT64, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement, cast_mut};
|
|
||||||
use rand::seq::SliceRandom;
|
|
||||||
use rand_core::RngCore;
|
|
||||||
use rand_distr::{Distribution, weighted::WeightedIndex};
|
|
||||||
use sampling::source::Source;
|
|
||||||
|
|
||||||
pub struct Scalar {
|
pub const SCALAR_ZNX_DFT_ROWS: usize = 1;
|
||||||
pub n: usize,
|
pub const SCALAR_ZNX_DFT_SIZE: usize = 1;
|
||||||
pub data: Vec<i64>,
|
|
||||||
pub ptr: *mut i64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<B: Backend> Module<B> {
|
|
||||||
pub fn new_scalar(&self) -> Scalar {
|
|
||||||
Scalar::new(self.n())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Scalar {
|
|
||||||
pub fn new(n: usize) -> Self {
|
|
||||||
let mut data: Vec<i64> = alloc_aligned::<i64>(n);
|
|
||||||
let ptr: *mut i64 = data.as_mut_ptr();
|
|
||||||
Self {
|
|
||||||
n: n,
|
|
||||||
data: data,
|
|
||||||
ptr: ptr,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn n(&self) -> usize {
|
|
||||||
self.n
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn bytes_of(n: usize) -> usize {
|
|
||||||
n * std::mem::size_of::<i64>()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_bytes(n: usize, bytes: &mut [u8]) -> Self {
|
|
||||||
let size: usize = Self::bytes_of(n);
|
|
||||||
debug_assert!(
|
|
||||||
bytes.len() == size,
|
|
||||||
"invalid buffer: bytes.len()={} < self.bytes_of(n={})={}",
|
|
||||||
bytes.len(),
|
|
||||||
n,
|
|
||||||
size
|
|
||||||
);
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_alignement(bytes.as_ptr())
|
|
||||||
}
|
|
||||||
unsafe {
|
|
||||||
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
|
|
||||||
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
|
|
||||||
Self {
|
|
||||||
n: n,
|
|
||||||
data: Vec::from_raw_parts(bytes_i64.as_mut_ptr(), bytes.len(), bytes.len()),
|
|
||||||
ptr: ptr,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_bytes_borrow(n: usize, bytes: &mut [u8]) -> Self {
|
|
||||||
let size: usize = Self::bytes_of(n);
|
|
||||||
debug_assert!(
|
|
||||||
bytes.len() == size,
|
|
||||||
"invalid buffer: bytes.len()={} < self.bytes_of(n={})={}",
|
|
||||||
bytes.len(),
|
|
||||||
n,
|
|
||||||
size
|
|
||||||
);
|
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_alignement(bytes.as_ptr())
|
|
||||||
}
|
|
||||||
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
|
|
||||||
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
|
|
||||||
Self {
|
|
||||||
n: n,
|
|
||||||
data: Vec::new(),
|
|
||||||
ptr: ptr,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn as_ptr(&self) -> *const i64 {
|
|
||||||
self.ptr
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn raw(&self) -> &[i64] {
|
|
||||||
unsafe { std::slice::from_raw_parts(self.ptr, self.n) }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn raw_mut(&self) -> &mut [i64] {
|
|
||||||
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n) }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
|
|
||||||
let choices: [i64; 3] = [-1, 0, 1];
|
|
||||||
let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0];
|
|
||||||
let dist: WeightedIndex<f64> = WeightedIndex::new(&weights).unwrap();
|
|
||||||
self.data
|
|
||||||
.iter_mut()
|
|
||||||
.for_each(|x: &mut i64| *x = choices[dist.sample(source)]);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
|
|
||||||
assert!(hw <= self.n());
|
|
||||||
self.data[..hw]
|
|
||||||
.iter_mut()
|
|
||||||
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
|
|
||||||
self.data.shuffle(source);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn as_vec_znx(&self) -> VecZnx {
|
|
||||||
VecZnx {
|
|
||||||
inner: ZnxBase {
|
|
||||||
n: self.n,
|
|
||||||
rows: 1,
|
|
||||||
cols: 1,
|
|
||||||
size: 1,
|
|
||||||
data: Vec::new(),
|
|
||||||
ptr: self.ptr as *mut u8,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait ScalarOps {
|
|
||||||
fn bytes_of_scalar(&self) -> usize;
|
|
||||||
fn new_scalar(&self) -> Scalar;
|
|
||||||
fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar;
|
|
||||||
fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar;
|
|
||||||
}
|
|
||||||
impl<B: Backend> ScalarOps for Module<B> {
|
|
||||||
fn bytes_of_scalar(&self) -> usize {
|
|
||||||
Scalar::bytes_of(self.n())
|
|
||||||
}
|
|
||||||
fn new_scalar(&self) -> Scalar {
|
|
||||||
Scalar::new(self.n())
|
|
||||||
}
|
|
||||||
fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar {
|
|
||||||
Scalar::from_bytes(self.n(), bytes)
|
|
||||||
}
|
|
||||||
fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar {
|
|
||||||
Scalar::from_bytes_borrow(self.n(), tmp_bytes)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct ScalarZnxDft<B: Backend> {
|
pub struct ScalarZnxDft<B: Backend> {
|
||||||
pub n: usize,
|
pub inner: ZnxBase,
|
||||||
pub data: Vec<u8>,
|
|
||||||
pub ptr: *mut u8,
|
|
||||||
_marker: PhantomData<B>,
|
_marker: PhantomData<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft].
|
impl<B: Backend> GetZnxBase for ScalarZnxDft<B> {
|
||||||
/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb.
|
fn znx(&self) -> &ZnxBase {
|
||||||
impl ScalarZnxDft<FFT64> {
|
&self.inner
|
||||||
pub fn new(module: &Module<FFT64>) -> Self {
|
|
||||||
module.new_scalar_znx_dft()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the ring degree of the [SvpPPol].
|
fn znx_mut(&mut self) -> &mut ZnxBase {
|
||||||
pub fn n(&self) -> usize {
|
&mut self.inner
|
||||||
self.n
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn bytes_of(module: &Module<FFT64>) -> usize {
|
impl<B: Backend> ZnxInfos for ScalarZnxDft<B> {}
|
||||||
module.bytes_of_scalar_znx_dft()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_bytes(module: &Module<FFT64>, bytes: &mut [u8]) -> Self {
|
impl<B: Backend> ZnxAlloc<B> for ScalarZnxDft<B> {
|
||||||
#[cfg(debug_assertions)]
|
type Scalar = u8;
|
||||||
{
|
|
||||||
assert_alignement(bytes.as_ptr());
|
|
||||||
assert_eq!(bytes.len(), module.bytes_of_scalar_znx_dft());
|
|
||||||
}
|
|
||||||
unsafe {
|
|
||||||
Self {
|
|
||||||
n: module.n(),
|
|
||||||
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
|
|
||||||
ptr: bytes.as_mut_ptr(),
|
|
||||||
_marker: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn from_bytes_borrow(module: &Module<FFT64>, tmp_bytes: &mut [u8]) -> Self {
|
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, _size: usize, bytes: &mut [u8]) -> Self {
|
||||||
#[cfg(debug_assertions)]
|
|
||||||
{
|
|
||||||
assert_alignement(tmp_bytes.as_ptr());
|
|
||||||
assert_eq!(tmp_bytes.len(), module.bytes_of_scalar_znx_dft());
|
|
||||||
}
|
|
||||||
Self {
|
Self {
|
||||||
n: module.n(),
|
inner: ZnxBase::from_bytes_borrow(
|
||||||
data: Vec::new(),
|
module.n(),
|
||||||
ptr: tmp_bytes.as_mut_ptr(),
|
SCALAR_ZNX_DFT_ROWS,
|
||||||
|
cols,
|
||||||
|
SCALAR_ZNX_DFT_SIZE,
|
||||||
|
bytes,
|
||||||
|
),
|
||||||
_marker: PhantomData,
|
_marker: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of cols of the [SvpPPol], which is always 1.
|
fn bytes_of(module: &Module<B>, _rows: usize, cols: usize, _size: usize) -> usize {
|
||||||
pub fn cols(&self) -> usize {
|
debug_assert_eq!(
|
||||||
1
|
_rows, SCALAR_ZNX_DFT_ROWS,
|
||||||
|
"rows != {} not supported for ScalarZnxDft",
|
||||||
|
SCALAR_ZNX_DFT_ROWS
|
||||||
|
);
|
||||||
|
debug_assert_eq!(
|
||||||
|
_size, SCALAR_ZNX_DFT_SIZE,
|
||||||
|
"rows != {} not supported for ScalarZnxDft",
|
||||||
|
SCALAR_ZNX_DFT_SIZE
|
||||||
|
);
|
||||||
|
unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait ScalarZnxDftOps<B: Backend> {
|
impl ZnxLayout for ScalarZnxDft<FFT64> {
|
||||||
/// Allocates a new [SvpPPol].
|
type Scalar = f64;
|
||||||
fn new_scalar_znx_dft(&self) -> ScalarZnxDft<B>;
|
|
||||||
|
|
||||||
/// Returns the minimum number of bytes necessary to allocate
|
|
||||||
/// a new [SvpPPol] through [SvpPPol::from_bytes] ro.
|
|
||||||
fn bytes_of_scalar_znx_dft(&self) -> usize;
|
|
||||||
|
|
||||||
/// Allocates a new [SvpPPol] from an array of bytes.
|
|
||||||
/// The array of bytes is owned by the [SvpPPol].
|
|
||||||
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
|
|
||||||
fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<B>;
|
|
||||||
|
|
||||||
/// Allocates a new [SvpPPol] from an array of bytes.
|
|
||||||
/// The array of bytes is borrowed by the [SvpPPol].
|
|
||||||
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
|
|
||||||
fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<B>;
|
|
||||||
|
|
||||||
/// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft].
|
|
||||||
fn svp_prepare(&self, svp_ppol: &mut ScalarZnxDft<B>, a: &Scalar);
|
|
||||||
|
|
||||||
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
|
|
||||||
/// the [VecZnxDft] is multiplied with [SvpPPol].
|
|
||||||
fn svp_apply_dft(&self, res: &mut VecZnxDft<B>, res_col: usize, a: &ScalarZnxDft<B>, b: &VecZnx, b_col: usize);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
|
impl ZnxSliceSize for ScalarZnxDft<FFT64> {
|
||||||
fn new_scalar_znx_dft(&self) -> ScalarZnxDft<FFT64> {
|
fn sl(&self) -> usize {
|
||||||
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_scalar_znx_dft());
|
self.n()
|
||||||
let ptr: *mut u8 = data.as_mut_ptr();
|
|
||||||
ScalarZnxDft::<FFT64> {
|
|
||||||
data: data,
|
|
||||||
ptr: ptr,
|
|
||||||
n: self.n(),
|
|
||||||
_marker: PhantomData,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn bytes_of_scalar_znx_dft(&self) -> usize {
|
|
||||||
unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_scalar_znx_dft_from_bytes(&self, bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
|
|
||||||
ScalarZnxDft::from_bytes(self, bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_scalar_znx_dft_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
|
|
||||||
ScalarZnxDft::from_bytes_borrow(self, tmp_bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn svp_prepare(&self, res: &mut ScalarZnxDft<FFT64>, a: &Scalar) {
|
|
||||||
unsafe { svp::svp_prepare(self.ptr, res.ptr as *mut svp_ppol_t, a.as_ptr()) }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn svp_apply_dft(&self, res: &mut VecZnxDft<FFT64>, res_col: usize, a: &ScalarZnxDft<FFT64>, b: &VecZnx, b_col: usize) {
|
|
||||||
unsafe {
|
|
||||||
svp::svp_apply_dft(
|
|
||||||
self.ptr,
|
|
||||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
|
||||||
res.size() as u64,
|
|
||||||
a.ptr as *const svp_ppol_t,
|
|
||||||
b.at_ptr(b_col, 0),
|
|
||||||
b.size() as u64,
|
|
||||||
b.sl() as u64,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
63
base2k/src/scalar_znx_dft_ops.rs
Normal file
63
base2k/src/scalar_znx_dft_ops.rs
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
use crate::ffi::svp::{self, svp_ppol_t};
|
||||||
|
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||||
|
use crate::znx_base::{ZnxAlloc, ZnxInfos, ZnxLayout, ZnxSliceSize};
|
||||||
|
use crate::{Backend, FFT64, Module, SCALAR_ZNX_DFT_ROWS, SCALAR_ZNX_DFT_SIZE, Scalar, ScalarZnxDft, VecZnx, VecZnxDft};
|
||||||
|
|
||||||
|
pub trait ScalarZnxDftOps<B: Backend> {
|
||||||
|
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDft<B>;
|
||||||
|
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
|
||||||
|
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDft<B>;
|
||||||
|
fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft<B>;
|
||||||
|
fn svp_prepare(&self, res: &mut ScalarZnxDft<B>, res_col: usize, a: &Scalar, a_col: usize);
|
||||||
|
fn svp_apply_dft(&self, res: &mut VecZnxDft<B>, res_col: usize, a: &ScalarZnxDft<B>, a_col: usize, b: &VecZnx, b_col: usize);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
|
||||||
|
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDft<FFT64> {
|
||||||
|
ScalarZnxDft::<FFT64>::new(&self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize {
|
||||||
|
ScalarZnxDft::<FFT64>::bytes_of(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDft<FFT64> {
|
||||||
|
ScalarZnxDft::from_bytes(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn new_scalar_znx_dft_from_bytes_borrow(&self, cols: usize, bytes: &mut [u8]) -> ScalarZnxDft<FFT64> {
|
||||||
|
ScalarZnxDft::from_bytes_borrow(self, SCALAR_ZNX_DFT_ROWS, cols, SCALAR_ZNX_DFT_SIZE, bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn svp_prepare(&self, res: &mut ScalarZnxDft<FFT64>, res_col: usize, a: &Scalar, a_col: usize) {
|
||||||
|
unsafe {
|
||||||
|
svp::svp_prepare(
|
||||||
|
self.ptr,
|
||||||
|
res.at_mut_ptr(res_col, 0) as *mut svp_ppol_t,
|
||||||
|
a.at_ptr(a_col, 0),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn svp_apply_dft(
|
||||||
|
&self,
|
||||||
|
res: &mut VecZnxDft<FFT64>,
|
||||||
|
res_col: usize,
|
||||||
|
a: &ScalarZnxDft<FFT64>,
|
||||||
|
a_col: usize,
|
||||||
|
b: &VecZnx,
|
||||||
|
b_col: usize,
|
||||||
|
) {
|
||||||
|
unsafe {
|
||||||
|
svp::svp_apply_dft(
|
||||||
|
self.ptr,
|
||||||
|
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||||
|
res.size() as u64,
|
||||||
|
a.at_ptr(a_col, 0) as *const svp_ppol_t,
|
||||||
|
b.at_ptr(b_col, 0),
|
||||||
|
b.size() as u64,
|
||||||
|
b.sl() as u64,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -26,7 +26,7 @@ impl<B: Backend> ZnxAlloc<B> for VecZnxDft<B> {
|
|||||||
type Scalar = u8;
|
type Scalar = u8;
|
||||||
|
|
||||||
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
fn from_bytes_borrow(module: &Module<B>, _rows: usize, cols: usize, size: usize, bytes: &mut [u8]) -> Self {
|
||||||
VecZnxDft {
|
Self {
|
||||||
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes),
|
inner: ZnxBase::from_bytes_borrow(module.n(), VEC_ZNX_DFT_ROWS, cols, size, bytes),
|
||||||
_marker: PhantomData,
|
_marker: PhantomData,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,47 +47,47 @@ pub trait VecZnxOps {
|
|||||||
&self,
|
&self,
|
||||||
log_base2k: usize,
|
log_base2k: usize,
|
||||||
res: &mut VecZnx,
|
res: &mut VecZnx,
|
||||||
col_res: usize,
|
res_col: usize,
|
||||||
a: &VecZnx,
|
a: &VecZnx,
|
||||||
col_a: usize,
|
a_col: usize,
|
||||||
tmp_bytes: &mut [u8],
|
tmp_bytes: &mut [u8],
|
||||||
);
|
);
|
||||||
|
|
||||||
/// Normalizes the selected column of `a`.
|
/// Normalizes the selected column of `a`.
|
||||||
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, tmp_bytes: &mut [u8]);
|
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]);
|
||||||
|
|
||||||
/// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `c`.
|
/// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `c`.
|
||||||
fn vec_znx_add(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize);
|
fn vec_znx_add(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize);
|
||||||
|
|
||||||
/// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `res`.
|
/// Adds the selected column of `a` to the selected column of `b` and write the result on the selected column of `res`.
|
||||||
fn vec_znx_add_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
|
fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize);
|
||||||
|
|
||||||
/// Subtracts the selected column of `b` to the selected column of `a` and write the result on the selected column of `res`.
|
/// Subtracts the selected column of `b` to the selected column of `a` and write the result on the selected column of `res`.
|
||||||
fn vec_znx_sub(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize);
|
fn vec_znx_sub(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize);
|
||||||
|
|
||||||
/// Subtracts the selected column of `a` to the selected column of `res`.
|
/// Subtracts the selected column of `a` to the selected column of `res`.
|
||||||
fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
|
fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize);
|
||||||
|
|
||||||
/// Subtracts the selected column of `a` to the selected column of `res` and negates the selected column of `res`.
|
/// Subtracts the selected column of `a` to the selected column of `res` and negates the selected column of `res`.
|
||||||
fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
|
fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize);
|
||||||
|
|
||||||
// Negates the selected column of `a` and stores the result on the selected column of `res`.
|
// Negates the selected column of `a` and stores the result on the selected column of `res`.
|
||||||
fn vec_znx_negate(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
|
fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize);
|
||||||
|
|
||||||
/// Negates the selected column of `a`.
|
/// Negates the selected column of `a`.
|
||||||
fn vec_znx_negate_inplace(&self, a: &mut VecZnx, col_a: usize);
|
fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize);
|
||||||
|
|
||||||
/// Multiplies the selected column of `a` by X^k and stores the result on the selected column of `res`.
|
/// Multiplies the selected column of `a` by X^k and stores the result on the selected column of `res`.
|
||||||
fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
|
fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize);
|
||||||
|
|
||||||
/// Multiplies the selected column of `a` by X^k.
|
/// Multiplies the selected column of `a` by X^k.
|
||||||
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize);
|
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize);
|
||||||
|
|
||||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result on the selected column of `res`.
|
/// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result on the selected column of `res`.
|
||||||
fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize);
|
fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize);
|
||||||
|
|
||||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
|
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
|
||||||
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize);
|
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize);
|
||||||
|
|
||||||
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
|
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
|
||||||
///
|
///
|
||||||
@@ -95,7 +95,7 @@ pub trait VecZnxOps {
|
|||||||
///
|
///
|
||||||
/// This method requires that all [VecZnx] of b have the same ring degree
|
/// This method requires that all [VecZnx] of b have the same ring degree
|
||||||
/// and that b.n() * b.len() <= a.n()
|
/// and that b.n() * b.len() <= a.n()
|
||||||
fn vec_znx_split(&self, res: &mut Vec<VecZnx>, col_res: usize, a: &VecZnx, col_a: usize, buf: &mut VecZnx);
|
fn vec_znx_split(&self, res: &mut Vec<VecZnx>, res_col: usize, a: &VecZnx, a_col: usize, buf: &mut VecZnx);
|
||||||
|
|
||||||
/// Merges the subrings of the selected column of `a` into the selected column of `res`.
|
/// Merges the subrings of the selected column of `a` into the selected column of `res`.
|
||||||
///
|
///
|
||||||
@@ -103,7 +103,7 @@ pub trait VecZnxOps {
|
|||||||
///
|
///
|
||||||
/// This method requires that all [VecZnx] of a have the same ring degree
|
/// This method requires that all [VecZnx] of a have the same ring degree
|
||||||
/// and that a.n() * a.len() <= b.n()
|
/// and that a.n() * a.len() <= b.n()
|
||||||
fn vec_znx_merge(&self, res: &mut VecZnx, col_res: usize, a: &Vec<VecZnx>, col_a: usize);
|
fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec<VecZnx>, a_col: usize);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> VecZnxOps for Module<B> {
|
impl<B: Backend> VecZnxOps for Module<B> {
|
||||||
@@ -131,9 +131,9 @@ impl<B: Backend> VecZnxOps for Module<B> {
|
|||||||
&self,
|
&self,
|
||||||
log_base2k: usize,
|
log_base2k: usize,
|
||||||
res: &mut VecZnx,
|
res: &mut VecZnx,
|
||||||
col_res: usize,
|
res_col: usize,
|
||||||
a: &VecZnx,
|
a: &VecZnx,
|
||||||
col_a: usize,
|
a_col: usize,
|
||||||
tmp_bytes: &mut [u8],
|
tmp_bytes: &mut [u8],
|
||||||
) {
|
) {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
@@ -147,10 +147,10 @@ impl<B: Backend> VecZnxOps for Module<B> {
|
|||||||
vec_znx::vec_znx_normalize_base2k(
|
vec_znx::vec_znx_normalize_base2k(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
log_base2k as u64,
|
log_base2k as u64,
|
||||||
res.at_mut_ptr(col_res, 0),
|
res.at_mut_ptr(res_col, 0),
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
res.sl() as u64,
|
res.sl() as u64,
|
||||||
a.at_ptr(col_a, 0),
|
a.at_ptr(a_col, 0),
|
||||||
a.size() as u64,
|
a.size() as u64,
|
||||||
a.sl() as u64,
|
a.sl() as u64,
|
||||||
tmp_bytes.as_mut_ptr(),
|
tmp_bytes.as_mut_ptr(),
|
||||||
@@ -158,22 +158,22 @@ impl<B: Backend> VecZnxOps for Module<B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, col_a: usize, tmp_bytes: &mut [u8]) {
|
fn vec_znx_normalize_inplace(&self, log_base2k: usize, a: &mut VecZnx, a_col: usize, tmp_bytes: &mut [u8]) {
|
||||||
unsafe {
|
unsafe {
|
||||||
let a_ptr: *mut VecZnx = a as *mut VecZnx;
|
let a_ptr: *mut VecZnx = a as *mut VecZnx;
|
||||||
Self::vec_znx_normalize(
|
Self::vec_znx_normalize(
|
||||||
self,
|
self,
|
||||||
log_base2k,
|
log_base2k,
|
||||||
&mut *a_ptr,
|
&mut *a_ptr,
|
||||||
col_a,
|
a_col,
|
||||||
&*a_ptr,
|
&*a_ptr,
|
||||||
col_a,
|
a_col,
|
||||||
tmp_bytes,
|
tmp_bytes,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_add(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize) {
|
fn vec_znx_add(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize) {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -184,27 +184,27 @@ impl<B: Backend> VecZnxOps for Module<B> {
|
|||||||
unsafe {
|
unsafe {
|
||||||
vec_znx::vec_znx_add(
|
vec_znx::vec_znx_add(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res.at_mut_ptr(col_res, 0),
|
res.at_mut_ptr(res_col, 0),
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
res.sl() as u64,
|
res.sl() as u64,
|
||||||
a.at_ptr(col_a, 0),
|
a.at_ptr(a_col, 0),
|
||||||
a.size() as u64,
|
a.size() as u64,
|
||||||
a.sl() as u64,
|
a.sl() as u64,
|
||||||
b.at_ptr(col_b, 0),
|
b.at_ptr(b_col, 0),
|
||||||
b.size() as u64,
|
b.size() as u64,
|
||||||
b.sl() as u64,
|
b.sl() as u64,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_add_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
|
fn vec_znx_add_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) {
|
||||||
unsafe {
|
unsafe {
|
||||||
let res_ptr: *mut VecZnx = res as *mut VecZnx;
|
let res_ptr: *mut VecZnx = res as *mut VecZnx;
|
||||||
Self::vec_znx_add(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res);
|
Self::vec_znx_add(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_sub(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize, b: &VecZnx, col_b: usize) {
|
fn vec_znx_sub(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize, b: &VecZnx, b_col: usize) {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -215,34 +215,34 @@ impl<B: Backend> VecZnxOps for Module<B> {
|
|||||||
unsafe {
|
unsafe {
|
||||||
vec_znx::vec_znx_sub(
|
vec_znx::vec_znx_sub(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res.at_mut_ptr(col_res, 0),
|
res.at_mut_ptr(res_col, 0),
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
res.sl() as u64,
|
res.sl() as u64,
|
||||||
a.at_ptr(col_a, 0),
|
a.at_ptr(a_col, 0),
|
||||||
a.size() as u64,
|
a.size() as u64,
|
||||||
a.sl() as u64,
|
a.sl() as u64,
|
||||||
b.at_ptr(col_b, 0),
|
b.at_ptr(b_col, 0),
|
||||||
b.size() as u64,
|
b.size() as u64,
|
||||||
b.sl() as u64,
|
b.sl() as u64,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
|
fn vec_znx_sub_ab_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) {
|
||||||
unsafe {
|
unsafe {
|
||||||
let res_ptr: *mut VecZnx = res as *mut VecZnx;
|
let res_ptr: *mut VecZnx = res as *mut VecZnx;
|
||||||
Self::vec_znx_sub(self, &mut *res_ptr, col_res, a, col_a, &*res_ptr, col_res);
|
Self::vec_znx_sub(self, &mut *res_ptr, res_col, a, a_col, &*res_ptr, res_col);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
|
fn vec_znx_sub_ba_inplace(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) {
|
||||||
unsafe {
|
unsafe {
|
||||||
let res_ptr: *mut VecZnx = res as *mut VecZnx;
|
let res_ptr: *mut VecZnx = res as *mut VecZnx;
|
||||||
Self::vec_znx_sub(self, &mut *res_ptr, col_res, &*res_ptr, col_res, a, col_a);
|
Self::vec_znx_sub(self, &mut *res_ptr, res_col, &*res_ptr, res_col, a, a_col);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_negate(&self, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
|
fn vec_znx_negate(&self, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -251,24 +251,24 @@ impl<B: Backend> VecZnxOps for Module<B> {
|
|||||||
unsafe {
|
unsafe {
|
||||||
vec_znx::vec_znx_negate(
|
vec_znx::vec_znx_negate(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
res.at_mut_ptr(col_res, 0),
|
res.at_mut_ptr(res_col, 0),
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
res.sl() as u64,
|
res.sl() as u64,
|
||||||
a.at_ptr(col_a, 0),
|
a.at_ptr(a_col, 0),
|
||||||
a.size() as u64,
|
a.size() as u64,
|
||||||
a.sl() as u64,
|
a.sl() as u64,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_negate_inplace(&self, a: &mut VecZnx, col_a: usize) {
|
fn vec_znx_negate_inplace(&self, a: &mut VecZnx, a_col: usize) {
|
||||||
unsafe {
|
unsafe {
|
||||||
let a_ptr: *mut VecZnx = a as *mut VecZnx;
|
let a_ptr: *mut VecZnx = a as *mut VecZnx;
|
||||||
Self::vec_znx_negate(self, &mut *a_ptr, col_a, &*a_ptr, col_a);
|
Self::vec_znx_negate(self, &mut *a_ptr, a_col, &*a_ptr, a_col);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
|
fn vec_znx_rotate(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -278,24 +278,24 @@ impl<B: Backend> VecZnxOps for Module<B> {
|
|||||||
vec_znx::vec_znx_rotate(
|
vec_znx::vec_znx_rotate(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
k,
|
k,
|
||||||
res.at_mut_ptr(col_res, 0),
|
res.at_mut_ptr(res_col, 0),
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
res.sl() as u64,
|
res.sl() as u64,
|
||||||
a.at_ptr(col_a, 0),
|
a.at_ptr(a_col, 0),
|
||||||
a.size() as u64,
|
a.size() as u64,
|
||||||
a.sl() as u64,
|
a.sl() as u64,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize) {
|
fn vec_znx_rotate_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) {
|
||||||
unsafe {
|
unsafe {
|
||||||
let a_ptr: *mut VecZnx = a as *mut VecZnx;
|
let a_ptr: *mut VecZnx = a as *mut VecZnx;
|
||||||
Self::vec_znx_rotate(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a);
|
Self::vec_znx_rotate(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, col_res: usize, a: &VecZnx, col_a: usize) {
|
fn vec_znx_automorphism(&self, k: i64, res: &mut VecZnx, res_col: usize, a: &VecZnx, a_col: usize) {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
assert_eq!(a.n(), self.n());
|
assert_eq!(a.n(), self.n());
|
||||||
@@ -305,24 +305,24 @@ impl<B: Backend> VecZnxOps for Module<B> {
|
|||||||
vec_znx::vec_znx_automorphism(
|
vec_znx::vec_znx_automorphism(
|
||||||
self.ptr,
|
self.ptr,
|
||||||
k,
|
k,
|
||||||
res.at_mut_ptr(col_res, 0),
|
res.at_mut_ptr(res_col, 0),
|
||||||
res.size() as u64,
|
res.size() as u64,
|
||||||
res.sl() as u64,
|
res.sl() as u64,
|
||||||
a.at_ptr(col_a, 0),
|
a.at_ptr(a_col, 0),
|
||||||
a.size() as u64,
|
a.size() as u64,
|
||||||
a.sl() as u64,
|
a.sl() as u64,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, col_a: usize) {
|
fn vec_znx_automorphism_inplace(&self, k: i64, a: &mut VecZnx, a_col: usize) {
|
||||||
unsafe {
|
unsafe {
|
||||||
let a_ptr: *mut VecZnx = a as *mut VecZnx;
|
let a_ptr: *mut VecZnx = a as *mut VecZnx;
|
||||||
Self::vec_znx_automorphism(self, k, &mut *a_ptr, col_a, &*a_ptr, col_a);
|
Self::vec_znx_automorphism(self, k, &mut *a_ptr, a_col, &*a_ptr, a_col);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_split(&self, res: &mut Vec<VecZnx>, col_res: usize, a: &VecZnx, col_a: usize, buf: &mut VecZnx) {
|
fn vec_znx_split(&self, res: &mut Vec<VecZnx>, res_col: usize, a: &VecZnx, a_col: usize, buf: &mut VecZnx) {
|
||||||
let (n_in, n_out) = (a.n(), res[0].n());
|
let (n_in, n_out) = (a.n(), res[0].n());
|
||||||
|
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
@@ -339,16 +339,16 @@ impl<B: Backend> VecZnxOps for Module<B> {
|
|||||||
|
|
||||||
res.iter_mut().enumerate().for_each(|(i, bi)| {
|
res.iter_mut().enumerate().for_each(|(i, bi)| {
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
switch_degree(bi, col_res, a, col_a);
|
switch_degree(bi, res_col, a, a_col);
|
||||||
self.vec_znx_rotate(-1, buf, 0, a, col_a);
|
self.vec_znx_rotate(-1, buf, 0, a, a_col);
|
||||||
} else {
|
} else {
|
||||||
switch_degree(bi, col_res, buf, col_a);
|
switch_degree(bi, res_col, buf, a_col);
|
||||||
self.vec_znx_rotate_inplace(-1, buf, col_a);
|
self.vec_znx_rotate_inplace(-1, buf, a_col);
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn vec_znx_merge(&self, res: &mut VecZnx, col_res: usize, a: &Vec<VecZnx>, col_a: usize) {
|
fn vec_znx_merge(&self, res: &mut VecZnx, res_col: usize, a: &Vec<VecZnx>, a_col: usize) {
|
||||||
let (n_in, n_out) = (res.n(), a[0].n());
|
let (n_in, n_out) = (res.n(), a[0].n());
|
||||||
|
|
||||||
debug_assert!(
|
debug_assert!(
|
||||||
@@ -364,10 +364,10 @@ impl<B: Backend> VecZnxOps for Module<B> {
|
|||||||
});
|
});
|
||||||
|
|
||||||
a.iter().enumerate().for_each(|(_, ai)| {
|
a.iter().enumerate().for_each(|(_, ai)| {
|
||||||
switch_degree(res, col_res, ai, col_a);
|
switch_degree(res, res_col, ai, a_col);
|
||||||
self.vec_znx_rotate_inplace(-1, res, col_res);
|
self.vec_znx_rotate_inplace(-1, res, res_col);
|
||||||
});
|
});
|
||||||
|
|
||||||
self.vec_znx_rotate_inplace(a.len() as i64, res, col_res);
|
self.vec_znx_rotate_inplace(a.len() as i64, res, res_col);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user