Merge pull request #20 from phantomzone-org/jay/restructure-base2k

Major refactoring on memory layout, memory safety & basic functionalities
This commit is contained in:
Janmajayamall
2025-05-21 15:02:09 +05:30
committed by GitHub
69 changed files with 11366 additions and 5443 deletions

View File

@@ -1,6 +1,6 @@
[workspace]
members = ["base2k", "rlwe", "sampling", "utils"]
members = ["base2k", "core", "sampling", "utils"]
resolver = "3"
[workspace.dependencies]
rug = "1.27"

11
base2k/.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,11 @@
{
"github.copilot.enable": {
"*": false,
"plaintext": false,
"markdown": false,
"scminput": false
},
"files.associations": {
"random": "c"
}
}

View File

@@ -1,7 +1,7 @@
[package]
name = "base2k"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
rug = {workspace = true}

View File

@@ -3,10 +3,11 @@ use std::path::absolute;
fn main() {
println!(
"cargo:rustc-link-search=native={}",
absolute("./spqlios-arithmetic/build/spqlios")
absolute("spqlios-arithmetic/build/spqlios")
.unwrap()
.to_str()
.unwrap()
);
println!("cargo:rustc-link-lib=static=spqlios"); //"cargo:rustc-link-lib=dylib=spqlios"
println!("cargo:rustc-link-lib=static=spqlios");
// println!("cargo:rustc-link-lib=dylib=spqlios")
}

View File

@@ -1,6 +1,7 @@
use base2k::{
BACKEND, Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
VecZnxDftOps, VecZnxOps, alloc_aligned,
AddNormal, Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc,
ScalarZnxDftOps, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos,
};
use itertools::izip;
use sampling::source::Source;
@@ -8,89 +9,125 @@ use sampling::source::Source;
fn main() {
let n: usize = 16;
let log_base2k: usize = 18;
let cols: usize = 3;
let msg_cols: usize = 2;
let log_scale: usize = msg_cols * log_base2k - 5;
let module: Module = Module::new(n, BACKEND::FFT64);
let ct_size: usize = 3;
let msg_size: usize = 2;
let log_scale: usize = msg_size * log_base2k - 5;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let mut carry: Vec<u8> = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes());
let mut scratch: ScratchOwned = ScratchOwned::new(module.vec_znx_big_normalize_tmp_bytes());
let seed: [u8; 32] = [0; 32];
let mut source: Source = Source::new(seed);
let mut res: VecZnx = module.new_vec_znx(1, cols);
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
let mut s: Scalar = Scalar::new(n);
s.fill_ternary_prob(0.5, &mut source);
let mut s: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
s.fill_ternary_prob(0, 0.5, &mut source);
// Buffer to store s in the DFT domain
let mut s_ppol: SvpPPol = module.new_svp_ppol();
let mut s_dft: ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(s.cols());
// s_ppol <- DFT(s)
module.svp_prepare(&mut s_ppol, &s);
// s_dft <- DFT(s)
module.svp_prepare(&mut s_dft, 0, &s, 0);
// a <- Z_{2^prec}[X]/(X^{N}+1)
let mut a: VecZnx = module.new_vec_znx(1, cols);
module.fill_uniform(log_base2k, &mut a, cols, &mut source);
// Allocates a VecZnx with two columns: ct=(0, 0)
let mut ct: VecZnx<Vec<u8>> = module.new_vec_znx(
2, // Number of columns
ct_size, // Number of small poly per column
);
// Scratch space for DFT values
let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.cols());
// Fill the second column with random values: ct = (0, a)
ct.fill_uniform(log_base2k, 1, ct_size, &mut source);
// Applies buf_dft <- s * a
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_size);
// Alias scratch space
let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big();
module.vec_znx_dft(&mut buf_dft, 0, &ct, 1);
// buf_big <- IDFT(buf_dft) (not normalized)
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
// Applies DFT(ct[1]) * DFT(s)
module.svp_apply_inplace(
&mut buf_dft, // DFT(ct[1] * s)
0, // Selects the first column of res
&s_dft, // DFT(s)
0, // Selects the first column of s_dft
);
let mut m: VecZnx = module.new_vec_znx(1, msg_cols);
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
let mut buf_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_size);
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
// Creates a plaintext: VecZnx with 1 column
let mut m = module.new_vec_znx(
1, // Number of columns
msg_size, // Number of small polynomials
);
let mut want: Vec<i64> = vec![0; n];
want.iter_mut()
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
// m
m.encode_vec_i64(0, log_base2k, log_scale, &want, 4);
m.normalize(log_base2k, &mut carry);
module.vec_znx_normalize_inplace(log_base2k, &mut m, 0, scratch.borrow());
// buf_big <- m - buf_big
module.vec_znx_big_sub_small_a_inplace(&mut buf_big, &m);
// b <- normalize(buf_big) + e
let mut b: VecZnx = module.new_vec_znx(1, cols);
module.vec_znx_big_normalize(log_base2k, &mut b, &buf_big, &mut carry);
module.add_normal(
log_base2k,
&mut b,
log_base2k * cols,
&mut source,
3.2,
19.0,
// m - BIG(ct[1] * s)
module.vec_znx_big_sub_small_b_inplace(
&mut buf_big,
0, // Selects the first column of the receiver
&m,
0, // Selects the first column of the message
);
// Decrypt
// Normalizes back to VecZnx
// ct[0] <- m - BIG(c1 * s)
module.vec_znx_big_normalize(
log_base2k,
&mut ct,
0, // Selects the first column of ct (ct[0])
&buf_big,
0, // Selects the first column of buf_big
scratch.borrow(),
);
// buf_big <- a * s
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
// Add noise to ct[0]
// ct[0] <- ct[0] + e
ct.add_normal(
log_base2k,
0, // Selects the first column of ct (ct[0])
log_base2k * ct_size, // Scaling of the noise: 2^{-log_base2k * limbs}
&mut source,
3.2, // Standard deviation
19.0, // Truncatation bound
);
// buf_big <- a * s + b
module.vec_znx_big_add_small_inplace(&mut buf_big, &b);
// Final ciphertext: ct = (-a * s + m + e, a)
// res <- normalize(buf_big)
module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry);
// Decryption
// DFT(ct[1] * s)
module.vec_znx_dft(&mut buf_dft, 0, &ct, 1);
module.svp_apply_inplace(
&mut buf_dft,
0, // Selects the first column of res.
&s_dft,
0,
);
// BIG(c1 * s) = IDFT(DFT(c1 * s))
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
// BIG(c1 * s) + ct[0]
module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0);
// m + e <- BIG(ct[1] * s + ct[0])
let mut res = module.new_vec_znx(1, ct_size);
module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, scratch.borrow());
// have = m * 2^{log_scale} + e
let mut have: Vec<i64> = vec![i64::default(); n];
res.decode_vec_i64(0, log_base2k, res.cols() * log_base2k, &mut have);
res.decode_vec_i64(0, log_base2k, res.size() * log_base2k, &mut have);
let scale: f64 = (1 << (res.cols() * log_base2k - log_scale)) as f64;
let scale: f64 = (1 << (res.size() * log_base2k - log_scale)) as f64;
izip!(want.iter(), have.iter())
.enumerate()
.for_each(|(i, (a, b))| {
println!("{}: {} {}", i, a, (*b as f64) / scale);
})
});
}

View File

@@ -1,58 +0,0 @@
use base2k::{
BACKEND, Encoding, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps,
alloc_aligned,
};
fn main() {
let log_n: i32 = 5;
let n: usize = 1 << log_n;
let module: Module = Module::new(n, BACKEND::FFT64);
let log_base2k: usize = 15;
let cols: usize = 5;
let log_k: usize = log_base2k * cols - 5;
let rows: usize = cols;
let cols: usize = cols + 1;
// Maximum size of the byte scratch needed
let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows, cols) | module.vmp_apply_dft_tmp_bytes(cols, cols, rows, cols);
let mut buf: Vec<u8> = alloc_aligned(tmp_bytes);
let mut a_values: Vec<i64> = vec![i64::default(); n];
a_values[1] = (1 << log_base2k) + 1;
let mut a: VecZnx = module.new_vec_znx(1, rows);
a.encode_vec_i64(0, log_base2k, log_k, &a_values, 32);
a.normalize(log_base2k, &mut buf);
a.print(0, a.cols(), n);
println!();
let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(1, rows, cols);
(0..a.cols()).for_each(|row_i| {
let mut tmp: VecZnx = module.new_vec_znx(1, cols);
tmp.at_mut(row_i)[1] = 1 as i64;
module.vmp_prepare_row(&mut vmp_pmat, tmp.raw(), row_i, &mut buf);
});
let mut c_dft: VecZnxDft = module.new_vec_znx_dft(1, cols);
module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf);
let mut c_big: VecZnxBig = c_dft.as_vec_znx_big();
module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft);
let mut res: VecZnx = module.new_vec_znx(1, rows);
module.vec_znx_big_normalize(log_base2k, &mut res, &c_big, &mut buf);
let mut values_res: Vec<i64> = vec![i64::default(); n];
res.decode_vec_i64(0, log_base2k, log_k, &mut values_res);
res.print(0, res.cols(), n);
module.free();
println!("{:?}", values_res)
}

View File

@@ -1,5 +1,6 @@
use crate::ffi::znx::znx_zero_i64_ref;
use crate::{Infos, VecZnx};
use crate::znx_base::{ZnxView, ZnxViewMut};
use crate::{VecZnx, znx_base::ZnxInfos};
use itertools::izip;
use rug::{Assign, Float};
use std::cmp::min;
@@ -9,129 +10,141 @@ pub trait Encoding {
///
/// # Arguments
///
/// * `poly_idx`: the index of the poly where to encode the data.
/// * `col_i`: the index of the poly where to encode the data.
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
/// * `log_k`: base two negative logarithm of the scaling of the data.
/// * `data`: data to encode on the receiver.
/// * `log_max`: base two logarithm of the infinity norm of the input data.
fn encode_vec_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize);
/// decode a vector of i64 from the receiver.
///
/// # Arguments
///
/// * `poly_idx`: the index of the poly where to encode the data.
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
/// * `log_k`: base two logarithm of the scaling of the data.
/// * `data`: data to decode from the receiver.
fn decode_vec_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]);
/// decode a vector of Float from the receiver.
///
/// # Arguments
/// * `poly_idx`: the index of the poly where to encode the data.
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
/// * `data`: data to decode from the receiver.
fn decode_vec_float(&self, poly_idx: usize, log_base2k: usize, data: &mut [Float]);
fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize);
/// encodes a single i64 on the receiver at the given index.
///
/// # Arguments
///
/// * `poly_idx`: the index of the poly where to encode the data.
/// * `col_i`: the index of the poly where to encode the data.
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
/// * `log_k`: base two negative logarithm of the scaling of the data.
/// * `i`: index of the coefficient on which to encode the data.
/// * `data`: data to encode on the receiver.
/// * `log_max`: base two logarithm of the infinity norm of the input data.
fn encode_coeff_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize);
fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize);
}
pub trait Decoding {
/// decode a vector of i64 from the receiver.
///
/// # Arguments
///
/// * `col_i`: the index of the poly where to encode the data.
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
/// * `log_k`: base two logarithm of the scaling of the data.
/// * `data`: data to decode from the receiver.
fn decode_vec_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]);
/// decode a vector of Float from the receiver.
///
/// # Arguments
/// * `col_i`: the index of the poly where to encode the data.
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
/// * `data`: data to decode from the receiver.
fn decode_vec_float(&self, col_i: usize, log_base2k: usize, data: &mut [Float]);
/// decode a single of i64 from the receiver at the given index.
///
/// # Arguments
///
/// * `poly_idx`: the index of the poly where to encode the data.
/// * `col_i`: the index of the poly where to encode the data.
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
/// * `log_k`: base two negative logarithm of the scaling of the data.
/// * `i`: index of the coefficient to decode.
/// * `data`: data to decode from the receiver.
fn decode_coeff_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64;
fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64;
}
impl Encoding for VecZnx {
fn encode_vec_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) {
encode_vec_i64(self, poly_idx, log_base2k, log_k, data, log_max)
impl<D: AsMut<[u8]> + AsRef<[u8]>> Encoding for VecZnx<D> {
fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) {
encode_vec_i64(self, col_i, log_base2k, log_k, data, log_max)
}
fn decode_vec_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) {
decode_vec_i64(self, poly_idx, log_base2k, log_k, data)
}
fn decode_vec_float(&self, poly_idx: usize, log_base2k: usize, data: &mut [Float]) {
decode_vec_float(self, poly_idx, log_base2k, data)
}
fn encode_coeff_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) {
encode_coeff_i64(self, poly_idx, log_base2k, log_k, i, value, log_max)
}
fn decode_coeff_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 {
decode_coeff_i64(self, poly_idx, log_base2k, log_k, i)
fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) {
encode_coeff_i64(self, col_i, log_base2k, log_k, i, value, log_max)
}
}
fn encode_vec_i64(a: &mut VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) {
let cols: usize = (log_k + log_base2k - 1) / log_base2k;
impl<D: AsRef<[u8]>> Decoding for VecZnx<D> {
fn decode_vec_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) {
decode_vec_i64(self, col_i, log_base2k, log_k, data)
}
fn decode_vec_float(&self, col_i: usize, log_base2k: usize, data: &mut [Float]) {
decode_vec_float(self, col_i, log_base2k, data)
}
fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 {
decode_coeff_i64(self, col_i, log_base2k, log_k, i)
}
}
fn encode_vec_i64<D: AsMut<[u8]> + AsRef<[u8]>>(
a: &mut VecZnx<D>,
col_i: usize,
log_base2k: usize,
log_k: usize,
data: &[i64],
log_max: usize,
) {
let size: usize = (log_k + log_base2k - 1) / log_base2k;
#[cfg(debug_assertions)]
{
assert!(
cols <= a.cols(),
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}",
cols,
a.cols()
size <= a.size(),
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.size()={}",
size,
a.size()
);
assert!(poly_idx < a.size);
assert!(col_i < a.cols());
assert!(data.len() <= a.n())
}
let data_len: usize = data.len();
let log_k_rem: usize = log_base2k - (log_k % log_base2k);
(0..a.cols()).for_each(|i| unsafe {
znx_zero_i64_ref(a.n() as u64, a.at_poly_mut_ptr(poly_idx, i));
// Zeroes coefficients of the i-th column
(0..a.size()).for_each(|i| unsafe {
znx_zero_i64_ref(a.n() as u64, a.at_mut_ptr(col_i, i));
});
// If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
a.at_poly_mut(poly_idx, cols - 1)[..data_len].copy_from_slice(&data[..data_len]);
a.at_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
} else {
let mask: i64 = (1 << log_base2k) - 1;
let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k);
(cols - steps..cols)
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
(size - steps..size)
.rev()
.enumerate()
.for_each(|(i, i_rev)| {
let shift: usize = i * log_base2k;
izip!(a.at_poly_mut(poly_idx, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
izip!(a.at_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
})
}
// Case where self.prec % self.k != 0.
if log_k_rem != log_base2k {
let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k);
(cols - steps..cols).rev().for_each(|i| {
a.at_poly_mut(poly_idx, i)[..data_len]
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
(size - steps..size).rev().for_each(|i| {
a.at_mut(col_i, i)[..data_len]
.iter_mut()
.for_each(|x| *x <<= log_k_rem);
})
}
}
fn decode_vec_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) {
let cols: usize = (log_k + log_base2k - 1) / log_base2k;
fn decode_vec_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) {
let size: usize = (log_k + log_base2k - 1) / log_base2k;
#[cfg(debug_assertions)]
{
assert!(
@@ -140,26 +153,26 @@ fn decode_vec_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize,
data.len(),
a.n()
);
assert!(poly_idx < a.size());
assert!(col_i < a.cols());
}
data.copy_from_slice(a.at_poly(poly_idx, 0));
data.copy_from_slice(a.at(col_i, 0));
let rem: usize = log_base2k - (log_k % log_base2k);
(1..cols).for_each(|i| {
if i == cols - 1 && rem != log_base2k {
(1..size).for_each(|i| {
if i == size - 1 && rem != log_base2k {
let k_rem: usize = log_base2k - rem;
izip!(a.at_poly(poly_idx, i).iter(), data.iter_mut()).for_each(|(x, y)| {
izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << k_rem) + (x >> rem);
});
} else {
izip!(a.at_poly(poly_idx, i).iter(), data.iter_mut()).for_each(|(x, y)| {
izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
*y = (*y << log_base2k) + x;
});
}
})
}
fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [Float]) {
let cols: usize = a.cols();
fn decode_vec_float<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, log_base2k: usize, data: &mut [Float]) {
let size: usize = a.size();
#[cfg(debug_assertions)]
{
assert!(
@@ -168,23 +181,23 @@ fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [
data.len(),
a.n()
);
assert!(poly_idx < a.size());
assert!(col_i < a.cols());
}
let prec: u32 = (log_base2k * cols) as u32;
let prec: u32 = (log_base2k * size) as u32;
// 2^{log_base2k}
let base = Float::with_val(prec, (1 << log_base2k) as f64);
// y[i] = sum x[j][i] * 2^{-log_base2k*j}
(0..cols).for_each(|i| {
(0..size).for_each(|i| {
if i == 0 {
izip!(a.at_poly(poly_idx, cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
y.assign(*x);
*y /= &base;
});
} else {
izip!(a.at_poly(poly_idx, cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
*y += Float::with_val(prec, *x);
*y /= &base;
});
@@ -192,54 +205,62 @@ fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [
});
}
fn encode_coeff_i64(a: &mut VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) {
let cols: usize = (log_k + log_base2k - 1) / log_base2k;
fn encode_coeff_i64<D: AsMut<[u8]> + AsRef<[u8]>>(
a: &mut VecZnx<D>,
col_i: usize,
log_base2k: usize,
log_k: usize,
i: usize,
value: i64,
log_max: usize,
) {
let size: usize = (log_k + log_base2k - 1) / log_base2k;
#[cfg(debug_assertions)]
{
assert!(i < a.n());
assert!(
cols <= a.cols(),
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}",
cols,
a.cols()
size <= a.size(),
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.size()={}",
size,
a.size()
);
assert!(poly_idx < a.size());
assert!(col_i < a.cols());
}
let log_k_rem: usize = log_base2k - (log_k % log_base2k);
(0..a.cols()).for_each(|j| a.at_poly_mut(poly_idx, j)[i] = 0);
(0..a.size()).for_each(|j| a.at_mut(col_i, j)[i] = 0);
// If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy
// values on the last limb.
// Else we decompose values base2k.
if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
a.at_poly_mut(poly_idx, cols - 1)[i] = value;
a.at_mut(col_i, size - 1)[i] = value;
} else {
let mask: i64 = (1 << log_base2k) - 1;
let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k);
(cols - steps..cols)
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
(size - steps..size)
.rev()
.enumerate()
.for_each(|(j, j_rev)| {
a.at_poly_mut(poly_idx, j_rev)[i] = (value >> (j * log_base2k)) & mask;
a.at_mut(col_i, j_rev)[i] = (value >> (j * log_base2k)) & mask;
})
}
// Case where prec % k != 0.
if log_k_rem != log_base2k {
let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k);
(cols - steps..cols).rev().for_each(|j| {
a.at_poly_mut(poly_idx, j)[i] <<= log_k_rem;
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
(size - steps..size).rev().for_each(|j| {
a.at_mut(col_i, j)[i] <<= log_k_rem;
})
}
}
fn decode_coeff_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 {
fn decode_coeff_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 {
#[cfg(debug_assertions)]
{
assert!(i < a.n());
assert!(poly_idx < a.size())
assert!(col_i < a.cols())
}
let cols: usize = (log_k + log_base2k - 1) / log_base2k;
@@ -261,27 +282,30 @@ fn decode_coeff_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize
#[cfg(test)]
mod tests {
use crate::{Encoding, Infos, VecZnx};
use crate::vec_znx_ops::*;
use crate::znx_base::*;
use crate::{Decoding, Encoding, FFT64, Module, VecZnx, znx_base::ZnxInfos};
use itertools::izip;
use sampling::source::Source;
#[test]
fn test_set_get_i64_lo_norm() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let log_base2k: usize = 17;
let cols: usize = 5;
let log_k: usize = cols * log_base2k - 5;
let mut a: VecZnx = VecZnx::new(n, 2, cols);
let size: usize = 5;
let log_k: usize = size * log_base2k - 5;
let mut a: VecZnx<_> = module.new_vec_znx(2, size);
let mut source: Source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.size()).for_each(|poly_idx| {
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut()
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
a.encode_vec_i64(poly_idx, log_base2k, log_k, &have, 10);
a.encode_vec_i64(col_i, log_base2k, log_k, &have, 10);
let mut want: Vec<i64> = vec![i64::default(); n];
a.decode_vec_i64(poly_idx, log_base2k, log_k, &mut want);
a.decode_vec_i64(col_i, log_base2k, log_k, &mut want);
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
});
}
@@ -289,19 +313,20 @@ mod tests {
#[test]
fn test_set_get_i64_hi_norm() {
let n: usize = 8;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let log_base2k: usize = 17;
let cols: usize = 5;
let log_k: usize = cols * log_base2k - 5;
let mut a: VecZnx = VecZnx::new(n, 2, cols);
let size: usize = 5;
let log_k: usize = size * log_base2k - 5;
let mut a: VecZnx<_> = module.new_vec_znx(2, size);
let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.size()).for_each(|poly_idx| {
(0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut().for_each(|x| *x = source.next_i64());
a.encode_vec_i64(poly_idx, log_base2k, log_k, &have, 64);
a.encode_vec_i64(col_i, log_base2k, log_k, &have, 64);
let mut want = vec![i64::default(); n];
a.decode_vec_i64(poly_idx, log_base2k, log_k, &mut want);
a.decode_vec_i64(col_i, log_base2k, log_k, &mut want);
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
})
}

View File

@@ -3,8 +3,6 @@ pub struct module_info_t {
}
pub type module_type_t = ::std::os::raw::c_uint;
pub const module_type_t_FFT64: module_type_t = 0;
pub const module_type_t_NTT120: module_type_t = 1;
pub use self::module_type_t as MODULE_TYPE;
pub type MODULE = module_info_t;

View File

@@ -33,3 +33,16 @@ unsafe extern "C" {
a_sl: u64,
);
}
unsafe extern "C" {
pub unsafe fn svp_apply_dft_to_dft(
module: *const MODULE,
res: *const VEC_ZNX_DFT,
res_size: u64,
res_cols: u64,
ppol: *const SVP_PPOL,
a: *const VEC_ZNX_DFT,
a_size: u64,
a_cols: u64,
);
}

View File

@@ -8,17 +8,17 @@ pub struct vec_znx_big_t {
pub type VEC_ZNX_BIG = vec_znx_big_t;
unsafe extern "C" {
pub fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64;
pub unsafe fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64;
}
unsafe extern "C" {
pub fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG;
pub unsafe fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG;
}
unsafe extern "C" {
pub fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG);
pub unsafe fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG);
}
unsafe extern "C" {
pub fn vec_znx_big_add(
pub unsafe fn vec_znx_big_add(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
@@ -29,7 +29,7 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub fn vec_znx_big_add_small(
pub unsafe fn vec_znx_big_add_small(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
@@ -41,7 +41,7 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub fn vec_znx_big_add_small2(
pub unsafe fn vec_znx_big_add_small2(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
@@ -54,7 +54,7 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub fn vec_znx_big_sub(
pub unsafe fn vec_znx_big_sub(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
@@ -65,7 +65,7 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub fn vec_znx_big_sub_small_b(
pub unsafe fn vec_znx_big_sub_small_b(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
@@ -77,7 +77,7 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub fn vec_znx_big_sub_small_a(
pub unsafe fn vec_znx_big_sub_small_a(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
@@ -89,7 +89,7 @@ unsafe extern "C" {
);
}
unsafe extern "C" {
pub fn vec_znx_big_sub_small2(
pub unsafe fn vec_znx_big_sub_small2(
module: *const MODULE,
res: *mut VEC_ZNX_BIG,
res_size: u64,
@@ -101,8 +101,13 @@ unsafe extern "C" {
b_sl: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_big_normalize_base2k(
pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_normalize_base2k(
module: *const MODULE,
log2_base2k: u64,
res: *mut i64,
@@ -113,34 +118,9 @@ unsafe extern "C" {
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
}
unsafe extern "C" {
pub fn vec_znx_big_automorphism(
module: *const MODULE,
p: i64,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_big_rotate(
module: *const MODULE,
p: i64,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
);
}
unsafe extern "C" {
pub fn vec_znx_big_range_normalize_base2k(
pub unsafe fn vec_znx_big_range_normalize_base2k(
module: *const MODULE,
log2_base2k: u64,
res: *mut i64,
@@ -153,6 +133,29 @@ unsafe extern "C" {
tmp_space: *mut u8,
);
}
unsafe extern "C" {
pub fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_automorphism(
module: *const MODULE,
p: i64,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
);
}
unsafe extern "C" {
pub unsafe fn vec_znx_big_rotate(
module: *const MODULE,
p: i64,
res: *mut VEC_ZNX_BIG,
res_size: u64,
a: *const VEC_ZNX_BIG,
a_size: u64,
);
}

View File

@@ -1,22 +0,0 @@
use crate::LAYOUT;
pub trait Infos {
/// Returns the ring degree of the receiver.
fn n(&self) -> usize;
/// Returns the base two logarithm of the ring dimension of the receiver.
fn log_n(&self) -> usize;
/// Returns the number of stacked polynomials.
fn size(&self) -> usize;
/// Returns the memory layout of the stacked polynomials.
fn layout(&self) -> LAYOUT;
/// Returns the number of columns of the receiver.
/// This method is equivalent to [Infos::cols].
fn cols(&self) -> usize;
/// Returns the number of rows of the receiver.
fn rows(&self) -> usize;
}

View File

@@ -2,39 +2,43 @@ pub mod encoding;
#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
// Other modules and exports
pub mod ffi;
pub mod infos;
pub mod mat_znx_dft;
pub mod mat_znx_dft_ops;
pub mod module;
pub mod sampling;
pub mod scalar_znx;
pub mod scalar_znx_dft;
pub mod scalar_znx_dft_ops;
pub mod stats;
pub mod svp;
pub mod vec_znx;
pub mod vec_znx_big;
pub mod vec_znx_big_ops;
pub mod vec_znx_dft;
pub mod vmp;
pub mod vec_znx_dft_ops;
pub mod vec_znx_ops;
pub mod znx_base;
pub use encoding::*;
pub use infos::*;
pub use mat_znx_dft::*;
pub use mat_znx_dft_ops::*;
pub use module::*;
pub use sampling::*;
#[allow(unused_imports)]
pub use scalar_znx::*;
pub use scalar_znx_dft::*;
pub use scalar_znx_dft_ops::*;
pub use stats::*;
pub use svp::*;
pub use vec_znx::*;
pub use vec_znx_big::*;
pub use vec_znx_big_ops::*;
pub use vec_znx_dft::*;
pub use vmp::*;
pub use vec_znx_dft_ops::*;
pub use vec_znx_ops::*;
pub use znx_base::*;
pub const GALOISGENERATOR: u64 = 5;
pub const DEFAULTALIGN: usize = 64;
#[derive(Copy, Clone)]
#[repr(u8)]
pub enum LAYOUT {
ROW,
COL,
}
pub fn is_aligned_custom<T>(ptr: *const T, align: usize) -> bool {
fn is_aligned_custom<T>(ptr: *const T, align: usize) -> bool {
(ptr as usize) % align == 0
}
@@ -51,38 +55,35 @@ pub fn assert_alignement<T>(ptr: *const T) {
pub fn cast<T, V>(data: &[T]) -> &[V] {
let ptr: *const V = data.as_ptr() as *const V;
let len: usize = data.len() / std::mem::size_of::<V>();
let len: usize = data.len() / size_of::<V>();
unsafe { std::slice::from_raw_parts(ptr, len) }
}
pub fn cast_mut<T, V>(data: &[T]) -> &mut [V] {
let ptr: *mut V = data.as_ptr() as *mut V;
let len: usize = data.len() / std::mem::size_of::<V>();
let len: usize = data.len() / size_of::<V>();
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
}
use std::alloc::{Layout, alloc};
use std::ptr;
/// Allocates a block of bytes with a custom alignement.
/// Alignement must be a power of two and size a multiple of the alignement.
/// Allocated memory is initialized to zero.
pub fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
assert!(
align.is_power_of_two(),
"Alignment must be a power of two but is {}",
align
);
assert_eq!(
(size * std::mem::size_of::<u8>()) % align,
(size * size_of::<u8>()) % align,
0,
"size={} must be a multiple of align={}",
size,
align
);
unsafe {
let layout: Layout = Layout::from_size_align(size, align).expect("Invalid alignment");
let ptr: *mut u8 = alloc(layout);
let layout: std::alloc::Layout = std::alloc::Layout::from_size_align(size, align).expect("Invalid alignment");
let ptr: *mut u8 = std::alloc::alloc(layout);
if ptr.is_null() {
panic!("Memory allocation failed");
}
@@ -93,36 +94,158 @@ pub fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
align
);
// Init allocated memory to zero
ptr::write_bytes(ptr, 0, size);
std::ptr::write_bytes(ptr, 0, size);
Vec::from_raw_parts(ptr, size, size)
}
}
/// Allocates a block of bytes aligned with [DEFAULTALIGN].
/// Size must be amultiple of [DEFAULTALIGN].
/// /// Allocated memory is initialized to zero.
pub fn alloc_aligned_u8(size: usize) -> Vec<u8> {
alloc_aligned_custom_u8(size, DEFAULTALIGN)
}
/// Allocates a block of T aligned with [DEFAULTALIGN].
/// Size of T * size msut be a multiple of [DEFAULTALIGN].
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
assert_eq!(
(size * std::mem::size_of::<T>()) % align,
(size * size_of::<T>()) % align,
0,
"size={} must be a multiple of align={}",
size,
align
);
let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(std::mem::size_of::<T>() * size, align);
let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(size_of::<T>() * size, align);
let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T;
let len: usize = vec_u8.len() / std::mem::size_of::<T>();
let cap: usize = vec_u8.capacity() / std::mem::size_of::<T>();
let len: usize = vec_u8.len() / size_of::<T>();
let cap: usize = vec_u8.capacity() / size_of::<T>();
std::mem::forget(vec_u8);
unsafe { Vec::from_raw_parts(ptr, len, cap) }
}
/// Allocates an aligned vector of size equal to the smallest multiple
/// of [DEFAULTALIGN]/size_of::<T>() that is equal or greater to `size`.
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
alloc_aligned_custom::<T>(size, DEFAULTALIGN)
alloc_aligned_custom::<T>(
size + (size % (DEFAULTALIGN / size_of::<T>())),
DEFAULTALIGN,
)
}
// Scratch implementation below
pub struct ScratchOwned(Vec<u8>);
impl ScratchOwned {
pub fn new(byte_count: usize) -> Self {
let data: Vec<u8> = alloc_aligned(byte_count);
Self(data)
}
pub fn borrow(&mut self) -> &mut Scratch {
Scratch::new(&mut self.0)
}
}
pub struct Scratch {
data: [u8],
}
impl Scratch {
fn new(data: &mut [u8]) -> &mut Self {
unsafe { &mut *(data as *mut [u8] as *mut Self) }
}
pub fn available(&self) -> usize {
let ptr: *const u8 = self.data.as_ptr();
let self_len: usize = self.data.len();
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
self_len.saturating_sub(aligned_offset)
}
fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) {
let ptr: *mut u8 = data.as_mut_ptr();
let self_len: usize = data.len();
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
let aligned_len: usize = self_len.saturating_sub(aligned_offset);
if let Some(rem_len) = aligned_len.checked_sub(take_len) {
unsafe {
let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len);
let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len);
return (take_slice, rem_slice);
}
} else {
panic!(
"Attempted to take {} from scratch with {} aligned bytes left",
take_len,
aligned_len,
// type_name::<T>(),
// aligned_len
);
}
}
pub fn tmp_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::<T>());
unsafe {
(
&mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)),
Self::new(rem_slice),
)
}
}
pub fn tmp_scalar_znx<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols));
(
ScalarZnx::from_data(take_slice, module.n(), cols),
Self::new(rem_slice),
)
}
pub fn tmp_scalar_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols));
(
ScalarZnxDft::from_data(take_slice, module.n(), cols),
Self::new(rem_slice),
)
}
pub fn tmp_vec_znx_dft<B: Backend>(
&mut self,
module: &Module<B>,
cols: usize,
size: usize,
) -> (VecZnxDft<&mut [u8], B>, &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_dft(module, cols, size));
(
VecZnxDft::from_data(take_slice, module.n(), cols, size),
Self::new(rem_slice),
)
}
pub fn tmp_vec_znx_big<B: Backend>(
&mut self,
module: &Module<B>,
cols: usize,
size: usize,
) -> (VecZnxBig<&mut [u8], B>, &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_big(module, cols, size));
(
VecZnxBig::from_data(take_slice, module.n(), cols, size),
Self::new(rem_slice),
)
}
pub fn tmp_vec_znx<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, module.bytes_of_vec_znx(cols, size));
(
VecZnx::from_data(take_slice, module.n(), cols, size),
Self::new(rem_slice),
)
}
}

232
base2k/src/mat_znx_dft.rs Normal file
View File

@@ -0,0 +1,232 @@
use crate::znx_base::ZnxInfos;
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned};
use std::marker::PhantomData;
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft].
///
/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft].
/// See the trait [MatZnxDftOps] for additional information.
pub struct MatZnxDft<D, B: Backend> {
data: D,
n: usize,
size: usize,
rows: usize,
cols_in: usize,
cols_out: usize,
_phantom: PhantomData<B>,
}
impl<D, B: Backend> ZnxInfos for MatZnxDft<D, B> {
fn cols(&self) -> usize {
self.cols_in
}
fn rows(&self) -> usize {
self.rows
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
self.size
}
}
impl<D> ZnxSliceSize for MatZnxDft<D, FFT64> {
fn sl(&self) -> usize {
self.n() * self.cols_out()
}
}
impl<D, B: Backend> DataView for MatZnxDft<D, B> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D, B: Backend> DataViewMut for MatZnxDft<D, B> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: AsRef<[u8]>> ZnxView for MatZnxDft<D, FFT64> {
type Scalar = f64;
}
impl<D, B: Backend> MatZnxDft<D, B> {
pub fn cols_in(&self) -> usize {
self.cols_in
}
pub fn cols_out(&self) -> usize {
self.cols_out
}
}
impl<D: From<Vec<u8>>, B: Backend> MatZnxDft<D, B> {
pub(crate) fn bytes_of(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
unsafe {
crate::ffi::vmp::bytes_of_vmp_pmat(
module.ptr,
(rows * cols_in) as u64,
(size * cols_out) as u64,
) as usize
}
}
pub(crate) fn new(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
let data: Vec<u8> = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size));
Self {
data: data.into(),
n: module.n(),
size,
rows,
cols_in,
cols_out,
_phantom: PhantomData,
}
}
pub(crate) fn new_from_bytes(
module: &Module<B>,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: impl Into<Vec<u8>>,
) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size));
Self {
data: data.into(),
n: module.n(),
size,
rows,
cols_in,
cols_out,
_phantom: PhantomData,
}
}
}
impl<D: AsRef<[u8]>> MatZnxDft<D, FFT64> {
/// Returns a copy of the backend array at index (i, j) of the [MatZnxDft].
///
/// # Arguments
///
/// * `row`: row index (i).
/// * `col`: col index (j).
#[allow(dead_code)]
fn at(&self, row: usize, col: usize) -> Vec<f64> {
let n: usize = self.n();
let mut res: Vec<f64> = alloc_aligned(n);
if n < 8 {
res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]);
} else {
(0..n >> 3).for_each(|blk| {
res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]);
});
}
res
}
#[allow(dead_code)]
fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] {
let nrows: usize = self.rows();
let nsize: usize = self.size();
if col == (nsize - 1) && (nsize & 1 == 1) {
&self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..]
} else {
&self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..]
}
}
}
pub type MatZnxDftOwned<B> = MatZnxDft<Vec<u8>, B>;
pub trait MatZnxDftToRef<B: Backend> {
fn to_ref(&self) -> MatZnxDft<&[u8], B>;
}
pub trait MatZnxDftToMut<B: Backend> {
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>;
}
impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<Vec<u8>, B> {
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
MatZnxDft {
data: self.data.as_mut_slice(),
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft {
data: self.data.as_slice(),
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
MatZnxDft {
data: self.data,
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&mut [u8], B> {
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft {
data: self.data,
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&[u8], B> {
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
MatZnxDft {
data: self.data,
n: self.n,
rows: self.rows,
cols_in: self.cols_in,
cols_out: self.cols_out,
size: self.size,
_phantom: PhantomData,
}
}
}

View File

@@ -0,0 +1,487 @@
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::ffi::vmp;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut,
VecZnxDftToRef,
};
pub trait MatZnxDftAlloc<B: Backend> {
/// Allocates a new [MatZnxDft] with the given number of rows and columns.
///
/// # Arguments
///
/// * `rows`: number of rows (number of [VecZnxDft]).
/// * `size`: number of size (number of size of each [VecZnxDft]).
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned<B>;
fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
fn new_mat_znx_dft_from_bytes(
&self,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> MatZnxDftOwned<B>;
}
pub trait MatZnxDftScratch {
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
fn vmp_apply_tmp_bytes(
&self,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize;
}
/// This trait implements methods for vector matrix product,
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
pub trait MatZnxDftOps<BACKEND: Backend> {
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
///
/// # Arguments
///
/// * `b`: [MatZnxDft] on which the values are encoded.
/// * `a`: the [VecZnxDft] to encode on the [MatZnxDft].
/// * `row_i`: the index of the row to prepare.
///
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
where
R: MatZnxDftToMut<BACKEND>,
A: VecZnxDftToRef<BACKEND>;
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft].
///
/// # Arguments
///
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
/// * `a`: [MatZnxDft] on which the values are encoded.
/// * `row_i`: the index of the row to extract.
fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: MatZnxDftToRef<BACKEND>;
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
///
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
/// `j` size, the output is a [VecZnx] of `j` size.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
/// ```text
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
/// |h i j|
/// |k l m|
/// ```
/// where each element is a [VecZnxDft].
///
/// # Arguments
///
/// * `c`: the output of the vector matrix product, as a [VecZnxDft].
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
where
R: VecZnxDftToMut<BACKEND>,
A: VecZnxDftToRef<BACKEND>,
B: MatZnxDftToRef<BACKEND>;
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
where
R: VecZnxDftToMut<BACKEND>,
A: VecZnxDftToRef<BACKEND>,
B: MatZnxDftToRef<BACKEND>;
}
impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
MatZnxDftOwned::bytes_of(self, rows, cols_in, cols_out, size)
}
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned<B> {
MatZnxDftOwned::new(self, rows, cols_in, cols_out, size)
}
fn new_mat_znx_dft_from_bytes(
&self,
rows: usize,
cols_in: usize,
cols_out: usize,
size: usize,
bytes: Vec<u8>,
) -> MatZnxDftOwned<B> {
MatZnxDftOwned::new_from_bytes(self, rows, cols_in, cols_out, size, bytes)
}
}
impl<BACKEND: Backend> MatZnxDftScratch for Module<BACKEND> {
fn vmp_apply_tmp_bytes(
&self,
res_size: usize,
a_size: usize,
b_rows: usize,
b_cols_in: usize,
b_cols_out: usize,
b_size: usize,
) -> usize {
unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes(
self.ptr,
(res_size * b_cols_out) as u64,
(a_size * b_cols_in) as u64,
(b_rows * b_cols_in) as u64,
(b_size * b_cols_out) as u64,
) as usize
}
}
}
impl MatZnxDftOps<FFT64> for Module<FFT64> {
fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
where
R: MatZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res: MatZnxDft<&mut [u8], _> = res.to_mut();
let a: VecZnxDft<&[u8], _> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), self.n());
assert_eq!(a.n(), self.n());
assert_eq!(
a.cols(),
res.cols_out(),
"a.cols(): {} != res.cols_out(): {}",
a.cols(),
res.cols_out()
);
assert!(
res_row < res.rows(),
"res_row: {} >= res.rows(): {}",
res_row,
res.rows()
);
assert!(
res_col_in < res.cols_in(),
"res_col_in: {} >= res.cols_in(): {}",
res_col_in,
res.cols_in()
);
assert_eq!(
res.size(),
a.size(),
"res.size(): {} != a.size(): {}",
res.size(),
a.size()
);
}
unsafe {
vmp::vmp_prepare_row_dft(
self.ptr,
res.as_mut_ptr() as *mut vmp::vmp_pmat_t,
a.as_ptr() as *const vec_znx_dft_t,
(res_row * res.cols_in() + res_col_in) as u64,
(res.rows() * res.cols_in()) as u64,
(res.size() * res.cols_out()) as u64,
);
}
}
fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
where
R: VecZnxDftToMut<FFT64>,
A: MatZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
let a: MatZnxDft<&[u8], _> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), self.n());
assert_eq!(a.n(), self.n());
assert_eq!(
res.cols(),
a.cols_out(),
"res.cols(): {} != a.cols_out(): {}",
res.cols(),
a.cols_out()
);
assert!(
a_row < a.rows(),
"a_row: {} >= a.rows(): {}",
a_row,
a.rows()
);
assert!(
a_col_in < a.cols_in(),
"a_col_in: {} >= a.cols_in(): {}",
a_col_in,
a.cols_in()
);
assert_eq!(
res.size(),
a.size(),
"res.size(): {} != a.size(): {}",
res.size(),
a.size()
);
}
unsafe {
vmp::vmp_extract_row_dft(
self.ptr,
res.as_mut_ptr() as *mut vec_znx_dft_t,
a.as_ptr() as *const vmp::vmp_pmat_t,
(a_row * a.cols_in() + a_col_in) as u64,
(a.rows() * a.cols_in()) as u64,
(a.size() * a.cols_out()) as u64,
);
}
}
fn vmp_apply<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
B: MatZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
let a: VecZnxDft<&[u8], _> = a.to_ref();
let b: MatZnxDft<&[u8], _> = b.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(a.n(), self.n());
assert_eq!(
res.cols(),
b.cols_out(),
"res.cols(): {} != b.cols_out: {}",
res.cols(),
b.cols_out()
);
assert_eq!(
a.cols(),
b.cols_in(),
"a.cols(): {} != b.cols_in: {}",
a.cols(),
b.cols_in()
);
}
let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes(
res.size(),
a.size(),
b.rows(),
b.cols_in(),
b.cols_out(),
b.size(),
));
unsafe {
vmp::vmp_apply_dft_to_dft(
self.ptr,
res.as_mut_ptr() as *mut vec_znx_dft_t,
(res.size() * res.cols()) as u64,
a.as_ptr() as *const vec_znx_dft_t,
(a.size() * a.cols()) as u64,
b.as_ptr() as *const vmp::vmp_pmat_t,
(b.rows() * b.cols_in()) as u64,
(b.size() * b.cols_out()) as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
B: MatZnxDftToRef<FFT64> {
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
let a: VecZnxDft<&[u8], _> = a.to_ref();
let b: MatZnxDft<&[u8], _> = b.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(res.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(a.n(), self.n());
assert_eq!(
res.cols(),
b.cols_out(),
"res.cols(): {} != b.cols_out: {}",
res.cols(),
b.cols_out()
);
assert_eq!(
a.cols(),
b.cols_in(),
"a.cols(): {} != b.cols_in: {}",
a.cols(),
b.cols_in()
);
}
let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes(
res.size(),
a.size(),
b.rows(),
b.cols_in(),
b.cols_out(),
b.size(),
));
unsafe {
vmp::vmp_apply_dft_to_dft_add(
self.ptr,
res.as_mut_ptr() as *mut vec_znx_dft_t,
(res.size() * res.cols()) as u64,
a.as_ptr() as *const vec_znx_dft_t,
(a.size() * a.cols()) as u64,
b.as_ptr() as *const vmp::vmp_pmat_t,
(b.rows() * b.cols_in()) as u64,
(b.size() * b.cols_out()) as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
}
#[cfg(test)]
mod tests {
use crate::{
Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig,
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut,
};
use sampling::source::Source;
use super::{MatZnxDftAlloc, MatZnxDftScratch};
#[test]
fn vmp_prepare_row() {
let module: Module<FFT64> = Module::<FFT64>::new(16);
let log_base2k: usize = 8;
let mat_rows: usize = 4;
let mat_cols_in: usize = 2;
let mat_cols_out: usize = 2;
let mat_size: usize = 5;
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
let mut b_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
let mut mat: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
for col_in in 0..mat_cols_in {
for row_i in 0..mat_rows {
let mut source: Source = Source::new([0u8; 32]);
(0..mat_cols_out).for_each(|col_out| {
a.fill_uniform(log_base2k, col_out, mat_size, &mut source);
module.vec_znx_dft(&mut a_dft, col_out, &a, col_out);
});
module.vmp_prepare_row(&mut mat, row_i, col_in, &a_dft);
module.vmp_extract_row(&mut b_dft, &mat, row_i, col_in);
assert_eq!(a_dft.raw(), b_dft.raw());
}
}
}
#[test]
fn vmp_apply() {
let log_n: i32 = 5;
let n: usize = 1 << log_n;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let log_base2k: usize = 15;
let a_size: usize = 5;
let mat_size: usize = 6;
let res_size: usize = 5;
[1, 2].iter().for_each(|in_cols| {
[1, 2].iter().for_each(|out_cols| {
let a_cols: usize = *in_cols;
let res_cols: usize = *out_cols;
let mat_rows: usize = a_size;
let mat_cols_in: usize = a_cols;
let mat_cols_out: usize = res_cols;
let res_cols: usize = mat_cols_out;
let mut scratch: ScratchOwned = ScratchOwned::new(
module.vmp_apply_tmp_bytes(
res_size,
a_size,
mat_rows,
mat_cols_in,
mat_cols_out,
mat_size,
) | module.vec_znx_big_normalize_tmp_bytes(),
);
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
(0..a_cols).for_each(|i| {
a.at_mut(i, 2)[i + 1] = 1;
});
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
(0..a.size()).for_each(|row_i| {
(0..mat_cols_in).for_each(|col_in_i| {
(0..mat_cols_out).for_each(|col_out_i| {
let idx = 1 + col_in_i * mat_cols_out + col_out_i;
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i);
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
});
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
});
});
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
(0..a_cols).for_each(|i| {
module.vec_znx_dft(&mut a_dft, i, &a, i);
});
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, res_size);
(0..mat_cols_out).for_each(|i| {
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
module.vec_znx_big_normalize(log_base2k, &mut res_have, i, &c_big, i, scratch.borrow());
});
(0..mat_cols_out).for_each(|col_i| {
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
(0..a_cols).for_each(|i| {
res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1;
});
res_have.decode_vec_i64(col_i, log_base2k, log_base2k * 3, &mut res_have_vi64);
assert_eq!(res_have_vi64, res_want_vi64);
});
});
});
}
}

View File

@@ -1,5 +1,6 @@
use crate::GALOISGENERATOR;
use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info};
use std::marker::PhantomData;
#[derive(Copy, Clone)]
#[repr(u8)]
@@ -8,37 +9,50 @@ pub enum BACKEND {
NTT120,
}
pub struct Module {
pub ptr: *mut MODULE,
pub n: usize,
pub backend: BACKEND,
pub trait Backend {
const KIND: BACKEND;
fn module_type() -> u32;
}
impl Module {
pub struct FFT64;
pub struct NTT120;
impl Backend for FFT64 {
const KIND: BACKEND = BACKEND::FFT64;
fn module_type() -> u32 {
0
}
}
impl Backend for NTT120 {
const KIND: BACKEND = BACKEND::NTT120;
fn module_type() -> u32 {
1
}
}
pub struct Module<B: Backend> {
pub ptr: *mut MODULE,
n: usize,
_marker: PhantomData<B>,
}
impl<B: Backend> Module<B> {
// Instantiates a new module.
pub fn new(n: usize, module_type: BACKEND) -> Self {
pub fn new(n: usize) -> Self {
unsafe {
let module_type_u32: u32;
match module_type {
BACKEND::FFT64 => module_type_u32 = 0,
BACKEND::NTT120 => module_type_u32 = 1,
}
let m: *mut module_info_t = new_module_info(n as u64, module_type_u32);
let m: *mut module_info_t = new_module_info(n as u64, B::module_type());
if m.is_null() {
panic!("Failed to create module.");
}
Self {
ptr: m,
n: n,
backend: module_type,
_marker: PhantomData,
}
}
}
pub fn backend(&self) -> BACKEND {
self.backend
}
pub fn n(&self) -> usize {
self.n
}
@@ -51,26 +65,27 @@ impl Module {
(self.n() << 1) as _
}
// Returns GALOISGENERATOR^|gen| * sign(gen)
pub fn galois_element(&self, gen: i64) -> i64 {
if gen == 0 {
// Returns GALOISGENERATOR^|generator| * sign(generator)
pub fn galois_element(&self, generator: i64) -> i64 {
if generator == 0 {
return 1;
}
((mod_exp_u64(GALOISGENERATOR, gen.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * gen.signum()
((mod_exp_u64(GALOISGENERATOR, generator.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * generator.signum()
}
// Returns gen^-1
pub fn galois_element_inv(&self, gen: i64) -> i64 {
if gen == 0 {
pub fn galois_element_inv(&self, gal_el: i64) -> i64 {
if gal_el == 0 {
panic!("cannot invert 0")
}
((mod_exp_u64(gen.abs() as u64, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1)) as i64)
* gen.signum()
((mod_exp_u64(gal_el.abs() as u64, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1)) as i64)
* gal_el.signum()
}
}
pub fn free(self) {
impl<B: Backend> Drop for Module<B> {
fn drop(&mut self) {
unsafe { delete_module_info(self.ptr) }
drop(self);
}
}

View File

@@ -1,56 +1,132 @@
use crate::{Infos, Module, VecZnx};
use crate::znx_base::ZnxViewMut;
use crate::{FFT64, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxToMut};
use rand_distr::{Distribution, Normal};
use sampling::source::Source;
pub trait Sampling {
/// Fills the first `cols` cols with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source);
pub trait FillUniform {
/// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
fn fill_uniform(&mut self, log_base2k: usize, col_i: usize, size: usize, source: &mut Source);
}
/// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\].
fn add_dist_f64<D: Distribution<f64>>(
&self,
pub trait FillDistF64 {
fn fill_dist_f64<D: Distribution<f64>>(
&mut self,
log_base2k: usize,
a: &mut VecZnx,
col_i: usize,
log_k: usize,
source: &mut Source,
dist: D,
bound: f64,
);
/// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\].
fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64);
}
impl Sampling for Module {
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source) {
pub trait AddDistF64 {
/// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\].
fn add_dist_f64<D: Distribution<f64>>(
&mut self,
log_base2k: usize,
col_i: usize,
log_k: usize,
source: &mut Source,
dist: D,
bound: f64,
);
}
pub trait FillNormal {
fn fill_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64);
}
pub trait AddNormal {
/// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\].
fn add_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64);
}
impl<T> FillUniform for VecZnx<T>
where
VecZnx<T>: VecZnxToMut,
{
fn fill_uniform(&mut self, log_base2k: usize, col_i: usize, size: usize, source: &mut Source) {
let mut a: VecZnx<&mut [u8]> = self.to_mut();
let base2k: u64 = 1 << log_base2k;
let mask: u64 = base2k - 1;
let base2k_half: i64 = (base2k >> 1) as i64;
let size: usize = a.n() * cols;
a.raw_mut()[..size]
.iter_mut()
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
(0..size).for_each(|j| {
a.at_mut(col_i, j)
.iter_mut()
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
})
}
}
fn add_dist_f64<D: Distribution<f64>>(
&self,
impl<T> FillDistF64 for VecZnx<T>
where
VecZnx<T>: VecZnxToMut,
{
fn fill_dist_f64<D: Distribution<f64>>(
&mut self,
log_base2k: usize,
a: &mut VecZnx,
col_i: usize,
log_k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
let mut a: VecZnx<&mut [u8]> = self.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let log_base2k_rem: usize = log_k % log_base2k;
let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1;
let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k;
if log_base2k_rem != 0 {
a.at_mut(a.cols() - 1).iter_mut().for_each(|a| {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a = (dist_f64.round() as i64) << log_base2k_rem;
});
} else {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a = dist_f64.round() as i64
});
}
}
}
impl<T> AddDistF64 for VecZnx<T>
where
VecZnx<T>: VecZnxToMut,
{
fn add_dist_f64<D: Distribution<f64>>(
&mut self,
log_base2k: usize,
col_i: usize,
log_k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
let mut a: VecZnx<&mut [u8]> = self.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1;
let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k;
if log_base2k_rem != 0 {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
@@ -58,7 +134,7 @@ impl Sampling for Module {
*a += (dist_f64.round() as i64) << log_base2k_rem;
});
} else {
a.at_mut(a.cols() - 1).iter_mut().for_each(|a| {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
@@ -67,11 +143,16 @@ impl Sampling for Module {
});
}
}
}
fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64) {
self.add_dist_f64(
impl<T> FillNormal for VecZnx<T>
where
VecZnx<T>: VecZnxToMut,
{
fn fill_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) {
self.fill_dist_f64(
log_base2k,
a,
col_i,
log_k,
source,
Normal::new(0.0, sigma).unwrap(),
@@ -79,3 +160,206 @@ impl Sampling for Module {
);
}
}
impl<T> AddNormal for VecZnx<T>
where
VecZnx<T>: VecZnxToMut,
{
fn add_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) {
self.add_dist_f64(
log_base2k,
col_i,
log_k,
source,
Normal::new(0.0, sigma).unwrap(),
bound,
);
}
}
impl<T> FillDistF64 for VecZnxBig<T, FFT64>
where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{
fn fill_dist_f64<D: Distribution<f64>>(
&mut self,
log_base2k: usize,
col_i: usize,
log_k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1;
let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k;
if log_base2k_rem != 0 {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a = (dist_f64.round() as i64) << log_base2k_rem;
});
} else {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a = dist_f64.round() as i64
});
}
}
}
impl<T> AddDistF64 for VecZnxBig<T, FFT64>
where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{
fn add_dist_f64<D: Distribution<f64>>(
&mut self,
log_base2k: usize,
col_i: usize,
log_k: usize,
source: &mut Source,
dist: D,
bound: f64,
) {
let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
assert!(
(bound.log2().ceil() as i64) < 64,
"invalid bound: ceil(log2(bound))={} > 63",
(bound.log2().ceil() as i64)
);
let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1;
let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k;
if log_base2k_rem != 0 {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a += (dist_f64.round() as i64) << log_base2k_rem;
});
} else {
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
let mut dist_f64: f64 = dist.sample(source);
while dist_f64.abs() > bound {
dist_f64 = dist.sample(source)
}
*a += dist_f64.round() as i64
});
}
}
}
impl<T> FillNormal for VecZnxBig<T, FFT64>
where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{
fn fill_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) {
self.fill_dist_f64(
log_base2k,
col_i,
log_k,
source,
Normal::new(0.0, sigma).unwrap(),
bound,
);
}
}
impl<T> AddNormal for VecZnxBig<T, FFT64>
where
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
{
fn add_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) {
self.add_dist_f64(
log_base2k,
col_i,
log_k,
source,
Normal::new(0.0, sigma).unwrap(),
bound,
);
}
}
#[cfg(test)]
mod tests {
use super::{AddNormal, FillUniform};
use crate::vec_znx_ops::*;
use crate::znx_base::*;
use crate::{FFT64, Module, Stats, VecZnx};
use sampling::source::Source;
#[test]
fn vec_znx_fill_uniform() {
let n: usize = 4096;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let log_base2k: usize = 17;
let size: usize = 5;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let zero: Vec<i64> = vec![0; n];
let one_12_sqrt: f64 = 0.28867513459481287;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = module.new_vec_znx(cols, size);
a.fill_uniform(log_base2k, col_i, size, &mut source);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(col_i, log_base2k);
assert!(
(std - one_12_sqrt).abs() < 0.01,
"std={} ~!= {}",
std,
one_12_sqrt
);
}
})
});
}
#[test]
fn vec_znx_add_normal() {
let n: usize = 4096;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let log_base2k: usize = 17;
let log_k: usize = 2 * 17;
let size: usize = 5;
let sigma: f64 = 3.2;
let bound: f64 = 6.0 * sigma;
let mut source: Source = Source::new([0u8; 32]);
let cols: usize = 2;
let zero: Vec<i64> = vec![0; n];
let k_f64: f64 = (1u64 << log_k as u64) as f64;
(0..cols).for_each(|col_i| {
let mut a: VecZnx<_> = module.new_vec_znx(cols, size);
a.add_normal(log_base2k, col_i, log_k, &mut source, sigma, bound);
(0..cols).for_each(|col_j| {
if col_j != col_i {
(0..size).for_each(|limb_i| {
assert_eq!(a.at(col_j, limb_i), zero);
})
} else {
let std: f64 = a.std(col_i, log_base2k) * k_f64;
assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma);
}
})
});
}
}

306
base2k/src/scalar_znx.rs Normal file
View File

@@ -0,0 +1,306 @@
use crate::ffi::vec_znx;
use crate::znx_base::ZnxInfos;
use crate::{
Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned,
};
use rand::seq::SliceRandom;
use rand_core::RngCore;
use rand_distr::{Distribution, weighted::WeightedIndex};
use sampling::source::Source;
pub struct ScalarZnx<D> {
pub(crate) data: D,
pub(crate) n: usize,
pub(crate) cols: usize,
}
impl<D> ZnxInfos for ScalarZnx<D> {
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
1
}
}
impl<D> ZnxSliceSize for ScalarZnx<D> {
fn sl(&self) -> usize {
self.n()
}
}
impl<D> DataView for ScalarZnx<D> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D> DataViewMut for ScalarZnx<D> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: AsRef<[u8]>> ZnxView for ScalarZnx<D> {
type Scalar = i64;
}
impl<D: AsMut<[u8]> + AsRef<[u8]>> ScalarZnx<D> {
pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) {
let choices: [i64; 3] = [-1, 0, 1];
let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0];
let dist: WeightedIndex<f64> = WeightedIndex::new(&weights).unwrap();
self.at_mut(col, 0)
.iter_mut()
.for_each(|x: &mut i64| *x = choices[dist.sample(source)]);
}
pub fn fill_ternary_hw(&mut self, col: usize, hw: usize, source: &mut Source) {
assert!(hw <= self.n());
self.at_mut(col, 0)[..hw]
.iter_mut()
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
self.at_mut(col, 0).shuffle(source);
}
}
impl<D: From<Vec<u8>>> ScalarZnx<D> {
pub(crate) fn bytes_of<S: Sized>(n: usize, cols: usize) -> usize {
n * cols * size_of::<S>()
}
pub(crate) fn new<S: Sized>(n: usize, cols: usize) -> Self {
let data = alloc_aligned::<u8>(Self::bytes_of::<S>(n, cols));
Self {
data: data.into(),
n,
cols,
}
}
pub(crate) fn new_from_bytes<S: Sized>(n: usize, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of::<S>(n, cols));
Self {
data: data.into(),
n,
cols,
}
}
}
pub type ScalarZnxOwned = ScalarZnx<Vec<u8>>;
pub(crate) fn bytes_of_scalar_znx<B: Backend>(module: &Module<B>, cols: usize) -> usize {
ScalarZnxOwned::bytes_of::<i64>(module.n(), cols)
}
pub trait ScalarZnxAlloc {
fn bytes_of_scalar_znx(&self, cols: usize) -> usize;
fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned;
fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned;
}
impl<B: Backend> ScalarZnxAlloc for Module<B> {
fn bytes_of_scalar_znx(&self, cols: usize) -> usize {
ScalarZnxOwned::bytes_of::<i64>(self.n(), cols)
}
fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned {
ScalarZnxOwned::new::<i64>(self.n(), cols)
}
fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned {
ScalarZnxOwned::new_from_bytes::<i64>(self.n(), cols, bytes)
}
}
pub trait ScalarZnxOps {
fn scalar_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxToMut,
A: ScalarZnxToRef;
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: ScalarZnxToMut;
}
impl<B: Backend> ScalarZnxOps for Module<B> {
fn scalar_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxToMut,
A: ScalarZnxToRef,
{
let a: ScalarZnx<&[u8]> = a.to_ref();
let mut res: ScalarZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
k,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: ScalarZnxToMut,
{
let mut a: ScalarZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
k,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
impl<D> ScalarZnx<D> {
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
Self { data, n, cols }
}
}
pub trait ScalarZnxToRef {
fn to_ref(&self) -> ScalarZnx<&[u8]>;
}
pub trait ScalarZnxToMut {
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]>;
}
impl ScalarZnxToMut for ScalarZnx<Vec<u8>> {
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
ScalarZnx {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToMut for ScalarZnx<Vec<u8>> {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToRef for ScalarZnx<Vec<u8>> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToRef for ScalarZnx<Vec<u8>> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToMut for ScalarZnx<&mut [u8]> {
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
ScalarZnx {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToMut for ScalarZnx<&mut [u8]> {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToRef for ScalarZnx<&mut [u8]> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToRef for ScalarZnx<&mut [u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToRef for ScalarZnx<&[u8]> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
data: self.data,
n: self.n,
cols: self.cols,
}
}
}
impl VecZnxToRef for ScalarZnx<&[u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
}
}
}

View File

@@ -0,0 +1,233 @@
use std::marker::PhantomData;
use crate::ffi::svp;
use crate::znx_base::ZnxInfos;
use crate::{
Backend, DataView, DataViewMut, FFT64, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView,
alloc_aligned,
};
pub struct ScalarZnxDft<D, B: Backend> {
data: D,
n: usize,
cols: usize,
_phantom: PhantomData<B>,
}
impl<D, B: Backend> ZnxInfos for ScalarZnxDft<D, B> {
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
fn n(&self) -> usize {
self.n
}
fn size(&self) -> usize {
1
}
}
impl<D> ZnxSliceSize for ScalarZnxDft<D, FFT64> {
fn sl(&self) -> usize {
self.n()
}
}
impl<D, B: Backend> DataView for ScalarZnxDft<D, B> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl<D, B: Backend> DataViewMut for ScalarZnxDft<D, B> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: AsRef<[u8]>> ZnxView for ScalarZnxDft<D, FFT64> {
type Scalar = f64;
}
pub(crate) fn bytes_of_scalar_znx_dft<B: Backend>(module: &Module<B>, cols: usize) -> usize {
ScalarZnxDftOwned::bytes_of(module, cols)
}
impl<D: From<Vec<u8>>, B: Backend> ScalarZnxDft<D, B> {
pub(crate) fn bytes_of(module: &Module<B>, cols: usize) -> usize {
unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols }
}
pub(crate) fn new(module: &Module<B>, cols: usize) -> Self {
let data = alloc_aligned::<u8>(Self::bytes_of(module, cols));
Self {
data: data.into(),
n: module.n(),
cols,
_phantom: PhantomData,
}
}
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == Self::bytes_of(module, cols));
Self {
data: data.into(),
n: module.n(),
cols,
_phantom: PhantomData,
}
}
}
impl<D, B: Backend> ScalarZnxDft<D, B> {
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
Self {
data,
n,
cols,
_phantom: PhantomData,
}
}
pub fn as_vec_znx_dft(self) -> VecZnxDft<D, B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
pub type ScalarZnxDftOwned<B> = ScalarZnxDft<Vec<u8>, B>;
pub trait ScalarZnxDftToRef<B: Backend> {
fn to_ref(&self) -> ScalarZnxDft<&[u8], B>;
}
pub trait ScalarZnxDftToMut<B: Backend> {
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B>;
}
impl<B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
ScalarZnxDft {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
ScalarZnxDft {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
ScalarZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<&mut [u8], B> {
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
ScalarZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
ScalarZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&mut [u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
_phantom: PhantomData,
}
}
}

View File

@@ -0,0 +1,103 @@
use crate::ffi::svp;
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft,
VecZnxDftToMut, VecZnxDftToRef,
};
pub trait ScalarZnxDftAlloc<B: Backend> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>;
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>;
}
pub trait ScalarZnxDftOps<BACKEND: Backend> {
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxDftToMut<BACKEND>,
A: ScalarZnxToRef;
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: ScalarZnxDftToRef<BACKEND>,
B: VecZnxDftToRef<FFT64>;
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: ScalarZnxDftToRef<BACKEND>;
}
impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B> {
ScalarZnxDftOwned::new(self, cols)
}
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize {
ScalarZnxDftOwned::bytes_of(self, cols)
}
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B> {
ScalarZnxDftOwned::new_from_bytes(self, cols, bytes)
}
}
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxDftToMut<FFT64>,
A: ScalarZnxToRef,
{
unsafe {
svp::svp_prepare(
self.ptr,
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
a.to_ref().at_ptr(a_col, 0),
)
}
}
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: ScalarZnxDftToRef<FFT64>,
B: VecZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
let b: VecZnxDft<&[u8], FFT64> = b.to_ref();
unsafe {
svp::svp_apply_dft_to_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
b.size() as u64,
b.cols() as u64,
)
}
}
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: ScalarZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
unsafe {
svp::svp_apply_dft_to_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
)
}
}
}

View File

@@ -1,13 +1,19 @@
use crate::{Encoding, Infos, VecZnx};
use crate::znx_base::ZnxInfos;
use crate::{Decoding, VecZnx};
use rug::Float;
use rug::float::Round;
use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound};
impl VecZnx {
pub fn std(&self, poly_idx: usize, log_base2k: usize) -> f64 {
let prec: u32 = (self.cols() * log_base2k) as u32;
pub trait Stats {
/// Returns the standard devaition of the i-th polynomial.
fn std(&self, col_i: usize, log_base2k: usize) -> f64;
}
impl<D: AsRef<[u8]>> Stats for VecZnx<D> {
fn std(&self, col_i: usize, log_base2k: usize) -> f64 {
let prec: u32 = (self.size() * log_base2k) as u32;
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();
self.decode_vec_float(poly_idx, log_base2k, &mut data);
self.decode_vec_float(col_i, log_base2k, &mut data);
// std = sqrt(sum((xi - avg)^2) / n)
let mut avg: Float = Float::with_val(prec, 0);
data.iter().for_each(|x| {

View File

@@ -1,276 +0,0 @@
use crate::ffi::svp::{self, svp_ppol_t};
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::{BACKEND, LAYOUT, Module, VecZnx, VecZnxDft, assert_alignement};
use crate::{Infos, alloc_aligned, cast_mut};
use rand::seq::SliceRandom;
use rand_core::RngCore;
use rand_distr::{Distribution, weighted::WeightedIndex};
use sampling::source::Source;
pub struct Scalar {
pub n: usize,
pub data: Vec<i64>,
pub ptr: *mut i64,
}
impl Module {
pub fn new_scalar(&self) -> Scalar {
Scalar::new(self.n())
}
}
impl Scalar {
pub fn new(n: usize) -> Self {
let mut data: Vec<i64> = alloc_aligned::<i64>(n);
let ptr: *mut i64 = data.as_mut_ptr();
Self {
n: n,
data: data,
ptr: ptr,
}
}
pub fn n(&self) -> usize {
self.n
}
pub fn bytes_of(n: usize) -> usize {
n * std::mem::size_of::<i64>()
}
pub fn from_bytes(n: usize, bytes: &mut [u8]) -> Self {
let size: usize = Self::bytes_of(n);
debug_assert!(
bytes.len() == size,
"invalid buffer: bytes.len()={} < self.bytes_of(n={})={}",
bytes.len(),
n,
size
);
#[cfg(debug_assertions)]
{
assert_alignement(bytes.as_ptr())
}
unsafe {
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
Self {
n: n,
data: Vec::from_raw_parts(bytes_i64.as_mut_ptr(), bytes.len(), bytes.len()),
ptr: ptr,
}
}
}
pub fn from_bytes_borrow(n: usize, bytes: &mut [u8]) -> Self {
let size: usize = Self::bytes_of(n);
debug_assert!(
bytes.len() == size,
"invalid buffer: bytes.len()={} < self.bytes_of(n={})={}",
bytes.len(),
n,
size
);
#[cfg(debug_assertions)]
{
assert_alignement(bytes.as_ptr())
}
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
Self {
n: n,
data: Vec::new(),
ptr: ptr,
}
}
pub fn as_ptr(&self) -> *const i64 {
self.ptr
}
pub fn raw(&self) -> &[i64] {
unsafe { std::slice::from_raw_parts(self.ptr, self.n) }
}
pub fn raw_mut(&self) -> &mut [i64] {
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n) }
}
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
let choices: [i64; 3] = [-1, 0, 1];
let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0];
let dist: WeightedIndex<f64> = WeightedIndex::new(&weights).unwrap();
self.data
.iter_mut()
.for_each(|x: &mut i64| *x = choices[dist.sample(source)]);
}
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
assert!(hw <= self.n());
self.data[..hw]
.iter_mut()
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
self.data.shuffle(source);
}
pub fn as_vec_znx(&self) -> VecZnx {
VecZnx {
n: self.n,
size: 1, // TODO REVIEW IF NEED TO ADD size TO SCALAR
cols: 1,
layout: LAYOUT::COL,
data: Vec::new(),
ptr: self.ptr,
}
}
}
pub trait ScalarOps {
fn bytes_of_scalar(&self) -> usize;
fn new_scalar(&self) -> Scalar;
fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar;
fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar;
}
impl ScalarOps for Module {
fn bytes_of_scalar(&self) -> usize {
Scalar::bytes_of(self.n())
}
fn new_scalar(&self) -> Scalar {
Scalar::new(self.n())
}
fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar {
Scalar::from_bytes(self.n(), bytes)
}
fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar {
Scalar::from_bytes_borrow(self.n(), tmp_bytes)
}
}
pub struct SvpPPol {
pub n: usize,
pub data: Vec<u8>,
pub ptr: *mut u8,
pub backend: BACKEND,
}
/// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft].
/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb.
impl SvpPPol {
pub fn new(module: &Module) -> Self {
module.new_svp_ppol()
}
/// Returns the ring degree of the [SvpPPol].
pub fn n(&self) -> usize {
self.n
}
pub fn bytes_of(module: &Module) -> usize {
module.bytes_of_svp_ppol()
}
pub fn from_bytes(module: &Module, bytes: &mut [u8]) -> SvpPPol {
#[cfg(debug_assertions)]
{
assert_alignement(bytes.as_ptr());
assert_eq!(bytes.len(), module.bytes_of_svp_ppol());
}
unsafe {
Self {
n: module.n(),
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
ptr: bytes.as_mut_ptr(),
backend: module.backend(),
}
}
}
pub fn from_bytes_borrow(module: &Module, tmp_bytes: &mut [u8]) -> SvpPPol {
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
assert_eq!(tmp_bytes.len(), module.bytes_of_svp_ppol());
}
Self {
n: module.n(),
data: Vec::new(),
ptr: tmp_bytes.as_mut_ptr(),
backend: module.backend(),
}
}
/// Returns the number of cols of the [SvpPPol], which is always 1.
pub fn cols(&self) -> usize {
1
}
}
pub trait SvpPPolOps {
/// Allocates a new [SvpPPol].
fn new_svp_ppol(&self) -> SvpPPol;
/// Returns the minimum number of bytes necessary to allocate
/// a new [SvpPPol] through [SvpPPol::from_bytes] ro.
fn bytes_of_svp_ppol(&self) -> usize;
/// Allocates a new [SvpPPol] from an array of bytes.
/// The array of bytes is owned by the [SvpPPol].
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol;
/// Allocates a new [SvpPPol] from an array of bytes.
/// The array of bytes is borrowed by the [SvpPPol].
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol;
/// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft].
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar);
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
/// the [VecZnxDft] is multiplied with [SvpPPol].
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx);
}
impl SvpPPolOps for Module {
fn new_svp_ppol(&self) -> SvpPPol {
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_svp_ppol());
let ptr: *mut u8 = data.as_mut_ptr();
SvpPPol {
data: data,
ptr: ptr,
n: self.n(),
backend: self.backend(),
}
}
fn bytes_of_svp_ppol(&self) -> usize {
unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize }
}
fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol {
SvpPPol::from_bytes(self, bytes)
}
fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol {
SvpPPol::from_bytes_borrow(self, tmp_bytes)
}
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) {
unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) }
}
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) {
unsafe {
svp::svp_apply_dft(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.cols() as u64,
a.ptr as *const svp_ppol_t,
b.as_ptr(),
b.cols() as u64,
b.n() as u64,
)
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,90 +1,26 @@
use crate::ffi::vec_znx_big::{self, vec_znx_big_t};
use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement};
use crate::ffi::vec_znx_big;
use crate::znx_base::{ZnxInfos, ZnxView};
use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned};
use std::fmt;
use std::marker::PhantomData;
pub struct VecZnxBig {
pub data: Vec<u8>,
pub ptr: *mut u8,
pub n: usize,
pub size: usize,
pub cols: usize,
pub layout: LAYOUT,
pub backend: BACKEND,
pub struct VecZnxBig<D, B: Backend> {
data: D,
n: usize,
cols: usize,
size: usize,
_phantom: PhantomData<B>,
}
impl VecZnxBig {
/// Returns a new [VecZnxBig] with the provided data as backing array.
/// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [Module::bytes_of_vec_znx_big].
pub fn from_bytes(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(size, cols));
assert_alignement(bytes.as_ptr())
};
unsafe {
Self {
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
ptr: bytes.as_mut_ptr(),
n: module.n(),
size: size,
layout: LAYOUT::COL,
cols: cols,
backend: module.backend,
}
}
impl<D, B: Backend> ZnxInfos for VecZnxBig<D, B> {
fn cols(&self) -> usize {
self.cols
}
pub fn from_bytes_borrow(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> Self {
#[cfg(debug_assertions)]
{
assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(size, cols));
assert_alignement(bytes.as_ptr());
}
Self {
data: Vec::new(),
ptr: bytes.as_mut_ptr(),
n: module.n(),
size: size,
layout: LAYOUT::COL,
cols: cols,
backend: module.backend,
}
fn rows(&self) -> usize {
1
}
pub fn as_vec_znx_dft(&mut self) -> VecZnxDft {
VecZnxDft {
data: Vec::new(),
ptr: self.ptr,
n: self.n,
size: self.size,
layout: LAYOUT::COL,
cols: self.cols,
backend: self.backend,
}
}
pub fn backend(&self) -> BACKEND {
self.backend
}
/// Returns a non-mutable reference of `T` of the entire contiguous array of the [VecZnxDft].
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
/// The length of the returned array is cols * n.
pub fn raw<T>(&self, module: &Module) -> &[T] {
let ptr: *const T = self.ptr as *const T;
let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::<T>();
unsafe { &std::slice::from_raw_parts(ptr, len) }
}
}
impl Infos for VecZnxBig {
/// Returns the base 2 logarithm of the [VecZnx] degree.
fn log_n(&self) -> usize {
(usize::BITS - (self.n - 1).leading_zeros()) as _
}
/// Returns the [VecZnx] degree.
fn n(&self) -> usize {
self.n
}
@@ -92,270 +28,217 @@ impl Infos for VecZnxBig {
fn size(&self) -> usize {
self.size
}
}
fn layout(&self) -> LAYOUT {
self.layout
}
/// Returns the number of cols of the [VecZnx].
fn cols(&self) -> usize {
self.cols
}
/// Returns the number of rows of the [VecZnx].
fn rows(&self) -> usize {
1
impl<D> ZnxSliceSize for VecZnxBig<D, FFT64> {
fn sl(&self) -> usize {
self.n() * self.cols()
}
}
pub trait VecZnxBigOps {
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
fn new_vec_znx_big(&self, size: usize, cols: usize) -> VecZnxBig;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxBig].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxBig;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
///
/// Behavior: the backing array is only borrowed.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxBig].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig;
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
fn bytes_of_vec_znx_big(&self, size: usize, cols: usize) -> usize;
/// b <- b - a
fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx);
/// c <- b - a
fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig);
/// c <- b + a
fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig);
/// b <- b + a
fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx);
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
/// b <- normalize(a)
fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]);
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize;
fn vec_znx_big_range_normalize_base2k(
&self,
log_base2k: usize,
res: &mut VecZnx,
a: &VecZnxBig,
a_range_begin: usize,
a_range_xend: usize,
a_range_step: usize,
tmp_bytes: &mut [u8],
);
fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig);
fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig);
impl<D, B: Backend> DataView for VecZnxBig<D, B> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl VecZnxBigOps for Module {
fn new_vec_znx_big(&self, size: usize, cols: usize) -> VecZnxBig {
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_vec_znx_big(size, cols));
let ptr: *mut u8 = data.as_mut_ptr();
impl<D, B: Backend> DataViewMut for VecZnxBig<D, B> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: AsRef<[u8]>> ZnxView for VecZnxBig<D, FFT64> {
type Scalar = i64;
}
pub(crate) fn bytes_of_vec_znx_big<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize {
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols }
}
impl<D: From<Vec<u8>>, B: Backend> VecZnxBig<D, B> {
pub(crate) fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
let data = alloc_aligned::<u8>(bytes_of_vec_znx_big(module, cols, size));
Self {
data: data.into(),
n: module.n(),
cols,
size,
_phantom: PhantomData,
}
}
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == bytes_of_vec_znx_big(module, cols, size));
Self {
data: data.into(),
n: module.n(),
cols,
size,
_phantom: PhantomData,
}
}
}
impl<D, B: Backend> VecZnxBig<D, B> {
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
Self {
data,
n,
cols,
size,
_phantom: PhantomData,
}
}
}
impl<D> VecZnxBig<D, FFT64>
where
VecZnxBig<D, FFT64>: VecZnxBigToMut<FFT64> + ZnxInfos,
{
// Consumes the VecZnxBig to return a VecZnx.
// Useful when no normalization is needed.
pub fn to_vec_znx_small(self) -> VecZnx<D> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
}
}
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxBig<C, FFT64>, a_col: usize)
where
VecZnxBig<C, FFT64>: VecZnxBigToRef<FFT64> + ZnxInfos,
{
#[cfg(debug_assertions)]
{
assert!(self_col < self.cols());
assert!(a_col < a.cols());
}
let min_size: usize = self.size.min(a.size());
let max_size: usize = self.size;
let mut self_mut: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
let a_ref: VecZnxBig<&[u8], FFT64> = a.to_ref();
(0..min_size).for_each(|i: usize| {
self_mut
.at_mut(self_col, i)
.copy_from_slice(a_ref.at(a_col, i));
});
(min_size..max_size).for_each(|i| {
self_mut.zero_at(self_col, i);
});
}
}
pub type VecZnxBigOwned<B> = VecZnxBig<Vec<u8>, B>;
pub trait VecZnxBigToRef<B: Backend> {
fn to_ref(&self) -> VecZnxBig<&[u8], B>;
}
pub trait VecZnxBigToMut<B: Backend> {
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>;
}
impl<B: Backend> VecZnxBigToMut<B> for VecZnxBig<Vec<u8>, B> {
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
VecZnxBig {
data: data,
ptr: ptr,
n: self.n(),
size: size,
layout: LAYOUT::COL,
cols: cols,
backend: self.backend(),
}
}
fn new_vec_znx_big_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxBig {
VecZnxBig::from_bytes(self, size, cols, bytes)
}
fn new_vec_znx_big_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig {
VecZnxBig::from_bytes_borrow(self, size, cols, tmp_bytes)
}
fn bytes_of_vec_znx_big(&self, size: usize, cols: usize) -> usize {
unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, cols as u64) as usize * size }
}
fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) {
unsafe {
vec_znx_big::vec_znx_big_sub_small_a(
self.ptr,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
)
}
}
fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) {
unsafe {
vec_znx_big::vec_znx_big_sub_small_a(
self.ptr,
c.ptr as *mut vec_znx_big_t,
c.cols() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
)
}
}
fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) {
unsafe {
vec_znx_big::vec_znx_big_add_small(
self.ptr,
c.ptr as *mut vec_znx_big_t,
c.cols() as u64,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
)
}
}
fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) {
unsafe {
vec_znx_big::vec_znx_big_add_small(
self.ptr,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
)
}
}
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize }
}
fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) {
debug_assert!(
tmp_bytes.len() >= <Module as VecZnxBigOps>::vec_znx_big_normalize_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}",
tmp_bytes.len(),
<Module as VecZnxBigOps>::vec_znx_big_normalize_tmp_bytes(self)
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr())
}
unsafe {
vec_znx_big::vec_znx_big_normalize_base2k(
self.ptr,
log_base2k as u64,
b.as_mut_ptr(),
b.cols() as u64,
b.n() as u64,
a.ptr as *mut vec_znx_big_t,
a.cols() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize {
unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.ptr) as usize }
}
fn vec_znx_big_range_normalize_base2k(
&self,
log_base2k: usize,
res: &mut VecZnx,
a: &VecZnxBig,
a_range_begin: usize,
a_range_xend: usize,
a_range_step: usize,
tmp_bytes: &mut [u8],
) {
debug_assert!(
tmp_bytes.len() >= <Module as VecZnxBigOps>::vec_znx_big_range_normalize_base2k_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}",
tmp_bytes.len(),
<Module as VecZnxBigOps>::vec_znx_big_range_normalize_base2k_tmp_bytes(self)
);
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr())
}
unsafe {
vec_znx_big::vec_znx_big_range_normalize_base2k(
self.ptr,
log_base2k as u64,
res.as_mut_ptr(),
res.cols() as u64,
res.n() as u64,
a.ptr as *mut vec_znx_big_t,
a_range_begin as u64,
a_range_xend as u64,
a_range_step as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) {
unsafe {
vec_znx_big::vec_znx_big_automorphism(
self.ptr,
gal_el,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
a.ptr as *mut vec_znx_big_t,
a.cols() as u64,
);
}
}
fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) {
unsafe {
vec_znx_big::vec_znx_big_automorphism(
self.ptr,
gal_el,
a.ptr as *mut vec_znx_big_t,
a.cols() as u64,
a.ptr as *mut vec_znx_big_t,
a.cols() as u64,
);
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<Vec<u8>, B> {
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
VecZnxBig {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToMut<B> for VecZnxBig<&mut [u8], B> {
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
VecZnxBig {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<&mut [u8], B> {
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
VecZnxBig {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<&[u8], B> {
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
VecZnxBig {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<D: AsRef<[u8]>> fmt::Display for VecZnxBig<D, FFT64> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"VecZnxBig(n={}, cols={}, size={})",
self.n, self.cols, self.size
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
}
if coeffs.len() > max_show {
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
}
writeln!(f, "]")?;
}
}
Ok(())
}
}

View File

@@ -0,0 +1,632 @@
use crate::ffi::vec_znx;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxScratch,
VecZnxToMut, VecZnxToRef, ZnxSliceSize, bytes_of_vec_znx_big,
};
pub trait VecZnxBigAlloc<B: Backend> {
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments
///
/// * `cols`: the number of polynomials..
/// * `size`: the number of polynomials per column.
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
// /// Returns a new [VecZnxBig] with the provided bytes array as backing array.
// ///
// /// Behavior: the backing array is only borrowed.
// ///
// /// # Arguments
// ///
// /// * `cols`: the number of polynomials..
// /// * `size`: the number of polynomials per column.
// /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
// ///
// /// # Panics
// /// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
// fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<B>;
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize;
}
pub trait VecZnxBigOps<BACKEND: Backend> {
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>,
B: VecZnxBigToRef<BACKEND>;
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>;
/// Adds `a` to `b` and stores the result on `c`.
fn vec_znx_big_add_small<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>,
B: VecZnxToRef;
/// Adds `a` to `b` and stores the result on `b`.
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxToRef;
/// Subtracts `a` to `b` and stores the result on `c`.
fn vec_znx_big_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>,
B: VecZnxBigToRef<BACKEND>;
/// Subtracts `a` from `b` and stores the result on `b`.
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>;
/// Subtracts `b` from `a` and stores the result on `b`.
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>;
/// Subtracts `b` from `a` and stores the result on `c`.
fn vec_znx_big_sub_small_a<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxToRef,
B: VecZnxBigToRef<BACKEND>;
/// Subtracts `a` from `res` and stores the result on `res`.
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxToRef;
/// Subtracts `b` from `a` and stores the result on `c`.
fn vec_znx_big_sub_small_b<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>,
B: VecZnxToRef;
/// Subtracts `res` from `a` and stores the result on `res`.
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxToRef;
/// Negates `a` inplace.
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<BACKEND>;
/// Normalizes `a` and stores the result on `b`.
///
/// # Arguments
///
/// * `log_base2k`: normalization basis.
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize].
fn vec_znx_big_normalize<R, A>(
&self,
log_base2k: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch,
) where
R: VecZnxToMut,
A: VecZnxBigToRef<BACKEND>;
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<BACKEND>,
A: VecZnxBigToRef<BACKEND>;
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<BACKEND>;
}
pub trait VecZnxBigScratch {
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize].
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
}
impl<B: Backend> VecZnxBigAlloc<B> for Module<B> {
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<B> {
VecZnxBig::new(self, cols, size)
}
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
VecZnxBig::new_from_bytes(self, cols, size, bytes)
}
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
bytes_of_vec_znx_big(self, cols, size)
}
}
impl VecZnxBigOps<FFT64> for Module<FFT64> {
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
B: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(res.n(), self.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
fn vec_znx_big_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
B: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(res.n(), self.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
fn vec_znx_big_sub_small_b<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
B: VecZnxToRef,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(res.n(), self.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
fn vec_znx_big_sub_small_a<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxToRef,
B: VecZnxBigToRef<FFT64>,
{
let a: VecZnx<&[u8]> = a.to_ref();
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(res.n(), self.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_big_add_small<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
B: VecZnxToRef,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(res.n(), self.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, res_col: usize)
where
A: VecZnxBigToMut<FFT64>,
{
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
}
unsafe {
vec_znx::vec_znx_negate(
self.ptr,
a.at_mut_ptr(res_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(res_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_big_normalize<R, A>(
&self,
log_base2k: usize,
res: &mut R,
res_col: usize,
a: &A,
a_col: usize,
scratch: &mut Scratch,
) where
R: VecZnxToMut,
A: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
//(Jay)Note: This is calling VezZnxOps::vec_znx_normalize_tmp_bytes and not VecZnxBigOps::vec_znx_big_normalize_tmp_bytes.
// In the FFT backend the tmp sizes are same but will be different in the NTT backend
// assert!(tmp_bytes.len() >= <Self as VecZnxOps<&mut [u8], & [u8]>>::vec_znx_normalize_tmp_bytes(&self));
// assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes, _) = scratch.tmp_slice(<Self as VecZnxBigScratch>::vec_znx_big_normalize_tmp_bytes(
&self,
));
unsafe {
vec_znx::vec_znx_normalize_base2k(
self.ptr,
log_base2k as u64,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxBigToRef<FFT64>,
{
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
k,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxBigToMut<FFT64>,
{
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
k,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
}
impl<B: Backend> VecZnxBigScratch for Module<B> {
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
<Self as VecZnxScratch>::vec_znx_normalize_tmp_bytes(self)
}
}

View File

@@ -1,114 +1,35 @@
use crate::ffi::vec_znx_big::vec_znx_big_t;
use std::marker::PhantomData;
use crate::ffi::vec_znx_dft;
use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t};
use crate::{BACKEND, Infos, LAYOUT, Module, VecZnxBig, assert_alignement};
use crate::{DEFAULTALIGN, VecZnx, alloc_aligned};
use crate::znx_base::ZnxInfos;
use crate::{
Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, alloc_aligned,
};
use std::fmt;
pub struct VecZnxDft {
pub data: Vec<u8>,
pub ptr: *mut u8,
pub n: usize,
pub size: usize,
pub layout: LAYOUT,
pub cols: usize,
pub backend: BACKEND,
pub struct VecZnxDft<D, B: Backend> {
pub(crate) data: D,
pub(crate) n: usize,
pub(crate) cols: usize,
pub(crate) size: usize,
pub(crate) _phantom: PhantomData<B>,
}
impl VecZnxDft {
/// Returns a new [VecZnxDft] with the provided data as backing array.
/// User must ensure that data is properly alligned and that
/// the size of data is at least equal to [Module::bytes_of_vec_znx_dft].
pub fn from_bytes(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft {
#[cfg(debug_assertions)]
{
assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(size, cols));
assert_alignement(bytes.as_ptr())
}
unsafe {
VecZnxDft {
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
ptr: bytes.as_mut_ptr(),
n: module.n(),
size: size,
layout: LAYOUT::COL,
cols: cols,
backend: module.backend,
}
}
}
pub fn from_bytes_borrow(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft {
#[cfg(debug_assertions)]
{
assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(size, cols));
assert_alignement(bytes.as_ptr());
}
VecZnxDft {
data: Vec::new(),
ptr: bytes.as_mut_ptr(),
n: module.n(),
size: size,
layout: LAYOUT::COL,
cols: cols,
backend: module.backend,
}
}
/// Cast a [VecZnxDft] into a [VecZnxBig].
/// The returned [VecZnxBig] shares the backing array
/// with the original [VecZnxDft].
pub fn as_vec_znx_big(&mut self) -> VecZnxBig {
VecZnxBig {
data: Vec::new(),
ptr: self.ptr,
n: self.n,
layout: LAYOUT::COL,
size: self.size,
cols: self.cols,
backend: self.backend,
}
}
pub fn backend(&self) -> BACKEND {
self.backend
}
/// Returns a non-mutable reference of `T` of the entire contiguous array of the [VecZnxDft].
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
/// The length of the returned array is cols * n.
pub fn raw<T>(&self, module: &Module) -> &[T] {
let ptr: *const T = self.ptr as *const T;
let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::<T>();
unsafe { &std::slice::from_raw_parts(ptr, len) }
}
pub fn at<T>(&self, module: &Module, col_i: usize) -> &[T] {
&self.raw::<T>(module)[col_i * module.n()..(col_i + 1) * module.n()]
}
/// Returns a mutable reference of `T` of the entire contiguous array of the [VecZnxDft].
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
/// The length of the returned array is cols * n.
pub fn raw_mut<T>(&self, module: &Module) -> &mut [T] {
let ptr: *mut T = self.ptr as *mut T;
let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::<T>();
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
}
pub fn at_mut<T>(&self, module: &Module, col_i: usize) -> &mut [T] {
&mut self.raw_mut::<T>(module)[col_i * module.n()..(col_i + 1) * module.n()]
impl<D, B: Backend> VecZnxDft<D, B> {
pub fn into_big(self) -> VecZnxBig<D, B> {
VecZnxBig::<D, B>::from_data(self.data, self.n, self.cols, self.size)
}
}
impl Infos for VecZnxDft {
/// Returns the base 2 logarithm of the [VecZnx] degree.
fn log_n(&self) -> usize {
(usize::BITS - (self.n - 1).leading_zeros()) as _
impl<D, B: Backend> ZnxInfos for VecZnxDft<D, B> {
fn cols(&self) -> usize {
self.cols
}
fn rows(&self) -> usize {
1
}
/// Returns the [VecZnx] degree.
fn n(&self) -> usize {
self.n
}
@@ -116,254 +37,206 @@ impl Infos for VecZnxDft {
fn size(&self) -> usize {
self.size
}
}
fn layout(&self) -> LAYOUT {
self.layout
}
/// Returns the number of cols of the [VecZnx].
fn cols(&self) -> usize {
self.cols
}
/// Returns the number of rows of the [VecZnx].
fn rows(&self) -> usize {
1
impl<D> ZnxSliceSize for VecZnxDft<D, FFT64> {
fn sl(&self) -> usize {
self.n() * self.cols()
}
}
pub trait VecZnxDftOps {
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
fn new_vec_znx_dft(&self, size: usize, cols: usize) -> VecZnxDft;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: the backing array is only borrowed.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes_borrow(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn bytes_of_vec_znx_dft(&self, size: usize, cols: usize) -> usize;
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
fn vec_znx_idft_tmp_bytes(&self) -> usize;
/// b <- IDFT(a), uses a as scratch space.
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft);
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]);
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx);
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft);
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]);
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize;
impl<D, B: Backend> DataView for VecZnxDft<D, B> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
}
}
impl VecZnxDftOps for Module {
fn new_vec_znx_dft(&self, size: usize, cols: usize) -> VecZnxDft {
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_vec_znx_dft(size, cols));
let ptr: *mut u8 = data.as_mut_ptr();
VecZnxDft {
data: data,
ptr: ptr,
n: self.n(),
size: size,
layout: LAYOUT::COL,
cols: cols,
backend: self.backend(),
impl<D, B: Backend> DataViewMut for VecZnxDft<D, B> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
}
}
impl<D: AsRef<[u8]>> ZnxView for VecZnxDft<D, FFT64> {
type Scalar = f64;
}
pub(crate) fn bytes_of_vec_znx_dft<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize {
unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
}
impl<D: From<Vec<u8>>, B: Backend> VecZnxDft<D, B> {
pub(crate) fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
let data = alloc_aligned::<u8>(bytes_of_vec_znx_dft(module, cols, size));
Self {
data: data.into(),
n: module.n(),
cols,
size,
_phantom: PhantomData,
}
}
fn new_vec_znx_dft_from_bytes(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
VecZnxDft::from_bytes(self, size, cols, tmp_bytes)
}
fn new_vec_znx_dft_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
VecZnxDft::from_bytes_borrow(self, size, cols, tmp_bytes)
}
fn bytes_of_vec_znx_dft(&self, size: usize, cols: usize) -> usize {
unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize * size }
}
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) {
unsafe {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
a.ptr as *mut vec_znx_dft_t,
a.cols() as u64,
)
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
let data: Vec<u8> = bytes.into();
assert!(data.len() == bytes_of_vec_znx_dft(module, cols, size));
Self {
data: data.into(),
n: module.n(),
cols,
size,
_phantom: PhantomData,
}
}
}
fn vec_znx_idft_tmp_bytes(&self) -> usize {
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
}
/// b <- DFT(a)
///
/// # Panics
/// If b.cols < a_cols
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx) {
unsafe {
vec_znx_dft::vec_znx_dft(
self.ptr,
b.ptr as *mut vec_znx_dft_t,
b.cols() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
)
}
}
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]) {
impl<D> VecZnxDft<D, FFT64>
where
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64> + ZnxInfos,
{
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxDft<C, FFT64>, a_col: usize)
where
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos,
{
#[cfg(debug_assertions)]
{
assert!(
tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
tmp_bytes.len(),
Self::vec_znx_idft_tmp_bytes(self)
);
assert_alignement(tmp_bytes.as_ptr())
assert!(self_col < self.cols());
assert!(a_col < a.cols());
}
unsafe {
vec_znx_dft::vec_znx_idft(
self.ptr,
b.ptr as *mut vec_znx_big_t,
b.cols() as u64,
a.ptr as *const vec_znx_dft_t,
a.cols() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft) {
unsafe {
vec_znx_dft::vec_znx_dft_automorphism(
self.ptr,
k,
b.ptr as *mut vec_znx_dft_t,
b.cols() as u64,
a.ptr as *const vec_znx_dft_t,
a.cols() as u64,
[0u8; 0].as_mut_ptr(),
);
}
}
let min_size: usize = self.size.min(a.size());
let max_size: usize = self.size;
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
assert!(
tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self),
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}",
tmp_bytes.len(),
Self::vec_znx_dft_automorphism_tmp_bytes(self)
);
assert_alignement(tmp_bytes.as_ptr())
}
unsafe {
vec_znx_dft::vec_znx_dft_automorphism(
self.ptr,
k,
a.ptr as *mut vec_znx_dft_t,
a.cols() as u64,
a.ptr as *const vec_znx_dft_t,
a.cols() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
let mut self_mut: VecZnxDft<&mut [u8], FFT64> = self.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize {
unsafe {
std::cmp::max(
vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize,
DEFAULTALIGN,
)
}
}
}
#[cfg(test)]
mod tests {
use crate::{BACKEND, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned};
use itertools::izip;
use sampling::source::{Source, new_seed};
#[test]
fn test_automorphism_dft() {
let module: Module = Module::new(128, BACKEND::FFT64);
let cols: usize = 2;
let log_base2k: usize = 17;
let mut a: VecZnx = module.new_vec_znx(1, cols);
let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, cols);
let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, cols);
let mut source: Source = Source::new(new_seed());
module.fill_uniform(log_base2k, &mut a, cols, &mut source);
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes());
let p: i64 = -5;
// a_dft <- DFT(a)
module.vec_znx_dft(&mut a_dft, &a);
// a_dft <- AUTO(a_dft)
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes);
// a <- AUTO(a)
module.vec_znx_automorphism_inplace(p, &mut a);
// b_dft <- DFT(AUTO(a))
module.vec_znx_dft(&mut b_dft, &a);
let a_f64: &[f64] = a_dft.raw(&module);
let b_f64: &[f64] = b_dft.raw(&module);
izip!(a_f64.iter(), b_f64.iter()).for_each(|(ai, bi)| {
assert!((ai - bi).abs() <= 1e-9, "{:+e} > 1e-9", (ai - bi).abs());
(0..min_size).for_each(|i: usize| {
self_mut
.at_mut(self_col, i)
.copy_from_slice(a_ref.at(a_col, i));
});
module.free()
(min_size..max_size).for_each(|i| {
self_mut.zero_at(self_col, i);
});
}
}
pub type VecZnxDftOwned<B> = VecZnxDft<Vec<u8>, B>;
impl<D, B: Backend> VecZnxDft<D, B> {
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
Self {
data,
n,
cols,
size,
_phantom: PhantomData,
}
}
}
pub trait VecZnxDftToRef<B: Backend> {
fn to_ref(&self) -> VecZnxDft<&[u8], B>;
}
pub trait VecZnxDftToMut<B: Backend> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>;
}
impl<B: Backend> VecZnxDftToMut<B> for VecZnxDft<Vec<u8>, B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<Vec<u8>, B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToMut<B> for VecZnxDft<&mut [u8], B> {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<&mut [u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<&[u8], B> {
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
VecZnxDft {
data: self.data,
n: self.n,
cols: self.cols,
size: self.size,
_phantom: PhantomData,
}
}
}
impl<D: AsRef<[u8]>> fmt::Display for VecZnxDft<D, FFT64> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(
f,
"VecZnxDft(n={}, cols={}, size={})",
self.n, self.cols, self.size
)?;
for col in 0..self.cols {
writeln!(f, "Column {}:", col)?;
for size in 0..self.size {
let coeffs = self.at(col, size);
write!(f, " Size {}: [", size)?;
let max_show = 100;
let show_count = coeffs.len().min(max_show);
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{}", coeff)?;
}
if coeffs.len() > max_show {
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
}
writeln!(f, "]")?;
}
}
Ok(())
}
}

View File

@@ -0,0 +1,287 @@
use crate::ffi::{vec_znx_big, vec_znx_dft};
use crate::vec_znx_dft::bytes_of_vec_znx_dft;
use crate::znx_base::ZnxInfos;
use crate::{
Backend, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
ZnxSliceSize,
};
use crate::{FFT64, Module, ZnxView, ZnxViewMut, ZnxZero};
use std::cmp::min;
pub trait VecZnxDftAlloc<B: Backend> {
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// Behavior: takes ownership of the backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
///
/// # Arguments
///
/// * `cols`: the number of cols of the [VecZnxDft].
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
///
/// # Panics
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
}
pub trait VecZnxDftOps<B: Backend> {
/// Returns the minimum number of bytes necessary to allocate
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
fn vec_znx_idft_tmp_bytes(&self) -> usize;
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>,
D: VecZnxDftToRef<B>;
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
fn vec_znx_dft_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxDftToRef<B>;
/// b <- IDFT(a), uses a as scratch space.
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToMut<B>;
/// Consumes a to return IDFT(a) in big coeff space.
fn vec_znx_idft_consume<D>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, FFT64>
where
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>;
fn vec_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where
R: VecZnxBigToMut<B>,
A: VecZnxDftToRef<B>;
fn vec_znx_dft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<B>,
A: VecZnxToRef;
}
impl<B: Backend> VecZnxDftAlloc<B> for Module<B> {
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<B> {
VecZnxDftOwned::new(&self, cols, size)
}
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
VecZnxDftOwned::new_from_bytes(self, cols, size, bytes)
}
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
bytes_of_vec_znx_dft(self, cols, size)
}
}
impl VecZnxDftOps<FFT64> for Module<FFT64> {
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
D: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_dft_add(
self.ptr,
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1,
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
);
});
}
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
})
}
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let min_size: usize = res_mut.size().min(a_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_dft_add(
self.ptr,
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1,
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1,
);
});
}
}
fn vec_znx_dft_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let min_size: usize = min(res_mut.size(), a_ref.size());
(0..min_size).for_each(|j| {
res_mut
.at_mut(res_col, j)
.copy_from_slice(a_ref.at(a_col, j));
});
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
})
}
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxDftToMut<FFT64>,
{
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
let min_size: usize = min(res_mut.size(), a_mut.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
1 as u64,
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1 as u64,
)
});
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
})
}
}
fn vec_znx_idft_consume<D>(&self, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
where
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
{
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
unsafe {
// Rev col and rows because ZnxDft.sl() >= ZnxBig.sl()
(0..a_mut.size()).for_each(|j| {
(0..a_mut.cols()).for_each(|i| {
vec_znx_dft::vec_znx_idft_tmp_a(
self.ptr,
a_mut.at_mut_ptr(i, j) as *mut vec_znx_big::vec_znx_big_t,
1 as u64,
a_mut.at_mut_ptr(i, j) as *mut vec_znx_dft::vec_znx_dft_t,
1 as u64,
)
});
});
}
a.into_big()
}
fn vec_znx_idft_tmp_bytes(&self) -> usize {
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
}
/// b <- DFT(a)
///
/// # Panics
/// If b.cols < a_col
fn vec_znx_dft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: VecZnxToRef,
{
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a_ref: crate::VecZnx<&[u8]> = a.to_ref();
let min_size: usize = min(res_mut.size(), a_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_znx_dft(
self.ptr,
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
1 as u64,
a_ref.at_ptr(a_col, j),
1 as u64,
a_ref.sl() as u64,
)
});
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
});
}
}
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
fn vec_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where
R: VecZnxBigToMut<FFT64>,
A: VecZnxDftToRef<FFT64>,
{
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_idft_tmp_bytes());
let min_size: usize = min(res_mut.size(), a_ref.size());
unsafe {
(0..min_size).for_each(|j| {
vec_znx_dft::vec_znx_idft(
self.ptr,
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
1 as u64,
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
1 as u64,
tmp_bytes.as_mut_ptr(),
)
});
(min_size..res_mut.size()).for_each(|j| {
res_mut.zero_at(res_col, j);
});
}
}
}

694
base2k/src/vec_znx_ops.rs Normal file
View File

@@ -0,0 +1,694 @@
use crate::ffi::vec_znx;
use crate::{
Backend, Module, ScalarZnxToRef, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView,
ZnxViewMut, ZnxZero,
};
use itertools::izip;
use std::cmp::min;
pub trait VecZnxAlloc {
/// Allocates a new [VecZnx].
///
/// # Arguments
///
/// * `cols`: the number of polynomials.
/// * `size`: the number small polynomials per column.
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned;
/// Instantiates a new [VecZnx] from a slice of bytes.
/// The returned [VecZnx] takes ownership of the slice of bytes.
///
/// # Arguments
///
/// * `cols`: the number of polynomials.
/// * `size`: the number small polynomials per column.
///
/// # Panic
/// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx].
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned;
/// Returns the number of bytes necessary to allocate
/// a new [VecZnx] through [VecZnxOps::new_vec_znx_from_bytes]
/// or [VecZnxOps::new_vec_znx_from_bytes_borrow].
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize;
}
pub trait VecZnxOps {
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`.
fn vec_znx_normalize<R, A>(&self, log_base2k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Normalizes the selected column of `a`.
fn vec_znx_normalize_inplace<A>(&self, log_base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch)
where
A: VecZnxToMut;
/// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`.
fn vec_znx_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef;
/// Adds the selected column of `a` to the selected column of `res` and writes the result on the selected column of `res`.
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Adds the selected column of `a` on the selected column and limb of `res`.
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef;
/// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`.
fn vec_znx_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef;
/// Subtracts the selected column of `a` from the selected column of `res` inplace.
///
/// res[res_col] -= a[a_col]
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res`
///
/// res[res_col] = a[a_col] - res[res_col]
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Subtracts the selected column of `a` on the selected column and limb of `res`.
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef;
// Negates the selected column of `a` and stores the result in `res_col` of `res`.
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Negates the selected column of `a`.
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
where
A: VecZnxToMut;
/// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Multiplies the selected column of `a` by X^k.
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut;
/// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`.
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut;
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
///
/// # Panics
///
/// This method requires that all [VecZnx] of b have the same ring degree
/// and that b.n() * b.len() <= a.n()
fn vec_znx_split<R, A>(&self, res: &mut Vec<R>, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where
R: VecZnxToMut,
A: VecZnxToRef;
/// Merges the subrings of the selected column of `a` into the selected column of `res`.
///
/// # Panics
///
/// This method requires that all [VecZnx] of a have the same ring degree
/// and that a.n() * a.len() <= b.n()
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
fn switch_degree<R, A>(&self, r: &mut R, col_b: usize, a: &A, col_a: usize)
where
R: VecZnxToMut,
A: VecZnxToRef;
}
pub trait VecZnxScratch {
/// Returns the minimum number of bytes necessary for normalization.
fn vec_znx_normalize_tmp_bytes(&self) -> usize;
}
impl<B: Backend> VecZnxAlloc for Module<B> {
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned {
VecZnxOwned::new::<i64>(self.n(), cols, size)
}
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize {
VecZnxOwned::bytes_of::<i64>(self.n(), cols, size)
}
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned {
VecZnxOwned::new_from_bytes::<i64>(self.n(), cols, size, bytes)
}
}
impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> {
fn vec_znx_normalize<R, A>(&self, log_base2k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes());
unsafe {
vec_znx::vec_znx_normalize_base2k(
self.ptr,
log_base2k as u64,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vec_znx_normalize_inplace<A>(&self, log_base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch)
where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
}
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes());
unsafe {
vec_znx::vec_znx_normalize_base2k(
self.ptr,
log_base2k as u64,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vec_znx_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(res.n(), self.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: crate::ScalarZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_add(
self.ptr,
res.at_mut_ptr(res_col, res_limb),
1 as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, res_limb),
1 as u64,
res.sl() as u64,
)
}
}
fn vec_znx_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
B: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let b: VecZnx<&[u8]> = b.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(b.n(), self.n());
assert_eq!(res.n(), self.n());
assert_ne!(a.as_ptr(), b.as_ptr());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
b.at_ptr(b_col, 0),
b.size() as u64,
b.sl() as u64,
)
}
}
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: ScalarZnxToRef,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let a: crate::ScalarZnx<&[u8]> = a.to_ref();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, res_limb),
1 as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, res_limb),
1 as u64,
res.sl() as u64,
)
}
}
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_sub(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
res.at_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
)
}
}
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_negate(
self.ptr,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
}
unsafe {
vec_znx::vec_znx_negate(
self.ptr,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_rotate(
self.ptr,
k,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
}
unsafe {
vec_znx::vec_znx_rotate(
self.ptr,
k,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert_eq!(res.n(), self.n());
}
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
k,
res.at_mut_ptr(res_col, 0),
res.size() as u64,
res.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
where
A: VecZnxToMut,
{
let mut a: VecZnx<&mut [u8]> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
assert!(
k & 1 != 0,
"invalid galois element: must be odd but is {}",
k
);
}
unsafe {
vec_znx::vec_znx_automorphism(
self.ptr,
k,
a.at_mut_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(a_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_split<R, A>(&self, res: &mut Vec<R>, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
let (mut buf, _) = scratch.tmp_vec_znx(self, 1, a.size());
debug_assert!(
n_out < n_in,
"invalid a: output ring degree should be smaller"
);
res[1..].iter_mut().for_each(|bi| {
debug_assert_eq!(
bi.to_mut().n(),
n_out,
"invalid input a: all VecZnx must have the same degree"
)
});
res.iter_mut().enumerate().for_each(|(i, bi)| {
if i == 0 {
self.switch_degree(bi, res_col, &a, a_col);
self.vec_znx_rotate(-1, &mut buf, 0, &a, a_col);
} else {
self.switch_degree(bi, res_col, &mut buf, a_col);
self.vec_znx_rotate_inplace(-1, &mut buf, a_col);
}
})
}
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let (n_in, n_out) = (res.n(), a[0].to_ref().n());
debug_assert!(
n_out < n_in,
"invalid a: output ring degree should be smaller"
);
a[1..].iter().for_each(|ai| {
debug_assert_eq!(
ai.to_ref().n(),
n_out,
"invalid input a: all VecZnx must have the same degree"
)
});
a.iter().enumerate().for_each(|(_, ai)| {
self.switch_degree(&mut res, res_col, ai, a_col);
self.vec_znx_rotate_inplace(-1, &mut res, res_col);
});
self.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col);
}
fn switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxToMut,
A: VecZnxToRef,
{
let a: VecZnx<&[u8]> = a.to_ref();
let mut res: VecZnx<&mut [u8]> = res.to_mut();
let (n_in, n_out) = (a.n(), res.n());
let (gap_in, gap_out): (usize, usize);
if n_in > n_out {
(gap_in, gap_out) = (n_in / n_out, 1)
} else {
(gap_in, gap_out) = (1, n_out / n_in);
res.zero();
}
let size: usize = min(a.size(), res.size());
(0..size).for_each(|i| {
izip!(
a.at(a_col, i).iter().step_by(gap_in),
res.at_mut(res_col, i).iter_mut().step_by(gap_out)
)
.for_each(|(x_in, x_out)| *x_out = *x_in);
});
}
}
impl<B: Backend> VecZnxScratch for Module<B> {
fn vec_znx_normalize_tmp_bytes(&self) -> usize {
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize }
}
}

View File

@@ -1,694 +0,0 @@
use crate::ffi::vec_znx_big::vec_znx_big_t;
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::ffi::vmp::{self, vmp_pmat_t};
use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement};
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
/// Each row of the [VmpPMat] can be seen as a [VecZnxDft].
///
/// The backend array of [VmpPMat] is allocate in C,
/// and thus must be manually freed.
///
/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx] and a [VmpPMat].
/// See the trait [VmpPMatOps] for additional information.
pub struct VmpPMat {
/// Raw data, is empty if borrowing scratch space.
data: Vec<u8>,
/// Pointer to data. Can point to scratch space.
ptr: *mut u8,
/// The number of [VecZnxDft].
rows: usize,
/// The number of cols in each [VecZnxDft].
cols: usize,
/// The ring degree of each [VecZnxDft].
n: usize,
/// The number of stacked [VmpPMat], must be a square.
size: usize,
/// The memory layout of the stacked [VmpPMat].
layout: LAYOUT,
/// The backend fft or ntt.
backend: BACKEND,
}
impl Infos for VmpPMat {
/// Returns the ring dimension of the [VmpPMat].
fn n(&self) -> usize {
self.n
}
fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
fn size(&self) -> usize {
self.size
}
fn layout(&self) -> LAYOUT {
self.layout
}
/// Returns the number of rows (i.e. of [VecZnxDft]) of the [VmpPMat]
fn rows(&self) -> usize {
self.rows
}
/// Returns the number of cols of the [VmpPMat].
/// The number of cols refers to the number of cols
/// of each [VecZnxDft].
/// This method is equivalent to [Self::cols].
fn cols(&self) -> usize {
self.cols
}
}
impl VmpPMat {
pub fn as_ptr(&self) -> *const u8 {
self.ptr
}
pub fn as_mut_ptr(&self) -> *mut u8 {
self.ptr
}
pub fn borrowed(&self) -> bool {
self.data.len() == 0
}
/// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat].
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
/// The length of the returned array is rows * cols * n.
pub fn raw<T>(&self) -> &[T] {
let ptr: *const T = self.ptr as *const T;
let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::<T>();
unsafe { &std::slice::from_raw_parts(ptr, len) }
}
/// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat].
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
/// The length of the returned array is rows * cols * n.
pub fn raw_mut<T>(&self) -> &mut [T] {
let ptr: *mut T = self.ptr as *mut T;
let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::<T>();
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
}
/// Returns a copy of the backend array at index (i, j) of the [VmpPMat].
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
///
/// # Arguments
///
/// * `row`: row index (i).
/// * `col`: col index (j).
pub fn at<T: Default + Copy>(&self, row: usize, col: usize) -> Vec<T> {
let mut res: Vec<T> = alloc_aligned(self.n);
if self.n < 8 {
res.copy_from_slice(
&self.raw::<T>()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)],
);
} else {
(0..self.n >> 3).for_each(|blk| {
res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]);
});
}
res
}
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
fn at_block<T>(&self, row: usize, col: usize, blk: usize) -> &[T] {
let nrows: usize = self.rows();
let ncols: usize = self.cols();
if col == (ncols - 1) && (ncols & 1 == 1) {
&self.raw::<T>()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..]
} else {
&self.raw::<T>()[blk * nrows * ncols * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..]
}
}
fn backend(&self) -> BACKEND {
self.backend
}
}
/// This trait implements methods for vector matrix product,
/// that is, multiplying a [VecZnx] with a [VmpPMat].
pub trait VmpPMatOps {
fn bytes_of_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> usize;
/// Allocates a new [VmpPMat] with the given number of rows and columns.
///
/// # Arguments
///
/// * `rows`: number of rows (number of [VecZnxDft]).
/// * `cols`: number of cols (number of cols of each [VecZnxDft]).
fn new_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> VmpPMat;
/// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous].
///
/// # Arguments
///
/// * `rows`: number of rows of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous].
/// * `cols`: number of cols of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous].
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize) -> usize;
/// Prepares a [VmpPMat] from a contiguous array of [i64].
/// The helper struct [Matrix3D] can be used to contruct and populate
/// the appropriate contiguous array.
///
/// # Arguments
///
/// * `b`: [VmpPMat] on which the values are encoded.
/// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat].
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]);
/// Prepares a [VmpPMat] from a vector of [VecZnx].
///
/// # Arguments
///
/// * `b`: [VmpPMat] on which the values are encoded.
/// * `a`: the vector of [VecZnx] to encode on the [VmpPMat].
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
///
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]);
/// Prepares the ith-row of [VmpPMat] from a [VecZnx].
///
/// # Arguments
///
/// * `b`: [VmpPMat] on which the values are encoded.
/// * `a`: the vector of [VecZnx] to encode on the [VmpPMat].
/// * `row_i`: the index of the row to prepare.
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
///
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
/// Extracts the ith-row of [VmpPMat] into a [VecZnxBig].
///
/// # Arguments
///
/// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat].
/// * `a`: [VmpPMat] on which the values are encoded.
/// * `row_i`: the index of the row to extract.
fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize);
/// Prepares the ith-row of [VmpPMat] from a [VecZnxDft].
///
/// # Arguments
///
/// * `b`: [VmpPMat] on which the values are encoded.
/// * `a`: the [VecZnxDft] to encode on the [VmpPMat].
/// * `row_i`: the index of the row to prepare.
///
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize);
/// Extracts the ith-row of [VmpPMat] into a [VecZnxDft].
///
/// # Arguments
///
/// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat].
/// * `a`: [VmpPMat] on which the values are encoded.
/// * `row_i`: the index of the row to extract.
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
///
/// # Arguments
///
/// * `c_cols`: number of cols of the output [VecZnxDft].
/// * `a_cols`: number of cols of the input [VecZnx].
/// * `rows`: number of rows of the input [VmpPMat].
/// * `cols`: number of cols of the input [VmpPMat].
fn vmp_apply_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize;
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
///
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
/// `j` cols, the output is a [VecZnx] of `j` cols.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
/// ```text
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
/// |h i j|
/// |k l m|
/// ```
/// where each element is a [VecZnxDft].
///
/// # Arguments
///
/// * `c`: the output of the vector matrix product, as a [VecZnxDft].
/// * `a`: the left operand [VecZnx] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on the receiver.
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
///
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
/// `j` cols, the output is a [VecZnx] of `j` cols.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
/// ```text
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
/// |h i j|
/// |k l m|
/// ```
/// where each element is a [VecZnxDft].
///
/// # Arguments
///
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
/// * `a`: the left operand [VecZnx] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]);
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft].
///
/// # Arguments
///
/// * `c_cols`: number of cols of the output [VecZnxDft].
/// * `a_cols`: number of cols of the input [VecZnxDft].
/// * `rows`: number of rows of the input [VmpPMat].
/// * `cols`: number of cols of the input [VmpPMat].
fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize;
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
///
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
/// `j` cols, the output is a [VecZnx] of `j` cols.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
/// ```text
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
/// |h i j|
/// |k l m|
/// ```
/// where each element is a [VecZnxDft].
///
/// # Arguments
///
/// * `c`: the output of the vector matrix product, as a [VecZnxDft].
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on top of the receiver instead of overwritting it.
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
///
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
/// `j` cols, the output is a [VecZnx] of `j` cols.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
/// ```text
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
/// |h i j|
/// |k l m|
/// ```
/// where each element is a [VecZnxDft].
///
/// # Arguments
///
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]);
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place.
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
///
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
///
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
/// `j` cols, the output is a [VecZnx] of `j` cols.
///
/// If there is a mismatch between the dimensions the largest valid ones are used.
///
/// ```text
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
/// |h i j|
/// |k l m|
/// ```
/// where each element is a [VecZnxDft].
///
/// # Arguments
///
/// * `b`: the input and output of the vector matrix product, as a [VecZnxDft].
/// * `a`: the right operand [VmpPMat] of the vector matrix product.
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]);
}
impl VmpPMatOps for Module {
fn bytes_of_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> usize {
unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, cols as u64) as usize * size }
}
fn new_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> VmpPMat {
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_vmp_pmat(size, rows, cols));
let ptr: *mut u8 = data.as_mut_ptr();
VmpPMat {
data: data,
ptr: ptr,
n: self.n(),
size: size,
layout: LAYOUT::COL,
cols: cols,
rows: rows,
backend: self.backend(),
}
}
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize) -> usize {
unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, cols as u64) as usize }
}
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) {
debug_assert_eq!(a.len(), b.n * b.rows * b.cols);
debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_prepare_contiguous(
self.ptr,
b.as_mut_ptr() as *mut vmp_pmat_t,
a.as_ptr(),
b.rows() as u64,
b.cols() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], tmp_bytes: &mut [u8]) {
let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect();
#[cfg(debug_assertions)]
{
debug_assert_eq!(a.len(), b.rows);
a.iter().for_each(|ai| {
debug_assert_eq!(ai.len(), b.n * b.cols);
});
debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_prepare_dblptr(
self.ptr,
b.as_mut_ptr() as *mut vmp_pmat_t,
ptrs.as_ptr(),
b.rows() as u64,
b.cols() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) {
#[cfg(debug_assertions)]
{
assert_eq!(a.len(), b.cols() * self.n());
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_prepare_row(
self.ptr,
b.as_mut_ptr() as *mut vmp_pmat_t,
a.as_ptr(),
row_i as u64,
b.rows() as u64,
b.cols() as u64,
tmp_bytes.as_mut_ptr(),
);
}
}
fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
assert_eq!(a.cols(), b.cols());
}
unsafe {
vmp::vmp_extract_row(
self.ptr,
b.ptr as *mut vec_znx_big_t,
a.as_ptr() as *const vmp_pmat_t,
row_i as u64,
a.rows() as u64,
a.cols() as u64,
);
}
}
fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
assert_eq!(a.cols(), b.cols());
}
unsafe {
vmp::vmp_prepare_row_dft(
self.ptr,
b.as_mut_ptr() as *mut vmp_pmat_t,
a.ptr as *const vec_znx_dft_t,
row_i as u64,
b.rows() as u64,
b.cols() as u64,
);
}
}
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize) {
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), b.n());
assert_eq!(a.cols(), b.cols());
}
unsafe {
vmp::vmp_extract_row_dft(
self.ptr,
b.ptr as *mut vec_znx_dft_t,
a.as_ptr() as *const vmp_pmat_t,
row_i as u64,
a.rows() as u64,
a.cols() as u64,
);
}
}
fn vmp_apply_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize {
unsafe {
vmp::vmp_apply_dft_tmp_bytes(
self.ptr,
res_cols as u64,
a_cols as u64,
gct_rows as u64,
gct_cols as u64,
) as usize
}
}
fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_apply_dft(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.cols() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
b.cols() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_apply_dft_add(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.cols() as u64,
a.as_ptr(),
a.cols() as u64,
a.n() as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
b.cols() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize {
unsafe {
vmp::vmp_apply_dft_to_dft_tmp_bytes(
self.ptr,
res_cols as u64,
a_cols as u64,
gct_rows as u64,
gct_cols as u64,
) as usize
}
}
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_apply_dft_to_dft(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.cols() as u64,
a.ptr as *const vec_znx_dft_t,
a.cols() as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
b.cols() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_apply_dft_to_dft_add(
self.ptr,
c.ptr as *mut vec_znx_dft_t,
c.cols() as u64,
a.ptr as *const vec_znx_dft_t,
a.cols() as u64,
b.as_ptr() as *const vmp_pmat_t,
b.rows() as u64,
b.cols() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) {
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols()));
#[cfg(debug_assertions)]
{
assert_alignement(tmp_bytes.as_ptr());
}
unsafe {
vmp::vmp_apply_dft_to_dft(
self.ptr,
b.ptr as *mut vec_znx_dft_t,
b.cols() as u64,
b.ptr as *mut vec_znx_dft_t,
b.cols() as u64,
a.as_ptr() as *const vmp_pmat_t,
a.rows() as u64,
a.cols() as u64,
tmp_bytes.as_mut_ptr(),
)
}
}
}
#[cfg(test)]
mod tests {
use crate::{
Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, alloc_aligned,
};
use sampling::source::Source;
#[test]
fn vmp_prepare_row_dft() {
let module: Module = Module::new(32, crate::BACKEND::FFT64);
let vpmat_rows: usize = 4;
let vpmat_cols: usize = 5;
let log_base2k: usize = 8;
let mut a: VecZnx = module.new_vec_znx(1, vpmat_cols);
let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_cols);
let mut a_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_cols);
let mut b_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_cols);
let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_cols);
let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(1, vpmat_rows, vpmat_cols);
let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(1, vpmat_rows, vpmat_cols);
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols));
for row_i in 0..vpmat_rows {
let mut source: Source = Source::new([0u8; 32]);
module.fill_uniform(log_base2k, &mut a, vpmat_cols, &mut source);
module.vec_znx_dft(&mut a_dft, &a);
module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes);
// Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft)
module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i);
assert_eq!(vmpmat_0.raw::<u8>(), vmpmat_1.raw::<u8>());
// Checks that a_dft = extract_dft(prepare(vmp_pmat, a), b_dft)
module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i);
assert_eq!(a_dft.raw::<u8>(&module), b_dft.raw::<u8>(&module));
// Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big)
module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i);
module.vec_znx_idft(&mut a_big, &a_dft, &mut tmp_bytes);
assert_eq!(a_big.raw::<i64>(&module), b_big.raw::<i64>(&module));
}
module.free();
}
}

199
base2k/src/znx_base.rs Normal file
View File

@@ -0,0 +1,199 @@
use itertools::izip;
use rand_distr::num_traits::Zero;
pub trait ZnxInfos {
/// Returns the ring degree of the polynomials.
fn n(&self) -> usize;
/// Returns the base two logarithm of the ring dimension of the polynomials.
fn log_n(&self) -> usize {
(usize::BITS - (self.n() - 1).leading_zeros()) as _
}
/// Returns the number of rows.
fn rows(&self) -> usize;
/// Returns the number of polynomials in each row.
fn cols(&self) -> usize;
/// Returns the number of size per polynomial.
fn size(&self) -> usize;
/// Returns the total number of small polynomials.
fn poly_count(&self) -> usize {
self.rows() * self.cols() * self.size()
}
}
pub trait ZnxSliceSize {
/// Returns the slice size, which is the offset between
/// two size of the same column.
fn sl(&self) -> usize;
}
pub trait DataView {
type D;
fn data(&self) -> &Self::D;
}
pub trait DataViewMut: DataView {
fn data_mut(&mut self) -> &mut Self::D;
}
pub trait ZnxView: ZnxInfos + DataView<D: AsRef<[u8]>> {
type Scalar: Copy;
/// Returns a non-mutable pointer to the underlying coefficients array.
fn as_ptr(&self) -> *const Self::Scalar {
self.data().as_ref().as_ptr() as *const Self::Scalar
}
/// Returns a non-mutable reference to the entire underlying coefficient array.
fn raw(&self) -> &[Self::Scalar] {
unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
}
/// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
#[cfg(debug_assertions)]
{
assert!(i < self.cols());
assert!(j < self.size());
}
let offset: usize = self.n() * (j * self.cols() + i);
unsafe { self.as_ptr().add(offset) }
}
/// Returns non-mutable reference to the (i, j)-th small polynomial.
fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
}
}
pub trait ZnxViewMut: ZnxView + DataViewMut<D: AsMut<[u8]>> {
/// Returns a mutable pointer to the underlying coefficients array.
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar
}
/// Returns a mutable reference to the entire underlying coefficient array.
fn raw_mut(&mut self) -> &mut [Self::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
}
/// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
#[cfg(debug_assertions)]
{
assert!(i < self.cols());
assert!(j < self.size());
}
let offset: usize = self.n() * (j * self.cols() + i);
unsafe { self.as_mut_ptr().add(offset) }
}
/// Returns mutable reference to the (i, j)-th small polynomial.
fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
}
}
//(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: AsMut<[u8]>> {}
pub trait ZnxZero: ZnxViewMut + ZnxSliceSize
where
Self: Sized,
{
fn zero(&mut self) {
unsafe {
std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * self.poly_count());
}
}
fn zero_at(&mut self, i: usize, j: usize) {
unsafe {
std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.n());
}
}
}
// Blanket implementations
impl<T> ZnxZero for T where T: ZnxViewMut + ZnxSliceSize {} // WARNING should not work for mat_znx_dft but it does
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
use crate::Scratch;
pub trait Integer:
Copy
+ Default
+ PartialEq
+ PartialOrd
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Div<Output = Self>
+ Neg<Output = Self>
+ Shl<Output = Self>
+ Shr<Output = Self>
+ AddAssign
{
const BITS: u32;
}
impl Integer for i64 {
const BITS: u32 = 64;
}
impl Integer for i128 {
const BITS: u32 = 128;
}
//(Jay)Note: `rsh` impl. ignores the column
pub fn rsh<V: ZnxZero>(k: usize, log_base2k: usize, a: &mut V, _a_col: usize, scratch: &mut Scratch)
where
V::Scalar: From<usize> + Integer + Zero,
{
let n: usize = a.n();
let _size: usize = a.size();
let cols: usize = a.cols();
let size: usize = a.size();
let steps: usize = k / log_base2k;
a.raw_mut().rotate_right(n * steps * cols);
(0..cols).for_each(|i| {
(0..steps).for_each(|j| {
a.zero_at(i, j);
})
});
let k_rem: usize = k % log_base2k;
if k_rem != 0 {
let (carry, _) = scratch.tmp_slice::<V::Scalar>(rsh_tmp_bytes::<V::Scalar>(n));
unsafe {
std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::<V::Scalar>());
}
let log_base2k_t = V::Scalar::from(log_base2k);
let shift = V::Scalar::from(V::Scalar::BITS as usize - k_rem);
let k_rem_t = V::Scalar::from(k_rem);
(0..cols).for_each(|i| {
(steps..size).for_each(|j| {
izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| {
*xi += *ci << log_base2k_t;
*ci = (*xi << shift) >> shift;
*xi = (*xi - *ci) >> k_rem_t;
});
});
carry.iter_mut().for_each(|r| *r = V::Scalar::zero());
})
}
}
pub fn rsh_tmp_bytes<T>(n: usize) -> usize {
n * std::mem::size_of::<T>()
}

View File

@@ -1,5 +1,3 @@
cargo-features = ["edition2024"]
[package]
name = "rlwe"
version = "0.1.0"
@@ -14,5 +12,9 @@ rand_distr = {workspace = true}
itertools = {workspace = true}
[[bench]]
name = "gadget_product"
name = "external_product_glwe_fft64"
harness = false
[[bench]]
name = "keyswitch_glwe_fft64"
harness = false

View File

@@ -0,0 +1,202 @@
use base2k::{FFT64, Module, ScalarZnxAlloc, ScratchOwned};
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
use rlwe::{
elem::Infos,
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext::GLWECiphertext,
keys::{SecretKey, SecretKeyFourier},
};
use sampling::source::Source;
fn bench_external_product_glwe_fft64(c: &mut Criterion) {
let mut group = c.benchmark_group("external_product_glwe_fft64");
struct Params {
log_n: usize,
basek: usize,
k_ct_in: usize,
k_ct_out: usize,
k_ggsw: usize,
rank: usize,
}
fn runner(p: Params) -> impl FnMut() {
let module: Module<FFT64> = Module::<FFT64>::new(1 << p.log_n);
let basek: usize = p.basek;
let k_ct_in: usize = p.k_ct_in;
let k_ct_out: usize = p.k_ct_out;
let k_ggsw: usize = p.k_ggsw;
let rank: usize = p.rank;
let rows: usize = (p.k_ct_in + p.basek - 1) / p.basek;
let sigma: f64 = 3.2;
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut ct_rlwe_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_in, rank);
let mut ct_rlwe_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_out, rank);
let pt_rgsw: base2k::ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut scratch = ScratchOwned::new(
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size())
| GLWECiphertext::external_product_scratch_space(
&module,
ct_rlwe_out.size(),
ct_rlwe_in.size(),
ct_rgsw.size(),
rank,
),
);
let mut source_xs = Source::new([0u8; 32]);
let mut source_xe = Source::new([0u8; 32]);
let mut source_xa = Source::new([0u8; 32]);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
ct_rgsw.encrypt_sk(
&module,
&pt_rgsw,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_rlwe_in.encrypt_zero_sk(
&module,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
move || {
ct_rlwe_out.external_product(
black_box(&module),
black_box(&ct_rlwe_in),
black_box(&ct_rgsw),
black_box(scratch.borrow()),
);
}
}
let params_set: Vec<Params> = vec![Params {
log_n: 10,
basek: 7,
k_ct_in: 27,
k_ct_out: 27,
k_ggsw: 27,
rank: 1,
}];
for params in params_set {
let id = BenchmarkId::new("EXTERNAL_PRODUCT_GLWE_FFT64", "");
let mut runner = runner(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) {
let mut group = c.benchmark_group("external_product_glwe_inplace_fft64");
struct Params {
log_n: usize,
basek: usize,
k_ct: usize,
k_ggsw: usize,
rank: usize,
}
fn runner(p: Params) -> impl FnMut() {
let module: Module<FFT64> = Module::<FFT64>::new(1 << p.log_n);
let basek: usize = p.basek;
let k_glwe: usize = p.k_ct;
let k_ggsw: usize = p.k_ggsw;
let rank: usize = p.rank;
let rows: usize = (p.k_ct + p.basek - 1) / p.basek;
let sigma: f64 = 3.2;
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut ct_rlwe: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_glwe, rank);
let pt_rgsw: base2k::ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut scratch = ScratchOwned::new(
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size())
| GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size(), rank),
);
let mut source_xs = Source::new([0u8; 32]);
let mut source_xe = Source::new([0u8; 32]);
let mut source_xa = Source::new([0u8; 32]);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
ct_rgsw.encrypt_sk(
&module,
&pt_rgsw,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_rlwe.encrypt_zero_sk(
&module,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
move || {
let scratch_borrow = scratch.borrow();
(0..687).for_each(|_| {
ct_rlwe.external_product_inplace(
black_box(&module),
black_box(&ct_rgsw),
black_box(scratch_borrow),
);
});
}
}
let params_set: Vec<Params> = vec![Params {
log_n: 12,
basek: 18,
k_ct: 54,
k_ggsw: 54,
rank: 1,
}];
for params in params_set {
let id = BenchmarkId::new("EXTERNAL_PRODUCT_GLWE_INPLACE_FFT64", "");
let mut runner = runner(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
criterion_group!(
benches,
bench_external_product_glwe_fft64,
bench_external_product_glwe_inplace_fft64
);
criterion_main!(benches);

View File

@@ -0,0 +1,211 @@
use base2k::{FFT64, Module, ScratchOwned};
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
use rlwe::{
elem::Infos,
glwe_ciphertext::GLWECiphertext,
keys::{SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
};
use sampling::source::Source;
fn bench_keyswitch_glwe_fft64(c: &mut Criterion) {
let mut group = c.benchmark_group("keyswitch_glwe_fft64");
struct Params {
log_n: usize,
basek: usize,
k_ct_in: usize,
k_ct_out: usize,
k_ksk: usize,
rank_in: usize,
rank_out: usize,
}
fn runner(p: Params) -> impl FnMut() {
let module: Module<FFT64> = Module::<FFT64>::new(1 << p.log_n);
let basek: usize = p.basek;
let k_rlwe_in: usize = p.k_ct_in;
let k_rlwe_out: usize = p.k_ct_out;
let k_grlwe: usize = p.k_ksk;
let rank_in: usize = p.rank_in;
let rank_out: usize = p.rank_out;
let rows: usize = (p.k_ct_in + p.basek - 1) / p.basek;
let sigma: f64 = 3.2;
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k_grlwe, rows, rank_in, rank_out);
let mut ct_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_rlwe_in, rank_in);
let mut ct_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_rlwe_out, rank_out);
let mut scratch = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ksk.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size())
| GLWECiphertext::keyswitch_scratch_space(
&module,
ct_out.size(),
ct_in.size(),
ksk.size(),
rank_in,
rank_out,
),
);
let mut source_xs = Source::new([0u8; 32]);
let mut source_xe = Source::new([0u8; 32]);
let mut source_xa = Source::new([0u8; 32]);
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in);
sk_in.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in);
sk_in_dft.dft(&module, &sk_in);
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out);
sk_out.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out);
sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk(
&module,
&sk_in,
&sk_out_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_in.encrypt_zero_sk(
&module,
&sk_in_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
move || {
ct_out.keyswitch(
black_box(&module),
black_box(&ct_in),
black_box(&ksk),
black_box(scratch.borrow()),
);
}
}
let params_set: Vec<Params> = vec![Params {
log_n: 16,
basek: 50,
k_ct_in: 1250,
k_ct_out: 1250,
k_ksk: 1250 + 66,
rank_in: 1,
rank_out: 1,
}];
for params in params_set {
let id = BenchmarkId::new("KEYSWITCH_GLWE_FFT64", "");
let mut runner = runner(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) {
let mut group = c.benchmark_group("keyswitch_glwe_inplace_fft64");
struct Params {
log_n: usize,
basek: usize,
k_ct: usize,
k_ksk: usize,
rank: usize,
}
fn runner(p: Params) -> impl FnMut() {
let module: Module<FFT64> = Module::<FFT64>::new(1 << p.log_n);
let basek: usize = p.basek;
let k_ct: usize = p.k_ct;
let k_ksk: usize = p.k_ksk;
let rank: usize = p.rank;
let rows: usize = (p.k_ct + p.basek - 1) / p.basek;
let sigma: f64 = 3.2;
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank, rank);
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
let mut scratch = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size())
| GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct.size(), ksk.size(), rank),
);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk_in.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_in_dft.dft(&module, &sk_in);
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk_out.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk(
&module,
&sk_in,
&sk_out_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct.encrypt_zero_sk(
&module,
&sk_in_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
move || {
ct.keyswitch_inplace(
black_box(&module),
black_box(&ksk),
black_box(scratch.borrow()),
);
}
}
let params_set: Vec<Params> = vec![Params {
log_n: 9,
basek: 18,
k_ct: 27,
k_ksk: 27,
rank: 1,
}];
for params in params_set {
let id = BenchmarkId::new("KEYSWITCH_GLWE_INPLACE_FFT64", "");
let mut runner = runner(params);
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
}
group.finish();
}
criterion_group!(
benches,
bench_keyswitch_glwe_fft64,
bench_keyswitch_glwe_inplace_fft64
);
criterion_main!(benches);

386
core/src/automorphism.rs Normal file
View File

@@ -0,0 +1,386 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps,
ScalarZnxToRef, Scratch, VecZnx, VecZnxBigAlloc, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps,
ZnxZero,
};
use sampling::source::Source;
use crate::{
elem::{GetRow, Infos, SetRow},
gglwe_ciphertext::GGLWECiphertext,
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext::GLWECiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
keys::{SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
};
pub struct AutomorphismKey<Data, B: Backend> {
pub(crate) key: GLWESwitchingKey<Data, B>,
pub(crate) p: i64,
}
impl AutomorphismKey<Vec<u8>, FFT64> {
pub fn new(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
AutomorphismKey {
key: GLWESwitchingKey::new(module, basek, k, rows, rank, rank),
p: 0,
}
}
}
impl<T, B: Backend> Infos for AutomorphismKey<T, B> {
type Inner = MatZnxDft<T, B>;
fn inner(&self) -> &Self::Inner {
&self.key.inner()
}
fn basek(&self) -> usize {
self.key.basek()
}
fn k(&self) -> usize {
self.key.k()
}
}
impl<T, B: Backend> AutomorphismKey<T, B> {
pub fn p(&self) -> i64 {
self.p
}
pub fn rank(&self) -> usize {
self.key.rank()
}
pub fn rank_in(&self) -> usize {
self.key.rank_in()
}
pub fn rank_out(&self) -> usize {
self.key.rank_out()
}
}
impl<DataSelf, B: Backend> MatZnxDftToMut<B> for AutomorphismKey<DataSelf, B>
where
MatZnxDft<DataSelf, B>: MatZnxDftToMut<B>,
{
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
self.key.to_mut()
}
}
impl<DataSelf, B: Backend> MatZnxDftToRef<B> for AutomorphismKey<DataSelf, B>
where
MatZnxDft<DataSelf, B>: MatZnxDftToRef<B>,
{
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
self.key.to_ref()
}
}
impl<C> GetRow<FFT64> for AutomorphismKey<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
{
fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut R)
where
R: VecZnxDftToMut<FFT64>,
{
module.vmp_extract_row(res, self, row_i, col_j);
}
}
impl<C> SetRow<FFT64> for AutomorphismKey<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
{
fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, col_j: usize, a: &R)
where
R: VecZnxDftToRef<FFT64>,
{
module.vmp_prepare_row(self, row_i, col_j, a);
}
}
impl AutomorphismKey<Vec<u8>, FFT64> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
GGLWECiphertext::encrypt_sk_scratch_space(module, rank, size)
}
pub fn encrypt_pk_scratch_space(module: &Module<FFT64>, rank: usize, pk_size: usize) -> usize {
GGLWECiphertext::encrypt_pk_scratch_space(module, rank, pk_size)
}
pub fn keyswitch_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
ksk_size: usize,
rank: usize,
) -> usize {
GLWESwitchingKey::keyswitch_scratch_space(module, out_size, rank, in_size, rank, ksk_size)
}
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
GLWESwitchingKey::keyswitch_inplace_scratch_space(module, out_size, out_rank, ksk_size)
}
pub fn automorphism_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
ksk_size: usize,
rank: usize,
) -> usize {
let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size);
let tmp_idft: usize = module.bytes_of_vec_znx_big(rank + 1, out_size);
let idft: usize = module.vec_znx_idft_tmp_bytes();
let keyswitch: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, out_size, rank, ksk_size);
tmp_dft + tmp_idft + idft + keyswitch
}
pub fn automorphism_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, ksk_size: usize, rank: usize) -> usize {
AutomorphismKey::automorphism_scratch_space(module, out_size, out_size, ksk_size, rank)
}
pub fn external_product_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
GLWESwitchingKey::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank)
}
pub fn external_product_inplace_scratch_space(
module: &Module<FFT64>,
out_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
GLWESwitchingKey::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank)
}
}
impl<DataSelf> AutomorphismKey<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
{
pub fn encrypt_sk<DataSk>(
&mut self,
module: &Module<FFT64>,
p: i64,
sk: &SecretKey<DataSk>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch,
) where
ScalarZnx<DataSk>: ScalarZnxToRef,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.n(), module.n());
assert_eq!(sk.n(), module.n());
assert_eq!(self.rank_out(), self.rank_in());
assert_eq!(sk.rank(), self.rank());
}
let (sk_out_dft_data, scratch_1) = scratch.tmp_scalar_znx_dft(module, sk.rank());
let mut sk_out_dft: SecretKeyFourier<&mut [u8], FFT64> = SecretKeyFourier {
data: sk_out_dft_data,
dist: sk.dist,
};
{
(0..self.rank()).for_each(|i| {
let (mut sk_inv_auto, _) = scratch_1.tmp_scalar_znx(module, 1);
module.scalar_znx_automorphism(module.galois_element_inv(p), &mut sk_inv_auto, 0, sk, i);
module.svp_prepare(&mut sk_out_dft, i, &sk_inv_auto, 0);
});
}
self.key.encrypt_sk(
module,
&sk,
&sk_out_dft,
source_xa,
source_xe,
sigma,
scratch_1,
);
self.p = p;
}
}
impl<DataSelf> AutomorphismKey<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
{
pub fn automorphism<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &AutomorphismKey<DataLhs, FFT64>,
rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut base2k::Scratch,
) where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank_in(),
lhs.rank_in(),
"ksk_out input rank: {} != ksk_in input rank: {}",
self.rank_in(),
lhs.rank_in()
);
assert_eq!(
lhs.rank_out(),
rhs.rank_in(),
"ksk_in output rank: {} != ksk_apply input rank: {}",
self.rank_out(),
rhs.rank_in()
);
assert_eq!(
self.rank_out(),
rhs.rank_out(),
"ksk_out output rank: {} != ksk_apply output rank: {}",
self.rank_out(),
rhs.rank_out()
);
}
let cols_out: usize = rhs.rank_out() + 1;
let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, lhs.size());
let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_dft_data,
basek: lhs.basek(),
k: lhs.k(),
};
(0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
// Extracts relevant row
lhs.get_row(module, row_j, col_i, &mut tmp_dft);
// Get a VecZnxBig from scratch space
let (mut tmp_idft_data, scratch2) = scratch1.tmp_vec_znx_big(module, cols_out, self.size());
// Switches input outside of DFT
(0..cols_out).for_each(|i| {
module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2);
});
// Consumes to small vec znx
let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small();
// Reverts the automorphis key from (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
(0..cols_out).for_each(|i| {
module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft_small_data, i);
});
// Wraps into ciphertext
let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
data: tmp_idft_small_data,
basek: self.basek(),
k: self.k(),
};
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
tmp_idft.keyswitch_inplace(module, &rhs.key, scratch2);
// Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a)
// and switches back to DFT domain
(0..self.rank_out() + 1).for_each(|i| {
module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft, i);
module.vec_znx_dft(&mut tmp_dft, i, &tmp_idft, i);
});
// Sets back the relevant row
self.set_row(module, row_j, col_i, &tmp_dft);
});
});
tmp_dft.data.zero();
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
(0..self.rank_in()).for_each(|col_j| {
self.set_row(module, row_i, col_j, &tmp_dft);
});
});
self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64);
}
pub fn automorphism_inplace<DataRhs>(
&mut self,
module: &Module<FFT64>,
rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe {
let self_ptr: *mut AutomorphismKey<DataSelf, FFT64> = self as *mut AutomorphismKey<DataSelf, FFT64>;
self.automorphism(&module, &*self_ptr, rhs, scratch);
}
}
pub fn keyswitch<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &AutomorphismKey<DataLhs, FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut base2k::Scratch,
) where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
self.key.keyswitch(module, &lhs.key, rhs, scratch);
}
pub fn keyswitch_inplace<DataRhs>(
&mut self,
module: &Module<FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut base2k::Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
self.key.keyswitch_inplace(module, &rhs, scratch);
}
pub fn external_product<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &AutomorphismKey<DataLhs, FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
self.key.external_product(module, &lhs.key, rhs, scratch);
}
pub fn external_product_inplace<DataRhs>(
&mut self,
module: &Module<FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
self.key.external_product_inplace(module, rhs, scratch);
}
}

59
core/src/elem.rs Normal file
View File

@@ -0,0 +1,59 @@
use base2k::{Backend, Module, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos};
use crate::utils::derive_size;
pub trait Infos {
type Inner: ZnxInfos;
fn inner(&self) -> &Self::Inner;
/// Returns the ring degree of the polynomials.
fn n(&self) -> usize {
self.inner().n()
}
/// Returns the base two logarithm of the ring dimension of the polynomials.
fn log_n(&self) -> usize {
self.inner().log_n()
}
/// Returns the number of rows.
fn rows(&self) -> usize {
self.inner().rows()
}
/// Returns the number of polynomials in each row.
fn cols(&self) -> usize {
self.inner().cols()
}
/// Returns the number of size per polynomial.
fn size(&self) -> usize {
let size: usize = self.inner().size();
debug_assert_eq!(size, derive_size(self.basek(), self.k()));
size
}
/// Returns the total number of small polynomials.
fn poly_count(&self) -> usize {
self.rows() * self.cols() * self.size()
}
/// Returns the base 2 logarithm of the ciphertext base.
fn basek(&self) -> usize;
/// Returns the bit precision of the ciphertext.
fn k(&self) -> usize;
}
pub trait GetRow<B: Backend> {
fn get_row<R>(&self, module: &Module<B>, row_i: usize, col_j: usize, res: &mut R)
where
R: VecZnxDftToMut<B>;
}
pub trait SetRow<B: Backend> {
fn set_row<R>(&mut self, module: &Module<B>, row_i: usize, col_j: usize, a: &R)
where
R: VecZnxDftToRef<B>;
}

View File

@@ -0,0 +1,211 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxInfos,
ZnxZero,
};
use sampling::source::Source;
use crate::{
elem::{GetRow, Infos, SetRow},
glwe_ciphertext::GLWECiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::SecretKeyFourier,
utils::derive_size,
};
pub struct GGLWECiphertext<C, B: Backend> {
pub(crate) data: MatZnxDft<C, B>,
pub(crate) basek: usize,
pub(crate) k: usize,
}
impl<B: Backend> GGLWECiphertext<Vec<u8>, B> {
pub fn new(module: &Module<B>, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self {
Self {
data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, derive_size(basek, k)),
basek: basek,
k,
}
}
}
impl<T, B: Backend> Infos for GGLWECiphertext<T, B> {
type Inner = MatZnxDft<T, B>;
fn inner(&self) -> &Self::Inner {
&self.data
}
fn basek(&self) -> usize {
self.basek
}
fn k(&self) -> usize {
self.k
}
}
impl<T, B: Backend> GGLWECiphertext<T, B> {
pub fn rank(&self) -> usize {
self.data.cols_out() - 1
}
pub fn rank_in(&self) -> usize {
self.data.cols_in()
}
pub fn rank_out(&self) -> usize {
self.data.cols_out() - 1
}
}
impl<C, B: Backend> MatZnxDftToMut<B> for GGLWECiphertext<C, B>
where
MatZnxDft<C, B>: MatZnxDftToMut<B>,
{
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> MatZnxDftToRef<B> for GGLWECiphertext<C, B>
where
MatZnxDft<C, B>: MatZnxDftToRef<B>,
{
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
self.data.to_ref()
}
}
impl GGLWECiphertext<Vec<u8>, FFT64> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
GLWECiphertext::encrypt_sk_scratch_space(module, size)
+ module.bytes_of_vec_znx(rank + 1, size)
+ module.bytes_of_vec_znx(1, size)
+ module.bytes_of_vec_znx_dft(rank + 1, size)
}
pub fn encrypt_pk_scratch_space(_module: &Module<FFT64>, _rank: usize, _pk_size: usize) -> usize {
unimplemented!()
}
}
impl<DataSelf> GGLWECiphertext<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + ZnxInfos,
{
pub fn encrypt_sk<DataPt, DataSk>(
&mut self,
module: &Module<FFT64>,
pt: &ScalarZnx<DataPt>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch,
) where
ScalarZnx<DataPt>: ScalarZnxToRef,
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.rank_in(), pt.cols());
assert_eq!(self.rank_out(), sk_dft.rank());
assert_eq!(self.n(), module.n());
assert_eq!(sk_dft.n(), module.n());
assert_eq!(pt.n(), module.n());
}
let rows: usize = self.rows();
let size: usize = self.size();
let basek: usize = self.basek();
let k: usize = self.k();
let cols_in: usize = self.rank_in();
let cols_out: usize = self.rank_out() + 1;
let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size);
let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, cols_out, size);
let (tmp_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, cols_out, size);
let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext {
data: tmp_znx_pt,
basek,
k,
};
let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext {
data: tmp_znx_ct,
basek,
k,
};
let mut vec_znx_ct_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier {
data: tmp_znx_dft_ct,
basek,
k,
};
// For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns
//
// Example for ksk rank 2 to rank 3:
//
// (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2)
// (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2)
//
// Example ksk rank 2 to rank 1
//
// (-(a*s) + s0, a)
// (-(b*s) + s1, b)
(0..cols_in).for_each(|col_i| {
(0..rows).for_each(|row_i| {
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
vec_znx_pt.data.zero(); // zeroes for next iteration
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, col_i); // Selects the i-th
module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3);
// rlwe encrypt of vec_znx_pt into vec_znx_ct
vec_znx_ct.encrypt_sk(
module,
&vec_znx_pt,
sk_dft,
source_xa,
source_xe,
sigma,
scratch_3,
);
// Switch vec_znx_ct into DFT domain
vec_znx_ct.dft(module, &mut vec_znx_ct_dft);
// Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft
module.vmp_prepare_row(self, row_i, col_i, &vec_znx_ct_dft);
});
});
}
}
impl<C> GetRow<FFT64> for GGLWECiphertext<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
{
fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut R)
where
R: VecZnxDftToMut<FFT64>,
{
module.vmp_extract_row(res, self, row_i, col_j);
}
}
impl<C> SetRow<FFT64> for GGLWECiphertext<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
{
fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, col_j: usize, a: &R)
where
R: VecZnxDftToRef<FFT64>,
{
module.vmp_prepare_row(self, row_i, col_j, a);
}
}

684
core/src/ggsw_ciphertext.rs Normal file
View File

@@ -0,0 +1,684 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps,
VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut,
VecZnxToRef, ZnxInfos, ZnxZero,
};
use sampling::source::Source;
use crate::{
automorphism::AutomorphismKey,
elem::{GetRow, Infos, SetRow},
glwe_ciphertext::GLWECiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::SecretKeyFourier,
keyswitch_key::GLWESwitchingKey,
tensor_key::TensorKey,
utils::derive_size,
};
pub struct GGSWCiphertext<C, B: Backend> {
pub data: MatZnxDft<C, B>,
pub basek: usize,
pub k: usize,
}
impl<B: Backend> GGSWCiphertext<Vec<u8>, B> {
pub fn new(module: &Module<B>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
Self {
data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k)),
basek: basek,
k: k,
}
}
}
impl<T, B: Backend> Infos for GGSWCiphertext<T, B> {
type Inner = MatZnxDft<T, B>;
fn inner(&self) -> &Self::Inner {
&self.data
}
fn basek(&self) -> usize {
self.basek
}
fn k(&self) -> usize {
self.k
}
}
impl<T, B: Backend> GGSWCiphertext<T, B> {
pub fn rank(&self) -> usize {
self.data.cols_out() - 1
}
}
impl<C, B: Backend> MatZnxDftToMut<B> for GGSWCiphertext<C, B>
where
MatZnxDft<C, B>: MatZnxDftToMut<B>,
{
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> MatZnxDftToRef<B> for GGSWCiphertext<C, B>
where
MatZnxDft<C, B>: MatZnxDftToRef<B>,
{
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
self.data.to_ref()
}
}
impl GGSWCiphertext<Vec<u8>, FFT64> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
GLWECiphertext::encrypt_sk_scratch_space(module, size)
+ module.bytes_of_vec_znx(rank + 1, size)
+ module.bytes_of_vec_znx(1, size)
+ module.bytes_of_vec_znx_dft(rank + 1, size)
}
pub(crate) fn expand_row_scratch_space(
module: &Module<FFT64>,
self_size: usize,
tensor_key_size: usize,
rank: usize,
) -> usize {
let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tensor_key_size);
let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size);
let vmp: usize =
tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tensor_key_size);
let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tensor_key_size);
let norm: usize = module.vec_znx_big_normalize_tmp_bytes();
tmp_dft_i + ((tmp_dft_col_data + vmp) | (tmp_idft + norm))
}
pub(crate) fn keyswitch_internal_col0_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
ksk_size: usize,
rank: usize,
) -> usize {
GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, ksk_size)
+ module.bytes_of_vec_znx_dft(rank + 1, in_size)
}
pub fn keyswitch_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
ksk_size: usize,
tensor_key_size: usize,
rank: usize,
) -> usize {
let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size);
let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
let ks: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, out_size, in_size, ksk_size, rank);
let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tensor_key_size, rank);
let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
res_znx + ci_dft + (ks | expand_rows | res_dft)
}
pub fn keyswitch_inplace_scratch_space(
module: &Module<FFT64>,
out_size: usize,
ksk_size: usize,
tensor_key_size: usize,
rank: usize,
) -> usize {
GGSWCiphertext::keyswitch_scratch_space(module, out_size, out_size, ksk_size, tensor_key_size, rank)
}
pub fn automorphism_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
auto_key_size: usize,
tensor_key_size: usize,
rank: usize,
) -> usize {
GGSWCiphertext::keyswitch_scratch_space(
module,
out_size,
in_size,
auto_key_size,
tensor_key_size,
rank,
)
}
pub fn automorphism_inplace_scratch_space(
module: &Module<FFT64>,
out_size: usize,
auto_key_size: usize,
tensor_key_size: usize,
rank: usize,
) -> usize {
GGSWCiphertext::automorphism_scratch_space(
module,
out_size,
out_size,
auto_key_size,
tensor_key_size,
rank,
)
}
pub fn external_product_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size);
let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank);
tmp_in + tmp_out + ggsw
}
pub fn external_product_inplace_scratch_space(
module: &Module<FFT64>,
out_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank);
tmp + ggsw
}
}
impl<DataSelf> GGSWCiphertext<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
{
pub fn encrypt_sk<DataPt, DataSk>(
&mut self,
module: &Module<FFT64>,
pt: &ScalarZnx<DataPt>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch,
) where
ScalarZnx<DataPt>: ScalarZnxToRef,
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), sk_dft.rank());
assert_eq!(self.n(), module.n());
assert_eq!(pt.n(), module.n());
assert_eq!(sk_dft.n(), module.n());
}
let size: usize = self.size();
let basek: usize = self.basek();
let k: usize = self.k();
let cols: usize = self.rank() + 1;
let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size);
let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, cols, size);
let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext {
data: tmp_znx_pt,
basek: basek,
k: k,
};
let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext {
data: tmp_znx_ct,
basek: basek,
k,
};
(0..self.rows()).for_each(|row_i| {
vec_znx_pt.data.zero();
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0);
module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scrach_2);
(0..cols).for_each(|col_j| {
// rlwe encrypt of vec_znx_pt into vec_znx_ct
vec_znx_ct.encrypt_sk_private(
module,
Some((&vec_znx_pt, col_j)),
sk_dft,
source_xa,
source_xe,
sigma,
scrach_2,
);
// Switch vec_znx_ct into DFT domain
{
let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, cols, size);
(0..cols).for_each(|i| {
module.vec_znx_dft(&mut vec_znx_dft_ct, i, &vec_znx_ct, i);
});
self.set_row(module, row_i, col_j, &vec_znx_dft_ct);
}
});
});
}
pub(crate) fn expand_row<R, DataCi, DataTsk>(
&mut self,
module: &Module<FFT64>,
col_j: usize,
res: &mut R,
ci_dft: &VecZnxDft<DataCi, FFT64>,
tsk: &TensorKey<DataTsk, FFT64>,
scratch: &mut Scratch,
) where
R: VecZnxToMut,
VecZnxDft<DataCi, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataTsk, FFT64>: MatZnxDftToRef<FFT64>,
{
let cols: usize = self.rank() + 1;
// Example for rank 3:
//
// Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is
// actually composed of that many rows and we focus on a specific row here
// implicitely given ci_dft.
//
// # Input
//
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 )
// col 1: (0, 0, 0, 0)
// col 2: (0, 0, 0, 0)
// col 3: (0, 0, 0, 0)
//
// # Output
//
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 )
// col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 )
// col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 )
// col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i])
let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size());
{
let (mut tmp_dft_col_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, self.size());
// Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
//
// # Example for col=1
//
// a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2)
// +
// a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2)
// +
// a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2)
// =
// (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2)
(1..cols).for_each(|col_i| {
// Extracts a[i] and multipies with Enc(s[i]s[j])
tmp_dft_col_data.extract_column(0, ci_dft, col_i);
if col_i == 1 {
module.vmp_apply(
&mut tmp_dft_i,
&tmp_dft_col_data,
tsk.at(col_i - 1, col_j - 1), // Selects Enc(s[i]s[j])
scratch2,
);
} else {
module.vmp_apply_add(
&mut tmp_dft_i,
&tmp_dft_col_data,
tsk.at(col_i - 1, col_j - 1), // Selects Enc(s[i]s[j])
scratch2,
);
}
});
}
// Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i
//
// (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2)
// +
// (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0)
// =
// (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2)
// =
// (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2)
module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, ci_dft, 0);
let (mut tmp_idft, scratch2) = scratch1.tmp_vec_znx_big(module, 1, tsk.size());
(0..cols).for_each(|i| {
module.vec_znx_idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i);
module.vec_znx_big_normalize(self.basek(), res, i, &tmp_idft, 0, scratch2);
});
}
pub fn keyswitch<DataLhs, DataKsk, DataTsk>(
&mut self,
module: &Module<FFT64>,
lhs: &GGSWCiphertext<DataLhs, FFT64>,
ksk: &GLWESwitchingKey<DataKsk, FFT64>,
tsk: &TensorKey<DataTsk, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataKsk, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataTsk, FFT64>: MatZnxDftToRef<FFT64>,
{
let cols: usize = self.rank() + 1;
let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size());
let mut res: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
data: res_data,
basek: self.basek(),
k: self.k(),
};
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size());
// Keyswitch the j-th row of the col 0
(0..lhs.rows()).for_each(|row_i| {
// Key-switch column 0, i.e.
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
lhs.keyswitch_internal_col0(module, row_i, &mut res, ksk, scratch2);
// Isolates DFT(a[i])
(0..cols).for_each(|col_i| {
module.vec_znx_dft(&mut ci_dft, col_i, &res, col_i);
});
self.set_row(module, row_i, 0, &ci_dft);
// Generates
//
// col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M[i], b1 , b2 )
// col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 )
// col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i])
(1..cols).for_each(|col_j| {
self.expand_row(module, col_j, &mut res, &ci_dft, tsk, scratch2);
let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size());
(0..cols).for_each(|i| {
module.vec_znx_dft(&mut res_dft, i, &res, i);
});
self.set_row(module, row_i, col_j, &res_dft);
})
})
}
pub fn keyswitch_inplace<DataKsk, DataTsk>(
&mut self,
module: &Module<FFT64>,
ksk: &GLWESwitchingKey<DataKsk, FFT64>,
tsk: &TensorKey<DataTsk, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataKsk, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataTsk, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe {
let self_ptr: *mut GGSWCiphertext<DataSelf, FFT64> = self as *mut GGSWCiphertext<DataSelf, FFT64>;
self.keyswitch(module, &*self_ptr, ksk, tsk, scratch);
}
}
pub fn automorphism<DataLhs, DataAk, DataTsk>(
&mut self,
module: &Module<FFT64>,
lhs: &GGSWCiphertext<DataLhs, FFT64>,
auto_key: &AutomorphismKey<DataAk, FFT64>,
tensor_key: &TensorKey<DataTsk, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataAk, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataTsk, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank(),
lhs.rank(),
"ggsw_out rank: {} != ggsw_in rank: {}",
self.rank(),
lhs.rank()
);
assert_eq!(
self.rank(),
auto_key.rank(),
"ggsw_in rank: {} != auto_key rank: {}",
self.rank(),
auto_key.rank()
);
assert_eq!(
self.rank(),
tensor_key.rank(),
"ggsw_in rank: {} != tensor_key rank: {}",
self.rank(),
tensor_key.rank()
);
};
let cols: usize = self.rank() + 1;
let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size());
let mut res: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
data: res_data,
basek: self.basek(),
k: self.k(),
};
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size());
// Keyswitch the j-th row of the col 0
(0..lhs.rows()).for_each(|row_i| {
// Key-switch column 0, i.e.
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2)
lhs.keyswitch_internal_col0(module, row_i, &mut res, &auto_key.key, scratch2);
// Isolates DFT(AUTO(a[i]))
(0..cols).for_each(|col_i| {
// (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) -> (-(a0s0 + a1s1 + a2s2) + pi(M[i]), a0, a1, a2)
module.vec_znx_automorphism_inplace(auto_key.p(), &mut res, col_i);
module.vec_znx_dft(&mut ci_dft, col_i, &res, col_i);
});
self.set_row(module, row_i, 0, &ci_dft);
// Generates
//
// col 1: (-(b0s0 + b1s1 + b2s2) , b0 + pi(M[i]), b1 , b2 )
// col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + pi(M[i]), c2 )
// col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + pi(M[i]))
(1..cols).for_each(|col_j| {
self.expand_row(module, col_j, &mut res, &ci_dft, tensor_key, scratch2);
let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size());
(0..cols).for_each(|i| {
module.vec_znx_dft(&mut res_dft, i, &res, i);
});
self.set_row(module, row_i, col_j, &res_dft);
})
})
}
pub fn automorphism_inplace<DataKsk, DataTsk>(
&mut self,
module: &Module<FFT64>,
auto_key: &AutomorphismKey<DataKsk, FFT64>,
tensor_key: &TensorKey<DataTsk, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataKsk, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataTsk, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe {
let self_ptr: *mut GGSWCiphertext<DataSelf, FFT64> = self as *mut GGSWCiphertext<DataSelf, FFT64>;
self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch);
}
}
pub fn external_product<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &GGSWCiphertext<DataLhs, FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank(),
lhs.rank(),
"ggsw_out rank: {} != ggsw_in rank: {}",
self.rank(),
lhs.rank()
);
assert_eq!(
self.rank(),
rhs.rank(),
"ggsw_in rank: {} != ggsw_apply rank: {}",
self.rank(),
rhs.rank()
);
}
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank() + 1, lhs.size());
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_in_data,
basek: lhs.basek(),
k: lhs.k(),
};
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_out_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank() + 1).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
lhs.get_row(module, row_j, col_i, &mut tmp_in);
tmp_out.external_product(module, &tmp_in, rhs, scratch2);
self.set_row(module, row_j, col_i, &tmp_out);
});
});
tmp_out.data.zero();
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
(0..self.rank() + 1).for_each(|col_j| {
self.set_row(module, row_i, col_j, &tmp_out);
});
});
}
pub fn external_product_inplace<DataRhs>(
&mut self,
module: &Module<FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank(),
rhs.rank(),
"ggsw_out rank: {} != ggsw_apply: {}",
self.rank(),
rhs.rank()
);
}
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank() + 1).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
self.get_row(module, row_j, col_i, &mut tmp);
tmp.external_product_inplace(module, rhs, scratch1);
self.set_row(module, row_j, col_i, &tmp);
});
});
}
}
impl<DataSelf> GGSWCiphertext<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToRef<FFT64>,
{
pub(crate) fn keyswitch_internal_col0<DataRes, DataKsk>(
&self,
module: &Module<FFT64>,
row_i: usize,
res: &mut GLWECiphertext<DataRes>,
ksk: &GLWESwitchingKey<DataKsk, FFT64>,
scratch: &mut Scratch,
) where
VecZnx<DataRes>: VecZnxToMut + VecZnxToRef,
MatZnxDft<DataKsk, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), ksk.rank());
assert_eq!(res.rank(), ksk.rank());
}
let (tmp_dft_in_data, scratch2) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
let mut tmp_dft_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_dft_in_data,
basek: self.basek(),
k: self.k(),
};
self.get_row(module, row_i, 0, &mut tmp_dft_in);
res.keyswitch_from_fourier(module, &tmp_dft_in, ksk, scratch2);
}
}
impl<DataSelf> GetRow<FFT64> for GGSWCiphertext<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToRef<FFT64>,
{
fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut R)
where
R: VecZnxDftToMut<FFT64>,
{
module.vmp_extract_row(res, self, row_i, col_j);
}
}
impl<DataSelf> SetRow<FFT64> for GGSWCiphertext<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64>,
{
fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, col_j: usize, a: &R)
where
R: VecZnxDftToRef<FFT64>,
{
module.vmp_prepare_row(self, row_i, col_j, a);
}
}

696
core/src/glwe_ciphertext.rs Normal file
View File

@@ -0,0 +1,696 @@
use base2k::{
AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxAlloc,
ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc,
VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps,
VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero,
};
use sampling::source::Source;
use crate::{
SIX_SIGMA,
automorphism::AutomorphismKey,
elem::Infos,
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
utils::derive_size,
};
pub struct GLWECiphertext<C> {
pub data: VecZnx<C>,
pub basek: usize,
pub k: usize,
}
impl GLWECiphertext<Vec<u8>> {
pub fn new<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self {
Self {
data: module.new_vec_znx(rank + 1, derive_size(basek, k)),
basek,
k,
}
}
}
impl<T> Infos for GLWECiphertext<T> {
type Inner = VecZnx<T>;
fn inner(&self) -> &Self::Inner {
&self.data
}
fn basek(&self) -> usize {
self.basek
}
fn k(&self) -> usize {
self.k
}
}
impl<T> GLWECiphertext<T> {
pub fn rank(&self) -> usize {
self.cols() - 1
}
}
impl<C> VecZnxToMut for GLWECiphertext<C>
where
VecZnx<C>: VecZnxToMut,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
self.data.to_mut()
}
}
impl<C> VecZnxToRef for GLWECiphertext<C>
where
VecZnx<C>: VecZnxToRef,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
self.data.to_ref()
}
}
impl<C> GLWECiphertext<C>
where
VecZnx<C>: VecZnxToRef,
{
#[allow(dead_code)]
pub(crate) fn dft<R>(&self, module: &Module<FFT64>, res: &mut GLWECiphertextFourier<R, FFT64>)
where
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64> + ZnxInfos,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), res.rank());
assert_eq!(self.basek(), res.basek())
}
(0..self.rank() + 1).for_each(|i| {
module.vec_znx_dft(res, i, self, i);
})
}
}
impl GLWECiphertext<Vec<u8>> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, ct_size: usize) -> usize {
module.vec_znx_big_normalize_tmp_bytes()
+ module.bytes_of_vec_znx_dft(1, ct_size)
+ module.bytes_of_vec_znx_big(1, ct_size)
}
pub fn encrypt_pk_scratch_space(module: &Module<FFT64>, pk_size: usize) -> usize {
((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1))
+ module.bytes_of_scalar_znx_dft(1)
+ module.vec_znx_big_normalize_tmp_bytes()
}
pub fn decrypt_scratch_space(module: &Module<FFT64>, ct_size: usize) -> usize {
(module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, ct_size))
+ module.bytes_of_vec_znx_big(1, ct_size)
}
pub fn keyswitch_scratch_space(
module: &Module<FFT64>,
out_size: usize,
out_rank: usize,
in_size: usize,
in_rank: usize,
ksk_size: usize,
) -> usize {
let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size);
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size)
+ module.bytes_of_vec_znx_dft(in_rank, in_size);
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
return res_dft + (vmp | normalize);
}
pub fn keyswitch_from_fourier_scratch_space(
module: &Module<FFT64>,
out_size: usize,
out_rank: usize,
in_size: usize,
in_rank: usize,
ksk_size: usize,
) -> usize {
let res_dft = module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size);
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size)
+ module.bytes_of_vec_znx_dft(in_rank, in_size);
let norm: usize = module.vec_znx_big_normalize_tmp_bytes();
res_dft + (vmp | norm)
}
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size)
}
pub fn automorphism_scratch_space(
module: &Module<FFT64>,
out_size: usize,
out_rank: usize,
in_size: usize,
autokey_size: usize,
) -> usize {
GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, in_size, out_rank, autokey_size)
}
pub fn automorphism_inplace_scratch_space(
module: &Module<FFT64>,
out_size: usize,
out_rank: usize,
autokey_size: usize,
) -> usize {
GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, autokey_size)
}
pub fn external_product_scratch_space(
module: &Module<FFT64>,
out_size: usize,
out_rank: usize,
in_size: usize,
ggsw_size: usize,
) -> usize {
let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, ggsw_size);
let vmp: usize = module.bytes_of_vec_znx_dft(out_rank + 1, in_size)
+ module.vmp_apply_tmp_bytes(
out_size,
in_size,
in_size, // rows
out_rank + 1, // cols in
out_rank + 1, // cols out
ggsw_size,
);
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
res_dft + (vmp | normalize)
}
pub fn external_product_inplace_scratch_space(
module: &Module<FFT64>,
out_size: usize,
out_rank: usize,
ggsw_size: usize,
) -> usize {
GLWECiphertext::external_product_scratch_space(module, out_size, out_rank, out_size, ggsw_size)
}
}
impl<DataSelf> GLWECiphertext<DataSelf>
where
VecZnx<DataSelf>: VecZnxToMut + VecZnxToRef,
{
pub fn encrypt_sk<DataPt, DataSk>(
&mut self,
module: &Module<FFT64>,
pt: &GLWEPlaintext<DataPt>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch,
) where
VecZnx<DataPt>: VecZnxToRef,
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
self.encrypt_sk_private(
module,
Some((pt, 0)),
sk_dft,
source_xa,
source_xe,
sigma,
scratch,
);
}
pub fn encrypt_zero_sk<DataSk>(
&mut self,
module: &Module<FFT64>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch,
) where
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
self.encrypt_sk_private(module, None, sk_dft, source_xa, source_xe, sigma, scratch);
}
pub fn encrypt_pk<DataPt, DataPk>(
&mut self,
module: &Module<FFT64>,
pt: &GLWEPlaintext<DataPt>,
pk: &GLWEPublicKey<DataPk, FFT64>,
source_xu: &mut Source,
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch,
) where
VecZnx<DataPt>: VecZnxToRef,
VecZnxDft<DataPk, FFT64>: VecZnxDftToRef<FFT64>,
{
self.encrypt_pk_private(
module,
Some((pt, 0)),
pk,
source_xu,
source_xe,
sigma,
scratch,
);
}
pub fn encrypt_zero_pk<DataPk>(
&mut self,
module: &Module<FFT64>,
pk: &GLWEPublicKey<DataPk, FFT64>,
source_xu: &mut Source,
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch,
) where
VecZnxDft<DataPk, FFT64>: VecZnxDftToRef<FFT64>,
{
self.encrypt_pk_private(module, None, pk, source_xu, source_xe, sigma, scratch);
}
pub fn automorphism<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &GLWECiphertext<DataLhs>,
rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
VecZnx<DataLhs>: VecZnxToRef,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
self.keyswitch(module, lhs, &rhs.key, scratch);
(0..self.rank() + 1).for_each(|i| {
module.vec_znx_automorphism_inplace(rhs.p(), self, i);
})
}
pub fn automorphism_inplace<DataRhs>(
&mut self,
module: &Module<FFT64>,
rhs: &AutomorphismKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
self.keyswitch_inplace(module, &rhs.key, scratch);
(0..self.rank() + 1).for_each(|i| {
module.vec_znx_automorphism_inplace(rhs.p(), self, i);
})
}
pub(crate) fn keyswitch_from_fourier<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(lhs.rank(), rhs.rank_in());
assert_eq!(self.rank(), rhs.rank_out());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
assert!(
scratch.available()
>= GLWECiphertext::keyswitch_from_fourier_scratch_space(
module,
self.size(),
self.rank(),
lhs.size(),
lhs.rank(),
rhs.size(),
)
);
}
let cols_in: usize = rhs.rank_in();
let cols_out: usize = rhs.rank_out() + 1;
// Buffer of the result of VMP in DFT
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
{
// Applies VMP
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
(0..cols_in).for_each(|col_i| {
module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
});
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
}
module.vec_znx_dft_add_inplace(&mut res_dft, 0, lhs, 0);
// Switches result of VMP outside of DFT
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
(0..cols_out).for_each(|i| {
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
});
}
pub fn keyswitch<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &GLWECiphertext<DataLhs>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
VecZnx<DataLhs>: VecZnxToRef,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(lhs.rank(), rhs.rank_in());
assert_eq!(self.rank(), rhs.rank_out());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
assert!(
scratch.available()
>= GLWECiphertext::keyswitch_scratch_space(
module,
self.size(),
self.rank(),
lhs.size(),
lhs.rank(),
rhs.size(),
)
);
}
let cols_in: usize = rhs.rank_in();
let cols_out: usize = rhs.rank_out() + 1;
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
{
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
(0..cols_in).for_each(|col_i| {
module.vec_znx_dft(&mut ai_dft, col_i, lhs, col_i + 1);
});
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
}
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, 0, lhs, 0);
(0..cols_out).for_each(|i| {
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
});
}
pub fn keyswitch_inplace<DataRhs>(
&mut self,
module: &Module<FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe {
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
self.keyswitch(&module, &*self_ptr, rhs, scratch);
}
}
pub fn external_product<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &GLWECiphertext<DataLhs>,
rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
VecZnx<DataLhs>: VecZnxToRef,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(rhs.rank(), lhs.rank());
assert_eq!(rhs.rank(), self.rank());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
}
let cols: usize = rhs.rank() + 1;
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise
{
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
(0..cols).for_each(|col_i| {
module.vec_znx_dft(&mut a_dft, col_i, lhs, col_i);
});
module.vmp_apply(&mut res_dft, &a_dft, rhs, scratch2);
}
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
(0..cols).for_each(|i| {
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
});
}
pub fn external_product_inplace<DataRhs>(
&mut self,
module: &Module<FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe {
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
self.external_product(&module, &*self_ptr, rhs, scratch);
}
}
pub(crate) fn encrypt_sk_private<DataPt, DataSk>(
&mut self,
module: &Module<FFT64>,
pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch,
) where
VecZnx<DataPt>: VecZnxToRef,
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), sk_dft.rank());
assert_eq!(sk_dft.n(), module.n());
assert_eq!(self.n(), module.n());
if let Some((pt, col)) = pt {
assert_eq!(pt.n(), module.n());
assert!(col < self.rank() + 1);
}
}
let log_base2k: usize = self.basek();
let log_k: usize = self.k();
let size: usize = self.size();
let cols: usize = self.rank() + 1;
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx(module, 1, size);
c0_big.zero();
{
// c[i] = uniform
// c[0] -= c[i] * s[i],
(1..cols).for_each(|i| {
let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size);
// c[i] = uniform
self.data.fill_uniform(log_base2k, i, size, source_xa);
// c[i] = norm(IDFT(DFT(c[i]) * DFT(s[i])))
module.vec_znx_dft(&mut ci_dft, 0, self, i);
module.svp_apply_inplace(&mut ci_dft, 0, sk_dft, i - 1);
let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft);
// use c[0] as buffer, which is overwritten later by the normalization step
module.vec_znx_big_normalize(log_base2k, self, 0, &ci_big, 0, scratch_2);
// c0_tmp = -c[i] * s[i] (use c[0] as buffer)
module.vec_znx_sub_ab_inplace(&mut c0_big, 0, self, 0);
// c[i] += m if col = i
if let Some((pt, col)) = pt {
if i == col {
module.vec_znx_add_inplace(self, i, pt, 0);
module.vec_znx_normalize_inplace(log_base2k, self, i, scratch_2);
}
}
});
}
// c[0] += e
c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, sigma * SIX_SIGMA);
// c[0] += m if col = 0
if let Some((pt, col)) = pt {
if col == 0 {
module.vec_znx_add_inplace(&mut c0_big, 0, pt, 0);
}
}
// c[0] = norm(c[0])
module.vec_znx_normalize(log_base2k, self, 0, &c0_big, 0, scratch_1);
}
pub(crate) fn encrypt_pk_private<DataPt, DataPk>(
&mut self,
module: &Module<FFT64>,
pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
pk: &GLWEPublicKey<DataPk, FFT64>,
source_xu: &mut Source,
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch,
) where
VecZnx<DataPt>: VecZnxToRef,
VecZnxDft<DataPk, FFT64>: VecZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.basek(), pk.basek());
assert_eq!(self.n(), module.n());
assert_eq!(pk.n(), module.n());
assert_eq!(self.rank(), pk.rank());
if let Some((pt, _)) = pt {
assert_eq!(pt.basek(), pk.basek());
assert_eq!(pt.n(), module.n());
}
}
let log_base2k: usize = pk.basek();
let size_pk: usize = pk.size();
let cols: usize = self.rank() + 1;
// Generates u according to the underlying secret distribution.
let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1);
{
let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1);
match pk.dist {
SecretDistribution::NONE => panic!(
"invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \
Self::generate"
),
SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu),
SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu),
SecretDistribution::ZERO => {}
}
module.svp_prepare(&mut u_dft, 0, &u, 0);
}
// ct[i] = pk[i] * u + ei (+ m if col = i)
(0..cols).for_each(|i| {
let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size_pk);
// ci_dft = DFT(u) * DFT(pk[i])
module.svp_apply(&mut ci_dft, 0, &u_dft, 0, pk, i);
// ci_big = u * p[i]
let mut ci_big = module.vec_znx_idft_consume(ci_dft);
// ci_big = u * pk[i] + e
ci_big.add_normal(log_base2k, 0, pk.k(), source_xe, sigma, sigma * SIX_SIGMA);
// ci_big = u * pk[i] + e + m (if col = i)
if let Some((pt, col)) = pt {
if col == i {
module.vec_znx_big_add_small_inplace(&mut ci_big, 0, pt, 0);
}
}
// ct[i] = norm(ci_big)
module.vec_znx_big_normalize(log_base2k, self, i, &ci_big, 0, scratch_2);
});
}
}
impl<DataSelf> GLWECiphertext<DataSelf>
where
VecZnx<DataSelf>: VecZnxToRef,
{
pub fn decrypt<DataPt, DataSk>(
&self,
module: &Module<FFT64>,
pt: &mut GLWEPlaintext<DataPt>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
scratch: &mut Scratch,
) where
VecZnx<DataPt>: VecZnxToMut,
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), sk_dft.rank());
assert_eq!(self.n(), module.n());
assert_eq!(pt.n(), module.n());
assert_eq!(sk_dft.n(), module.n());
}
let cols: usize = self.rank() + 1;
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct
c0_big.zero();
{
(1..cols).for_each(|i| {
// ci_dft = DFT(a[i]) * DFT(s[i])
let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct
module.vec_znx_dft(&mut ci_dft, 0, self, i);
module.svp_apply_inplace(&mut ci_dft, 0, sk_dft, i - 1);
let ci_big = module.vec_znx_idft_consume(ci_dft);
// c0_big += a[i] * s[i]
module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0);
});
}
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
module.vec_znx_big_add_small_inplace(&mut c0_big, 0, self, 0);
// pt = norm(BIG(m + e))
module.vec_znx_big_normalize(self.basek(), pt, 0, &mut c0_big, 0, scratch_1);
pt.basek = self.basek();
pt.k = pt.k().min(self.k());
}
}

View File

@@ -0,0 +1,323 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps,
ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero,
};
use sampling::source::Source;
use crate::{
elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, glwe_plaintext::GLWEPlaintext,
keys::SecretKeyFourier, keyswitch_key::GLWESwitchingKey, utils::derive_size,
};
pub struct GLWECiphertextFourier<C, B: Backend> {
pub data: VecZnxDft<C, B>,
pub basek: usize,
pub k: usize,
}
impl<B: Backend> GLWECiphertextFourier<Vec<u8>, B> {
pub fn new(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self {
Self {
data: module.new_vec_znx_dft(rank + 1, derive_size(basek, k)),
basek: basek,
k: k,
}
}
}
impl<T, B: Backend> Infos for GLWECiphertextFourier<T, B> {
type Inner = VecZnxDft<T, B>;
fn inner(&self) -> &Self::Inner {
&self.data
}
fn basek(&self) -> usize {
self.basek
}
fn k(&self) -> usize {
self.k
}
}
impl<T, B: Backend> GLWECiphertextFourier<T, B> {
pub fn rank(&self) -> usize {
self.cols() - 1
}
}
impl<C, B: Backend> VecZnxDftToMut<B> for GLWECiphertextFourier<C, B>
where
VecZnxDft<C, B>: VecZnxDftToMut<B>,
{
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> VecZnxDftToRef<B> for GLWECiphertextFourier<C, B>
where
VecZnxDft<C, B>: VecZnxDftToRef<B>,
{
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
self.data.to_ref()
}
}
impl GLWECiphertextFourier<Vec<u8>, FFT64> {
#[allow(dead_code)]
pub(crate) fn idft_scratch_space(module: &Module<FFT64>, size: usize) -> usize {
module.bytes_of_vec_znx(1, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())
}
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, ct_size: usize) -> usize {
module.bytes_of_vec_znx(rank + 1, ct_size) + GLWECiphertext::encrypt_sk_scratch_space(module, ct_size)
}
pub fn decrypt_scratch_space(module: &Module<FFT64>, ct_size: usize) -> usize {
(module.vec_znx_big_normalize_tmp_bytes()
| module.bytes_of_vec_znx_dft(1, ct_size)
| (module.bytes_of_vec_znx_big(1, ct_size) + module.vec_znx_idft_tmp_bytes()))
+ module.bytes_of_vec_znx_big(1, ct_size)
}
pub fn keyswitch_scratch_space(
module: &Module<FFT64>,
out_size: usize,
out_rank: usize,
in_size: usize,
in_rank: usize,
ksk_size: usize,
) -> usize {
module.bytes_of_vec_znx(out_rank + 1, out_size)
+ GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, out_rank, in_size, in_rank, ksk_size)
}
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
Self::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size)
}
pub fn external_product_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size);
let res_small: usize = module.bytes_of_vec_znx(rank + 1, out_size);
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
res_dft + (vmp | (res_small + normalize))
}
pub fn external_product_inplace_scratch_space(
module: &Module<FFT64>,
out_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
Self::external_product_scratch_space(module, out_size, out_size, ggsw_size, rank)
}
}
impl<DataSelf> GLWECiphertextFourier<DataSelf, FFT64>
where
VecZnxDft<DataSelf, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
{
pub fn encrypt_zero_sk<DataSk>(
&mut self,
module: &Module<FFT64>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch,
) where
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
let (vec_znx_tmp, scratch_1) = scratch.tmp_vec_znx(module, self.rank() + 1, self.size());
let mut ct_idft = GLWECiphertext {
data: vec_znx_tmp,
basek: self.basek,
k: self.k,
};
ct_idft.encrypt_zero_sk(module, sk_dft, source_xa, source_xe, sigma, scratch_1);
ct_idft.dft(module, self);
}
pub fn keyswitch<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
let cols_out: usize = rhs.rank_out() + 1;
// Space fr normalized VMP result outside of DFT domain
let (res_idft_data, scratch1) = scratch.tmp_vec_znx(module, cols_out, lhs.size());
let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
data: res_idft_data,
basek: lhs.basek,
k: lhs.k,
};
res_idft.keyswitch_from_fourier(module, lhs, rhs, scratch1);
(0..cols_out).for_each(|i| {
module.vec_znx_dft(self, i, &res_idft, i);
});
}
pub fn keyswitch_inplace<DataRhs>(
&mut self,
module: &Module<FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe {
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
self.keyswitch(&module, &*self_ptr, rhs, scratch);
}
}
pub fn external_product<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
let basek: usize = self.basek();
#[cfg(debug_assertions)]
{
assert_eq!(rhs.rank(), lhs.rank());
assert_eq!(rhs.rank(), self.rank());
assert_eq!(self.basek(), basek);
assert_eq!(lhs.basek(), basek);
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
}
let cols: usize = rhs.rank() + 1;
// Space for VMP result in DFT domain and high precision
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size());
{
module.vmp_apply(&mut res_dft, lhs, rhs, scratch1);
}
// VMP result in high precision
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
// Space for VMP result normalized
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size());
(0..cols).for_each(|i| {
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
module.vec_znx_dft(self, i, &res_small, i);
});
}
pub fn external_product_inplace<DataRhs>(
&mut self,
module: &Module<FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
unsafe {
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
self.external_product(&module, &*self_ptr, rhs, scratch);
}
}
}
impl<DataSelf> GLWECiphertextFourier<DataSelf, FFT64>
where
VecZnxDft<DataSelf, FFT64>: VecZnxDftToRef<FFT64>,
{
pub fn decrypt<DataPt, DataSk>(
&self,
module: &Module<FFT64>,
pt: &mut GLWEPlaintext<DataPt>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
scratch: &mut Scratch,
) where
VecZnx<DataPt>: VecZnxToMut + VecZnxToRef,
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), sk_dft.rank());
assert_eq!(self.n(), module.n());
assert_eq!(pt.n(), module.n());
assert_eq!(sk_dft.n(), module.n());
}
let cols = self.rank() + 1;
let (mut pt_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct
pt_big.zero();
{
(1..cols).for_each(|i| {
let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct
module.svp_apply(&mut ci_dft, 0, sk_dft, i - 1, self, i);
let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft);
module.vec_znx_big_add_inplace(&mut pt_big, 0, &ci_big, 0);
});
}
{
let (mut c0_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, self.size());
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
module.vec_znx_idft(&mut c0_big, 0, self, 0, scratch_2);
module.vec_znx_big_add_inplace(&mut pt_big, 0, &c0_big, 0);
}
// pt = norm(BIG(m + e))
module.vec_znx_big_normalize(self.basek(), pt, 0, &mut pt_big, 0, scratch_1);
pt.basek = self.basek();
pt.k = pt.k().min(self.k());
}
#[allow(dead_code)]
pub(crate) fn idft<DataRes>(&self, module: &Module<FFT64>, res: &mut GLWECiphertext<DataRes>, scratch: &mut Scratch)
where
GLWECiphertext<DataRes>: VecZnxToMut,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), res.rank());
assert_eq!(self.basek(), res.basek())
}
let min_size: usize = self.size().min(res.size());
let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 1, min_size);
(0..self.rank() + 1).for_each(|i| {
module.vec_znx_idft(&mut res_big, 0, self, i, scratch1);
module.vec_znx_big_normalize(self.basek(), res, i, &res_big, 0, scratch1);
});
}
}

View File

@@ -0,0 +1,53 @@
use base2k::{Backend, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef};
use crate::{elem::Infos, utils::derive_size};
pub struct GLWEPlaintext<C> {
pub data: VecZnx<C>,
pub basek: usize,
pub k: usize,
}
impl<T> Infos for GLWEPlaintext<T> {
type Inner = VecZnx<T>;
fn inner(&self) -> &Self::Inner {
&self.data
}
fn basek(&self) -> usize {
self.basek
}
fn k(&self) -> usize {
self.k
}
}
impl<C> VecZnxToMut for GLWEPlaintext<C>
where
VecZnx<C>: VecZnxToMut,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
self.data.to_mut()
}
}
impl<C> VecZnxToRef for GLWEPlaintext<C>
where
VecZnx<C>: VecZnxToRef,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
self.data.to_ref()
}
}
impl GLWEPlaintext<Vec<u8>> {
pub fn new<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> Self {
Self {
data: module.new_vec_znx(1, derive_size(basek, k)),
basek: basek,
k,
}
}
}

247
core/src/keys.rs Normal file
View File

@@ -0,0 +1,247 @@
use base2k::{
Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut,
ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, ScratchOwned, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos,
ZnxZero,
};
use sampling::source::Source;
use crate::{elem::Infos, glwe_ciphertext_fourier::GLWECiphertextFourier};
#[derive(Clone, Copy, Debug)]
pub enum SecretDistribution {
TernaryFixed(usize), // Ternary with fixed Hamming weight
TernaryProb(f64), // Ternary with probabilistic Hamming weight
ZERO, // Debug mod
NONE,
}
pub struct SecretKey<T> {
pub data: ScalarZnx<T>,
pub dist: SecretDistribution,
}
impl SecretKey<Vec<u8>> {
pub fn new<B: Backend>(module: &Module<B>, rank: usize) -> Self {
Self {
data: module.new_scalar_znx(rank),
dist: SecretDistribution::NONE,
}
}
}
impl<DataSelf> SecretKey<DataSelf> {
pub fn n(&self) -> usize {
self.data.n()
}
pub fn log_n(&self) -> usize {
self.data.log_n()
}
pub fn rank(&self) -> usize {
self.data.cols()
}
}
impl<S> SecretKey<S>
where
S: AsMut<[u8]> + AsRef<[u8]>,
{
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
(0..self.rank()).for_each(|i| {
self.data.fill_ternary_prob(i, prob, source);
});
self.dist = SecretDistribution::TernaryProb(prob);
}
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
(0..self.rank()).for_each(|i| {
self.data.fill_ternary_hw(i, hw, source);
});
self.dist = SecretDistribution::TernaryFixed(hw);
}
pub fn fill_zero(&mut self) {
self.data.zero();
self.dist = SecretDistribution::ZERO;
}
}
impl<C> ScalarZnxToMut for SecretKey<C>
where
ScalarZnx<C>: ScalarZnxToMut,
{
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
self.data.to_mut()
}
}
impl<C> ScalarZnxToRef for SecretKey<C>
where
ScalarZnx<C>: ScalarZnxToRef,
{
fn to_ref(&self) -> ScalarZnx<&[u8]> {
self.data.to_ref()
}
}
pub struct SecretKeyFourier<T, B: Backend> {
pub data: ScalarZnxDft<T, B>,
pub dist: SecretDistribution,
}
impl<DataSelf, B: Backend> SecretKeyFourier<DataSelf, B> {
pub fn n(&self) -> usize {
self.data.n()
}
pub fn log_n(&self) -> usize {
self.data.log_n()
}
pub fn rank(&self) -> usize {
self.data.cols()
}
}
impl<B: Backend> SecretKeyFourier<Vec<u8>, B> {
pub fn new(module: &Module<B>, rank: usize) -> Self {
Self {
data: module.new_scalar_znx_dft(rank),
dist: SecretDistribution::NONE,
}
}
pub fn dft<S>(&mut self, module: &Module<FFT64>, sk: &SecretKey<S>)
where
SecretKeyFourier<Vec<u8>, B>: ScalarZnxDftToMut<base2k::FFT64>,
SecretKey<S>: ScalarZnxToRef,
{
#[cfg(debug_assertions)]
{
match sk.dist {
SecretDistribution::NONE => panic!("invalid sk: SecretDistribution::NONE"),
_ => {}
}
assert_eq!(self.n(), module.n());
assert_eq!(sk.n(), module.n());
assert_eq!(self.rank(), sk.rank());
}
(0..self.rank()).for_each(|i| {
module.svp_prepare(self, i, sk, i);
});
self.dist = sk.dist;
}
}
impl<C, B: Backend> ScalarZnxDftToMut<B> for SecretKeyFourier<C, B>
where
ScalarZnxDft<C, B>: ScalarZnxDftToMut<B>,
{
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> ScalarZnxDftToRef<B> for SecretKeyFourier<C, B>
where
ScalarZnxDft<C, B>: ScalarZnxDftToRef<B>,
{
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
self.data.to_ref()
}
}
pub struct GLWEPublicKey<D, B: Backend> {
pub data: GLWECiphertextFourier<D, B>,
pub dist: SecretDistribution,
}
impl<B: Backend> GLWEPublicKey<Vec<u8>, B> {
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize, rank: usize) -> Self {
Self {
data: GLWECiphertextFourier::new(module, log_base2k, log_k, rank),
dist: SecretDistribution::NONE,
}
}
}
impl<T, B: Backend> Infos for GLWEPublicKey<T, B> {
type Inner = VecZnxDft<T, B>;
fn inner(&self) -> &Self::Inner {
&self.data.data
}
fn basek(&self) -> usize {
self.data.basek
}
fn k(&self) -> usize {
self.data.k
}
}
impl<T, B: Backend> GLWEPublicKey<T, B> {
pub fn rank(&self) -> usize {
self.cols() - 1
}
}
impl<C, B: Backend> VecZnxDftToMut<B> for GLWEPublicKey<C, B>
where
VecZnxDft<C, B>: VecZnxDftToMut<B>,
{
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> VecZnxDftToRef<B> for GLWEPublicKey<C, B>
where
VecZnxDft<C, B>: VecZnxDftToRef<B>,
{
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
self.data.to_ref()
}
}
impl<C> GLWEPublicKey<C, FFT64> {
pub fn generate<S>(
&mut self,
module: &Module<FFT64>,
sk_dft: &SecretKeyFourier<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
) where
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64> + ZnxInfos,
{
#[cfg(debug_assertions)]
{
match sk_dft.dist {
SecretDistribution::NONE => panic!("invalid sk_dft: SecretDistribution::NONE"),
_ => {}
}
}
// Its ok to allocate scratch space here since pk is usually generated only once.
let mut scratch: ScratchOwned = ScratchOwned::new(GLWECiphertextFourier::encrypt_sk_scratch_space(
module,
self.rank(),
self.size(),
));
self.data.encrypt_zero_sk(
module,
sk_dft,
source_xa,
source_xe,
sigma,
scratch.borrow(),
);
self.dist = sk_dft.dist;
}
}

385
core/src/keyswitch_key.rs Normal file
View File

@@ -0,0 +1,385 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef,
ScalarZnxToRef, Scratch, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, ZnxZero,
};
use sampling::source::Source;
use crate::{
elem::{GetRow, Infos, SetRow},
gglwe_ciphertext::GGLWECiphertext,
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
keys::{SecretKey, SecretKeyFourier},
};
pub struct GLWESwitchingKey<Data, B: Backend>(pub(crate) GGLWECiphertext<Data, B>);
impl GLWESwitchingKey<Vec<u8>, FFT64> {
pub fn new(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self {
GLWESwitchingKey(GGLWECiphertext::new(
module, basek, k, rows, rank_in, rank_out,
))
}
}
impl<T, B: Backend> Infos for GLWESwitchingKey<T, B> {
type Inner = MatZnxDft<T, B>;
fn inner(&self) -> &Self::Inner {
self.0.inner()
}
fn basek(&self) -> usize {
self.0.basek()
}
fn k(&self) -> usize {
self.0.k()
}
}
impl<T, B: Backend> GLWESwitchingKey<T, B> {
pub fn rank(&self) -> usize {
self.0.data.cols_out() - 1
}
pub fn rank_in(&self) -> usize {
self.0.data.cols_in()
}
pub fn rank_out(&self) -> usize {
self.0.data.cols_out() - 1
}
}
impl<DataSelf, B: Backend> MatZnxDftToMut<B> for GLWESwitchingKey<DataSelf, B>
where
MatZnxDft<DataSelf, B>: MatZnxDftToMut<B>,
{
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
self.0.data.to_mut()
}
}
impl<DataSelf, B: Backend> MatZnxDftToRef<B> for GLWESwitchingKey<DataSelf, B>
where
MatZnxDft<DataSelf, B>: MatZnxDftToRef<B>,
{
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
self.0.data.to_ref()
}
}
impl<C> GetRow<FFT64> for GLWESwitchingKey<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
{
fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut R)
where
R: VecZnxDftToMut<FFT64>,
{
module.vmp_extract_row(res, self, row_i, col_j);
}
}
impl<C> SetRow<FFT64> for GLWESwitchingKey<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
{
fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, col_j: usize, a: &R)
where
R: VecZnxDftToRef<FFT64>,
{
module.vmp_prepare_row(self, row_i, col_j, a);
}
}
impl GLWESwitchingKey<Vec<u8>, FFT64> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
GGLWECiphertext::encrypt_sk_scratch_space(module, rank, size)
}
pub fn encrypt_pk_scratch_space(module: &Module<FFT64>, rank: usize, pk_size: usize) -> usize {
GGLWECiphertext::encrypt_pk_scratch_space(module, rank, pk_size)
}
pub fn keyswitch_scratch_space(
module: &Module<FFT64>,
out_size: usize,
out_rank: usize,
in_size: usize,
in_rank: usize,
ksk_size: usize,
) -> usize {
let tmp_in: usize = module.bytes_of_vec_znx_dft(in_rank + 1, in_size);
let tmp_out: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size);
let ksk: usize = GLWECiphertextFourier::keyswitch_scratch_space(module, out_size, out_rank, in_size, in_rank, ksk_size);
tmp_in + tmp_out + ksk
}
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
let tmp: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size);
let ksk: usize = GLWECiphertextFourier::keyswitch_inplace_scratch_space(module, out_size, out_rank, ksk_size);
tmp + ksk
}
pub fn external_product_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size);
let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank);
tmp_in + tmp_out + ggsw
}
pub fn external_product_inplace_scratch_space(
module: &Module<FFT64>,
out_size: usize,
ggsw_size: usize,
rank: usize,
) -> usize {
let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank);
tmp + ggsw
}
}
impl<DataSelf> GLWESwitchingKey<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
{
pub fn encrypt_sk<DataSkIn, DataSkOut>(
&mut self,
module: &Module<FFT64>,
sk_in: &SecretKey<DataSkIn>,
sk_out_dft: &SecretKeyFourier<DataSkOut, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch,
) where
ScalarZnx<DataSkIn>: ScalarZnxToRef,
ScalarZnxDft<DataSkOut, FFT64>: ScalarZnxDftToRef<FFT64>,
{
self.0.encrypt_sk(
module,
&sk_in.data,
sk_out_dft,
source_xa,
source_xe,
sigma,
scratch,
);
}
pub fn keyswitch<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &GLWESwitchingKey<DataLhs, FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut base2k::Scratch,
) where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank_in(),
lhs.rank_in(),
"ksk_out input rank: {} != ksk_in input rank: {}",
self.rank_in(),
lhs.rank_in()
);
assert_eq!(
lhs.rank_out(),
rhs.rank_in(),
"ksk_in output rank: {} != ksk_apply input rank: {}",
self.rank_out(),
rhs.rank_in()
);
assert_eq!(
self.rank_out(),
rhs.rank_out(),
"ksk_out output rank: {} != ksk_apply output rank: {}",
self.rank_out(),
rhs.rank_out()
);
}
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size());
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_in_data,
basek: lhs.basek(),
k: lhs.k(),
};
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_out_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
lhs.get_row(module, row_j, col_i, &mut tmp_in);
tmp_out.keyswitch(module, &tmp_in, rhs, scratch2);
self.set_row(module, row_j, col_i, &tmp_out);
});
});
tmp_out.data.zero();
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
(0..self.rank_in()).for_each(|col_j| {
self.set_row(module, row_i, col_j, &tmp_out);
});
});
}
pub fn keyswitch_inplace<DataRhs>(
&mut self,
module: &Module<FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut base2k::Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank_out(),
rhs.rank_out(),
"ksk_out output rank: {} != ksk_apply output rank: {}",
self.rank_out(),
rhs.rank_out()
);
}
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
self.get_row(module, row_j, col_i, &mut tmp);
tmp.keyswitch_inplace(module, rhs, scratch1);
self.set_row(module, row_j, col_i, &tmp);
});
});
}
pub fn external_product<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &GLWESwitchingKey<DataLhs, FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank_in(),
lhs.rank_in(),
"ksk_out input rank: {} != ksk_in input rank: {}",
self.rank_in(),
lhs.rank_in()
);
assert_eq!(
lhs.rank_out(),
rhs.rank(),
"ksk_in output rank: {} != ggsw rank: {}",
self.rank_out(),
rhs.rank()
);
assert_eq!(
self.rank_out(),
rhs.rank(),
"ksk_out output rank: {} != ggsw rank: {}",
self.rank_out(),
rhs.rank()
);
}
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size());
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_in_data,
basek: lhs.basek(),
k: lhs.k(),
};
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_out_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
lhs.get_row(module, row_j, col_i, &mut tmp_in);
tmp_out.external_product(module, &tmp_in, rhs, scratch2);
self.set_row(module, row_j, col_i, &tmp_out);
});
});
tmp_out.data.zero();
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
(0..self.rank_in()).for_each(|col_j| {
self.set_row(module, row_i, col_j, &tmp_out);
});
});
}
pub fn external_product_inplace<DataRhs>(
&mut self,
module: &Module<FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(
self.rank_out(),
rhs.rank(),
"ksk_out output rank: {} != ggsw rank: {}",
self.rank_out(),
rhs.rank()
);
}
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| {
self.get_row(module, row_j, col_i, &mut tmp);
tmp.external_product_inplace(module, rhs, scratch1);
self.set_row(module, row_j, col_i, &tmp);
});
});
}
}

15
core/src/lib.rs Normal file
View File

@@ -0,0 +1,15 @@
pub mod automorphism;
pub mod elem;
pub mod gglwe_ciphertext;
pub mod ggsw_ciphertext;
pub mod glwe_ciphertext;
pub mod glwe_ciphertext_fourier;
pub mod glwe_plaintext;
pub mod keys;
pub mod keyswitch_key;
pub mod tensor_key;
#[cfg(test)]
mod test_fft64;
mod utils;
pub(crate) const SIX_SIGMA: f64 = 6.0;

130
core/src/tensor_key.rs Normal file
View File

@@ -0,0 +1,130 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftAlloc,
ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnxDftOps, VecZnxDftToRef,
};
use sampling::source::Source;
use crate::{
elem::Infos,
keys::{SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
};
pub struct TensorKey<C, B: Backend> {
pub(crate) keys: Vec<GLWESwitchingKey<C, B>>,
}
impl TensorKey<Vec<u8>, FFT64> {
pub fn new(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
let mut keys: Vec<GLWESwitchingKey<Vec<u8>, FFT64>> = Vec::new();
let pairs: usize = ((rank + 1) * rank) >> 1;
(0..pairs).for_each(|_| {
keys.push(GLWESwitchingKey::new(module, basek, k, rows, 1, rank));
});
Self { keys: keys }
}
}
impl<T, B: Backend> Infos for TensorKey<T, B> {
type Inner = MatZnxDft<T, B>;
fn inner(&self) -> &Self::Inner {
&self.keys[0].inner()
}
fn basek(&self) -> usize {
self.keys[0].basek()
}
fn k(&self) -> usize {
self.keys[0].k()
}
}
impl<T, B: Backend> TensorKey<T, B> {
pub fn rank(&self) -> usize {
self.keys[0].rank()
}
pub fn rank_in(&self) -> usize {
self.keys[0].rank_in()
}
pub fn rank_out(&self) -> usize {
self.keys[0].rank_out()
}
}
impl TensorKey<Vec<u8>, FFT64> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
module.bytes_of_scalar_znx_dft(1) + GLWESwitchingKey::encrypt_sk_scratch_space(module, rank, size)
}
}
impl<DataSelf> TensorKey<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
{
pub fn encrypt_sk<DataSk>(
&mut self,
module: &Module<FFT64>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch,
) where
ScalarZnxDft<DataSk, FFT64>: VecZnxDftToRef<FFT64> + ScalarZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.rank(), sk_dft.rank());
assert_eq!(self.n(), module.n());
assert_eq!(sk_dft.n(), module.n());
}
let rank: usize = self.rank();
(0..rank).for_each(|i| {
(i..rank).for_each(|j| {
let (mut sk_ij_dft, scratch1) = scratch.tmp_scalar_znx_dft(module, 1);
module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j);
let sk_ij: ScalarZnx<&mut [u8]> = module
.vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft())
.to_vec_znx_small()
.to_scalar_znx();
let sk_ij: SecretKey<&mut [u8]> = SecretKey {
data: sk_ij,
dist: sk_dft.dist,
};
self.at_mut(i, j).encrypt_sk(
module, &sk_ij, sk_dft, source_xa, source_xe, sigma, scratch1,
);
});
})
}
// Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j])
pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKey<DataSelf, FFT64> {
if i > j {
std::mem::swap(&mut i, &mut j);
};
let rank: usize = self.rank();
&mut self.keys[i * rank + j - (i * (i + 1) / 2)]
}
}
impl<DataSelf> TensorKey<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToRef<FFT64>,
{
// Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j])
pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey<DataSelf, FFT64> {
if i > j {
std::mem::swap(&mut i, &mut j);
};
let rank: usize = self.rank();
&self.keys[i * rank + j - (i * (i + 1) / 2)]
}
}

View File

@@ -0,0 +1,216 @@
use base2k::{FFT64, Module, ScalarZnxOps, ScratchOwned, Stats, VecZnxOps};
use sampling::source::Source;
use crate::{
automorphism::AutomorphismKey,
elem::{GetRow, Infos},
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::{SecretKey, SecretKeyFourier},
test_fft64::gglwe::log2_std_noise_gglwe_product,
};
#[test]
fn automorphism() {
(1..4).for_each(|rank| {
println!("test automorphism rank: {}", rank);
test_automorphism(-1, 5, 12, 12, 60, 3.2, rank);
});
}
#[test]
fn automorphism_inplace() {
(1..4).for_each(|rank| {
println!("test automorphism_inplace rank: {}", rank);
test_automorphism_inplace(-1, 5, 12, 12, 60, 3.2, rank);
});
}
fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows = (k_ksk + basek - 1) / basek;
let mut auto_key_in: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
let mut auto_key_out: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
let mut auto_key_apply: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(
AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key_in.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, auto_key_out.size())
| AutomorphismKey::automorphism_scratch_space(
&module,
auto_key_out.size(),
auto_key_in.size(),
auto_key_apply.size(),
rank,
),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
// gglwe_{s1}(s0) = s0 -> s1
auto_key_in.encrypt_sk(
&module,
p0,
&sk,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
// gglwe_{s2}(s1) -> s1 -> s2
auto_key_apply.encrypt_sk(
&module,
p1,
&sk,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
// gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0)
auto_key_out.automorphism(&module, &auto_key_in, &auto_key_apply, scratch.borrow());
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ksk);
let mut sk_auto: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk
(0..rank).for_each(|i| {
module.scalar_znx_automorphism(module.galois_element_inv(p0 * p1), &mut sk_auto, i, &sk, i);
});
let mut sk_auto_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_auto_dft.dft(&module, &sk_auto);
(0..auto_key_out.rank_in()).for_each(|col_i| {
(0..auto_key_out.rows()).for_each(|row_i| {
auto_key_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i);
let noise_have: f64 = pt.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product(
module.n() as f64,
basek,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
rank as f64,
k_ksk,
k_ksk,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
});
});
}
fn test_automorphism_inplace(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows = (k_ksk + basek - 1) / basek;
let mut auto_key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
let mut auto_key_apply: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(
AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, auto_key.size())
| AutomorphismKey::automorphism_inplace_scratch_space(&module, auto_key.size(), auto_key_apply.size(), rank),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
// gglwe_{s1}(s0) = s0 -> s1
auto_key.encrypt_sk(
&module,
p0,
&sk,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
// gglwe_{s2}(s1) -> s1 -> s2
auto_key_apply.encrypt_sk(
&module,
p1,
&sk,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
// gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0)
auto_key.automorphism_inplace(&module, &auto_key_apply, scratch.borrow());
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ksk);
let mut sk_auto: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk
(0..rank).for_each(|i| {
module.scalar_znx_automorphism(module.galois_element_inv(p0 * p1), &mut sk_auto, i, &sk, i);
});
let mut sk_auto_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_auto_dft.dft(&module, &sk_auto);
(0..auto_key.rank_in()).for_each(|col_i| {
(0..auto_key.rows()).for_each(|row_i| {
auto_key.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i);
let noise_have: f64 = pt.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product(
module.n() as f64,
basek,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
rank as f64,
k_ksk,
k_ksk,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
});
});
}

View File

@@ -0,0 +1,630 @@
use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToMut, ScratchOwned, Stats, VecZnxOps, ZnxViewMut};
use sampling::source::Source;
use crate::{
elem::{GetRow, Infos},
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::{SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
test_fft64::ggsw::noise_ggsw_product,
};
#[test]
fn encrypt_sk() {
(1..4).for_each(|rank_in| {
(1..4).for_each(|rank_out| {
println!("test encrypt_sk rank_in rank_out: {} {}", rank_in, rank_out);
test_encrypt_sk(12, 8, 54, 3.2, rank_in, rank_out);
});
});
}
#[test]
fn key_switch() {
(1..4).for_each(|rank_in_s0s1| {
(1..4).for_each(|rank_out_s0s1| {
(1..4).for_each(|rank_out_s1s2| {
println!(
"test key_switch : ({},{},{})",
rank_in_s0s1, rank_out_s0s1, rank_out_s1s2
);
test_key_switch(12, 15, 60, 3.2, rank_in_s0s1, rank_out_s0s1, rank_out_s1s2);
})
});
});
}
#[test]
fn key_switch_inplace() {
(1..4).for_each(|rank_in_s0s1| {
(1..4).for_each(|rank_out_s0s1| {
println!(
"test key_switch_inplace : ({},{})",
rank_in_s0s1, rank_out_s0s1
);
test_key_switch_inplace(12, 15, 60, 3.2, rank_in_s0s1, rank_out_s0s1);
});
});
}
#[test]
fn external_product() {
(1..4).for_each(|rank_in| {
(1..4).for_each(|rank_out| {
println!("test external_product rank: {} {}", rank_in, rank_out);
test_external_product(12, 12, 60, 3.2, rank_in, rank_out);
});
});
}
#[test]
fn external_product_inplace() {
(1..4).for_each(|rank_in| {
(1..4).for_each(|rank_out| {
println!(
"test external_product_inplace rank: {} {}",
rank_in, rank_out
);
test_external_product_inplace(12, 12, 60, 3.2, rank_in, rank_out);
});
});
}
fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in: usize, rank_out: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows = (k_ksk + basek - 1) / basek;
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in, rank_out);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ksk);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ksk.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ksk.size()),
);
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in);
sk_in.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in);
sk_in_dft.dft(&module, &sk_in);
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out);
sk_out.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out);
sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk(
&module,
&sk_in,
&sk_out_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank_out);
(0..ksk.rank_in()).for_each(|col_i| {
(0..ksk.rows()).for_each(|row_i| {
ksk.get_row(&module, row_i, col_i, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i);
let std_pt: f64 = pt.data.std(0, basek) * (k_ksk as f64).exp2();
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
});
});
}
fn test_key_switch(
log_n: usize,
basek: usize,
k_ksk: usize,
sigma: f64,
rank_in_s0s1: usize,
rank_out_s0s1: usize,
rank_out_s1s2: usize,
) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows = (k_ksk + basek - 1) / basek;
let mut ct_gglwe_s0s1: GLWESwitchingKey<Vec<u8>, FFT64> =
GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in_s0s1, rank_out_s0s1);
let mut ct_gglwe_s1s2: GLWESwitchingKey<Vec<u8>, FFT64> =
GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_out_s0s1, rank_out_s1s2);
let mut ct_gglwe_s0s2: GLWESwitchingKey<Vec<u8>, FFT64> =
GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in_s0s1, rank_out_s1s2);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_in_s0s1 | rank_out_s0s1, ct_gglwe_s0s1.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_s0s2.size())
| GLWESwitchingKey::keyswitch_scratch_space(
&module,
ct_gglwe_s0s2.size(),
ct_gglwe_s0s2.rank(),
ct_gglwe_s0s1.size(),
ct_gglwe_s0s1.rank(),
ct_gglwe_s1s2.size(),
),
);
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in_s0s1);
sk0.fill_ternary_prob(0.5, &mut source_xs);
let mut sk0_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in_s0s1);
sk0_dft.dft(&module, &sk0);
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out_s0s1);
sk1.fill_ternary_prob(0.5, &mut source_xs);
let mut sk1_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out_s0s1);
sk1_dft.dft(&module, &sk1);
let mut sk2: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out_s1s2);
sk2.fill_ternary_prob(0.5, &mut source_xs);
let mut sk2_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out_s1s2);
sk2_dft.dft(&module, &sk2);
// gglwe_{s1}(s0) = s0 -> s1
ct_gglwe_s0s1.encrypt_sk(
&module,
&sk0,
&sk1_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
// gglwe_{s2}(s1) -> s1 -> s2
ct_gglwe_s1s2.encrypt_sk(
&module,
&sk1,
&sk2_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
// gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0)
ct_gglwe_s0s2.keyswitch(&module, &ct_gglwe_s0s1, &ct_gglwe_s1s2, scratch.borrow());
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank_out_s1s2);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ksk);
(0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| {
(0..ct_gglwe_s0s2.rows()).for_each(|row_i| {
ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, col_i);
let noise_have: f64 = pt.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product(
module.n() as f64,
basek,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
rank_out_s0s1 as f64,
k_ksk,
k_ksk,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
});
});
}
fn test_key_switch_inplace(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in_s0s1: usize, rank_out_s0s1: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ksk + basek - 1) / basek;
let mut ct_gglwe_s0s1: GLWESwitchingKey<Vec<u8>, FFT64> =
GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in_s0s1, rank_out_s0s1);
let mut ct_gglwe_s1s2: GLWESwitchingKey<Vec<u8>, FFT64> =
GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_out_s0s1, rank_out_s0s1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out_s0s1, ct_gglwe_s0s1.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_s0s1.size())
| GLWESwitchingKey::keyswitch_inplace_scratch_space(
&module,
ct_gglwe_s0s1.size(),
ct_gglwe_s0s1.rank(),
ct_gglwe_s1s2.size(),
),
);
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in_s0s1);
sk0.fill_ternary_prob(0.5, &mut source_xs);
let mut sk0_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in_s0s1);
sk0_dft.dft(&module, &sk0);
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out_s0s1);
sk1.fill_ternary_prob(0.5, &mut source_xs);
let mut sk1_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out_s0s1);
sk1_dft.dft(&module, &sk1);
let mut sk2: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out_s0s1);
sk2.fill_ternary_prob(0.5, &mut source_xs);
let mut sk2_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out_s0s1);
sk2_dft.dft(&module, &sk2);
// gglwe_{s1}(s0) = s0 -> s1
ct_gglwe_s0s1.encrypt_sk(
&module,
&sk0,
&sk1_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
// gglwe_{s2}(s1) -> s1 -> s2
ct_gglwe_s1s2.encrypt_sk(
&module,
&sk1,
&sk2_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
// gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0)
ct_gglwe_s0s1.keyswitch_inplace(&module, &ct_gglwe_s1s2, scratch.borrow());
let ct_gglwe_s0s2: GLWESwitchingKey<Vec<u8>, FFT64> = ct_gglwe_s0s1;
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank_out_s0s1);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ksk);
(0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| {
(0..ct_gglwe_s0s2.rows()).for_each(|row_i| {
ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, col_i);
let noise_have: f64 = pt.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product(
module.n() as f64,
basek,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
rank_out_s0s1 as f64,
k_ksk,
k_ksk,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
});
});
}
fn test_external_product(log_n: usize, basek: usize, k: usize, sigma: f64, rank_in: usize, rank_out: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k + basek - 1) / basek;
let mut ct_gglwe_in: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank_in, rank_out);
let mut ct_gglwe_out: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank_in, rank_out);
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank_out);
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ct_gglwe_in.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_out.size())
| GLWESwitchingKey::external_product_scratch_space(
&module,
ct_gglwe_out.size(),
ct_gglwe_in.size(),
ct_rgsw.size(),
rank_out,
)
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank_out, ct_rgsw.size()),
);
let r: usize = 1;
pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r}
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in);
sk_in.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in);
sk_in_dft.dft(&module, &sk_in);
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out);
sk_out.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out);
sk_out_dft.dft(&module, &sk_out);
// gglwe_{s1}(s0) = s0 -> s1
ct_gglwe_in.encrypt_sk(
&module,
&sk_in,
&sk_out_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_rgsw.encrypt_sk(
&module,
&pt_rgsw,
&sk_out_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
// gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k)
ct_gglwe_out.external_product(&module, &ct_gglwe_in, &ct_rgsw, scratch.borrow());
scratch = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ct_gglwe_in.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_out.size())
| GLWESwitchingKey::external_product_scratch_space(
&module,
ct_gglwe_out.size(),
ct_gglwe_in.size(),
ct_rgsw.size(),
rank_out,
)
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank_out, ct_rgsw.size()),
);
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank_out);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
(0..rank_in).for_each(|i| {
module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data, i); // * X^{r}
});
(0..rank_in).for_each(|col_i| {
(0..ct_gglwe_out.rows()).for_each(|row_i| {
ct_gglwe_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i);
let noise_have: f64 = pt.data.std(0, basek).log2();
let var_gct_err_lhs: f64 = sigma * sigma;
let var_gct_err_rhs: f64 = 0f64;
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_ggsw_product(
module.n() as f64,
basek,
0.5,
var_msg,
var_a0_err,
var_a1_err,
var_gct_err_lhs,
var_gct_err_rhs,
rank_out as f64,
k,
k,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
});
});
}
fn test_external_product_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank_in: usize, rank_out: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k + basek - 1) / basek;
let mut ct_gglwe: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank_in, rank_out);
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank_out);
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ct_gglwe.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe.size())
| GLWESwitchingKey::external_product_inplace_scratch_space(&module, ct_gglwe.size(), ct_rgsw.size(), rank_out)
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank_out, ct_rgsw.size()),
);
let r: usize = 1;
pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r}
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in);
sk_in.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in);
sk_in_dft.dft(&module, &sk_in);
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out);
sk_out.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out);
sk_out_dft.dft(&module, &sk_out);
// gglwe_{s1}(s0) = s0 -> s1
ct_gglwe.encrypt_sk(
&module,
&sk_in,
&sk_out_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_rgsw.encrypt_sk(
&module,
&pt_rgsw,
&sk_out_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
// gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k)
ct_gglwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow());
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank_out);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
(0..rank_in).for_each(|i| {
module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data, i); // * X^{r}
});
(0..rank_in).for_each(|col_i| {
(0..ct_gglwe.rows()).for_each(|row_i| {
ct_gglwe.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i);
let noise_have: f64 = pt.data.std(0, basek).log2();
let var_gct_err_lhs: f64 = sigma * sigma;
let var_gct_err_rhs: f64 = 0f64;
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_ggsw_product(
module.n() as f64,
basek,
0.5,
var_msg,
var_a0_err,
var_a1_err,
var_gct_err_lhs,
var_gct_err_rhs,
rank_out as f64,
k,
k,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
});
});
}
pub(crate) fn var_noise_gglwe_product(
n: f64,
basek: usize,
var_xs: f64,
var_msg: f64,
var_a_err: f64,
var_gct_err_lhs: f64,
var_gct_err_rhs: f64,
rank_in: f64,
a_logq: usize,
b_logq: usize,
) -> f64 {
let a_logq: usize = a_logq.min(b_logq);
let a_cols: usize = (a_logq + basek - 1) / basek;
let b_scale = 2.0f64.powi(b_logq as i32);
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
let base: f64 = (1 << (basek)) as f64;
let var_base: f64 = base * base / 12f64;
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
noise += var_msg * var_a_err * a_scale * a_scale * n;
noise *= rank_in;
noise /= b_scale * b_scale;
noise
}
pub(crate) fn log2_std_noise_gglwe_product(
n: f64,
basek: usize,
var_xs: f64,
var_msg: f64,
var_a_err: f64,
var_gct_err_lhs: f64,
var_gct_err_rhs: f64,
rank_in: f64,
a_logq: usize,
b_logq: usize,
) -> f64 {
let mut noise: f64 = var_noise_gglwe_product(
n,
basek,
var_xs,
var_msg,
var_a_err,
var_gct_err_lhs,
var_gct_err_rhs,
rank_in,
a_logq,
b_logq,
);
noise = noise.sqrt();
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
}

934
core/src/test_fft64/ggsw.rs Normal file
View File

@@ -0,0 +1,934 @@
use base2k::{
FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScalarZnxOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc,
VecZnxBigOps, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero,
};
use sampling::source::Source;
use crate::{
automorphism::AutomorphismKey,
elem::{GetRow, Infos},
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::{SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
tensor_key::TensorKey,
};
use super::gglwe::var_noise_gglwe_product;
#[test]
fn encrypt_sk() {
(1..4).for_each(|rank| {
println!("test encrypt_sk rank: {}", rank);
test_encrypt_sk(11, 8, 54, 3.2, rank);
});
}
#[test]
fn keyswitch() {
(1..4).for_each(|rank| {
println!("test keyswitch rank: {}", rank);
test_keyswitch(12, 15, 60, rank, 3.2);
});
}
#[test]
fn keyswitch_inplace() {
(1..4).for_each(|rank| {
println!("test keyswitch_inplace rank: {}", rank);
test_keyswitch_inplace(12, 15, 60, rank, 3.2);
});
}
#[test]
fn automorphism() {
(1..4).for_each(|rank| {
println!("test automorphism rank: {}", rank);
test_automorphism(-5, 12, 15, 60, rank, 3.2);
});
}
#[test]
fn automorphism_inplace() {
(1..4).for_each(|rank| {
println!("test automorphism_inplace rank: {}", rank);
test_automorphism_inplace(-5, 12, 15, 60, rank, 3.2);
});
}
#[test]
fn external_product() {
(1..4).for_each(|rank| {
println!("test external_product rank: {}", rank);
test_external_product(12, 12, 60, rank, 3.2);
});
}
#[test]
fn external_product_inplace() {
(1..4).for_each(|rank| {
println!("test external_product rank: {}", rank);
test_external_product_inplace(12, 15, 60, rank, 3.2);
});
}
fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ggsw + basek - 1) / basek;
let mut ct: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
let mut scratch: ScratchOwned = ScratchOwned::new(
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
ct.encrypt_sk(
&module,
&pt_scalar,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct.size());
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct.size());
(0..ct.rank() + 1).for_each(|col_j| {
(0..ct.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
// mul with sk[col_j-1]
if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
}
ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let std_pt: f64 = pt_have.data.std(0, basek) * (k_ggsw as f64).exp2();
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
pt_want.data.zero();
});
});
}
fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k + basek - 1) / basek;
let mut ct_in: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
let mut ct_out: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
let mut tsk: TensorKey<Vec<u8>, FFT64> = TensorKey::new(&module, basek, k, rows, rank);
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank, rank);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_in.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_out.size())
| GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size())
| TensorKey::encrypt_sk_scratch_space(&module, rank, ksk.size())
| GGSWCiphertext::keyswitch_scratch_space(
&module,
ct_out.size(),
ct_in.size(),
ksk.size(),
tsk.size(),
rank,
),
);
let var_xs: f64 = 0.5;
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk_in.fill_ternary_prob(var_xs, &mut source_xs);
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_in_dft.dft(&module, &sk_in);
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk_out.fill_ternary_prob(var_xs, &mut source_xs);
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk(
&module,
&sk_in,
&sk_out_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
tsk.encrypt_sk(
&module,
&sk_out_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
ct_in.encrypt_sk(
&module,
&pt_scalar,
&sk_in_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_out.keyswitch(&module, &ct_in, &ksk, &tsk, scratch.borrow());
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank);
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_out.size());
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_out.size());
(0..ct_out.rank() + 1).for_each(|col_j| {
(0..ct_out.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
// mul with sk[col_j-1]
if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
}
ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = noise_ggsw_keyswitch(
module.n() as f64,
basek,
col_j,
var_xs,
0f64,
sigma * sigma,
0f64,
rank as f64,
k,
k,
);
println!("{} {}", noise_have, noise_want);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
pt_want.data.zero();
});
});
}
fn test_keyswitch_inplace(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k + basek - 1) / basek;
let mut ct: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
let mut tsk: TensorKey<Vec<u8>, FFT64> = TensorKey::new(&module, basek, k, rows, rank);
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank, rank);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size())
| GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size())
| TensorKey::encrypt_sk_scratch_space(&module, rank, ksk.size())
| GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct.size(), ksk.size(), tsk.size(), rank),
);
let var_xs: f64 = 0.5;
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk_in.fill_ternary_prob(var_xs, &mut source_xs);
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_in_dft.dft(&module, &sk_in);
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk_out.fill_ternary_prob(var_xs, &mut source_xs);
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk(
&module,
&sk_in,
&sk_out_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
tsk.encrypt_sk(
&module,
&sk_out_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
ct.encrypt_sk(
&module,
&pt_scalar,
&sk_in_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct.keyswitch_inplace(&module, &ksk, &tsk, scratch.borrow());
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank);
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct.size());
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct.size());
(0..ct.rank() + 1).for_each(|col_j| {
(0..ct.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
// mul with sk[col_j-1]
if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
}
ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = noise_ggsw_keyswitch(
module.n() as f64,
basek,
col_j,
var_xs,
0f64,
sigma * sigma,
0f64,
rank as f64,
k,
k,
);
println!("{} {}", noise_have, noise_want);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
pt_want.data.zero();
});
});
}
pub(crate) fn noise_ggsw_keyswitch(
n: f64,
basek: usize,
col: usize,
var_xs: f64,
var_a_err: f64,
var_gct_err_lhs: f64,
var_gct_err_rhs: f64,
rank: f64,
a_logq: usize,
b_logq: usize,
) -> f64 {
let var_si_x_sj: f64 = n * var_xs * var_xs;
// Initial KS for col = 0
let mut noise: f64 = var_noise_gglwe_product(
n,
basek,
var_xs,
var_xs,
var_a_err,
var_gct_err_lhs,
var_gct_err_rhs,
rank,
a_logq,
b_logq,
);
// Other GGSW reconstruction for col > 0
if col > 0 {
noise += var_noise_gglwe_product(
n,
basek,
var_xs,
var_si_x_sj,
var_a_err + 1f64 / 12.0,
var_gct_err_lhs,
var_gct_err_rhs,
rank,
a_logq,
b_logq,
);
noise += n * noise * var_xs * 0.5;
}
noise = noise.sqrt();
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
}
fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k + basek - 1) / basek;
let mut ct_in: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
let mut ct_out: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
let mut tensor_key: TensorKey<Vec<u8>, FFT64> = TensorKey::new(&module, basek, k, rows, rank);
let mut auto_key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_in.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_out.size())
| AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size())
| TensorKey::encrypt_sk_scratch_space(&module, rank, tensor_key.size())
| GGSWCiphertext::automorphism_scratch_space(
&module,
ct_out.size(),
ct_in.size(),
auto_key.size(),
tensor_key.size(),
rank,
),
);
let var_xs: f64 = 0.5;
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(var_xs, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
auto_key.encrypt_sk(
&module,
p,
&sk,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
tensor_key.encrypt_sk(
&module,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
ct_in.encrypt_sk(
&module,
&pt_scalar,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_out.automorphism(&module, &ct_in, &auto_key, &tensor_key, scratch.borrow());
module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0);
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank);
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_out.size());
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_out.size());
(0..ct_out.rank() + 1).for_each(|col_j| {
(0..ct_out.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
// mul with sk[col_j-1]
if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
}
ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = noise_ggsw_keyswitch(
module.n() as f64,
basek,
col_j,
var_xs,
0f64,
sigma * sigma,
0f64,
rank as f64,
k,
k,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
pt_want.data.zero();
});
});
}
fn test_automorphism_inplace(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k + basek - 1) / basek;
let mut ct: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
let mut tensor_key: TensorKey<Vec<u8>, FFT64> = TensorKey::new(&module, basek, k, rows, rank);
let mut auto_key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size())
| AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size())
| TensorKey::encrypt_sk_scratch_space(&module, rank, tensor_key.size())
| GGSWCiphertext::automorphism_inplace_scratch_space(&module, ct.size(), auto_key.size(), tensor_key.size(), rank),
);
let var_xs: f64 = 0.5;
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(var_xs, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
auto_key.encrypt_sk(
&module,
p,
&sk,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
tensor_key.encrypt_sk(
&module,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
ct.encrypt_sk(
&module,
&pt_scalar,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct.automorphism_inplace(&module, &auto_key, &tensor_key, scratch.borrow());
module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0);
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank);
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct.size());
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct.size());
(0..ct.rank() + 1).for_each(|col_j| {
(0..ct.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
// mul with sk[col_j-1]
if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
}
ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = noise_ggsw_keyswitch(
module.n() as f64,
basek,
col_j,
var_xs,
0f64,
sigma * sigma,
0f64,
rank as f64,
k,
k,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
pt_want.data.zero();
});
});
}
fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ggsw + basek - 1) / basek;
let mut ct_ggsw_rhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut ct_ggsw_lhs_in: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut ct_ggsw_lhs_out: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut pt_ggsw_lhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut pt_ggsw_rhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs);
let k: usize = 1;
pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size())
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs_in.size())
| GGSWCiphertext::external_product_scratch_space(
&module,
ct_ggsw_lhs_out.size(),
ct_ggsw_lhs_in.size(),
ct_ggsw_rhs.size(),
rank,
),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
ct_ggsw_rhs.encrypt_sk(
&module,
&pt_ggsw_rhs,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_ggsw_lhs_in.encrypt_sk(
&module,
&pt_ggsw_lhs,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_ggsw_lhs_out.external_product(&module, &ct_ggsw_lhs_in, &ct_ggsw_rhs, scratch.borrow());
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size());
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size());
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0);
(0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| {
(0..ct_ggsw_lhs_out.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0);
if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
}
ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
let noise_have: f64 = pt.data.std(0, basek).log2();
let var_gct_err_lhs: f64 = sigma * sigma;
let var_gct_err_rhs: f64 = 0f64;
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_ggsw_product(
module.n() as f64,
basek,
0.5,
var_msg,
var_a0_err,
var_a1_err,
var_gct_err_lhs,
var_gct_err_rhs,
rank as f64,
k_ggsw,
k_ggsw,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"have: {} want: {}",
noise_have,
noise_want
);
pt_want.data.zero();
});
});
}
fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ggsw + basek - 1) / basek;
let mut ct_ggsw_rhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut ct_ggsw_lhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut pt_ggsw_lhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut pt_ggsw_rhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs);
let k: usize = 1;
pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size())
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs.size())
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs.size())
| GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_ggsw_lhs.size(), ct_ggsw_rhs.size(), rank),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
ct_ggsw_rhs.encrypt_sk(
&module,
&pt_ggsw_rhs,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_ggsw_lhs.encrypt_sk(
&module,
&pt_ggsw_lhs,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_ggsw_lhs.external_product_inplace(&module, &ct_ggsw_rhs, scratch.borrow());
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs.size());
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs.size());
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0);
(0..ct_ggsw_lhs.rank() + 1).for_each(|col_j| {
(0..ct_ggsw_lhs.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0);
if col_j > 0 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
}
ct_ggsw_lhs.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
let noise_have: f64 = pt.data.std(0, basek).log2();
let var_gct_err_lhs: f64 = sigma * sigma;
let var_gct_err_rhs: f64 = 0f64;
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_ggsw_product(
module.n() as f64,
basek,
0.5,
var_msg,
var_a0_err,
var_a1_err,
var_gct_err_lhs,
var_gct_err_rhs,
rank as f64,
k_ggsw,
k_ggsw,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"have: {} want: {}",
noise_have,
noise_want
);
pt_want.data.zero();
});
});
}
pub(crate) fn noise_ggsw_product(
n: f64,
basek: usize,
var_xs: f64,
var_msg: f64,
var_a0_err: f64,
var_a1_err: f64,
var_gct_err_lhs: f64,
var_gct_err_rhs: f64,
rank: f64,
a_logq: usize,
b_logq: usize,
) -> f64 {
let a_logq: usize = a_logq.min(b_logq);
let a_cols: usize = (a_logq + basek - 1) / basek;
let b_scale = 2.0f64.powi(b_logq as i32);
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
let base: f64 = (1 << (basek)) as f64;
let var_base: f64 = base * base / 12f64;
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
let mut noise: f64 = (rank + 1.0) * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
noise += var_msg * var_a0_err * a_scale * a_scale * n;
noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs * rank;
noise = noise.sqrt();
noise /= b_scale;
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
}

805
core/src/test_fft64/glwe.rs Normal file
View File

@@ -0,0 +1,805 @@
use base2k::{
Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut,
ZnxViewMut, ZnxZero,
};
use itertools::izip;
use sampling::source::Source;
use crate::{
automorphism::AutomorphismKey,
elem::Infos,
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext::GLWECiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::{GLWEPublicKey, SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
test_fft64::{gglwe::log2_std_noise_gglwe_product, ggsw::noise_ggsw_product},
};
#[test]
fn encrypt_sk() {
(1..4).for_each(|rank| {
println!("test encrypt_sk rank: {}", rank);
test_encrypt_sk(11, 8, 54, 30, 3.2, rank);
});
}
#[test]
fn encrypt_zero_sk() {
(1..4).for_each(|rank| {
println!("test encrypt_zero_sk rank: {}", rank);
test_encrypt_zero_sk(11, 8, 64, 3.2, rank);
});
}
#[test]
fn encrypt_pk() {
(1..4).for_each(|rank| {
println!("test encrypt_pk rank: {}", rank);
test_encrypt_pk(11, 8, 64, 64, 3.2, rank)
});
}
#[test]
fn keyswitch() {
(1..4).for_each(|rank_in| {
(1..4).for_each(|rank_out| {
println!("test keyswitch rank_in: {} rank_out: {}", rank_in, rank_out);
test_keyswitch(12, 12, 60, 45, 60, rank_in, rank_out, 3.2);
});
});
}
#[test]
fn keyswitch_inplace() {
(1..4).for_each(|rank| {
println!("test keyswitch_inplace rank: {}", rank);
test_keyswitch_inplace(12, 12, 60, 45, rank, 3.2);
});
}
#[test]
fn external_product() {
(1..4).for_each(|rank| {
println!("test external_product rank: {}", rank);
test_external_product(12, 12, 60, 45, 60, rank, 3.2);
});
}
#[test]
fn external_product_inplace() {
(1..4).for_each(|rank| {
println!("test external_product rank: {}", rank);
test_external_product_inplace(12, 15, 60, 60, rank, 3.2);
});
}
#[test]
fn automorphism_inplace() {
(1..4).for_each(|rank| {
println!("test automorphism_inplace rank: {}", rank);
test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2);
});
}
#[test]
fn automorphism() {
(1..4).for_each(|rank| {
println!("test automorphism rank: {}", rank);
test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2);
});
}
fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_pt);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct.size()),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
let mut data_want: Vec<i64> = vec![0i64; module.n()];
data_want
.iter_mut()
.for_each(|x| *x = source_xa.next_i64() & 0xFF);
pt.data.encode_vec_i64(0, basek, k_pt, &data_want, 10);
ct.encrypt_sk(
&module,
&pt,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
pt.data.zero();
ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
let mut data_have: Vec<i64> = vec![0i64; module.n()];
pt.data
.decode_vec_i64(0, basek, pt.size() * basek, &mut data_have);
// TODO: properly assert the decryption noise through std(dec(ct) - pt)
let scale: f64 = (1 << (pt.size() * basek - k_pt)) as f64;
izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| {
let b_scaled = (*b as f64) / scale;
assert!(
(*a as f64 - b_scaled).abs() < 0.1,
"{} {}",
*a as f64,
b_scaled
)
});
}
fn test_encrypt_zero_sk(log_n: usize, basek: usize, k_ct: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([1u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
let mut ct_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct, rank);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWECiphertextFourier::decrypt_scratch_space(&module, ct_dft.size())
| GLWECiphertextFourier::encrypt_sk_scratch_space(&module, rank, ct_dft.size()),
);
ct_dft.encrypt_zero_sk(
&module,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
assert!((sigma - pt.data.std(0, basek) * (k_ct as f64).exp2()) <= 0.2);
}
fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut source_xu: Source = Source::new([0u8; 32]);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
let mut pk: GLWEPublicKey<Vec<u8>, FFT64> = GLWEPublicKey::new(&module, basek, k_pk, rank);
pk.generate(&module, &sk_dft, &mut source_xa, &mut source_xe, sigma);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size())
| GLWECiphertext::decrypt_scratch_space(&module, ct.size())
| GLWECiphertext::encrypt_pk_scratch_space(&module, pk.size()),
);
let mut data_want: Vec<i64> = vec![0i64; module.n()];
data_want
.iter_mut()
.for_each(|x| *x = source_xa.next_i64() & 0);
pt_want.data.encode_vec_i64(0, basek, k_ct, &data_want, 10);
ct.encrypt_pk(
&module,
&pt_want,
&pk,
&mut source_xu,
&mut source_xe,
sigma,
scratch.borrow(),
);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0);
let noise_have: f64 = pt_want.data.std(0, basek).log2();
let noise_want: f64 = ((((rank as f64) + 1.0) * module.n() as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (k_ct as f64);
assert!(
(noise_have - noise_want).abs() < 0.2,
"{} {}",
noise_have,
noise_want
);
}
fn test_keyswitch(
log_n: usize,
basek: usize,
k_keyswitch: usize,
k_ct_in: usize,
k_ct_out: usize,
rank_in: usize,
rank_out: usize,
sigma: f64,
) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ct_in + basek - 1) / basek;
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k_keyswitch, rows, rank_in, rank_out);
let mut ct_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_in, rank_in);
let mut ct_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_out, rank_out);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_in);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_out);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_in, ksk.size())
| GLWECiphertext::decrypt_scratch_space(&module, ct_out.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size())
| GLWECiphertext::keyswitch_scratch_space(
&module,
ct_out.size(),
rank_out,
ct_in.size(),
rank_in,
ksk.size(),
),
);
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in);
sk_in.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in);
sk_in_dft.dft(&module, &sk_in);
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out);
sk_out.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out);
sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk(
&module,
&sk_in,
&sk_out_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_in.encrypt_sk(
&module,
&pt_want,
&sk_in_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_out.keyswitch(&module, &ct_in, &ksk, scratch.borrow());
ct_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product(
module.n() as f64,
basek,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
rank_in as f64,
k_ct_in,
k_keyswitch,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
}
fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ct + basek - 1) / basek;
let mut ct_grlwe: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank, rank);
let mut ct_rlwe: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size())
| GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size())
| GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), rank, ct_grlwe.size()),
);
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk0.fill_ternary_prob(0.5, &mut source_xs);
let mut sk0_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk0_dft.dft(&module, &sk0);
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk1.fill_ternary_prob(0.5, &mut source_xs);
let mut sk1_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk1_dft.dft(&module, &sk1);
ct_grlwe.encrypt_sk(
&module,
&sk0,
&sk1_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_rlwe.encrypt_sk(
&module,
&pt_want,
&sk0_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_rlwe.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow());
ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product(
module.n() as f64,
basek,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
rank as f64,
k_ct,
k_ksk,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
}
fn test_automorphism(
log_n: usize,
basek: usize,
p: i64,
k_autokey: usize,
k_ct_in: usize,
k_ct_out: usize,
rank: usize,
sigma: f64,
) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ct_in + basek - 1) / basek;
let mut autokey: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_autokey, rows, rank);
let mut ct_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_in, rank);
let mut ct_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_out, rank);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_in);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_out);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
pt_want
.data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, autokey.size())
| GLWECiphertext::decrypt_scratch_space(&module, ct_out.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size())
| GLWECiphertext::automorphism_scratch_space(&module, ct_out.size(), rank, ct_in.size(), autokey.size()),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
autokey.encrypt_sk(
&module,
p,
&sk,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_in.encrypt_sk(
&module,
&pt_want,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_out.automorphism(&module, &ct_in, &autokey, scratch.borrow());
ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_automorphism_inplace(p, &mut pt_want, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
module.vec_znx_normalize_inplace(basek, &mut pt_have, 0, scratch.borrow());
let noise_have: f64 = pt_have.data.std(0, basek).log2();
println!("{}", noise_have);
let noise_want: f64 = log2_std_noise_gglwe_product(
module.n() as f64,
basek,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
rank as f64,
k_ct_in,
k_autokey,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
}
fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usize, k_ct: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ct + basek - 1) / basek;
let mut autokey: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_autokey, rows, rank);
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, autokey.size())
| GLWECiphertext::decrypt_scratch_space(&module, ct.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size())
| GLWECiphertext::automorphism_inplace_scratch_space(&module, ct.size(), rank, autokey.size()),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
autokey.encrypt_sk(
&module,
p,
&sk,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct.encrypt_sk(
&module,
&pt_want,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct.automorphism_inplace(&module, &autokey, scratch.borrow());
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_automorphism_inplace(p, &mut pt_want, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
module.vec_znx_normalize_inplace(basek, &mut pt_have, 0, scratch.borrow());
let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product(
module.n() as f64,
basek,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
rank as f64,
k_ct,
k_autokey,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
}
fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usize, k_ct_out: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ct_in + basek - 1) / basek;
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut ct_rlwe_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_in, rank);
let mut ct_rlwe_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_out, rank);
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_in);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_out);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
pt_want.to_mut().at_mut(0, 0)[1] = 1;
let k: usize = 1;
pt_rgsw.raw_mut()[k] = 1; // X^{k}
let mut scratch: ScratchOwned = ScratchOwned::new(
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size())
| GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size())
| GLWECiphertext::external_product_scratch_space(
&module,
ct_rlwe_out.size(),
ct_rlwe_in.size(),
ct_rgsw.size(),
rank,
),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
ct_rgsw.encrypt_sk(
&module,
&pt_rgsw,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_rlwe_in.encrypt_sk(
&module,
&pt_want,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_rlwe_out.external_product(&module, &ct_rlwe_in, &ct_rgsw, scratch.borrow());
ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2();
let var_gct_err_lhs: f64 = sigma * sigma;
let var_gct_err_rhs: f64 = 0f64;
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_ggsw_product(
module.n() as f64,
basek,
0.5,
var_msg,
var_a0_err,
var_a1_err,
var_gct_err_lhs,
var_gct_err_rhs,
rank as f64,
k_ct_in,
k_ggsw,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
}
fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ct + basek - 1) / basek;
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut ct_rlwe: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
pt_want.to_mut().at_mut(0, 0)[1] = 1;
let k: usize = 1;
pt_rgsw.raw_mut()[k] = 1; // X^{k}
let mut scratch: ScratchOwned = ScratchOwned::new(
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size())
| GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size())
| GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size(), rank),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
ct_rgsw.encrypt_sk(
&module,
&pt_rgsw,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_rlwe.encrypt_sk(
&module,
&pt_want,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_rlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow());
ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2();
let var_gct_err_lhs: f64 = sigma * sigma;
let var_gct_err_rhs: f64 = 0f64;
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_ggsw_product(
module.n() as f64,
basek,
0.5,
var_msg,
var_a0_err,
var_a1_err,
var_gct_err_lhs,
var_gct_err_rhs,
rank as f64,
k_ct,
k_ggsw,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
}

View File

@@ -0,0 +1,445 @@
use crate::{
elem::Infos,
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext::GLWECiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::{SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
test_fft64::{gglwe::log2_std_noise_gglwe_product, ggsw::noise_ggsw_product},
};
use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut};
use sampling::source::Source;
#[test]
fn keyswitch() {
(1..4).for_each(|rank_in| {
(1..4).for_each(|rank_out| {
println!("test keyswitch rank_in: {} rank_out: {}", rank_in, rank_out);
test_keyswitch(12, 12, 60, 45, 60, rank_in, rank_out, 3.2);
});
});
}
#[test]
fn keyswitch_inplace() {
(1..4).for_each(|rank| {
println!("test keyswitch_inplace rank: {}", rank);
test_keyswitch_inplace(12, 12, 60, 45, rank, 3.2);
});
}
#[test]
fn external_product() {
(1..4).for_each(|rank| {
println!("test external_product rank: {}", rank);
test_external_product(12, 12, 60, 45, 60, rank, 3.2);
});
}
#[test]
fn external_product_inplace() {
(1..4).for_each(|rank| {
println!("test external_product rank: {}", rank);
test_external_product_inplace(12, 15, 60, 60, rank, 3.2);
});
}
fn test_keyswitch(
log_n: usize,
basek: usize,
k_ksk: usize,
k_ct_in: usize,
k_ct_out: usize,
rank_in: usize,
rank_out: usize,
sigma: f64,
) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ct_in + basek - 1) / basek;
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in, rank_out);
let mut ct_glwe_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_in, rank_in);
let mut ct_glwe_dft_in: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct_in, rank_in);
let mut ct_glwe_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_out, rank_out);
let mut ct_glwe_dft_out: GLWECiphertextFourier<Vec<u8>, FFT64> =
GLWECiphertextFourier::new(&module, basek, k_ct_out, rank_out);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_in);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_out);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ksk.size())
| GLWECiphertext::decrypt_scratch_space(&module, ct_glwe_out.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_glwe_in.size())
| GLWECiphertextFourier::keyswitch_scratch_space(
&module,
ct_glwe_out.size(),
rank_out,
ct_glwe_in.size(),
rank_in,
ksk.size(),
),
);
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in);
sk_in.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in);
sk_in_dft.dft(&module, &sk_in);
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out);
sk_out.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out);
sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk(
&module,
&sk_in,
&sk_out_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_glwe_in.encrypt_sk(
&module,
&pt_want,
&sk_in_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_glwe_in.dft(&module, &mut ct_glwe_dft_in);
ct_glwe_dft_out.keyswitch(&module, &ct_glwe_dft_in, &ksk, scratch.borrow());
ct_glwe_dft_out.idft(&module, &mut ct_glwe_out, scratch.borrow());
ct_glwe_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product(
module.n() as f64,
basek,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
rank_in as f64,
k_ct_in,
k_ksk,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
}
fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ct + basek - 1) / basek;
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank, rank);
let mut ct_glwe: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
let mut ct_rlwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct, rank);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
let mut scratch: ScratchOwned = ScratchOwned::new(
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size())
| GLWECiphertext::decrypt_scratch_space(&module, ct_glwe.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_glwe.size())
| GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, ct_rlwe_dft.size(), ksk.size(), rank),
);
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk_in.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_in_dft.dft(&module, &sk_in);
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk_out.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_out_dft.dft(&module, &sk_out);
ksk.encrypt_sk(
&module,
&sk_in,
&sk_out_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_glwe.encrypt_sk(
&module,
&pt_want,
&sk_in_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_glwe.dft(&module, &mut ct_rlwe_dft);
ct_rlwe_dft.keyswitch_inplace(&module, &ksk, scratch.borrow());
ct_rlwe_dft.idft(&module, &mut ct_glwe, scratch.borrow());
ct_glwe.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2();
let noise_want: f64 = log2_std_noise_gglwe_product(
module.n() as f64,
basek,
0.5,
0.5,
0f64,
sigma * sigma,
0f64,
rank as f64,
k_ct,
k_ksk,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
}
fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usize, k_ct_out: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ct_in + basek - 1) / basek;
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut ct_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_in, rank);
let mut ct_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_out, rank);
let mut ct_in_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct_in, rank);
let mut ct_out_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct_out, rank);
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_in);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_out);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
pt_want.to_mut().at_mut(0, 0)[1] = 1;
let k: usize = 1;
pt_rgsw.raw_mut()[k] = 1; // X^{k}
let mut scratch: ScratchOwned = ScratchOwned::new(
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size())
| GLWECiphertext::decrypt_scratch_space(&module, ct_out.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size())
| GLWECiphertextFourier::external_product_scratch_space(&module, ct_out.size(), ct_in.size(), ct_rgsw.size(), rank),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
ct_rgsw.encrypt_sk(
&module,
&pt_rgsw,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_in.encrypt_sk(
&module,
&pt_want,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct_in.dft(&module, &mut ct_in_dft);
ct_out_dft.external_product(&module, &ct_in_dft, &ct_rgsw, scratch.borrow());
ct_out_dft.idft(&module, &mut ct_out, scratch.borrow());
ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2();
let var_gct_err_lhs: f64 = sigma * sigma;
let var_gct_err_rhs: f64 = 0f64;
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_ggsw_product(
module.n() as f64,
basek,
0.5,
var_msg,
var_a0_err,
var_a1_err,
var_gct_err_lhs,
var_gct_err_rhs,
rank as f64,
k_ct_in,
k_ggsw,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
}
fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct: usize, rank: usize, sigma: f64) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k_ct + basek - 1) / basek;
let mut ct_ggsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
let mut ct_rlwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct, rank);
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
// Random input plaintext
pt_want
.data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
pt_want.to_mut().at_mut(0, 0)[1] = 1;
let k: usize = 1;
pt_rgsw.raw_mut()[k] = 1; // X^{k}
let mut scratch: ScratchOwned = ScratchOwned::new(
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw.size())
| GLWECiphertext::decrypt_scratch_space(&module, ct.size())
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size())
| GLWECiphertextFourier::external_product_inplace_scratch_space(&module, ct.size(), ct_ggsw.size(), rank),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
ct_ggsw.encrypt_sk(
&module,
&pt_rgsw,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct.encrypt_sk(
&module,
&pt_want,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
ct.dft(&module, &mut ct_rlwe_dft);
ct_rlwe_dft.external_product_inplace(&module, &ct_ggsw, scratch.borrow());
ct_rlwe_dft.idft(&module, &mut ct, scratch.borrow());
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0);
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let noise_have: f64 = pt_have.data.std(0, basek).log2();
let var_gct_err_lhs: f64 = sigma * sigma;
let var_gct_err_rhs: f64 = 0f64;
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
let var_a0_err: f64 = sigma * sigma;
let var_a1_err: f64 = 1f64 / 12f64;
let noise_want: f64 = noise_ggsw_product(
module.n() as f64,
basek,
0.5,
var_msg,
var_a0_err,
var_a1_err,
var_gct_err_lhs,
var_gct_err_rhs,
rank as f64,
k_ct,
k_ggsw,
);
assert!(
(noise_have - noise_want).abs() <= 0.1,
"{} {}",
noise_have,
noise_want
);
}

View File

@@ -0,0 +1,6 @@
mod automorphism_key;
mod gglwe;
mod ggsw;
mod glwe;
mod glwe_fourier;
mod tensor_key;

View File

@@ -0,0 +1,77 @@
use base2k::{FFT64, Module, ScalarZnx, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxDftOps, VecZnxOps};
use sampling::source::Source;
use crate::{
elem::{GetRow, Infos},
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::{SecretKey, SecretKeyFourier},
tensor_key::TensorKey,
};
#[test]
fn encrypt_sk() {
(1..4).for_each(|rank| {
println!("test encrypt_sk rank: {}", rank);
test_encrypt_sk(12, 16, 54, 3.2, rank);
});
}
fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let rows: usize = (k + basek - 1) / basek;
let mut tensor_key: TensorKey<Vec<u8>, FFT64> = TensorKey::new(&module, basek, k, rows, rank);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned = ScratchOwned::new(TensorKey::encrypt_sk_scratch_space(
&module,
rank,
tensor_key.size(),
));
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
sk_dft.dft(&module, &sk);
tensor_key.encrypt_sk(
&module,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
scratch.borrow(),
);
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
(0..rank).for_each(|i| {
(0..rank).for_each(|j| {
let mut sk_ij_dft: base2k::ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(1);
module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j);
let sk_ij: ScalarZnx<Vec<u8>> = module
.vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft())
.to_vec_znx_small()
.to_scalar_znx();
(0..tensor_key.rank_in()).for_each(|col_i| {
(0..tensor_key.rows()).for_each(|row_i| {
tensor_key
.at(i, j)
.get_row(&module, row_i, col_i, &mut ct_glwe_fourier);
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_ij, col_i);
let std_pt: f64 = pt.data.std(0, basek) * (k as f64).exp2();
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
});
});
})
})
}

3
core/src/utils.rs Normal file
View File

@@ -0,0 +1,3 @@
pub(crate) fn derive_size(basek: usize, k: usize) -> usize {
(k + basek - 1) / basek
}

View File

@@ -1,139 +0,0 @@
use base2k::{BACKEND, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8};
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use rlwe::{
ciphertext::{Ciphertext, new_gadget_ciphertext},
elem::ElemCommon,
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
gadget_product::{gadget_product_core, gadget_product_core_tmp_bytes},
keys::SecretKey,
parameters::{Parameters, ParametersLiteral},
};
use sampling::source::Source;
fn bench_gadget_product_inplace(c: &mut Criterion) {
fn runner<'a>(
module: &'a Module,
res_dft_0: &'a mut VecZnxDft,
res_dft_1: &'a mut VecZnxDft,
a: &'a VecZnx,
b: &'a Ciphertext<VmpPMat>,
b_cols: usize,
tmp_bytes: &'a mut [u8],
) -> Box<dyn FnMut() + 'a> {
Box::new(move || {
gadget_product_core(module, res_dft_0, res_dft_1, a, b, b_cols, tmp_bytes);
})
}
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("gadget_product_inplace");
for log_n in 10..11 {
let params_lit: ParametersLiteral = ParametersLiteral {
backend: BACKEND::FFT64,
log_n: log_n,
log_q: 32,
log_p: 0,
log_base2k: 16,
log_scale: 20,
xe: 3.2,
xs: 128,
};
let params: Parameters = Parameters::new(&params_lit);
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
| gadget_product_core_tmp_bytes(
params.module(),
params.log_base2k(),
params.log_q(),
params.log_q(),
params.cols_q(),
params.log_qp(),
)
| encrypt_grlwe_sk_tmp_bytes(
params.module(),
params.log_base2k(),
params.cols_qp(),
params.log_qp(),
),
);
let mut source: Source = Source::new([3; 32]);
let mut sk0: SecretKey = SecretKey::new(params.module());
let mut sk1: SecretKey = SecretKey::new(params.module());
sk0.fill_ternary_hw(params.xs(), &mut source);
sk1.fill_ternary_hw(params.xs(), &mut source);
let mut source_xe: Source = Source::new([4; 32]);
let mut source_xa: Source = Source::new([5; 32]);
let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0);
let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0);
let mut gadget_ct: Ciphertext<VmpPMat> = new_gadget_ciphertext(
params.module(),
params.log_base2k(),
params.cols_q(),
params.log_qp(),
);
encrypt_grlwe_sk(
params.module(),
&mut gadget_ct,
&sk0.0,
&sk1_svp_ppol,
&mut source_xa,
&mut source_xe,
params.xe(),
&mut tmp_bytes,
);
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(params.log_q());
params.encrypt_rlwe_sk(
&mut ct,
None,
&sk0_svp_ppol,
&mut source_xa,
&mut source_xe,
&mut tmp_bytes,
);
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(1, gadget_ct.cols());
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(1, gadget_ct.cols());
let mut a: VecZnx = params.module().new_vec_znx(0, params.cols_q());
params
.module()
.fill_uniform(params.log_base2k(), &mut a, params.cols_q(), &mut source_xa);
let b_cols: usize = gadget_ct.cols();
let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), {
runner(
params.module(),
&mut res_dft_0,
&mut res_dft_1,
&mut a,
&gadget_ct,
b_cols,
&mut tmp_bytes,
)
})];
for (name, mut runner) in runners {
let id: BenchmarkId = BenchmarkId::new(name, format!("n={}", 1 << log_n));
b.bench_with_input(id, &(), |b: &mut criterion::Bencher<'_>, _| {
b.iter(&mut runner)
});
}
}
}
criterion_group!(benches, bench_gadget_product_inplace);
criterion_main!(benches);

View File

@@ -1,76 +0,0 @@
use base2k::{Encoding, SvpPPolOps, VecZnx, alloc_aligned};
use rlwe::{
ciphertext::Ciphertext,
elem::ElemCommon,
keys::SecretKey,
parameters::{Parameters, ParametersLiteral},
plaintext::Plaintext,
};
use sampling::source::Source;
fn main() {
let params_lit: ParametersLiteral = ParametersLiteral {
backend: base2k::BACKEND::FFT64,
log_n: 10,
log_q: 54,
log_p: 0,
log_base2k: 17,
log_scale: 20,
xe: 3.2,
xs: 128,
};
let params: Parameters = Parameters::new(&params_lit);
let mut tmp_bytes: Vec<u8> =
alloc_aligned(params.decrypt_rlwe_tmp_byte(params.log_q()) | params.encrypt_rlwe_sk_tmp_bytes(params.log_q()));
let mut source: Source = Source::new([0; 32]);
let mut sk: SecretKey = SecretKey::new(params.module());
sk.fill_ternary_hw(params.xs(), &mut source);
let mut want = vec![i64::default(); params.n()];
want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
let mut pt: Plaintext = params.new_plaintext(params.log_q());
let log_base2k = pt.log_base2k();
let log_k: usize = params.log_q() - 20;
pt.0.value[0].encode_vec_i64(0, log_base2k, log_k, &want, 32);
pt.0.value[0].normalize(log_base2k, &mut tmp_bytes);
println!("log_k: {}", log_k);
pt.0.value[0].print(0, pt.cols(), 16);
println!();
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(params.log_q());
let mut source_xe: Source = Source::new([1; 32]);
let mut source_xa: Source = Source::new([2; 32]);
let mut sk_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
params.module().svp_prepare(&mut sk_svp_ppol, &sk.0);
params.encrypt_rlwe_sk(
&mut ct,
Some(&pt),
&sk_svp_ppol,
&mut source_xa,
&mut source_xe,
&mut tmp_bytes,
);
params.decrypt_rlwe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes);
pt.0.value[0].print(0, pt.cols(), 16);
let mut have = vec![i64::default(); params.n()];
println!("pt: {}", log_k);
pt.0.value[0].decode_vec_i64(0, pt.log_base2k(), log_k, &mut have);
println!("want: {:?}", &want[..16]);
println!("have: {:?}", &have[..16]);
}

View File

@@ -1,349 +0,0 @@
use crate::{
ciphertext::{Ciphertext, new_gadget_ciphertext},
elem::ElemCommon,
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
key_switching::{key_switch_rlwe, key_switch_rlwe_inplace, key_switch_tmp_bytes},
keys::SecretKey,
parameters::Parameters,
};
use base2k::{
Module, Scalar, ScalarOps, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat,
VmpPMatOps, assert_alignement,
};
use sampling::source::Source;
use std::collections::HashMap;
/// Stores DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E, A]) where AUTO(X, p): X^{i} -> X^{i*p}
pub struct AutomorphismKey {
pub value: Ciphertext<VmpPMat>,
pub p: i64,
}
pub fn automorphis_key_new_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
module.bytes_of_scalar() + module.bytes_of_svp_ppol() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
}
impl Parameters {
pub fn automorphism_key_new_tmp_bytes(&self, rows: usize, log_q: usize) -> usize {
automorphis_key_new_tmp_bytes(self.module(), self.log_base2k(), rows, log_q)
}
pub fn automorphism_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
automorphism_tmp_bytes(
self.module(),
self.log_base2k(),
res_logq,
in_logq,
gct_logq,
)
}
}
impl AutomorphismKey {
pub fn new(
module: &Module,
p: i64,
sk: &SecretKey,
log_base2k: usize,
rows: usize,
log_q: usize,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) -> Self {
Self::new_many_core(
module,
&vec![p],
sk,
log_base2k,
rows,
log_q,
source_xa,
source_xe,
sigma,
tmp_bytes,
)
.into_iter()
.next()
.unwrap()
}
pub fn new_many(
module: &Module,
p: &Vec<i64>,
sk: &SecretKey,
log_base2k: usize,
rows: usize,
log_q: usize,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) -> HashMap<i64, AutomorphismKey> {
Self::new_many_core(
module, p, sk, log_base2k, rows, log_q, source_xa, source_xe, sigma, tmp_bytes,
)
.into_iter()
.zip(p.iter().cloned())
.map(|(key, pi)| (pi, key))
.collect()
}
fn new_many_core(
module: &Module,
p: &Vec<i64>,
sk: &SecretKey,
log_base2k: usize,
rows: usize,
log_q: usize,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) -> Vec<Self> {
let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar());
let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol());
let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes);
let mut sk_out: SvpPPol = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes);
let mut keys: Vec<AutomorphismKey> = Vec::new();
p.iter().for_each(|pi| {
let mut value: Ciphertext<VmpPMat> = new_gadget_ciphertext(module, log_base2k, rows, log_q);
let p_inv: i64 = module.galois_element_inv(*pi);
module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx());
module.svp_prepare(&mut sk_out, &sk_auto);
encrypt_grlwe_sk(
module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes,
);
keys.push(Self {
value: value,
p: *pi,
})
});
keys
}
}
pub fn automorphism_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
key_switch_tmp_bytes(module, log_base2k, res_logq, in_logq, gct_logq)
}
pub fn automorphism(
module: &Module,
c: &mut Ciphertext<VecZnx>,
a: &Ciphertext<VecZnx>,
b: &AutomorphismKey,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
key_switch_rlwe(module, c, a, &b.value, b_cols, tmp_bytes);
// c[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
module.vec_znx_automorphism_inplace(b.p, c.at_mut(0));
// c[1] = AUTO(b, p)
module.vec_znx_automorphism_inplace(b.p, c.at_mut(1));
}
pub fn automorphism_inplace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
+ 2 * module.bytes_of_vec_znx_dft(1, std::cmp::min(c_cols, a_cols));
}
pub fn automorphism_inplace(
module: &Module,
a: &mut Ciphertext<VecZnx>,
b: &AutomorphismKey,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
key_switch_rlwe_inplace(module, a, &b.value, b_cols, tmp_bytes);
// a[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
module.vec_znx_automorphism_inplace(b.p, a.at_mut(0));
// a[1] = AUTO(b, p)
module.vec_znx_automorphism_inplace(b.p, a.at_mut(1));
}
pub fn automorphism_big(
module: &Module,
c: &mut Ciphertext<VecZnxBig>,
a: &Ciphertext<VecZnx>,
b: &AutomorphismKey,
tmp_bytes: &mut [u8],
) {
let cols = std::cmp::min(c.cols(), a.cols());
#[cfg(debug_assertions)]
{
assert!(tmp_bytes.len() >= automorphism_tmp_bytes(module, c.cols(), a.cols(), b.value.rows(), b.value.cols()));
assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft);
// a1_dft = DFT(a[1])
module.vec_znx_dft(&mut a1_dft, a.at(1));
// res_dft = IDFT(<DFT(a), DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E])>) = [-b*AUTO(s, -p) + a * s + e]
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(0), tmp_bytes);
module.vec_znx_idft_tmp_a(c.at_mut(0), &mut res_dft);
// res_dft = [-b*AUTO(s, -p) + a * s + e] + [-a * s + m + e] = [-b*AUTO(s, -p) + m + e]
module.vec_znx_big_add_small_inplace(c.at_mut(0), a.at(0));
// c[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
module.vec_znx_big_automorphism_inplace(b.p, c.at_mut(0));
// res_dft = IDFT(<DFT(a), DFT([A])>) = [b]
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(1), tmp_bytes);
module.vec_znx_idft_tmp_a(c.at_mut(1), &mut res_dft);
// c[1] = AUTO(b, p)
module.vec_znx_big_automorphism_inplace(b.p, c.at_mut(1));
}
#[cfg(test)]
mod test {
use super::{AutomorphismKey, automorphism};
use crate::{
ciphertext::Ciphertext,
decryptor::decrypt_rlwe,
elem::ElemCommon,
encryptor::encrypt_rlwe_sk,
keys::SecretKey,
parameters::{Parameters, ParametersLiteral},
plaintext::Plaintext,
};
use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, alloc_aligned};
use sampling::source::{Source, new_seed};
#[test]
fn test_automorphism() {
let log_base2k: usize = 10;
let log_q: usize = 50;
let log_p: usize = 15;
// Basic parameters with enough limbs to test edge cases
let params_lit: ParametersLiteral = ParametersLiteral {
backend: BACKEND::FFT64,
log_n: 12,
log_q: log_q,
log_p: log_p,
log_base2k: log_base2k,
log_scale: 20,
xe: 3.2,
xs: 1 << 11,
};
let params: Parameters = Parameters::new(&params_lit);
let module: &Module = params.module();
let log_q: usize = params.log_q();
let log_qp: usize = params.log_qp();
let gct_rows: usize = params.cols_q();
let gct_cols: usize = params.cols_qp();
// scratch space
let mut tmp_bytes: Vec<u8> = alloc_aligned(
params.decrypt_rlwe_tmp_byte(log_q)
| params.encrypt_rlwe_sk_tmp_bytes(log_q)
| params.automorphism_key_new_tmp_bytes(gct_rows, log_qp)
| params.automorphism_tmp_bytes(log_q, log_q, log_qp),
);
// Samplers for public and private randomness
let mut source_xe: Source = Source::new(new_seed());
let mut source_xa: Source = Source::new(new_seed());
let mut source_xs: Source = Source::new(new_seed());
let mut sk: SecretKey = SecretKey::new(module);
sk.fill_ternary_hw(params.xs(), &mut source_xs);
let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol();
module.svp_prepare(&mut sk_svp_ppol, &sk.0);
let p: i64 = -5;
let auto_key: AutomorphismKey = AutomorphismKey::new(
module,
p,
&sk,
log_base2k,
gct_rows,
log_qp,
&mut source_xa,
&mut source_xe,
params.xe(),
&mut tmp_bytes,
);
let mut data: Vec<i64> = vec![0i64; params.n()];
data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
let log_k: usize = 2 * log_base2k;
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
let mut pt: Plaintext = params.new_plaintext(log_q);
let mut pt_auto: Plaintext = params.new_plaintext(log_q);
pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32);
module.vec_znx_automorphism(p, pt_auto.at_mut(0), pt.at(0));
encrypt_rlwe_sk(
module,
&mut ct.elem_mut(),
Some(pt.at(0)),
&sk_svp_ppol,
&mut source_xa,
&mut source_xe,
params.xe(),
&mut tmp_bytes,
);
let mut ct_auto: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
// ct <- AUTO(ct)
automorphism(
module,
&mut ct_auto,
&ct,
&auto_key,
gct_cols,
&mut tmp_bytes,
);
// pt = dec(auto(ct)) - auto(pt)
decrypt_rlwe(
module,
pt.elem_mut(),
ct_auto.elem(),
&sk_svp_ppol,
&mut tmp_bytes,
);
module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_auto.at(0));
// pt.at(0).print(pt.cols(), 16);
let noise_have: f64 = pt.at(0).std(0, log_base2k).log2();
let var_msg: f64 = (params.xs() as f64) / params.n() as f64;
let var_a_err: f64 = 1f64 / 12f64;
let noise_pred: f64 = params.noise_grlwe_product(var_msg, var_a_err, ct_auto.log_q(), auto_key.value.log_q());
println!("noise_pred: {}", noise_pred);
println!("noise_have: {}", noise_have);
assert!(noise_have <= noise_pred + 1.0);
}
}

View File

@@ -1,93 +0,0 @@
use crate::elem::{Elem, ElemCommon};
use crate::parameters::Parameters;
use base2k::{Infos, LAYOUT, Module, VecZnx, VmpPMat};
pub struct Ciphertext<T>(pub Elem<T>);
impl Parameters {
pub fn new_ciphertext(&self, log_q: usize) -> Ciphertext<VecZnx> {
Ciphertext::new(self.module(), self.log_base2k(), log_q, 2)
}
}
impl<T> ElemCommon<T> for Ciphertext<T>
where
T: Infos,
{
fn n(&self) -> usize {
self.elem().n()
}
fn log_n(&self) -> usize {
self.elem().log_n()
}
fn log_q(&self) -> usize {
self.elem().log_q()
}
fn elem(&self) -> &Elem<T> {
&self.0
}
fn elem_mut(&mut self) -> &mut Elem<T> {
&mut self.0
}
fn size(&self) -> usize {
self.elem().size()
}
fn layout(&self) -> LAYOUT {
self.elem().layout()
}
fn rows(&self) -> usize {
self.elem().rows()
}
fn cols(&self) -> usize {
self.elem().cols()
}
fn at(&self, i: usize) -> &T {
self.elem().at(i)
}
fn at_mut(&mut self, i: usize) -> &mut T {
self.elem_mut().at_mut(i)
}
fn log_base2k(&self) -> usize {
self.elem().log_base2k()
}
fn log_scale(&self) -> usize {
self.elem().log_scale()
}
}
impl Ciphertext<VecZnx> {
pub fn new(module: &Module, log_base2k: usize, log_q: usize, rows: usize) -> Self {
Self(Elem::<VecZnx>::new(module, log_base2k, log_q, rows))
}
}
pub fn new_rlwe_ciphertext(module: &Module, log_base2k: usize, log_q: usize) -> Ciphertext<VecZnx> {
let rows: usize = 2;
Ciphertext::<VecZnx>::new(module, log_base2k, log_q, rows)
}
pub fn new_gadget_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext<VmpPMat> {
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 2, rows, cols);
elem.log_q = log_q;
Ciphertext(elem)
}
pub fn new_rgsw_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext<VmpPMat> {
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 4, rows, cols);
elem.log_q = log_q;
Ciphertext(elem)
}

View File

@@ -1,67 +0,0 @@
use crate::{
ciphertext::Ciphertext,
elem::{Elem, ElemCommon},
keys::SecretKey,
parameters::Parameters,
plaintext::Plaintext,
};
use base2k::{Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZnxDftOps};
use std::cmp::min;
pub struct Decryptor {
sk: SvpPPol,
}
impl Decryptor {
pub fn new(params: &Parameters, sk: &SecretKey) -> Self {
let mut sk_svp_ppol: SvpPPol = params.module().new_svp_ppol();
sk.prepare(params.module(), &mut sk_svp_ppol);
Self { sk: sk_svp_ppol }
}
}
pub fn decrypt_rlwe_tmp_byte(module: &Module, cols: usize) -> usize {
module.bytes_of_vec_znx_dft(1, cols) + module.vec_znx_big_normalize_tmp_bytes()
}
impl Parameters {
pub fn decrypt_rlwe_tmp_byte(&self, log_q: usize) -> usize {
decrypt_rlwe_tmp_byte(
self.module(),
(log_q + self.log_base2k() - 1) / self.log_base2k(),
)
}
pub fn decrypt_rlwe(&self, res: &mut Plaintext, ct: &Ciphertext<VecZnx>, sk: &SvpPPol, tmp_bytes: &mut [u8]) {
decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes)
}
}
pub fn decrypt_rlwe(module: &Module, res: &mut Elem<VecZnx>, a: &Elem<VecZnx>, sk: &SvpPPol, tmp_bytes: &mut [u8]) {
let cols: usize = a.cols();
assert!(
tmp_bytes.len() >= decrypt_rlwe_tmp_byte(module, cols),
"invalid tmp_bytes: tmp_bytes.len()={} < decrypt_rlwe_tmp_byte={}",
tmp_bytes.len(),
decrypt_rlwe_tmp_byte(module, cols)
);
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let mut res_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, 1, cols, tmp_bytes_vec_znx_dft);
let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big();
// res_dft <- DFT(ct[1]) * DFT(sk)
module.svp_apply_dft(&mut res_dft, sk, a.at(1));
// res_big <- ct[1] x sk
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
// res_big <- ct[1] x sk + ct[0]
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
// res <- normalize(ct[1] x sk + ct[0])
module.vec_znx_big_normalize(a.log_base2k(), res.at_mut(0), &res_big, tmp_bytes_normalize);
res.log_base2k = a.log_base2k();
res.log_q = min(res.log_q(), a.log_q());
res.log_scale = a.log_scale();
}

View File

@@ -1,168 +0,0 @@
use base2k::{Infos, LAYOUT, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps};
pub struct Elem<T> {
pub value: Vec<T>,
pub log_base2k: usize,
pub log_q: usize,
pub log_scale: usize,
}
pub trait ElemVecZnx {
fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx>;
fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx>;
fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize;
fn zero(&mut self);
}
impl ElemVecZnx for Elem<VecZnx> {
fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize {
let cols = (log_q + log_base2k - 1) / log_base2k;
module.n() * cols * size * 8
}
fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx> {
assert!(size > 0);
let n: usize = module.n();
assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size));
let mut value: Vec<VecZnx> = Vec::new();
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
let elem_size = VecZnx::bytes_of(n, size, cols);
let mut ptr: usize = 0;
(0..size).for_each(|_| {
value.push(VecZnx::from_bytes(n, 1, cols, &mut bytes[ptr..]));
ptr += elem_size
});
Self {
value,
log_q,
log_base2k,
log_scale: 0,
}
}
fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx> {
assert!(size > 0);
let n: usize = module.n();
assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size));
let mut value: Vec<VecZnx> = Vec::new();
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
let elem_size = VecZnx::bytes_of(n, 1, cols);
let mut ptr: usize = 0;
(0..size).for_each(|_| {
value.push(VecZnx::from_bytes_borrow(n, 1, cols, &mut bytes[ptr..]));
ptr += elem_size
});
Self {
value,
log_q,
log_base2k,
log_scale: 0,
}
}
fn zero(&mut self) {
self.value.iter_mut().for_each(|i| i.zero());
}
}
pub trait ElemCommon<T> {
fn n(&self) -> usize;
fn log_n(&self) -> usize;
fn elem(&self) -> &Elem<T>;
fn elem_mut(&mut self) -> &mut Elem<T>;
fn size(&self) -> usize;
fn layout(&self) -> LAYOUT;
fn rows(&self) -> usize;
fn cols(&self) -> usize;
fn log_base2k(&self) -> usize;
fn log_q(&self) -> usize;
fn log_scale(&self) -> usize;
fn at(&self, i: usize) -> &T;
fn at_mut(&mut self, i: usize) -> &mut T;
}
impl<T: Infos> ElemCommon<T> for Elem<T> {
fn n(&self) -> usize {
self.value[0].n()
}
fn log_n(&self) -> usize {
self.value[0].log_n()
}
fn elem(&self) -> &Elem<T> {
self
}
fn elem_mut(&mut self) -> &mut Elem<T> {
self
}
fn size(&self) -> usize {
self.value.len()
}
fn layout(&self) -> LAYOUT {
self.value[0].layout()
}
fn rows(&self) -> usize {
self.value[0].rows()
}
fn cols(&self) -> usize {
self.value[0].cols()
}
fn log_base2k(&self) -> usize {
self.log_base2k
}
fn log_q(&self) -> usize {
self.log_q
}
fn log_scale(&self) -> usize {
self.log_scale
}
fn at(&self, i: usize) -> &T {
assert!(i < self.size());
&self.value[i]
}
fn at_mut(&mut self, i: usize) -> &mut T {
assert!(i < self.size());
&mut self.value[i]
}
}
impl Elem<VecZnx> {
pub fn new(module: &Module, log_base2k: usize, log_q: usize, rows: usize) -> Self {
assert!(rows > 0);
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
let mut value: Vec<VecZnx> = Vec::new();
(0..rows).for_each(|_| value.push(module.new_vec_znx(1, cols)));
Self {
value,
log_q,
log_base2k,
log_scale: 0,
}
}
}
impl Elem<VmpPMat> {
pub fn new(module: &Module, log_base2k: usize, size: usize, rows: usize, cols: usize) -> Self {
assert!(rows > 0);
assert!(cols > 0);
let mut value: Vec<VmpPMat> = Vec::new();
(0..size).for_each(|_| value.push(module.new_vmp_pmat(1, rows, cols)));
Self {
value: value,
log_q: 0,
log_base2k: log_base2k,
log_scale: 0,
}
}
}

View File

@@ -1,369 +0,0 @@
use crate::ciphertext::Ciphertext;
use crate::elem::{Elem, ElemCommon, ElemVecZnx};
use crate::keys::SecretKey;
use crate::parameters::Parameters;
use crate::plaintext::Plaintext;
use base2k::sampling::Sampling;
use base2k::{
Infos, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat,
VmpPMatOps,
};
use sampling::source::{Source, new_seed};
impl Parameters {
pub fn encrypt_rlwe_sk_tmp_bytes(&self, log_q: usize) -> usize {
encrypt_rlwe_sk_tmp_bytes(self.module(), self.log_base2k(), log_q)
}
pub fn encrypt_rlwe_sk(
&self,
ct: &mut Ciphertext<VecZnx>,
pt: Option<&Plaintext>,
sk: &SvpPPol,
source_xa: &mut Source,
source_xe: &mut Source,
tmp_bytes: &mut [u8],
) {
encrypt_rlwe_sk(
self.module(),
&mut ct.0,
pt.map(|pt| pt.at(0)),
sk,
source_xa,
source_xe,
self.xe(),
tmp_bytes,
)
}
}
pub struct EncryptorSk {
sk: SvpPPol,
source_xa: Source,
source_xe: Source,
initialized: bool,
tmp_bytes: Vec<u8>,
}
impl EncryptorSk {
pub fn new(params: &Parameters, sk: Option<&SecretKey>) -> Self {
let mut sk_svp_ppol: SvpPPol = params.module().new_svp_ppol();
let mut initialized: bool = false;
if let Some(sk) = sk {
sk.prepare(params.module(), &mut sk_svp_ppol);
initialized = true;
}
Self {
sk: sk_svp_ppol,
initialized,
source_xa: Source::new(new_seed()),
source_xe: Source::new(new_seed()),
tmp_bytes: vec![0u8; params.encrypt_rlwe_sk_tmp_bytes(params.cols_qp())],
}
}
pub fn set_sk(&mut self, module: &Module, sk: &SecretKey) {
sk.prepare(module, &mut self.sk);
self.initialized = true;
}
pub fn seed_source_xa(&mut self, seed: [u8; 32]) {
self.source_xa = Source::new(seed)
}
pub fn seed_source_xe(&mut self, seed: [u8; 32]) {
self.source_xe = Source::new(seed)
}
pub fn encrypt_rlwe_sk(&mut self, params: &Parameters, ct: &mut Ciphertext<VecZnx>, pt: Option<&Plaintext>) {
assert!(
self.initialized == true,
"invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]"
);
params.encrypt_rlwe_sk(
ct,
pt,
&self.sk,
&mut self.source_xa,
&mut self.source_xe,
&mut self.tmp_bytes,
);
}
pub fn encrypt_rlwe_sk_core(
&self,
params: &Parameters,
ct: &mut Ciphertext<VecZnx>,
pt: Option<&Plaintext>,
source_xa: &mut Source,
source_xe: &mut Source,
tmp_bytes: &mut [u8],
) {
assert!(
self.initialized == true,
"invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]"
);
params.encrypt_rlwe_sk(ct, pt, &self.sk, source_xa, source_xe, tmp_bytes);
}
}
pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize {
module.bytes_of_vec_znx_dft(1, (log_q + log_base2k - 1) / log_base2k) + module.vec_znx_big_normalize_tmp_bytes()
}
pub fn encrypt_rlwe_sk(
module: &Module,
ct: &mut Elem<VecZnx>,
pt: Option<&VecZnx>,
sk: &SvpPPol,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) {
encrypt_rlwe_sk_core::<0>(module, ct, pt, sk, source_xa, source_xe, sigma, tmp_bytes)
}
fn encrypt_rlwe_sk_core<const PT_POS: u8>(
module: &Module,
ct: &mut Elem<VecZnx>,
pt: Option<&VecZnx>,
sk: &SvpPPol,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) {
let cols: usize = ct.cols();
let log_base2k: usize = ct.log_base2k();
let log_q: usize = ct.log_q();
assert!(
tmp_bytes.len() >= encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q),
"invalid tmp_bytes: tmp_bytes={} < encrypt_rlwe_sk_tmp_bytes={}",
tmp_bytes.len(),
encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
);
let log_q: usize = ct.log_q();
let log_base2k: usize = ct.log_base2k();
let c1: &mut VecZnx = ct.at_mut(1);
// c1 <- Z_{2^prec}[X]/(X^{N}+1)
module.fill_uniform(log_base2k, c1, cols, source_xa);
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
// Scratch space for DFT values
let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, 1, cols, tmp_bytes_vec_znx_dft);
// Applies buf_dft <- DFT(s) * DFT(c1)
module.svp_apply_dft(&mut buf_dft, sk, c1);
// Alias scratch space
let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big();
// buf_big = s x c1
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
match PT_POS {
// c0 <- -s x c1 + m
0 => {
let c0: &mut VecZnx = ct.at_mut(0);
if let Some(pt) = pt {
module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt);
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
} else {
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
module.vec_znx_negate_inplace(c0);
}
}
// c1 <- c1 + m
1 => {
if let Some(pt) = pt {
module.vec_znx_add_inplace(c1, pt);
c1.normalize(log_base2k, tmp_bytes_normalize);
}
let c0: &mut VecZnx = ct.at_mut(0);
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
module.vec_znx_negate_inplace(c0);
}
_ => panic!("PT_POS must be 1 or 2"),
}
// c0 <- -s x c1 + m + e
module.add_normal(
log_base2k,
ct.at_mut(0),
log_q,
source_xe,
sigma,
(sigma * 6.0).ceil(),
);
}
impl Parameters {
pub fn encrypt_grlwe_sk_tmp_bytes(&self, rows: usize, log_q: usize) -> usize {
encrypt_grlwe_sk_tmp_bytes(self.module(), self.log_base2k(), rows, log_q)
}
}
pub fn encrypt_grlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
let cols = (log_q + log_base2k - 1) / log_base2k;
Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2)
+ Plaintext::bytes_of(module, log_base2k, log_q)
+ encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
+ module.vmp_prepare_tmp_bytes(rows, cols)
}
pub fn encrypt_grlwe_sk(
module: &Module,
ct: &mut Ciphertext<VmpPMat>,
m: &Scalar,
sk: &SvpPPol,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) {
let log_q: usize = ct.log_q();
let log_base2k: usize = ct.log_base2k();
let (left, right) = ct.0.value.split_at_mut(1);
encrypt_grlwe_sk_core::<0>(
module,
log_base2k,
[&mut left[0], &mut right[0]],
log_q,
m,
sk,
source_xa,
source_xe,
sigma,
tmp_bytes,
)
}
impl Parameters {
pub fn encrypt_rgsw_sk_tmp_bytes(&self, rows: usize, log_q: usize) -> usize {
encrypt_rgsw_sk_tmp_bytes(self.module(), self.log_base2k(), rows, log_q)
}
}
pub fn encrypt_rgsw_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
let cols = (log_q + log_base2k - 1) / log_base2k;
Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2)
+ Plaintext::bytes_of(module, log_base2k, log_q)
+ encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
+ module.vmp_prepare_tmp_bytes(rows, cols)
}
pub fn encrypt_rgsw_sk(
module: &Module,
ct: &mut Ciphertext<VmpPMat>,
m: &Scalar,
sk: &SvpPPol,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) {
let log_q: usize = ct.log_q();
let log_base2k: usize = ct.log_base2k();
let (left, right) = ct.0.value.split_at_mut(2);
let (ll, lr) = left.split_at_mut(1);
let (rl, rr) = right.split_at_mut(1);
encrypt_grlwe_sk_core::<0>(
module,
log_base2k,
[&mut ll[0], &mut lr[0]],
log_q,
m,
sk,
source_xa,
source_xe,
sigma,
tmp_bytes,
);
encrypt_grlwe_sk_core::<1>(
module,
log_base2k,
[&mut rl[0], &mut rr[0]],
log_q,
m,
sk,
source_xa,
source_xe,
sigma,
tmp_bytes,
);
}
fn encrypt_grlwe_sk_core<const PT_POS: u8>(
module: &Module,
log_base2k: usize,
mut ct: [&mut VmpPMat; 2],
log_q: usize,
m: &Scalar,
sk: &SvpPPol,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) {
let rows: usize = ct[0].rows();
let min_tmp_bytes_len = encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q);
assert!(
tmp_bytes.len() >= min_tmp_bytes_len,
"invalid tmp_bytes: tmp_bytes.len()={} < encrypt_grlwe_sk_tmp_bytes={}",
tmp_bytes.len(),
min_tmp_bytes_len
);
let bytes_of_elem: usize = Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2);
let bytes_of_pt: usize = Plaintext::bytes_of(module, log_base2k, log_q);
let bytes_of_enc_sk: usize = encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q);
let (tmp_bytes_pt, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_pt);
let (tmp_bytes_enc_sk, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_enc_sk);
let (tmp_bytes_elem, tmp_bytes_vmp_prepare_row) = tmp_bytes.split_at_mut(bytes_of_elem);
let mut tmp_elem: Elem<VecZnx> = Elem::<VecZnx>::from_bytes_borrow(module, log_base2k, log_q, 2, tmp_bytes_elem);
let mut tmp_pt: Plaintext = Plaintext::from_bytes_borrow(module, log_base2k, log_q, tmp_bytes_pt);
(0..rows).for_each(|row_i| {
// Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i})
tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.raw());
// Encrypts RLWE(m * 2^{-log_base2k*i})
encrypt_rlwe_sk_core::<PT_POS>(
module,
&mut tmp_elem,
Some(tmp_pt.at(0)),
sk,
source_xa,
source_xe,
sigma,
tmp_bytes_enc_sk,
);
// Zeroes the ith-row of tmp_pt
tmp_pt.at_mut(0).at_mut(row_i).fill(0);
// GRLWE[row_i][0||1] = [-as + m * 2^{-i*log_base2k} + e*2^{-log_q} || a]
module.vmp_prepare_row(
ct[0],
tmp_elem.at(0).raw(),
row_i,
tmp_bytes_vmp_prepare_row,
);
module.vmp_prepare_row(
&mut ct[1],
tmp_elem.at(1).raw(),
row_i,
tmp_bytes_vmp_prepare_row,
);
});
}

View File

@@ -1,383 +0,0 @@
use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters};
use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps};
use std::cmp::min;
pub fn gadget_product_core_tmp_bytes(
module: &Module,
log_base2k: usize,
res_log_q: usize,
in_log_q: usize,
gct_rows: usize,
gct_log_q: usize,
) -> usize {
let gct_cols: usize = (gct_log_q + log_base2k - 1) / log_base2k;
let in_cols: usize = (in_log_q + log_base2k - 1) / log_base2k;
let out_cols: usize = (res_log_q + log_base2k - 1) / log_base2k;
module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, gct_rows, gct_cols)
}
impl Parameters {
pub fn gadget_product_tmp_bytes(&self, res_log_q: usize, in_log_q: usize, gct_rows: usize, gct_log_q: usize) -> usize {
gadget_product_core_tmp_bytes(
self.module(),
self.log_base2k(),
res_log_q,
in_log_q,
gct_rows,
gct_log_q,
)
}
}
pub fn gadget_product_core(
module: &Module,
res_dft_0: &mut VecZnxDft,
res_dft_1: &mut VecZnxDft,
a: &VecZnx,
b: &Ciphertext<VmpPMat>,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
assert!(b_cols <= b.cols());
module.vec_znx_dft(res_dft_1, a);
module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes);
module.vmp_apply_dft_to_dft_inplace(res_dft_1, b.at(1), tmp_bytes);
}
pub fn gadget_product_big_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
+ 2 * module.bytes_of_vec_znx_dft(1, min(c_cols, a_cols));
}
/// Evaluates the gadget product: c.at(i) = IDFT(<DFT(a.at(i)), b.at(i)>)
///
/// # Arguments
///
/// * `module`: backend support for operations mod (X^N + 1).
/// * `c`: a [Ciphertext<VecZnxBig>] with cols_c cols.
/// * `a`: a [Ciphertext<VecZnx>] with cols_a cols.
/// * `b`: a [Ciphertext<VmpPMat>] with at least min(cols_c, cols_a) rows.
pub fn gadget_product_big(
module: &Module,
c: &mut Ciphertext<VecZnxBig>,
a: &Ciphertext<VecZnx>,
b: &Ciphertext<VmpPMat>,
tmp_bytes: &mut [u8],
) {
let cols: usize = min(c.cols(), a.cols());
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft);
// a1_dft = DFT(a[1])
module.vec_znx_dft(&mut a1_dft, a.at(1));
// c[i] = IDFT(DFT(a[1]) * b[i])
(0..2).for_each(|i| {
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(i), tmp_bytes);
module.vec_znx_idft_tmp_a(c.at_mut(i), &mut res_dft);
})
}
/// Evaluates the gadget product: c.at(i) = NORMALIZE(IDFT(<DFT(a.at(i)), b.at(i)>)
///
/// # Arguments
///
/// * `module`: backend support for operations mod (X^N + 1).
/// * `c`: a [Ciphertext<VecZnx>] with cols_c cols.
/// * `a`: a [Ciphertext<VecZnx>] with cols_a cols.
/// * `b`: a [Ciphertext<VmpPMat>] with at least min(cols_c, cols_a) rows.
pub fn gadget_product(
module: &Module,
c: &mut Ciphertext<VecZnx>,
a: &Ciphertext<VecZnx>,
b: &Ciphertext<VmpPMat>,
tmp_bytes: &mut [u8],
) {
let cols: usize = min(c.cols(), a.cols());
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft);
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
// a1_dft = DFT(a[1])
module.vec_znx_dft(&mut a1_dft, a.at(1));
// c[i] = IDFT(DFT(a[1]) * b[i])
(0..2).for_each(|i| {
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(i), tmp_bytes);
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(i), &mut res_big, tmp_bytes);
})
}
#[cfg(test)]
mod test {
use crate::{
ciphertext::{Ciphertext, new_gadget_ciphertext},
decryptor::decrypt_rlwe,
elem::{Elem, ElemCommon, ElemVecZnx},
encryptor::encrypt_grlwe_sk,
gadget_product::gadget_product_core,
keys::SecretKey,
parameters::{Parameters, ParametersLiteral},
plaintext::Plaintext,
};
use base2k::{
BACKEND, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat,
alloc_aligned_u8,
};
use sampling::source::{Source, new_seed};
#[test]
fn test_gadget_product_core() {
let log_base2k: usize = 10;
let q_cols: usize = 7;
let p_cols: usize = 1;
// Basic parameters with enough limbs to test edge cases
let params_lit: ParametersLiteral = ParametersLiteral {
backend: BACKEND::FFT64,
log_n: 12,
log_q: q_cols * log_base2k,
log_p: p_cols * log_base2k,
log_base2k: log_base2k,
log_scale: 20,
xe: 3.2,
xs: 1 << 11,
};
let params: Parameters = Parameters::new(&params_lit);
// scratch space
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
params.decrypt_rlwe_tmp_byte(params.log_qp())
| params.gadget_product_tmp_bytes(
params.log_qp(),
params.log_qp(),
params.cols_qp(),
params.log_qp(),
)
| params.encrypt_grlwe_sk_tmp_bytes(params.cols_qp(), params.log_qp()),
);
// Samplers for public and private randomness
let mut source_xe: Source = Source::new(new_seed());
let mut source_xa: Source = Source::new(new_seed());
let mut source_xs: Source = Source::new(new_seed());
// Two secret keys
let mut sk0: SecretKey = SecretKey::new(params.module());
sk0.fill_ternary_hw(params.xs(), &mut source_xs);
let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0);
let mut sk1: SecretKey = SecretKey::new(params.module());
sk1.fill_ternary_hw(params.xs(), &mut source_xs);
let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0);
// The gadget ciphertext
let mut gadget_ct: Ciphertext<VmpPMat> = new_gadget_ciphertext(
params.module(),
log_base2k,
params.cols_qp(),
params.log_qp(),
);
// gct = [-b*sk1 + g(sk0) + e, b]
encrypt_grlwe_sk(
params.module(),
&mut gadget_ct,
&sk0.0,
&sk1_svp_ppol,
&mut source_xa,
&mut source_xe,
params.xe(),
&mut tmp_bytes,
);
// Intermediate buffers
// Input polynopmial, uniformly distributed
let mut a: VecZnx = params.module().new_vec_znx(1, params.cols_q());
params
.module()
.fill_uniform(log_base2k, &mut a, params.cols_q(), &mut source_xa);
// res = g^-1(a) * gct
let mut elem_res: Elem<VecZnx> = Elem::<VecZnx>::new(params.module(), log_base2k, params.log_qp(), 2);
// Ideal output = a * s
let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(1, a.cols());
let mut a_big: VecZnxBig = a_dft.as_vec_znx_big();
let mut a_times_s: VecZnx = params.module().new_vec_znx(1, a.cols());
// a * sk0
params.module().svp_apply_dft(&mut a_dft, &sk0_svp_ppol, &a);
params.module().vec_znx_idft_tmp_a(&mut a_big, &mut a_dft);
params
.module()
.vec_znx_big_normalize(params.log_base2k(), &mut a_times_s, &a_big, &mut tmp_bytes);
// Plaintext for decrypted output of gadget product
let mut pt: Plaintext = Plaintext::new(params.module(), params.log_base2k(), params.log_qp());
// Iterates over all possible cols values for input/output polynomials and gadget ciphertext.
(1..a.cols() + 1).for_each(|a_cols| {
let mut a_trunc: VecZnx = params.module().new_vec_znx(1, a_cols);
a_trunc.copy_from(&a);
(1..gadget_ct.cols() + 1).for_each(|b_cols| {
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(1, b_cols);
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(1, b_cols);
let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big();
let mut res_big_1: VecZnxBig = res_dft_1.as_vec_znx_big();
pt.elem_mut().zero();
elem_res.zero();
// let b_cols: usize = min(a_cols+1, gadget_ct.cols());
println!("a_cols: {} b_cols: {}", a_cols, b_cols);
// res_dft_0 = DFT(gct_[0] * ct[1] = a * (-bs' + s + e) = -cs' + as + e')
// res_dft_1 = DFT(gct_[1] * ct[1] = a * b = c)
gadget_product_core(
params.module(),
&mut res_dft_0,
&mut res_dft_1,
&a_trunc,
&gadget_ct,
b_cols,
&mut tmp_bytes,
);
// res_big_0 = IDFT(res_dft_0)
params
.module()
.vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0);
// res_big_1 = IDFT(res_dft_1);
params
.module()
.vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1);
// res_big_0 = normalize(res_big_0)
params
.module()
.vec_znx_big_normalize(log_base2k, elem_res.at_mut(0), &res_big_0, &mut tmp_bytes);
// res_big_1 = normalize(res_big_1)
params
.module()
.vec_znx_big_normalize(log_base2k, elem_res.at_mut(1), &res_big_1, &mut tmp_bytes);
// <(-c*sk1 + a*sk0 + e, a), (1, sk1)> = a*sk0 + e
decrypt_rlwe(
params.module(),
pt.elem_mut(),
&elem_res,
&sk1_svp_ppol,
&mut tmp_bytes,
);
// a * sk0 + e - a*sk0 = e
params
.module()
.vec_znx_sub_ab_inplace(pt.at_mut(0), &mut a_times_s);
pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes);
// pt.at(0).print(pt.elem().cols(), 16);
let noise_have: f64 = pt.at(0).std(0, log_base2k).log2();
let var_a_err: f64;
if a_cols < a.cols() {
var_a_err = 1f64 / 12f64;
} else {
var_a_err = 0f64;
}
let a_logq: usize = a_cols * log_base2k;
let b_logq: usize = b_cols * log_base2k;
let var_msg: f64 = (params.xs() as f64) / params.n() as f64;
println!("{} {} {} {}", var_msg, var_a_err, a_logq, b_logq);
let noise_pred: f64 = params.noise_grlwe_product(var_msg, var_a_err, a_logq, b_logq);
println!("noise_pred: {}", noise_pred);
println!("noise_have: {}", noise_have);
// assert!(noise_have <= noise_pred + 1.0);
});
});
}
}
impl Parameters {
pub fn noise_grlwe_product(&self, var_msg: f64, var_a_err: f64, a_logq: usize, b_logq: usize) -> f64 {
let n: f64 = self.n() as f64;
let var_xs: f64 = self.xs() as f64;
let var_gct_err_lhs: f64;
let var_gct_err_rhs: f64;
if b_logq < self.log_qp() {
let var_round: f64 = 1f64 / 12f64;
var_gct_err_lhs = var_round;
var_gct_err_rhs = var_round;
} else {
var_gct_err_lhs = self.xe() * self.xe();
var_gct_err_rhs = 0f64;
}
noise_grlwe_product(
n,
self.log_base2k(),
var_xs,
var_msg,
var_a_err,
var_gct_err_lhs,
var_gct_err_rhs,
a_logq,
b_logq,
)
}
}
pub fn noise_grlwe_product(
n: f64,
log_base2k: usize,
var_xs: f64,
var_msg: f64,
var_a_err: f64,
var_gct_err_lhs: f64,
var_gct_err_rhs: f64,
a_logq: usize,
b_logq: usize,
) -> f64 {
let a_logq: usize = min(a_logq, b_logq);
let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k;
let b_scale = 2.0f64.powi(b_logq as i32);
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
let base: f64 = (1 << (log_base2k)) as f64;
let var_base: f64 = base * base / 12f64;
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
noise += var_msg * var_a_err * a_scale * a_scale * n;
noise = noise.sqrt();
noise /= b_scale;
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
}

View File

@@ -1,55 +0,0 @@
use crate::encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes};
use crate::keys::{PublicKey, SecretKey, SwitchingKey};
use crate::parameters::Parameters;
use base2k::{Module, SvpPPol};
use sampling::source::Source;
pub struct KeyGenerator {}
impl KeyGenerator {
pub fn gen_secret_key_thread_safe(&self, params: &Parameters, source: &mut Source) -> SecretKey {
let mut sk: SecretKey = SecretKey::new(params.module());
sk.fill_ternary_hw(params.xs(), source);
sk
}
pub fn gen_public_key_thread_safe(
&self,
params: &Parameters,
sk_ppol: &SvpPPol,
source: &mut Source,
tmp_bytes: &mut [u8],
) -> PublicKey {
let mut xa_source: Source = source.branch();
let mut xe_source: Source = source.branch();
let mut pk: PublicKey = PublicKey::new(params.module(), params.log_base2k(), params.log_qp());
pk.gen_thread_safe(
params.module(),
sk_ppol,
params.xe(),
&mut xa_source,
&mut xe_source,
tmp_bytes,
);
pk
}
}
pub fn gen_switching_key_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
}
pub fn gen_switching_key(
module: &Module,
swk: &mut SwitchingKey,
sk_in: &SecretKey,
sk_out: &SvpPPol,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
tmp_bytes: &mut [u8],
) {
encrypt_grlwe_sk(
module, &mut swk.0, &sk_in.0, sk_out, source_xa, source_xe, sigma, tmp_bytes,
);
}

View File

@@ -1,79 +0,0 @@
use crate::ciphertext::Ciphertext;
use crate::elem::ElemCommon;
use base2k::{Module, VecZnx, VecZnxBigOps, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement};
use std::cmp::min;
pub fn key_switch_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k;
let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k;
let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k;
return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols)
+ module.bytes_of_vec_znx_dft(1, std::cmp::min(res_cols, in_cols))
+ module.bytes_of_vec_znx_dft(1, gct_cols);
}
pub fn key_switch_rlwe(
module: &Module,
c: &mut Ciphertext<VecZnx>,
a: &Ciphertext<VecZnx>,
b: &Ciphertext<VmpPMat>,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
key_switch_rlwe_core(module, c, a, b, b_cols, tmp_bytes);
}
pub fn key_switch_rlwe_inplace(
module: &Module,
a: &mut Ciphertext<VecZnx>,
b: &Ciphertext<VmpPMat>,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
key_switch_rlwe_core(module, a, a, b, b_cols, tmp_bytes);
}
fn key_switch_rlwe_core(
module: &Module,
c: *mut Ciphertext<VecZnx>,
a: *const Ciphertext<VecZnx>,
b: &Ciphertext<VmpPMat>,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
// SAFETY WARNING: must ensure `c` and `a` are valid for read/write
let c: &mut Ciphertext<VecZnx> = unsafe { &mut *c };
let a: &Ciphertext<VecZnx> = unsafe { &*a };
let cols: usize = min(min(c.cols(), a.cols()), b.rows());
#[cfg(debug_assertions)]
{
assert!(b_cols <= b.cols());
assert!(tmp_bytes.len() >= key_switch_tmp_bytes(module, c.cols(), a.cols(), b.rows(), b.cols()));
assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes_a1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
let mut a1_dft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_a1_dft);
let mut res_dft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_res_dft);
let mut res_big = res_dft.as_vec_znx_big();
module.vec_znx_dft(&mut a1_dft, a.at(1));
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(0), tmp_bytes);
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(0), &mut res_big, tmp_bytes);
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(1), tmp_bytes);
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut res_big, tmp_bytes);
}
pub fn key_switch_grlwe(module: &Module, c: &mut Ciphertext<VecZnx>, a: &Ciphertext<VecZnx>, b: &Ciphertext<VmpPMat>) {}
pub fn key_switch_rgsw(module: &Module, c: &mut Ciphertext<VecZnx>, a: &Ciphertext<VecZnx>, b: &Ciphertext<VmpPMat>) {}

View File

@@ -1,82 +0,0 @@
use crate::ciphertext::{Ciphertext, new_gadget_ciphertext};
use crate::elem::{Elem, ElemCommon};
use crate::encryptor::{encrypt_rlwe_sk, encrypt_rlwe_sk_tmp_bytes};
use base2k::{Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VmpPMat};
use sampling::source::Source;
pub struct SecretKey(pub Scalar);
impl SecretKey {
pub fn new(module: &Module) -> Self {
SecretKey(Scalar::new(module.n()))
}
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
self.0.fill_ternary_prob(prob, source);
}
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
self.0.fill_ternary_hw(hw, source);
}
pub fn prepare(&self, module: &Module, sk_ppol: &mut SvpPPol) {
module.svp_prepare(sk_ppol, &self.0)
}
}
pub struct PublicKey(pub Elem<VecZnx>);
impl PublicKey {
pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> PublicKey {
PublicKey(Elem::<VecZnx>::new(module, log_base2k, log_q, 2))
}
pub fn gen_thread_safe(
&mut self,
module: &Module,
sk: &SvpPPol,
xe: f64,
xa_source: &mut Source,
xe_source: &mut Source,
tmp_bytes: &mut [u8],
) {
encrypt_rlwe_sk(
module,
&mut self.0,
None,
sk,
xa_source,
xe_source,
xe,
tmp_bytes,
);
}
pub fn gen_thread_safe_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize {
encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
}
}
pub struct SwitchingKey(pub Ciphertext<VmpPMat>);
impl SwitchingKey {
pub fn new(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> SwitchingKey {
SwitchingKey(new_gadget_ciphertext(module, log_base2k, rows, log_q))
}
pub fn n(&self) -> usize {
self.0.n()
}
pub fn rows(&self) -> usize {
self.0.rows()
}
pub fn cols(&self) -> usize {
self.0.cols()
}
pub fn log_base2k(&self) -> usize {
self.0.log_base2k()
}
}

View File

@@ -1,13 +0,0 @@
pub mod automorphism;
pub mod ciphertext;
pub mod decryptor;
pub mod elem;
pub mod encryptor;
pub mod gadget_product;
pub mod key_generator;
pub mod key_switching;
pub mod keys;
pub mod parameters;
pub mod plaintext;
pub mod rgsw_product;
pub mod trace;

View File

@@ -1,88 +0,0 @@
use base2k::module::{BACKEND, Module};
pub const DEFAULT_SIGMA: f64 = 3.2;
pub struct ParametersLiteral {
pub backend: BACKEND,
pub log_n: usize,
pub log_q: usize,
pub log_p: usize,
pub log_base2k: usize,
pub log_scale: usize,
pub xe: f64,
pub xs: usize,
}
pub struct Parameters {
log_n: usize,
log_q: usize,
log_p: usize,
log_scale: usize,
log_base2k: usize,
xe: f64,
xs: usize,
module: Module,
}
impl Parameters {
pub fn new(p: &ParametersLiteral) -> Self {
assert!(
p.log_n + 2 * p.log_base2k <= 53,
"invalid parameters: p.log_n + 2*p.log_base2k > 53"
);
Self {
log_n: p.log_n,
log_q: p.log_q,
log_p: p.log_p,
log_scale: p.log_scale,
log_base2k: p.log_base2k,
xe: p.xe,
xs: p.xs,
module: Module::new(1 << p.log_n, p.backend),
}
}
pub fn n(&self) -> usize {
1 << self.log_n
}
pub fn log_scale(&self) -> usize {
self.log_scale
}
pub fn log_q(&self) -> usize {
self.log_q
}
pub fn log_p(&self) -> usize {
self.log_p
}
pub fn log_qp(&self) -> usize {
self.log_q + self.log_p
}
pub fn cols_q(&self) -> usize {
(self.log_q + self.log_base2k - 1) / self.log_base2k
}
pub fn cols_qp(&self) -> usize {
(self.log_q + self.log_p + self.log_base2k - 1) / self.log_base2k
}
pub fn log_base2k(&self) -> usize {
self.log_base2k
}
pub fn module(&self) -> &Module {
&self.module
}
pub fn xe(&self) -> f64 {
self.xe
}
pub fn xs(&self) -> usize {
self.xs
}
}

View File

@@ -1,109 +0,0 @@
use crate::ciphertext::Ciphertext;
use crate::elem::{Elem, ElemCommon, ElemVecZnx};
use crate::parameters::Parameters;
use base2k::{LAYOUT, Module, VecZnx};
pub struct Plaintext(pub Elem<VecZnx>);
impl Parameters {
pub fn new_plaintext(&self, log_q: usize) -> Plaintext {
Plaintext::new(self.module(), self.log_base2k(), log_q)
}
pub fn bytes_of_plaintext(&self, log_q: usize) -> usize
where {
Elem::<VecZnx>::bytes_of(self.module(), self.log_base2k(), log_q, 1)
}
pub fn plaintext_from_bytes(&self, log_q: usize, bytes: &mut [u8]) -> Plaintext {
Plaintext(Elem::<VecZnx>::from_bytes(
self.module(),
self.log_base2k(),
log_q,
1,
bytes,
))
}
}
impl Plaintext {
pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self {
Self(Elem::<VecZnx>::new(module, log_base2k, log_q, 1))
}
}
impl Plaintext {
pub fn bytes_of(module: &Module, log_base2k: usize, log_q: usize) -> usize {
Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 1)
}
pub fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self {
Self(Elem::<VecZnx>::from_bytes(
module, log_base2k, log_q, 1, bytes,
))
}
pub fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self {
Self(Elem::<VecZnx>::from_bytes_borrow(
module, log_base2k, log_q, 1, bytes,
))
}
pub fn as_ciphertext(&self) -> Ciphertext<VecZnx> {
unsafe { Ciphertext::<VecZnx>(std::ptr::read(&self.0)) }
}
}
impl ElemCommon<VecZnx> for Plaintext {
fn n(&self) -> usize {
self.0.n()
}
fn log_n(&self) -> usize {
self.elem().log_n()
}
fn log_q(&self) -> usize {
self.0.log_q
}
fn elem(&self) -> &Elem<VecZnx> {
&self.0
}
fn elem_mut(&mut self) -> &mut Elem<VecZnx> {
&mut self.0
}
fn size(&self) -> usize {
self.elem().size()
}
fn layout(&self) -> LAYOUT {
self.elem().layout()
}
fn rows(&self) -> usize {
self.0.rows()
}
fn cols(&self) -> usize {
self.0.cols()
}
fn at(&self, i: usize) -> &VecZnx {
self.0.at(i)
}
fn at_mut(&mut self, i: usize) -> &mut VecZnx {
self.0.at_mut(i)
}
fn log_base2k(&self) -> usize {
self.0.log_base2k()
}
fn log_scale(&self) -> usize {
self.0.log_scale()
}
}

View File

@@ -1,300 +0,0 @@
use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters};
use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement};
use std::cmp::min;
impl Parameters {
pub fn rgsw_product_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
rgsw_product_tmp_bytes(
self.module(),
self.log_base2k(),
res_logq,
in_logq,
gct_logq,
)
}
}
pub fn rgsw_product_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k;
let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k;
let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k;
return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols)
+ module.bytes_of_vec_znx_dft(1, std::cmp::min(res_cols, in_cols))
+ 2 * module.bytes_of_vec_znx_dft(1, gct_cols);
}
pub fn rgsw_product(
module: &Module,
c: &mut Ciphertext<VecZnx>,
a: &Ciphertext<VecZnx>,
b: &Ciphertext<VmpPMat>,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
#[cfg(debug_assertions)]
{
assert!(b_cols <= b.cols());
assert_eq!(c.size(), 2);
assert_eq!(a.size(), 2);
assert_eq!(b.size(), 4);
assert!(tmp_bytes.len() >= rgsw_product_tmp_bytes(module, c.cols(), a.cols(), min(b.rows(), a.cols()), b_cols));
assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, a.cols()));
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, a.cols(), tmp_bytes_ai_dft);
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c0_dft);
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c1_dft);
let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big();
let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big();
module.vec_znx_dft(&mut ai_dft, a.at(0));
module.vmp_apply_dft_to_dft(&mut c0_dft, &ai_dft, b.at(0), tmp_bytes);
module.vmp_apply_dft_to_dft(&mut c1_dft, &ai_dft, b.at(1), tmp_bytes);
module.vec_znx_dft(&mut ai_dft, a.at(1));
module.vmp_apply_dft_to_dft_add(&mut c0_dft, &ai_dft, b.at(2), tmp_bytes);
module.vmp_apply_dft_to_dft_add(&mut c1_dft, &ai_dft, b.at(3), tmp_bytes);
module.vec_znx_idft_tmp_a(&mut c0_big, &mut c0_dft);
module.vec_znx_idft_tmp_a(&mut c1_big, &mut c1_dft);
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(0), &mut c0_big, tmp_bytes);
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut c1_big, tmp_bytes);
}
pub fn rgsw_product_inplace(
module: &Module,
a: &mut Ciphertext<VecZnx>,
b: &Ciphertext<VmpPMat>,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
#[cfg(debug_assertions)]
{
assert!(b_cols <= b.cols());
assert_eq!(a.size(), 2);
assert_eq!(b.size(), 4);
assert!(tmp_bytes.len() >= rgsw_product_tmp_bytes(module, a.cols(), a.cols(), min(b.rows(), a.cols()), b_cols));
assert_alignement(tmp_bytes.as_ptr());
}
let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, a.cols()));
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, a.cols(), tmp_bytes_ai_dft);
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c0_dft);
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c1_dft);
let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big();
let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big();
module.vec_znx_dft(&mut ai_dft, a.at(0));
module.vmp_apply_dft_to_dft(&mut c0_dft, &ai_dft, b.at(0), tmp_bytes);
module.vmp_apply_dft_to_dft(&mut c1_dft, &ai_dft, b.at(1), tmp_bytes);
module.vec_znx_dft(&mut ai_dft, a.at(1));
module.vmp_apply_dft_to_dft_add(&mut c0_dft, &ai_dft, b.at(2), tmp_bytes);
module.vmp_apply_dft_to_dft_add(&mut c1_dft, &ai_dft, b.at(3), tmp_bytes);
module.vec_znx_idft_tmp_a(&mut c0_big, &mut c0_dft);
module.vec_znx_idft_tmp_a(&mut c1_big, &mut c1_dft);
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut c0_big, tmp_bytes);
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut c1_big, tmp_bytes);
}
#[cfg(test)]
mod test {
use crate::{
ciphertext::{Ciphertext, new_rgsw_ciphertext},
decryptor::decrypt_rlwe,
elem::ElemCommon,
encryptor::{encrypt_rgsw_sk, encrypt_rlwe_sk},
keys::SecretKey,
parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral},
plaintext::Plaintext,
rgsw_product::rgsw_product_inplace,
};
use base2k::{BACKEND, Encoding, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, VmpPMat, alloc_aligned};
use sampling::source::{Source, new_seed};
#[test]
fn test_rgsw_product() {
let log_base2k: usize = 10;
let log_q: usize = 50;
let log_p: usize = 15;
// Basic parameters with enough limbs to test edge cases
let params_lit: ParametersLiteral = ParametersLiteral {
backend: BACKEND::FFT64,
log_n: 12,
log_q: log_q,
log_p: log_p,
log_base2k: log_base2k,
log_scale: 20,
xe: 3.2,
xs: 1 << 11,
};
let params: Parameters = Parameters::new(&params_lit);
let module: &Module = params.module();
let log_q: usize = params.log_q();
let log_qp: usize = params.log_qp();
let gct_rows: usize = params.cols_q();
let gct_cols: usize = params.cols_qp();
// scratch space
let mut tmp_bytes: Vec<u8> = alloc_aligned(
params.decrypt_rlwe_tmp_byte(log_q)
| params.encrypt_rlwe_sk_tmp_bytes(log_q)
| params.rgsw_product_tmp_bytes(log_q, log_q, log_qp)
| params.encrypt_rgsw_sk_tmp_bytes(gct_rows, log_qp),
);
// Samplers for public and private randomness
let mut source_xe: Source = Source::new(new_seed());
let mut source_xa: Source = Source::new(new_seed());
let mut source_xs: Source = Source::new(new_seed());
let mut sk: SecretKey = SecretKey::new(module);
sk.fill_ternary_hw(params.xs(), &mut source_xs);
let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol();
module.svp_prepare(&mut sk_svp_ppol, &sk.0);
let mut ct_rgsw: Ciphertext<VmpPMat> = new_rgsw_ciphertext(module, log_base2k, gct_rows, log_qp);
let k: i64 = 3;
// X^k
let m: Scalar = module.new_scalar();
let data: &mut [i64] = m.raw_mut();
data[k as usize] = 1;
encrypt_rgsw_sk(
module,
&mut ct_rgsw,
&m,
&sk_svp_ppol,
&mut source_xa,
&mut source_xe,
DEFAULT_SIGMA,
&mut tmp_bytes,
);
let log_k: usize = 2 * log_base2k;
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
let mut pt: Plaintext = params.new_plaintext(log_q);
let mut pt_rotate: Plaintext = params.new_plaintext(log_q);
pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32);
module.vec_znx_rotate(k, pt_rotate.at_mut(0), pt.at_mut(0));
encrypt_rlwe_sk(
module,
&mut ct.elem_mut(),
Some(pt.at(0)),
&sk_svp_ppol,
&mut source_xa,
&mut source_xe,
params.xe(),
&mut tmp_bytes,
);
rgsw_product_inplace(module, &mut ct, &ct_rgsw, gct_cols, &mut tmp_bytes);
decrypt_rlwe(
module,
pt.elem_mut(),
ct.elem(),
&sk_svp_ppol,
&mut tmp_bytes,
);
module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_rotate.at(0));
// pt.at(0).print(pt.cols(), 16);
let noise_have: f64 = pt.at(0).std(0, log_base2k).log2();
let var_msg: f64 = 1f64 / params.n() as f64; // X^{k}
let var_a0_err: f64 = params.xe() * params.xe();
let var_a1_err: f64 = 1f64 / 12f64;
let noise_pred: f64 = params.noise_rgsw_product(var_msg, var_a0_err, var_a1_err, ct.log_q(), ct_rgsw.log_q());
println!("noise_pred: {}", noise_pred);
println!("noise_have: {}", noise_have);
assert!(noise_have <= noise_pred + 1.0);
}
}
impl Parameters {
pub fn noise_rgsw_product(&self, var_msg: f64, var_a0_err: f64, var_a1_err: f64, a_logq: usize, b_logq: usize) -> f64 {
let n: f64 = self.n() as f64;
let var_xs: f64 = self.xs() as f64;
let var_gct_err_lhs: f64;
let var_gct_err_rhs: f64;
if b_logq < self.log_qp() {
let var_round: f64 = 1f64 / 12f64;
var_gct_err_lhs = var_round;
var_gct_err_rhs = var_round;
} else {
var_gct_err_lhs = self.xe() * self.xe();
var_gct_err_rhs = 0f64;
}
noise_rgsw_product(
n,
self.log_base2k(),
var_xs,
var_msg,
var_a0_err,
var_a1_err,
var_gct_err_lhs,
var_gct_err_rhs,
a_logq,
b_logq,
)
}
}
pub fn noise_rgsw_product(
n: f64,
log_base2k: usize,
var_xs: f64,
var_msg: f64,
var_a0_err: f64,
var_a1_err: f64,
var_gct_err_lhs: f64,
var_gct_err_rhs: f64,
a_logq: usize,
b_logq: usize,
) -> f64 {
let a_logq: usize = min(a_logq, b_logq);
let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k;
let b_scale = 2.0f64.powi(b_logq as i32);
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
let base: f64 = (1 << (log_base2k)) as f64;
let var_base: f64 = base * base / 12f64;
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
let mut noise: f64 = 2.0 * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
noise += var_msg * var_a0_err * a_scale * a_scale * n;
noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs;
noise = noise.sqrt();
noise /= b_scale;
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
}

View File

@@ -1,113 +0,0 @@
use base2k::{alloc_aligned, SvpPPol, SvpPPolOps, VecZnx, BACKEND};
use sampling::source::{Source, new_seed};
use crate::{ciphertext::Ciphertext, decryptor::decrypt_rlwe, elem::ElemCommon, encryptor::encrypt_rlwe_sk, keys::SecretKey, parameters::{Parameters, ParametersLiteral, DEFAULT_SIGMA}, plaintext::Plaintext};
pub struct Context{
pub params: Parameters,
pub sk0: SecretKey,
pub sk0_ppol:SvpPPol,
pub sk1: SecretKey,
pub sk1_ppol: SvpPPol,
pub tmp_bytes: Vec<u8>,
}
impl Context{
pub fn new(log_n: usize, log_base2k: usize, log_q: usize, log_p: usize) -> Self{
let params_lit: ParametersLiteral = ParametersLiteral {
backend: BACKEND::FFT64,
log_n: log_n,
log_q: log_q,
log_p: log_p,
log_base2k: log_base2k,
log_scale: 20,
xe: DEFAULT_SIGMA,
xs: 1 << (log_n-1),
};
let params: Parameters =Parameters::new(&params_lit);
let module = params.module();
let log_q: usize = params.log_q();
let mut source_xs: Source = Source::new(new_seed());
let mut sk0: SecretKey = SecretKey::new(module);
sk0.fill_ternary_hw(params.xs(), &mut source_xs);
let mut sk0_ppol: base2k::SvpPPol = module.new_svp_ppol();
module.svp_prepare(&mut sk0_ppol, &sk0.0);
let mut sk1: SecretKey = SecretKey::new(module);
sk1.fill_ternary_hw(params.xs(), &mut source_xs);
let mut sk1_ppol: base2k::SvpPPol = module.new_svp_ppol();
module.svp_prepare(&mut sk1_ppol, &sk1.0);
let tmp_bytes: Vec<u8> = alloc_aligned(params.decrypt_rlwe_tmp_byte(log_q)| params.encrypt_rlwe_sk_tmp_bytes(log_q));
Context{
params: params,
sk0: sk0,
sk0_ppol: sk0_ppol,
sk1: sk1,
sk1_ppol: sk1_ppol,
tmp_bytes: tmp_bytes,
}
}
pub fn encrypt_rlwe_sk0(&mut self, pt: &Plaintext, ct: &mut Ciphertext<VecZnx>){
let mut source_xe: Source = Source::new(new_seed());
let mut source_xa: Source = Source::new(new_seed());
encrypt_rlwe_sk(
self.params.module(),
ct.elem_mut(),
Some(pt.elem()),
&self.sk0_ppol,
&mut source_xa,
&mut source_xe,
self.params.xe(),
&mut self.tmp_bytes,
);
}
pub fn encrypt_rlwe_sk1(&mut self, ct: &mut Ciphertext<VecZnx>, pt: &Plaintext){
let mut source_xe: Source = Source::new(new_seed());
let mut source_xa: Source = Source::new(new_seed());
encrypt_rlwe_sk(
self.params.module(),
ct.elem_mut(),
Some(pt.elem()),
&self.sk1_ppol,
&mut source_xa,
&mut source_xe,
self.params.xe(),
&mut self.tmp_bytes,
);
}
pub fn decrypt_sk0(&mut self, pt: &mut Plaintext, ct: &Ciphertext<VecZnx>){
decrypt_rlwe(
self.params.module(),
pt.elem_mut(),
ct.elem(),
&self.sk0_ppol,
&mut self.tmp_bytes,
);
}
pub fn decrypt_sk1(&mut self, pt: &mut Plaintext, ct: &Ciphertext<VecZnx>){
decrypt_rlwe(
self.params.module(),
pt.elem_mut(),
ct.elem(),
&self.sk1_ppol,
&mut self.tmp_bytes,
);
}
}

View File

@@ -1,236 +0,0 @@
use crate::{automorphism::AutomorphismKey, ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters};
use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMatOps, assert_alignement};
use std::collections::HashMap;
pub fn trace_galois_elements(module: &Module) -> Vec<i64> {
let mut gal_els: Vec<i64> = Vec::new();
(0..module.log_n()).for_each(|i| {
if i == 0 {
gal_els.push(-1);
} else {
gal_els.push(module.galois_element(1 << (i - 1)));
}
});
gal_els
}
impl Parameters {
pub fn trace_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
self.automorphism_tmp_bytes(res_logq, in_logq, gct_logq)
}
}
pub fn trace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
+ 2 * module.bytes_of_vec_znx_dft(1, std::cmp::min(c_cols, a_cols));
}
pub fn trace_inplace(
module: &Module,
a: &mut Ciphertext<VecZnx>,
start: usize,
end: usize,
b: &HashMap<i64, AutomorphismKey>,
b_cols: usize,
tmp_bytes: &mut [u8],
) {
let cols: usize = a.cols();
let b_rows: usize;
if let Some((_, key)) = b.iter().next() {
b_rows = key.value.rows();
#[cfg(debug_assertions)]
{
println!("{} {}", b_cols, key.value.cols());
assert!(b_cols <= key.value.cols())
}
} else {
panic!("b: HashMap<i64, AutomorphismKey>, is empty")
}
#[cfg(debug_assertions)]
{
assert!(start <= end);
assert!(end <= module.n());
assert!(tmp_bytes.len() >= trace_tmp_bytes(module, cols, cols, b_rows, b_cols));
assert_alignement(tmp_bytes.as_ptr());
}
let cols: usize = std::cmp::min(b_cols, a.cols());
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft);
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_res_dft);
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
let log_base2k: usize = a.log_base2k();
(start..end).for_each(|i| {
a.at_mut(0).rsh(log_base2k, 1, tmp_bytes);
a.at_mut(1).rsh(log_base2k, 1, tmp_bytes);
let p: i64;
if i == 0 {
p = -1;
} else {
p = module.galois_element(1 << (i - 1));
}
if let Some(key) = b.get(&p) {
module.vec_znx_dft(&mut a1_dft, a.at(1));
// a[0] = NORMALIZE(a[0] + AUTO(a[0] + IDFT(<DFT(a[1]), key[0]>)))
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, key.value.at(0), tmp_bytes);
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
module.vec_znx_big_automorphism_inplace(p, &mut res_big);
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut res_big, tmp_bytes);
// a[1] = NORMALIZE(a[1] + AUTO(IDFT(<DFT(a[1]), key[1]>)))
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, key.value.at(1), tmp_bytes);
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
module.vec_znx_big_automorphism_inplace(p, &mut res_big);
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(1));
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut res_big, tmp_bytes);
} else {
panic!("b[{}] is empty", p)
}
})
}
#[cfg(test)]
mod test {
use super::{trace_galois_elements, trace_inplace};
use crate::{
automorphism::AutomorphismKey,
ciphertext::Ciphertext,
decryptor::decrypt_rlwe,
elem::ElemCommon,
encryptor::encrypt_rlwe_sk,
keys::SecretKey,
parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral},
plaintext::Plaintext,
};
use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, alloc_aligned};
use sampling::source::{Source, new_seed};
use std::collections::HashMap;
#[test]
fn test_trace_inplace() {
let log_base2k: usize = 10;
let log_q: usize = 50;
let log_p: usize = 15;
// Basic parameters with enough limbs to test edge cases
let params_lit: ParametersLiteral = ParametersLiteral {
backend: BACKEND::FFT64,
log_n: 12,
log_q: log_q,
log_p: log_p,
log_base2k: log_base2k,
log_scale: 20,
xe: 3.2,
xs: 1 << 11,
};
let params: Parameters = Parameters::new(&params_lit);
let module: &Module = params.module();
let log_q: usize = params.log_q();
let log_qp: usize = params.log_qp();
let gct_rows: usize = params.cols_q();
let gct_cols: usize = params.cols_qp();
// scratch space
let mut tmp_bytes: Vec<u8> = alloc_aligned(
params.decrypt_rlwe_tmp_byte(log_q)
| params.encrypt_rlwe_sk_tmp_bytes(log_q)
| params.automorphism_key_new_tmp_bytes(gct_rows, log_qp)
| params.automorphism_tmp_bytes(log_q, log_q, log_qp),
);
// Samplers for public and private randomness
let mut source_xe: Source = Source::new(new_seed());
let mut source_xa: Source = Source::new(new_seed());
let mut source_xs: Source = Source::new(new_seed());
let mut sk: SecretKey = SecretKey::new(module);
sk.fill_ternary_hw(params.xs(), &mut source_xs);
let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol();
module.svp_prepare(&mut sk_svp_ppol, &sk.0);
let gal_els: Vec<i64> = trace_galois_elements(module);
let auto_keys: HashMap<i64, AutomorphismKey> = AutomorphismKey::new_many(
module,
&gal_els,
&sk,
log_base2k,
gct_rows,
log_qp,
&mut source_xa,
&mut source_xe,
DEFAULT_SIGMA,
&mut tmp_bytes,
);
let mut data: Vec<i64> = vec![0i64; params.n()];
data.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = 1 + i as i64);
let log_k: usize = 2 * log_base2k;
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
let mut pt: Plaintext = params.new_plaintext(log_q);
pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32);
pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes);
pt.at(0).decode_vec_i64(0, log_base2k, log_k, &mut data);
pt.at(0).print(0, pt.cols(), 16);
encrypt_rlwe_sk(
module,
&mut ct.elem_mut(),
Some(pt.at(0)),
&sk_svp_ppol,
&mut source_xa,
&mut source_xe,
params.xe(),
&mut tmp_bytes,
);
trace_inplace(module, &mut ct, 0, 4, &auto_keys, gct_cols, &mut tmp_bytes);
trace_inplace(
module,
&mut ct,
4,
module.log_n(),
&auto_keys,
gct_cols,
&mut tmp_bytes,
);
// pt = dec(auto(ct)) - auto(pt)
decrypt_rlwe(
module,
pt.elem_mut(),
ct.elem(),
&sk_svp_ppol,
&mut tmp_bytes,
);
pt.at(0).print(0, pt.cols(), 16);
pt.at(0).decode_vec_i64(0, log_base2k, log_k, &mut data);
println!("trace: {:?}", &data[..16]);
}
}