mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Merge pull request #20 from phantomzone-org/jay/restructure-base2k
Major refactoring on memory layout, memory safety & basic functionalities
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
[workspace]
|
||||
members = ["base2k", "rlwe", "sampling", "utils"]
|
||||
|
||||
members = ["base2k", "core", "sampling", "utils"]
|
||||
resolver = "3"
|
||||
|
||||
[workspace.dependencies]
|
||||
rug = "1.27"
|
||||
|
||||
11
base2k/.vscode/settings.json
vendored
Normal file
11
base2k/.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"github.copilot.enable": {
|
||||
"*": false,
|
||||
"plaintext": false,
|
||||
"markdown": false,
|
||||
"scminput": false
|
||||
},
|
||||
"files.associations": {
|
||||
"random": "c"
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
[package]
|
||||
name = "base2k"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
rug = {workspace = true}
|
||||
|
||||
@@ -3,10 +3,11 @@ use std::path::absolute;
|
||||
fn main() {
|
||||
println!(
|
||||
"cargo:rustc-link-search=native={}",
|
||||
absolute("./spqlios-arithmetic/build/spqlios")
|
||||
absolute("spqlios-arithmetic/build/spqlios")
|
||||
.unwrap()
|
||||
.to_str()
|
||||
.unwrap()
|
||||
);
|
||||
println!("cargo:rustc-link-lib=static=spqlios"); //"cargo:rustc-link-lib=dylib=spqlios"
|
||||
println!("cargo:rustc-link-lib=static=spqlios");
|
||||
// println!("cargo:rustc-link-lib=dylib=spqlios")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use base2k::{
|
||||
BACKEND, Encoding, Infos, Module, Sampling, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft,
|
||||
VecZnxDftOps, VecZnxOps, alloc_aligned,
|
||||
AddNormal, Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc,
|
||||
ScalarZnxDftOps, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
|
||||
VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos,
|
||||
};
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
@@ -8,89 +9,125 @@ use sampling::source::Source;
|
||||
fn main() {
|
||||
let n: usize = 16;
|
||||
let log_base2k: usize = 18;
|
||||
let cols: usize = 3;
|
||||
let msg_cols: usize = 2;
|
||||
let log_scale: usize = msg_cols * log_base2k - 5;
|
||||
let module: Module = Module::new(n, BACKEND::FFT64);
|
||||
let ct_size: usize = 3;
|
||||
let msg_size: usize = 2;
|
||||
let log_scale: usize = msg_size * log_base2k - 5;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
|
||||
let mut carry: Vec<u8> = alloc_aligned(module.vec_znx_big_normalize_tmp_bytes());
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(module.vec_znx_big_normalize_tmp_bytes());
|
||||
|
||||
let seed: [u8; 32] = [0; 32];
|
||||
let mut source: Source = Source::new(seed);
|
||||
|
||||
let mut res: VecZnx = module.new_vec_znx(1, cols);
|
||||
|
||||
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
|
||||
let mut s: Scalar = Scalar::new(n);
|
||||
s.fill_ternary_prob(0.5, &mut source);
|
||||
let mut s: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
s.fill_ternary_prob(0, 0.5, &mut source);
|
||||
|
||||
// Buffer to store s in the DFT domain
|
||||
let mut s_ppol: SvpPPol = module.new_svp_ppol();
|
||||
let mut s_dft: ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(s.cols());
|
||||
|
||||
// s_ppol <- DFT(s)
|
||||
module.svp_prepare(&mut s_ppol, &s);
|
||||
// s_dft <- DFT(s)
|
||||
module.svp_prepare(&mut s_dft, 0, &s, 0);
|
||||
|
||||
// a <- Z_{2^prec}[X]/(X^{N}+1)
|
||||
let mut a: VecZnx = module.new_vec_znx(1, cols);
|
||||
module.fill_uniform(log_base2k, &mut a, cols, &mut source);
|
||||
// Allocates a VecZnx with two columns: ct=(0, 0)
|
||||
let mut ct: VecZnx<Vec<u8>> = module.new_vec_znx(
|
||||
2, // Number of columns
|
||||
ct_size, // Number of small poly per column
|
||||
);
|
||||
|
||||
// Scratch space for DFT values
|
||||
let mut buf_dft: VecZnxDft = module.new_vec_znx_dft(1, a.cols());
|
||||
// Fill the second column with random values: ct = (0, a)
|
||||
ct.fill_uniform(log_base2k, 1, ct_size, &mut source);
|
||||
|
||||
// Applies buf_dft <- s * a
|
||||
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
|
||||
let mut buf_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_size);
|
||||
|
||||
// Alias scratch space
|
||||
let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big();
|
||||
module.vec_znx_dft(&mut buf_dft, 0, &ct, 1);
|
||||
|
||||
// buf_big <- IDFT(buf_dft) (not normalized)
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
|
||||
// Applies DFT(ct[1]) * DFT(s)
|
||||
module.svp_apply_inplace(
|
||||
&mut buf_dft, // DFT(ct[1] * s)
|
||||
0, // Selects the first column of res
|
||||
&s_dft, // DFT(s)
|
||||
0, // Selects the first column of s_dft
|
||||
);
|
||||
|
||||
let mut m: VecZnx = module.new_vec_znx(1, msg_cols);
|
||||
// Alias scratch space (VecZnxDft<B> is always at least as big as VecZnxBig<B>)
|
||||
|
||||
// BIG(ct[1] * s) <- IDFT(DFT(ct[1] * s)) (not normalized)
|
||||
let mut buf_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_size);
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
|
||||
// Creates a plaintext: VecZnx with 1 column
|
||||
let mut m = module.new_vec_znx(
|
||||
1, // Number of columns
|
||||
msg_size, // Number of small polynomials
|
||||
);
|
||||
let mut want: Vec<i64> = vec![0; n];
|
||||
want.iter_mut()
|
||||
.for_each(|x| *x = source.next_u64n(16, 15) as i64);
|
||||
|
||||
// m
|
||||
m.encode_vec_i64(0, log_base2k, log_scale, &want, 4);
|
||||
m.normalize(log_base2k, &mut carry);
|
||||
module.vec_znx_normalize_inplace(log_base2k, &mut m, 0, scratch.borrow());
|
||||
|
||||
// buf_big <- m - buf_big
|
||||
module.vec_znx_big_sub_small_a_inplace(&mut buf_big, &m);
|
||||
|
||||
// b <- normalize(buf_big) + e
|
||||
let mut b: VecZnx = module.new_vec_znx(1, cols);
|
||||
module.vec_znx_big_normalize(log_base2k, &mut b, &buf_big, &mut carry);
|
||||
module.add_normal(
|
||||
log_base2k,
|
||||
&mut b,
|
||||
log_base2k * cols,
|
||||
&mut source,
|
||||
3.2,
|
||||
19.0,
|
||||
// m - BIG(ct[1] * s)
|
||||
module.vec_znx_big_sub_small_b_inplace(
|
||||
&mut buf_big,
|
||||
0, // Selects the first column of the receiver
|
||||
&m,
|
||||
0, // Selects the first column of the message
|
||||
);
|
||||
|
||||
// Decrypt
|
||||
// Normalizes back to VecZnx
|
||||
// ct[0] <- m - BIG(c1 * s)
|
||||
module.vec_znx_big_normalize(
|
||||
log_base2k,
|
||||
&mut ct,
|
||||
0, // Selects the first column of ct (ct[0])
|
||||
&buf_big,
|
||||
0, // Selects the first column of buf_big
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// buf_big <- a * s
|
||||
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
|
||||
// Add noise to ct[0]
|
||||
// ct[0] <- ct[0] + e
|
||||
ct.add_normal(
|
||||
log_base2k,
|
||||
0, // Selects the first column of ct (ct[0])
|
||||
log_base2k * ct_size, // Scaling of the noise: 2^{-log_base2k * limbs}
|
||||
&mut source,
|
||||
3.2, // Standard deviation
|
||||
19.0, // Truncatation bound
|
||||
);
|
||||
|
||||
// buf_big <- a * s + b
|
||||
module.vec_znx_big_add_small_inplace(&mut buf_big, &b);
|
||||
// Final ciphertext: ct = (-a * s + m + e, a)
|
||||
|
||||
// res <- normalize(buf_big)
|
||||
module.vec_znx_big_normalize(log_base2k, &mut res, &buf_big, &mut carry);
|
||||
// Decryption
|
||||
|
||||
// DFT(ct[1] * s)
|
||||
module.vec_znx_dft(&mut buf_dft, 0, &ct, 1);
|
||||
module.svp_apply_inplace(
|
||||
&mut buf_dft,
|
||||
0, // Selects the first column of res.
|
||||
&s_dft,
|
||||
0,
|
||||
);
|
||||
|
||||
// BIG(c1 * s) = IDFT(DFT(c1 * s))
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, 0, &mut buf_dft, 0);
|
||||
|
||||
// BIG(c1 * s) + ct[0]
|
||||
module.vec_znx_big_add_small_inplace(&mut buf_big, 0, &ct, 0);
|
||||
|
||||
// m + e <- BIG(ct[1] * s + ct[0])
|
||||
let mut res = module.new_vec_znx(1, ct_size);
|
||||
module.vec_znx_big_normalize(log_base2k, &mut res, 0, &buf_big, 0, scratch.borrow());
|
||||
|
||||
// have = m * 2^{log_scale} + e
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
res.decode_vec_i64(0, log_base2k, res.cols() * log_base2k, &mut have);
|
||||
res.decode_vec_i64(0, log_base2k, res.size() * log_base2k, &mut have);
|
||||
|
||||
let scale: f64 = (1 << (res.cols() * log_base2k - log_scale)) as f64;
|
||||
let scale: f64 = (1 << (res.size() * log_base2k - log_scale)) as f64;
|
||||
izip!(want.iter(), have.iter())
|
||||
.enumerate()
|
||||
.for_each(|(i, (a, b))| {
|
||||
println!("{}: {} {}", i, a, (*b as f64) / scale);
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
use base2k::{
|
||||
BACKEND, Encoding, Infos, Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps,
|
||||
alloc_aligned,
|
||||
};
|
||||
|
||||
fn main() {
|
||||
let log_n: i32 = 5;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module = Module::new(n, BACKEND::FFT64);
|
||||
let log_base2k: usize = 15;
|
||||
let cols: usize = 5;
|
||||
let log_k: usize = log_base2k * cols - 5;
|
||||
|
||||
let rows: usize = cols;
|
||||
let cols: usize = cols + 1;
|
||||
|
||||
// Maximum size of the byte scratch needed
|
||||
let tmp_bytes: usize = module.vmp_prepare_tmp_bytes(rows, cols) | module.vmp_apply_dft_tmp_bytes(cols, cols, rows, cols);
|
||||
|
||||
let mut buf: Vec<u8> = alloc_aligned(tmp_bytes);
|
||||
|
||||
let mut a_values: Vec<i64> = vec![i64::default(); n];
|
||||
a_values[1] = (1 << log_base2k) + 1;
|
||||
|
||||
let mut a: VecZnx = module.new_vec_znx(1, rows);
|
||||
a.encode_vec_i64(0, log_base2k, log_k, &a_values, 32);
|
||||
a.normalize(log_base2k, &mut buf);
|
||||
|
||||
a.print(0, a.cols(), n);
|
||||
println!();
|
||||
|
||||
let mut vmp_pmat: VmpPMat = module.new_vmp_pmat(1, rows, cols);
|
||||
|
||||
(0..a.cols()).for_each(|row_i| {
|
||||
let mut tmp: VecZnx = module.new_vec_znx(1, cols);
|
||||
tmp.at_mut(row_i)[1] = 1 as i64;
|
||||
module.vmp_prepare_row(&mut vmp_pmat, tmp.raw(), row_i, &mut buf);
|
||||
});
|
||||
|
||||
let mut c_dft: VecZnxDft = module.new_vec_znx_dft(1, cols);
|
||||
module.vmp_apply_dft(&mut c_dft, &a, &vmp_pmat, &mut buf);
|
||||
|
||||
let mut c_big: VecZnxBig = c_dft.as_vec_znx_big();
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, &mut c_dft);
|
||||
|
||||
let mut res: VecZnx = module.new_vec_znx(1, rows);
|
||||
module.vec_znx_big_normalize(log_base2k, &mut res, &c_big, &mut buf);
|
||||
|
||||
let mut values_res: Vec<i64> = vec![i64::default(); n];
|
||||
res.decode_vec_i64(0, log_base2k, log_k, &mut values_res);
|
||||
|
||||
res.print(0, res.cols(), n);
|
||||
|
||||
module.free();
|
||||
|
||||
println!("{:?}", values_res)
|
||||
}
|
||||
Submodule base2k/spqlios-arithmetic updated: e3d3247335...b919282c9b
@@ -1,5 +1,6 @@
|
||||
use crate::ffi::znx::znx_zero_i64_ref;
|
||||
use crate::{Infos, VecZnx};
|
||||
use crate::znx_base::{ZnxView, ZnxViewMut};
|
||||
use crate::{VecZnx, znx_base::ZnxInfos};
|
||||
use itertools::izip;
|
||||
use rug::{Assign, Float};
|
||||
use std::cmp::min;
|
||||
@@ -9,129 +10,141 @@ pub trait Encoding {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `poly_idx`: the index of the poly where to encode the data.
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `log_k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `data`: data to encode on the receiver.
|
||||
/// * `log_max`: base two logarithm of the infinity norm of the input data.
|
||||
fn encode_vec_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize);
|
||||
|
||||
/// decode a vector of i64 from the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `poly_idx`: the index of the poly where to encode the data.
|
||||
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `log_k`: base two logarithm of the scaling of the data.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]);
|
||||
|
||||
/// decode a vector of Float from the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `poly_idx`: the index of the poly where to encode the data.
|
||||
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_float(&self, poly_idx: usize, log_base2k: usize, data: &mut [Float]);
|
||||
fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize);
|
||||
|
||||
/// encodes a single i64 on the receiver at the given index.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `poly_idx`: the index of the poly where to encode the data.
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `log_k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `i`: index of the coefficient on which to encode the data.
|
||||
/// * `data`: data to encode on the receiver.
|
||||
/// * `log_max`: base two logarithm of the infinity norm of the input data.
|
||||
fn encode_coeff_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize);
|
||||
fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, data: i64, log_max: usize);
|
||||
}
|
||||
|
||||
pub trait Decoding {
|
||||
/// decode a vector of i64 from the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `log_k`: base two logarithm of the scaling of the data.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]);
|
||||
|
||||
/// decode a vector of Float from the receiver.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_vec_float(&self, col_i: usize, log_base2k: usize, data: &mut [Float]);
|
||||
|
||||
/// decode a single of i64 from the receiver at the given index.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `poly_idx`: the index of the poly where to encode the data.
|
||||
/// * `col_i`: the index of the poly where to encode the data.
|
||||
/// * `log_base2k`: base two negative logarithm decomposition of the receiver.
|
||||
/// * `log_k`: base two negative logarithm of the scaling of the data.
|
||||
/// * `i`: index of the coefficient to decode.
|
||||
/// * `data`: data to decode from the receiver.
|
||||
fn decode_coeff_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64;
|
||||
fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64;
|
||||
}
|
||||
|
||||
impl Encoding for VecZnx {
|
||||
fn encode_vec_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) {
|
||||
encode_vec_i64(self, poly_idx, log_base2k, log_k, data, log_max)
|
||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> Encoding for VecZnx<D> {
|
||||
fn encode_vec_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) {
|
||||
encode_vec_i64(self, col_i, log_base2k, log_k, data, log_max)
|
||||
}
|
||||
|
||||
fn decode_vec_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) {
|
||||
decode_vec_i64(self, poly_idx, log_base2k, log_k, data)
|
||||
}
|
||||
|
||||
fn decode_vec_float(&self, poly_idx: usize, log_base2k: usize, data: &mut [Float]) {
|
||||
decode_vec_float(self, poly_idx, log_base2k, data)
|
||||
}
|
||||
|
||||
fn encode_coeff_i64(&mut self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) {
|
||||
encode_coeff_i64(self, poly_idx, log_base2k, log_k, i, value, log_max)
|
||||
}
|
||||
|
||||
fn decode_coeff_i64(&self, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 {
|
||||
decode_coeff_i64(self, poly_idx, log_base2k, log_k, i)
|
||||
fn encode_coeff_i64(&mut self, col_i: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) {
|
||||
encode_coeff_i64(self, col_i, log_base2k, log_k, i, value, log_max)
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_vec_i64(a: &mut VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) {
|
||||
let cols: usize = (log_k + log_base2k - 1) / log_base2k;
|
||||
impl<D: AsRef<[u8]>> Decoding for VecZnx<D> {
|
||||
fn decode_vec_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) {
|
||||
decode_vec_i64(self, col_i, log_base2k, log_k, data)
|
||||
}
|
||||
|
||||
fn decode_vec_float(&self, col_i: usize, log_base2k: usize, data: &mut [Float]) {
|
||||
decode_vec_float(self, col_i, log_base2k, data)
|
||||
}
|
||||
|
||||
fn decode_coeff_i64(&self, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 {
|
||||
decode_coeff_i64(self, col_i, log_base2k, log_k, i)
|
||||
}
|
||||
}
|
||||
|
||||
fn encode_vec_i64<D: AsMut<[u8]> + AsRef<[u8]>>(
|
||||
a: &mut VecZnx<D>,
|
||||
col_i: usize,
|
||||
log_base2k: usize,
|
||||
log_k: usize,
|
||||
data: &[i64],
|
||||
log_max: usize,
|
||||
) {
|
||||
let size: usize = (log_k + log_base2k - 1) / log_base2k;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
cols <= a.cols(),
|
||||
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}",
|
||||
cols,
|
||||
a.cols()
|
||||
size <= a.size(),
|
||||
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.size()={}",
|
||||
size,
|
||||
a.size()
|
||||
);
|
||||
assert!(poly_idx < a.size);
|
||||
assert!(col_i < a.cols());
|
||||
assert!(data.len() <= a.n())
|
||||
}
|
||||
|
||||
let data_len: usize = data.len();
|
||||
let log_k_rem: usize = log_base2k - (log_k % log_base2k);
|
||||
|
||||
(0..a.cols()).for_each(|i| unsafe {
|
||||
znx_zero_i64_ref(a.n() as u64, a.at_poly_mut_ptr(poly_idx, i));
|
||||
// Zeroes coefficients of the i-th column
|
||||
(0..a.size()).for_each(|i| unsafe {
|
||||
znx_zero_i64_ref(a.n() as u64, a.at_mut_ptr(col_i, i));
|
||||
});
|
||||
|
||||
// If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
|
||||
a.at_poly_mut(poly_idx, cols - 1)[..data_len].copy_from_slice(&data[..data_len]);
|
||||
a.at_mut(col_i, size - 1)[..data_len].copy_from_slice(&data[..data_len]);
|
||||
} else {
|
||||
let mask: i64 = (1 << log_base2k) - 1;
|
||||
let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k);
|
||||
(cols - steps..cols)
|
||||
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
||||
(size - steps..size)
|
||||
.rev()
|
||||
.enumerate()
|
||||
.for_each(|(i, i_rev)| {
|
||||
let shift: usize = i * log_base2k;
|
||||
izip!(a.at_poly_mut(poly_idx, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
|
||||
izip!(a.at_mut(col_i, i_rev).iter_mut(), data.iter()).for_each(|(y, x)| *y = (x >> shift) & mask);
|
||||
})
|
||||
}
|
||||
|
||||
// Case where self.prec % self.k != 0.
|
||||
if log_k_rem != log_base2k {
|
||||
let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k);
|
||||
(cols - steps..cols).rev().for_each(|i| {
|
||||
a.at_poly_mut(poly_idx, i)[..data_len]
|
||||
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
||||
(size - steps..size).rev().for_each(|i| {
|
||||
a.at_mut(col_i, i)[..data_len]
|
||||
.iter_mut()
|
||||
.for_each(|x| *x <<= log_k_rem);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_vec_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) {
|
||||
let cols: usize = (log_k + log_base2k - 1) / log_base2k;
|
||||
fn decode_vec_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, log_base2k: usize, log_k: usize, data: &mut [i64]) {
|
||||
let size: usize = (log_k + log_base2k - 1) / log_base2k;
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
@@ -140,26 +153,26 @@ fn decode_vec_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize,
|
||||
data.len(),
|
||||
a.n()
|
||||
);
|
||||
assert!(poly_idx < a.size());
|
||||
assert!(col_i < a.cols());
|
||||
}
|
||||
data.copy_from_slice(a.at_poly(poly_idx, 0));
|
||||
data.copy_from_slice(a.at(col_i, 0));
|
||||
let rem: usize = log_base2k - (log_k % log_base2k);
|
||||
(1..cols).for_each(|i| {
|
||||
if i == cols - 1 && rem != log_base2k {
|
||||
(1..size).for_each(|i| {
|
||||
if i == size - 1 && rem != log_base2k {
|
||||
let k_rem: usize = log_base2k - rem;
|
||||
izip!(a.at_poly(poly_idx, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y = (*y << k_rem) + (x >> rem);
|
||||
});
|
||||
} else {
|
||||
izip!(a.at_poly(poly_idx, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y = (*y << log_base2k) + x;
|
||||
});
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [Float]) {
|
||||
let cols: usize = a.cols();
|
||||
fn decode_vec_float<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, log_base2k: usize, data: &mut [Float]) {
|
||||
let size: usize = a.size();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
@@ -168,23 +181,23 @@ fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [
|
||||
data.len(),
|
||||
a.n()
|
||||
);
|
||||
assert!(poly_idx < a.size());
|
||||
assert!(col_i < a.cols());
|
||||
}
|
||||
|
||||
let prec: u32 = (log_base2k * cols) as u32;
|
||||
let prec: u32 = (log_base2k * size) as u32;
|
||||
|
||||
// 2^{log_base2k}
|
||||
let base = Float::with_val(prec, (1 << log_base2k) as f64);
|
||||
|
||||
// y[i] = sum x[j][i] * 2^{-log_base2k*j}
|
||||
(0..cols).for_each(|i| {
|
||||
(0..size).for_each(|i| {
|
||||
if i == 0 {
|
||||
izip!(a.at_poly(poly_idx, cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
y.assign(*x);
|
||||
*y /= &base;
|
||||
});
|
||||
} else {
|
||||
izip!(a.at_poly(poly_idx, cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
izip!(a.at(col_i, size - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| {
|
||||
*y += Float::with_val(prec, *x);
|
||||
*y /= &base;
|
||||
});
|
||||
@@ -192,54 +205,62 @@ fn decode_vec_float(a: &VecZnx, poly_idx: usize, log_base2k: usize, data: &mut [
|
||||
});
|
||||
}
|
||||
|
||||
fn encode_coeff_i64(a: &mut VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize, value: i64, log_max: usize) {
|
||||
let cols: usize = (log_k + log_base2k - 1) / log_base2k;
|
||||
fn encode_coeff_i64<D: AsMut<[u8]> + AsRef<[u8]>>(
|
||||
a: &mut VecZnx<D>,
|
||||
col_i: usize,
|
||||
log_base2k: usize,
|
||||
log_k: usize,
|
||||
i: usize,
|
||||
value: i64,
|
||||
log_max: usize,
|
||||
) {
|
||||
let size: usize = (log_k + log_base2k - 1) / log_base2k;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < a.n());
|
||||
assert!(
|
||||
cols <= a.cols(),
|
||||
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}",
|
||||
cols,
|
||||
a.cols()
|
||||
size <= a.size(),
|
||||
"invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.size()={}",
|
||||
size,
|
||||
a.size()
|
||||
);
|
||||
assert!(poly_idx < a.size());
|
||||
assert!(col_i < a.cols());
|
||||
}
|
||||
|
||||
let log_k_rem: usize = log_base2k - (log_k % log_base2k);
|
||||
(0..a.cols()).for_each(|j| a.at_poly_mut(poly_idx, j)[i] = 0);
|
||||
(0..a.size()).for_each(|j| a.at_mut(col_i, j)[i] = 0);
|
||||
|
||||
// If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy
|
||||
// values on the last limb.
|
||||
// Else we decompose values base2k.
|
||||
if log_max + log_k_rem < 63 || log_k_rem == log_base2k {
|
||||
a.at_poly_mut(poly_idx, cols - 1)[i] = value;
|
||||
a.at_mut(col_i, size - 1)[i] = value;
|
||||
} else {
|
||||
let mask: i64 = (1 << log_base2k) - 1;
|
||||
let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k);
|
||||
(cols - steps..cols)
|
||||
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
||||
(size - steps..size)
|
||||
.rev()
|
||||
.enumerate()
|
||||
.for_each(|(j, j_rev)| {
|
||||
a.at_poly_mut(poly_idx, j_rev)[i] = (value >> (j * log_base2k)) & mask;
|
||||
a.at_mut(col_i, j_rev)[i] = (value >> (j * log_base2k)) & mask;
|
||||
})
|
||||
}
|
||||
|
||||
// Case where prec % k != 0.
|
||||
if log_k_rem != log_base2k {
|
||||
let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k);
|
||||
(cols - steps..cols).rev().for_each(|j| {
|
||||
a.at_poly_mut(poly_idx, j)[i] <<= log_k_rem;
|
||||
let steps: usize = min(size, (log_max + log_base2k - 1) / log_base2k);
|
||||
(size - steps..size).rev().for_each(|j| {
|
||||
a.at_mut(col_i, j)[i] <<= log_k_rem;
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_coeff_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 {
|
||||
fn decode_coeff_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, log_base2k: usize, log_k: usize, i: usize) -> i64 {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < a.n());
|
||||
assert!(poly_idx < a.size())
|
||||
assert!(col_i < a.cols())
|
||||
}
|
||||
|
||||
let cols: usize = (log_k + log_base2k - 1) / log_base2k;
|
||||
@@ -261,27 +282,30 @@ fn decode_coeff_i64(a: &VecZnx, poly_idx: usize, log_base2k: usize, log_k: usize
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{Encoding, Infos, VecZnx};
|
||||
use crate::vec_znx_ops::*;
|
||||
use crate::znx_base::*;
|
||||
use crate::{Decoding, Encoding, FFT64, Module, VecZnx, znx_base::ZnxInfos};
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
|
||||
#[test]
|
||||
fn test_set_get_i64_lo_norm() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let log_base2k: usize = 17;
|
||||
let cols: usize = 5;
|
||||
let log_k: usize = cols * log_base2k - 5;
|
||||
let mut a: VecZnx = VecZnx::new(n, 2, cols);
|
||||
let size: usize = 5;
|
||||
let log_k: usize = size * log_base2k - 5;
|
||||
let mut a: VecZnx<_> = module.new_vec_znx(2, size);
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.size()).for_each(|poly_idx| {
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
have.iter_mut()
|
||||
.for_each(|x| *x = (source.next_i64() << 56) >> 56);
|
||||
a.encode_vec_i64(poly_idx, log_base2k, log_k, &have, 10);
|
||||
a.encode_vec_i64(col_i, log_base2k, log_k, &have, 10);
|
||||
let mut want: Vec<i64> = vec![i64::default(); n];
|
||||
a.decode_vec_i64(poly_idx, log_base2k, log_k, &mut want);
|
||||
a.decode_vec_i64(col_i, log_base2k, log_k, &mut want);
|
||||
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
});
|
||||
}
|
||||
@@ -289,19 +313,20 @@ mod tests {
|
||||
#[test]
|
||||
fn test_set_get_i64_hi_norm() {
|
||||
let n: usize = 8;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let log_base2k: usize = 17;
|
||||
let cols: usize = 5;
|
||||
let log_k: usize = cols * log_base2k - 5;
|
||||
let mut a: VecZnx = VecZnx::new(n, 2, cols);
|
||||
let size: usize = 5;
|
||||
let log_k: usize = size * log_base2k - 5;
|
||||
let mut a: VecZnx<_> = module.new_vec_znx(2, size);
|
||||
let mut source = Source::new([0u8; 32]);
|
||||
let raw: &mut [i64] = a.raw_mut();
|
||||
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
(0..a.size()).for_each(|poly_idx| {
|
||||
(0..a.cols()).for_each(|col_i| {
|
||||
let mut have: Vec<i64> = vec![i64::default(); n];
|
||||
have.iter_mut().for_each(|x| *x = source.next_i64());
|
||||
a.encode_vec_i64(poly_idx, log_base2k, log_k, &have, 64);
|
||||
a.encode_vec_i64(col_i, log_base2k, log_k, &have, 64);
|
||||
let mut want = vec![i64::default(); n];
|
||||
a.decode_vec_i64(poly_idx, log_base2k, log_k, &mut want);
|
||||
a.decode_vec_i64(col_i, log_base2k, log_k, &mut want);
|
||||
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,8 +3,6 @@ pub struct module_info_t {
|
||||
}
|
||||
|
||||
pub type module_type_t = ::std::os::raw::c_uint;
|
||||
pub const module_type_t_FFT64: module_type_t = 0;
|
||||
pub const module_type_t_NTT120: module_type_t = 1;
|
||||
pub use self::module_type_t as MODULE_TYPE;
|
||||
|
||||
pub type MODULE = module_info_t;
|
||||
|
||||
@@ -33,3 +33,16 @@ unsafe extern "C" {
|
||||
a_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn svp_apply_dft_to_dft(
|
||||
module: *const MODULE,
|
||||
res: *const VEC_ZNX_DFT,
|
||||
res_size: u64,
|
||||
res_cols: u64,
|
||||
ppol: *const SVP_PPOL,
|
||||
a: *const VEC_ZNX_DFT,
|
||||
a_size: u64,
|
||||
a_cols: u64,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -8,17 +8,17 @@ pub struct vec_znx_big_t {
|
||||
pub type VEC_ZNX_BIG = vec_znx_big_t;
|
||||
|
||||
unsafe extern "C" {
|
||||
pub fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64;
|
||||
pub unsafe fn bytes_of_vec_znx_big(module: *const MODULE, size: u64) -> u64;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG;
|
||||
pub unsafe fn new_vec_znx_big(module: *const MODULE, size: u64) -> *mut VEC_ZNX_BIG;
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG);
|
||||
pub unsafe fn delete_vec_znx_big(res: *mut VEC_ZNX_BIG);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub fn vec_znx_big_add(
|
||||
pub unsafe fn vec_znx_big_add(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
@@ -29,7 +29,7 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn vec_znx_big_add_small(
|
||||
pub unsafe fn vec_znx_big_add_small(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
@@ -41,7 +41,7 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn vec_znx_big_add_small2(
|
||||
pub unsafe fn vec_znx_big_add_small2(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
@@ -54,7 +54,7 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn vec_znx_big_sub(
|
||||
pub unsafe fn vec_znx_big_sub(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
@@ -65,7 +65,7 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn vec_znx_big_sub_small_b(
|
||||
pub unsafe fn vec_znx_big_sub_small_b(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
@@ -77,7 +77,7 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn vec_znx_big_sub_small_a(
|
||||
pub unsafe fn vec_znx_big_sub_small_a(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
@@ -89,7 +89,7 @@ unsafe extern "C" {
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn vec_znx_big_sub_small2(
|
||||
pub unsafe fn vec_znx_big_sub_small2(
|
||||
module: *const MODULE,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
@@ -101,8 +101,13 @@ unsafe extern "C" {
|
||||
b_sl: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub fn vec_znx_big_normalize_base2k(
|
||||
pub unsafe fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
log2_base2k: u64,
|
||||
res: *mut i64,
|
||||
@@ -113,34 +118,9 @@ unsafe extern "C" {
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
unsafe extern "C" {
|
||||
pub fn vec_znx_big_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub fn vec_znx_big_automorphism(
|
||||
module: *const MODULE,
|
||||
p: i64,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub fn vec_znx_big_rotate(
|
||||
module: *const MODULE,
|
||||
p: i64,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub fn vec_znx_big_range_normalize_base2k(
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k(
|
||||
module: *const MODULE,
|
||||
log2_base2k: u64,
|
||||
res: *mut i64,
|
||||
@@ -153,6 +133,29 @@ unsafe extern "C" {
|
||||
tmp_space: *mut u8,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
|
||||
pub unsafe fn vec_znx_big_range_normalize_base2k_tmp_bytes(module: *const MODULE) -> u64;
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_automorphism(
|
||||
module: *const MODULE,
|
||||
p: i64,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
|
||||
unsafe extern "C" {
|
||||
pub unsafe fn vec_znx_big_rotate(
|
||||
module: *const MODULE,
|
||||
p: i64,
|
||||
res: *mut VEC_ZNX_BIG,
|
||||
res_size: u64,
|
||||
a: *const VEC_ZNX_BIG,
|
||||
a_size: u64,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
use crate::LAYOUT;
|
||||
|
||||
pub trait Infos {
|
||||
/// Returns the ring degree of the receiver.
|
||||
fn n(&self) -> usize;
|
||||
|
||||
/// Returns the base two logarithm of the ring dimension of the receiver.
|
||||
fn log_n(&self) -> usize;
|
||||
|
||||
/// Returns the number of stacked polynomials.
|
||||
fn size(&self) -> usize;
|
||||
|
||||
/// Returns the memory layout of the stacked polynomials.
|
||||
fn layout(&self) -> LAYOUT;
|
||||
|
||||
/// Returns the number of columns of the receiver.
|
||||
/// This method is equivalent to [Infos::cols].
|
||||
fn cols(&self) -> usize;
|
||||
|
||||
/// Returns the number of rows of the receiver.
|
||||
fn rows(&self) -> usize;
|
||||
}
|
||||
@@ -2,39 +2,43 @@ pub mod encoding;
|
||||
#[allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, improper_ctypes)]
|
||||
// Other modules and exports
|
||||
pub mod ffi;
|
||||
pub mod infos;
|
||||
pub mod mat_znx_dft;
|
||||
pub mod mat_znx_dft_ops;
|
||||
pub mod module;
|
||||
pub mod sampling;
|
||||
pub mod scalar_znx;
|
||||
pub mod scalar_znx_dft;
|
||||
pub mod scalar_znx_dft_ops;
|
||||
pub mod stats;
|
||||
pub mod svp;
|
||||
pub mod vec_znx;
|
||||
pub mod vec_znx_big;
|
||||
pub mod vec_znx_big_ops;
|
||||
pub mod vec_znx_dft;
|
||||
pub mod vmp;
|
||||
pub mod vec_znx_dft_ops;
|
||||
pub mod vec_znx_ops;
|
||||
pub mod znx_base;
|
||||
|
||||
pub use encoding::*;
|
||||
pub use infos::*;
|
||||
pub use mat_znx_dft::*;
|
||||
pub use mat_znx_dft_ops::*;
|
||||
pub use module::*;
|
||||
pub use sampling::*;
|
||||
#[allow(unused_imports)]
|
||||
pub use scalar_znx::*;
|
||||
pub use scalar_znx_dft::*;
|
||||
pub use scalar_znx_dft_ops::*;
|
||||
pub use stats::*;
|
||||
pub use svp::*;
|
||||
pub use vec_znx::*;
|
||||
pub use vec_znx_big::*;
|
||||
pub use vec_znx_big_ops::*;
|
||||
pub use vec_znx_dft::*;
|
||||
pub use vmp::*;
|
||||
pub use vec_znx_dft_ops::*;
|
||||
pub use vec_znx_ops::*;
|
||||
pub use znx_base::*;
|
||||
|
||||
pub const GALOISGENERATOR: u64 = 5;
|
||||
pub const DEFAULTALIGN: usize = 64;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[repr(u8)]
|
||||
pub enum LAYOUT {
|
||||
ROW,
|
||||
COL,
|
||||
}
|
||||
|
||||
pub fn is_aligned_custom<T>(ptr: *const T, align: usize) -> bool {
|
||||
fn is_aligned_custom<T>(ptr: *const T, align: usize) -> bool {
|
||||
(ptr as usize) % align == 0
|
||||
}
|
||||
|
||||
@@ -51,38 +55,35 @@ pub fn assert_alignement<T>(ptr: *const T) {
|
||||
|
||||
pub fn cast<T, V>(data: &[T]) -> &[V] {
|
||||
let ptr: *const V = data.as_ptr() as *const V;
|
||||
let len: usize = data.len() / std::mem::size_of::<V>();
|
||||
let len: usize = data.len() / size_of::<V>();
|
||||
unsafe { std::slice::from_raw_parts(ptr, len) }
|
||||
}
|
||||
|
||||
pub fn cast_mut<T, V>(data: &[T]) -> &mut [V] {
|
||||
let ptr: *mut V = data.as_ptr() as *mut V;
|
||||
let len: usize = data.len() / std::mem::size_of::<V>();
|
||||
let len: usize = data.len() / size_of::<V>();
|
||||
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
|
||||
}
|
||||
|
||||
use std::alloc::{Layout, alloc};
|
||||
use std::ptr;
|
||||
|
||||
/// Allocates a block of bytes with a custom alignement.
|
||||
/// Alignement must be a power of two and size a multiple of the alignement.
|
||||
/// Allocated memory is initialized to zero.
|
||||
pub fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
|
||||
fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
|
||||
assert!(
|
||||
align.is_power_of_two(),
|
||||
"Alignment must be a power of two but is {}",
|
||||
align
|
||||
);
|
||||
assert_eq!(
|
||||
(size * std::mem::size_of::<u8>()) % align,
|
||||
(size * size_of::<u8>()) % align,
|
||||
0,
|
||||
"size={} must be a multiple of align={}",
|
||||
size,
|
||||
align
|
||||
);
|
||||
unsafe {
|
||||
let layout: Layout = Layout::from_size_align(size, align).expect("Invalid alignment");
|
||||
let ptr: *mut u8 = alloc(layout);
|
||||
let layout: std::alloc::Layout = std::alloc::Layout::from_size_align(size, align).expect("Invalid alignment");
|
||||
let ptr: *mut u8 = std::alloc::alloc(layout);
|
||||
if ptr.is_null() {
|
||||
panic!("Memory allocation failed");
|
||||
}
|
||||
@@ -93,36 +94,158 @@ pub fn alloc_aligned_custom_u8(size: usize, align: usize) -> Vec<u8> {
|
||||
align
|
||||
);
|
||||
// Init allocated memory to zero
|
||||
ptr::write_bytes(ptr, 0, size);
|
||||
std::ptr::write_bytes(ptr, 0, size);
|
||||
Vec::from_raw_parts(ptr, size, size)
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocates a block of bytes aligned with [DEFAULTALIGN].
|
||||
/// Size must be amultiple of [DEFAULTALIGN].
|
||||
/// /// Allocated memory is initialized to zero.
|
||||
pub fn alloc_aligned_u8(size: usize) -> Vec<u8> {
|
||||
alloc_aligned_custom_u8(size, DEFAULTALIGN)
|
||||
}
|
||||
|
||||
/// Allocates a block of T aligned with [DEFAULTALIGN].
|
||||
/// Size of T * size msut be a multiple of [DEFAULTALIGN].
|
||||
pub fn alloc_aligned_custom<T>(size: usize, align: usize) -> Vec<T> {
|
||||
assert_eq!(
|
||||
(size * std::mem::size_of::<T>()) % align,
|
||||
(size * size_of::<T>()) % align,
|
||||
0,
|
||||
"size={} must be a multiple of align={}",
|
||||
size,
|
||||
align
|
||||
);
|
||||
let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(std::mem::size_of::<T>() * size, align);
|
||||
let mut vec_u8: Vec<u8> = alloc_aligned_custom_u8(size_of::<T>() * size, align);
|
||||
let ptr: *mut T = vec_u8.as_mut_ptr() as *mut T;
|
||||
let len: usize = vec_u8.len() / std::mem::size_of::<T>();
|
||||
let cap: usize = vec_u8.capacity() / std::mem::size_of::<T>();
|
||||
let len: usize = vec_u8.len() / size_of::<T>();
|
||||
let cap: usize = vec_u8.capacity() / size_of::<T>();
|
||||
std::mem::forget(vec_u8);
|
||||
unsafe { Vec::from_raw_parts(ptr, len, cap) }
|
||||
}
|
||||
|
||||
/// Allocates an aligned vector of size equal to the smallest multiple
|
||||
/// of [DEFAULTALIGN]/size_of::<T>() that is equal or greater to `size`.
|
||||
pub fn alloc_aligned<T>(size: usize) -> Vec<T> {
|
||||
alloc_aligned_custom::<T>(size, DEFAULTALIGN)
|
||||
alloc_aligned_custom::<T>(
|
||||
size + (size % (DEFAULTALIGN / size_of::<T>())),
|
||||
DEFAULTALIGN,
|
||||
)
|
||||
}
|
||||
|
||||
// Scratch implementation below
|
||||
|
||||
pub struct ScratchOwned(Vec<u8>);
|
||||
|
||||
impl ScratchOwned {
|
||||
pub fn new(byte_count: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned(byte_count);
|
||||
Self(data)
|
||||
}
|
||||
|
||||
pub fn borrow(&mut self) -> &mut Scratch {
|
||||
Scratch::new(&mut self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Scratch {
|
||||
data: [u8],
|
||||
}
|
||||
|
||||
impl Scratch {
|
||||
fn new(data: &mut [u8]) -> &mut Self {
|
||||
unsafe { &mut *(data as *mut [u8] as *mut Self) }
|
||||
}
|
||||
|
||||
pub fn available(&self) -> usize {
|
||||
let ptr: *const u8 = self.data.as_ptr();
|
||||
let self_len: usize = self.data.len();
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
self_len.saturating_sub(aligned_offset)
|
||||
}
|
||||
|
||||
fn take_slice_aligned(data: &mut [u8], take_len: usize) -> (&mut [u8], &mut [u8]) {
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
let self_len: usize = data.len();
|
||||
|
||||
let aligned_offset: usize = ptr.align_offset(DEFAULTALIGN);
|
||||
let aligned_len: usize = self_len.saturating_sub(aligned_offset);
|
||||
|
||||
if let Some(rem_len) = aligned_len.checked_sub(take_len) {
|
||||
unsafe {
|
||||
let rem_ptr: *mut u8 = ptr.add(aligned_offset).add(take_len);
|
||||
let rem_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(rem_ptr, rem_len);
|
||||
|
||||
let take_slice: &mut [u8] = &mut *std::ptr::slice_from_raw_parts_mut(ptr.add(aligned_offset), take_len);
|
||||
|
||||
return (take_slice, rem_slice);
|
||||
}
|
||||
} else {
|
||||
panic!(
|
||||
"Attempted to take {} from scratch with {} aligned bytes left",
|
||||
take_len,
|
||||
aligned_len,
|
||||
// type_name::<T>(),
|
||||
// aligned_len
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tmp_slice<T>(&mut self, len: usize) -> (&mut [T], &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, len * std::mem::size_of::<T>());
|
||||
|
||||
unsafe {
|
||||
(
|
||||
&mut *(std::ptr::slice_from_raw_parts_mut(take_slice.as_mut_ptr() as *mut T, len)),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tmp_scalar_znx<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols));
|
||||
|
||||
(
|
||||
ScalarZnx::from_data(take_slice, module.n(), cols),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tmp_scalar_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols));
|
||||
|
||||
(
|
||||
ScalarZnxDft::from_data(take_slice, module.n(), cols),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tmp_vec_znx_dft<B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxDft<&mut [u8], B>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_dft(module, cols, size));
|
||||
|
||||
(
|
||||
VecZnxDft::from_data(take_slice, module.n(), cols, size),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tmp_vec_znx_big<B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
) -> (VecZnxBig<&mut [u8], B>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_vec_znx_big(module, cols, size));
|
||||
|
||||
(
|
||||
VecZnxBig::from_data(take_slice, module.n(), cols, size),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tmp_vec_znx<B: Backend>(&mut self, module: &Module<B>, cols: usize, size: usize) -> (VecZnx<&mut [u8]>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, module.bytes_of_vec_znx(cols, size));
|
||||
(
|
||||
VecZnx::from_data(take_slice, module.n(), cols, size),
|
||||
Self::new(rem_slice),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
232
base2k/src/mat_znx_dft.rs
Normal file
232
base2k/src/mat_znx_dft.rs
Normal file
@@ -0,0 +1,232 @@
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
||||
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
|
||||
/// Each col of the [MatZnxDft] can be seen as a collection of [VecZnxDft].
|
||||
///
|
||||
/// [MatZnxDft] is used to permform a vector matrix product between a [VecZnx]/[VecZnxDft] and a [MatZnxDft].
|
||||
/// See the trait [MatZnxDftOps] for additional information.
|
||||
pub struct MatZnxDft<D, B: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
size: usize,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ZnxInfos for MatZnxDft<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.rows
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxSliceSize for MatZnxDft<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols_out()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataView for MatZnxDft<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataViewMut for MatZnxDft<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for MatZnxDft<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
impl<D, B: Backend> MatZnxDft<D, B> {
|
||||
pub fn cols_in(&self) -> usize {
|
||||
self.cols_in
|
||||
}
|
||||
|
||||
pub fn cols_out(&self) -> usize {
|
||||
self.cols_out
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>, B: Backend> MatZnxDft<D, B> {
|
||||
pub(crate) fn bytes_of(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
unsafe {
|
||||
crate::ffi::vmp::bytes_of_vmp_pmat(
|
||||
module.ptr,
|
||||
(rows * cols_in) as u64,
|
||||
(size * cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new(module: &Module<B>, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> Self {
|
||||
let data: Vec<u8> = alloc_aligned(Self::bytes_of(module, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes(
|
||||
module: &Module<B>,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: impl Into<Vec<u8>>,
|
||||
) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(module, rows, cols_in, cols_out, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
size,
|
||||
rows,
|
||||
cols_in,
|
||||
cols_out,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> MatZnxDft<D, FFT64> {
|
||||
/// Returns a copy of the backend array at index (i, j) of the [MatZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `row`: row index (i).
|
||||
/// * `col`: col index (j).
|
||||
#[allow(dead_code)]
|
||||
fn at(&self, row: usize, col: usize) -> Vec<f64> {
|
||||
let n: usize = self.n();
|
||||
|
||||
let mut res: Vec<f64> = alloc_aligned(n);
|
||||
|
||||
if n < 8 {
|
||||
res.copy_from_slice(&self.raw()[(row + col * self.rows()) * n..(row + col * self.rows()) * (n + 1)]);
|
||||
} else {
|
||||
(0..n >> 3).for_each(|blk| {
|
||||
res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]);
|
||||
});
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn at_block(&self, row: usize, col: usize, blk: usize) -> &[f64] {
|
||||
let nrows: usize = self.rows();
|
||||
let nsize: usize = self.size();
|
||||
if col == (nsize - 1) && (nsize & 1 == 1) {
|
||||
&self.raw()[blk * nrows * nsize * 8 + col * nrows * 8 + row * 8..]
|
||||
} else {
|
||||
&self.raw()[blk * nrows * nsize * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type MatZnxDftOwned<B> = MatZnxDft<Vec<u8>, B>;
|
||||
|
||||
pub trait MatZnxDftToRef<B: Backend> {
|
||||
fn to_ref(&self) -> MatZnxDft<&[u8], B>;
|
||||
}
|
||||
|
||||
pub trait MatZnxDftToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<Vec<u8>, B> {
|
||||
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
|
||||
MatZnxDft {
|
||||
data: self.data.as_mut_slice(),
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<Vec<u8>, B> {
|
||||
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
||||
MatZnxDft {
|
||||
data: self.data.as_slice(),
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> MatZnxDftToMut<B> for MatZnxDft<&mut [u8], B> {
|
||||
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
|
||||
MatZnxDft {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&mut [u8], B> {
|
||||
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
||||
MatZnxDft {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> MatZnxDftToRef<B> for MatZnxDft<&[u8], B> {
|
||||
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
||||
MatZnxDft {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
rows: self.rows,
|
||||
cols_in: self.cols_in,
|
||||
cols_out: self.cols_out,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
487
base2k/src/mat_znx_dft_ops.rs
Normal file
487
base2k/src/mat_znx_dft_ops.rs
Normal file
@@ -0,0 +1,487 @@
|
||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::ffi::vmp;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use crate::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftOwned, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnxDft, VecZnxDftToMut,
|
||||
VecZnxDftToRef,
|
||||
};
|
||||
|
||||
pub trait MatZnxDftAlloc<B: Backend> {
|
||||
/// Allocates a new [MatZnxDft] with the given number of rows and columns.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `rows`: number of rows (number of [VecZnxDft]).
|
||||
/// * `size`: number of size (number of size of each [VecZnxDft]).
|
||||
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned<B>;
|
||||
|
||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize;
|
||||
|
||||
fn new_mat_znx_dft_from_bytes(
|
||||
&self,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> MatZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait MatZnxDftScratch {
|
||||
/// Returns the size of the stratch space necessary for [MatZnxDftOps::vmp_apply_dft_to_dft].
|
||||
fn vmp_apply_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize;
|
||||
}
|
||||
|
||||
/// This trait implements methods for vector matrix product,
|
||||
/// that is, multiplying a [VecZnx] with a [MatZnxDft].
|
||||
pub trait MatZnxDftOps<BACKEND: Backend> {
|
||||
/// Prepares the ith-row of [MatZnxDft] from a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: [MatZnxDft] on which the values are encoded.
|
||||
/// * `a`: the [VecZnxDft] to encode on the [MatZnxDft].
|
||||
/// * `row_i`: the index of the row to prepare.
|
||||
///
|
||||
/// The size of buf can be obtained with [MatZnxDftOps::vmp_prepare_tmp_bytes].
|
||||
fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
|
||||
where
|
||||
R: MatZnxDftToMut<BACKEND>,
|
||||
A: VecZnxDftToRef<BACKEND>;
|
||||
|
||||
/// Extracts the ith-row of [MatZnxDft] into a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: the [VecZnxDft] to on which to extract the row of the [MatZnxDft].
|
||||
/// * `a`: [MatZnxDft] on which the values are encoded.
|
||||
/// * `row_i`: the index of the row to extract.
|
||||
fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<BACKEND>,
|
||||
A: MatZnxDftToRef<BACKEND>;
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [MatZnxDft].
|
||||
/// The size of `buf` is given by [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
///
|
||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [MatZnxDft].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` size and a [MatZnxDft] of `i` rows and
|
||||
/// `j` size, the output is a [VecZnx] of `j` size.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
///
|
||||
/// ```text
|
||||
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
||||
/// |h i j|
|
||||
/// |k l m|
|
||||
/// ```
|
||||
/// where each element is a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c`: the output of the vector matrix product, as a [VecZnxDft].
|
||||
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
|
||||
/// * `b`: the right operand [MatZnxDft] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [MatZnxDftOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
fn vmp_apply<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<BACKEND>,
|
||||
A: VecZnxDftToRef<BACKEND>,
|
||||
B: MatZnxDftToRef<BACKEND>;
|
||||
|
||||
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.
|
||||
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<BACKEND>,
|
||||
A: VecZnxDftToRef<BACKEND>,
|
||||
B: MatZnxDftToRef<BACKEND>;
|
||||
}
|
||||
|
||||
impl<B: Backend> MatZnxDftAlloc<B> for Module<B> {
|
||||
fn bytes_of_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
|
||||
MatZnxDftOwned::bytes_of(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
|
||||
fn new_mat_znx_dft(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnxDftOwned<B> {
|
||||
MatZnxDftOwned::new(self, rows, cols_in, cols_out, size)
|
||||
}
|
||||
|
||||
fn new_mat_znx_dft_from_bytes(
|
||||
&self,
|
||||
rows: usize,
|
||||
cols_in: usize,
|
||||
cols_out: usize,
|
||||
size: usize,
|
||||
bytes: Vec<u8>,
|
||||
) -> MatZnxDftOwned<B> {
|
||||
MatZnxDftOwned::new_from_bytes(self, rows, cols_in, cols_out, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BACKEND: Backend> MatZnxDftScratch for Module<BACKEND> {
|
||||
fn vmp_apply_tmp_bytes(
|
||||
&self,
|
||||
res_size: usize,
|
||||
a_size: usize,
|
||||
b_rows: usize,
|
||||
b_cols_in: usize,
|
||||
b_cols_out: usize,
|
||||
b_size: usize,
|
||||
) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
self.ptr,
|
||||
(res_size * b_cols_out) as u64,
|
||||
(a_size * b_cols_in) as u64,
|
||||
(b_rows * b_cols_in) as u64,
|
||||
(b_size * b_cols_out) as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MatZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn vmp_prepare_row<R, A>(&self, res: &mut R, res_row: usize, res_col_in: usize, a: &A)
|
||||
where
|
||||
R: MatZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: MatZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
res.cols_out(),
|
||||
"a.cols(): {} != res.cols_out(): {}",
|
||||
a.cols(),
|
||||
res.cols_out()
|
||||
);
|
||||
assert!(
|
||||
res_row < res.rows(),
|
||||
"res_row: {} >= res.rows(): {}",
|
||||
res_row,
|
||||
res.rows()
|
||||
);
|
||||
assert!(
|
||||
res_col_in < res.cols_in(),
|
||||
"res_col_in: {} >= res.cols_in(): {}",
|
||||
res_col_in,
|
||||
res.cols_in()
|
||||
);
|
||||
assert_eq!(
|
||||
res.size(),
|
||||
a.size(),
|
||||
"res.size(): {} != a.size(): {}",
|
||||
res.size(),
|
||||
a.size()
|
||||
);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
vmp::vmp_prepare_row_dft(
|
||||
self.ptr,
|
||||
res.as_mut_ptr() as *mut vmp::vmp_pmat_t,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(res_row * res.cols_in() + res_col_in) as u64,
|
||||
(res.rows() * res.cols_in()) as u64,
|
||||
(res.size() * res.cols_out()) as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_extract_row<R, A>(&self, res: &mut R, a: &A, a_row: usize, a_col_in: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: MatZnxDft<&[u8], _> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
a.cols_out(),
|
||||
"res.cols(): {} != a.cols_out(): {}",
|
||||
res.cols(),
|
||||
a.cols_out()
|
||||
);
|
||||
assert!(
|
||||
a_row < a.rows(),
|
||||
"a_row: {} >= a.rows(): {}",
|
||||
a_row,
|
||||
a.rows()
|
||||
);
|
||||
assert!(
|
||||
a_col_in < a.cols_in(),
|
||||
"a_col_in: {} >= a.cols_in(): {}",
|
||||
a_col_in,
|
||||
a.cols_in()
|
||||
);
|
||||
assert_eq!(
|
||||
res.size(),
|
||||
a.size(),
|
||||
"res.size(): {} != a.size(): {}",
|
||||
res.size(),
|
||||
a.size()
|
||||
);
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_extract_row_dft(
|
||||
self.ptr,
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
a.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(a_row * a.cols_in() + a_col_in) as u64,
|
||||
(a.rows() * a.cols_in()) as u64,
|
||||
(a.size() * a.cols_out()) as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
B: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
let b: MatZnxDft<&[u8], _> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
b.cols_out(),
|
||||
"res.cols(): {} != b.cols_out: {}",
|
||||
res.cols(),
|
||||
b.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
b.cols_in(),
|
||||
"a.cols(): {} != b.cols_in: {}",
|
||||
a.cols(),
|
||||
b.cols_in()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes(
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
b.cols_in(),
|
||||
b.cols_out(),
|
||||
b.size(),
|
||||
));
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
(res.size() * res.cols()) as u64,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(a.size() * a.cols()) as u64,
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
B: MatZnxDftToRef<FFT64> {
|
||||
let mut res: VecZnxDft<&mut [u8], _> = res.to_mut();
|
||||
let a: VecZnxDft<&[u8], _> = a.to_ref();
|
||||
let b: MatZnxDft<&[u8], _> = b.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(
|
||||
res.cols(),
|
||||
b.cols_out(),
|
||||
"res.cols(): {} != b.cols_out: {}",
|
||||
res.cols(),
|
||||
b.cols_out()
|
||||
);
|
||||
assert_eq!(
|
||||
a.cols(),
|
||||
b.cols_in(),
|
||||
"a.cols(): {} != b.cols_in: {}",
|
||||
a.cols(),
|
||||
b.cols_in()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(self.vmp_apply_tmp_bytes(
|
||||
res.size(),
|
||||
a.size(),
|
||||
b.rows(),
|
||||
b.cols_in(),
|
||||
b.cols_out(),
|
||||
b.size(),
|
||||
));
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_add(
|
||||
self.ptr,
|
||||
res.as_mut_ptr() as *mut vec_znx_dft_t,
|
||||
(res.size() * res.cols()) as u64,
|
||||
a.as_ptr() as *const vec_znx_dft_t,
|
||||
(a.size() * a.cols()) as u64,
|
||||
b.as_ptr() as *const vmp::vmp_pmat_t,
|
||||
(b.rows() * b.cols_in()) as u64,
|
||||
(b.size() * b.cols_out()) as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig,
|
||||
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use super::{MatZnxDftAlloc, MatZnxDftScratch};
|
||||
|
||||
#[test]
|
||||
fn vmp_prepare_row() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(16);
|
||||
let log_base2k: usize = 8;
|
||||
let mat_rows: usize = 4;
|
||||
let mat_cols_in: usize = 2;
|
||||
let mat_cols_out: usize = 2;
|
||||
let mat_size: usize = 5;
|
||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||
let mut b_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||
let mut mat: MatZnxDft<Vec<u8>, FFT64> = module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
|
||||
for col_in in 0..mat_cols_in {
|
||||
for row_i in 0..mat_rows {
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
(0..mat_cols_out).for_each(|col_out| {
|
||||
a.fill_uniform(log_base2k, col_out, mat_size, &mut source);
|
||||
module.vec_znx_dft(&mut a_dft, col_out, &a, col_out);
|
||||
});
|
||||
module.vmp_prepare_row(&mut mat, row_i, col_in, &a_dft);
|
||||
module.vmp_extract_row(&mut b_dft, &mat, row_i, col_in);
|
||||
assert_eq!(a_dft.raw(), b_dft.raw());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vmp_apply() {
|
||||
let log_n: i32 = 5;
|
||||
let n: usize = 1 << log_n;
|
||||
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let log_base2k: usize = 15;
|
||||
let a_size: usize = 5;
|
||||
let mat_size: usize = 6;
|
||||
let res_size: usize = 5;
|
||||
|
||||
[1, 2].iter().for_each(|in_cols| {
|
||||
[1, 2].iter().for_each(|out_cols| {
|
||||
let a_cols: usize = *in_cols;
|
||||
let res_cols: usize = *out_cols;
|
||||
|
||||
let mat_rows: usize = a_size;
|
||||
let mat_cols_in: usize = a_cols;
|
||||
let mat_cols_out: usize = res_cols;
|
||||
let res_cols: usize = mat_cols_out;
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
module.vmp_apply_tmp_bytes(
|
||||
res_size,
|
||||
a_size,
|
||||
mat_rows,
|
||||
mat_cols_in,
|
||||
mat_cols_out,
|
||||
mat_size,
|
||||
) | module.vec_znx_big_normalize_tmp_bytes(),
|
||||
);
|
||||
|
||||
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
|
||||
|
||||
(0..a_cols).for_each(|i| {
|
||||
a.at_mut(i, 2)[i + 1] = 1;
|
||||
});
|
||||
|
||||
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
|
||||
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
|
||||
|
||||
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
|
||||
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
|
||||
|
||||
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
|
||||
|
||||
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
|
||||
(0..a.size()).for_each(|row_i| {
|
||||
(0..mat_cols_in).for_each(|col_in_i| {
|
||||
(0..mat_cols_out).for_each(|col_out_i| {
|
||||
let idx = 1 + col_in_i * mat_cols_out + col_out_i;
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
|
||||
module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i);
|
||||
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
|
||||
});
|
||||
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
|
||||
});
|
||||
});
|
||||
|
||||
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
|
||||
(0..a_cols).for_each(|i| {
|
||||
module.vec_znx_dft(&mut a_dft, i, &a, i);
|
||||
});
|
||||
|
||||
module.vmp_apply(&mut c_dft, &a_dft, &mat_znx_dft, scratch.borrow());
|
||||
|
||||
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
|
||||
|
||||
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, res_size);
|
||||
(0..mat_cols_out).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
|
||||
module.vec_znx_big_normalize(log_base2k, &mut res_have, i, &c_big, i, scratch.borrow());
|
||||
});
|
||||
|
||||
(0..mat_cols_out).for_each(|col_i| {
|
||||
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
|
||||
(0..a_cols).for_each(|i| {
|
||||
res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1;
|
||||
});
|
||||
res_have.decode_vec_i64(col_i, log_base2k, log_base2k * 3, &mut res_have_vi64);
|
||||
assert_eq!(res_have_vi64, res_want_vi64);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::GALOISGENERATOR;
|
||||
use crate::ffi::module::{MODULE, delete_module_info, module_info_t, new_module_info};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[repr(u8)]
|
||||
@@ -8,37 +9,50 @@ pub enum BACKEND {
|
||||
NTT120,
|
||||
}
|
||||
|
||||
pub struct Module {
|
||||
pub ptr: *mut MODULE,
|
||||
pub n: usize,
|
||||
pub backend: BACKEND,
|
||||
pub trait Backend {
|
||||
const KIND: BACKEND;
|
||||
fn module_type() -> u32;
|
||||
}
|
||||
|
||||
impl Module {
|
||||
pub struct FFT64;
|
||||
pub struct NTT120;
|
||||
|
||||
impl Backend for FFT64 {
|
||||
const KIND: BACKEND = BACKEND::FFT64;
|
||||
fn module_type() -> u32 {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl Backend for NTT120 {
|
||||
const KIND: BACKEND = BACKEND::NTT120;
|
||||
fn module_type() -> u32 {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Module<B: Backend> {
|
||||
pub ptr: *mut MODULE,
|
||||
n: usize,
|
||||
_marker: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Module<B> {
|
||||
// Instantiates a new module.
|
||||
pub fn new(n: usize, module_type: BACKEND) -> Self {
|
||||
pub fn new(n: usize) -> Self {
|
||||
unsafe {
|
||||
let module_type_u32: u32;
|
||||
match module_type {
|
||||
BACKEND::FFT64 => module_type_u32 = 0,
|
||||
BACKEND::NTT120 => module_type_u32 = 1,
|
||||
}
|
||||
let m: *mut module_info_t = new_module_info(n as u64, module_type_u32);
|
||||
let m: *mut module_info_t = new_module_info(n as u64, B::module_type());
|
||||
if m.is_null() {
|
||||
panic!("Failed to create module.");
|
||||
}
|
||||
Self {
|
||||
ptr: m,
|
||||
n: n,
|
||||
backend: module_type,
|
||||
_marker: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn backend(&self) -> BACKEND {
|
||||
self.backend
|
||||
}
|
||||
|
||||
pub fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
@@ -51,26 +65,27 @@ impl Module {
|
||||
(self.n() << 1) as _
|
||||
}
|
||||
|
||||
// Returns GALOISGENERATOR^|gen| * sign(gen)
|
||||
pub fn galois_element(&self, gen: i64) -> i64 {
|
||||
if gen == 0 {
|
||||
// Returns GALOISGENERATOR^|generator| * sign(generator)
|
||||
pub fn galois_element(&self, generator: i64) -> i64 {
|
||||
if generator == 0 {
|
||||
return 1;
|
||||
}
|
||||
((mod_exp_u64(GALOISGENERATOR, gen.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * gen.signum()
|
||||
((mod_exp_u64(GALOISGENERATOR, generator.abs() as usize) & (self.cyclotomic_order() - 1)) as i64) * generator.signum()
|
||||
}
|
||||
|
||||
// Returns gen^-1
|
||||
pub fn galois_element_inv(&self, gen: i64) -> i64 {
|
||||
if gen == 0 {
|
||||
pub fn galois_element_inv(&self, gal_el: i64) -> i64 {
|
||||
if gal_el == 0 {
|
||||
panic!("cannot invert 0")
|
||||
}
|
||||
((mod_exp_u64(gen.abs() as u64, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1)) as i64)
|
||||
* gen.signum()
|
||||
((mod_exp_u64(gal_el.abs() as u64, (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1)) as i64)
|
||||
* gal_el.signum()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn free(self) {
|
||||
impl<B: Backend> Drop for Module<B> {
|
||||
fn drop(&mut self) {
|
||||
unsafe { delete_module_info(self.ptr) }
|
||||
drop(self);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,56 +1,132 @@
|
||||
use crate::{Infos, Module, VecZnx};
|
||||
use crate::znx_base::ZnxViewMut;
|
||||
use crate::{FFT64, VecZnx, VecZnxBig, VecZnxBigToMut, VecZnxToMut};
|
||||
use rand_distr::{Distribution, Normal};
|
||||
use sampling::source::Source;
|
||||
|
||||
pub trait Sampling {
|
||||
/// Fills the first `cols` cols with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
|
||||
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source);
|
||||
pub trait FillUniform {
|
||||
/// Fills the first `size` size with uniform values in \[-2^{log_base2k-1}, 2^{log_base2k-1}\]
|
||||
fn fill_uniform(&mut self, log_base2k: usize, col_i: usize, size: usize, source: &mut Source);
|
||||
}
|
||||
|
||||
/// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\].
|
||||
fn add_dist_f64<D: Distribution<f64>>(
|
||||
&self,
|
||||
pub trait FillDistF64 {
|
||||
fn fill_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
log_base2k: usize,
|
||||
a: &mut VecZnx,
|
||||
col_i: usize,
|
||||
log_k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
|
||||
/// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\].
|
||||
fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64);
|
||||
}
|
||||
|
||||
impl Sampling for Module {
|
||||
fn fill_uniform(&self, log_base2k: usize, a: &mut VecZnx, cols: usize, source: &mut Source) {
|
||||
pub trait AddDistF64 {
|
||||
/// Adds vector sampled according to the provided distribution, scaled by 2^{-log_k} and bounded to \[-bound, bound\].
|
||||
fn add_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
log_base2k: usize,
|
||||
col_i: usize,
|
||||
log_k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
);
|
||||
}
|
||||
|
||||
pub trait FillNormal {
|
||||
fn fill_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64);
|
||||
}
|
||||
|
||||
pub trait AddNormal {
|
||||
/// Adds a discrete normal vector scaled by 2^{-log_k} with the provided standard deviation and bounded to \[-bound, bound\].
|
||||
fn add_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64);
|
||||
}
|
||||
|
||||
impl<T> FillUniform for VecZnx<T>
|
||||
where
|
||||
VecZnx<T>: VecZnxToMut,
|
||||
{
|
||||
fn fill_uniform(&mut self, log_base2k: usize, col_i: usize, size: usize, source: &mut Source) {
|
||||
let mut a: VecZnx<&mut [u8]> = self.to_mut();
|
||||
let base2k: u64 = 1 << log_base2k;
|
||||
let mask: u64 = base2k - 1;
|
||||
let base2k_half: i64 = (base2k >> 1) as i64;
|
||||
let size: usize = a.n() * cols;
|
||||
a.raw_mut()[..size]
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
|
||||
(0..size).for_each(|j| {
|
||||
a.at_mut(col_i, j)
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = (source.next_u64n(base2k, mask) as i64) - base2k_half);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn add_dist_f64<D: Distribution<f64>>(
|
||||
&self,
|
||||
impl<T> FillDistF64 for VecZnx<T>
|
||||
where
|
||||
VecZnx<T>: VecZnxToMut,
|
||||
{
|
||||
fn fill_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
log_base2k: usize,
|
||||
a: &mut VecZnx,
|
||||
col_i: usize,
|
||||
log_k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut a: VecZnx<&mut [u8]> = self.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let log_base2k_rem: usize = log_k % log_base2k;
|
||||
let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1;
|
||||
let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k;
|
||||
|
||||
if log_base2k_rem != 0 {
|
||||
a.at_mut(a.cols() - 1).iter_mut().for_each(|a| {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = (dist_f64.round() as i64) << log_base2k_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AddDistF64 for VecZnx<T>
|
||||
where
|
||||
VecZnx<T>: VecZnxToMut,
|
||||
{
|
||||
fn add_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
log_base2k: usize,
|
||||
col_i: usize,
|
||||
log_k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut a: VecZnx<&mut [u8]> = self.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1;
|
||||
let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k;
|
||||
|
||||
if log_base2k_rem != 0 {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
@@ -58,7 +134,7 @@ impl Sampling for Module {
|
||||
*a += (dist_f64.round() as i64) << log_base2k_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(a.cols() - 1).iter_mut().for_each(|a| {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
@@ -67,11 +143,16 @@ impl Sampling for Module {
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn add_normal(&self, log_base2k: usize, a: &mut VecZnx, log_k: usize, source: &mut Source, sigma: f64, bound: f64) {
|
||||
self.add_dist_f64(
|
||||
impl<T> FillNormal for VecZnx<T>
|
||||
where
|
||||
VecZnx<T>: VecZnxToMut,
|
||||
{
|
||||
fn fill_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) {
|
||||
self.fill_dist_f64(
|
||||
log_base2k,
|
||||
a,
|
||||
col_i,
|
||||
log_k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
@@ -79,3 +160,206 @@ impl Sampling for Module {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AddNormal for VecZnx<T>
|
||||
where
|
||||
VecZnx<T>: VecZnxToMut,
|
||||
{
|
||||
fn add_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) {
|
||||
self.add_dist_f64(
|
||||
log_base2k,
|
||||
col_i,
|
||||
log_k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> FillDistF64 for VecZnxBig<T, FFT64>
|
||||
where
|
||||
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
fn fill_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
log_base2k: usize,
|
||||
col_i: usize,
|
||||
log_k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1;
|
||||
let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k;
|
||||
|
||||
if log_base2k_rem != 0 {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = (dist_f64.round() as i64) << log_base2k_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a = dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AddDistF64 for VecZnxBig<T, FFT64>
|
||||
where
|
||||
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
fn add_dist_f64<D: Distribution<f64>>(
|
||||
&mut self,
|
||||
log_base2k: usize,
|
||||
col_i: usize,
|
||||
log_k: usize,
|
||||
source: &mut Source,
|
||||
dist: D,
|
||||
bound: f64,
|
||||
) {
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
|
||||
assert!(
|
||||
(bound.log2().ceil() as i64) < 64,
|
||||
"invalid bound: ceil(log2(bound))={} > 63",
|
||||
(bound.log2().ceil() as i64)
|
||||
);
|
||||
|
||||
let limb: usize = (log_k + log_base2k - 1) / log_base2k - 1;
|
||||
let log_base2k_rem: usize = (limb + 1) * log_base2k - log_k;
|
||||
|
||||
if log_base2k_rem != 0 {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += (dist_f64.round() as i64) << log_base2k_rem;
|
||||
});
|
||||
} else {
|
||||
a.at_mut(col_i, limb).iter_mut().for_each(|a| {
|
||||
let mut dist_f64: f64 = dist.sample(source);
|
||||
while dist_f64.abs() > bound {
|
||||
dist_f64 = dist.sample(source)
|
||||
}
|
||||
*a += dist_f64.round() as i64
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> FillNormal for VecZnxBig<T, FFT64>
|
||||
where
|
||||
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
fn fill_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) {
|
||||
self.fill_dist_f64(
|
||||
log_base2k,
|
||||
col_i,
|
||||
log_k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AddNormal for VecZnxBig<T, FFT64>
|
||||
where
|
||||
VecZnxBig<T, FFT64>: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
fn add_normal(&mut self, log_base2k: usize, col_i: usize, log_k: usize, source: &mut Source, sigma: f64, bound: f64) {
|
||||
self.add_dist_f64(
|
||||
log_base2k,
|
||||
col_i,
|
||||
log_k,
|
||||
source,
|
||||
Normal::new(0.0, sigma).unwrap(),
|
||||
bound,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{AddNormal, FillUniform};
|
||||
use crate::vec_znx_ops::*;
|
||||
use crate::znx_base::*;
|
||||
use crate::{FFT64, Module, Stats, VecZnx};
|
||||
use sampling::source::Source;
|
||||
|
||||
#[test]
|
||||
fn vec_znx_fill_uniform() {
|
||||
let n: usize = 4096;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let log_base2k: usize = 17;
|
||||
let size: usize = 5;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let one_12_sqrt: f64 = 0.28867513459481287;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx<_> = module.new_vec_znx(cols, size);
|
||||
a.fill_uniform(log_base2k, col_i, size, &mut source);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.std(col_i, log_base2k);
|
||||
assert!(
|
||||
(std - one_12_sqrt).abs() < 0.01,
|
||||
"std={} ~!= {}",
|
||||
std,
|
||||
one_12_sqrt
|
||||
);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn vec_znx_add_normal() {
|
||||
let n: usize = 4096;
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(n);
|
||||
let log_base2k: usize = 17;
|
||||
let log_k: usize = 2 * 17;
|
||||
let size: usize = 5;
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = 6.0 * sigma;
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
let cols: usize = 2;
|
||||
let zero: Vec<i64> = vec![0; n];
|
||||
let k_f64: f64 = (1u64 << log_k as u64) as f64;
|
||||
(0..cols).for_each(|col_i| {
|
||||
let mut a: VecZnx<_> = module.new_vec_znx(cols, size);
|
||||
a.add_normal(log_base2k, col_i, log_k, &mut source, sigma, bound);
|
||||
(0..cols).for_each(|col_j| {
|
||||
if col_j != col_i {
|
||||
(0..size).for_each(|limb_i| {
|
||||
assert_eq!(a.at(col_j, limb_i), zero);
|
||||
})
|
||||
} else {
|
||||
let std: f64 = a.std(col_i, log_base2k) * k_f64;
|
||||
assert!((std - sigma).abs() < 0.1, "std={} ~!= {}", std, sigma);
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
306
base2k/src/scalar_znx.rs
Normal file
306
base2k/src/scalar_znx.rs
Normal file
@@ -0,0 +1,306 @@
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{
|
||||
Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned,
|
||||
};
|
||||
use rand::seq::SliceRandom;
|
||||
use rand_core::RngCore;
|
||||
use rand_distr::{Distribution, weighted::WeightedIndex};
|
||||
use sampling::source::Source;
|
||||
|
||||
pub struct ScalarZnx<D> {
|
||||
pub(crate) data: D,
|
||||
pub(crate) n: usize,
|
||||
pub(crate) cols: usize,
|
||||
}
|
||||
|
||||
impl<D> ZnxInfos for ScalarZnx<D> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxSliceSize for ScalarZnx<D> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> DataView for ScalarZnx<D> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> DataViewMut for ScalarZnx<D> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for ScalarZnx<D> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
impl<D: AsMut<[u8]> + AsRef<[u8]>> ScalarZnx<D> {
|
||||
pub fn fill_ternary_prob(&mut self, col: usize, prob: f64, source: &mut Source) {
|
||||
let choices: [i64; 3] = [-1, 0, 1];
|
||||
let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0];
|
||||
let dist: WeightedIndex<f64> = WeightedIndex::new(&weights).unwrap();
|
||||
self.at_mut(col, 0)
|
||||
.iter_mut()
|
||||
.for_each(|x: &mut i64| *x = choices[dist.sample(source)]);
|
||||
}
|
||||
|
||||
pub fn fill_ternary_hw(&mut self, col: usize, hw: usize, source: &mut Source) {
|
||||
assert!(hw <= self.n());
|
||||
self.at_mut(col, 0)[..hw]
|
||||
.iter_mut()
|
||||
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
|
||||
self.at_mut(col, 0).shuffle(source);
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>> ScalarZnx<D> {
|
||||
pub(crate) fn bytes_of<S: Sized>(n: usize, cols: usize) -> usize {
|
||||
n * cols * size_of::<S>()
|
||||
}
|
||||
|
||||
pub(crate) fn new<S: Sized>(n: usize, cols: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(Self::bytes_of::<S>(n, cols));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes<S: Sized>(n: usize, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of::<S>(n, cols));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n,
|
||||
cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type ScalarZnxOwned = ScalarZnx<Vec<u8>>;
|
||||
|
||||
pub(crate) fn bytes_of_scalar_znx<B: Backend>(module: &Module<B>, cols: usize) -> usize {
|
||||
ScalarZnxOwned::bytes_of::<i64>(module.n(), cols)
|
||||
}
|
||||
|
||||
pub trait ScalarZnxAlloc {
|
||||
fn bytes_of_scalar_znx(&self, cols: usize) -> usize;
|
||||
fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned;
|
||||
fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned;
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxAlloc for Module<B> {
|
||||
fn bytes_of_scalar_znx(&self, cols: usize) -> usize {
|
||||
ScalarZnxOwned::bytes_of::<i64>(self.n(), cols)
|
||||
}
|
||||
fn new_scalar_znx(&self, cols: usize) -> ScalarZnxOwned {
|
||||
ScalarZnxOwned::new::<i64>(self.n(), cols)
|
||||
}
|
||||
fn new_scalar_znx_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxOwned {
|
||||
ScalarZnxOwned::new_from_bytes::<i64>(self.n(), cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ScalarZnxOps {
|
||||
fn scalar_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
|
||||
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: ScalarZnxToMut;
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxOps for Module<B> {
|
||||
fn scalar_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
let a: ScalarZnx<&[u8]> = a.to_ref();
|
||||
let mut res: ScalarZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn scalar_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: ScalarZnxToMut,
|
||||
{
|
||||
let mut a: ScalarZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ScalarZnx<D> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
|
||||
Self { data, n, cols }
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ScalarZnxToRef {
|
||||
fn to_ref(&self) -> ScalarZnx<&[u8]>;
|
||||
}
|
||||
|
||||
pub trait ScalarZnxToMut {
|
||||
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]>;
|
||||
}
|
||||
|
||||
impl ScalarZnxToMut for ScalarZnx<Vec<u8>> {
|
||||
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
|
||||
ScalarZnx {
|
||||
data: self.data.as_mut_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxToMut for ScalarZnx<Vec<u8>> {
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_mut_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarZnxToRef for ScalarZnx<Vec<u8>> {
|
||||
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
||||
ScalarZnx {
|
||||
data: self.data.as_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxToRef for ScalarZnx<Vec<u8>> {
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarZnxToMut for ScalarZnx<&mut [u8]> {
|
||||
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
|
||||
ScalarZnx {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxToMut for ScalarZnx<&mut [u8]> {
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
VecZnx {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarZnxToRef for ScalarZnx<&mut [u8]> {
|
||||
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
||||
ScalarZnx {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxToRef for ScalarZnx<&mut [u8]> {
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
VecZnx {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarZnxToRef for ScalarZnx<&[u8]> {
|
||||
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
||||
ScalarZnx {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxToRef for ScalarZnx<&[u8]> {
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
VecZnx {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
233
base2k/src/scalar_znx_dft.rs
Normal file
233
base2k/src/scalar_znx_dft.rs
Normal file
@@ -0,0 +1,233 @@
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::ffi::svp;
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{
|
||||
Backend, DataView, DataViewMut, FFT64, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView,
|
||||
alloc_aligned,
|
||||
};
|
||||
|
||||
pub struct ScalarZnxDft<D, B: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ZnxInfos for ScalarZnxDft<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> ZnxSliceSize for ScalarZnxDft<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n()
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataView for ScalarZnxDft<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> DataViewMut for ScalarZnxDft<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for ScalarZnxDft<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
pub(crate) fn bytes_of_scalar_znx_dft<B: Backend>(module: &Module<B>, cols: usize) -> usize {
|
||||
ScalarZnxDftOwned::bytes_of(module, cols)
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>, B: Backend> ScalarZnxDft<D, B> {
|
||||
pub(crate) fn bytes_of(module: &Module<B>, cols: usize) -> usize {
|
||||
unsafe { svp::bytes_of_svp_ppol(module.ptr) as usize * cols }
|
||||
}
|
||||
|
||||
pub(crate) fn new(module: &Module<B>, cols: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(Self::bytes_of(module, cols));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == Self::bytes_of(module, cols));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> ScalarZnxDft<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_vec_znx_dft(self) -> VecZnxDft<D, B> {
|
||||
VecZnxDft {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type ScalarZnxDftOwned<B> = ScalarZnxDft<Vec<u8>, B>;
|
||||
|
||||
pub trait ScalarZnxDftToRef<B: Backend> {
|
||||
fn to_ref(&self) -> ScalarZnxDft<&[u8], B>;
|
||||
}
|
||||
|
||||
pub trait ScalarZnxDftToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> {
|
||||
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
|
||||
ScalarZnxDft {
|
||||
data: self.data.as_mut_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<Vec<u8>, B> {
|
||||
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
|
||||
ScalarZnxDft {
|
||||
data: self.data.as_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxDftToMut<B> for ScalarZnxDft<&mut [u8], B> {
|
||||
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
|
||||
ScalarZnxDft {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<&mut [u8], B> {
|
||||
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
|
||||
ScalarZnxDft {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
|
||||
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
|
||||
ScalarZnxDft {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> {
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_mut_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<Vec<u8>, B> {
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<&mut [u8], B> {
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&mut [u8], B> {
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
103
base2k/src/scalar_znx_dft_ops.rs
Normal file
103
base2k/src/scalar_znx_dft_ops.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
use crate::ffi::svp;
|
||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use crate::{
|
||||
Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft,
|
||||
VecZnxDftToMut, VecZnxDftToRef,
|
||||
};
|
||||
|
||||
pub trait ScalarZnxDftAlloc<B: Backend> {
|
||||
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>;
|
||||
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
|
||||
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait ScalarZnxDftOps<BACKEND: Backend> {
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxDftToMut<BACKEND>,
|
||||
A: ScalarZnxToRef;
|
||||
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<BACKEND>,
|
||||
A: ScalarZnxDftToRef<BACKEND>,
|
||||
B: VecZnxDftToRef<FFT64>;
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<BACKEND>,
|
||||
A: ScalarZnxDftToRef<BACKEND>;
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
|
||||
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B> {
|
||||
ScalarZnxDftOwned::new(self, cols)
|
||||
}
|
||||
|
||||
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize {
|
||||
ScalarZnxDftOwned::bytes_of(self, cols)
|
||||
}
|
||||
|
||||
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B> {
|
||||
ScalarZnxDftOwned::new_from_bytes(self, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxDftToMut<FFT64>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
unsafe {
|
||||
svp::svp_prepare(
|
||||
self.ptr,
|
||||
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
|
||||
a.to_ref().at_ptr(a_col, 0),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: ScalarZnxDftToRef<FFT64>,
|
||||
B: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
|
||||
b.size() as u64,
|
||||
b.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,19 @@
|
||||
use crate::{Encoding, Infos, VecZnx};
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{Decoding, VecZnx};
|
||||
use rug::Float;
|
||||
use rug::float::Round;
|
||||
use rug::ops::{AddAssignRound, DivAssignRound, SubAssignRound};
|
||||
|
||||
impl VecZnx {
|
||||
pub fn std(&self, poly_idx: usize, log_base2k: usize) -> f64 {
|
||||
let prec: u32 = (self.cols() * log_base2k) as u32;
|
||||
pub trait Stats {
|
||||
/// Returns the standard devaition of the i-th polynomial.
|
||||
fn std(&self, col_i: usize, log_base2k: usize) -> f64;
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> Stats for VecZnx<D> {
|
||||
fn std(&self, col_i: usize, log_base2k: usize) -> f64 {
|
||||
let prec: u32 = (self.size() * log_base2k) as u32;
|
||||
let mut data: Vec<Float> = (0..self.n()).map(|_| Float::with_val(prec, 0)).collect();
|
||||
self.decode_vec_float(poly_idx, log_base2k, &mut data);
|
||||
self.decode_vec_float(col_i, log_base2k, &mut data);
|
||||
// std = sqrt(sum((xi - avg)^2) / n)
|
||||
let mut avg: Float = Float::with_val(prec, 0);
|
||||
data.iter().for_each(|x| {
|
||||
|
||||
@@ -1,276 +0,0 @@
|
||||
use crate::ffi::svp::{self, svp_ppol_t};
|
||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::{BACKEND, LAYOUT, Module, VecZnx, VecZnxDft, assert_alignement};
|
||||
|
||||
use crate::{Infos, alloc_aligned, cast_mut};
|
||||
use rand::seq::SliceRandom;
|
||||
use rand_core::RngCore;
|
||||
use rand_distr::{Distribution, weighted::WeightedIndex};
|
||||
use sampling::source::Source;
|
||||
|
||||
pub struct Scalar {
|
||||
pub n: usize,
|
||||
pub data: Vec<i64>,
|
||||
pub ptr: *mut i64,
|
||||
}
|
||||
|
||||
impl Module {
|
||||
pub fn new_scalar(&self) -> Scalar {
|
||||
Scalar::new(self.n())
|
||||
}
|
||||
}
|
||||
|
||||
impl Scalar {
|
||||
pub fn new(n: usize) -> Self {
|
||||
let mut data: Vec<i64> = alloc_aligned::<i64>(n);
|
||||
let ptr: *mut i64 = data.as_mut_ptr();
|
||||
Self {
|
||||
n: n,
|
||||
data: data,
|
||||
ptr: ptr,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
pub fn bytes_of(n: usize) -> usize {
|
||||
n * std::mem::size_of::<i64>()
|
||||
}
|
||||
|
||||
pub fn from_bytes(n: usize, bytes: &mut [u8]) -> Self {
|
||||
let size: usize = Self::bytes_of(n);
|
||||
debug_assert!(
|
||||
bytes.len() == size,
|
||||
"invalid buffer: bytes.len()={} < self.bytes_of(n={})={}",
|
||||
bytes.len(),
|
||||
n,
|
||||
size
|
||||
);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(bytes.as_ptr())
|
||||
}
|
||||
unsafe {
|
||||
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
|
||||
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
|
||||
Self {
|
||||
n: n,
|
||||
data: Vec::from_raw_parts(bytes_i64.as_mut_ptr(), bytes.len(), bytes.len()),
|
||||
ptr: ptr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bytes_borrow(n: usize, bytes: &mut [u8]) -> Self {
|
||||
let size: usize = Self::bytes_of(n);
|
||||
debug_assert!(
|
||||
bytes.len() == size,
|
||||
"invalid buffer: bytes.len()={} < self.bytes_of(n={})={}",
|
||||
bytes.len(),
|
||||
n,
|
||||
size
|
||||
);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(bytes.as_ptr())
|
||||
}
|
||||
let bytes_i64: &mut [i64] = cast_mut::<u8, i64>(bytes);
|
||||
let ptr: *mut i64 = bytes_i64.as_mut_ptr();
|
||||
Self {
|
||||
n: n,
|
||||
data: Vec::new(),
|
||||
ptr: ptr,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_ptr(&self) -> *const i64 {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
pub fn raw(&self) -> &[i64] {
|
||||
unsafe { std::slice::from_raw_parts(self.ptr, self.n) }
|
||||
}
|
||||
|
||||
pub fn raw_mut(&self) -> &mut [i64] {
|
||||
unsafe { std::slice::from_raw_parts_mut(self.ptr, self.n) }
|
||||
}
|
||||
|
||||
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
|
||||
let choices: [i64; 3] = [-1, 0, 1];
|
||||
let weights: [f64; 3] = [prob / 2.0, 1.0 - prob, prob / 2.0];
|
||||
let dist: WeightedIndex<f64> = WeightedIndex::new(&weights).unwrap();
|
||||
self.data
|
||||
.iter_mut()
|
||||
.for_each(|x: &mut i64| *x = choices[dist.sample(source)]);
|
||||
}
|
||||
|
||||
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
|
||||
assert!(hw <= self.n());
|
||||
self.data[..hw]
|
||||
.iter_mut()
|
||||
.for_each(|x: &mut i64| *x = (((source.next_u32() & 1) as i64) << 1) - 1);
|
||||
self.data.shuffle(source);
|
||||
}
|
||||
|
||||
pub fn as_vec_znx(&self) -> VecZnx {
|
||||
VecZnx {
|
||||
n: self.n,
|
||||
size: 1, // TODO REVIEW IF NEED TO ADD size TO SCALAR
|
||||
cols: 1,
|
||||
layout: LAYOUT::COL,
|
||||
data: Vec::new(),
|
||||
ptr: self.ptr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ScalarOps {
|
||||
fn bytes_of_scalar(&self) -> usize;
|
||||
fn new_scalar(&self) -> Scalar;
|
||||
fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar;
|
||||
fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar;
|
||||
}
|
||||
impl ScalarOps for Module {
|
||||
fn bytes_of_scalar(&self) -> usize {
|
||||
Scalar::bytes_of(self.n())
|
||||
}
|
||||
fn new_scalar(&self) -> Scalar {
|
||||
Scalar::new(self.n())
|
||||
}
|
||||
fn new_scalar_from_bytes(&self, bytes: &mut [u8]) -> Scalar {
|
||||
Scalar::from_bytes(self.n(), bytes)
|
||||
}
|
||||
fn new_scalar_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> Scalar {
|
||||
Scalar::from_bytes_borrow(self.n(), tmp_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SvpPPol {
|
||||
pub n: usize,
|
||||
pub data: Vec<u8>,
|
||||
pub ptr: *mut u8,
|
||||
pub backend: BACKEND,
|
||||
}
|
||||
|
||||
/// A prepared [crate::Scalar] for [SvpPPolOps::svp_apply_dft].
|
||||
/// An [SvpPPol] an be seen as a [VecZnxDft] of one limb.
|
||||
impl SvpPPol {
|
||||
pub fn new(module: &Module) -> Self {
|
||||
module.new_svp_ppol()
|
||||
}
|
||||
|
||||
/// Returns the ring degree of the [SvpPPol].
|
||||
pub fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
pub fn bytes_of(module: &Module) -> usize {
|
||||
module.bytes_of_svp_ppol()
|
||||
}
|
||||
|
||||
pub fn from_bytes(module: &Module, bytes: &mut [u8]) -> SvpPPol {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(bytes.as_ptr());
|
||||
assert_eq!(bytes.len(), module.bytes_of_svp_ppol());
|
||||
}
|
||||
unsafe {
|
||||
Self {
|
||||
n: module.n(),
|
||||
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
|
||||
ptr: bytes.as_mut_ptr(),
|
||||
backend: module.backend(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bytes_borrow(module: &Module, tmp_bytes: &mut [u8]) -> SvpPPol {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
assert_eq!(tmp_bytes.len(), module.bytes_of_svp_ppol());
|
||||
}
|
||||
Self {
|
||||
n: module.n(),
|
||||
data: Vec::new(),
|
||||
ptr: tmp_bytes.as_mut_ptr(),
|
||||
backend: module.backend(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of cols of the [SvpPPol], which is always 1.
|
||||
pub fn cols(&self) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SvpPPolOps {
|
||||
/// Allocates a new [SvpPPol].
|
||||
fn new_svp_ppol(&self) -> SvpPPol;
|
||||
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [SvpPPol] through [SvpPPol::from_bytes] ro.
|
||||
fn bytes_of_svp_ppol(&self) -> usize;
|
||||
|
||||
/// Allocates a new [SvpPPol] from an array of bytes.
|
||||
/// The array of bytes is owned by the [SvpPPol].
|
||||
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
|
||||
fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol;
|
||||
|
||||
/// Allocates a new [SvpPPol] from an array of bytes.
|
||||
/// The array of bytes is borrowed by the [SvpPPol].
|
||||
/// The method will panic if bytes.len() < [SvpPPolOps::bytes_of_svp_ppol]
|
||||
fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol;
|
||||
|
||||
/// Prepares a [crate::Scalar] for a [SvpPPolOps::svp_apply_dft].
|
||||
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar);
|
||||
|
||||
/// Applies the [SvpPPol] x [VecZnxDft] product, where each limb of
|
||||
/// the [VecZnxDft] is multiplied with [SvpPPol].
|
||||
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx);
|
||||
}
|
||||
|
||||
impl SvpPPolOps for Module {
|
||||
fn new_svp_ppol(&self) -> SvpPPol {
|
||||
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_svp_ppol());
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
SvpPPol {
|
||||
data: data,
|
||||
ptr: ptr,
|
||||
n: self.n(),
|
||||
backend: self.backend(),
|
||||
}
|
||||
}
|
||||
|
||||
fn bytes_of_svp_ppol(&self) -> usize {
|
||||
unsafe { svp::bytes_of_svp_ppol(self.ptr) as usize }
|
||||
}
|
||||
|
||||
fn new_svp_ppol_from_bytes(&self, bytes: &mut [u8]) -> SvpPPol {
|
||||
SvpPPol::from_bytes(self, bytes)
|
||||
}
|
||||
|
||||
fn new_svp_ppol_from_bytes_borrow(&self, tmp_bytes: &mut [u8]) -> SvpPPol {
|
||||
SvpPPol::from_bytes_borrow(self, tmp_bytes)
|
||||
}
|
||||
|
||||
fn svp_prepare(&self, svp_ppol: &mut SvpPPol, a: &Scalar) {
|
||||
unsafe { svp::svp_prepare(self.ptr, svp_ppol.ptr as *mut svp_ppol_t, a.as_ptr()) }
|
||||
}
|
||||
|
||||
fn svp_apply_dft(&self, c: &mut VecZnxDft, a: &SvpPPol, b: &VecZnx) {
|
||||
unsafe {
|
||||
svp::svp_apply_dft(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.cols() as u64,
|
||||
a.ptr as *const svp_ppol_t,
|
||||
b.as_ptr(),
|
||||
b.cols() as u64,
|
||||
b.n() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,90 +1,26 @@
|
||||
use crate::ffi::vec_znx_big::{self, vec_znx_big_t};
|
||||
use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxDft, alloc_aligned, assert_alignement};
|
||||
use crate::ffi::vec_znx_big;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView};
|
||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, VecZnx, ZnxSliceSize, ZnxViewMut, ZnxZero, alloc_aligned};
|
||||
use std::fmt;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
pub struct VecZnxBig {
|
||||
pub data: Vec<u8>,
|
||||
pub ptr: *mut u8,
|
||||
pub n: usize,
|
||||
pub size: usize,
|
||||
pub cols: usize,
|
||||
pub layout: LAYOUT,
|
||||
pub backend: BACKEND,
|
||||
pub struct VecZnxBig<D, B: Backend> {
|
||||
data: D,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
_phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl VecZnxBig {
|
||||
/// Returns a new [VecZnxBig] with the provided data as backing array.
|
||||
/// User must ensure that data is properly alligned and that
|
||||
/// the size of data is at least equal to [Module::bytes_of_vec_znx_big].
|
||||
pub fn from_bytes(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(size, cols));
|
||||
assert_alignement(bytes.as_ptr())
|
||||
};
|
||||
unsafe {
|
||||
Self {
|
||||
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
|
||||
ptr: bytes.as_mut_ptr(),
|
||||
n: module.n(),
|
||||
size: size,
|
||||
layout: LAYOUT::COL,
|
||||
cols: cols,
|
||||
backend: module.backend,
|
||||
}
|
||||
}
|
||||
impl<D, B: Backend> ZnxInfos for VecZnxBig<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
pub fn from_bytes_borrow(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(bytes.len(), module.bytes_of_vec_znx_big(size, cols));
|
||||
assert_alignement(bytes.as_ptr());
|
||||
}
|
||||
Self {
|
||||
data: Vec::new(),
|
||||
ptr: bytes.as_mut_ptr(),
|
||||
n: module.n(),
|
||||
size: size,
|
||||
layout: LAYOUT::COL,
|
||||
cols: cols,
|
||||
backend: module.backend,
|
||||
}
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
pub fn as_vec_znx_dft(&mut self) -> VecZnxDft {
|
||||
VecZnxDft {
|
||||
data: Vec::new(),
|
||||
ptr: self.ptr,
|
||||
n: self.n,
|
||||
size: self.size,
|
||||
layout: LAYOUT::COL,
|
||||
cols: self.cols,
|
||||
backend: self.backend,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn backend(&self) -> BACKEND {
|
||||
self.backend
|
||||
}
|
||||
|
||||
/// Returns a non-mutable reference of `T` of the entire contiguous array of the [VecZnxDft].
|
||||
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
|
||||
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
|
||||
/// The length of the returned array is cols * n.
|
||||
pub fn raw<T>(&self, module: &Module) -> &[T] {
|
||||
let ptr: *const T = self.ptr as *const T;
|
||||
let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::<T>();
|
||||
unsafe { &std::slice::from_raw_parts(ptr, len) }
|
||||
}
|
||||
}
|
||||
|
||||
impl Infos for VecZnxBig {
|
||||
/// Returns the base 2 logarithm of the [VecZnx] degree.
|
||||
fn log_n(&self) -> usize {
|
||||
(usize::BITS - (self.n - 1).leading_zeros()) as _
|
||||
}
|
||||
|
||||
/// Returns the [VecZnx] degree.
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
@@ -92,270 +28,217 @@ impl Infos for VecZnxBig {
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
fn layout(&self) -> LAYOUT {
|
||||
self.layout
|
||||
}
|
||||
|
||||
/// Returns the number of cols of the [VecZnx].
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
/// Returns the number of rows of the [VecZnx].
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
impl<D> ZnxSliceSize for VecZnxBig<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxBigOps {
|
||||
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
|
||||
fn new_vec_znx_big(&self, size: usize, cols: usize) -> VecZnxBig;
|
||||
|
||||
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||
///
|
||||
/// Behavior: takes ownership of the backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxBig].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||
fn new_vec_znx_big_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxBig;
|
||||
|
||||
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||
///
|
||||
/// Behavior: the backing array is only borrowed.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxBig].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||
fn new_vec_znx_big_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig;
|
||||
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
|
||||
fn bytes_of_vec_znx_big(&self, size: usize, cols: usize) -> usize;
|
||||
|
||||
/// b <- b - a
|
||||
fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx);
|
||||
|
||||
/// c <- b - a
|
||||
fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig);
|
||||
|
||||
/// c <- b + a
|
||||
fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig);
|
||||
|
||||
/// b <- b + a
|
||||
fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx);
|
||||
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
|
||||
|
||||
/// b <- normalize(a)
|
||||
fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]);
|
||||
|
||||
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize;
|
||||
|
||||
fn vec_znx_big_range_normalize_base2k(
|
||||
&self,
|
||||
log_base2k: usize,
|
||||
res: &mut VecZnx,
|
||||
a: &VecZnxBig,
|
||||
a_range_begin: usize,
|
||||
a_range_xend: usize,
|
||||
a_range_step: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
);
|
||||
|
||||
fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig);
|
||||
|
||||
fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig);
|
||||
impl<D, B: Backend> DataView for VecZnxBig<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxBigOps for Module {
|
||||
fn new_vec_znx_big(&self, size: usize, cols: usize) -> VecZnxBig {
|
||||
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_vec_znx_big(size, cols));
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
impl<D, B: Backend> DataViewMut for VecZnxBig<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for VecZnxBig<D, FFT64> {
|
||||
type Scalar = i64;
|
||||
}
|
||||
|
||||
pub(crate) fn bytes_of_vec_znx_big<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize {
|
||||
unsafe { vec_znx_big::bytes_of_vec_znx_big(module.ptr, size as u64) as usize * cols }
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>, B: Backend> VecZnxBig<D, B> {
|
||||
pub(crate) fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(bytes_of_vec_znx_big(module, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == bytes_of_vec_znx_big(module, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D, B: Backend> VecZnxBig<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D> VecZnxBig<D, FFT64>
|
||||
where
|
||||
VecZnxBig<D, FFT64>: VecZnxBigToMut<FFT64> + ZnxInfos,
|
||||
{
|
||||
// Consumes the VecZnxBig to return a VecZnx.
|
||||
// Useful when no normalization is needed.
|
||||
pub fn to_vec_znx_small(self) -> VecZnx<D> {
|
||||
VecZnx {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
|
||||
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxBig<C, FFT64>, a_col: usize)
|
||||
where
|
||||
VecZnxBig<C, FFT64>: VecZnxBigToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(self_col < self.cols());
|
||||
assert!(a_col < a.cols());
|
||||
}
|
||||
|
||||
let min_size: usize = self.size.min(a.size());
|
||||
let max_size: usize = self.size;
|
||||
|
||||
let mut self_mut: VecZnxBig<&mut [u8], FFT64> = self.to_mut();
|
||||
let a_ref: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
(0..min_size).for_each(|i: usize| {
|
||||
self_mut
|
||||
.at_mut(self_col, i)
|
||||
.copy_from_slice(a_ref.at(a_col, i));
|
||||
});
|
||||
|
||||
(min_size..max_size).for_each(|i| {
|
||||
self_mut.zero_at(self_col, i);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxBigOwned<B> = VecZnxBig<Vec<u8>, B>;
|
||||
|
||||
pub trait VecZnxBigToRef<B: Backend> {
|
||||
fn to_ref(&self) -> VecZnxBig<&[u8], B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxBigToMut<B> for VecZnxBig<Vec<u8>, B> {
|
||||
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
|
||||
VecZnxBig {
|
||||
data: data,
|
||||
ptr: ptr,
|
||||
n: self.n(),
|
||||
size: size,
|
||||
layout: LAYOUT::COL,
|
||||
cols: cols,
|
||||
backend: self.backend(),
|
||||
}
|
||||
}
|
||||
|
||||
fn new_vec_znx_big_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxBig {
|
||||
VecZnxBig::from_bytes(self, size, cols, bytes)
|
||||
}
|
||||
|
||||
fn new_vec_znx_big_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxBig {
|
||||
VecZnxBig::from_bytes_borrow(self, size, cols, tmp_bytes)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx_big(&self, size: usize, cols: usize) -> usize {
|
||||
unsafe { vec_znx_big::bytes_of_vec_znx_big(self.ptr, cols as u64) as usize * size }
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_a_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) {
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_sub_small_a(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
a.as_ptr(),
|
||||
a.cols() as u64,
|
||||
a.n() as u64,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_a(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) {
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_sub_small_a(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_big_t,
|
||||
c.cols() as u64,
|
||||
a.as_ptr(),
|
||||
a.cols() as u64,
|
||||
a.n() as u64,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_add_small(&self, c: &mut VecZnxBig, a: &VecZnx, b: &VecZnxBig) {
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_add_small(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_big_t,
|
||||
c.cols() as u64,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
a.as_ptr(),
|
||||
a.cols() as u64,
|
||||
a.n() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_add_small_inplace(&self, b: &mut VecZnxBig, a: &VecZnx) {
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_add_small(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
a.as_ptr(),
|
||||
a.cols() as u64,
|
||||
a.n() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
|
||||
unsafe { vec_znx_big::vec_znx_big_normalize_base2k_tmp_bytes(self.ptr) as usize }
|
||||
}
|
||||
|
||||
fn vec_znx_big_normalize(&self, log_base2k: usize, b: &mut VecZnx, a: &VecZnxBig, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(
|
||||
tmp_bytes.len() >= <Module as VecZnxBigOps>::vec_znx_big_normalize_tmp_bytes(self),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_normalize_tmp_bytes()={}",
|
||||
tmp_bytes.len(),
|
||||
<Module as VecZnxBigOps>::vec_znx_big_normalize_tmp_bytes(self)
|
||||
);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_normalize_base2k(
|
||||
self.ptr,
|
||||
log_base2k as u64,
|
||||
b.as_mut_ptr(),
|
||||
b.cols() as u64,
|
||||
b.n() as u64,
|
||||
a.ptr as *mut vec_znx_big_t,
|
||||
a.cols() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_range_normalize_base2k_tmp_bytes(&self) -> usize {
|
||||
unsafe { vec_znx_big::vec_znx_big_range_normalize_base2k_tmp_bytes(self.ptr) as usize }
|
||||
}
|
||||
|
||||
fn vec_znx_big_range_normalize_base2k(
|
||||
&self,
|
||||
log_base2k: usize,
|
||||
res: &mut VecZnx,
|
||||
a: &VecZnxBig,
|
||||
a_range_begin: usize,
|
||||
a_range_xend: usize,
|
||||
a_range_step: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
debug_assert!(
|
||||
tmp_bytes.len() >= <Module as VecZnxBigOps>::vec_znx_big_range_normalize_base2k_tmp_bytes(self),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} <= self.vec_znx_big_range_normalize_base2k_tmp_bytes()={}",
|
||||
tmp_bytes.len(),
|
||||
<Module as VecZnxBigOps>::vec_znx_big_range_normalize_base2k_tmp_bytes(self)
|
||||
);
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_range_normalize_base2k(
|
||||
self.ptr,
|
||||
log_base2k as u64,
|
||||
res.as_mut_ptr(),
|
||||
res.cols() as u64,
|
||||
res.n() as u64,
|
||||
a.ptr as *mut vec_znx_big_t,
|
||||
a_range_begin as u64,
|
||||
a_range_xend as u64,
|
||||
a_range_step as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_automorphism(&self, gal_el: i64, b: &mut VecZnxBig, a: &VecZnxBig) {
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_automorphism(
|
||||
self.ptr,
|
||||
gal_el,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
a.ptr as *mut vec_znx_big_t,
|
||||
a.cols() as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_automorphism_inplace(&self, gal_el: i64, a: &mut VecZnxBig) {
|
||||
unsafe {
|
||||
vec_znx_big::vec_znx_big_automorphism(
|
||||
self.ptr,
|
||||
gal_el,
|
||||
a.ptr as *mut vec_znx_big_t,
|
||||
a.cols() as u64,
|
||||
a.ptr as *mut vec_znx_big_t,
|
||||
a.cols() as u64,
|
||||
);
|
||||
data: self.data.as_mut_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<Vec<u8>, B> {
|
||||
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
|
||||
VecZnxBig {
|
||||
data: self.data.as_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxBigToMut<B> for VecZnxBig<&mut [u8], B> {
|
||||
fn to_mut(&mut self) -> VecZnxBig<&mut [u8], B> {
|
||||
VecZnxBig {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<&mut [u8], B> {
|
||||
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
|
||||
VecZnxBig {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxBigToRef<B> for VecZnxBig<&[u8], B> {
|
||||
fn to_ref(&self) -> VecZnxBig<&[u8], B> {
|
||||
VecZnxBig {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> fmt::Display for VecZnxBig<D, FFT64> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnxBig(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
632
base2k/src/vec_znx_big_ops.rs
Normal file
632
base2k/src/vec_znx_big_ops.rs
Normal file
@@ -0,0 +1,632 @@
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use crate::{
|
||||
Backend, FFT64, Module, Scratch, VecZnx, VecZnxBig, VecZnxBigOwned, VecZnxBigToMut, VecZnxBigToRef, VecZnxScratch,
|
||||
VecZnxToMut, VecZnxToRef, ZnxSliceSize, bytes_of_vec_znx_big,
|
||||
};
|
||||
|
||||
pub trait VecZnxBigAlloc<B: Backend> {
|
||||
/// Allocates a vector Z[X]/(X^N+1) that stores not normalized values.
|
||||
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<B>;
|
||||
|
||||
/// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||
///
|
||||
/// Behavior: takes ownership of the backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials..
|
||||
/// * `size`: the number of polynomials per column.
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B>;
|
||||
|
||||
// /// Returns a new [VecZnxBig] with the provided bytes array as backing array.
|
||||
// ///
|
||||
// /// Behavior: the backing array is only borrowed.
|
||||
// ///
|
||||
// /// # Arguments
|
||||
// ///
|
||||
// /// * `cols`: the number of polynomials..
|
||||
// /// * `size`: the number of polynomials per column.
|
||||
// /// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_big].
|
||||
// ///
|
||||
// /// # Panics
|
||||
// /// If `bytes.len()` < [Module::bytes_of_vec_znx_big].
|
||||
// fn new_vec_znx_big_from_bytes_borrow(&self, cols: usize, size: usize, tmp_bytes: &mut [u8]) -> VecZnxBig<B>;
|
||||
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnxBig] through [VecZnxBig::from_bytes].
|
||||
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigOps<BACKEND: Backend> {
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>,
|
||||
B: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Adds `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_add_small<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>,
|
||||
B: VecZnxToRef;
|
||||
|
||||
/// Adds `a` to `b` and stores the result on `b`.
|
||||
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Subtracts `a` to `b` and stores the result on `c`.
|
||||
fn vec_znx_big_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>,
|
||||
B: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Subtracts `a` from `b` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Subtracts `b` from `a` and stores the result on `b`.
|
||||
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_a<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Subtracts `a` from `res` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Subtracts `b` from `a` and stores the result on `c`.
|
||||
fn vec_znx_big_sub_small_b<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>,
|
||||
B: VecZnxToRef;
|
||||
|
||||
/// Subtracts `res` from `a` and stores the result on `res`.
|
||||
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Negates `a` inplace.
|
||||
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<BACKEND>;
|
||||
|
||||
/// Normalizes `a` and stores the result on `b`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `log_base2k`: normalization basis.
|
||||
/// * `tmp_bytes`: scratch space of size at least [VecZnxBigOps::vec_znx_big_normalize].
|
||||
fn vec_znx_big_normalize<R, A>(
|
||||
&self,
|
||||
log_base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `b`.
|
||||
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxBigToRef<BACKEND>;
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on `a` and stores the result on `a`.
|
||||
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<BACKEND>;
|
||||
}
|
||||
|
||||
pub trait VecZnxBigScratch {
|
||||
/// Returns the minimum number of bytes to apply [VecZnxBigOps::vec_znx_big_normalize].
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxBigAlloc<B> for Module<B> {
|
||||
fn new_vec_znx_big(&self, cols: usize, size: usize) -> VecZnxBigOwned<B> {
|
||||
VecZnxBig::new(self, cols, size)
|
||||
}
|
||||
|
||||
fn new_vec_znx_big_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxBigOwned<B> {
|
||||
VecZnxBig::new_from_bytes(self, cols, size, bytes)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx_big(&self, cols: usize, size: usize) -> usize {
|
||||
bytes_of_vec_znx_big(self, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
fn vec_znx_big_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_b<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_b_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_a<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnxBig<&[u8], FFT64> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_sub_small_a_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_add_small<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_add_small_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, res_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
self.ptr,
|
||||
a.at_mut_ptr(res_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(res_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_normalize<R, A>(
|
||||
&self,
|
||||
log_base2k: usize,
|
||||
res: &mut R,
|
||||
res_col: usize,
|
||||
a: &A,
|
||||
a_col: usize,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
//(Jay)Note: This is calling VezZnxOps::vec_znx_normalize_tmp_bytes and not VecZnxBigOps::vec_znx_big_normalize_tmp_bytes.
|
||||
// In the FFT backend the tmp sizes are same but will be different in the NTT backend
|
||||
// assert!(tmp_bytes.len() >= <Self as VecZnxOps<&mut [u8], & [u8]>>::vec_znx_normalize_tmp_bytes(&self));
|
||||
// assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(<Self as VecZnxBigScratch>::vec_znx_big_normalize_tmp_bytes(
|
||||
&self,
|
||||
));
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
self.ptr,
|
||||
log_base2k as u64,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxBigToRef<FFT64>,
|
||||
{
|
||||
let a: VecZnxBig<&[u8], FFT64> = a.to_ref();
|
||||
let mut res: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxBigToMut<FFT64>,
|
||||
{
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxBigScratch for Module<B> {
|
||||
fn vec_znx_big_normalize_tmp_bytes(&self) -> usize {
|
||||
<Self as VecZnxScratch>::vec_znx_normalize_tmp_bytes(self)
|
||||
}
|
||||
}
|
||||
@@ -1,114 +1,35 @@
|
||||
use crate::ffi::vec_znx_big::vec_znx_big_t;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::ffi::vec_znx_dft;
|
||||
use crate::ffi::vec_znx_dft::{bytes_of_vec_znx_dft, vec_znx_dft_t};
|
||||
use crate::{BACKEND, Infos, LAYOUT, Module, VecZnxBig, assert_alignement};
|
||||
use crate::{DEFAULTALIGN, VecZnx, alloc_aligned};
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{
|
||||
Backend, DataView, DataViewMut, FFT64, Module, VecZnxBig, ZnxSliceSize, ZnxView, ZnxViewMut, ZnxZero, alloc_aligned,
|
||||
};
|
||||
use std::fmt;
|
||||
|
||||
pub struct VecZnxDft {
|
||||
pub data: Vec<u8>,
|
||||
pub ptr: *mut u8,
|
||||
pub n: usize,
|
||||
pub size: usize,
|
||||
pub layout: LAYOUT,
|
||||
pub cols: usize,
|
||||
pub backend: BACKEND,
|
||||
pub struct VecZnxDft<D, B: Backend> {
|
||||
pub(crate) data: D,
|
||||
pub(crate) n: usize,
|
||||
pub(crate) cols: usize,
|
||||
pub(crate) size: usize,
|
||||
pub(crate) _phantom: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl VecZnxDft {
|
||||
/// Returns a new [VecZnxDft] with the provided data as backing array.
|
||||
/// User must ensure that data is properly alligned and that
|
||||
/// the size of data is at least equal to [Module::bytes_of_vec_znx_dft].
|
||||
pub fn from_bytes(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(size, cols));
|
||||
assert_alignement(bytes.as_ptr())
|
||||
}
|
||||
unsafe {
|
||||
VecZnxDft {
|
||||
data: Vec::from_raw_parts(bytes.as_mut_ptr(), bytes.len(), bytes.len()),
|
||||
ptr: bytes.as_mut_ptr(),
|
||||
n: module.n(),
|
||||
size: size,
|
||||
layout: LAYOUT::COL,
|
||||
cols: cols,
|
||||
backend: module.backend,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_bytes_borrow(module: &Module, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(bytes.len(), module.bytes_of_vec_znx_dft(size, cols));
|
||||
assert_alignement(bytes.as_ptr());
|
||||
}
|
||||
VecZnxDft {
|
||||
data: Vec::new(),
|
||||
ptr: bytes.as_mut_ptr(),
|
||||
n: module.n(),
|
||||
size: size,
|
||||
layout: LAYOUT::COL,
|
||||
cols: cols,
|
||||
backend: module.backend,
|
||||
}
|
||||
}
|
||||
|
||||
/// Cast a [VecZnxDft] into a [VecZnxBig].
|
||||
/// The returned [VecZnxBig] shares the backing array
|
||||
/// with the original [VecZnxDft].
|
||||
pub fn as_vec_znx_big(&mut self) -> VecZnxBig {
|
||||
VecZnxBig {
|
||||
data: Vec::new(),
|
||||
ptr: self.ptr,
|
||||
n: self.n,
|
||||
layout: LAYOUT::COL,
|
||||
size: self.size,
|
||||
cols: self.cols,
|
||||
backend: self.backend,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn backend(&self) -> BACKEND {
|
||||
self.backend
|
||||
}
|
||||
|
||||
/// Returns a non-mutable reference of `T` of the entire contiguous array of the [VecZnxDft].
|
||||
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
|
||||
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
|
||||
/// The length of the returned array is cols * n.
|
||||
pub fn raw<T>(&self, module: &Module) -> &[T] {
|
||||
let ptr: *const T = self.ptr as *const T;
|
||||
let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::<T>();
|
||||
unsafe { &std::slice::from_raw_parts(ptr, len) }
|
||||
}
|
||||
|
||||
pub fn at<T>(&self, module: &Module, col_i: usize) -> &[T] {
|
||||
&self.raw::<T>(module)[col_i * module.n()..(col_i + 1) * module.n()]
|
||||
}
|
||||
|
||||
/// Returns a mutable reference of `T` of the entire contiguous array of the [VecZnxDft].
|
||||
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
|
||||
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
|
||||
/// The length of the returned array is cols * n.
|
||||
pub fn raw_mut<T>(&self, module: &Module) -> &mut [T] {
|
||||
let ptr: *mut T = self.ptr as *mut T;
|
||||
let len: usize = (self.cols() * module.n() * 8) / std::mem::size_of::<T>();
|
||||
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
|
||||
}
|
||||
|
||||
pub fn at_mut<T>(&self, module: &Module, col_i: usize) -> &mut [T] {
|
||||
&mut self.raw_mut::<T>(module)[col_i * module.n()..(col_i + 1) * module.n()]
|
||||
impl<D, B: Backend> VecZnxDft<D, B> {
|
||||
pub fn into_big(self) -> VecZnxBig<D, B> {
|
||||
VecZnxBig::<D, B>::from_data(self.data, self.n, self.cols, self.size)
|
||||
}
|
||||
}
|
||||
|
||||
impl Infos for VecZnxDft {
|
||||
/// Returns the base 2 logarithm of the [VecZnx] degree.
|
||||
fn log_n(&self) -> usize {
|
||||
(usize::BITS - (self.n - 1).leading_zeros()) as _
|
||||
impl<D, B: Backend> ZnxInfos for VecZnxDft<D, B> {
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
/// Returns the [VecZnx] degree.
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
@@ -116,254 +37,206 @@ impl Infos for VecZnxDft {
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
}
|
||||
|
||||
fn layout(&self) -> LAYOUT {
|
||||
self.layout
|
||||
}
|
||||
|
||||
/// Returns the number of cols of the [VecZnx].
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
|
||||
/// Returns the number of rows of the [VecZnx].
|
||||
fn rows(&self) -> usize {
|
||||
1
|
||||
impl<D> ZnxSliceSize for VecZnxDft<D, FFT64> {
|
||||
fn sl(&self) -> usize {
|
||||
self.n() * self.cols()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxDftOps {
|
||||
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
|
||||
fn new_vec_znx_dft(&self, size: usize, cols: usize) -> VecZnxDft;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// Behavior: takes ownership of the backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn new_vec_znx_dft_from_bytes(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// Behavior: the backing array is only borrowed.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn new_vec_znx_dft_from_bytes_borrow(&self, size: usize, cols: usize, bytes: &mut [u8]) -> VecZnxDft;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn bytes_of_vec_znx_dft(&self, size: usize, cols: usize) -> usize;
|
||||
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
|
||||
fn vec_znx_idft_tmp_bytes(&self) -> usize;
|
||||
|
||||
/// b <- IDFT(a), uses a as scratch space.
|
||||
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft);
|
||||
|
||||
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]);
|
||||
|
||||
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx);
|
||||
|
||||
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft);
|
||||
|
||||
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]);
|
||||
|
||||
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize;
|
||||
impl<D, B: Backend> DataView for VecZnxDft<D, B> {
|
||||
type D = D;
|
||||
fn data(&self) -> &Self::D {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxDftOps for Module {
|
||||
fn new_vec_znx_dft(&self, size: usize, cols: usize) -> VecZnxDft {
|
||||
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_vec_znx_dft(size, cols));
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
VecZnxDft {
|
||||
data: data,
|
||||
ptr: ptr,
|
||||
n: self.n(),
|
||||
size: size,
|
||||
layout: LAYOUT::COL,
|
||||
cols: cols,
|
||||
backend: self.backend(),
|
||||
impl<D, B: Backend> DataViewMut for VecZnxDft<D, B> {
|
||||
fn data_mut(&mut self) -> &mut Self::D {
|
||||
&mut self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> ZnxView for VecZnxDft<D, FFT64> {
|
||||
type Scalar = f64;
|
||||
}
|
||||
|
||||
pub(crate) fn bytes_of_vec_znx_dft<B: Backend>(module: &Module<B>, cols: usize, size: usize) -> usize {
|
||||
unsafe { vec_znx_dft::bytes_of_vec_znx_dft(module.ptr, size as u64) as usize * cols }
|
||||
}
|
||||
|
||||
impl<D: From<Vec<u8>>, B: Backend> VecZnxDft<D, B> {
|
||||
pub(crate) fn new(module: &Module<B>, cols: usize, size: usize) -> Self {
|
||||
let data = alloc_aligned::<u8>(bytes_of_vec_znx_dft(module, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn new_vec_znx_dft_from_bytes(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
|
||||
VecZnxDft::from_bytes(self, size, cols, tmp_bytes)
|
||||
}
|
||||
|
||||
fn new_vec_znx_dft_from_bytes_borrow(&self, size: usize, cols: usize, tmp_bytes: &mut [u8]) -> VecZnxDft {
|
||||
VecZnxDft::from_bytes_borrow(self, size, cols, tmp_bytes)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx_dft(&self, size: usize, cols: usize) -> usize {
|
||||
unsafe { bytes_of_vec_znx_dft(self.ptr, cols as u64) as usize * size }
|
||||
}
|
||||
|
||||
fn vec_znx_idft_tmp_a(&self, b: &mut VecZnxBig, a: &mut VecZnxDft) {
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
a.ptr as *mut vec_znx_dft_t,
|
||||
a.cols() as u64,
|
||||
)
|
||||
pub(crate) fn new_from_bytes(module: &Module<B>, cols: usize, size: usize, bytes: impl Into<Vec<u8>>) -> Self {
|
||||
let data: Vec<u8> = bytes.into();
|
||||
assert!(data.len() == bytes_of_vec_znx_dft(module, cols, size));
|
||||
Self {
|
||||
data: data.into(),
|
||||
n: module.n(),
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_idft_tmp_bytes(&self) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
|
||||
}
|
||||
|
||||
/// b <- DFT(a)
|
||||
///
|
||||
/// # Panics
|
||||
/// If b.cols < a_cols
|
||||
fn vec_znx_dft(&self, b: &mut VecZnxDft, a: &VecZnx) {
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_dft(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
b.cols() as u64,
|
||||
a.as_ptr(),
|
||||
a.cols() as u64,
|
||||
a.n() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
|
||||
fn vec_znx_idft(&self, b: &mut VecZnxBig, a: &VecZnxDft, tmp_bytes: &mut [u8]) {
|
||||
impl<D> VecZnxDft<D, FFT64>
|
||||
where
|
||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64> + ZnxInfos,
|
||||
{
|
||||
/// Extracts the a_col-th column of 'a' and stores it on the self_col-th column [Self].
|
||||
pub fn extract_column<C>(&mut self, self_col: usize, a: &VecZnxDft<C, FFT64>, a_col: usize)
|
||||
where
|
||||
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
tmp_bytes.len() >= Self::vec_znx_idft_tmp_bytes(self),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_idft_tmp_bytes()={}",
|
||||
tmp_bytes.len(),
|
||||
Self::vec_znx_idft_tmp_bytes(self)
|
||||
);
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
assert!(self_col < self.cols());
|
||||
assert!(a_col < a.cols());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_idft(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
b.cols() as u64,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a.cols() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_automorphism(&self, k: i64, b: &mut VecZnxDft, a: &VecZnxDft) {
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_dft_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
b.cols() as u64,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a.cols() as u64,
|
||||
[0u8; 0].as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
let min_size: usize = self.size.min(a.size());
|
||||
let max_size: usize = self.size;
|
||||
|
||||
fn vec_znx_dft_automorphism_inplace(&self, k: i64, a: &mut VecZnxDft, tmp_bytes: &mut [u8]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
tmp_bytes.len() >= Self::vec_znx_dft_automorphism_tmp_bytes(self),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} < self.vec_znx_dft_automorphism_tmp_bytes()={}",
|
||||
tmp_bytes.len(),
|
||||
Self::vec_znx_dft_automorphism_tmp_bytes(self)
|
||||
);
|
||||
assert_alignement(tmp_bytes.as_ptr())
|
||||
}
|
||||
unsafe {
|
||||
vec_znx_dft::vec_znx_dft_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
a.ptr as *mut vec_znx_dft_t,
|
||||
a.cols() as u64,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a.cols() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
let mut self_mut: VecZnxDft<&mut [u8], FFT64> = self.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
fn vec_znx_dft_automorphism_tmp_bytes(&self) -> usize {
|
||||
unsafe {
|
||||
std::cmp::max(
|
||||
vec_znx_dft::vec_znx_dft_automorphism_tmp_bytes(self.ptr) as usize,
|
||||
DEFAULTALIGN,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{BACKEND, Module, Sampling, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, alloc_aligned};
|
||||
use itertools::izip;
|
||||
use sampling::source::{Source, new_seed};
|
||||
|
||||
#[test]
|
||||
fn test_automorphism_dft() {
|
||||
let module: Module = Module::new(128, BACKEND::FFT64);
|
||||
|
||||
let cols: usize = 2;
|
||||
let log_base2k: usize = 17;
|
||||
let mut a: VecZnx = module.new_vec_znx(1, cols);
|
||||
let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, cols);
|
||||
let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, cols);
|
||||
|
||||
let mut source: Source = Source::new(new_seed());
|
||||
module.fill_uniform(log_base2k, &mut a, cols, &mut source);
|
||||
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vec_znx_dft_automorphism_tmp_bytes());
|
||||
|
||||
let p: i64 = -5;
|
||||
|
||||
// a_dft <- DFT(a)
|
||||
module.vec_znx_dft(&mut a_dft, &a);
|
||||
|
||||
// a_dft <- AUTO(a_dft)
|
||||
module.vec_znx_dft_automorphism_inplace(p, &mut a_dft, &mut tmp_bytes);
|
||||
|
||||
// a <- AUTO(a)
|
||||
module.vec_znx_automorphism_inplace(p, &mut a);
|
||||
|
||||
// b_dft <- DFT(AUTO(a))
|
||||
module.vec_znx_dft(&mut b_dft, &a);
|
||||
|
||||
let a_f64: &[f64] = a_dft.raw(&module);
|
||||
let b_f64: &[f64] = b_dft.raw(&module);
|
||||
izip!(a_f64.iter(), b_f64.iter()).for_each(|(ai, bi)| {
|
||||
assert!((ai - bi).abs() <= 1e-9, "{:+e} > 1e-9", (ai - bi).abs());
|
||||
(0..min_size).for_each(|i: usize| {
|
||||
self_mut
|
||||
.at_mut(self_col, i)
|
||||
.copy_from_slice(a_ref.at(a_col, i));
|
||||
});
|
||||
|
||||
module.free()
|
||||
(min_size..max_size).for_each(|i| {
|
||||
self_mut.zero_at(self_col, i);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub type VecZnxDftOwned<B> = VecZnxDft<Vec<u8>, B>;
|
||||
|
||||
impl<D, B: Backend> VecZnxDft<D, B> {
|
||||
pub(crate) fn from_data(data: D, n: usize, cols: usize, size: usize) -> Self {
|
||||
Self {
|
||||
data,
|
||||
n,
|
||||
cols,
|
||||
size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToRef<B: Backend> {
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B>;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftToMut<B: Backend> {
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B>;
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxDftToMut<B> for VecZnxDft<Vec<u8>, B> {
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_mut_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<Vec<u8>, B> {
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data.as_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxDftToMut<B> for VecZnxDft<&mut [u8], B> {
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<&mut [u8], B> {
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxDftToRef<B> for VecZnxDft<&[u8], B> {
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
VecZnxDft {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: self.size,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> fmt::Display for VecZnxDft<D, FFT64> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
writeln!(
|
||||
f,
|
||||
"VecZnxDft(n={}, cols={}, size={})",
|
||||
self.n, self.cols, self.size
|
||||
)?;
|
||||
|
||||
for col in 0..self.cols {
|
||||
writeln!(f, "Column {}:", col)?;
|
||||
for size in 0..self.size {
|
||||
let coeffs = self.at(col, size);
|
||||
write!(f, " Size {}: [", size)?;
|
||||
|
||||
let max_show = 100;
|
||||
let show_count = coeffs.len().min(max_show);
|
||||
|
||||
for (i, &coeff) in coeffs.iter().take(show_count).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "{}", coeff)?;
|
||||
}
|
||||
|
||||
if coeffs.len() > max_show {
|
||||
write!(f, ", ... ({} more)", coeffs.len() - max_show)?;
|
||||
}
|
||||
|
||||
writeln!(f, "]")?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
287
base2k/src/vec_znx_dft_ops.rs
Normal file
287
base2k/src/vec_znx_dft_ops.rs
Normal file
@@ -0,0 +1,287 @@
|
||||
use crate::ffi::{vec_znx_big, vec_znx_dft};
|
||||
use crate::vec_znx_dft::bytes_of_vec_znx_dft;
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{
|
||||
Backend, Scratch, VecZnxBig, VecZnxBigToMut, VecZnxDft, VecZnxDftOwned, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef,
|
||||
ZnxSliceSize,
|
||||
};
|
||||
use crate::{FFT64, Module, ZnxView, ZnxViewMut, ZnxZero};
|
||||
use std::cmp::min;
|
||||
|
||||
pub trait VecZnxDftAlloc<B: Backend> {
|
||||
/// Allocates a vector Z[X]/(X^N+1) that stores normalized in the DFT space.
|
||||
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<B>;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// Behavior: takes ownership of the backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B>;
|
||||
|
||||
/// Returns a new [VecZnxDft] with the provided bytes array as backing array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of cols of the [VecZnxDft].
|
||||
/// * `bytes`: a byte array of size at least [Module::bytes_of_vec_znx_dft].
|
||||
///
|
||||
/// # Panics
|
||||
/// If `bytes.len()` < [Module::bytes_of_vec_znx_dft].
|
||||
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxDftOps<B: Backend> {
|
||||
/// Returns the minimum number of bytes necessary to allocate
|
||||
/// a new [VecZnxDft] through [VecZnxDft::from_bytes].
|
||||
fn vec_znx_idft_tmp_bytes(&self) -> usize;
|
||||
|
||||
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>,
|
||||
D: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
|
||||
/// b <- IDFT(a), uses a as scratch space.
|
||||
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToMut<B>;
|
||||
|
||||
/// Consumes a to return IDFT(a) in big coeff space.
|
||||
fn vec_znx_idft_consume<D>(&self, a: VecZnxDft<D, B>) -> VecZnxBig<D, FFT64>
|
||||
where
|
||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>;
|
||||
|
||||
fn vec_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxBigToMut<B>,
|
||||
A: VecZnxDftToRef<B>;
|
||||
|
||||
fn vec_znx_dft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<B>,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxDftAlloc<B> for Module<B> {
|
||||
fn new_vec_znx_dft(&self, cols: usize, size: usize) -> VecZnxDftOwned<B> {
|
||||
VecZnxDftOwned::new(&self, cols, size)
|
||||
}
|
||||
|
||||
fn new_vec_znx_dft_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxDftOwned<B> {
|
||||
VecZnxDftOwned::new_from_bytes(self, cols, size, bytes)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx_dft(&self, cols: usize, size: usize) -> usize {
|
||||
bytes_of_vec_znx_dft(self, cols, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn vec_znx_dft_add<R, A, D>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &D, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
D: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b_ref: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size()).min(b_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
b_ref.at_ptr(b_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
|
||||
fn vec_znx_dft_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = res_mut.size().min(a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_dft_add(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
res_mut.at_ptr(res_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1,
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_dft_copy<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let min_size: usize = min(res_mut.size(), a_ref.size());
|
||||
|
||||
(0..min_size).for_each(|j| {
|
||||
res_mut
|
||||
.at_mut(res_col, j)
|
||||
.copy_from_slice(a_ref.at(a_col, j));
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
|
||||
fn vec_znx_idft_tmp_a<R, A>(&self, res: &mut R, res_col: usize, a: &mut A, a_col: usize)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
let min_size: usize = min(res_mut.size(), a_mut.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1 as u64,
|
||||
a_mut.at_mut_ptr(a_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
)
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_idft_consume<D>(&self, mut a: VecZnxDft<D, FFT64>) -> VecZnxBig<D, FFT64>
|
||||
where
|
||||
VecZnxDft<D, FFT64>: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
let mut a_mut: VecZnxDft<&mut [u8], FFT64> = a.to_mut();
|
||||
|
||||
unsafe {
|
||||
// Rev col and rows because ZnxDft.sl() >= ZnxBig.sl()
|
||||
(0..a_mut.size()).for_each(|j| {
|
||||
(0..a_mut.cols()).for_each(|i| {
|
||||
vec_znx_dft::vec_znx_idft_tmp_a(
|
||||
self.ptr,
|
||||
a_mut.at_mut_ptr(i, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1 as u64,
|
||||
a_mut.at_mut_ptr(i, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
)
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
a.into_big()
|
||||
}
|
||||
|
||||
fn vec_znx_idft_tmp_bytes(&self) -> usize {
|
||||
unsafe { vec_znx_dft::vec_znx_idft_tmp_bytes(self.ptr) as usize }
|
||||
}
|
||||
|
||||
/// b <- DFT(a)
|
||||
///
|
||||
/// # Panics
|
||||
/// If b.cols < a_col
|
||||
fn vec_znx_dft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res_mut: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: crate::VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let min_size: usize = min(res_mut.size(), a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_znx_dft(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
a_ref.at_ptr(a_col, j),
|
||||
1 as u64,
|
||||
a_ref.sl() as u64,
|
||||
)
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// b <- IDFT(a), scratch space size obtained with [vec_znx_idft_tmp_bytes].
|
||||
fn vec_znx_idft<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxBigToMut<FFT64>,
|
||||
A: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res_mut: VecZnxBig<&mut [u8], FFT64> = res.to_mut();
|
||||
let a_ref: VecZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_idft_tmp_bytes());
|
||||
|
||||
let min_size: usize = min(res_mut.size(), a_ref.size());
|
||||
|
||||
unsafe {
|
||||
(0..min_size).for_each(|j| {
|
||||
vec_znx_dft::vec_znx_idft(
|
||||
self.ptr,
|
||||
res_mut.at_mut_ptr(res_col, j) as *mut vec_znx_big::vec_znx_big_t,
|
||||
1 as u64,
|
||||
a_ref.at_ptr(a_col, j) as *const vec_znx_dft::vec_znx_dft_t,
|
||||
1 as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
});
|
||||
(min_size..res_mut.size()).for_each(|j| {
|
||||
res_mut.zero_at(res_col, j);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
694
base2k/src/vec_znx_ops.rs
Normal file
694
base2k/src/vec_znx_ops.rs
Normal file
@@ -0,0 +1,694 @@
|
||||
use crate::ffi::vec_znx;
|
||||
use crate::{
|
||||
Backend, Module, ScalarZnxToRef, Scratch, VecZnx, VecZnxOwned, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxSliceSize, ZnxView,
|
||||
ZnxViewMut, ZnxZero,
|
||||
};
|
||||
use itertools::izip;
|
||||
use std::cmp::min;
|
||||
|
||||
pub trait VecZnxAlloc {
|
||||
/// Allocates a new [VecZnx].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials.
|
||||
/// * `size`: the number small polynomials per column.
|
||||
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned;
|
||||
|
||||
/// Instantiates a new [VecZnx] from a slice of bytes.
|
||||
/// The returned [VecZnx] takes ownership of the slice of bytes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `cols`: the number of polynomials.
|
||||
/// * `size`: the number small polynomials per column.
|
||||
///
|
||||
/// # Panic
|
||||
/// Requires the slice of bytes to be equal to [VecZnxOps::bytes_of_vec_znx].
|
||||
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned;
|
||||
|
||||
/// Returns the number of bytes necessary to allocate
|
||||
/// a new [VecZnx] through [VecZnxOps::new_vec_znx_from_bytes]
|
||||
/// or [VecZnxOps::new_vec_znx_from_bytes_borrow].
|
||||
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize;
|
||||
}
|
||||
|
||||
pub trait VecZnxOps {
|
||||
/// Normalizes the selected column of `a` and stores the result into the selected column of `res`.
|
||||
fn vec_znx_normalize<R, A>(&self, log_base2k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Normalizes the selected column of `a`.
|
||||
fn vec_znx_normalize_inplace<A>(&self, log_base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
|
||||
/// Adds the selected column of `a` to the selected column of `b` and writes the result on the selected column of `res`.
|
||||
fn vec_znx_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef;
|
||||
|
||||
/// Adds the selected column of `a` to the selected column of `res` and writes the result on the selected column of `res`.
|
||||
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Adds the selected column of `a` on the selected column and limb of `res`.
|
||||
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
|
||||
/// Subtracts the selected column of `b` from the selected column of `a` and writes the result on the selected column of `res`.
|
||||
fn vec_znx_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef;
|
||||
|
||||
/// Subtracts the selected column of `a` from the selected column of `res` inplace.
|
||||
///
|
||||
/// res[res_col] -= a[a_col]
|
||||
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Subtracts the selected column of `res` from the selected column of `a` and inplace mutates `res`
|
||||
///
|
||||
/// res[res_col] = a[a_col] - res[res_col]
|
||||
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Subtracts the selected column of `a` on the selected column and limb of `res`.
|
||||
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef;
|
||||
|
||||
// Negates the selected column of `a` and stores the result in `res_col` of `res`.
|
||||
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Negates the selected column of `a`.
|
||||
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
|
||||
/// Multiplies the selected column of `a` by X^k and stores the result in `res_col` of `res`.
|
||||
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Multiplies the selected column of `a` by X^k.
|
||||
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a` and stores the result in `res_col` column of `res`.
|
||||
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Applies the automorphism X^i -> X^ik on the selected column of `a`.
|
||||
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut;
|
||||
|
||||
/// Splits the selected columns of `b` into subrings and copies them them into the selected column of `res`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [VecZnx] of b have the same ring degree
|
||||
/// and that b.n() * b.len() <= a.n()
|
||||
fn vec_znx_split<R, A>(&self, res: &mut Vec<R>, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Merges the subrings of the selected column of `a` into the selected column of `res`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This method requires that all [VecZnx] of a have the same ring degree
|
||||
/// and that a.n() * a.len() <= b.n()
|
||||
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
|
||||
fn switch_degree<R, A>(&self, r: &mut R, col_b: usize, a: &A, col_a: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef;
|
||||
}
|
||||
|
||||
pub trait VecZnxScratch {
|
||||
/// Returns the minimum number of bytes necessary for normalization.
|
||||
fn vec_znx_normalize_tmp_bytes(&self) -> usize;
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxAlloc for Module<B> {
|
||||
fn new_vec_znx(&self, cols: usize, size: usize) -> VecZnxOwned {
|
||||
VecZnxOwned::new::<i64>(self.n(), cols, size)
|
||||
}
|
||||
|
||||
fn bytes_of_vec_znx(&self, cols: usize, size: usize) -> usize {
|
||||
VecZnxOwned::bytes_of::<i64>(self.n(), cols, size)
|
||||
}
|
||||
|
||||
fn new_vec_znx_from_bytes(&self, cols: usize, size: usize, bytes: Vec<u8>) -> VecZnxOwned {
|
||||
VecZnxOwned::new_from_bytes::<i64>(self.n(), cols, size, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<BACKEND: Backend> VecZnxOps for Module<BACKEND> {
|
||||
fn vec_znx_normalize<R, A>(&self, log_base2k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes());
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
self.ptr,
|
||||
log_base2k as u64,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_normalize_inplace<A>(&self, log_base2k: usize, a: &mut A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
|
||||
let (tmp_bytes, _) = scratch.tmp_slice(self.vec_znx_normalize_tmp_bytes());
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_normalize_base2k(
|
||||
self.ptr,
|
||||
log_base2k as u64,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_add<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_add_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_add_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: crate::ScalarZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_add(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, res_limb),
|
||||
1 as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, res_limb),
|
||||
1 as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_sub<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
B: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let b: VecZnx<&[u8]> = b.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(b.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
assert_ne!(a.as_ptr(), b.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
b.at_ptr(b_col, 0),
|
||||
b.size() as u64,
|
||||
b.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_sub_scalar_inplace<R, A>(&self, res: &mut R, res_col: usize, res_limb: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
let a: crate::ScalarZnx<&[u8]> = a.to_ref();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, res_limb),
|
||||
1 as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, res_limb),
|
||||
1 as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_sub_ab_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_sub_ba_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_sub(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
res.at_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_negate<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_negate_inplace<A>(&self, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
self.ptr,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_rotate<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_rotate(
|
||||
self.ptr,
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_rotate_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_rotate(
|
||||
self.ptr,
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_automorphism<R, A>(&self, k: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert_eq!(res.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
res.at_mut_ptr(res_col, 0),
|
||||
res.size() as u64,
|
||||
res.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_automorphism_inplace<A>(&self, k: i64, a: &mut A, a_col: usize)
|
||||
where
|
||||
A: VecZnxToMut,
|
||||
{
|
||||
let mut a: VecZnx<&mut [u8]> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
assert!(
|
||||
k & 1 != 0,
|
||||
"invalid galois element: must be odd but is {}",
|
||||
k
|
||||
);
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_automorphism(
|
||||
self.ptr,
|
||||
k,
|
||||
a.at_mut_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(a_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_split<R, A>(&self, res: &mut Vec<R>, res_col: usize, a: &A, a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
|
||||
let (n_in, n_out) = (a.n(), res[0].to_mut().n());
|
||||
|
||||
let (mut buf, _) = scratch.tmp_vec_znx(self, 1, a.size());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
res[1..].iter_mut().for_each(|bi| {
|
||||
debug_assert_eq!(
|
||||
bi.to_mut().n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
res.iter_mut().enumerate().for_each(|(i, bi)| {
|
||||
if i == 0 {
|
||||
self.switch_degree(bi, res_col, &a, a_col);
|
||||
self.vec_znx_rotate(-1, &mut buf, 0, &a, a_col);
|
||||
} else {
|
||||
self.switch_degree(bi, res_col, &mut buf, a_col);
|
||||
self.vec_znx_rotate_inplace(-1, &mut buf, a_col);
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn vec_znx_merge<R, A>(&self, res: &mut R, res_col: usize, a: Vec<A>, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let (n_in, n_out) = (res.n(), a[0].to_ref().n());
|
||||
|
||||
debug_assert!(
|
||||
n_out < n_in,
|
||||
"invalid a: output ring degree should be smaller"
|
||||
);
|
||||
a[1..].iter().for_each(|ai| {
|
||||
debug_assert_eq!(
|
||||
ai.to_ref().n(),
|
||||
n_out,
|
||||
"invalid input a: all VecZnx must have the same degree"
|
||||
)
|
||||
});
|
||||
|
||||
a.iter().enumerate().for_each(|(_, ai)| {
|
||||
self.switch_degree(&mut res, res_col, ai, a_col);
|
||||
self.vec_znx_rotate_inplace(-1, &mut res, res_col);
|
||||
});
|
||||
|
||||
self.vec_znx_rotate_inplace(a.len() as i64, &mut res, res_col);
|
||||
}
|
||||
|
||||
fn switch_degree<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxToMut,
|
||||
A: VecZnxToRef,
|
||||
{
|
||||
let a: VecZnx<&[u8]> = a.to_ref();
|
||||
let mut res: VecZnx<&mut [u8]> = res.to_mut();
|
||||
|
||||
let (n_in, n_out) = (a.n(), res.n());
|
||||
let (gap_in, gap_out): (usize, usize);
|
||||
|
||||
if n_in > n_out {
|
||||
(gap_in, gap_out) = (n_in / n_out, 1)
|
||||
} else {
|
||||
(gap_in, gap_out) = (1, n_out / n_in);
|
||||
res.zero();
|
||||
}
|
||||
|
||||
let size: usize = min(a.size(), res.size());
|
||||
|
||||
(0..size).for_each(|i| {
|
||||
izip!(
|
||||
a.at(a_col, i).iter().step_by(gap_in),
|
||||
res.at_mut(res_col, i).iter_mut().step_by(gap_out)
|
||||
)
|
||||
.for_each(|(x_in, x_out)| *x_out = *x_in);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> VecZnxScratch for Module<B> {
|
||||
fn vec_znx_normalize_tmp_bytes(&self) -> usize {
|
||||
unsafe { vec_znx::vec_znx_normalize_base2k_tmp_bytes(self.ptr) as usize }
|
||||
}
|
||||
}
|
||||
@@ -1,694 +0,0 @@
|
||||
use crate::ffi::vec_znx_big::vec_znx_big_t;
|
||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::ffi::vmp::{self, vmp_pmat_t};
|
||||
use crate::{BACKEND, Infos, LAYOUT, Module, VecZnx, VecZnxBig, VecZnxDft, alloc_aligned, assert_alignement};
|
||||
|
||||
/// Vector Matrix Product Prepared Matrix: a vector of [VecZnx],
|
||||
/// stored as a 3D matrix in the DFT domain in a single contiguous array.
|
||||
/// Each row of the [VmpPMat] can be seen as a [VecZnxDft].
|
||||
///
|
||||
/// The backend array of [VmpPMat] is allocate in C,
|
||||
/// and thus must be manually freed.
|
||||
///
|
||||
/// [VmpPMat] is used to permform a vector matrix product between a [VecZnx] and a [VmpPMat].
|
||||
/// See the trait [VmpPMatOps] for additional information.
|
||||
pub struct VmpPMat {
|
||||
/// Raw data, is empty if borrowing scratch space.
|
||||
data: Vec<u8>,
|
||||
/// Pointer to data. Can point to scratch space.
|
||||
ptr: *mut u8,
|
||||
/// The number of [VecZnxDft].
|
||||
rows: usize,
|
||||
/// The number of cols in each [VecZnxDft].
|
||||
cols: usize,
|
||||
/// The ring degree of each [VecZnxDft].
|
||||
n: usize,
|
||||
/// The number of stacked [VmpPMat], must be a square.
|
||||
size: usize,
|
||||
/// The memory layout of the stacked [VmpPMat].
|
||||
layout: LAYOUT,
|
||||
/// The backend fft or ntt.
|
||||
backend: BACKEND,
|
||||
}
|
||||
|
||||
impl Infos for VmpPMat {
|
||||
/// Returns the ring dimension of the [VmpPMat].
|
||||
fn n(&self) -> usize {
|
||||
self.n
|
||||
}
|
||||
|
||||
fn log_n(&self) -> usize {
|
||||
(usize::BITS - (self.n() - 1).leading_zeros()) as _
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
|
||||
fn layout(&self) -> LAYOUT {
|
||||
self.layout
|
||||
}
|
||||
|
||||
/// Returns the number of rows (i.e. of [VecZnxDft]) of the [VmpPMat]
|
||||
fn rows(&self) -> usize {
|
||||
self.rows
|
||||
}
|
||||
|
||||
/// Returns the number of cols of the [VmpPMat].
|
||||
/// The number of cols refers to the number of cols
|
||||
/// of each [VecZnxDft].
|
||||
/// This method is equivalent to [Self::cols].
|
||||
fn cols(&self) -> usize {
|
||||
self.cols
|
||||
}
|
||||
}
|
||||
|
||||
impl VmpPMat {
|
||||
pub fn as_ptr(&self) -> *const u8 {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
pub fn as_mut_ptr(&self) -> *mut u8 {
|
||||
self.ptr
|
||||
}
|
||||
|
||||
pub fn borrowed(&self) -> bool {
|
||||
self.data.len() == 0
|
||||
}
|
||||
|
||||
/// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat].
|
||||
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
|
||||
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
|
||||
/// The length of the returned array is rows * cols * n.
|
||||
pub fn raw<T>(&self) -> &[T] {
|
||||
let ptr: *const T = self.ptr as *const T;
|
||||
let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::<T>();
|
||||
unsafe { &std::slice::from_raw_parts(ptr, len) }
|
||||
}
|
||||
|
||||
/// Returns a non-mutable reference of `T` of the entire contiguous array of the [VmpPMat].
|
||||
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
|
||||
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
|
||||
/// The length of the returned array is rows * cols * n.
|
||||
pub fn raw_mut<T>(&self) -> &mut [T] {
|
||||
let ptr: *mut T = self.ptr as *mut T;
|
||||
let len: usize = (self.rows() * self.cols() * self.n() * 8) / std::mem::size_of::<T>();
|
||||
unsafe { std::slice::from_raw_parts_mut(ptr, len) }
|
||||
}
|
||||
|
||||
/// Returns a copy of the backend array at index (i, j) of the [VmpPMat].
|
||||
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
|
||||
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `row`: row index (i).
|
||||
/// * `col`: col index (j).
|
||||
pub fn at<T: Default + Copy>(&self, row: usize, col: usize) -> Vec<T> {
|
||||
let mut res: Vec<T> = alloc_aligned(self.n);
|
||||
|
||||
if self.n < 8 {
|
||||
res.copy_from_slice(
|
||||
&self.raw::<T>()[(row + col * self.rows()) * self.n()..(row + col * self.rows()) * (self.n() + 1)],
|
||||
);
|
||||
} else {
|
||||
(0..self.n >> 3).for_each(|blk| {
|
||||
res[blk * 8..(blk + 1) * 8].copy_from_slice(&self.at_block(row, col, blk)[..8]);
|
||||
});
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
/// When using [`crate::FFT64`] as backend, `T` should be [f64].
|
||||
/// When using [`crate::NTT120`] as backend, `T` should be [i64].
|
||||
fn at_block<T>(&self, row: usize, col: usize, blk: usize) -> &[T] {
|
||||
let nrows: usize = self.rows();
|
||||
let ncols: usize = self.cols();
|
||||
if col == (ncols - 1) && (ncols & 1 == 1) {
|
||||
&self.raw::<T>()[blk * nrows * ncols * 8 + col * nrows * 8 + row * 8..]
|
||||
} else {
|
||||
&self.raw::<T>()[blk * nrows * ncols * 8 + (col / 2) * (2 * nrows) * 8 + row * 2 * 8 + (col % 2) * 8..]
|
||||
}
|
||||
}
|
||||
|
||||
fn backend(&self) -> BACKEND {
|
||||
self.backend
|
||||
}
|
||||
}
|
||||
|
||||
/// This trait implements methods for vector matrix product,
|
||||
/// that is, multiplying a [VecZnx] with a [VmpPMat].
|
||||
pub trait VmpPMatOps {
|
||||
fn bytes_of_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> usize;
|
||||
|
||||
/// Allocates a new [VmpPMat] with the given number of rows and columns.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `rows`: number of rows (number of [VecZnxDft]).
|
||||
/// * `cols`: number of cols (number of cols of each [VecZnxDft]).
|
||||
fn new_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> VmpPMat;
|
||||
|
||||
/// Returns the number of bytes needed as scratch space for [VmpPMatOps::vmp_prepare_contiguous].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `rows`: number of rows of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous].
|
||||
/// * `cols`: number of cols of the [VmpPMat] used in [VmpPMatOps::vmp_prepare_contiguous].
|
||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize) -> usize;
|
||||
|
||||
/// Prepares a [VmpPMat] from a contiguous array of [i64].
|
||||
/// The helper struct [Matrix3D] can be used to contruct and populate
|
||||
/// the appropriate contiguous array.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: [VmpPMat] on which the values are encoded.
|
||||
/// * `a`: the contiguous array of [i64] of the 3D matrix to encode on the [VmpPMat].
|
||||
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], buf: &mut [u8]);
|
||||
|
||||
/// Prepares a [VmpPMat] from a vector of [VecZnx].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: [VmpPMat] on which the values are encoded.
|
||||
/// * `a`: the vector of [VecZnx] to encode on the [VmpPMat].
|
||||
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
///
|
||||
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], buf: &mut [u8]);
|
||||
|
||||
/// Prepares the ith-row of [VmpPMat] from a [VecZnx].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: [VmpPMat] on which the values are encoded.
|
||||
/// * `a`: the vector of [VecZnx] to encode on the [VmpPMat].
|
||||
/// * `row_i`: the index of the row to prepare.
|
||||
/// * `buf`: scratch space, the size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
///
|
||||
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]);
|
||||
|
||||
/// Extracts the ith-row of [VmpPMat] into a [VecZnxBig].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: the [VecZnxBig] to on which to extract the row of the [VmpPMat].
|
||||
/// * `a`: [VmpPMat] on which the values are encoded.
|
||||
/// * `row_i`: the index of the row to extract.
|
||||
fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize);
|
||||
|
||||
/// Prepares the ith-row of [VmpPMat] from a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: [VmpPMat] on which the values are encoded.
|
||||
/// * `a`: the [VecZnxDft] to encode on the [VmpPMat].
|
||||
/// * `row_i`: the index of the row to prepare.
|
||||
///
|
||||
/// The size of buf can be obtained with [VmpPMatOps::vmp_prepare_tmp_bytes].
|
||||
fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize);
|
||||
|
||||
/// Extracts the ith-row of [VmpPMat] into a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: the [VecZnxDft] to on which to extract the row of the [VmpPMat].
|
||||
/// * `a`: [VmpPMat] on which the values are encoded.
|
||||
/// * `row_i`: the index of the row to extract.
|
||||
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize);
|
||||
|
||||
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c_cols`: number of cols of the output [VecZnxDft].
|
||||
/// * `a_cols`: number of cols of the input [VecZnx].
|
||||
/// * `rows`: number of rows of the input [VmpPMat].
|
||||
/// * `cols`: number of cols of the input [VmpPMat].
|
||||
fn vmp_apply_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize;
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
|
||||
///
|
||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
|
||||
/// `j` cols, the output is a [VecZnx] of `j` cols.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
///
|
||||
/// ```text
|
||||
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
||||
/// |h i j|
|
||||
/// |k l m|
|
||||
/// ```
|
||||
/// where each element is a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c`: the output of the vector matrix product, as a [VecZnxDft].
|
||||
/// * `a`: the left operand [VecZnx] of the vector matrix product.
|
||||
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
|
||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]);
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on the receiver.
|
||||
///
|
||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
|
||||
/// `j` cols, the output is a [VecZnx] of `j` cols.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
///
|
||||
/// ```text
|
||||
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
||||
/// |h i j|
|
||||
/// |k l m|
|
||||
/// ```
|
||||
/// where each element is a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
|
||||
/// * `a`: the left operand [VecZnx] of the vector matrix product.
|
||||
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_tmp_bytes].
|
||||
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, buf: &mut [u8]);
|
||||
|
||||
/// Returns the size of the stratch space necessary for [VmpPMatOps::vmp_apply_dft_to_dft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c_cols`: number of cols of the output [VecZnxDft].
|
||||
/// * `a_cols`: number of cols of the input [VecZnxDft].
|
||||
/// * `rows`: number of rows of the input [VmpPMat].
|
||||
/// * `cols`: number of cols of the input [VmpPMat].
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(&self, c_cols: usize, a_cols: usize, rows: usize, cols: usize) -> usize;
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat].
|
||||
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
///
|
||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
|
||||
/// `j` cols, the output is a [VecZnx] of `j` cols.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
///
|
||||
/// ```text
|
||||
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
||||
/// |h i j|
|
||||
/// |k l m|
|
||||
/// ```
|
||||
/// where each element is a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c`: the output of the vector matrix product, as a [VecZnxDft].
|
||||
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
|
||||
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]);
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] and adds on top of the receiver instead of overwritting it.
|
||||
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
///
|
||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
|
||||
/// `j` cols, the output is a [VecZnx] of `j` cols.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
///
|
||||
/// ```text
|
||||
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
||||
/// |h i j|
|
||||
/// |k l m|
|
||||
/// ```
|
||||
/// where each element is a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `c`: the operand on which the output of the vector matrix product is added, as a [VecZnxDft].
|
||||
/// * `a`: the left operand [VecZnxDft] of the vector matrix product.
|
||||
/// * `b`: the right operand [VmpPMat] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, buf: &mut [u8]);
|
||||
|
||||
/// Applies the vector matrix product [VecZnxDft] x [VmpPMat] in place.
|
||||
/// The size of `buf` is given by [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
///
|
||||
/// A vector matrix product is equivalent to a sum of [crate::SvpPPolOps::svp_apply_dft]
|
||||
/// where each [crate::Scalar] is a limb of the input [VecZnxDft] (equivalent to an [crate::SvpPPol])
|
||||
/// and each vector a [VecZnxDft] (row) of the [VmpPMat].
|
||||
///
|
||||
/// As such, given an input [VecZnx] of `i` cols and a [VmpPMat] of `i` rows and
|
||||
/// `j` cols, the output is a [VecZnx] of `j` cols.
|
||||
///
|
||||
/// If there is a mismatch between the dimensions the largest valid ones are used.
|
||||
///
|
||||
/// ```text
|
||||
/// |a b c d| x |e f g| = (a * |e f g| + b * |h i j| + c * |k l m|) = |n o p|
|
||||
/// |h i j|
|
||||
/// |k l m|
|
||||
/// ```
|
||||
/// where each element is a [VecZnxDft].
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `b`: the input and output of the vector matrix product, as a [VecZnxDft].
|
||||
/// * `a`: the right operand [VmpPMat] of the vector matrix product.
|
||||
/// * `buf`: scratch space, the size can be obtained with [VmpPMatOps::vmp_apply_dft_to_dft_tmp_bytes].
|
||||
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, buf: &mut [u8]);
|
||||
}
|
||||
|
||||
impl VmpPMatOps for Module {
|
||||
fn bytes_of_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> usize {
|
||||
unsafe { vmp::bytes_of_vmp_pmat(self.ptr, rows as u64, cols as u64) as usize * size }
|
||||
}
|
||||
|
||||
fn new_vmp_pmat(&self, size: usize, rows: usize, cols: usize) -> VmpPMat {
|
||||
let mut data: Vec<u8> = alloc_aligned::<u8>(self.bytes_of_vmp_pmat(size, rows, cols));
|
||||
let ptr: *mut u8 = data.as_mut_ptr();
|
||||
VmpPMat {
|
||||
data: data,
|
||||
ptr: ptr,
|
||||
n: self.n(),
|
||||
size: size,
|
||||
layout: LAYOUT::COL,
|
||||
cols: cols,
|
||||
rows: rows,
|
||||
backend: self.backend(),
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_prepare_tmp_bytes(&self, rows: usize, cols: usize) -> usize {
|
||||
unsafe { vmp::vmp_prepare_tmp_bytes(self.ptr, rows as u64, cols as u64) as usize }
|
||||
}
|
||||
|
||||
fn vmp_prepare_contiguous(&self, b: &mut VmpPMat, a: &[i64], tmp_bytes: &mut [u8]) {
|
||||
debug_assert_eq!(a.len(), b.n * b.rows * b.cols);
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_prepare_contiguous(
|
||||
self.ptr,
|
||||
b.as_mut_ptr() as *mut vmp_pmat_t,
|
||||
a.as_ptr(),
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_prepare_dblptr(&self, b: &mut VmpPMat, a: &[&[i64]], tmp_bytes: &mut [u8]) {
|
||||
let ptrs: Vec<*const i64> = a.iter().map(|v| v.as_ptr()).collect();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
debug_assert_eq!(a.len(), b.rows);
|
||||
a.iter().for_each(|ai| {
|
||||
debug_assert_eq!(ai.len(), b.n * b.cols);
|
||||
});
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_prepare_dblptr(
|
||||
self.ptr,
|
||||
b.as_mut_ptr() as *mut vmp_pmat_t,
|
||||
ptrs.as_ptr(),
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_prepare_row(&self, b: &mut VmpPMat, a: &[i64], row_i: usize, tmp_bytes: &mut [u8]) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.len(), b.cols() * self.n());
|
||||
assert!(tmp_bytes.len() >= self.vmp_prepare_tmp_bytes(b.rows(), b.cols()));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_prepare_row(
|
||||
self.ptr,
|
||||
b.as_mut_ptr() as *mut vmp_pmat_t,
|
||||
a.as_ptr(),
|
||||
row_i as u64,
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_extract_row(&self, b: &mut VecZnxBig, a: &VmpPMat, row_i: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), b.n());
|
||||
assert_eq!(a.cols(), b.cols());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_extract_row(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_big_t,
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
row_i as u64,
|
||||
a.rows() as u64,
|
||||
a.cols() as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_prepare_row_dft(&self, b: &mut VmpPMat, a: &VecZnxDft, row_i: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), b.n());
|
||||
assert_eq!(a.cols(), b.cols());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_prepare_row_dft(
|
||||
self.ptr,
|
||||
b.as_mut_ptr() as *mut vmp_pmat_t,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
row_i as u64,
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_extract_row_dft(&self, b: &mut VecZnxDft, a: &VmpPMat, row_i: usize) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), b.n());
|
||||
assert_eq!(a.cols(), b.cols());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_extract_row_dft(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
row_i as u64,
|
||||
a.rows() as u64,
|
||||
a.cols() as u64,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_tmp_bytes(
|
||||
self.ptr,
|
||||
res_cols as u64,
|
||||
a_cols as u64,
|
||||
gct_rows as u64,
|
||||
gct_cols as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.cols() as u64,
|
||||
a.as_ptr(),
|
||||
a.cols() as u64,
|
||||
a.n() as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_add(&self, c: &mut VecZnxDft, a: &VecZnx, b: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_add(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.cols() as u64,
|
||||
a.as_ptr(),
|
||||
a.cols() as u64,
|
||||
a.n() as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft_tmp_bytes(&self, res_cols: usize, a_cols: usize, gct_rows: usize, gct_cols: usize) -> usize {
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_tmp_bytes(
|
||||
self.ptr,
|
||||
res_cols as u64,
|
||||
a_cols as u64,
|
||||
gct_rows as u64,
|
||||
gct_cols as u64,
|
||||
) as usize
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.cols() as u64,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a.cols() as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft_add(&self, c: &mut VecZnxDft, a: &VecZnxDft, b: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(c.cols(), a.cols(), b.rows(), b.cols()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft_add(
|
||||
self.ptr,
|
||||
c.ptr as *mut vec_znx_dft_t,
|
||||
c.cols() as u64,
|
||||
a.ptr as *const vec_znx_dft_t,
|
||||
a.cols() as u64,
|
||||
b.as_ptr() as *const vmp_pmat_t,
|
||||
b.rows() as u64,
|
||||
b.cols() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vmp_apply_dft_to_dft_inplace(&self, b: &mut VecZnxDft, a: &VmpPMat, tmp_bytes: &mut [u8]) {
|
||||
debug_assert!(tmp_bytes.len() >= self.vmp_apply_dft_to_dft_tmp_bytes(b.cols(), b.cols(), a.rows(), a.cols()));
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
unsafe {
|
||||
vmp::vmp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
b.cols() as u64,
|
||||
b.ptr as *mut vec_znx_dft_t,
|
||||
b.cols() as u64,
|
||||
a.as_ptr() as *const vmp_pmat_t,
|
||||
a.rows() as u64,
|
||||
a.cols() as u64,
|
||||
tmp_bytes.as_mut_ptr(),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{
|
||||
Module, Sampling, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, alloc_aligned,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
#[test]
|
||||
fn vmp_prepare_row_dft() {
|
||||
let module: Module = Module::new(32, crate::BACKEND::FFT64);
|
||||
let vpmat_rows: usize = 4;
|
||||
let vpmat_cols: usize = 5;
|
||||
let log_base2k: usize = 8;
|
||||
let mut a: VecZnx = module.new_vec_znx(1, vpmat_cols);
|
||||
let mut a_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_cols);
|
||||
let mut a_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_cols);
|
||||
let mut b_big: VecZnxBig = module.new_vec_znx_big(1, vpmat_cols);
|
||||
let mut b_dft: VecZnxDft = module.new_vec_znx_dft(1, vpmat_cols);
|
||||
let mut vmpmat_0: VmpPMat = module.new_vmp_pmat(1, vpmat_rows, vpmat_cols);
|
||||
let mut vmpmat_1: VmpPMat = module.new_vmp_pmat(1, vpmat_rows, vpmat_cols);
|
||||
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned(module.vmp_prepare_tmp_bytes(vpmat_rows, vpmat_cols));
|
||||
|
||||
for row_i in 0..vpmat_rows {
|
||||
let mut source: Source = Source::new([0u8; 32]);
|
||||
module.fill_uniform(log_base2k, &mut a, vpmat_cols, &mut source);
|
||||
module.vec_znx_dft(&mut a_dft, &a);
|
||||
module.vmp_prepare_row(&mut vmpmat_0, &a.raw(), row_i, &mut tmp_bytes);
|
||||
|
||||
// Checks that prepare(vmp_pmat, a) = prepare_dft(vmp_pmat, a_dft)
|
||||
module.vmp_prepare_row_dft(&mut vmpmat_1, &a_dft, row_i);
|
||||
assert_eq!(vmpmat_0.raw::<u8>(), vmpmat_1.raw::<u8>());
|
||||
|
||||
// Checks that a_dft = extract_dft(prepare(vmp_pmat, a), b_dft)
|
||||
module.vmp_extract_row_dft(&mut b_dft, &vmpmat_0, row_i);
|
||||
assert_eq!(a_dft.raw::<u8>(&module), b_dft.raw::<u8>(&module));
|
||||
|
||||
// Checks that a_big = extract(prepare_dft(vmp_pmat, a_dft), b_big)
|
||||
module.vmp_extract_row(&mut b_big, &vmpmat_0, row_i);
|
||||
module.vec_znx_idft(&mut a_big, &a_dft, &mut tmp_bytes);
|
||||
assert_eq!(a_big.raw::<i64>(&module), b_big.raw::<i64>(&module));
|
||||
}
|
||||
|
||||
module.free();
|
||||
}
|
||||
}
|
||||
199
base2k/src/znx_base.rs
Normal file
199
base2k/src/znx_base.rs
Normal file
@@ -0,0 +1,199 @@
|
||||
use itertools::izip;
|
||||
use rand_distr::num_traits::Zero;
|
||||
|
||||
pub trait ZnxInfos {
|
||||
/// Returns the ring degree of the polynomials.
|
||||
fn n(&self) -> usize;
|
||||
|
||||
/// Returns the base two logarithm of the ring dimension of the polynomials.
|
||||
fn log_n(&self) -> usize {
|
||||
(usize::BITS - (self.n() - 1).leading_zeros()) as _
|
||||
}
|
||||
|
||||
/// Returns the number of rows.
|
||||
fn rows(&self) -> usize;
|
||||
|
||||
/// Returns the number of polynomials in each row.
|
||||
fn cols(&self) -> usize;
|
||||
|
||||
/// Returns the number of size per polynomial.
|
||||
fn size(&self) -> usize;
|
||||
|
||||
/// Returns the total number of small polynomials.
|
||||
fn poly_count(&self) -> usize {
|
||||
self.rows() * self.cols() * self.size()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ZnxSliceSize {
|
||||
/// Returns the slice size, which is the offset between
|
||||
/// two size of the same column.
|
||||
fn sl(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait DataView {
|
||||
type D;
|
||||
fn data(&self) -> &Self::D;
|
||||
}
|
||||
|
||||
pub trait DataViewMut: DataView {
|
||||
fn data_mut(&mut self) -> &mut Self::D;
|
||||
}
|
||||
|
||||
pub trait ZnxView: ZnxInfos + DataView<D: AsRef<[u8]>> {
|
||||
type Scalar: Copy;
|
||||
|
||||
/// Returns a non-mutable pointer to the underlying coefficients array.
|
||||
fn as_ptr(&self) -> *const Self::Scalar {
|
||||
self.data().as_ref().as_ptr() as *const Self::Scalar
|
||||
}
|
||||
|
||||
/// Returns a non-mutable reference to the entire underlying coefficient array.
|
||||
fn raw(&self) -> &[Self::Scalar] {
|
||||
unsafe { std::slice::from_raw_parts(self.as_ptr(), self.n() * self.poly_count()) }
|
||||
}
|
||||
|
||||
/// Returns a non-mutable pointer starting at the j-th small polynomial of the i-th column.
|
||||
fn at_ptr(&self, i: usize, j: usize) -> *const Self::Scalar {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < self.cols());
|
||||
assert!(j < self.size());
|
||||
}
|
||||
let offset: usize = self.n() * (j * self.cols() + i);
|
||||
unsafe { self.as_ptr().add(offset) }
|
||||
}
|
||||
|
||||
/// Returns non-mutable reference to the (i, j)-th small polynomial.
|
||||
fn at(&self, i: usize, j: usize) -> &[Self::Scalar] {
|
||||
unsafe { std::slice::from_raw_parts(self.at_ptr(i, j), self.n()) }
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ZnxViewMut: ZnxView + DataViewMut<D: AsMut<[u8]>> {
|
||||
/// Returns a mutable pointer to the underlying coefficients array.
|
||||
fn as_mut_ptr(&mut self) -> *mut Self::Scalar {
|
||||
self.data_mut().as_mut().as_mut_ptr() as *mut Self::Scalar
|
||||
}
|
||||
|
||||
/// Returns a mutable reference to the entire underlying coefficient array.
|
||||
fn raw_mut(&mut self) -> &mut [Self::Scalar] {
|
||||
unsafe { std::slice::from_raw_parts_mut(self.as_mut_ptr(), self.n() * self.poly_count()) }
|
||||
}
|
||||
|
||||
/// Returns a mutable pointer starting at the j-th small polynomial of the i-th column.
|
||||
fn at_mut_ptr(&mut self, i: usize, j: usize) -> *mut Self::Scalar {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(i < self.cols());
|
||||
assert!(j < self.size());
|
||||
}
|
||||
let offset: usize = self.n() * (j * self.cols() + i);
|
||||
unsafe { self.as_mut_ptr().add(offset) }
|
||||
}
|
||||
|
||||
/// Returns mutable reference to the (i, j)-th small polynomial.
|
||||
fn at_mut(&mut self, i: usize, j: usize) -> &mut [Self::Scalar] {
|
||||
unsafe { std::slice::from_raw_parts_mut(self.at_mut_ptr(i, j), self.n()) }
|
||||
}
|
||||
}
|
||||
|
||||
//(Jay)Note: Can't provide blanket impl. of ZnxView because Scalar is not known
|
||||
impl<T> ZnxViewMut for T where T: ZnxView + DataViewMut<D: AsMut<[u8]>> {}
|
||||
|
||||
pub trait ZnxZero: ZnxViewMut + ZnxSliceSize
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
fn zero(&mut self) {
|
||||
unsafe {
|
||||
std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * self.poly_count());
|
||||
}
|
||||
}
|
||||
|
||||
fn zero_at(&mut self, i: usize, j: usize) {
|
||||
unsafe {
|
||||
std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.n());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Blanket implementations
|
||||
impl<T> ZnxZero for T where T: ZnxViewMut + ZnxSliceSize {} // WARNING should not work for mat_znx_dft but it does
|
||||
|
||||
use std::ops::{Add, AddAssign, Div, Mul, Neg, Shl, Shr, Sub};
|
||||
|
||||
use crate::Scratch;
|
||||
pub trait Integer:
|
||||
Copy
|
||||
+ Default
|
||||
+ PartialEq
|
||||
+ PartialOrd
|
||||
+ Add<Output = Self>
|
||||
+ Sub<Output = Self>
|
||||
+ Mul<Output = Self>
|
||||
+ Div<Output = Self>
|
||||
+ Neg<Output = Self>
|
||||
+ Shl<Output = Self>
|
||||
+ Shr<Output = Self>
|
||||
+ AddAssign
|
||||
{
|
||||
const BITS: u32;
|
||||
}
|
||||
|
||||
impl Integer for i64 {
|
||||
const BITS: u32 = 64;
|
||||
}
|
||||
|
||||
impl Integer for i128 {
|
||||
const BITS: u32 = 128;
|
||||
}
|
||||
|
||||
//(Jay)Note: `rsh` impl. ignores the column
|
||||
pub fn rsh<V: ZnxZero>(k: usize, log_base2k: usize, a: &mut V, _a_col: usize, scratch: &mut Scratch)
|
||||
where
|
||||
V::Scalar: From<usize> + Integer + Zero,
|
||||
{
|
||||
let n: usize = a.n();
|
||||
let _size: usize = a.size();
|
||||
let cols: usize = a.cols();
|
||||
|
||||
let size: usize = a.size();
|
||||
let steps: usize = k / log_base2k;
|
||||
|
||||
a.raw_mut().rotate_right(n * steps * cols);
|
||||
(0..cols).for_each(|i| {
|
||||
(0..steps).for_each(|j| {
|
||||
a.zero_at(i, j);
|
||||
})
|
||||
});
|
||||
|
||||
let k_rem: usize = k % log_base2k;
|
||||
|
||||
if k_rem != 0 {
|
||||
let (carry, _) = scratch.tmp_slice::<V::Scalar>(rsh_tmp_bytes::<V::Scalar>(n));
|
||||
|
||||
unsafe {
|
||||
std::ptr::write_bytes(carry.as_mut_ptr(), 0, n * size_of::<V::Scalar>());
|
||||
}
|
||||
|
||||
let log_base2k_t = V::Scalar::from(log_base2k);
|
||||
let shift = V::Scalar::from(V::Scalar::BITS as usize - k_rem);
|
||||
let k_rem_t = V::Scalar::from(k_rem);
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
(steps..size).for_each(|j| {
|
||||
izip!(carry.iter_mut(), a.at_mut(i, j).iter_mut()).for_each(|(ci, xi)| {
|
||||
*xi += *ci << log_base2k_t;
|
||||
*ci = (*xi << shift) >> shift;
|
||||
*xi = (*xi - *ci) >> k_rem_t;
|
||||
});
|
||||
});
|
||||
carry.iter_mut().for_each(|r| *r = V::Scalar::zero());
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rsh_tmp_bytes<T>(n: usize) -> usize {
|
||||
n * std::mem::size_of::<T>()
|
||||
}
|
||||
@@ -1,5 +1,3 @@
|
||||
cargo-features = ["edition2024"]
|
||||
|
||||
[package]
|
||||
name = "rlwe"
|
||||
version = "0.1.0"
|
||||
@@ -14,5 +12,9 @@ rand_distr = {workspace = true}
|
||||
itertools = {workspace = true}
|
||||
|
||||
[[bench]]
|
||||
name = "gadget_product"
|
||||
name = "external_product_glwe_fft64"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "keyswitch_glwe_fft64"
|
||||
harness = false
|
||||
202
core/benches/external_product_glwe_fft64.rs
Normal file
202
core/benches/external_product_glwe_fft64.rs
Normal file
@@ -0,0 +1,202 @@
|
||||
use base2k::{FFT64, Module, ScalarZnxAlloc, ScratchOwned};
|
||||
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
|
||||
use rlwe::{
|
||||
elem::Infos,
|
||||
ggsw_ciphertext::GGSWCiphertext,
|
||||
glwe_ciphertext::GLWECiphertext,
|
||||
keys::{SecretKey, SecretKeyFourier},
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
fn bench_external_product_glwe_fft64(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("external_product_glwe_fft64");
|
||||
|
||||
struct Params {
|
||||
log_n: usize,
|
||||
basek: usize,
|
||||
k_ct_in: usize,
|
||||
k_ct_out: usize,
|
||||
k_ggsw: usize,
|
||||
rank: usize,
|
||||
}
|
||||
|
||||
fn runner(p: Params) -> impl FnMut() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << p.log_n);
|
||||
|
||||
let basek: usize = p.basek;
|
||||
let k_ct_in: usize = p.k_ct_in;
|
||||
let k_ct_out: usize = p.k_ct_out;
|
||||
let k_ggsw: usize = p.k_ggsw;
|
||||
let rank: usize = p.rank;
|
||||
|
||||
let rows: usize = (p.k_ct_in + p.basek - 1) / p.basek;
|
||||
let sigma: f64 = 3.2;
|
||||
|
||||
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||
let mut ct_rlwe_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_in, rank);
|
||||
let mut ct_rlwe_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_out, rank);
|
||||
let pt_rgsw: base2k::ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
|
||||
let mut scratch = ScratchOwned::new(
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size())
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size())
|
||||
| GLWECiphertext::external_product_scratch_space(
|
||||
&module,
|
||||
ct_rlwe_out.size(),
|
||||
ct_rlwe_in.size(),
|
||||
ct_rgsw.size(),
|
||||
rank,
|
||||
),
|
||||
);
|
||||
|
||||
let mut source_xs = Source::new([0u8; 32]);
|
||||
let mut source_xe = Source::new([0u8; 32]);
|
||||
let mut source_xa = Source::new([0u8; 32]);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
ct_rgsw.encrypt_sk(
|
||||
&module,
|
||||
&pt_rgsw,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rlwe_in.encrypt_zero_sk(
|
||||
&module,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
move || {
|
||||
ct_rlwe_out.external_product(
|
||||
black_box(&module),
|
||||
black_box(&ct_rlwe_in),
|
||||
black_box(&ct_rgsw),
|
||||
black_box(scratch.borrow()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let params_set: Vec<Params> = vec![Params {
|
||||
log_n: 10,
|
||||
basek: 7,
|
||||
k_ct_in: 27,
|
||||
k_ct_out: 27,
|
||||
k_ggsw: 27,
|
||||
rank: 1,
|
||||
}];
|
||||
|
||||
for params in params_set {
|
||||
let id = BenchmarkId::new("EXTERNAL_PRODUCT_GLWE_FFT64", "");
|
||||
let mut runner = runner(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_external_product_glwe_inplace_fft64(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("external_product_glwe_inplace_fft64");
|
||||
|
||||
struct Params {
|
||||
log_n: usize,
|
||||
basek: usize,
|
||||
k_ct: usize,
|
||||
k_ggsw: usize,
|
||||
rank: usize,
|
||||
}
|
||||
|
||||
fn runner(p: Params) -> impl FnMut() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << p.log_n);
|
||||
|
||||
let basek: usize = p.basek;
|
||||
let k_glwe: usize = p.k_ct;
|
||||
let k_ggsw: usize = p.k_ggsw;
|
||||
let rank: usize = p.rank;
|
||||
|
||||
let rows: usize = (p.k_ct + p.basek - 1) / p.basek;
|
||||
let sigma: f64 = 3.2;
|
||||
|
||||
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||
let mut ct_rlwe: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_glwe, rank);
|
||||
let pt_rgsw: base2k::ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
|
||||
let mut scratch = ScratchOwned::new(
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size())
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size())
|
||||
| GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size(), rank),
|
||||
);
|
||||
|
||||
let mut source_xs = Source::new([0u8; 32]);
|
||||
let mut source_xe = Source::new([0u8; 32]);
|
||||
let mut source_xa = Source::new([0u8; 32]);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
ct_rgsw.encrypt_sk(
|
||||
&module,
|
||||
&pt_rgsw,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rlwe.encrypt_zero_sk(
|
||||
&module,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
move || {
|
||||
let scratch_borrow = scratch.borrow();
|
||||
(0..687).for_each(|_| {
|
||||
ct_rlwe.external_product_inplace(
|
||||
black_box(&module),
|
||||
black_box(&ct_rgsw),
|
||||
black_box(scratch_borrow),
|
||||
);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let params_set: Vec<Params> = vec![Params {
|
||||
log_n: 12,
|
||||
basek: 18,
|
||||
k_ct: 54,
|
||||
k_ggsw: 54,
|
||||
rank: 1,
|
||||
}];
|
||||
|
||||
for params in params_set {
|
||||
let id = BenchmarkId::new("EXTERNAL_PRODUCT_GLWE_INPLACE_FFT64", "");
|
||||
let mut runner = runner(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_external_product_glwe_fft64,
|
||||
bench_external_product_glwe_inplace_fft64
|
||||
);
|
||||
criterion_main!(benches);
|
||||
211
core/benches/keyswitch_glwe_fft64.rs
Normal file
211
core/benches/keyswitch_glwe_fft64.rs
Normal file
@@ -0,0 +1,211 @@
|
||||
use base2k::{FFT64, Module, ScratchOwned};
|
||||
use criterion::{BenchmarkId, Criterion, black_box, criterion_group, criterion_main};
|
||||
use rlwe::{
|
||||
elem::Infos,
|
||||
glwe_ciphertext::GLWECiphertext,
|
||||
keys::{SecretKey, SecretKeyFourier},
|
||||
keyswitch_key::GLWESwitchingKey,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
fn bench_keyswitch_glwe_fft64(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("keyswitch_glwe_fft64");
|
||||
|
||||
struct Params {
|
||||
log_n: usize,
|
||||
basek: usize,
|
||||
k_ct_in: usize,
|
||||
k_ct_out: usize,
|
||||
k_ksk: usize,
|
||||
rank_in: usize,
|
||||
rank_out: usize,
|
||||
}
|
||||
|
||||
fn runner(p: Params) -> impl FnMut() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << p.log_n);
|
||||
|
||||
let basek: usize = p.basek;
|
||||
let k_rlwe_in: usize = p.k_ct_in;
|
||||
let k_rlwe_out: usize = p.k_ct_out;
|
||||
let k_grlwe: usize = p.k_ksk;
|
||||
let rank_in: usize = p.rank_in;
|
||||
let rank_out: usize = p.rank_out;
|
||||
|
||||
let rows: usize = (p.k_ct_in + p.basek - 1) / p.basek;
|
||||
let sigma: f64 = 3.2;
|
||||
|
||||
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k_grlwe, rows, rank_in, rank_out);
|
||||
let mut ct_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_rlwe_in, rank_in);
|
||||
let mut ct_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_rlwe_out, rank_out);
|
||||
|
||||
let mut scratch = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ksk.size())
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size())
|
||||
| GLWECiphertext::keyswitch_scratch_space(
|
||||
&module,
|
||||
ct_out.size(),
|
||||
ct_in.size(),
|
||||
ksk.size(),
|
||||
rank_in,
|
||||
rank_out,
|
||||
),
|
||||
);
|
||||
|
||||
let mut source_xs = Source::new([0u8; 32]);
|
||||
let mut source_xe = Source::new([0u8; 32]);
|
||||
let mut source_xa = Source::new([0u8; 32]);
|
||||
|
||||
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in);
|
||||
sk_in.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in);
|
||||
sk_in_dft.dft(&module, &sk_in);
|
||||
|
||||
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out);
|
||||
sk_out.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out);
|
||||
sk_out_dft.dft(&module, &sk_out);
|
||||
|
||||
ksk.encrypt_sk(
|
||||
&module,
|
||||
&sk_in,
|
||||
&sk_out_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_in.encrypt_zero_sk(
|
||||
&module,
|
||||
&sk_in_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
move || {
|
||||
ct_out.keyswitch(
|
||||
black_box(&module),
|
||||
black_box(&ct_in),
|
||||
black_box(&ksk),
|
||||
black_box(scratch.borrow()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let params_set: Vec<Params> = vec![Params {
|
||||
log_n: 16,
|
||||
basek: 50,
|
||||
k_ct_in: 1250,
|
||||
k_ct_out: 1250,
|
||||
k_ksk: 1250 + 66,
|
||||
rank_in: 1,
|
||||
rank_out: 1,
|
||||
}];
|
||||
|
||||
for params in params_set {
|
||||
let id = BenchmarkId::new("KEYSWITCH_GLWE_FFT64", "");
|
||||
let mut runner = runner(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
fn bench_keyswitch_glwe_inplace_fft64(c: &mut Criterion) {
|
||||
let mut group = c.benchmark_group("keyswitch_glwe_inplace_fft64");
|
||||
|
||||
struct Params {
|
||||
log_n: usize,
|
||||
basek: usize,
|
||||
k_ct: usize,
|
||||
k_ksk: usize,
|
||||
rank: usize,
|
||||
}
|
||||
|
||||
fn runner(p: Params) -> impl FnMut() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << p.log_n);
|
||||
|
||||
let basek: usize = p.basek;
|
||||
let k_ct: usize = p.k_ct;
|
||||
let k_ksk: usize = p.k_ksk;
|
||||
let rank: usize = p.rank;
|
||||
|
||||
let rows: usize = (p.k_ct + p.basek - 1) / p.basek;
|
||||
let sigma: f64 = 3.2;
|
||||
|
||||
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank, rank);
|
||||
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
|
||||
|
||||
let mut scratch = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size())
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size())
|
||||
| GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct.size(), ksk.size(), rank),
|
||||
);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk_in.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_in_dft.dft(&module, &sk_in);
|
||||
|
||||
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk_out.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_out_dft.dft(&module, &sk_out);
|
||||
|
||||
ksk.encrypt_sk(
|
||||
&module,
|
||||
&sk_in,
|
||||
&sk_out_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct.encrypt_zero_sk(
|
||||
&module,
|
||||
&sk_in_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
move || {
|
||||
ct.keyswitch_inplace(
|
||||
black_box(&module),
|
||||
black_box(&ksk),
|
||||
black_box(scratch.borrow()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let params_set: Vec<Params> = vec![Params {
|
||||
log_n: 9,
|
||||
basek: 18,
|
||||
k_ct: 27,
|
||||
k_ksk: 27,
|
||||
rank: 1,
|
||||
}];
|
||||
|
||||
for params in params_set {
|
||||
let id = BenchmarkId::new("KEYSWITCH_GLWE_INPLACE_FFT64", "");
|
||||
let mut runner = runner(params);
|
||||
group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
|
||||
}
|
||||
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
benches,
|
||||
bench_keyswitch_glwe_fft64,
|
||||
bench_keyswitch_glwe_inplace_fft64
|
||||
);
|
||||
criterion_main!(benches);
|
||||
386
core/src/automorphism.rs
Normal file
386
core/src/automorphism.rs
Normal file
@@ -0,0 +1,386 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps,
|
||||
ScalarZnxToRef, Scratch, VecZnx, VecZnxBigAlloc, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps,
|
||||
ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::{GetRow, Infos, SetRow},
|
||||
gglwe_ciphertext::GGLWECiphertext,
|
||||
ggsw_ciphertext::GGSWCiphertext,
|
||||
glwe_ciphertext::GLWECiphertext,
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
keys::{SecretKey, SecretKeyFourier},
|
||||
keyswitch_key::GLWESwitchingKey,
|
||||
};
|
||||
|
||||
pub struct AutomorphismKey<Data, B: Backend> {
|
||||
pub(crate) key: GLWESwitchingKey<Data, B>,
|
||||
pub(crate) p: i64,
|
||||
}
|
||||
|
||||
impl AutomorphismKey<Vec<u8>, FFT64> {
|
||||
pub fn new(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
|
||||
AutomorphismKey {
|
||||
key: GLWESwitchingKey::new(module, basek, k, rows, rank, rank),
|
||||
p: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> Infos for AutomorphismKey<T, B> {
|
||||
type Inner = MatZnxDft<T, B>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
&self.key.inner()
|
||||
}
|
||||
|
||||
fn basek(&self) -> usize {
|
||||
self.key.basek()
|
||||
}
|
||||
|
||||
fn k(&self) -> usize {
|
||||
self.key.k()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> AutomorphismKey<T, B> {
|
||||
pub fn p(&self) -> i64 {
|
||||
self.p
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
self.key.rank()
|
||||
}
|
||||
|
||||
pub fn rank_in(&self) -> usize {
|
||||
self.key.rank_in()
|
||||
}
|
||||
|
||||
pub fn rank_out(&self) -> usize {
|
||||
self.key.rank_out()
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf, B: Backend> MatZnxDftToMut<B> for AutomorphismKey<DataSelf, B>
|
||||
where
|
||||
MatZnxDft<DataSelf, B>: MatZnxDftToMut<B>,
|
||||
{
|
||||
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
|
||||
self.key.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf, B: Backend> MatZnxDftToRef<B> for AutomorphismKey<DataSelf, B>
|
||||
where
|
||||
MatZnxDft<DataSelf, B>: MatZnxDftToRef<B>,
|
||||
{
|
||||
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
||||
self.key.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> GetRow<FFT64> for AutomorphismKey<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
module.vmp_extract_row(res, self, row_i, col_j);
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> SetRow<FFT64> for AutomorphismKey<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
|
||||
{
|
||||
fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, col_j: usize, a: &R)
|
||||
where
|
||||
R: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
module.vmp_prepare_row(self, row_i, col_j, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl AutomorphismKey<Vec<u8>, FFT64> {
|
||||
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
|
||||
GGLWECiphertext::encrypt_sk_scratch_space(module, rank, size)
|
||||
}
|
||||
|
||||
pub fn encrypt_pk_scratch_space(module: &Module<FFT64>, rank: usize, pk_size: usize) -> usize {
|
||||
GGLWECiphertext::encrypt_pk_scratch_space(module, rank, pk_size)
|
||||
}
|
||||
|
||||
pub fn keyswitch_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ksk_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
GLWESwitchingKey::keyswitch_scratch_space(module, out_size, rank, in_size, rank, ksk_size)
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
|
||||
GLWESwitchingKey::keyswitch_inplace_scratch_space(module, out_size, out_rank, ksk_size)
|
||||
}
|
||||
|
||||
pub fn automorphism_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ksk_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size);
|
||||
let tmp_idft: usize = module.bytes_of_vec_znx_big(rank + 1, out_size);
|
||||
let idft: usize = module.vec_znx_idft_tmp_bytes();
|
||||
let keyswitch: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, out_size, rank, ksk_size);
|
||||
tmp_dft + tmp_idft + idft + keyswitch
|
||||
}
|
||||
|
||||
pub fn automorphism_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, ksk_size: usize, rank: usize) -> usize {
|
||||
AutomorphismKey::automorphism_scratch_space(module, out_size, out_size, ksk_size, rank)
|
||||
}
|
||||
|
||||
pub fn external_product_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ggsw_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
GLWESwitchingKey::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank)
|
||||
}
|
||||
|
||||
pub fn external_product_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
ggsw_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
GLWESwitchingKey::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank)
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> AutomorphismKey<DataSelf, FFT64>
|
||||
where
|
||||
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
pub fn encrypt_sk<DataSk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
p: i64,
|
||||
sk: &SecretKey<DataSk>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
ScalarZnx<DataSk>: ScalarZnxToRef,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(sk.n(), module.n());
|
||||
assert_eq!(self.rank_out(), self.rank_in());
|
||||
assert_eq!(sk.rank(), self.rank());
|
||||
}
|
||||
|
||||
let (sk_out_dft_data, scratch_1) = scratch.tmp_scalar_znx_dft(module, sk.rank());
|
||||
|
||||
let mut sk_out_dft: SecretKeyFourier<&mut [u8], FFT64> = SecretKeyFourier {
|
||||
data: sk_out_dft_data,
|
||||
dist: sk.dist,
|
||||
};
|
||||
|
||||
{
|
||||
(0..self.rank()).for_each(|i| {
|
||||
let (mut sk_inv_auto, _) = scratch_1.tmp_scalar_znx(module, 1);
|
||||
module.scalar_znx_automorphism(module.galois_element_inv(p), &mut sk_inv_auto, 0, sk, i);
|
||||
module.svp_prepare(&mut sk_out_dft, i, &sk_inv_auto, 0);
|
||||
});
|
||||
}
|
||||
|
||||
self.key.encrypt_sk(
|
||||
module,
|
||||
&sk,
|
||||
&sk_out_dft,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
scratch_1,
|
||||
);
|
||||
|
||||
self.p = p;
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> AutomorphismKey<DataSelf, FFT64>
|
||||
where
|
||||
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
pub fn automorphism<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &AutomorphismKey<DataLhs, FFT64>,
|
||||
rhs: &AutomorphismKey<DataRhs, FFT64>,
|
||||
scratch: &mut base2k::Scratch,
|
||||
) where
|
||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank_in(),
|
||||
lhs.rank_in(),
|
||||
"ksk_out input rank: {} != ksk_in input rank: {}",
|
||||
self.rank_in(),
|
||||
lhs.rank_in()
|
||||
);
|
||||
assert_eq!(
|
||||
lhs.rank_out(),
|
||||
rhs.rank_in(),
|
||||
"ksk_in output rank: {} != ksk_apply input rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank_in()
|
||||
);
|
||||
assert_eq!(
|
||||
self.rank_out(),
|
||||
rhs.rank_out(),
|
||||
"ksk_out output rank: {} != ksk_apply output rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank_out()
|
||||
);
|
||||
}
|
||||
|
||||
let cols_out: usize = rhs.rank_out() + 1;
|
||||
|
||||
let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, lhs.size());
|
||||
|
||||
let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_dft_data,
|
||||
basek: lhs.basek(),
|
||||
k: lhs.k(),
|
||||
};
|
||||
|
||||
(0..self.rank_in()).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
// Extracts relevant row
|
||||
lhs.get_row(module, row_j, col_i, &mut tmp_dft);
|
||||
|
||||
// Get a VecZnxBig from scratch space
|
||||
let (mut tmp_idft_data, scratch2) = scratch1.tmp_vec_znx_big(module, cols_out, self.size());
|
||||
|
||||
// Switches input outside of DFT
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_idft(&mut tmp_idft_data, i, &tmp_dft.data, i, scratch2);
|
||||
});
|
||||
|
||||
// Consumes to small vec znx
|
||||
let mut tmp_idft_small_data: VecZnx<&mut [u8]> = tmp_idft_data.to_vec_znx_small();
|
||||
|
||||
// Reverts the automorphis key from (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft_small_data, i);
|
||||
});
|
||||
|
||||
// Wraps into ciphertext
|
||||
let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
|
||||
data: tmp_idft_small_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
// Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
|
||||
tmp_idft.keyswitch_inplace(module, &rhs.key, scratch2);
|
||||
|
||||
// Applies back the automorphism X^{k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) -> (-pi^{-1}_{k'+k}(s)a + s, a)
|
||||
// and switches back to DFT domain
|
||||
(0..self.rank_out() + 1).for_each(|i| {
|
||||
module.vec_znx_automorphism_inplace(lhs.p(), &mut tmp_idft, i);
|
||||
module.vec_znx_dft(&mut tmp_dft, i, &tmp_idft, i);
|
||||
});
|
||||
|
||||
// Sets back the relevant row
|
||||
self.set_row(module, row_j, col_i, &tmp_dft);
|
||||
});
|
||||
});
|
||||
|
||||
tmp_dft.data.zero();
|
||||
|
||||
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
|
||||
(0..self.rank_in()).for_each(|col_j| {
|
||||
self.set_row(module, row_i, col_j, &tmp_dft);
|
||||
});
|
||||
});
|
||||
|
||||
self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64);
|
||||
}
|
||||
|
||||
pub fn automorphism_inplace<DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &AutomorphismKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut AutomorphismKey<DataSelf, FFT64> = self as *mut AutomorphismKey<DataSelf, FFT64>;
|
||||
self.automorphism(&module, &*self_ptr, rhs, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn keyswitch<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &AutomorphismKey<DataLhs, FFT64>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut base2k::Scratch,
|
||||
) where
|
||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
self.key.keyswitch(module, &lhs.key, rhs, scratch);
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace<DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut base2k::Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
self.key.keyswitch_inplace(module, &rhs, scratch);
|
||||
}
|
||||
|
||||
pub fn external_product<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &AutomorphismKey<DataLhs, FFT64>,
|
||||
rhs: &GGSWCiphertext<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
self.key.external_product(module, &lhs.key, rhs, scratch);
|
||||
}
|
||||
|
||||
pub fn external_product_inplace<DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &GGSWCiphertext<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
self.key.external_product_inplace(module, rhs, scratch);
|
||||
}
|
||||
}
|
||||
59
core/src/elem.rs
Normal file
59
core/src/elem.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use base2k::{Backend, Module, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos};
|
||||
|
||||
use crate::utils::derive_size;
|
||||
|
||||
pub trait Infos {
|
||||
type Inner: ZnxInfos;
|
||||
|
||||
fn inner(&self) -> &Self::Inner;
|
||||
|
||||
/// Returns the ring degree of the polynomials.
|
||||
fn n(&self) -> usize {
|
||||
self.inner().n()
|
||||
}
|
||||
|
||||
/// Returns the base two logarithm of the ring dimension of the polynomials.
|
||||
fn log_n(&self) -> usize {
|
||||
self.inner().log_n()
|
||||
}
|
||||
|
||||
/// Returns the number of rows.
|
||||
fn rows(&self) -> usize {
|
||||
self.inner().rows()
|
||||
}
|
||||
|
||||
/// Returns the number of polynomials in each row.
|
||||
fn cols(&self) -> usize {
|
||||
self.inner().cols()
|
||||
}
|
||||
|
||||
/// Returns the number of size per polynomial.
|
||||
fn size(&self) -> usize {
|
||||
let size: usize = self.inner().size();
|
||||
debug_assert_eq!(size, derive_size(self.basek(), self.k()));
|
||||
size
|
||||
}
|
||||
|
||||
/// Returns the total number of small polynomials.
|
||||
fn poly_count(&self) -> usize {
|
||||
self.rows() * self.cols() * self.size()
|
||||
}
|
||||
|
||||
/// Returns the base 2 logarithm of the ciphertext base.
|
||||
fn basek(&self) -> usize;
|
||||
|
||||
/// Returns the bit precision of the ciphertext.
|
||||
fn k(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait GetRow<B: Backend> {
|
||||
fn get_row<R>(&self, module: &Module<B>, row_i: usize, col_j: usize, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<B>;
|
||||
}
|
||||
|
||||
pub trait SetRow<B: Backend> {
|
||||
fn set_row<R>(&mut self, module: &Module<B>, row_i: usize, col_j: usize, a: &R)
|
||||
where
|
||||
R: VecZnxDftToRef<B>;
|
||||
}
|
||||
211
core/src/gglwe_ciphertext.rs
Normal file
211
core/src/gglwe_ciphertext.rs
Normal file
@@ -0,0 +1,211 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
|
||||
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxInfos,
|
||||
ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::{GetRow, Infos, SetRow},
|
||||
glwe_ciphertext::GLWECiphertext,
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
glwe_plaintext::GLWEPlaintext,
|
||||
keys::SecretKeyFourier,
|
||||
utils::derive_size,
|
||||
};
|
||||
|
||||
pub struct GGLWECiphertext<C, B: Backend> {
|
||||
pub(crate) data: MatZnxDft<C, B>,
|
||||
pub(crate) basek: usize,
|
||||
pub(crate) k: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> GGLWECiphertext<Vec<u8>, B> {
|
||||
pub fn new(module: &Module<B>, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, derive_size(basek, k)),
|
||||
basek: basek,
|
||||
k,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> Infos for GGLWECiphertext<T, B> {
|
||||
type Inner = MatZnxDft<T, B>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
&self.data
|
||||
}
|
||||
|
||||
fn basek(&self) -> usize {
|
||||
self.basek
|
||||
}
|
||||
|
||||
fn k(&self) -> usize {
|
||||
self.k
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> GGLWECiphertext<T, B> {
|
||||
pub fn rank(&self) -> usize {
|
||||
self.data.cols_out() - 1
|
||||
}
|
||||
|
||||
pub fn rank_in(&self) -> usize {
|
||||
self.data.cols_in()
|
||||
}
|
||||
|
||||
pub fn rank_out(&self) -> usize {
|
||||
self.data.cols_out() - 1
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> MatZnxDftToMut<B> for GGLWECiphertext<C, B>
|
||||
where
|
||||
MatZnxDft<C, B>: MatZnxDftToMut<B>,
|
||||
{
|
||||
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> MatZnxDftToRef<B> for GGLWECiphertext<C, B>
|
||||
where
|
||||
MatZnxDft<C, B>: MatZnxDftToRef<B>,
|
||||
{
|
||||
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl GGLWECiphertext<Vec<u8>, FFT64> {
|
||||
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
|
||||
GLWECiphertext::encrypt_sk_scratch_space(module, size)
|
||||
+ module.bytes_of_vec_znx(rank + 1, size)
|
||||
+ module.bytes_of_vec_znx(1, size)
|
||||
+ module.bytes_of_vec_znx_dft(rank + 1, size)
|
||||
}
|
||||
|
||||
pub fn encrypt_pk_scratch_space(_module: &Module<FFT64>, _rank: usize, _pk_size: usize) -> usize {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> GGLWECiphertext<DataSelf, FFT64>
|
||||
where
|
||||
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + ZnxInfos,
|
||||
{
|
||||
pub fn encrypt_sk<DataPt, DataSk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
pt: &ScalarZnx<DataPt>,
|
||||
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
ScalarZnx<DataPt>: ScalarZnxToRef,
|
||||
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank_in(), pt.cols());
|
||||
assert_eq!(self.rank_out(), sk_dft.rank());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(sk_dft.n(), module.n());
|
||||
assert_eq!(pt.n(), module.n());
|
||||
}
|
||||
|
||||
let rows: usize = self.rows();
|
||||
let size: usize = self.size();
|
||||
let basek: usize = self.basek();
|
||||
let k: usize = self.k();
|
||||
|
||||
let cols_in: usize = self.rank_in();
|
||||
let cols_out: usize = self.rank_out() + 1;
|
||||
|
||||
let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size);
|
||||
let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, cols_out, size);
|
||||
let (tmp_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, cols_out, size);
|
||||
|
||||
let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext {
|
||||
data: tmp_znx_pt,
|
||||
basek,
|
||||
k,
|
||||
};
|
||||
|
||||
let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext {
|
||||
data: tmp_znx_ct,
|
||||
basek,
|
||||
k,
|
||||
};
|
||||
|
||||
let mut vec_znx_ct_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier {
|
||||
data: tmp_znx_dft_ct,
|
||||
basek,
|
||||
k,
|
||||
};
|
||||
|
||||
// For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns
|
||||
//
|
||||
// Example for ksk rank 2 to rank 3:
|
||||
//
|
||||
// (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2)
|
||||
// (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2)
|
||||
//
|
||||
// Example ksk rank 2 to rank 1
|
||||
//
|
||||
// (-(a*s) + s0, a)
|
||||
// (-(b*s) + s1, b)
|
||||
(0..cols_in).for_each(|col_i| {
|
||||
(0..rows).for_each(|row_i| {
|
||||
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
||||
vec_znx_pt.data.zero(); // zeroes for next iteration
|
||||
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, col_i); // Selects the i-th
|
||||
module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scratch_3);
|
||||
|
||||
// rlwe encrypt of vec_znx_pt into vec_znx_ct
|
||||
vec_znx_ct.encrypt_sk(
|
||||
module,
|
||||
&vec_znx_pt,
|
||||
sk_dft,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
scratch_3,
|
||||
);
|
||||
|
||||
// Switch vec_znx_ct into DFT domain
|
||||
vec_znx_ct.dft(module, &mut vec_znx_ct_dft);
|
||||
|
||||
// Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft
|
||||
module.vmp_prepare_row(self, row_i, col_i, &vec_znx_ct_dft);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> GetRow<FFT64> for GGLWECiphertext<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
module.vmp_extract_row(res, self, row_i, col_j);
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> SetRow<FFT64> for GGLWECiphertext<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
|
||||
{
|
||||
fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, col_j: usize, a: &R)
|
||||
where
|
||||
R: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
module.vmp_prepare_row(self, row_i, col_j, a);
|
||||
}
|
||||
}
|
||||
684
core/src/ggsw_ciphertext.rs
Normal file
684
core/src/ggsw_ciphertext.rs
Normal file
@@ -0,0 +1,684 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
|
||||
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps,
|
||||
VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut,
|
||||
VecZnxToRef, ZnxInfos, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
automorphism::AutomorphismKey,
|
||||
elem::{GetRow, Infos, SetRow},
|
||||
glwe_ciphertext::GLWECiphertext,
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
glwe_plaintext::GLWEPlaintext,
|
||||
keys::SecretKeyFourier,
|
||||
keyswitch_key::GLWESwitchingKey,
|
||||
tensor_key::TensorKey,
|
||||
utils::derive_size,
|
||||
};
|
||||
|
||||
pub struct GGSWCiphertext<C, B: Backend> {
|
||||
pub data: MatZnxDft<C, B>,
|
||||
pub basek: usize,
|
||||
pub k: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> GGSWCiphertext<Vec<u8>, B> {
|
||||
pub fn new(module: &Module<B>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k)),
|
||||
basek: basek,
|
||||
k: k,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> Infos for GGSWCiphertext<T, B> {
|
||||
type Inner = MatZnxDft<T, B>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
&self.data
|
||||
}
|
||||
|
||||
fn basek(&self) -> usize {
|
||||
self.basek
|
||||
}
|
||||
|
||||
fn k(&self) -> usize {
|
||||
self.k
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> GGSWCiphertext<T, B> {
|
||||
pub fn rank(&self) -> usize {
|
||||
self.data.cols_out() - 1
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> MatZnxDftToMut<B> for GGSWCiphertext<C, B>
|
||||
where
|
||||
MatZnxDft<C, B>: MatZnxDftToMut<B>,
|
||||
{
|
||||
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> MatZnxDftToRef<B> for GGSWCiphertext<C, B>
|
||||
where
|
||||
MatZnxDft<C, B>: MatZnxDftToRef<B>,
|
||||
{
|
||||
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
|
||||
GLWECiphertext::encrypt_sk_scratch_space(module, size)
|
||||
+ module.bytes_of_vec_znx(rank + 1, size)
|
||||
+ module.bytes_of_vec_znx(1, size)
|
||||
+ module.bytes_of_vec_znx_dft(rank + 1, size)
|
||||
}
|
||||
|
||||
pub(crate) fn expand_row_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
self_size: usize,
|
||||
tensor_key_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tensor_key_size);
|
||||
let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size);
|
||||
let vmp: usize =
|
||||
tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tensor_key_size);
|
||||
let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tensor_key_size);
|
||||
let norm: usize = module.vec_znx_big_normalize_tmp_bytes();
|
||||
tmp_dft_i + ((tmp_dft_col_data + vmp) | (tmp_idft + norm))
|
||||
}
|
||||
|
||||
pub(crate) fn keyswitch_internal_col0_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ksk_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, ksk_size)
|
||||
+ module.bytes_of_vec_znx_dft(rank + 1, in_size)
|
||||
}
|
||||
|
||||
pub fn keyswitch_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ksk_size: usize,
|
||||
tensor_key_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size);
|
||||
let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let ks: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, out_size, in_size, ksk_size, rank);
|
||||
let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tensor_key_size, rank);
|
||||
let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
res_znx + ci_dft + (ks | expand_rows | res_dft)
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
ksk_size: usize,
|
||||
tensor_key_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
GGSWCiphertext::keyswitch_scratch_space(module, out_size, out_size, ksk_size, tensor_key_size, rank)
|
||||
}
|
||||
|
||||
pub fn automorphism_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
auto_key_size: usize,
|
||||
tensor_key_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
GGSWCiphertext::keyswitch_scratch_space(
|
||||
module,
|
||||
out_size,
|
||||
in_size,
|
||||
auto_key_size,
|
||||
tensor_key_size,
|
||||
rank,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn automorphism_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
auto_key_size: usize,
|
||||
tensor_key_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
GGSWCiphertext::automorphism_scratch_space(
|
||||
module,
|
||||
out_size,
|
||||
out_size,
|
||||
auto_key_size,
|
||||
tensor_key_size,
|
||||
rank,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn external_product_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ggsw_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size);
|
||||
let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank);
|
||||
tmp_in + tmp_out + ggsw
|
||||
}
|
||||
|
||||
pub fn external_product_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
ggsw_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank);
|
||||
tmp + ggsw
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> GGSWCiphertext<DataSelf, FFT64>
|
||||
where
|
||||
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
pub fn encrypt_sk<DataPt, DataSk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
pt: &ScalarZnx<DataPt>,
|
||||
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
ScalarZnx<DataPt>: ScalarZnxToRef,
|
||||
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank(), sk_dft.rank());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(pt.n(), module.n());
|
||||
assert_eq!(sk_dft.n(), module.n());
|
||||
}
|
||||
|
||||
let size: usize = self.size();
|
||||
let basek: usize = self.basek();
|
||||
let k: usize = self.k();
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size);
|
||||
let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, cols, size);
|
||||
|
||||
let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext {
|
||||
data: tmp_znx_pt,
|
||||
basek: basek,
|
||||
k: k,
|
||||
};
|
||||
|
||||
let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext {
|
||||
data: tmp_znx_ct,
|
||||
basek: basek,
|
||||
k,
|
||||
};
|
||||
|
||||
(0..self.rows()).for_each(|row_i| {
|
||||
vec_znx_pt.data.zero();
|
||||
|
||||
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
||||
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt, 0, scrach_2);
|
||||
|
||||
(0..cols).for_each(|col_j| {
|
||||
// rlwe encrypt of vec_znx_pt into vec_znx_ct
|
||||
|
||||
vec_znx_ct.encrypt_sk_private(
|
||||
module,
|
||||
Some((&vec_znx_pt, col_j)),
|
||||
sk_dft,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
scrach_2,
|
||||
);
|
||||
|
||||
// Switch vec_znx_ct into DFT domain
|
||||
{
|
||||
let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, cols, size);
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_dft(&mut vec_znx_dft_ct, i, &vec_znx_ct, i);
|
||||
});
|
||||
|
||||
self.set_row(module, row_i, col_j, &vec_znx_dft_ct);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn expand_row<R, DataCi, DataTsk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
col_j: usize,
|
||||
res: &mut R,
|
||||
ci_dft: &VecZnxDft<DataCi, FFT64>,
|
||||
tsk: &TensorKey<DataTsk, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
R: VecZnxToMut,
|
||||
VecZnxDft<DataCi, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataTsk, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
// Example for rank 3:
|
||||
//
|
||||
// Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is
|
||||
// actually composed of that many rows and we focus on a specific row here
|
||||
// implicitely given ci_dft.
|
||||
//
|
||||
// # Input
|
||||
//
|
||||
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 )
|
||||
// col 1: (0, 0, 0, 0)
|
||||
// col 2: (0, 0, 0, 0)
|
||||
// col 3: (0, 0, 0, 0)
|
||||
//
|
||||
// # Output
|
||||
//
|
||||
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0 , a1 , a2 )
|
||||
// col 1: (-(b0s0 + b1s1 + b2s2) , b0 + M[i], b1 , b2 )
|
||||
// col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M[i], c2 )
|
||||
// col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M[i])
|
||||
|
||||
let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size());
|
||||
{
|
||||
let (mut tmp_dft_col_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, self.size());
|
||||
|
||||
// Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
|
||||
//
|
||||
// # Example for col=1
|
||||
//
|
||||
// a0 * (-(f0s0 + f1s1 + f1s2) + s0^2, f0, f1, f2) = (-(a0f0s0 + a0f1s1 + a0f1s2) + a0s0^2, a0f0, a0f1, a0f2)
|
||||
// +
|
||||
// a1 * (-(g0s0 + g1s1 + g1s2) + s0s1, g0, g1, g2) = (-(a1g0s0 + a1g1s1 + a1g1s2) + a1s0s1, a1g0, a1g1, a1g2)
|
||||
// +
|
||||
// a2 * (-(h0s0 + h1s1 + h1s2) + s0s2, h0, h1, h2) = (-(a2h0s0 + a2h1s1 + a2h1s2) + a2s0s2, a2h0, a2h1, a2h2)
|
||||
// =
|
||||
// (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2)
|
||||
(1..cols).for_each(|col_i| {
|
||||
// Extracts a[i] and multipies with Enc(s[i]s[j])
|
||||
tmp_dft_col_data.extract_column(0, ci_dft, col_i);
|
||||
|
||||
if col_i == 1 {
|
||||
module.vmp_apply(
|
||||
&mut tmp_dft_i,
|
||||
&tmp_dft_col_data,
|
||||
tsk.at(col_i - 1, col_j - 1), // Selects Enc(s[i]s[j])
|
||||
scratch2,
|
||||
);
|
||||
} else {
|
||||
module.vmp_apply_add(
|
||||
&mut tmp_dft_i,
|
||||
&tmp_dft_col_data,
|
||||
tsk.at(col_i - 1, col_j - 1), // Selects Enc(s[i]s[j])
|
||||
scratch2,
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i
|
||||
//
|
||||
// (-(x0s0 + x1s1 + x2s2) + a0s0s0 + a1s0s1 + a2s0s2, x0, x1, x2)
|
||||
// +
|
||||
// (0, -(a0s0 + a1s1 + a2s2) + M[i], 0, 0)
|
||||
// =
|
||||
// (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0 -(a0s0 + a1s1 + a2s2) + M[i], x1, x2)
|
||||
// =
|
||||
// (-(x0s0 + x1s1 + x2s2), x0 + M[i], x1, x2)
|
||||
module.vec_znx_dft_add_inplace(&mut tmp_dft_i, col_j, ci_dft, 0);
|
||||
let (mut tmp_idft, scratch2) = scratch1.tmp_vec_znx_big(module, 1, tsk.size());
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i);
|
||||
module.vec_znx_big_normalize(self.basek(), res, i, &tmp_idft, 0, scratch2);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn keyswitch<DataLhs, DataKsk, DataTsk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GGSWCiphertext<DataLhs, FFT64>,
|
||||
ksk: &GLWESwitchingKey<DataKsk, FFT64>,
|
||||
tsk: &TensorKey<DataTsk, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataKsk, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataTsk, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size());
|
||||
let mut res: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
|
||||
data: res_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size());
|
||||
|
||||
// Keyswitch the j-th row of the col 0
|
||||
(0..lhs.rows()).for_each(|row_i| {
|
||||
// Key-switch column 0, i.e.
|
||||
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
|
||||
lhs.keyswitch_internal_col0(module, row_i, &mut res, ksk, scratch2);
|
||||
|
||||
// Isolates DFT(a[i])
|
||||
(0..cols).for_each(|col_i| {
|
||||
module.vec_znx_dft(&mut ci_dft, col_i, &res, col_i);
|
||||
});
|
||||
|
||||
self.set_row(module, row_i, 0, &ci_dft);
|
||||
|
||||
// Generates
|
||||
//
|
||||
// col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M[i], b1 , b2 )
|
||||
// col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 )
|
||||
// col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i])
|
||||
(1..cols).for_each(|col_j| {
|
||||
self.expand_row(module, col_j, &mut res, &ci_dft, tsk, scratch2);
|
||||
|
||||
let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size());
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_dft(&mut res_dft, i, &res, i);
|
||||
});
|
||||
|
||||
self.set_row(module, row_i, col_j, &res_dft);
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace<DataKsk, DataTsk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
ksk: &GLWESwitchingKey<DataKsk, FFT64>,
|
||||
tsk: &TensorKey<DataTsk, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataKsk, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataTsk, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GGSWCiphertext<DataSelf, FFT64> = self as *mut GGSWCiphertext<DataSelf, FFT64>;
|
||||
self.keyswitch(module, &*self_ptr, ksk, tsk, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn automorphism<DataLhs, DataAk, DataTsk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GGSWCiphertext<DataLhs, FFT64>,
|
||||
auto_key: &AutomorphismKey<DataAk, FFT64>,
|
||||
tensor_key: &TensorKey<DataTsk, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataAk, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataTsk, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank(),
|
||||
lhs.rank(),
|
||||
"ggsw_out rank: {} != ggsw_in rank: {}",
|
||||
self.rank(),
|
||||
lhs.rank()
|
||||
);
|
||||
assert_eq!(
|
||||
self.rank(),
|
||||
auto_key.rank(),
|
||||
"ggsw_in rank: {} != auto_key rank: {}",
|
||||
self.rank(),
|
||||
auto_key.rank()
|
||||
);
|
||||
assert_eq!(
|
||||
self.rank(),
|
||||
tensor_key.rank(),
|
||||
"ggsw_in rank: {} != tensor_key rank: {}",
|
||||
self.rank(),
|
||||
tensor_key.rank()
|
||||
);
|
||||
};
|
||||
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size());
|
||||
let mut res: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
|
||||
data: res_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size());
|
||||
|
||||
// Keyswitch the j-th row of the col 0
|
||||
(0..lhs.rows()).for_each(|row_i| {
|
||||
// Key-switch column 0, i.e.
|
||||
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2)
|
||||
lhs.keyswitch_internal_col0(module, row_i, &mut res, &auto_key.key, scratch2);
|
||||
|
||||
// Isolates DFT(AUTO(a[i]))
|
||||
(0..cols).for_each(|col_i| {
|
||||
// (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) -> (-(a0s0 + a1s1 + a2s2) + pi(M[i]), a0, a1, a2)
|
||||
module.vec_znx_automorphism_inplace(auto_key.p(), &mut res, col_i);
|
||||
module.vec_znx_dft(&mut ci_dft, col_i, &res, col_i);
|
||||
});
|
||||
|
||||
self.set_row(module, row_i, 0, &ci_dft);
|
||||
|
||||
// Generates
|
||||
//
|
||||
// col 1: (-(b0s0 + b1s1 + b2s2) , b0 + pi(M[i]), b1 , b2 )
|
||||
// col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + pi(M[i]), c2 )
|
||||
// col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + pi(M[i]))
|
||||
(1..cols).for_each(|col_j| {
|
||||
self.expand_row(module, col_j, &mut res, &ci_dft, tensor_key, scratch2);
|
||||
|
||||
let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size());
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_dft(&mut res_dft, i, &res, i);
|
||||
});
|
||||
|
||||
self.set_row(module, row_i, col_j, &res_dft);
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
pub fn automorphism_inplace<DataKsk, DataTsk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
auto_key: &AutomorphismKey<DataKsk, FFT64>,
|
||||
tensor_key: &TensorKey<DataTsk, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataKsk, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataTsk, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GGSWCiphertext<DataSelf, FFT64> = self as *mut GGSWCiphertext<DataSelf, FFT64>;
|
||||
self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn external_product<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GGSWCiphertext<DataLhs, FFT64>,
|
||||
rhs: &GGSWCiphertext<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank(),
|
||||
lhs.rank(),
|
||||
"ggsw_out rank: {} != ggsw_in rank: {}",
|
||||
self.rank(),
|
||||
lhs.rank()
|
||||
);
|
||||
assert_eq!(
|
||||
self.rank(),
|
||||
rhs.rank(),
|
||||
"ggsw_in rank: {} != ggsw_apply rank: {}",
|
||||
self.rank(),
|
||||
rhs.rank()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank() + 1, lhs.size());
|
||||
|
||||
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_in_data,
|
||||
basek: lhs.basek(),
|
||||
k: lhs.k(),
|
||||
};
|
||||
|
||||
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
|
||||
|
||||
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_out_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
(0..self.rank() + 1).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
lhs.get_row(module, row_j, col_i, &mut tmp_in);
|
||||
tmp_out.external_product(module, &tmp_in, rhs, scratch2);
|
||||
self.set_row(module, row_j, col_i, &tmp_out);
|
||||
});
|
||||
});
|
||||
|
||||
tmp_out.data.zero();
|
||||
|
||||
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
|
||||
(0..self.rank() + 1).for_each(|col_j| {
|
||||
self.set_row(module, row_i, col_j, &tmp_out);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub fn external_product_inplace<DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &GGSWCiphertext<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank(),
|
||||
rhs.rank(),
|
||||
"ggsw_out rank: {} != ggsw_apply: {}",
|
||||
self.rank(),
|
||||
rhs.rank()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
|
||||
|
||||
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
(0..self.rank() + 1).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
self.get_row(module, row_j, col_i, &mut tmp);
|
||||
tmp.external_product_inplace(module, rhs, scratch1);
|
||||
self.set_row(module, row_j, col_i, &tmp);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> GGSWCiphertext<DataSelf, FFT64>
|
||||
where
|
||||
MatZnxDft<DataSelf, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
pub(crate) fn keyswitch_internal_col0<DataRes, DataKsk>(
|
||||
&self,
|
||||
module: &Module<FFT64>,
|
||||
row_i: usize,
|
||||
res: &mut GLWECiphertext<DataRes>,
|
||||
ksk: &GLWESwitchingKey<DataKsk, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<DataRes>: VecZnxToMut + VecZnxToRef,
|
||||
MatZnxDft<DataKsk, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank(), ksk.rank());
|
||||
assert_eq!(res.rank(), ksk.rank());
|
||||
}
|
||||
|
||||
let (tmp_dft_in_data, scratch2) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
|
||||
let mut tmp_dft_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_dft_in_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
self.get_row(module, row_i, 0, &mut tmp_dft_in);
|
||||
res.keyswitch_from_fourier(module, &tmp_dft_in, ksk, scratch2);
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> GetRow<FFT64> for GGSWCiphertext<DataSelf, FFT64>
|
||||
where
|
||||
MatZnxDft<DataSelf, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
module.vmp_extract_row(res, self, row_i, col_j);
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> SetRow<FFT64> for GGSWCiphertext<DataSelf, FFT64>
|
||||
where
|
||||
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64>,
|
||||
{
|
||||
fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, col_j: usize, a: &R)
|
||||
where
|
||||
R: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
module.vmp_prepare_row(self, row_i, col_j, a);
|
||||
}
|
||||
}
|
||||
696
core/src/glwe_ciphertext.rs
Normal file
696
core/src/glwe_ciphertext.rs
Normal file
@@ -0,0 +1,696 @@
|
||||
use base2k::{
|
||||
AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxAlloc,
|
||||
ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc,
|
||||
VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps,
|
||||
VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
SIX_SIGMA,
|
||||
automorphism::AutomorphismKey,
|
||||
elem::Infos,
|
||||
ggsw_ciphertext::GGSWCiphertext,
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
glwe_plaintext::GLWEPlaintext,
|
||||
keys::{GLWEPublicKey, SecretDistribution, SecretKeyFourier},
|
||||
keyswitch_key::GLWESwitchingKey,
|
||||
utils::derive_size,
|
||||
};
|
||||
|
||||
pub struct GLWECiphertext<C> {
|
||||
pub data: VecZnx<C>,
|
||||
pub basek: usize,
|
||||
pub k: usize,
|
||||
}
|
||||
|
||||
impl GLWECiphertext<Vec<u8>> {
|
||||
pub fn new<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_vec_znx(rank + 1, derive_size(basek, k)),
|
||||
basek,
|
||||
k,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Infos for GLWECiphertext<T> {
|
||||
type Inner = VecZnx<T>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
&self.data
|
||||
}
|
||||
|
||||
fn basek(&self) -> usize {
|
||||
self.basek
|
||||
}
|
||||
|
||||
fn k(&self) -> usize {
|
||||
self.k
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> GLWECiphertext<T> {
|
||||
pub fn rank(&self) -> usize {
|
||||
self.cols() - 1
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> VecZnxToMut for GLWECiphertext<C>
|
||||
where
|
||||
VecZnx<C>: VecZnxToMut,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> VecZnxToRef for GLWECiphertext<C>
|
||||
where
|
||||
VecZnx<C>: VecZnxToRef,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> GLWECiphertext<C>
|
||||
where
|
||||
VecZnx<C>: VecZnxToRef,
|
||||
{
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn dft<R>(&self, module: &Module<FFT64>, res: &mut GLWECiphertextFourier<R, FFT64>)
|
||||
where
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64> + ZnxInfos,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank(), res.rank());
|
||||
assert_eq!(self.basek(), res.basek())
|
||||
}
|
||||
|
||||
(0..self.rank() + 1).for_each(|i| {
|
||||
module.vec_znx_dft(res, i, self, i);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl GLWECiphertext<Vec<u8>> {
|
||||
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, ct_size: usize) -> usize {
|
||||
module.vec_znx_big_normalize_tmp_bytes()
|
||||
+ module.bytes_of_vec_znx_dft(1, ct_size)
|
||||
+ module.bytes_of_vec_znx_big(1, ct_size)
|
||||
}
|
||||
pub fn encrypt_pk_scratch_space(module: &Module<FFT64>, pk_size: usize) -> usize {
|
||||
((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1))
|
||||
+ module.bytes_of_scalar_znx_dft(1)
|
||||
+ module.vec_znx_big_normalize_tmp_bytes()
|
||||
}
|
||||
|
||||
pub fn decrypt_scratch_space(module: &Module<FFT64>, ct_size: usize) -> usize {
|
||||
(module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, ct_size))
|
||||
+ module.bytes_of_vec_znx_big(1, ct_size)
|
||||
}
|
||||
|
||||
pub fn keyswitch_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
out_rank: usize,
|
||||
in_size: usize,
|
||||
in_rank: usize,
|
||||
ksk_size: usize,
|
||||
) -> usize {
|
||||
let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size);
|
||||
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size)
|
||||
+ module.bytes_of_vec_znx_dft(in_rank, in_size);
|
||||
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
|
||||
|
||||
return res_dft + (vmp | normalize);
|
||||
}
|
||||
|
||||
pub fn keyswitch_from_fourier_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
out_rank: usize,
|
||||
in_size: usize,
|
||||
in_rank: usize,
|
||||
ksk_size: usize,
|
||||
) -> usize {
|
||||
let res_dft = module.bytes_of_vec_znx_dft(out_rank + 1, ksk_size);
|
||||
|
||||
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size)
|
||||
+ module.bytes_of_vec_znx_dft(in_rank, in_size);
|
||||
|
||||
let norm: usize = module.vec_znx_big_normalize_tmp_bytes();
|
||||
|
||||
res_dft + (vmp | norm)
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
|
||||
GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size)
|
||||
}
|
||||
|
||||
pub fn automorphism_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
out_rank: usize,
|
||||
in_size: usize,
|
||||
autokey_size: usize,
|
||||
) -> usize {
|
||||
GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, in_size, out_rank, autokey_size)
|
||||
}
|
||||
|
||||
pub fn automorphism_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
out_rank: usize,
|
||||
autokey_size: usize,
|
||||
) -> usize {
|
||||
GLWECiphertext::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, autokey_size)
|
||||
}
|
||||
|
||||
pub fn external_product_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
out_rank: usize,
|
||||
in_size: usize,
|
||||
ggsw_size: usize,
|
||||
) -> usize {
|
||||
let res_dft: usize = module.bytes_of_vec_znx_dft(out_rank + 1, ggsw_size);
|
||||
let vmp: usize = module.bytes_of_vec_znx_dft(out_rank + 1, in_size)
|
||||
+ module.vmp_apply_tmp_bytes(
|
||||
out_size,
|
||||
in_size,
|
||||
in_size, // rows
|
||||
out_rank + 1, // cols in
|
||||
out_rank + 1, // cols out
|
||||
ggsw_size,
|
||||
);
|
||||
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
|
||||
|
||||
res_dft + (vmp | normalize)
|
||||
}
|
||||
|
||||
pub fn external_product_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
out_rank: usize,
|
||||
ggsw_size: usize,
|
||||
) -> usize {
|
||||
GLWECiphertext::external_product_scratch_space(module, out_size, out_rank, out_size, ggsw_size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> GLWECiphertext<DataSelf>
|
||||
where
|
||||
VecZnx<DataSelf>: VecZnxToMut + VecZnxToRef,
|
||||
{
|
||||
pub fn encrypt_sk<DataPt, DataSk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
pt: &GLWEPlaintext<DataPt>,
|
||||
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<DataPt>: VecZnxToRef,
|
||||
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
self.encrypt_sk_private(
|
||||
module,
|
||||
Some((pt, 0)),
|
||||
sk_dft,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
scratch,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn encrypt_zero_sk<DataSk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
self.encrypt_sk_private(module, None, sk_dft, source_xa, source_xe, sigma, scratch);
|
||||
}
|
||||
|
||||
pub fn encrypt_pk<DataPt, DataPk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
pt: &GLWEPlaintext<DataPt>,
|
||||
pk: &GLWEPublicKey<DataPk, FFT64>,
|
||||
source_xu: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<DataPt>: VecZnxToRef,
|
||||
VecZnxDft<DataPk, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
self.encrypt_pk_private(
|
||||
module,
|
||||
Some((pt, 0)),
|
||||
pk,
|
||||
source_xu,
|
||||
source_xe,
|
||||
sigma,
|
||||
scratch,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn encrypt_zero_pk<DataPk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
pk: &GLWEPublicKey<DataPk, FFT64>,
|
||||
source_xu: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnxDft<DataPk, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
self.encrypt_pk_private(module, None, pk, source_xu, source_xe, sigma, scratch);
|
||||
}
|
||||
|
||||
pub fn automorphism<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GLWECiphertext<DataLhs>,
|
||||
rhs: &AutomorphismKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<DataLhs>: VecZnxToRef,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
self.keyswitch(module, lhs, &rhs.key, scratch);
|
||||
(0..self.rank() + 1).for_each(|i| {
|
||||
module.vec_znx_automorphism_inplace(rhs.p(), self, i);
|
||||
})
|
||||
}
|
||||
|
||||
pub fn automorphism_inplace<DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &AutomorphismKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
self.keyswitch_inplace(module, &rhs.key, scratch);
|
||||
(0..self.rank() + 1).for_each(|i| {
|
||||
module.vec_znx_automorphism_inplace(rhs.p(), self, i);
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn keyswitch_from_fourier<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
let basek: usize = self.basek();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(lhs.rank(), rhs.rank_in());
|
||||
assert_eq!(self.rank(), rhs.rank_out());
|
||||
assert_eq!(self.basek(), basek);
|
||||
assert_eq!(lhs.basek(), basek);
|
||||
assert_eq!(rhs.n(), module.n());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(lhs.n(), module.n());
|
||||
assert!(
|
||||
scratch.available()
|
||||
>= GLWECiphertext::keyswitch_from_fourier_scratch_space(
|
||||
module,
|
||||
self.size(),
|
||||
self.rank(),
|
||||
lhs.size(),
|
||||
lhs.rank(),
|
||||
rhs.size(),
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let cols_in: usize = rhs.rank_in();
|
||||
let cols_out: usize = rhs.rank_out() + 1;
|
||||
|
||||
// Buffer of the result of VMP in DFT
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
|
||||
|
||||
{
|
||||
// Applies VMP
|
||||
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
|
||||
(0..cols_in).for_each(|col_i| {
|
||||
module.vec_znx_dft_copy(&mut ai_dft, col_i, lhs, col_i + 1);
|
||||
});
|
||||
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
|
||||
}
|
||||
|
||||
module.vec_znx_dft_add_inplace(&mut res_dft, 0, lhs, 0);
|
||||
|
||||
// Switches result of VMP outside of DFT
|
||||
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
|
||||
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn keyswitch<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GLWECiphertext<DataLhs>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<DataLhs>: VecZnxToRef,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
let basek: usize = self.basek();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(lhs.rank(), rhs.rank_in());
|
||||
assert_eq!(self.rank(), rhs.rank_out());
|
||||
assert_eq!(self.basek(), basek);
|
||||
assert_eq!(lhs.basek(), basek);
|
||||
assert_eq!(rhs.n(), module.n());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(lhs.n(), module.n());
|
||||
assert!(
|
||||
scratch.available()
|
||||
>= GLWECiphertext::keyswitch_scratch_space(
|
||||
module,
|
||||
self.size(),
|
||||
self.rank(),
|
||||
lhs.size(),
|
||||
lhs.rank(),
|
||||
rhs.size(),
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
let cols_in: usize = rhs.rank_in();
|
||||
let cols_out: usize = rhs.rank_out() + 1;
|
||||
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, rhs.size()); // Todo optimise
|
||||
|
||||
{
|
||||
let (mut ai_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols_in, lhs.size());
|
||||
(0..cols_in).for_each(|col_i| {
|
||||
module.vec_znx_dft(&mut ai_dft, col_i, lhs, col_i + 1);
|
||||
});
|
||||
module.vmp_apply(&mut res_dft, &ai_dft, rhs, scratch2);
|
||||
}
|
||||
|
||||
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
|
||||
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, 0, lhs, 0);
|
||||
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace<DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
|
||||
self.keyswitch(&module, &*self_ptr, rhs, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn external_product<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GLWECiphertext<DataLhs>,
|
||||
rhs: &GGSWCiphertext<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<DataLhs>: VecZnxToRef,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
let basek: usize = self.basek();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(rhs.rank(), lhs.rank());
|
||||
assert_eq!(rhs.rank(), self.rank());
|
||||
assert_eq!(self.basek(), basek);
|
||||
assert_eq!(lhs.basek(), basek);
|
||||
assert_eq!(rhs.n(), module.n());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(lhs.n(), module.n());
|
||||
}
|
||||
|
||||
let cols: usize = rhs.rank() + 1;
|
||||
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); // Todo optimise
|
||||
|
||||
{
|
||||
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
|
||||
(0..cols).for_each(|col_i| {
|
||||
module.vec_znx_dft(&mut a_dft, col_i, lhs, col_i);
|
||||
});
|
||||
module.vmp_apply(&mut res_dft, &a_dft, rhs, scratch2);
|
||||
}
|
||||
|
||||
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn external_product_inplace<DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &GGSWCiphertext<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
|
||||
self.external_product(&module, &*self_ptr, rhs, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn encrypt_sk_private<DataPt, DataSk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
|
||||
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<DataPt>: VecZnxToRef,
|
||||
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank(), sk_dft.rank());
|
||||
assert_eq!(sk_dft.n(), module.n());
|
||||
assert_eq!(self.n(), module.n());
|
||||
if let Some((pt, col)) = pt {
|
||||
assert_eq!(pt.n(), module.n());
|
||||
assert!(col < self.rank() + 1);
|
||||
}
|
||||
}
|
||||
|
||||
let log_base2k: usize = self.basek();
|
||||
let log_k: usize = self.k();
|
||||
let size: usize = self.size();
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx(module, 1, size);
|
||||
c0_big.zero();
|
||||
|
||||
{
|
||||
// c[i] = uniform
|
||||
// c[0] -= c[i] * s[i],
|
||||
(1..cols).for_each(|i| {
|
||||
let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size);
|
||||
|
||||
// c[i] = uniform
|
||||
self.data.fill_uniform(log_base2k, i, size, source_xa);
|
||||
|
||||
// c[i] = norm(IDFT(DFT(c[i]) * DFT(s[i])))
|
||||
module.vec_znx_dft(&mut ci_dft, 0, self, i);
|
||||
module.svp_apply_inplace(&mut ci_dft, 0, sk_dft, i - 1);
|
||||
let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft);
|
||||
|
||||
// use c[0] as buffer, which is overwritten later by the normalization step
|
||||
module.vec_znx_big_normalize(log_base2k, self, 0, &ci_big, 0, scratch_2);
|
||||
|
||||
// c0_tmp = -c[i] * s[i] (use c[0] as buffer)
|
||||
module.vec_znx_sub_ab_inplace(&mut c0_big, 0, self, 0);
|
||||
|
||||
// c[i] += m if col = i
|
||||
if let Some((pt, col)) = pt {
|
||||
if i == col {
|
||||
module.vec_znx_add_inplace(self, i, pt, 0);
|
||||
module.vec_znx_normalize_inplace(log_base2k, self, i, scratch_2);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// c[0] += e
|
||||
c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, sigma * SIX_SIGMA);
|
||||
|
||||
// c[0] += m if col = 0
|
||||
if let Some((pt, col)) = pt {
|
||||
if col == 0 {
|
||||
module.vec_znx_add_inplace(&mut c0_big, 0, pt, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// c[0] = norm(c[0])
|
||||
module.vec_znx_normalize(log_base2k, self, 0, &c0_big, 0, scratch_1);
|
||||
}
|
||||
|
||||
pub(crate) fn encrypt_pk_private<DataPt, DataPk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
pt: Option<(&GLWEPlaintext<DataPt>, usize)>,
|
||||
pk: &GLWEPublicKey<DataPk, FFT64>,
|
||||
source_xu: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<DataPt>: VecZnxToRef,
|
||||
VecZnxDft<DataPk, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.basek(), pk.basek());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(pk.n(), module.n());
|
||||
assert_eq!(self.rank(), pk.rank());
|
||||
if let Some((pt, _)) = pt {
|
||||
assert_eq!(pt.basek(), pk.basek());
|
||||
assert_eq!(pt.n(), module.n());
|
||||
}
|
||||
}
|
||||
|
||||
let log_base2k: usize = pk.basek();
|
||||
let size_pk: usize = pk.size();
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
// Generates u according to the underlying secret distribution.
|
||||
let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1);
|
||||
|
||||
{
|
||||
let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1);
|
||||
match pk.dist {
|
||||
SecretDistribution::NONE => panic!(
|
||||
"invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through \
|
||||
Self::generate"
|
||||
),
|
||||
SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu),
|
||||
SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu),
|
||||
SecretDistribution::ZERO => {}
|
||||
}
|
||||
|
||||
module.svp_prepare(&mut u_dft, 0, &u, 0);
|
||||
}
|
||||
|
||||
// ct[i] = pk[i] * u + ei (+ m if col = i)
|
||||
(0..cols).for_each(|i| {
|
||||
let (mut ci_dft, scratch_2) = scratch_1.tmp_vec_znx_dft(module, 1, size_pk);
|
||||
// ci_dft = DFT(u) * DFT(pk[i])
|
||||
module.svp_apply(&mut ci_dft, 0, &u_dft, 0, pk, i);
|
||||
|
||||
// ci_big = u * p[i]
|
||||
let mut ci_big = module.vec_znx_idft_consume(ci_dft);
|
||||
|
||||
// ci_big = u * pk[i] + e
|
||||
ci_big.add_normal(log_base2k, 0, pk.k(), source_xe, sigma, sigma * SIX_SIGMA);
|
||||
|
||||
// ci_big = u * pk[i] + e + m (if col = i)
|
||||
if let Some((pt, col)) = pt {
|
||||
if col == i {
|
||||
module.vec_znx_big_add_small_inplace(&mut ci_big, 0, pt, 0);
|
||||
}
|
||||
}
|
||||
|
||||
// ct[i] = norm(ci_big)
|
||||
module.vec_znx_big_normalize(log_base2k, self, i, &ci_big, 0, scratch_2);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> GLWECiphertext<DataSelf>
|
||||
where
|
||||
VecZnx<DataSelf>: VecZnxToRef,
|
||||
{
|
||||
pub fn decrypt<DataPt, DataSk>(
|
||||
&self,
|
||||
module: &Module<FFT64>,
|
||||
pt: &mut GLWEPlaintext<DataPt>,
|
||||
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<DataPt>: VecZnxToMut,
|
||||
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank(), sk_dft.rank());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(pt.n(), module.n());
|
||||
assert_eq!(sk_dft.n(), module.n());
|
||||
}
|
||||
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct
|
||||
c0_big.zero();
|
||||
|
||||
{
|
||||
(1..cols).for_each(|i| {
|
||||
// ci_dft = DFT(a[i]) * DFT(s[i])
|
||||
let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct
|
||||
module.vec_znx_dft(&mut ci_dft, 0, self, i);
|
||||
module.svp_apply_inplace(&mut ci_dft, 0, sk_dft, i - 1);
|
||||
let ci_big = module.vec_znx_idft_consume(ci_dft);
|
||||
|
||||
// c0_big += a[i] * s[i]
|
||||
module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0);
|
||||
});
|
||||
}
|
||||
|
||||
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
|
||||
module.vec_znx_big_add_small_inplace(&mut c0_big, 0, self, 0);
|
||||
|
||||
// pt = norm(BIG(m + e))
|
||||
module.vec_znx_big_normalize(self.basek(), pt, 0, &mut c0_big, 0, scratch_1);
|
||||
|
||||
pt.basek = self.basek();
|
||||
pt.k = pt.k().min(self.k());
|
||||
}
|
||||
}
|
||||
323
core/src/glwe_ciphertext_fourier.rs
Normal file
323
core/src/glwe_ciphertext_fourier.rs
Normal file
@@ -0,0 +1,323 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps,
|
||||
ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
|
||||
VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, glwe_plaintext::GLWEPlaintext,
|
||||
keys::SecretKeyFourier, keyswitch_key::GLWESwitchingKey, utils::derive_size,
|
||||
};
|
||||
|
||||
pub struct GLWECiphertextFourier<C, B: Backend> {
|
||||
pub data: VecZnxDft<C, B>,
|
||||
pub basek: usize,
|
||||
pub k: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> GLWECiphertextFourier<Vec<u8>, B> {
|
||||
pub fn new(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_vec_znx_dft(rank + 1, derive_size(basek, k)),
|
||||
basek: basek,
|
||||
k: k,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> Infos for GLWECiphertextFourier<T, B> {
|
||||
type Inner = VecZnxDft<T, B>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
&self.data
|
||||
}
|
||||
|
||||
fn basek(&self) -> usize {
|
||||
self.basek
|
||||
}
|
||||
|
||||
fn k(&self) -> usize {
|
||||
self.k
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> GLWECiphertextFourier<T, B> {
|
||||
pub fn rank(&self) -> usize {
|
||||
self.cols() - 1
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> VecZnxDftToMut<B> for GLWECiphertextFourier<C, B>
|
||||
where
|
||||
VecZnxDft<C, B>: VecZnxDftToMut<B>,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> VecZnxDftToRef<B> for GLWECiphertextFourier<C, B>
|
||||
where
|
||||
VecZnxDft<C, B>: VecZnxDftToRef<B>,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl GLWECiphertextFourier<Vec<u8>, FFT64> {
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn idft_scratch_space(module: &Module<FFT64>, size: usize) -> usize {
|
||||
module.bytes_of_vec_znx(1, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())
|
||||
}
|
||||
|
||||
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, ct_size: usize) -> usize {
|
||||
module.bytes_of_vec_znx(rank + 1, ct_size) + GLWECiphertext::encrypt_sk_scratch_space(module, ct_size)
|
||||
}
|
||||
|
||||
pub fn decrypt_scratch_space(module: &Module<FFT64>, ct_size: usize) -> usize {
|
||||
(module.vec_znx_big_normalize_tmp_bytes()
|
||||
| module.bytes_of_vec_znx_dft(1, ct_size)
|
||||
| (module.bytes_of_vec_znx_big(1, ct_size) + module.vec_znx_idft_tmp_bytes()))
|
||||
+ module.bytes_of_vec_znx_big(1, ct_size)
|
||||
}
|
||||
|
||||
pub fn keyswitch_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
out_rank: usize,
|
||||
in_size: usize,
|
||||
in_rank: usize,
|
||||
ksk_size: usize,
|
||||
) -> usize {
|
||||
module.bytes_of_vec_znx(out_rank + 1, out_size)
|
||||
+ GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, out_rank, in_size, in_rank, ksk_size)
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
|
||||
Self::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size)
|
||||
}
|
||||
|
||||
pub fn external_product_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ggsw_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size);
|
||||
let res_small: usize = module.bytes_of_vec_znx(rank + 1, out_size);
|
||||
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
|
||||
|
||||
res_dft + (vmp | (res_small + normalize))
|
||||
}
|
||||
|
||||
pub fn external_product_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
ggsw_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
Self::external_product_scratch_space(module, out_size, out_size, ggsw_size, rank)
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> GLWECiphertextFourier<DataSelf, FFT64>
|
||||
where
|
||||
VecZnxDft<DataSelf, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
pub fn encrypt_zero_sk<DataSk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let (vec_znx_tmp, scratch_1) = scratch.tmp_vec_znx(module, self.rank() + 1, self.size());
|
||||
let mut ct_idft = GLWECiphertext {
|
||||
data: vec_znx_tmp,
|
||||
basek: self.basek,
|
||||
k: self.k,
|
||||
};
|
||||
ct_idft.encrypt_zero_sk(module, sk_dft, source_xa, source_xe, sigma, scratch_1);
|
||||
|
||||
ct_idft.dft(module, self);
|
||||
}
|
||||
|
||||
pub fn keyswitch<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
let cols_out: usize = rhs.rank_out() + 1;
|
||||
|
||||
// Space fr normalized VMP result outside of DFT domain
|
||||
let (res_idft_data, scratch1) = scratch.tmp_vec_znx(module, cols_out, lhs.size());
|
||||
|
||||
let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
|
||||
data: res_idft_data,
|
||||
basek: lhs.basek,
|
||||
k: lhs.k,
|
||||
};
|
||||
|
||||
res_idft.keyswitch_from_fourier(module, lhs, rhs, scratch1);
|
||||
|
||||
(0..cols_out).for_each(|i| {
|
||||
module.vec_znx_dft(self, i, &res_idft, i);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace<DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
|
||||
self.keyswitch(&module, &*self_ptr, rhs, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn external_product<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
|
||||
rhs: &GGSWCiphertext<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
let basek: usize = self.basek();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(rhs.rank(), lhs.rank());
|
||||
assert_eq!(rhs.rank(), self.rank());
|
||||
assert_eq!(self.basek(), basek);
|
||||
assert_eq!(lhs.basek(), basek);
|
||||
assert_eq!(rhs.n(), module.n());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(lhs.n(), module.n());
|
||||
}
|
||||
|
||||
let cols: usize = rhs.rank() + 1;
|
||||
|
||||
// Space for VMP result in DFT domain and high precision
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size());
|
||||
|
||||
{
|
||||
module.vmp_apply(&mut res_dft, lhs, rhs, scratch1);
|
||||
}
|
||||
|
||||
// VMP result in high precision
|
||||
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume::<&mut [u8]>(res_dft);
|
||||
|
||||
// Space for VMP result normalized
|
||||
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size());
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
|
||||
module.vec_znx_dft(self, i, &res_small, i);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn external_product_inplace<DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &GGSWCiphertext<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
unsafe {
|
||||
let self_ptr: *mut GLWECiphertextFourier<DataSelf, FFT64> = self as *mut GLWECiphertextFourier<DataSelf, FFT64>;
|
||||
self.external_product(&module, &*self_ptr, rhs, scratch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> GLWECiphertextFourier<DataSelf, FFT64>
|
||||
where
|
||||
VecZnxDft<DataSelf, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
pub fn decrypt<DataPt, DataSk>(
|
||||
&self,
|
||||
module: &Module<FFT64>,
|
||||
pt: &mut GLWEPlaintext<DataPt>,
|
||||
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<DataPt>: VecZnxToMut + VecZnxToRef,
|
||||
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank(), sk_dft.rank());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(pt.n(), module.n());
|
||||
assert_eq!(sk_dft.n(), module.n());
|
||||
}
|
||||
|
||||
let cols = self.rank() + 1;
|
||||
|
||||
let (mut pt_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct
|
||||
pt_big.zero();
|
||||
|
||||
{
|
||||
(1..cols).for_each(|i| {
|
||||
let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct
|
||||
module.svp_apply(&mut ci_dft, 0, sk_dft, i - 1, self, i);
|
||||
let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft);
|
||||
module.vec_znx_big_add_inplace(&mut pt_big, 0, &ci_big, 0);
|
||||
});
|
||||
}
|
||||
|
||||
{
|
||||
let (mut c0_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, self.size());
|
||||
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
|
||||
module.vec_znx_idft(&mut c0_big, 0, self, 0, scratch_2);
|
||||
module.vec_znx_big_add_inplace(&mut pt_big, 0, &c0_big, 0);
|
||||
}
|
||||
|
||||
// pt = norm(BIG(m + e))
|
||||
module.vec_znx_big_normalize(self.basek(), pt, 0, &mut pt_big, 0, scratch_1);
|
||||
|
||||
pt.basek = self.basek();
|
||||
pt.k = pt.k().min(self.k());
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn idft<DataRes>(&self, module: &Module<FFT64>, res: &mut GLWECiphertext<DataRes>, scratch: &mut Scratch)
|
||||
where
|
||||
GLWECiphertext<DataRes>: VecZnxToMut,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank(), res.rank());
|
||||
assert_eq!(self.basek(), res.basek())
|
||||
}
|
||||
|
||||
let min_size: usize = self.size().min(res.size());
|
||||
|
||||
let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 1, min_size);
|
||||
|
||||
(0..self.rank() + 1).for_each(|i| {
|
||||
module.vec_znx_idft(&mut res_big, 0, self, i, scratch1);
|
||||
module.vec_znx_big_normalize(self.basek(), res, i, &res_big, 0, scratch1);
|
||||
});
|
||||
}
|
||||
}
|
||||
53
core/src/glwe_plaintext.rs
Normal file
53
core/src/glwe_plaintext.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use base2k::{Backend, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef};
|
||||
|
||||
use crate::{elem::Infos, utils::derive_size};
|
||||
|
||||
pub struct GLWEPlaintext<C> {
|
||||
pub data: VecZnx<C>,
|
||||
pub basek: usize,
|
||||
pub k: usize,
|
||||
}
|
||||
|
||||
impl<T> Infos for GLWEPlaintext<T> {
|
||||
type Inner = VecZnx<T>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
&self.data
|
||||
}
|
||||
|
||||
fn basek(&self) -> usize {
|
||||
self.basek
|
||||
}
|
||||
|
||||
fn k(&self) -> usize {
|
||||
self.k
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> VecZnxToMut for GLWEPlaintext<C>
|
||||
where
|
||||
VecZnx<C>: VecZnxToMut,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> VecZnxToRef for GLWEPlaintext<C>
|
||||
where
|
||||
VecZnx<C>: VecZnxToRef,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl GLWEPlaintext<Vec<u8>> {
|
||||
pub fn new<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_vec_znx(1, derive_size(basek, k)),
|
||||
basek: basek,
|
||||
k,
|
||||
}
|
||||
}
|
||||
}
|
||||
247
core/src/keys.rs
Normal file
247
core/src/keys.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut,
|
||||
ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, ScratchOwned, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos,
|
||||
ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{elem::Infos, glwe_ciphertext_fourier::GLWECiphertextFourier};
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum SecretDistribution {
|
||||
TernaryFixed(usize), // Ternary with fixed Hamming weight
|
||||
TernaryProb(f64), // Ternary with probabilistic Hamming weight
|
||||
ZERO, // Debug mod
|
||||
NONE,
|
||||
}
|
||||
|
||||
pub struct SecretKey<T> {
|
||||
pub data: ScalarZnx<T>,
|
||||
pub dist: SecretDistribution,
|
||||
}
|
||||
|
||||
impl SecretKey<Vec<u8>> {
|
||||
pub fn new<B: Backend>(module: &Module<B>, rank: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_scalar_znx(rank),
|
||||
dist: SecretDistribution::NONE,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> SecretKey<DataSelf> {
|
||||
pub fn n(&self) -> usize {
|
||||
self.data.n()
|
||||
}
|
||||
|
||||
pub fn log_n(&self) -> usize {
|
||||
self.data.log_n()
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
self.data.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> SecretKey<S>
|
||||
where
|
||||
S: AsMut<[u8]> + AsRef<[u8]>,
|
||||
{
|
||||
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
|
||||
(0..self.rank()).for_each(|i| {
|
||||
self.data.fill_ternary_prob(i, prob, source);
|
||||
});
|
||||
self.dist = SecretDistribution::TernaryProb(prob);
|
||||
}
|
||||
|
||||
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
|
||||
(0..self.rank()).for_each(|i| {
|
||||
self.data.fill_ternary_hw(i, hw, source);
|
||||
});
|
||||
self.dist = SecretDistribution::TernaryFixed(hw);
|
||||
}
|
||||
|
||||
pub fn fill_zero(&mut self) {
|
||||
self.data.zero();
|
||||
self.dist = SecretDistribution::ZERO;
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> ScalarZnxToMut for SecretKey<C>
|
||||
where
|
||||
ScalarZnx<C>: ScalarZnxToMut,
|
||||
{
|
||||
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> ScalarZnxToRef for SecretKey<C>
|
||||
where
|
||||
ScalarZnx<C>: ScalarZnxToRef,
|
||||
{
|
||||
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SecretKeyFourier<T, B: Backend> {
|
||||
pub data: ScalarZnxDft<T, B>,
|
||||
pub dist: SecretDistribution,
|
||||
}
|
||||
|
||||
impl<DataSelf, B: Backend> SecretKeyFourier<DataSelf, B> {
|
||||
pub fn n(&self) -> usize {
|
||||
self.data.n()
|
||||
}
|
||||
|
||||
pub fn log_n(&self) -> usize {
|
||||
self.data.log_n()
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
self.data.cols()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> SecretKeyFourier<Vec<u8>, B> {
|
||||
pub fn new(module: &Module<B>, rank: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_scalar_znx_dft(rank),
|
||||
dist: SecretDistribution::NONE,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dft<S>(&mut self, module: &Module<FFT64>, sk: &SecretKey<S>)
|
||||
where
|
||||
SecretKeyFourier<Vec<u8>, B>: ScalarZnxDftToMut<base2k::FFT64>,
|
||||
SecretKey<S>: ScalarZnxToRef,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
match sk.dist {
|
||||
SecretDistribution::NONE => panic!("invalid sk: SecretDistribution::NONE"),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(sk.n(), module.n());
|
||||
assert_eq!(self.rank(), sk.rank());
|
||||
}
|
||||
|
||||
(0..self.rank()).for_each(|i| {
|
||||
module.svp_prepare(self, i, sk, i);
|
||||
});
|
||||
self.dist = sk.dist;
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> ScalarZnxDftToMut<B> for SecretKeyFourier<C, B>
|
||||
where
|
||||
ScalarZnxDft<C, B>: ScalarZnxDftToMut<B>,
|
||||
{
|
||||
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> ScalarZnxDftToRef<B> for SecretKeyFourier<C, B>
|
||||
where
|
||||
ScalarZnxDft<C, B>: ScalarZnxDftToRef<B>,
|
||||
{
|
||||
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GLWEPublicKey<D, B: Backend> {
|
||||
pub data: GLWECiphertextFourier<D, B>,
|
||||
pub dist: SecretDistribution,
|
||||
}
|
||||
|
||||
impl<B: Backend> GLWEPublicKey<Vec<u8>, B> {
|
||||
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize, rank: usize) -> Self {
|
||||
Self {
|
||||
data: GLWECiphertextFourier::new(module, log_base2k, log_k, rank),
|
||||
dist: SecretDistribution::NONE,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> Infos for GLWEPublicKey<T, B> {
|
||||
type Inner = VecZnxDft<T, B>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
&self.data.data
|
||||
}
|
||||
|
||||
fn basek(&self) -> usize {
|
||||
self.data.basek
|
||||
}
|
||||
|
||||
fn k(&self) -> usize {
|
||||
self.data.k
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> GLWEPublicKey<T, B> {
|
||||
pub fn rank(&self) -> usize {
|
||||
self.cols() - 1
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> VecZnxDftToMut<B> for GLWEPublicKey<C, B>
|
||||
where
|
||||
VecZnxDft<C, B>: VecZnxDftToMut<B>,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> VecZnxDftToRef<B> for GLWEPublicKey<C, B>
|
||||
where
|
||||
VecZnxDft<C, B>: VecZnxDftToRef<B>,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> GLWEPublicKey<C, FFT64> {
|
||||
pub fn generate<S>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
sk_dft: &SecretKeyFourier<S, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
) where
|
||||
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
match sk_dft.dist {
|
||||
SecretDistribution::NONE => panic!("invalid sk_dft: SecretDistribution::NONE"),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Its ok to allocate scratch space here since pk is usually generated only once.
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(GLWECiphertextFourier::encrypt_sk_scratch_space(
|
||||
module,
|
||||
self.rank(),
|
||||
self.size(),
|
||||
));
|
||||
self.data.encrypt_zero_sk(
|
||||
module,
|
||||
sk_dft,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
self.dist = sk_dft.dist;
|
||||
}
|
||||
}
|
||||
385
core/src/keyswitch_key.rs
Normal file
385
core/src/keyswitch_key.rs
Normal file
@@ -0,0 +1,385 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef,
|
||||
ScalarZnxToRef, Scratch, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::{GetRow, Infos, SetRow},
|
||||
gglwe_ciphertext::GGLWECiphertext,
|
||||
ggsw_ciphertext::GGSWCiphertext,
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
keys::{SecretKey, SecretKeyFourier},
|
||||
};
|
||||
|
||||
pub struct GLWESwitchingKey<Data, B: Backend>(pub(crate) GGLWECiphertext<Data, B>);
|
||||
|
||||
impl GLWESwitchingKey<Vec<u8>, FFT64> {
|
||||
pub fn new(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self {
|
||||
GLWESwitchingKey(GGLWECiphertext::new(
|
||||
module, basek, k, rows, rank_in, rank_out,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> Infos for GLWESwitchingKey<T, B> {
|
||||
type Inner = MatZnxDft<T, B>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
self.0.inner()
|
||||
}
|
||||
|
||||
fn basek(&self) -> usize {
|
||||
self.0.basek()
|
||||
}
|
||||
|
||||
fn k(&self) -> usize {
|
||||
self.0.k()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> GLWESwitchingKey<T, B> {
|
||||
pub fn rank(&self) -> usize {
|
||||
self.0.data.cols_out() - 1
|
||||
}
|
||||
|
||||
pub fn rank_in(&self) -> usize {
|
||||
self.0.data.cols_in()
|
||||
}
|
||||
|
||||
pub fn rank_out(&self) -> usize {
|
||||
self.0.data.cols_out() - 1
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf, B: Backend> MatZnxDftToMut<B> for GLWESwitchingKey<DataSelf, B>
|
||||
where
|
||||
MatZnxDft<DataSelf, B>: MatZnxDftToMut<B>,
|
||||
{
|
||||
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
|
||||
self.0.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf, B: Backend> MatZnxDftToRef<B> for GLWESwitchingKey<DataSelf, B>
|
||||
where
|
||||
MatZnxDft<DataSelf, B>: MatZnxDftToRef<B>,
|
||||
{
|
||||
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
||||
self.0.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> GetRow<FFT64> for GLWESwitchingKey<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut R)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
module.vmp_extract_row(res, self, row_i, col_j);
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> SetRow<FFT64> for GLWESwitchingKey<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
|
||||
{
|
||||
fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, col_j: usize, a: &R)
|
||||
where
|
||||
R: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
module.vmp_prepare_row(self, row_i, col_j, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl GLWESwitchingKey<Vec<u8>, FFT64> {
|
||||
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
|
||||
GGLWECiphertext::encrypt_sk_scratch_space(module, rank, size)
|
||||
}
|
||||
|
||||
pub fn encrypt_pk_scratch_space(module: &Module<FFT64>, rank: usize, pk_size: usize) -> usize {
|
||||
GGLWECiphertext::encrypt_pk_scratch_space(module, rank, pk_size)
|
||||
}
|
||||
|
||||
pub fn keyswitch_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
out_rank: usize,
|
||||
in_size: usize,
|
||||
in_rank: usize,
|
||||
ksk_size: usize,
|
||||
) -> usize {
|
||||
let tmp_in: usize = module.bytes_of_vec_znx_dft(in_rank + 1, in_size);
|
||||
let tmp_out: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size);
|
||||
let ksk: usize = GLWECiphertextFourier::keyswitch_scratch_space(module, out_size, out_rank, in_size, in_rank, ksk_size);
|
||||
tmp_in + tmp_out + ksk
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
|
||||
let tmp: usize = module.bytes_of_vec_znx_dft(out_rank + 1, out_size);
|
||||
let ksk: usize = GLWECiphertextFourier::keyswitch_inplace_scratch_space(module, out_size, out_rank, ksk_size);
|
||||
tmp + ksk
|
||||
}
|
||||
|
||||
pub fn external_product_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ggsw_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size);
|
||||
let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank);
|
||||
tmp_in + tmp_out + ggsw
|
||||
}
|
||||
|
||||
pub fn external_product_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
ggsw_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank);
|
||||
tmp + ggsw
|
||||
}
|
||||
}
|
||||
impl<DataSelf> GLWESwitchingKey<DataSelf, FFT64>
|
||||
where
|
||||
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
pub fn encrypt_sk<DataSkIn, DataSkOut>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
sk_in: &SecretKey<DataSkIn>,
|
||||
sk_out_dft: &SecretKeyFourier<DataSkOut, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
ScalarZnx<DataSkIn>: ScalarZnxToRef,
|
||||
ScalarZnxDft<DataSkOut, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
self.0.encrypt_sk(
|
||||
module,
|
||||
&sk_in.data,
|
||||
sk_out_dft,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
scratch,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn keyswitch<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GLWESwitchingKey<DataLhs, FFT64>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut base2k::Scratch,
|
||||
) where
|
||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank_in(),
|
||||
lhs.rank_in(),
|
||||
"ksk_out input rank: {} != ksk_in input rank: {}",
|
||||
self.rank_in(),
|
||||
lhs.rank_in()
|
||||
);
|
||||
assert_eq!(
|
||||
lhs.rank_out(),
|
||||
rhs.rank_in(),
|
||||
"ksk_in output rank: {} != ksk_apply input rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank_in()
|
||||
);
|
||||
assert_eq!(
|
||||
self.rank_out(),
|
||||
rhs.rank_out(),
|
||||
"ksk_out output rank: {} != ksk_apply output rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank_out()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size());
|
||||
|
||||
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_in_data,
|
||||
basek: lhs.basek(),
|
||||
k: lhs.k(),
|
||||
};
|
||||
|
||||
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
|
||||
|
||||
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_out_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
(0..self.rank_in()).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
lhs.get_row(module, row_j, col_i, &mut tmp_in);
|
||||
tmp_out.keyswitch(module, &tmp_in, rhs, scratch2);
|
||||
self.set_row(module, row_j, col_i, &tmp_out);
|
||||
});
|
||||
});
|
||||
|
||||
tmp_out.data.zero();
|
||||
|
||||
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
|
||||
(0..self.rank_in()).for_each(|col_j| {
|
||||
self.set_row(module, row_i, col_j, &tmp_out);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace<DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut base2k::Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank_out(),
|
||||
rhs.rank_out(),
|
||||
"ksk_out output rank: {} != ksk_apply output rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank_out()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
|
||||
|
||||
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
(0..self.rank_in()).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
self.get_row(module, row_j, col_i, &mut tmp);
|
||||
tmp.keyswitch_inplace(module, rhs, scratch1);
|
||||
self.set_row(module, row_j, col_i, &tmp);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub fn external_product<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GLWESwitchingKey<DataLhs, FFT64>,
|
||||
rhs: &GGSWCiphertext<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank_in(),
|
||||
lhs.rank_in(),
|
||||
"ksk_out input rank: {} != ksk_in input rank: {}",
|
||||
self.rank_in(),
|
||||
lhs.rank_in()
|
||||
);
|
||||
assert_eq!(
|
||||
lhs.rank_out(),
|
||||
rhs.rank(),
|
||||
"ksk_in output rank: {} != ggsw rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank()
|
||||
);
|
||||
assert_eq!(
|
||||
self.rank_out(),
|
||||
rhs.rank(),
|
||||
"ksk_out output rank: {} != ggsw rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size());
|
||||
|
||||
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_in_data,
|
||||
basek: lhs.basek(),
|
||||
k: lhs.k(),
|
||||
};
|
||||
|
||||
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
|
||||
|
||||
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_out_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
(0..self.rank_in()).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
lhs.get_row(module, row_j, col_i, &mut tmp_in);
|
||||
tmp_out.external_product(module, &tmp_in, rhs, scratch2);
|
||||
self.set_row(module, row_j, col_i, &tmp_out);
|
||||
});
|
||||
});
|
||||
|
||||
tmp_out.data.zero();
|
||||
|
||||
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
|
||||
(0..self.rank_in()).for_each(|col_j| {
|
||||
self.set_row(module, row_i, col_j, &tmp_out);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub fn external_product_inplace<DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &GGSWCiphertext<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(
|
||||
self.rank_out(),
|
||||
rhs.rank(),
|
||||
"ksk_out output rank: {} != ggsw rank: {}",
|
||||
self.rank_out(),
|
||||
rhs.rank()
|
||||
);
|
||||
}
|
||||
|
||||
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
|
||||
|
||||
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_data,
|
||||
basek: self.basek(),
|
||||
k: self.k(),
|
||||
};
|
||||
|
||||
(0..self.rank_in()).for_each(|col_i| {
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
self.get_row(module, row_j, col_i, &mut tmp);
|
||||
tmp.external_product_inplace(module, rhs, scratch1);
|
||||
self.set_row(module, row_j, col_i, &tmp);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
15
core/src/lib.rs
Normal file
15
core/src/lib.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
pub mod automorphism;
|
||||
pub mod elem;
|
||||
pub mod gglwe_ciphertext;
|
||||
pub mod ggsw_ciphertext;
|
||||
pub mod glwe_ciphertext;
|
||||
pub mod glwe_ciphertext_fourier;
|
||||
pub mod glwe_plaintext;
|
||||
pub mod keys;
|
||||
pub mod keyswitch_key;
|
||||
pub mod tensor_key;
|
||||
#[cfg(test)]
|
||||
mod test_fft64;
|
||||
mod utils;
|
||||
|
||||
pub(crate) const SIX_SIGMA: f64 = 6.0;
|
||||
130
core/src/tensor_key.rs
Normal file
130
core/src/tensor_key.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftAlloc,
|
||||
ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnxDftOps, VecZnxDftToRef,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::Infos,
|
||||
keys::{SecretKey, SecretKeyFourier},
|
||||
keyswitch_key::GLWESwitchingKey,
|
||||
};
|
||||
|
||||
pub struct TensorKey<C, B: Backend> {
|
||||
pub(crate) keys: Vec<GLWESwitchingKey<C, B>>,
|
||||
}
|
||||
|
||||
impl TensorKey<Vec<u8>, FFT64> {
|
||||
pub fn new(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
|
||||
let mut keys: Vec<GLWESwitchingKey<Vec<u8>, FFT64>> = Vec::new();
|
||||
let pairs: usize = ((rank + 1) * rank) >> 1;
|
||||
(0..pairs).for_each(|_| {
|
||||
keys.push(GLWESwitchingKey::new(module, basek, k, rows, 1, rank));
|
||||
});
|
||||
Self { keys: keys }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> Infos for TensorKey<T, B> {
|
||||
type Inner = MatZnxDft<T, B>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
&self.keys[0].inner()
|
||||
}
|
||||
|
||||
fn basek(&self) -> usize {
|
||||
self.keys[0].basek()
|
||||
}
|
||||
|
||||
fn k(&self) -> usize {
|
||||
self.keys[0].k()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> TensorKey<T, B> {
|
||||
pub fn rank(&self) -> usize {
|
||||
self.keys[0].rank()
|
||||
}
|
||||
|
||||
pub fn rank_in(&self) -> usize {
|
||||
self.keys[0].rank_in()
|
||||
}
|
||||
|
||||
pub fn rank_out(&self) -> usize {
|
||||
self.keys[0].rank_out()
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorKey<Vec<u8>, FFT64> {
|
||||
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
|
||||
module.bytes_of_scalar_znx_dft(1) + GLWESwitchingKey::encrypt_sk_scratch_space(module, rank, size)
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> TensorKey<DataSelf, FFT64>
|
||||
where
|
||||
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
pub fn encrypt_sk<DataSk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
ScalarZnxDft<DataSk, FFT64>: VecZnxDftToRef<FFT64> + ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank(), sk_dft.rank());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(sk_dft.n(), module.n());
|
||||
}
|
||||
|
||||
let rank: usize = self.rank();
|
||||
|
||||
(0..rank).for_each(|i| {
|
||||
(i..rank).for_each(|j| {
|
||||
let (mut sk_ij_dft, scratch1) = scratch.tmp_scalar_znx_dft(module, 1);
|
||||
module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j);
|
||||
let sk_ij: ScalarZnx<&mut [u8]> = module
|
||||
.vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft())
|
||||
.to_vec_znx_small()
|
||||
.to_scalar_znx();
|
||||
let sk_ij: SecretKey<&mut [u8]> = SecretKey {
|
||||
data: sk_ij,
|
||||
dist: sk_dft.dist,
|
||||
};
|
||||
|
||||
self.at_mut(i, j).encrypt_sk(
|
||||
module, &sk_ij, sk_dft, source_xa, source_xe, sigma, scratch1,
|
||||
);
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
// Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j])
|
||||
pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKey<DataSelf, FFT64> {
|
||||
if i > j {
|
||||
std::mem::swap(&mut i, &mut j);
|
||||
};
|
||||
let rank: usize = self.rank();
|
||||
&mut self.keys[i * rank + j - (i * (i + 1) / 2)]
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> TensorKey<DataSelf, FFT64>
|
||||
where
|
||||
MatZnxDft<DataSelf, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
// Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j])
|
||||
pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey<DataSelf, FFT64> {
|
||||
if i > j {
|
||||
std::mem::swap(&mut i, &mut j);
|
||||
};
|
||||
let rank: usize = self.rank();
|
||||
&self.keys[i * rank + j - (i * (i + 1) / 2)]
|
||||
}
|
||||
}
|
||||
216
core/src/test_fft64/automorphism_key.rs
Normal file
216
core/src/test_fft64/automorphism_key.rs
Normal file
@@ -0,0 +1,216 @@
|
||||
use base2k::{FFT64, Module, ScalarZnxOps, ScratchOwned, Stats, VecZnxOps};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
automorphism::AutomorphismKey,
|
||||
elem::{GetRow, Infos},
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
glwe_plaintext::GLWEPlaintext,
|
||||
keys::{SecretKey, SecretKeyFourier},
|
||||
test_fft64::gglwe::log2_std_noise_gglwe_product,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn automorphism() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test automorphism rank: {}", rank);
|
||||
test_automorphism(-1, 5, 12, 12, 60, 3.2, rank);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn automorphism_inplace() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test automorphism_inplace rank: {}", rank);
|
||||
test_automorphism_inplace(-1, 5, 12, 12, 60, 3.2, rank);
|
||||
});
|
||||
}
|
||||
|
||||
fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows = (k_ksk + basek - 1) / basek;
|
||||
|
||||
let mut auto_key_in: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
|
||||
let mut auto_key_out: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
|
||||
let mut auto_key_apply: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key_in.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, auto_key_out.size())
|
||||
| AutomorphismKey::automorphism_scratch_space(
|
||||
&module,
|
||||
auto_key_out.size(),
|
||||
auto_key_in.size(),
|
||||
auto_key_apply.size(),
|
||||
rank,
|
||||
),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
// gglwe_{s1}(s0) = s0 -> s1
|
||||
auto_key_in.encrypt_sk(
|
||||
&module,
|
||||
p0,
|
||||
&sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// gglwe_{s2}(s1) -> s1 -> s2
|
||||
auto_key_apply.encrypt_sk(
|
||||
&module,
|
||||
p1,
|
||||
&sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0)
|
||||
auto_key_out.automorphism(&module, &auto_key_in, &auto_key_apply, scratch.borrow());
|
||||
|
||||
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank);
|
||||
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ksk);
|
||||
|
||||
let mut sk_auto: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk
|
||||
(0..rank).for_each(|i| {
|
||||
module.scalar_znx_automorphism(module.galois_element_inv(p0 * p1), &mut sk_auto, i, &sk, i);
|
||||
});
|
||||
|
||||
let mut sk_auto_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_auto_dft.dft(&module, &sk_auto);
|
||||
|
||||
(0..auto_key_out.rank_in()).for_each(|col_i| {
|
||||
(0..auto_key_out.rows()).for_each(|row_i| {
|
||||
auto_key_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
|
||||
|
||||
ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow());
|
||||
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i);
|
||||
|
||||
let noise_have: f64 = pt.data.std(0, basek).log2();
|
||||
let noise_want: f64 = log2_std_noise_gglwe_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
rank as f64,
|
||||
k_ksk,
|
||||
k_ksk,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_automorphism_inplace(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows = (k_ksk + basek - 1) / basek;
|
||||
|
||||
let mut auto_key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
|
||||
let mut auto_key_apply: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, auto_key.size())
|
||||
| AutomorphismKey::automorphism_inplace_scratch_space(&module, auto_key.size(), auto_key_apply.size(), rank),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
// gglwe_{s1}(s0) = s0 -> s1
|
||||
auto_key.encrypt_sk(
|
||||
&module,
|
||||
p0,
|
||||
&sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// gglwe_{s2}(s1) -> s1 -> s2
|
||||
auto_key_apply.encrypt_sk(
|
||||
&module,
|
||||
p1,
|
||||
&sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0)
|
||||
auto_key.automorphism_inplace(&module, &auto_key_apply, scratch.borrow());
|
||||
|
||||
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank);
|
||||
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ksk);
|
||||
|
||||
let mut sk_auto: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk
|
||||
(0..rank).for_each(|i| {
|
||||
module.scalar_znx_automorphism(module.galois_element_inv(p0 * p1), &mut sk_auto, i, &sk, i);
|
||||
});
|
||||
|
||||
let mut sk_auto_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_auto_dft.dft(&module, &sk_auto);
|
||||
|
||||
(0..auto_key.rank_in()).for_each(|col_i| {
|
||||
(0..auto_key.rows()).for_each(|row_i| {
|
||||
auto_key.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
|
||||
|
||||
ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow());
|
||||
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i);
|
||||
|
||||
let noise_have: f64 = pt.data.std(0, basek).log2();
|
||||
let noise_want: f64 = log2_std_noise_gglwe_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
rank as f64,
|
||||
k_ksk,
|
||||
k_ksk,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
630
core/src/test_fft64/gglwe.rs
Normal file
630
core/src/test_fft64/gglwe.rs
Normal file
@@ -0,0 +1,630 @@
|
||||
use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToMut, ScratchOwned, Stats, VecZnxOps, ZnxViewMut};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::{GetRow, Infos},
|
||||
ggsw_ciphertext::GGSWCiphertext,
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
glwe_plaintext::GLWEPlaintext,
|
||||
keys::{SecretKey, SecretKeyFourier},
|
||||
keyswitch_key::GLWESwitchingKey,
|
||||
test_fft64::ggsw::noise_ggsw_product,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn encrypt_sk() {
|
||||
(1..4).for_each(|rank_in| {
|
||||
(1..4).for_each(|rank_out| {
|
||||
println!("test encrypt_sk rank_in rank_out: {} {}", rank_in, rank_out);
|
||||
test_encrypt_sk(12, 8, 54, 3.2, rank_in, rank_out);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_switch() {
|
||||
(1..4).for_each(|rank_in_s0s1| {
|
||||
(1..4).for_each(|rank_out_s0s1| {
|
||||
(1..4).for_each(|rank_out_s1s2| {
|
||||
println!(
|
||||
"test key_switch : ({},{},{})",
|
||||
rank_in_s0s1, rank_out_s0s1, rank_out_s1s2
|
||||
);
|
||||
test_key_switch(12, 15, 60, 3.2, rank_in_s0s1, rank_out_s0s1, rank_out_s1s2);
|
||||
})
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_switch_inplace() {
|
||||
(1..4).for_each(|rank_in_s0s1| {
|
||||
(1..4).for_each(|rank_out_s0s1| {
|
||||
println!(
|
||||
"test key_switch_inplace : ({},{})",
|
||||
rank_in_s0s1, rank_out_s0s1
|
||||
);
|
||||
test_key_switch_inplace(12, 15, 60, 3.2, rank_in_s0s1, rank_out_s0s1);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn external_product() {
|
||||
(1..4).for_each(|rank_in| {
|
||||
(1..4).for_each(|rank_out| {
|
||||
println!("test external_product rank: {} {}", rank_in, rank_out);
|
||||
test_external_product(12, 12, 60, 3.2, rank_in, rank_out);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn external_product_inplace() {
|
||||
(1..4).for_each(|rank_in| {
|
||||
(1..4).for_each(|rank_out| {
|
||||
println!(
|
||||
"test external_product_inplace rank: {} {}",
|
||||
rank_in, rank_out
|
||||
);
|
||||
test_external_product_inplace(12, 12, 60, 3.2, rank_in, rank_out);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_encrypt_sk(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in: usize, rank_out: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows = (k_ksk + basek - 1) / basek;
|
||||
|
||||
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in, rank_out);
|
||||
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ksk);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ksk.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, ksk.size()),
|
||||
);
|
||||
|
||||
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in);
|
||||
sk_in.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in);
|
||||
sk_in_dft.dft(&module, &sk_in);
|
||||
|
||||
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out);
|
||||
sk_out.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out);
|
||||
sk_out_dft.dft(&module, &sk_out);
|
||||
|
||||
ksk.encrypt_sk(
|
||||
&module,
|
||||
&sk_in,
|
||||
&sk_out_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank_out);
|
||||
|
||||
(0..ksk.rank_in()).for_each(|col_i| {
|
||||
(0..ksk.rows()).for_each(|row_i| {
|
||||
ksk.get_row(&module, row_i, col_i, &mut ct_glwe_fourier);
|
||||
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow());
|
||||
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i);
|
||||
let std_pt: f64 = pt.data.std(0, basek) * (k_ksk as f64).exp2();
|
||||
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_key_switch(
|
||||
log_n: usize,
|
||||
basek: usize,
|
||||
k_ksk: usize,
|
||||
sigma: f64,
|
||||
rank_in_s0s1: usize,
|
||||
rank_out_s0s1: usize,
|
||||
rank_out_s1s2: usize,
|
||||
) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows = (k_ksk + basek - 1) / basek;
|
||||
|
||||
let mut ct_gglwe_s0s1: GLWESwitchingKey<Vec<u8>, FFT64> =
|
||||
GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in_s0s1, rank_out_s0s1);
|
||||
let mut ct_gglwe_s1s2: GLWESwitchingKey<Vec<u8>, FFT64> =
|
||||
GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_out_s0s1, rank_out_s1s2);
|
||||
let mut ct_gglwe_s0s2: GLWESwitchingKey<Vec<u8>, FFT64> =
|
||||
GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in_s0s1, rank_out_s1s2);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_in_s0s1 | rank_out_s0s1, ct_gglwe_s0s1.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_s0s2.size())
|
||||
| GLWESwitchingKey::keyswitch_scratch_space(
|
||||
&module,
|
||||
ct_gglwe_s0s2.size(),
|
||||
ct_gglwe_s0s2.rank(),
|
||||
ct_gglwe_s0s1.size(),
|
||||
ct_gglwe_s0s1.rank(),
|
||||
ct_gglwe_s1s2.size(),
|
||||
),
|
||||
);
|
||||
|
||||
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in_s0s1);
|
||||
sk0.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk0_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in_s0s1);
|
||||
sk0_dft.dft(&module, &sk0);
|
||||
|
||||
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out_s0s1);
|
||||
sk1.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk1_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out_s0s1);
|
||||
sk1_dft.dft(&module, &sk1);
|
||||
|
||||
let mut sk2: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out_s1s2);
|
||||
sk2.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk2_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out_s1s2);
|
||||
sk2_dft.dft(&module, &sk2);
|
||||
|
||||
// gglwe_{s1}(s0) = s0 -> s1
|
||||
ct_gglwe_s0s1.encrypt_sk(
|
||||
&module,
|
||||
&sk0,
|
||||
&sk1_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// gglwe_{s2}(s1) -> s1 -> s2
|
||||
ct_gglwe_s1s2.encrypt_sk(
|
||||
&module,
|
||||
&sk1,
|
||||
&sk2_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0)
|
||||
ct_gglwe_s0s2.keyswitch(&module, &ct_gglwe_s0s1, &ct_gglwe_s1s2, scratch.borrow());
|
||||
|
||||
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank_out_s1s2);
|
||||
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ksk);
|
||||
|
||||
(0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| {
|
||||
(0..ct_gglwe_s0s2.rows()).for_each(|row_i| {
|
||||
ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
|
||||
ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow());
|
||||
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, col_i);
|
||||
|
||||
let noise_have: f64 = pt.data.std(0, basek).log2();
|
||||
let noise_want: f64 = log2_std_noise_gglwe_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
rank_out_s0s1 as f64,
|
||||
k_ksk,
|
||||
k_ksk,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_key_switch_inplace(log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank_in_s0s1: usize, rank_out_s0s1: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows: usize = (k_ksk + basek - 1) / basek;
|
||||
|
||||
let mut ct_gglwe_s0s1: GLWESwitchingKey<Vec<u8>, FFT64> =
|
||||
GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in_s0s1, rank_out_s0s1);
|
||||
let mut ct_gglwe_s1s2: GLWESwitchingKey<Vec<u8>, FFT64> =
|
||||
GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_out_s0s1, rank_out_s0s1);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out_s0s1, ct_gglwe_s0s1.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_s0s1.size())
|
||||
| GLWESwitchingKey::keyswitch_inplace_scratch_space(
|
||||
&module,
|
||||
ct_gglwe_s0s1.size(),
|
||||
ct_gglwe_s0s1.rank(),
|
||||
ct_gglwe_s1s2.size(),
|
||||
),
|
||||
);
|
||||
|
||||
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in_s0s1);
|
||||
sk0.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk0_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in_s0s1);
|
||||
sk0_dft.dft(&module, &sk0);
|
||||
|
||||
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out_s0s1);
|
||||
sk1.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk1_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out_s0s1);
|
||||
sk1_dft.dft(&module, &sk1);
|
||||
|
||||
let mut sk2: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out_s0s1);
|
||||
sk2.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk2_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out_s0s1);
|
||||
sk2_dft.dft(&module, &sk2);
|
||||
|
||||
// gglwe_{s1}(s0) = s0 -> s1
|
||||
ct_gglwe_s0s1.encrypt_sk(
|
||||
&module,
|
||||
&sk0,
|
||||
&sk1_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// gglwe_{s2}(s1) -> s1 -> s2
|
||||
ct_gglwe_s1s2.encrypt_sk(
|
||||
&module,
|
||||
&sk1,
|
||||
&sk2_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0)
|
||||
ct_gglwe_s0s1.keyswitch_inplace(&module, &ct_gglwe_s1s2, scratch.borrow());
|
||||
|
||||
let ct_gglwe_s0s2: GLWESwitchingKey<Vec<u8>, FFT64> = ct_gglwe_s0s1;
|
||||
|
||||
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank_out_s0s1);
|
||||
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ksk);
|
||||
|
||||
(0..ct_gglwe_s0s2.rank_in()).for_each(|col_i| {
|
||||
(0..ct_gglwe_s0s2.rows()).for_each(|row_i| {
|
||||
ct_gglwe_s0s2.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
|
||||
ct_glwe_dft.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow());
|
||||
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, col_i);
|
||||
|
||||
let noise_have: f64 = pt.data.std(0, basek).log2();
|
||||
let noise_want: f64 = log2_std_noise_gglwe_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
rank_out_s0s1 as f64,
|
||||
k_ksk,
|
||||
k_ksk,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_external_product(log_n: usize, basek: usize, k: usize, sigma: f64, rank_in: usize, rank_out: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
|
||||
let rows: usize = (k + basek - 1) / basek;
|
||||
|
||||
let mut ct_gglwe_in: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank_in, rank_out);
|
||||
let mut ct_gglwe_out: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank_in, rank_out);
|
||||
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank_out);
|
||||
|
||||
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ct_gglwe_in.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_out.size())
|
||||
| GLWESwitchingKey::external_product_scratch_space(
|
||||
&module,
|
||||
ct_gglwe_out.size(),
|
||||
ct_gglwe_in.size(),
|
||||
ct_rgsw.size(),
|
||||
rank_out,
|
||||
)
|
||||
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank_out, ct_rgsw.size()),
|
||||
);
|
||||
|
||||
let r: usize = 1;
|
||||
|
||||
pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r}
|
||||
|
||||
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in);
|
||||
sk_in.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in);
|
||||
sk_in_dft.dft(&module, &sk_in);
|
||||
|
||||
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out);
|
||||
sk_out.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out);
|
||||
sk_out_dft.dft(&module, &sk_out);
|
||||
|
||||
// gglwe_{s1}(s0) = s0 -> s1
|
||||
ct_gglwe_in.encrypt_sk(
|
||||
&module,
|
||||
&sk_in,
|
||||
&sk_out_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rgsw.encrypt_sk(
|
||||
&module,
|
||||
&pt_rgsw,
|
||||
&sk_out_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k)
|
||||
ct_gglwe_out.external_product(&module, &ct_gglwe_in, &ct_rgsw, scratch.borrow());
|
||||
|
||||
scratch = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ct_gglwe_in.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe_out.size())
|
||||
| GLWESwitchingKey::external_product_scratch_space(
|
||||
&module,
|
||||
ct_gglwe_out.size(),
|
||||
ct_gglwe_in.size(),
|
||||
ct_rgsw.size(),
|
||||
rank_out,
|
||||
)
|
||||
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank_out, ct_rgsw.size()),
|
||||
);
|
||||
|
||||
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank_out);
|
||||
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
|
||||
|
||||
(0..rank_in).for_each(|i| {
|
||||
module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data, i); // * X^{r}
|
||||
});
|
||||
|
||||
(0..rank_in).for_each(|col_i| {
|
||||
(0..ct_gglwe_out.rows()).for_each(|row_i| {
|
||||
ct_gglwe_out.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
|
||||
ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow());
|
||||
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i);
|
||||
|
||||
let noise_have: f64 = pt.data.std(0, basek).log2();
|
||||
|
||||
let var_gct_err_lhs: f64 = sigma * sigma;
|
||||
let var_gct_err_rhs: f64 = 0f64;
|
||||
|
||||
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
|
||||
let var_a0_err: f64 = sigma * sigma;
|
||||
let var_a1_err: f64 = 1f64 / 12f64;
|
||||
|
||||
let noise_want: f64 = noise_ggsw_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
var_msg,
|
||||
var_a0_err,
|
||||
var_a1_err,
|
||||
var_gct_err_lhs,
|
||||
var_gct_err_rhs,
|
||||
rank_out as f64,
|
||||
k,
|
||||
k,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_external_product_inplace(log_n: usize, basek: usize, k: usize, sigma: f64, rank_in: usize, rank_out: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
|
||||
let rows: usize = (k + basek - 1) / basek;
|
||||
|
||||
let mut ct_gglwe: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank_in, rank_out);
|
||||
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank_out);
|
||||
|
||||
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ct_gglwe.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_gglwe.size())
|
||||
| GLWESwitchingKey::external_product_inplace_scratch_space(&module, ct_gglwe.size(), ct_rgsw.size(), rank_out)
|
||||
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank_out, ct_rgsw.size()),
|
||||
);
|
||||
|
||||
let r: usize = 1;
|
||||
|
||||
pt_rgsw.to_mut().raw_mut()[r] = 1; // X^{r}
|
||||
|
||||
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in);
|
||||
sk_in.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in);
|
||||
sk_in_dft.dft(&module, &sk_in);
|
||||
|
||||
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out);
|
||||
sk_out.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out);
|
||||
sk_out_dft.dft(&module, &sk_out);
|
||||
|
||||
// gglwe_{s1}(s0) = s0 -> s1
|
||||
ct_gglwe.encrypt_sk(
|
||||
&module,
|
||||
&sk_in,
|
||||
&sk_out_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rgsw.encrypt_sk(
|
||||
&module,
|
||||
&pt_rgsw,
|
||||
&sk_out_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// gglwe_(m) (x) RGSW_(X^k) = gglwe_(m * X^k)
|
||||
ct_gglwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow());
|
||||
|
||||
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank_out);
|
||||
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
|
||||
|
||||
(0..rank_in).for_each(|i| {
|
||||
module.vec_znx_rotate_inplace(r as i64, &mut sk_in.data, i); // * X^{r}
|
||||
});
|
||||
|
||||
(0..rank_in).for_each(|col_i| {
|
||||
(0..ct_gglwe.rows()).for_each(|row_i| {
|
||||
ct_gglwe.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
|
||||
ct_glwe_dft.decrypt(&module, &mut pt, &sk_out_dft, scratch.borrow());
|
||||
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_in, col_i);
|
||||
|
||||
let noise_have: f64 = pt.data.std(0, basek).log2();
|
||||
|
||||
let var_gct_err_lhs: f64 = sigma * sigma;
|
||||
let var_gct_err_rhs: f64 = 0f64;
|
||||
|
||||
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
|
||||
let var_a0_err: f64 = sigma * sigma;
|
||||
let var_a1_err: f64 = 1f64 / 12f64;
|
||||
|
||||
let noise_want: f64 = noise_ggsw_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
var_msg,
|
||||
var_a0_err,
|
||||
var_a1_err,
|
||||
var_gct_err_lhs,
|
||||
var_gct_err_rhs,
|
||||
rank_out as f64,
|
||||
k,
|
||||
k,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn var_noise_gglwe_product(
|
||||
n: f64,
|
||||
basek: usize,
|
||||
var_xs: f64,
|
||||
var_msg: f64,
|
||||
var_a_err: f64,
|
||||
var_gct_err_lhs: f64,
|
||||
var_gct_err_rhs: f64,
|
||||
rank_in: f64,
|
||||
a_logq: usize,
|
||||
b_logq: usize,
|
||||
) -> f64 {
|
||||
let a_logq: usize = a_logq.min(b_logq);
|
||||
let a_cols: usize = (a_logq + basek - 1) / basek;
|
||||
|
||||
let b_scale = 2.0f64.powi(b_logq as i32);
|
||||
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
|
||||
|
||||
let base: f64 = (1 << (basek)) as f64;
|
||||
let var_base: f64 = base * base / 12f64;
|
||||
|
||||
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
|
||||
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
|
||||
let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
|
||||
noise += var_msg * var_a_err * a_scale * a_scale * n;
|
||||
noise *= rank_in;
|
||||
noise /= b_scale * b_scale;
|
||||
noise
|
||||
}
|
||||
|
||||
pub(crate) fn log2_std_noise_gglwe_product(
|
||||
n: f64,
|
||||
basek: usize,
|
||||
var_xs: f64,
|
||||
var_msg: f64,
|
||||
var_a_err: f64,
|
||||
var_gct_err_lhs: f64,
|
||||
var_gct_err_rhs: f64,
|
||||
rank_in: f64,
|
||||
a_logq: usize,
|
||||
b_logq: usize,
|
||||
) -> f64 {
|
||||
let mut noise: f64 = var_noise_gglwe_product(
|
||||
n,
|
||||
basek,
|
||||
var_xs,
|
||||
var_msg,
|
||||
var_a_err,
|
||||
var_gct_err_lhs,
|
||||
var_gct_err_rhs,
|
||||
rank_in,
|
||||
a_logq,
|
||||
b_logq,
|
||||
);
|
||||
noise = noise.sqrt();
|
||||
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
|
||||
}
|
||||
934
core/src/test_fft64/ggsw.rs
Normal file
934
core/src/test_fft64/ggsw.rs
Normal file
@@ -0,0 +1,934 @@
|
||||
use base2k::{
|
||||
FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScalarZnxOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc,
|
||||
VecZnxBigOps, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
automorphism::AutomorphismKey,
|
||||
elem::{GetRow, Infos},
|
||||
ggsw_ciphertext::GGSWCiphertext,
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
glwe_plaintext::GLWEPlaintext,
|
||||
keys::{SecretKey, SecretKeyFourier},
|
||||
keyswitch_key::GLWESwitchingKey,
|
||||
tensor_key::TensorKey,
|
||||
};
|
||||
|
||||
use super::gglwe::var_noise_gglwe_product;
|
||||
|
||||
#[test]
|
||||
fn encrypt_sk() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test encrypt_sk rank: {}", rank);
|
||||
test_encrypt_sk(11, 8, 54, 3.2, rank);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keyswitch() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test keyswitch rank: {}", rank);
|
||||
test_keyswitch(12, 15, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keyswitch_inplace() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test keyswitch_inplace rank: {}", rank);
|
||||
test_keyswitch_inplace(12, 15, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn automorphism() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test automorphism rank: {}", rank);
|
||||
test_automorphism(-5, 12, 15, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn automorphism_inplace() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test automorphism_inplace rank: {}", rank);
|
||||
test_automorphism_inplace(-5, 12, 15, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn external_product() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test external_product rank: {}", rank);
|
||||
test_external_product(12, 12, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn external_product_inplace() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test external_product rank: {}", rank);
|
||||
test_external_product_inplace(12, 15, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
|
||||
let rows: usize = (k_ggsw + basek - 1) / basek;
|
||||
|
||||
let mut ct: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size()),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
ct.encrypt_sk(
|
||||
&module,
|
||||
&pt_scalar,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
|
||||
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct.size());
|
||||
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct.size());
|
||||
|
||||
(0..ct.rank() + 1).for_each(|col_j| {
|
||||
(0..ct.rows()).for_each(|row_i| {
|
||||
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
|
||||
|
||||
// mul with sk[col_j-1]
|
||||
if col_j > 0 {
|
||||
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
||||
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
|
||||
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
||||
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
||||
}
|
||||
|
||||
ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
|
||||
|
||||
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let std_pt: f64 = pt_have.data.std(0, basek) * (k_ggsw as f64).exp2();
|
||||
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
|
||||
|
||||
pt_want.data.zero();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_keyswitch(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows: usize = (k + basek - 1) / basek;
|
||||
|
||||
let mut ct_in: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
|
||||
let mut ct_out: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
|
||||
let mut tsk: TensorKey<Vec<u8>, FFT64> = TensorKey::new(&module, basek, k, rows, rank);
|
||||
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank, rank);
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
|
||||
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_in.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_out.size())
|
||||
| GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size())
|
||||
| TensorKey::encrypt_sk_scratch_space(&module, rank, ksk.size())
|
||||
| GGSWCiphertext::keyswitch_scratch_space(
|
||||
&module,
|
||||
ct_out.size(),
|
||||
ct_in.size(),
|
||||
ksk.size(),
|
||||
tsk.size(),
|
||||
rank,
|
||||
),
|
||||
);
|
||||
|
||||
let var_xs: f64 = 0.5;
|
||||
|
||||
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk_in.fill_ternary_prob(var_xs, &mut source_xs);
|
||||
|
||||
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_in_dft.dft(&module, &sk_in);
|
||||
|
||||
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk_out.fill_ternary_prob(var_xs, &mut source_xs);
|
||||
|
||||
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_out_dft.dft(&module, &sk_out);
|
||||
|
||||
ksk.encrypt_sk(
|
||||
&module,
|
||||
&sk_in,
|
||||
&sk_out_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
tsk.encrypt_sk(
|
||||
&module,
|
||||
&sk_out_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
|
||||
|
||||
ct_in.encrypt_sk(
|
||||
&module,
|
||||
&pt_scalar,
|
||||
&sk_in_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_out.keyswitch(&module, &ct_in, &ksk, &tsk, scratch.borrow());
|
||||
|
||||
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank);
|
||||
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_out.size());
|
||||
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_out.size());
|
||||
|
||||
(0..ct_out.rank() + 1).for_each(|col_j| {
|
||||
(0..ct_out.rows()).for_each(|row_i| {
|
||||
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
|
||||
|
||||
// mul with sk[col_j-1]
|
||||
if col_j > 0 {
|
||||
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
||||
module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft, col_j - 1);
|
||||
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
||||
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
||||
}
|
||||
|
||||
ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
|
||||
|
||||
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, basek).log2();
|
||||
let noise_want: f64 = noise_ggsw_keyswitch(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
col_j,
|
||||
var_xs,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
rank as f64,
|
||||
k,
|
||||
k,
|
||||
);
|
||||
|
||||
println!("{} {}", noise_have, noise_want);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
|
||||
pt_want.data.zero();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_keyswitch_inplace(log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows: usize = (k + basek - 1) / basek;
|
||||
|
||||
let mut ct: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
|
||||
let mut tsk: TensorKey<Vec<u8>, FFT64> = TensorKey::new(&module, basek, k, rows, rank);
|
||||
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k, rows, rank, rank);
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
|
||||
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size())
|
||||
| GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size())
|
||||
| TensorKey::encrypt_sk_scratch_space(&module, rank, ksk.size())
|
||||
| GGSWCiphertext::keyswitch_inplace_scratch_space(&module, ct.size(), ksk.size(), tsk.size(), rank),
|
||||
);
|
||||
|
||||
let var_xs: f64 = 0.5;
|
||||
|
||||
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk_in.fill_ternary_prob(var_xs, &mut source_xs);
|
||||
|
||||
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_in_dft.dft(&module, &sk_in);
|
||||
|
||||
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk_out.fill_ternary_prob(var_xs, &mut source_xs);
|
||||
|
||||
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_out_dft.dft(&module, &sk_out);
|
||||
|
||||
ksk.encrypt_sk(
|
||||
&module,
|
||||
&sk_in,
|
||||
&sk_out_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
tsk.encrypt_sk(
|
||||
&module,
|
||||
&sk_out_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
|
||||
|
||||
ct.encrypt_sk(
|
||||
&module,
|
||||
&pt_scalar,
|
||||
&sk_in_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct.keyswitch_inplace(&module, &ksk, &tsk, scratch.borrow());
|
||||
|
||||
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank);
|
||||
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct.size());
|
||||
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct.size());
|
||||
|
||||
(0..ct.rank() + 1).for_each(|col_j| {
|
||||
(0..ct.rows()).for_each(|row_i| {
|
||||
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
|
||||
|
||||
// mul with sk[col_j-1]
|
||||
if col_j > 0 {
|
||||
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
||||
module.svp_apply_inplace(&mut pt_dft, 0, &sk_out_dft, col_j - 1);
|
||||
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
||||
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
||||
}
|
||||
|
||||
ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
|
||||
|
||||
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, basek).log2();
|
||||
let noise_want: f64 = noise_ggsw_keyswitch(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
col_j,
|
||||
var_xs,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
rank as f64,
|
||||
k,
|
||||
k,
|
||||
);
|
||||
|
||||
println!("{} {}", noise_have, noise_want);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
|
||||
pt_want.data.zero();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn noise_ggsw_keyswitch(
|
||||
n: f64,
|
||||
basek: usize,
|
||||
col: usize,
|
||||
var_xs: f64,
|
||||
var_a_err: f64,
|
||||
var_gct_err_lhs: f64,
|
||||
var_gct_err_rhs: f64,
|
||||
rank: f64,
|
||||
a_logq: usize,
|
||||
b_logq: usize,
|
||||
) -> f64 {
|
||||
let var_si_x_sj: f64 = n * var_xs * var_xs;
|
||||
|
||||
// Initial KS for col = 0
|
||||
let mut noise: f64 = var_noise_gglwe_product(
|
||||
n,
|
||||
basek,
|
||||
var_xs,
|
||||
var_xs,
|
||||
var_a_err,
|
||||
var_gct_err_lhs,
|
||||
var_gct_err_rhs,
|
||||
rank,
|
||||
a_logq,
|
||||
b_logq,
|
||||
);
|
||||
|
||||
// Other GGSW reconstruction for col > 0
|
||||
if col > 0 {
|
||||
noise += var_noise_gglwe_product(
|
||||
n,
|
||||
basek,
|
||||
var_xs,
|
||||
var_si_x_sj,
|
||||
var_a_err + 1f64 / 12.0,
|
||||
var_gct_err_lhs,
|
||||
var_gct_err_rhs,
|
||||
rank,
|
||||
a_logq,
|
||||
b_logq,
|
||||
);
|
||||
noise += n * noise * var_xs * 0.5;
|
||||
}
|
||||
|
||||
noise = noise.sqrt();
|
||||
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
|
||||
}
|
||||
|
||||
fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows: usize = (k + basek - 1) / basek;
|
||||
|
||||
let mut ct_in: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
|
||||
let mut ct_out: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
|
||||
let mut tensor_key: TensorKey<Vec<u8>, FFT64> = TensorKey::new(&module, basek, k, rows, rank);
|
||||
let mut auto_key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank);
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
|
||||
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_in.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_out.size())
|
||||
| AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size())
|
||||
| TensorKey::encrypt_sk_scratch_space(&module, rank, tensor_key.size())
|
||||
| GGSWCiphertext::automorphism_scratch_space(
|
||||
&module,
|
||||
ct_out.size(),
|
||||
ct_in.size(),
|
||||
auto_key.size(),
|
||||
tensor_key.size(),
|
||||
rank,
|
||||
),
|
||||
);
|
||||
|
||||
let var_xs: f64 = 0.5;
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(var_xs, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
auto_key.encrypt_sk(
|
||||
&module,
|
||||
p,
|
||||
&sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
tensor_key.encrypt_sk(
|
||||
&module,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
|
||||
|
||||
ct_in.encrypt_sk(
|
||||
&module,
|
||||
&pt_scalar,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_out.automorphism(&module, &ct_in, &auto_key, &tensor_key, scratch.borrow());
|
||||
|
||||
module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0);
|
||||
|
||||
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank);
|
||||
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_out.size());
|
||||
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_out.size());
|
||||
|
||||
(0..ct_out.rank() + 1).for_each(|col_j| {
|
||||
(0..ct_out.rows()).for_each(|row_i| {
|
||||
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
|
||||
|
||||
// mul with sk[col_j-1]
|
||||
if col_j > 0 {
|
||||
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
||||
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
|
||||
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
||||
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
||||
}
|
||||
|
||||
ct_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
|
||||
|
||||
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, basek).log2();
|
||||
let noise_want: f64 = noise_ggsw_keyswitch(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
col_j,
|
||||
var_xs,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
rank as f64,
|
||||
k,
|
||||
k,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
|
||||
pt_want.data.zero();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_automorphism_inplace(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows: usize = (k + basek - 1) / basek;
|
||||
|
||||
let mut ct: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
|
||||
let mut tensor_key: TensorKey<Vec<u8>, FFT64> = TensorKey::new(&module, basek, k, rows, rank);
|
||||
let mut auto_key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank);
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
|
||||
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct.size())
|
||||
| AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size())
|
||||
| TensorKey::encrypt_sk_scratch_space(&module, rank, tensor_key.size())
|
||||
| GGSWCiphertext::automorphism_inplace_scratch_space(&module, ct.size(), auto_key.size(), tensor_key.size(), rank),
|
||||
);
|
||||
|
||||
let var_xs: f64 = 0.5;
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(var_xs, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
auto_key.encrypt_sk(
|
||||
&module,
|
||||
p,
|
||||
&sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
tensor_key.encrypt_sk(
|
||||
&module,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
|
||||
|
||||
ct.encrypt_sk(
|
||||
&module,
|
||||
&pt_scalar,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct.automorphism_inplace(&module, &auto_key, &tensor_key, scratch.borrow());
|
||||
|
||||
module.scalar_znx_automorphism_inplace(p, &mut pt_scalar, 0);
|
||||
|
||||
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank);
|
||||
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct.size());
|
||||
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct.size());
|
||||
|
||||
(0..ct.rank() + 1).for_each(|col_j| {
|
||||
(0..ct.rows()).for_each(|row_i| {
|
||||
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
|
||||
|
||||
// mul with sk[col_j-1]
|
||||
if col_j > 0 {
|
||||
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
||||
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
|
||||
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
||||
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
||||
}
|
||||
|
||||
ct.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
|
||||
|
||||
ct_glwe_fourier.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, basek).log2();
|
||||
let noise_want: f64 = noise_ggsw_keyswitch(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
col_j,
|
||||
var_xs,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
rank as f64,
|
||||
k,
|
||||
k,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
|
||||
pt_want.data.zero();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
|
||||
let rows: usize = (k_ggsw + basek - 1) / basek;
|
||||
|
||||
let mut ct_ggsw_rhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||
let mut ct_ggsw_lhs_in: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||
let mut ct_ggsw_lhs_out: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||
let mut pt_ggsw_lhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
let mut pt_ggsw_rhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs);
|
||||
|
||||
let k: usize = 1;
|
||||
|
||||
pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size())
|
||||
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs_in.size())
|
||||
| GGSWCiphertext::external_product_scratch_space(
|
||||
&module,
|
||||
ct_ggsw_lhs_out.size(),
|
||||
ct_ggsw_lhs_in.size(),
|
||||
ct_ggsw_rhs.size(),
|
||||
rank,
|
||||
),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
ct_ggsw_rhs.encrypt_sk(
|
||||
&module,
|
||||
&pt_ggsw_rhs,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_ggsw_lhs_in.encrypt_sk(
|
||||
&module,
|
||||
&pt_ggsw_lhs,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_ggsw_lhs_out.external_product(&module, &ct_ggsw_lhs_in, &ct_ggsw_rhs, scratch.borrow());
|
||||
|
||||
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
|
||||
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size());
|
||||
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size());
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||
|
||||
module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0);
|
||||
|
||||
(0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| {
|
||||
(0..ct_ggsw_lhs_out.rows()).for_each(|row_i| {
|
||||
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0);
|
||||
|
||||
if col_j > 0 {
|
||||
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
||||
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
|
||||
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
||||
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
||||
}
|
||||
|
||||
ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
|
||||
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt.data.std(0, basek).log2();
|
||||
|
||||
let var_gct_err_lhs: f64 = sigma * sigma;
|
||||
let var_gct_err_rhs: f64 = 0f64;
|
||||
|
||||
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
|
||||
let var_a0_err: f64 = sigma * sigma;
|
||||
let var_a1_err: f64 = 1f64 / 12f64;
|
||||
|
||||
let noise_want: f64 = noise_ggsw_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
var_msg,
|
||||
var_a0_err,
|
||||
var_a1_err,
|
||||
var_gct_err_lhs,
|
||||
var_gct_err_rhs,
|
||||
rank as f64,
|
||||
k_ggsw,
|
||||
k_ggsw,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"have: {} want: {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
|
||||
pt_want.data.zero();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows: usize = (k_ggsw + basek - 1) / basek;
|
||||
|
||||
let mut ct_ggsw_rhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||
let mut ct_ggsw_lhs: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||
let mut pt_ggsw_lhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
let mut pt_ggsw_rhs: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
pt_ggsw_lhs.fill_ternary_prob(0, 0.5, &mut source_xs);
|
||||
|
||||
let k: usize = 1;
|
||||
|
||||
pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size())
|
||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs.size())
|
||||
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs.size())
|
||||
| GGSWCiphertext::external_product_inplace_scratch_space(&module, ct_ggsw_lhs.size(), ct_ggsw_rhs.size(), rank),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
ct_ggsw_rhs.encrypt_sk(
|
||||
&module,
|
||||
&pt_ggsw_rhs,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_ggsw_lhs.encrypt_sk(
|
||||
&module,
|
||||
&pt_ggsw_lhs,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_ggsw_lhs.external_product_inplace(&module, &ct_ggsw_rhs, scratch.borrow());
|
||||
|
||||
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
|
||||
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs.size());
|
||||
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs.size());
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||
|
||||
module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0);
|
||||
|
||||
(0..ct_ggsw_lhs.rank() + 1).for_each(|col_j| {
|
||||
(0..ct_ggsw_lhs.rows()).for_each(|row_i| {
|
||||
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0);
|
||||
|
||||
if col_j > 0 {
|
||||
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
||||
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
|
||||
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
||||
module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
||||
}
|
||||
|
||||
ct_ggsw_lhs.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
|
||||
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt.data.std(0, basek).log2();
|
||||
|
||||
let var_gct_err_lhs: f64 = sigma * sigma;
|
||||
let var_gct_err_rhs: f64 = 0f64;
|
||||
|
||||
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
|
||||
let var_a0_err: f64 = sigma * sigma;
|
||||
let var_a1_err: f64 = 1f64 / 12f64;
|
||||
|
||||
let noise_want: f64 = noise_ggsw_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
var_msg,
|
||||
var_a0_err,
|
||||
var_a1_err,
|
||||
var_gct_err_lhs,
|
||||
var_gct_err_rhs,
|
||||
rank as f64,
|
||||
k_ggsw,
|
||||
k_ggsw,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"have: {} want: {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
|
||||
pt_want.data.zero();
|
||||
});
|
||||
});
|
||||
}
|
||||
pub(crate) fn noise_ggsw_product(
|
||||
n: f64,
|
||||
basek: usize,
|
||||
var_xs: f64,
|
||||
var_msg: f64,
|
||||
var_a0_err: f64,
|
||||
var_a1_err: f64,
|
||||
var_gct_err_lhs: f64,
|
||||
var_gct_err_rhs: f64,
|
||||
rank: f64,
|
||||
a_logq: usize,
|
||||
b_logq: usize,
|
||||
) -> f64 {
|
||||
let a_logq: usize = a_logq.min(b_logq);
|
||||
let a_cols: usize = (a_logq + basek - 1) / basek;
|
||||
|
||||
let b_scale = 2.0f64.powi(b_logq as i32);
|
||||
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
|
||||
|
||||
let base: f64 = (1 << (basek)) as f64;
|
||||
let var_base: f64 = base * base / 12f64;
|
||||
|
||||
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
|
||||
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
|
||||
let mut noise: f64 = (rank + 1.0) * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
|
||||
noise += var_msg * var_a0_err * a_scale * a_scale * n;
|
||||
noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs * rank;
|
||||
noise = noise.sqrt();
|
||||
noise /= b_scale;
|
||||
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
|
||||
}
|
||||
805
core/src/test_fft64/glwe.rs
Normal file
805
core/src/test_fft64/glwe.rs
Normal file
@@ -0,0 +1,805 @@
|
||||
use base2k::{
|
||||
Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut,
|
||||
ZnxViewMut, ZnxZero,
|
||||
};
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
automorphism::AutomorphismKey,
|
||||
elem::Infos,
|
||||
ggsw_ciphertext::GGSWCiphertext,
|
||||
glwe_ciphertext::GLWECiphertext,
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
glwe_plaintext::GLWEPlaintext,
|
||||
keys::{GLWEPublicKey, SecretKey, SecretKeyFourier},
|
||||
keyswitch_key::GLWESwitchingKey,
|
||||
test_fft64::{gglwe::log2_std_noise_gglwe_product, ggsw::noise_ggsw_product},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn encrypt_sk() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test encrypt_sk rank: {}", rank);
|
||||
test_encrypt_sk(11, 8, 54, 30, 3.2, rank);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_zero_sk() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test encrypt_zero_sk rank: {}", rank);
|
||||
test_encrypt_zero_sk(11, 8, 64, 3.2, rank);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_pk() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test encrypt_pk rank: {}", rank);
|
||||
test_encrypt_pk(11, 8, 64, 64, 3.2, rank)
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keyswitch() {
|
||||
(1..4).for_each(|rank_in| {
|
||||
(1..4).for_each(|rank_out| {
|
||||
println!("test keyswitch rank_in: {} rank_out: {}", rank_in, rank_out);
|
||||
test_keyswitch(12, 12, 60, 45, 60, rank_in, rank_out, 3.2);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keyswitch_inplace() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test keyswitch_inplace rank: {}", rank);
|
||||
test_keyswitch_inplace(12, 12, 60, 45, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn external_product() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test external_product rank: {}", rank);
|
||||
test_external_product(12, 12, 60, 45, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn external_product_inplace() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test external_product rank: {}", rank);
|
||||
test_external_product_inplace(12, 15, 60, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn automorphism_inplace() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test automorphism_inplace rank: {}", rank);
|
||||
test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn automorphism() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test automorphism rank: {}", rank);
|
||||
test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
|
||||
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
|
||||
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_pt);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size()) | GLWECiphertext::decrypt_scratch_space(&module, ct.size()),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
let mut data_want: Vec<i64> = vec![0i64; module.n()];
|
||||
|
||||
data_want
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = source_xa.next_i64() & 0xFF);
|
||||
|
||||
pt.data.encode_vec_i64(0, basek, k_pt, &data_want, 10);
|
||||
|
||||
ct.encrypt_sk(
|
||||
&module,
|
||||
&pt,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
pt.data.zero();
|
||||
|
||||
ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||
|
||||
let mut data_have: Vec<i64> = vec![0i64; module.n()];
|
||||
|
||||
pt.data
|
||||
.decode_vec_i64(0, basek, pt.size() * basek, &mut data_have);
|
||||
|
||||
// TODO: properly assert the decryption noise through std(dec(ct) - pt)
|
||||
let scale: f64 = (1 << (pt.size() * basek - k_pt)) as f64;
|
||||
izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| {
|
||||
let b_scaled = (*b as f64) / scale;
|
||||
assert!(
|
||||
(*a as f64 - b_scaled).abs() < 0.1,
|
||||
"{} {}",
|
||||
*a as f64,
|
||||
b_scaled
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
fn test_encrypt_zero_sk(log_n: usize, basek: usize, k_ct: usize, sigma: f64, rank: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
|
||||
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([1u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
let mut ct_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct, rank);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWECiphertextFourier::decrypt_scratch_space(&module, ct_dft.size())
|
||||
| GLWECiphertextFourier::encrypt_sk_scratch_space(&module, rank, ct_dft.size()),
|
||||
);
|
||||
|
||||
ct_dft.encrypt_zero_sk(
|
||||
&module,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||
|
||||
assert!((sigma - pt.data.std(0, basek) * (k_ct as f64).exp2()) <= 0.2);
|
||||
}
|
||||
|
||||
fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma: f64, rank: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
|
||||
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
let mut source_xu: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
let mut pk: GLWEPublicKey<Vec<u8>, FFT64> = GLWEPublicKey::new(&module, basek, k_pk, rank);
|
||||
pk.generate(&module, &sk_dft, &mut source_xa, &mut source_xe, sigma);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size())
|
||||
| GLWECiphertext::decrypt_scratch_space(&module, ct.size())
|
||||
| GLWECiphertext::encrypt_pk_scratch_space(&module, pk.size()),
|
||||
);
|
||||
|
||||
let mut data_want: Vec<i64> = vec![0i64; module.n()];
|
||||
|
||||
data_want
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = source_xa.next_i64() & 0);
|
||||
|
||||
pt_want.data.encode_vec_i64(0, basek, k_ct, &data_want, 10);
|
||||
|
||||
ct.encrypt_pk(
|
||||
&module,
|
||||
&pt_want,
|
||||
&pk,
|
||||
&mut source_xu,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
|
||||
|
||||
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0);
|
||||
|
||||
let noise_have: f64 = pt_want.data.std(0, basek).log2();
|
||||
let noise_want: f64 = ((((rank as f64) + 1.0) * module.n() as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (k_ct as f64);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() < 0.2,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
}
|
||||
|
||||
fn test_keyswitch(
|
||||
log_n: usize,
|
||||
basek: usize,
|
||||
k_keyswitch: usize,
|
||||
k_ct_in: usize,
|
||||
k_ct_out: usize,
|
||||
rank_in: usize,
|
||||
rank_out: usize,
|
||||
sigma: f64,
|
||||
) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows: usize = (k_ct_in + basek - 1) / basek;
|
||||
|
||||
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k_keyswitch, rows, rank_in, rank_out);
|
||||
let mut ct_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_in, rank_in);
|
||||
let mut ct_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_out, rank_out);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_in);
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_out);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
// Random input plaintext
|
||||
pt_want
|
||||
.data
|
||||
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_in, ksk.size())
|
||||
| GLWECiphertext::decrypt_scratch_space(&module, ct_out.size())
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size())
|
||||
| GLWECiphertext::keyswitch_scratch_space(
|
||||
&module,
|
||||
ct_out.size(),
|
||||
rank_out,
|
||||
ct_in.size(),
|
||||
rank_in,
|
||||
ksk.size(),
|
||||
),
|
||||
);
|
||||
|
||||
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in);
|
||||
sk_in.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in);
|
||||
sk_in_dft.dft(&module, &sk_in);
|
||||
|
||||
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out);
|
||||
sk_out.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out);
|
||||
sk_out_dft.dft(&module, &sk_out);
|
||||
|
||||
ksk.encrypt_sk(
|
||||
&module,
|
||||
&sk_in,
|
||||
&sk_out_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_in.encrypt_sk(
|
||||
&module,
|
||||
&pt_want,
|
||||
&sk_in_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_out.keyswitch(&module, &ct_in, &ksk, scratch.borrow());
|
||||
|
||||
ct_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, basek).log2();
|
||||
let noise_want: f64 = log2_std_noise_gglwe_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
rank_in as f64,
|
||||
k_ct_in,
|
||||
k_keyswitch,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
}
|
||||
|
||||
fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, rank: usize, sigma: f64) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows: usize = (k_ct + basek - 1) / basek;
|
||||
|
||||
let mut ct_grlwe: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank, rank);
|
||||
let mut ct_rlwe: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
// Random input plaintext
|
||||
pt_want
|
||||
.data
|
||||
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_grlwe.size())
|
||||
| GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size())
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size())
|
||||
| GLWECiphertext::keyswitch_inplace_scratch_space(&module, ct_rlwe.size(), rank, ct_grlwe.size()),
|
||||
);
|
||||
|
||||
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk0.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk0_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk0_dft.dft(&module, &sk0);
|
||||
|
||||
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk1.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk1_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk1_dft.dft(&module, &sk1);
|
||||
|
||||
ct_grlwe.encrypt_sk(
|
||||
&module,
|
||||
&sk0,
|
||||
&sk1_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rlwe.encrypt_sk(
|
||||
&module,
|
||||
&pt_want,
|
||||
&sk0_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rlwe.keyswitch_inplace(&module, &ct_grlwe, scratch.borrow());
|
||||
|
||||
ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, basek).log2();
|
||||
let noise_want: f64 = log2_std_noise_gglwe_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
rank as f64,
|
||||
k_ct,
|
||||
k_ksk,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
}
|
||||
|
||||
fn test_automorphism(
|
||||
log_n: usize,
|
||||
basek: usize,
|
||||
p: i64,
|
||||
k_autokey: usize,
|
||||
k_ct_in: usize,
|
||||
k_ct_out: usize,
|
||||
rank: usize,
|
||||
sigma: f64,
|
||||
) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows: usize = (k_ct_in + basek - 1) / basek;
|
||||
|
||||
let mut autokey: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_autokey, rows, rank);
|
||||
let mut ct_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_in, rank);
|
||||
let mut ct_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_out, rank);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_in);
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_out);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
pt_want
|
||||
.data
|
||||
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, autokey.size())
|
||||
| GLWECiphertext::decrypt_scratch_space(&module, ct_out.size())
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size())
|
||||
| GLWECiphertext::automorphism_scratch_space(&module, ct_out.size(), rank, ct_in.size(), autokey.size()),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
autokey.encrypt_sk(
|
||||
&module,
|
||||
p,
|
||||
&sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_in.encrypt_sk(
|
||||
&module,
|
||||
&pt_want,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_out.automorphism(&module, &ct_in, &autokey, scratch.borrow());
|
||||
ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||
module.vec_znx_automorphism_inplace(p, &mut pt_want, 0);
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut pt_have, 0, scratch.borrow());
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, basek).log2();
|
||||
|
||||
println!("{}", noise_have);
|
||||
|
||||
let noise_want: f64 = log2_std_noise_gglwe_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
rank as f64,
|
||||
k_ct_in,
|
||||
k_autokey,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
}
|
||||
|
||||
fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usize, k_ct: usize, rank: usize, sigma: f64) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows: usize = (k_ct + basek - 1) / basek;
|
||||
|
||||
let mut autokey: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_autokey, rows, rank);
|
||||
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
// Random input plaintext
|
||||
pt_want
|
||||
.data
|
||||
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, autokey.size())
|
||||
| GLWECiphertext::decrypt_scratch_space(&module, ct.size())
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size())
|
||||
| GLWECiphertext::automorphism_inplace_scratch_space(&module, ct.size(), rank, autokey.size()),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
autokey.encrypt_sk(
|
||||
&module,
|
||||
p,
|
||||
&sk,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct.encrypt_sk(
|
||||
&module,
|
||||
&pt_want,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct.automorphism_inplace(&module, &autokey, scratch.borrow());
|
||||
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||
module.vec_znx_automorphism_inplace(p, &mut pt_want, 0);
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut pt_have, 0, scratch.borrow());
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, basek).log2();
|
||||
let noise_want: f64 = log2_std_noise_gglwe_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
rank as f64,
|
||||
k_ct,
|
||||
k_autokey,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
}
|
||||
|
||||
fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usize, k_ct_out: usize, rank: usize, sigma: f64) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
|
||||
let rows: usize = (k_ct_in + basek - 1) / basek;
|
||||
|
||||
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||
let mut ct_rlwe_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_in, rank);
|
||||
let mut ct_rlwe_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_out, rank);
|
||||
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_in);
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_out);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
// Random input plaintext
|
||||
pt_want
|
||||
.data
|
||||
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
|
||||
|
||||
pt_want.to_mut().at_mut(0, 0)[1] = 1;
|
||||
|
||||
let k: usize = 1;
|
||||
|
||||
pt_rgsw.raw_mut()[k] = 1; // X^{k}
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size())
|
||||
| GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe_out.size())
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe_in.size())
|
||||
| GLWECiphertext::external_product_scratch_space(
|
||||
&module,
|
||||
ct_rlwe_out.size(),
|
||||
ct_rlwe_in.size(),
|
||||
ct_rgsw.size(),
|
||||
rank,
|
||||
),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
ct_rgsw.encrypt_sk(
|
||||
&module,
|
||||
&pt_rgsw,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rlwe_in.encrypt_sk(
|
||||
&module,
|
||||
&pt_want,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rlwe_out.external_product(&module, &ct_rlwe_in, &ct_rgsw, scratch.borrow());
|
||||
|
||||
ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0);
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, basek).log2();
|
||||
|
||||
let var_gct_err_lhs: f64 = sigma * sigma;
|
||||
let var_gct_err_rhs: f64 = 0f64;
|
||||
|
||||
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
|
||||
let var_a0_err: f64 = sigma * sigma;
|
||||
let var_a1_err: f64 = 1f64 / 12f64;
|
||||
|
||||
let noise_want: f64 = noise_ggsw_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
var_msg,
|
||||
var_a0_err,
|
||||
var_a1_err,
|
||||
var_gct_err_lhs,
|
||||
var_gct_err_rhs,
|
||||
rank as f64,
|
||||
k_ct_in,
|
||||
k_ggsw,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
}
|
||||
|
||||
fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct: usize, rank: usize, sigma: f64) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows: usize = (k_ct + basek - 1) / basek;
|
||||
|
||||
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||
let mut ct_rlwe: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
|
||||
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
// Random input plaintext
|
||||
pt_want
|
||||
.data
|
||||
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
|
||||
|
||||
pt_want.to_mut().at_mut(0, 0)[1] = 1;
|
||||
|
||||
let k: usize = 1;
|
||||
|
||||
pt_rgsw.raw_mut()[k] = 1; // X^{k}
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size())
|
||||
| GLWECiphertext::decrypt_scratch_space(&module, ct_rlwe.size())
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_rlwe.size())
|
||||
| GLWECiphertext::external_product_inplace_scratch_space(&module, ct_rlwe.size(), ct_rgsw.size(), rank),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
ct_rgsw.encrypt_sk(
|
||||
&module,
|
||||
&pt_rgsw,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rlwe.encrypt_sk(
|
||||
&module,
|
||||
&pt_want,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rlwe.external_product_inplace(&module, &ct_rgsw, scratch.borrow());
|
||||
|
||||
ct_rlwe.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0);
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, basek).log2();
|
||||
|
||||
let var_gct_err_lhs: f64 = sigma * sigma;
|
||||
let var_gct_err_rhs: f64 = 0f64;
|
||||
|
||||
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
|
||||
let var_a0_err: f64 = sigma * sigma;
|
||||
let var_a1_err: f64 = 1f64 / 12f64;
|
||||
|
||||
let noise_want: f64 = noise_ggsw_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
var_msg,
|
||||
var_a0_err,
|
||||
var_a1_err,
|
||||
var_gct_err_lhs,
|
||||
var_gct_err_rhs,
|
||||
rank as f64,
|
||||
k_ct,
|
||||
k_ggsw,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
}
|
||||
445
core/src/test_fft64/glwe_fourier.rs
Normal file
445
core/src/test_fft64/glwe_fourier.rs
Normal file
@@ -0,0 +1,445 @@
|
||||
use crate::{
|
||||
elem::Infos,
|
||||
ggsw_ciphertext::GGSWCiphertext,
|
||||
glwe_ciphertext::GLWECiphertext,
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
glwe_plaintext::GLWEPlaintext,
|
||||
keys::{SecretKey, SecretKeyFourier},
|
||||
keyswitch_key::GLWESwitchingKey,
|
||||
test_fft64::{gglwe::log2_std_noise_gglwe_product, ggsw::noise_ggsw_product},
|
||||
};
|
||||
use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut, ZnxViewMut};
|
||||
use sampling::source::Source;
|
||||
|
||||
#[test]
|
||||
fn keyswitch() {
|
||||
(1..4).for_each(|rank_in| {
|
||||
(1..4).for_each(|rank_out| {
|
||||
println!("test keyswitch rank_in: {} rank_out: {}", rank_in, rank_out);
|
||||
test_keyswitch(12, 12, 60, 45, 60, rank_in, rank_out, 3.2);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keyswitch_inplace() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test keyswitch_inplace rank: {}", rank);
|
||||
test_keyswitch_inplace(12, 12, 60, 45, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn external_product() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test external_product rank: {}", rank);
|
||||
test_external_product(12, 12, 60, 45, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn external_product_inplace() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test external_product rank: {}", rank);
|
||||
test_external_product_inplace(12, 15, 60, 60, rank, 3.2);
|
||||
});
|
||||
}
|
||||
|
||||
fn test_keyswitch(
|
||||
log_n: usize,
|
||||
basek: usize,
|
||||
k_ksk: usize,
|
||||
k_ct_in: usize,
|
||||
k_ct_out: usize,
|
||||
rank_in: usize,
|
||||
rank_out: usize,
|
||||
sigma: f64,
|
||||
) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
|
||||
let rows: usize = (k_ct_in + basek - 1) / basek;
|
||||
|
||||
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank_in, rank_out);
|
||||
let mut ct_glwe_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_in, rank_in);
|
||||
let mut ct_glwe_dft_in: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct_in, rank_in);
|
||||
let mut ct_glwe_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_out, rank_out);
|
||||
let mut ct_glwe_dft_out: GLWECiphertextFourier<Vec<u8>, FFT64> =
|
||||
GLWECiphertextFourier::new(&module, basek, k_ct_out, rank_out);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_in);
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_out);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
// Random input plaintext
|
||||
pt_want
|
||||
.data
|
||||
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank_out, ksk.size())
|
||||
| GLWECiphertext::decrypt_scratch_space(&module, ct_glwe_out.size())
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_glwe_in.size())
|
||||
| GLWECiphertextFourier::keyswitch_scratch_space(
|
||||
&module,
|
||||
ct_glwe_out.size(),
|
||||
rank_out,
|
||||
ct_glwe_in.size(),
|
||||
rank_in,
|
||||
ksk.size(),
|
||||
),
|
||||
);
|
||||
|
||||
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_in);
|
||||
sk_in.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_in);
|
||||
sk_in_dft.dft(&module, &sk_in);
|
||||
|
||||
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank_out);
|
||||
sk_out.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank_out);
|
||||
sk_out_dft.dft(&module, &sk_out);
|
||||
|
||||
ksk.encrypt_sk(
|
||||
&module,
|
||||
&sk_in,
|
||||
&sk_out_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_glwe_in.encrypt_sk(
|
||||
&module,
|
||||
&pt_want,
|
||||
&sk_in_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_glwe_in.dft(&module, &mut ct_glwe_dft_in);
|
||||
ct_glwe_dft_out.keyswitch(&module, &ct_glwe_dft_in, &ksk, scratch.borrow());
|
||||
ct_glwe_dft_out.idft(&module, &mut ct_glwe_out, scratch.borrow());
|
||||
|
||||
ct_glwe_out.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, basek).log2();
|
||||
let noise_want: f64 = log2_std_noise_gglwe_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
rank_in as f64,
|
||||
k_ct_in,
|
||||
k_ksk,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
}
|
||||
|
||||
fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize, rank: usize, sigma: f64) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows: usize = (k_ct + basek - 1) / basek;
|
||||
|
||||
let mut ksk: GLWESwitchingKey<Vec<u8>, FFT64> = GLWESwitchingKey::new(&module, basek, k_ksk, rows, rank, rank);
|
||||
let mut ct_glwe: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
|
||||
let mut ct_rlwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct, rank);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
// Random input plaintext
|
||||
pt_want
|
||||
.data
|
||||
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ksk.size())
|
||||
| GLWECiphertext::decrypt_scratch_space(&module, ct_glwe.size())
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_glwe.size())
|
||||
| GLWECiphertextFourier::keyswitch_inplace_scratch_space(&module, ct_rlwe_dft.size(), ksk.size(), rank),
|
||||
);
|
||||
|
||||
let mut sk_in: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk_in.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_in_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_in_dft.dft(&module, &sk_in);
|
||||
|
||||
let mut sk_out: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk_out.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_out_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_out_dft.dft(&module, &sk_out);
|
||||
|
||||
ksk.encrypt_sk(
|
||||
&module,
|
||||
&sk_in,
|
||||
&sk_out_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_glwe.encrypt_sk(
|
||||
&module,
|
||||
&pt_want,
|
||||
&sk_in_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_glwe.dft(&module, &mut ct_rlwe_dft);
|
||||
ct_rlwe_dft.keyswitch_inplace(&module, &ksk, scratch.borrow());
|
||||
ct_rlwe_dft.idft(&module, &mut ct_glwe, scratch.borrow());
|
||||
|
||||
ct_glwe.decrypt(&module, &mut pt_have, &sk_out_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, basek).log2();
|
||||
let noise_want: f64 = log2_std_noise_gglwe_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
rank as f64,
|
||||
k_ct,
|
||||
k_ksk,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
}
|
||||
|
||||
fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, k_ct_in: usize, k_ct_out: usize, rank: usize, sigma: f64) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
|
||||
let rows: usize = (k_ct_in + basek - 1) / basek;
|
||||
|
||||
let mut ct_rgsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||
let mut ct_in: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_in, rank);
|
||||
let mut ct_out: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct_out, rank);
|
||||
let mut ct_in_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct_in, rank);
|
||||
let mut ct_out_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct_out, rank);
|
||||
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_in);
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct_out);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
// Random input plaintext
|
||||
pt_want
|
||||
.data
|
||||
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
|
||||
|
||||
pt_want.to_mut().at_mut(0, 0)[1] = 1;
|
||||
|
||||
let k: usize = 1;
|
||||
|
||||
pt_rgsw.raw_mut()[k] = 1; // X^{k}
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_rgsw.size())
|
||||
| GLWECiphertext::decrypt_scratch_space(&module, ct_out.size())
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct_in.size())
|
||||
| GLWECiphertextFourier::external_product_scratch_space(&module, ct_out.size(), ct_in.size(), ct_rgsw.size(), rank),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
ct_rgsw.encrypt_sk(
|
||||
&module,
|
||||
&pt_rgsw,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_in.encrypt_sk(
|
||||
&module,
|
||||
&pt_want,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_in.dft(&module, &mut ct_in_dft);
|
||||
ct_out_dft.external_product(&module, &ct_in_dft, &ct_rgsw, scratch.borrow());
|
||||
ct_out_dft.idft(&module, &mut ct_out, scratch.borrow());
|
||||
|
||||
ct_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0);
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, basek).log2();
|
||||
|
||||
let var_gct_err_lhs: f64 = sigma * sigma;
|
||||
let var_gct_err_rhs: f64 = 0f64;
|
||||
|
||||
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
|
||||
let var_a0_err: f64 = sigma * sigma;
|
||||
let var_a1_err: f64 = 1f64 / 12f64;
|
||||
|
||||
let noise_want: f64 = noise_ggsw_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
var_msg,
|
||||
var_a0_err,
|
||||
var_a1_err,
|
||||
var_gct_err_lhs,
|
||||
var_gct_err_rhs,
|
||||
rank as f64,
|
||||
k_ct_in,
|
||||
k_ggsw,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
}
|
||||
|
||||
fn test_external_product_inplace(log_n: usize, basek: usize, k_ggsw: usize, k_ct: usize, rank: usize, sigma: f64) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
let rows: usize = (k_ct + basek - 1) / basek;
|
||||
|
||||
let mut ct_ggsw: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k_ggsw, rows, rank);
|
||||
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::new(&module, basek, k_ct, rank);
|
||||
let mut ct_rlwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ct, rank);
|
||||
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ct);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
// Random input plaintext
|
||||
pt_want
|
||||
.data
|
||||
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
|
||||
|
||||
pt_want.to_mut().at_mut(0, 0)[1] = 1;
|
||||
|
||||
let k: usize = 1;
|
||||
|
||||
pt_rgsw.raw_mut()[k] = 1; // X^{k}
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw.size())
|
||||
| GLWECiphertext::decrypt_scratch_space(&module, ct.size())
|
||||
| GLWECiphertext::encrypt_sk_scratch_space(&module, ct.size())
|
||||
| GLWECiphertextFourier::external_product_inplace_scratch_space(&module, ct.size(), ct_ggsw.size(), rank),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
ct_ggsw.encrypt_sk(
|
||||
&module,
|
||||
&pt_rgsw,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct.encrypt_sk(
|
||||
&module,
|
||||
&pt_want,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct.dft(&module, &mut ct_rlwe_dft);
|
||||
ct_rlwe_dft.external_product_inplace(&module, &ct_ggsw, scratch.borrow());
|
||||
ct_rlwe_dft.idft(&module, &mut ct, scratch.borrow());
|
||||
|
||||
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0);
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, basek).log2();
|
||||
|
||||
let var_gct_err_lhs: f64 = sigma * sigma;
|
||||
let var_gct_err_rhs: f64 = 0f64;
|
||||
|
||||
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
|
||||
let var_a0_err: f64 = sigma * sigma;
|
||||
let var_a1_err: f64 = 1f64 / 12f64;
|
||||
|
||||
let noise_want: f64 = noise_ggsw_product(
|
||||
module.n() as f64,
|
||||
basek,
|
||||
0.5,
|
||||
var_msg,
|
||||
var_a0_err,
|
||||
var_a1_err,
|
||||
var_gct_err_lhs,
|
||||
var_gct_err_rhs,
|
||||
rank as f64,
|
||||
k_ct,
|
||||
k_ggsw,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
}
|
||||
6
core/src/test_fft64/mod.rs
Normal file
6
core/src/test_fft64/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
mod automorphism_key;
|
||||
mod gglwe;
|
||||
mod ggsw;
|
||||
mod glwe;
|
||||
mod glwe_fourier;
|
||||
mod tensor_key;
|
||||
77
core/src/test_fft64/tensor_key.rs
Normal file
77
core/src/test_fft64/tensor_key.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use base2k::{FFT64, Module, ScalarZnx, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxDftOps, VecZnxOps};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::{GetRow, Infos},
|
||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||
glwe_plaintext::GLWEPlaintext,
|
||||
keys::{SecretKey, SecretKeyFourier},
|
||||
tensor_key::TensorKey,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn encrypt_sk() {
|
||||
(1..4).for_each(|rank| {
|
||||
println!("test encrypt_sk rank: {}", rank);
|
||||
test_encrypt_sk(12, 16, 54, 3.2, rank);
|
||||
});
|
||||
}
|
||||
|
||||
fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||
|
||||
let rows: usize = (k + basek - 1) / basek;
|
||||
|
||||
let mut tensor_key: TensorKey<Vec<u8>, FFT64> = TensorKey::new(&module, basek, k, rows, rank);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(TensorKey::encrypt_sk_scratch_space(
|
||||
&module,
|
||||
rank,
|
||||
tensor_key.size(),
|
||||
));
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
tensor_key.encrypt_sk(
|
||||
&module,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank);
|
||||
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
|
||||
|
||||
(0..rank).for_each(|i| {
|
||||
(0..rank).for_each(|j| {
|
||||
let mut sk_ij_dft: base2k::ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(1);
|
||||
module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j);
|
||||
let sk_ij: ScalarZnx<Vec<u8>> = module
|
||||
.vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft())
|
||||
.to_vec_znx_small()
|
||||
.to_scalar_znx();
|
||||
|
||||
(0..tensor_key.rank_in()).for_each(|col_i| {
|
||||
(0..tensor_key.rows()).for_each(|row_i| {
|
||||
tensor_key
|
||||
.at(i, j)
|
||||
.get_row(&module, row_i, col_i, &mut ct_glwe_fourier);
|
||||
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_ij, col_i);
|
||||
let std_pt: f64 = pt.data.std(0, basek) * (k as f64).exp2();
|
||||
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
|
||||
});
|
||||
});
|
||||
})
|
||||
})
|
||||
}
|
||||
3
core/src/utils.rs
Normal file
3
core/src/utils.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub(crate) fn derive_size(basek: usize, k: usize) -> usize {
|
||||
(k + basek - 1) / basek
|
||||
}
|
||||
@@ -1,139 +0,0 @@
|
||||
use base2k::{BACKEND, Module, Sampling, SvpPPolOps, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, alloc_aligned_u8};
|
||||
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
|
||||
use rlwe::{
|
||||
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
||||
elem::ElemCommon,
|
||||
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
|
||||
gadget_product::{gadget_product_core, gadget_product_core_tmp_bytes},
|
||||
keys::SecretKey,
|
||||
parameters::{Parameters, ParametersLiteral},
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
fn bench_gadget_product_inplace(c: &mut Criterion) {
|
||||
fn runner<'a>(
|
||||
module: &'a Module,
|
||||
res_dft_0: &'a mut VecZnxDft,
|
||||
res_dft_1: &'a mut VecZnxDft,
|
||||
a: &'a VecZnx,
|
||||
b: &'a Ciphertext<VmpPMat>,
|
||||
b_cols: usize,
|
||||
tmp_bytes: &'a mut [u8],
|
||||
) -> Box<dyn FnMut() + 'a> {
|
||||
Box::new(move || {
|
||||
gadget_product_core(module, res_dft_0, res_dft_1, a, b, b_cols, tmp_bytes);
|
||||
})
|
||||
}
|
||||
|
||||
let mut b: criterion::BenchmarkGroup<'_, criterion::measurement::WallTime> = c.benchmark_group("gadget_product_inplace");
|
||||
|
||||
for log_n in 10..11 {
|
||||
let params_lit: ParametersLiteral = ParametersLiteral {
|
||||
backend: BACKEND::FFT64,
|
||||
log_n: log_n,
|
||||
log_q: 32,
|
||||
log_p: 0,
|
||||
log_base2k: 16,
|
||||
log_scale: 20,
|
||||
xe: 3.2,
|
||||
xs: 128,
|
||||
};
|
||||
|
||||
let params: Parameters = Parameters::new(¶ms_lit);
|
||||
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
|
||||
params.encrypt_rlwe_sk_tmp_bytes(params.log_q())
|
||||
| gadget_product_core_tmp_bytes(
|
||||
params.module(),
|
||||
params.log_base2k(),
|
||||
params.log_q(),
|
||||
params.log_q(),
|
||||
params.cols_q(),
|
||||
params.log_qp(),
|
||||
)
|
||||
| encrypt_grlwe_sk_tmp_bytes(
|
||||
params.module(),
|
||||
params.log_base2k(),
|
||||
params.cols_qp(),
|
||||
params.log_qp(),
|
||||
),
|
||||
);
|
||||
|
||||
let mut source: Source = Source::new([3; 32]);
|
||||
|
||||
let mut sk0: SecretKey = SecretKey::new(params.module());
|
||||
let mut sk1: SecretKey = SecretKey::new(params.module());
|
||||
sk0.fill_ternary_hw(params.xs(), &mut source);
|
||||
sk1.fill_ternary_hw(params.xs(), &mut source);
|
||||
|
||||
let mut source_xe: Source = Source::new([4; 32]);
|
||||
let mut source_xa: Source = Source::new([5; 32]);
|
||||
|
||||
let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
|
||||
params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0);
|
||||
|
||||
let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
|
||||
params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0);
|
||||
|
||||
let mut gadget_ct: Ciphertext<VmpPMat> = new_gadget_ciphertext(
|
||||
params.module(),
|
||||
params.log_base2k(),
|
||||
params.cols_q(),
|
||||
params.log_qp(),
|
||||
);
|
||||
|
||||
encrypt_grlwe_sk(
|
||||
params.module(),
|
||||
&mut gadget_ct,
|
||||
&sk0.0,
|
||||
&sk1_svp_ppol,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
params.xe(),
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(params.log_q());
|
||||
|
||||
params.encrypt_rlwe_sk(
|
||||
&mut ct,
|
||||
None,
|
||||
&sk0_svp_ppol,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(1, gadget_ct.cols());
|
||||
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(1, gadget_ct.cols());
|
||||
|
||||
let mut a: VecZnx = params.module().new_vec_znx(0, params.cols_q());
|
||||
params
|
||||
.module()
|
||||
.fill_uniform(params.log_base2k(), &mut a, params.cols_q(), &mut source_xa);
|
||||
|
||||
let b_cols: usize = gadget_ct.cols();
|
||||
|
||||
let runners: [(String, Box<dyn FnMut()>); 1] = [(format!("gadget_product"), {
|
||||
runner(
|
||||
params.module(),
|
||||
&mut res_dft_0,
|
||||
&mut res_dft_1,
|
||||
&mut a,
|
||||
&gadget_ct,
|
||||
b_cols,
|
||||
&mut tmp_bytes,
|
||||
)
|
||||
})];
|
||||
|
||||
for (name, mut runner) in runners {
|
||||
let id: BenchmarkId = BenchmarkId::new(name, format!("n={}", 1 << log_n));
|
||||
b.bench_with_input(id, &(), |b: &mut criterion::Bencher<'_>, _| {
|
||||
b.iter(&mut runner)
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
criterion_group!(benches, bench_gadget_product_inplace);
|
||||
criterion_main!(benches);
|
||||
@@ -1,76 +0,0 @@
|
||||
use base2k::{Encoding, SvpPPolOps, VecZnx, alloc_aligned};
|
||||
use rlwe::{
|
||||
ciphertext::Ciphertext,
|
||||
elem::ElemCommon,
|
||||
keys::SecretKey,
|
||||
parameters::{Parameters, ParametersLiteral},
|
||||
plaintext::Plaintext,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
fn main() {
|
||||
let params_lit: ParametersLiteral = ParametersLiteral {
|
||||
backend: base2k::BACKEND::FFT64,
|
||||
log_n: 10,
|
||||
log_q: 54,
|
||||
log_p: 0,
|
||||
log_base2k: 17,
|
||||
log_scale: 20,
|
||||
xe: 3.2,
|
||||
xs: 128,
|
||||
};
|
||||
|
||||
let params: Parameters = Parameters::new(¶ms_lit);
|
||||
|
||||
let mut tmp_bytes: Vec<u8> =
|
||||
alloc_aligned(params.decrypt_rlwe_tmp_byte(params.log_q()) | params.encrypt_rlwe_sk_tmp_bytes(params.log_q()));
|
||||
|
||||
let mut source: Source = Source::new([0; 32]);
|
||||
let mut sk: SecretKey = SecretKey::new(params.module());
|
||||
sk.fill_ternary_hw(params.xs(), &mut source);
|
||||
|
||||
let mut want = vec![i64::default(); params.n()];
|
||||
|
||||
want.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
|
||||
let mut pt: Plaintext = params.new_plaintext(params.log_q());
|
||||
|
||||
let log_base2k = pt.log_base2k();
|
||||
|
||||
let log_k: usize = params.log_q() - 20;
|
||||
|
||||
pt.0.value[0].encode_vec_i64(0, log_base2k, log_k, &want, 32);
|
||||
pt.0.value[0].normalize(log_base2k, &mut tmp_bytes);
|
||||
|
||||
println!("log_k: {}", log_k);
|
||||
pt.0.value[0].print(0, pt.cols(), 16);
|
||||
println!();
|
||||
|
||||
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(params.log_q());
|
||||
|
||||
let mut source_xe: Source = Source::new([1; 32]);
|
||||
let mut source_xa: Source = Source::new([2; 32]);
|
||||
|
||||
let mut sk_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
|
||||
params.module().svp_prepare(&mut sk_svp_ppol, &sk.0);
|
||||
|
||||
params.encrypt_rlwe_sk(
|
||||
&mut ct,
|
||||
Some(&pt),
|
||||
&sk_svp_ppol,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
params.decrypt_rlwe(&mut pt, &ct, &sk_svp_ppol, &mut tmp_bytes);
|
||||
pt.0.value[0].print(0, pt.cols(), 16);
|
||||
|
||||
let mut have = vec![i64::default(); params.n()];
|
||||
|
||||
println!("pt: {}", log_k);
|
||||
pt.0.value[0].decode_vec_i64(0, pt.log_base2k(), log_k, &mut have);
|
||||
|
||||
println!("want: {:?}", &want[..16]);
|
||||
println!("have: {:?}", &have[..16]);
|
||||
}
|
||||
@@ -1,349 +0,0 @@
|
||||
use crate::{
|
||||
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
||||
elem::ElemCommon,
|
||||
encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes},
|
||||
key_switching::{key_switch_rlwe, key_switch_rlwe_inplace, key_switch_tmp_bytes},
|
||||
keys::SecretKey,
|
||||
parameters::Parameters,
|
||||
};
|
||||
use base2k::{
|
||||
Module, Scalar, ScalarOps, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat,
|
||||
VmpPMatOps, assert_alignement,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Stores DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E, A]) where AUTO(X, p): X^{i} -> X^{i*p}
|
||||
pub struct AutomorphismKey {
|
||||
pub value: Ciphertext<VmpPMat>,
|
||||
pub p: i64,
|
||||
}
|
||||
|
||||
pub fn automorphis_key_new_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
|
||||
module.bytes_of_scalar() + module.bytes_of_svp_ppol() + encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
pub fn automorphism_key_new_tmp_bytes(&self, rows: usize, log_q: usize) -> usize {
|
||||
automorphis_key_new_tmp_bytes(self.module(), self.log_base2k(), rows, log_q)
|
||||
}
|
||||
|
||||
pub fn automorphism_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
|
||||
automorphism_tmp_bytes(
|
||||
self.module(),
|
||||
self.log_base2k(),
|
||||
res_logq,
|
||||
in_logq,
|
||||
gct_logq,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl AutomorphismKey {
|
||||
pub fn new(
|
||||
module: &Module,
|
||||
p: i64,
|
||||
sk: &SecretKey,
|
||||
log_base2k: usize,
|
||||
rows: usize,
|
||||
log_q: usize,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
tmp_bytes: &mut [u8],
|
||||
) -> Self {
|
||||
Self::new_many_core(
|
||||
module,
|
||||
&vec![p],
|
||||
sk,
|
||||
log_base2k,
|
||||
rows,
|
||||
log_q,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
tmp_bytes,
|
||||
)
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub fn new_many(
|
||||
module: &Module,
|
||||
p: &Vec<i64>,
|
||||
sk: &SecretKey,
|
||||
log_base2k: usize,
|
||||
rows: usize,
|
||||
log_q: usize,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
tmp_bytes: &mut [u8],
|
||||
) -> HashMap<i64, AutomorphismKey> {
|
||||
Self::new_many_core(
|
||||
module, p, sk, log_base2k, rows, log_q, source_xa, source_xe, sigma, tmp_bytes,
|
||||
)
|
||||
.into_iter()
|
||||
.zip(p.iter().cloned())
|
||||
.map(|(key, pi)| (pi, key))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn new_many_core(
|
||||
module: &Module,
|
||||
p: &Vec<i64>,
|
||||
sk: &SecretKey,
|
||||
log_base2k: usize,
|
||||
rows: usize,
|
||||
log_q: usize,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
tmp_bytes: &mut [u8],
|
||||
) -> Vec<Self> {
|
||||
let (sk_auto_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_scalar());
|
||||
let (sk_out_bytes, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_svp_ppol());
|
||||
|
||||
let sk_auto: Scalar = module.new_scalar_from_bytes_borrow(sk_auto_bytes);
|
||||
let mut sk_out: SvpPPol = module.new_svp_ppol_from_bytes_borrow(sk_out_bytes);
|
||||
|
||||
let mut keys: Vec<AutomorphismKey> = Vec::new();
|
||||
|
||||
p.iter().for_each(|pi| {
|
||||
let mut value: Ciphertext<VmpPMat> = new_gadget_ciphertext(module, log_base2k, rows, log_q);
|
||||
|
||||
let p_inv: i64 = module.galois_element_inv(*pi);
|
||||
|
||||
module.vec_znx_automorphism(p_inv, &mut sk_auto.as_vec_znx(), &sk.0.as_vec_znx());
|
||||
module.svp_prepare(&mut sk_out, &sk_auto);
|
||||
encrypt_grlwe_sk(
|
||||
module, &mut value, &sk.0, &sk_out, source_xa, source_xe, sigma, tmp_bytes,
|
||||
);
|
||||
|
||||
keys.push(Self {
|
||||
value: value,
|
||||
p: *pi,
|
||||
})
|
||||
});
|
||||
|
||||
keys
|
||||
}
|
||||
}
|
||||
|
||||
pub fn automorphism_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
|
||||
key_switch_tmp_bytes(module, log_base2k, res_logq, in_logq, gct_logq)
|
||||
}
|
||||
|
||||
pub fn automorphism(
|
||||
module: &Module,
|
||||
c: &mut Ciphertext<VecZnx>,
|
||||
a: &Ciphertext<VecZnx>,
|
||||
b: &AutomorphismKey,
|
||||
b_cols: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
key_switch_rlwe(module, c, a, &b.value, b_cols, tmp_bytes);
|
||||
// c[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
|
||||
module.vec_znx_automorphism_inplace(b.p, c.at_mut(0));
|
||||
// c[1] = AUTO(b, p)
|
||||
module.vec_znx_automorphism_inplace(b.p, c.at_mut(1));
|
||||
}
|
||||
|
||||
pub fn automorphism_inplace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
|
||||
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
|
||||
+ 2 * module.bytes_of_vec_znx_dft(1, std::cmp::min(c_cols, a_cols));
|
||||
}
|
||||
|
||||
pub fn automorphism_inplace(
|
||||
module: &Module,
|
||||
a: &mut Ciphertext<VecZnx>,
|
||||
b: &AutomorphismKey,
|
||||
b_cols: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
key_switch_rlwe_inplace(module, a, &b.value, b_cols, tmp_bytes);
|
||||
// a[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
|
||||
module.vec_znx_automorphism_inplace(b.p, a.at_mut(0));
|
||||
// a[1] = AUTO(b, p)
|
||||
module.vec_znx_automorphism_inplace(b.p, a.at_mut(1));
|
||||
}
|
||||
|
||||
pub fn automorphism_big(
|
||||
module: &Module,
|
||||
c: &mut Ciphertext<VecZnxBig>,
|
||||
a: &Ciphertext<VecZnx>,
|
||||
b: &AutomorphismKey,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
let cols = std::cmp::min(c.cols(), a.cols());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(tmp_bytes.len() >= automorphism_tmp_bytes(module, c.cols(), a.cols(), b.value.rows(), b.value.cols()));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
|
||||
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
|
||||
|
||||
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft);
|
||||
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft);
|
||||
|
||||
// a1_dft = DFT(a[1])
|
||||
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||
|
||||
// res_dft = IDFT(<DFT(a), DFT([-A*AUTO(s, -p) + 2^{-K*i}*s + E])>) = [-b*AUTO(s, -p) + a * s + e]
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(0), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(c.at_mut(0), &mut res_dft);
|
||||
|
||||
// res_dft = [-b*AUTO(s, -p) + a * s + e] + [-a * s + m + e] = [-b*AUTO(s, -p) + m + e]
|
||||
module.vec_znx_big_add_small_inplace(c.at_mut(0), a.at(0));
|
||||
|
||||
// c[0] = AUTO([-b*AUTO(s, -p) + m + e], p) = [-AUTO(b, p)*s + AUTO(m, p) + AUTO(b, e)]
|
||||
module.vec_znx_big_automorphism_inplace(b.p, c.at_mut(0));
|
||||
|
||||
// res_dft = IDFT(<DFT(a), DFT([A])>) = [b]
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.value.at(1), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(c.at_mut(1), &mut res_dft);
|
||||
|
||||
// c[1] = AUTO(b, p)
|
||||
module.vec_znx_big_automorphism_inplace(b.p, c.at_mut(1));
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::{AutomorphismKey, automorphism};
|
||||
use crate::{
|
||||
ciphertext::Ciphertext,
|
||||
decryptor::decrypt_rlwe,
|
||||
elem::ElemCommon,
|
||||
encryptor::encrypt_rlwe_sk,
|
||||
keys::SecretKey,
|
||||
parameters::{Parameters, ParametersLiteral},
|
||||
plaintext::Plaintext,
|
||||
};
|
||||
use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, alloc_aligned};
|
||||
use sampling::source::{Source, new_seed};
|
||||
|
||||
#[test]
|
||||
fn test_automorphism() {
|
||||
let log_base2k: usize = 10;
|
||||
let log_q: usize = 50;
|
||||
let log_p: usize = 15;
|
||||
|
||||
// Basic parameters with enough limbs to test edge cases
|
||||
let params_lit: ParametersLiteral = ParametersLiteral {
|
||||
backend: BACKEND::FFT64,
|
||||
log_n: 12,
|
||||
log_q: log_q,
|
||||
log_p: log_p,
|
||||
log_base2k: log_base2k,
|
||||
log_scale: 20,
|
||||
xe: 3.2,
|
||||
xs: 1 << 11,
|
||||
};
|
||||
|
||||
let params: Parameters = Parameters::new(¶ms_lit);
|
||||
|
||||
let module: &Module = params.module();
|
||||
let log_q: usize = params.log_q();
|
||||
let log_qp: usize = params.log_qp();
|
||||
let gct_rows: usize = params.cols_q();
|
||||
let gct_cols: usize = params.cols_qp();
|
||||
|
||||
// scratch space
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned(
|
||||
params.decrypt_rlwe_tmp_byte(log_q)
|
||||
| params.encrypt_rlwe_sk_tmp_bytes(log_q)
|
||||
| params.automorphism_key_new_tmp_bytes(gct_rows, log_qp)
|
||||
| params.automorphism_tmp_bytes(log_q, log_q, log_qp),
|
||||
);
|
||||
|
||||
// Samplers for public and private randomness
|
||||
let mut source_xe: Source = Source::new(new_seed());
|
||||
let mut source_xa: Source = Source::new(new_seed());
|
||||
let mut source_xs: Source = Source::new(new_seed());
|
||||
|
||||
let mut sk: SecretKey = SecretKey::new(module);
|
||||
sk.fill_ternary_hw(params.xs(), &mut source_xs);
|
||||
let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol();
|
||||
module.svp_prepare(&mut sk_svp_ppol, &sk.0);
|
||||
|
||||
let p: i64 = -5;
|
||||
|
||||
let auto_key: AutomorphismKey = AutomorphismKey::new(
|
||||
module,
|
||||
p,
|
||||
&sk,
|
||||
log_base2k,
|
||||
gct_rows,
|
||||
log_qp,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
params.xe(),
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
let mut data: Vec<i64> = vec![0i64; params.n()];
|
||||
|
||||
data.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
|
||||
|
||||
let log_k: usize = 2 * log_base2k;
|
||||
|
||||
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
|
||||
let mut pt: Plaintext = params.new_plaintext(log_q);
|
||||
let mut pt_auto: Plaintext = params.new_plaintext(log_q);
|
||||
|
||||
pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32);
|
||||
module.vec_znx_automorphism(p, pt_auto.at_mut(0), pt.at(0));
|
||||
|
||||
encrypt_rlwe_sk(
|
||||
module,
|
||||
&mut ct.elem_mut(),
|
||||
Some(pt.at(0)),
|
||||
&sk_svp_ppol,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
params.xe(),
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
let mut ct_auto: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
|
||||
|
||||
// ct <- AUTO(ct)
|
||||
automorphism(
|
||||
module,
|
||||
&mut ct_auto,
|
||||
&ct,
|
||||
&auto_key,
|
||||
gct_cols,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
// pt = dec(auto(ct)) - auto(pt)
|
||||
decrypt_rlwe(
|
||||
module,
|
||||
pt.elem_mut(),
|
||||
ct_auto.elem(),
|
||||
&sk_svp_ppol,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_auto.at(0));
|
||||
|
||||
// pt.at(0).print(pt.cols(), 16);
|
||||
|
||||
let noise_have: f64 = pt.at(0).std(0, log_base2k).log2();
|
||||
|
||||
let var_msg: f64 = (params.xs() as f64) / params.n() as f64;
|
||||
let var_a_err: f64 = 1f64 / 12f64;
|
||||
|
||||
let noise_pred: f64 = params.noise_grlwe_product(var_msg, var_a_err, ct_auto.log_q(), auto_key.value.log_q());
|
||||
|
||||
println!("noise_pred: {}", noise_pred);
|
||||
println!("noise_have: {}", noise_have);
|
||||
|
||||
assert!(noise_have <= noise_pred + 1.0);
|
||||
}
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
use crate::elem::{Elem, ElemCommon};
|
||||
use crate::parameters::Parameters;
|
||||
use base2k::{Infos, LAYOUT, Module, VecZnx, VmpPMat};
|
||||
|
||||
pub struct Ciphertext<T>(pub Elem<T>);
|
||||
|
||||
impl Parameters {
|
||||
pub fn new_ciphertext(&self, log_q: usize) -> Ciphertext<VecZnx> {
|
||||
Ciphertext::new(self.module(), self.log_base2k(), log_q, 2)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ElemCommon<T> for Ciphertext<T>
|
||||
where
|
||||
T: Infos,
|
||||
{
|
||||
fn n(&self) -> usize {
|
||||
self.elem().n()
|
||||
}
|
||||
|
||||
fn log_n(&self) -> usize {
|
||||
self.elem().log_n()
|
||||
}
|
||||
|
||||
fn log_q(&self) -> usize {
|
||||
self.elem().log_q()
|
||||
}
|
||||
|
||||
fn elem(&self) -> &Elem<T> {
|
||||
&self.0
|
||||
}
|
||||
|
||||
fn elem_mut(&mut self) -> &mut Elem<T> {
|
||||
&mut self.0
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.elem().size()
|
||||
}
|
||||
|
||||
fn layout(&self) -> LAYOUT {
|
||||
self.elem().layout()
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.elem().rows()
|
||||
}
|
||||
|
||||
fn cols(&self) -> usize {
|
||||
self.elem().cols()
|
||||
}
|
||||
|
||||
fn at(&self, i: usize) -> &T {
|
||||
self.elem().at(i)
|
||||
}
|
||||
|
||||
fn at_mut(&mut self, i: usize) -> &mut T {
|
||||
self.elem_mut().at_mut(i)
|
||||
}
|
||||
|
||||
fn log_base2k(&self) -> usize {
|
||||
self.elem().log_base2k()
|
||||
}
|
||||
|
||||
fn log_scale(&self) -> usize {
|
||||
self.elem().log_scale()
|
||||
}
|
||||
}
|
||||
|
||||
impl Ciphertext<VecZnx> {
|
||||
pub fn new(module: &Module, log_base2k: usize, log_q: usize, rows: usize) -> Self {
|
||||
Self(Elem::<VecZnx>::new(module, log_base2k, log_q, rows))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_rlwe_ciphertext(module: &Module, log_base2k: usize, log_q: usize) -> Ciphertext<VecZnx> {
|
||||
let rows: usize = 2;
|
||||
Ciphertext::<VecZnx>::new(module, log_base2k, log_q, rows)
|
||||
}
|
||||
|
||||
pub fn new_gadget_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext<VmpPMat> {
|
||||
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
|
||||
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 2, rows, cols);
|
||||
elem.log_q = log_q;
|
||||
Ciphertext(elem)
|
||||
}
|
||||
|
||||
pub fn new_rgsw_ciphertext(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> Ciphertext<VmpPMat> {
|
||||
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
|
||||
let mut elem: Elem<VmpPMat> = Elem::<VmpPMat>::new(module, log_base2k, 4, rows, cols);
|
||||
elem.log_q = log_q;
|
||||
Ciphertext(elem)
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
use crate::{
|
||||
ciphertext::Ciphertext,
|
||||
elem::{Elem, ElemCommon},
|
||||
keys::SecretKey,
|
||||
parameters::Parameters,
|
||||
plaintext::Plaintext,
|
||||
};
|
||||
use base2k::{Module, SvpPPol, SvpPPolOps, VecZnx, VecZnxBigOps, VecZnxDft, VecZnxDftOps};
|
||||
use std::cmp::min;
|
||||
|
||||
pub struct Decryptor {
|
||||
sk: SvpPPol,
|
||||
}
|
||||
|
||||
impl Decryptor {
|
||||
pub fn new(params: &Parameters, sk: &SecretKey) -> Self {
|
||||
let mut sk_svp_ppol: SvpPPol = params.module().new_svp_ppol();
|
||||
sk.prepare(params.module(), &mut sk_svp_ppol);
|
||||
Self { sk: sk_svp_ppol }
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrypt_rlwe_tmp_byte(module: &Module, cols: usize) -> usize {
|
||||
module.bytes_of_vec_znx_dft(1, cols) + module.vec_znx_big_normalize_tmp_bytes()
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
pub fn decrypt_rlwe_tmp_byte(&self, log_q: usize) -> usize {
|
||||
decrypt_rlwe_tmp_byte(
|
||||
self.module(),
|
||||
(log_q + self.log_base2k() - 1) / self.log_base2k(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn decrypt_rlwe(&self, res: &mut Plaintext, ct: &Ciphertext<VecZnx>, sk: &SvpPPol, tmp_bytes: &mut [u8]) {
|
||||
decrypt_rlwe(self.module(), &mut res.0, &ct.0, sk, tmp_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrypt_rlwe(module: &Module, res: &mut Elem<VecZnx>, a: &Elem<VecZnx>, sk: &SvpPPol, tmp_bytes: &mut [u8]) {
|
||||
let cols: usize = a.cols();
|
||||
|
||||
assert!(
|
||||
tmp_bytes.len() >= decrypt_rlwe_tmp_byte(module, cols),
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} < decrypt_rlwe_tmp_byte={}",
|
||||
tmp_bytes.len(),
|
||||
decrypt_rlwe_tmp_byte(module, cols)
|
||||
);
|
||||
|
||||
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
|
||||
|
||||
let mut res_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, 1, cols, tmp_bytes_vec_znx_dft);
|
||||
let mut res_big: base2k::VecZnxBig = res_dft.as_vec_znx_big();
|
||||
|
||||
// res_dft <- DFT(ct[1]) * DFT(sk)
|
||||
module.svp_apply_dft(&mut res_dft, sk, a.at(1));
|
||||
// res_big <- ct[1] x sk
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
// res_big <- ct[1] x sk + ct[0]
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
|
||||
// res <- normalize(ct[1] x sk + ct[0])
|
||||
module.vec_znx_big_normalize(a.log_base2k(), res.at_mut(0), &res_big, tmp_bytes_normalize);
|
||||
|
||||
res.log_base2k = a.log_base2k();
|
||||
res.log_q = min(res.log_q(), a.log_q());
|
||||
res.log_scale = a.log_scale();
|
||||
}
|
||||
168
rlwe/src/elem.rs
168
rlwe/src/elem.rs
@@ -1,168 +0,0 @@
|
||||
use base2k::{Infos, LAYOUT, Module, VecZnx, VecZnxOps, VmpPMat, VmpPMatOps};
|
||||
|
||||
pub struct Elem<T> {
|
||||
pub value: Vec<T>,
|
||||
pub log_base2k: usize,
|
||||
pub log_q: usize,
|
||||
pub log_scale: usize,
|
||||
}
|
||||
|
||||
pub trait ElemVecZnx {
|
||||
fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx>;
|
||||
fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx>;
|
||||
fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize;
|
||||
fn zero(&mut self);
|
||||
}
|
||||
|
||||
impl ElemVecZnx for Elem<VecZnx> {
|
||||
fn bytes_of(module: &Module, log_base2k: usize, log_q: usize, size: usize) -> usize {
|
||||
let cols = (log_q + log_base2k - 1) / log_base2k;
|
||||
module.n() * cols * size * 8
|
||||
}
|
||||
|
||||
fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx> {
|
||||
assert!(size > 0);
|
||||
let n: usize = module.n();
|
||||
assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size));
|
||||
let mut value: Vec<VecZnx> = Vec::new();
|
||||
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
|
||||
let elem_size = VecZnx::bytes_of(n, size, cols);
|
||||
let mut ptr: usize = 0;
|
||||
(0..size).for_each(|_| {
|
||||
value.push(VecZnx::from_bytes(n, 1, cols, &mut bytes[ptr..]));
|
||||
ptr += elem_size
|
||||
});
|
||||
Self {
|
||||
value,
|
||||
log_q,
|
||||
log_base2k,
|
||||
log_scale: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, size: usize, bytes: &mut [u8]) -> Elem<VecZnx> {
|
||||
assert!(size > 0);
|
||||
let n: usize = module.n();
|
||||
assert!(bytes.len() >= Self::bytes_of(module, log_base2k, log_q, size));
|
||||
let mut value: Vec<VecZnx> = Vec::new();
|
||||
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
|
||||
let elem_size = VecZnx::bytes_of(n, 1, cols);
|
||||
let mut ptr: usize = 0;
|
||||
(0..size).for_each(|_| {
|
||||
value.push(VecZnx::from_bytes_borrow(n, 1, cols, &mut bytes[ptr..]));
|
||||
ptr += elem_size
|
||||
});
|
||||
Self {
|
||||
value,
|
||||
log_q,
|
||||
log_base2k,
|
||||
log_scale: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn zero(&mut self) {
|
||||
self.value.iter_mut().for_each(|i| i.zero());
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ElemCommon<T> {
|
||||
fn n(&self) -> usize;
|
||||
fn log_n(&self) -> usize;
|
||||
fn elem(&self) -> &Elem<T>;
|
||||
fn elem_mut(&mut self) -> &mut Elem<T>;
|
||||
fn size(&self) -> usize;
|
||||
fn layout(&self) -> LAYOUT;
|
||||
fn rows(&self) -> usize;
|
||||
fn cols(&self) -> usize;
|
||||
fn log_base2k(&self) -> usize;
|
||||
fn log_q(&self) -> usize;
|
||||
fn log_scale(&self) -> usize;
|
||||
fn at(&self, i: usize) -> &T;
|
||||
fn at_mut(&mut self, i: usize) -> &mut T;
|
||||
}
|
||||
|
||||
impl<T: Infos> ElemCommon<T> for Elem<T> {
|
||||
fn n(&self) -> usize {
|
||||
self.value[0].n()
|
||||
}
|
||||
|
||||
fn log_n(&self) -> usize {
|
||||
self.value[0].log_n()
|
||||
}
|
||||
|
||||
fn elem(&self) -> &Elem<T> {
|
||||
self
|
||||
}
|
||||
|
||||
fn elem_mut(&mut self) -> &mut Elem<T> {
|
||||
self
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.value.len()
|
||||
}
|
||||
|
||||
fn layout(&self) -> LAYOUT {
|
||||
self.value[0].layout()
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.value[0].rows()
|
||||
}
|
||||
|
||||
fn cols(&self) -> usize {
|
||||
self.value[0].cols()
|
||||
}
|
||||
|
||||
fn log_base2k(&self) -> usize {
|
||||
self.log_base2k
|
||||
}
|
||||
|
||||
fn log_q(&self) -> usize {
|
||||
self.log_q
|
||||
}
|
||||
|
||||
fn log_scale(&self) -> usize {
|
||||
self.log_scale
|
||||
}
|
||||
|
||||
fn at(&self, i: usize) -> &T {
|
||||
assert!(i < self.size());
|
||||
&self.value[i]
|
||||
}
|
||||
|
||||
fn at_mut(&mut self, i: usize) -> &mut T {
|
||||
assert!(i < self.size());
|
||||
&mut self.value[i]
|
||||
}
|
||||
}
|
||||
|
||||
impl Elem<VecZnx> {
|
||||
pub fn new(module: &Module, log_base2k: usize, log_q: usize, rows: usize) -> Self {
|
||||
assert!(rows > 0);
|
||||
let cols: usize = (log_q + log_base2k - 1) / log_base2k;
|
||||
let mut value: Vec<VecZnx> = Vec::new();
|
||||
(0..rows).for_each(|_| value.push(module.new_vec_znx(1, cols)));
|
||||
Self {
|
||||
value,
|
||||
log_q,
|
||||
log_base2k,
|
||||
log_scale: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Elem<VmpPMat> {
|
||||
pub fn new(module: &Module, log_base2k: usize, size: usize, rows: usize, cols: usize) -> Self {
|
||||
assert!(rows > 0);
|
||||
assert!(cols > 0);
|
||||
let mut value: Vec<VmpPMat> = Vec::new();
|
||||
(0..size).for_each(|_| value.push(module.new_vmp_pmat(1, rows, cols)));
|
||||
Self {
|
||||
value: value,
|
||||
log_q: 0,
|
||||
log_base2k: log_base2k,
|
||||
log_scale: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,369 +0,0 @@
|
||||
use crate::ciphertext::Ciphertext;
|
||||
use crate::elem::{Elem, ElemCommon, ElemVecZnx};
|
||||
use crate::keys::SecretKey;
|
||||
use crate::parameters::Parameters;
|
||||
use crate::plaintext::Plaintext;
|
||||
use base2k::sampling::Sampling;
|
||||
use base2k::{
|
||||
Infos, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat,
|
||||
VmpPMatOps,
|
||||
};
|
||||
|
||||
use sampling::source::{Source, new_seed};
|
||||
|
||||
impl Parameters {
|
||||
pub fn encrypt_rlwe_sk_tmp_bytes(&self, log_q: usize) -> usize {
|
||||
encrypt_rlwe_sk_tmp_bytes(self.module(), self.log_base2k(), log_q)
|
||||
}
|
||||
pub fn encrypt_rlwe_sk(
|
||||
&self,
|
||||
ct: &mut Ciphertext<VecZnx>,
|
||||
pt: Option<&Plaintext>,
|
||||
sk: &SvpPPol,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
encrypt_rlwe_sk(
|
||||
self.module(),
|
||||
&mut ct.0,
|
||||
pt.map(|pt| pt.at(0)),
|
||||
sk,
|
||||
source_xa,
|
||||
source_xe,
|
||||
self.xe(),
|
||||
tmp_bytes,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct EncryptorSk {
|
||||
sk: SvpPPol,
|
||||
source_xa: Source,
|
||||
source_xe: Source,
|
||||
initialized: bool,
|
||||
tmp_bytes: Vec<u8>,
|
||||
}
|
||||
|
||||
impl EncryptorSk {
|
||||
pub fn new(params: &Parameters, sk: Option<&SecretKey>) -> Self {
|
||||
let mut sk_svp_ppol: SvpPPol = params.module().new_svp_ppol();
|
||||
let mut initialized: bool = false;
|
||||
if let Some(sk) = sk {
|
||||
sk.prepare(params.module(), &mut sk_svp_ppol);
|
||||
initialized = true;
|
||||
}
|
||||
Self {
|
||||
sk: sk_svp_ppol,
|
||||
initialized,
|
||||
source_xa: Source::new(new_seed()),
|
||||
source_xe: Source::new(new_seed()),
|
||||
tmp_bytes: vec![0u8; params.encrypt_rlwe_sk_tmp_bytes(params.cols_qp())],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_sk(&mut self, module: &Module, sk: &SecretKey) {
|
||||
sk.prepare(module, &mut self.sk);
|
||||
self.initialized = true;
|
||||
}
|
||||
|
||||
pub fn seed_source_xa(&mut self, seed: [u8; 32]) {
|
||||
self.source_xa = Source::new(seed)
|
||||
}
|
||||
|
||||
pub fn seed_source_xe(&mut self, seed: [u8; 32]) {
|
||||
self.source_xe = Source::new(seed)
|
||||
}
|
||||
|
||||
pub fn encrypt_rlwe_sk(&mut self, params: &Parameters, ct: &mut Ciphertext<VecZnx>, pt: Option<&Plaintext>) {
|
||||
assert!(
|
||||
self.initialized == true,
|
||||
"invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]"
|
||||
);
|
||||
params.encrypt_rlwe_sk(
|
||||
ct,
|
||||
pt,
|
||||
&self.sk,
|
||||
&mut self.source_xa,
|
||||
&mut self.source_xe,
|
||||
&mut self.tmp_bytes,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn encrypt_rlwe_sk_core(
|
||||
&self,
|
||||
params: &Parameters,
|
||||
ct: &mut Ciphertext<VecZnx>,
|
||||
pt: Option<&Plaintext>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
assert!(
|
||||
self.initialized == true,
|
||||
"invalid call to [EncryptorSk.encrypt_rlwe_sk]: [EncryptorSk] has not been initialized with a [SecretKey]"
|
||||
);
|
||||
params.encrypt_rlwe_sk(ct, pt, &self.sk, source_xa, source_xe, tmp_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encrypt_rlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize {
|
||||
module.bytes_of_vec_znx_dft(1, (log_q + log_base2k - 1) / log_base2k) + module.vec_znx_big_normalize_tmp_bytes()
|
||||
}
|
||||
pub fn encrypt_rlwe_sk(
|
||||
module: &Module,
|
||||
ct: &mut Elem<VecZnx>,
|
||||
pt: Option<&VecZnx>,
|
||||
sk: &SvpPPol,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
encrypt_rlwe_sk_core::<0>(module, ct, pt, sk, source_xa, source_xe, sigma, tmp_bytes)
|
||||
}
|
||||
|
||||
fn encrypt_rlwe_sk_core<const PT_POS: u8>(
|
||||
module: &Module,
|
||||
ct: &mut Elem<VecZnx>,
|
||||
pt: Option<&VecZnx>,
|
||||
sk: &SvpPPol,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
let cols: usize = ct.cols();
|
||||
let log_base2k: usize = ct.log_base2k();
|
||||
let log_q: usize = ct.log_q();
|
||||
|
||||
assert!(
|
||||
tmp_bytes.len() >= encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q),
|
||||
"invalid tmp_bytes: tmp_bytes={} < encrypt_rlwe_sk_tmp_bytes={}",
|
||||
tmp_bytes.len(),
|
||||
encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
|
||||
);
|
||||
|
||||
let log_q: usize = ct.log_q();
|
||||
let log_base2k: usize = ct.log_base2k();
|
||||
let c1: &mut VecZnx = ct.at_mut(1);
|
||||
|
||||
// c1 <- Z_{2^prec}[X]/(X^{N}+1)
|
||||
module.fill_uniform(log_base2k, c1, cols, source_xa);
|
||||
|
||||
let (tmp_bytes_vec_znx_dft, tmp_bytes_normalize) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
|
||||
|
||||
// Scratch space for DFT values
|
||||
let mut buf_dft: VecZnxDft = VecZnxDft::from_bytes_borrow(module, 1, cols, tmp_bytes_vec_znx_dft);
|
||||
|
||||
// Applies buf_dft <- DFT(s) * DFT(c1)
|
||||
module.svp_apply_dft(&mut buf_dft, sk, c1);
|
||||
|
||||
// Alias scratch space
|
||||
let mut buf_big: VecZnxBig = buf_dft.as_vec_znx_big();
|
||||
|
||||
// buf_big = s x c1
|
||||
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft);
|
||||
|
||||
match PT_POS {
|
||||
// c0 <- -s x c1 + m
|
||||
0 => {
|
||||
let c0: &mut VecZnx = ct.at_mut(0);
|
||||
if let Some(pt) = pt {
|
||||
module.vec_znx_big_sub_small_a_inplace(&mut buf_big, pt);
|
||||
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
|
||||
} else {
|
||||
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
|
||||
module.vec_znx_negate_inplace(c0);
|
||||
}
|
||||
}
|
||||
// c1 <- c1 + m
|
||||
1 => {
|
||||
if let Some(pt) = pt {
|
||||
module.vec_znx_add_inplace(c1, pt);
|
||||
c1.normalize(log_base2k, tmp_bytes_normalize);
|
||||
}
|
||||
let c0: &mut VecZnx = ct.at_mut(0);
|
||||
module.vec_znx_big_normalize(log_base2k, c0, &buf_big, tmp_bytes_normalize);
|
||||
module.vec_znx_negate_inplace(c0);
|
||||
}
|
||||
_ => panic!("PT_POS must be 1 or 2"),
|
||||
}
|
||||
|
||||
// c0 <- -s x c1 + m + e
|
||||
module.add_normal(
|
||||
log_base2k,
|
||||
ct.at_mut(0),
|
||||
log_q,
|
||||
source_xe,
|
||||
sigma,
|
||||
(sigma * 6.0).ceil(),
|
||||
);
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
pub fn encrypt_grlwe_sk_tmp_bytes(&self, rows: usize, log_q: usize) -> usize {
|
||||
encrypt_grlwe_sk_tmp_bytes(self.module(), self.log_base2k(), rows, log_q)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encrypt_grlwe_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
|
||||
let cols = (log_q + log_base2k - 1) / log_base2k;
|
||||
Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2)
|
||||
+ Plaintext::bytes_of(module, log_base2k, log_q)
|
||||
+ encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
|
||||
+ module.vmp_prepare_tmp_bytes(rows, cols)
|
||||
}
|
||||
|
||||
pub fn encrypt_grlwe_sk(
|
||||
module: &Module,
|
||||
ct: &mut Ciphertext<VmpPMat>,
|
||||
m: &Scalar,
|
||||
sk: &SvpPPol,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
let log_q: usize = ct.log_q();
|
||||
let log_base2k: usize = ct.log_base2k();
|
||||
let (left, right) = ct.0.value.split_at_mut(1);
|
||||
encrypt_grlwe_sk_core::<0>(
|
||||
module,
|
||||
log_base2k,
|
||||
[&mut left[0], &mut right[0]],
|
||||
log_q,
|
||||
m,
|
||||
sk,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
tmp_bytes,
|
||||
)
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
pub fn encrypt_rgsw_sk_tmp_bytes(&self, rows: usize, log_q: usize) -> usize {
|
||||
encrypt_rgsw_sk_tmp_bytes(self.module(), self.log_base2k(), rows, log_q)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encrypt_rgsw_sk_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
|
||||
let cols = (log_q + log_base2k - 1) / log_base2k;
|
||||
Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2)
|
||||
+ Plaintext::bytes_of(module, log_base2k, log_q)
|
||||
+ encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
|
||||
+ module.vmp_prepare_tmp_bytes(rows, cols)
|
||||
}
|
||||
|
||||
pub fn encrypt_rgsw_sk(
|
||||
module: &Module,
|
||||
ct: &mut Ciphertext<VmpPMat>,
|
||||
m: &Scalar,
|
||||
sk: &SvpPPol,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
let log_q: usize = ct.log_q();
|
||||
let log_base2k: usize = ct.log_base2k();
|
||||
|
||||
let (left, right) = ct.0.value.split_at_mut(2);
|
||||
let (ll, lr) = left.split_at_mut(1);
|
||||
let (rl, rr) = right.split_at_mut(1);
|
||||
|
||||
encrypt_grlwe_sk_core::<0>(
|
||||
module,
|
||||
log_base2k,
|
||||
[&mut ll[0], &mut lr[0]],
|
||||
log_q,
|
||||
m,
|
||||
sk,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
tmp_bytes,
|
||||
);
|
||||
encrypt_grlwe_sk_core::<1>(
|
||||
module,
|
||||
log_base2k,
|
||||
[&mut rl[0], &mut rr[0]],
|
||||
log_q,
|
||||
m,
|
||||
sk,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
tmp_bytes,
|
||||
);
|
||||
}
|
||||
|
||||
fn encrypt_grlwe_sk_core<const PT_POS: u8>(
|
||||
module: &Module,
|
||||
log_base2k: usize,
|
||||
mut ct: [&mut VmpPMat; 2],
|
||||
log_q: usize,
|
||||
m: &Scalar,
|
||||
sk: &SvpPPol,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
let rows: usize = ct[0].rows();
|
||||
|
||||
let min_tmp_bytes_len = encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q);
|
||||
|
||||
assert!(
|
||||
tmp_bytes.len() >= min_tmp_bytes_len,
|
||||
"invalid tmp_bytes: tmp_bytes.len()={} < encrypt_grlwe_sk_tmp_bytes={}",
|
||||
tmp_bytes.len(),
|
||||
min_tmp_bytes_len
|
||||
);
|
||||
|
||||
let bytes_of_elem: usize = Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 2);
|
||||
let bytes_of_pt: usize = Plaintext::bytes_of(module, log_base2k, log_q);
|
||||
let bytes_of_enc_sk: usize = encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q);
|
||||
|
||||
let (tmp_bytes_pt, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_pt);
|
||||
let (tmp_bytes_enc_sk, tmp_bytes) = tmp_bytes.split_at_mut(bytes_of_enc_sk);
|
||||
let (tmp_bytes_elem, tmp_bytes_vmp_prepare_row) = tmp_bytes.split_at_mut(bytes_of_elem);
|
||||
|
||||
let mut tmp_elem: Elem<VecZnx> = Elem::<VecZnx>::from_bytes_borrow(module, log_base2k, log_q, 2, tmp_bytes_elem);
|
||||
let mut tmp_pt: Plaintext = Plaintext::from_bytes_borrow(module, log_base2k, log_q, tmp_bytes_pt);
|
||||
|
||||
(0..rows).for_each(|row_i| {
|
||||
// Sets the i-th row of the RLWE sample to m (i.e. m * 2^{-log_base2k*i})
|
||||
tmp_pt.at_mut(0).at_mut(row_i).copy_from_slice(&m.raw());
|
||||
|
||||
// Encrypts RLWE(m * 2^{-log_base2k*i})
|
||||
encrypt_rlwe_sk_core::<PT_POS>(
|
||||
module,
|
||||
&mut tmp_elem,
|
||||
Some(tmp_pt.at(0)),
|
||||
sk,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
tmp_bytes_enc_sk,
|
||||
);
|
||||
|
||||
// Zeroes the ith-row of tmp_pt
|
||||
tmp_pt.at_mut(0).at_mut(row_i).fill(0);
|
||||
|
||||
// GRLWE[row_i][0||1] = [-as + m * 2^{-i*log_base2k} + e*2^{-log_q} || a]
|
||||
module.vmp_prepare_row(
|
||||
ct[0],
|
||||
tmp_elem.at(0).raw(),
|
||||
row_i,
|
||||
tmp_bytes_vmp_prepare_row,
|
||||
);
|
||||
module.vmp_prepare_row(
|
||||
&mut ct[1],
|
||||
tmp_elem.at(1).raw(),
|
||||
row_i,
|
||||
tmp_bytes_vmp_prepare_row,
|
||||
);
|
||||
});
|
||||
}
|
||||
@@ -1,383 +0,0 @@
|
||||
use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters};
|
||||
use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps};
|
||||
use std::cmp::min;
|
||||
|
||||
pub fn gadget_product_core_tmp_bytes(
|
||||
module: &Module,
|
||||
log_base2k: usize,
|
||||
res_log_q: usize,
|
||||
in_log_q: usize,
|
||||
gct_rows: usize,
|
||||
gct_log_q: usize,
|
||||
) -> usize {
|
||||
let gct_cols: usize = (gct_log_q + log_base2k - 1) / log_base2k;
|
||||
let in_cols: usize = (in_log_q + log_base2k - 1) / log_base2k;
|
||||
let out_cols: usize = (res_log_q + log_base2k - 1) / log_base2k;
|
||||
module.vmp_apply_dft_to_dft_tmp_bytes(out_cols, in_cols, gct_rows, gct_cols)
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
pub fn gadget_product_tmp_bytes(&self, res_log_q: usize, in_log_q: usize, gct_rows: usize, gct_log_q: usize) -> usize {
|
||||
gadget_product_core_tmp_bytes(
|
||||
self.module(),
|
||||
self.log_base2k(),
|
||||
res_log_q,
|
||||
in_log_q,
|
||||
gct_rows,
|
||||
gct_log_q,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gadget_product_core(
|
||||
module: &Module,
|
||||
res_dft_0: &mut VecZnxDft,
|
||||
res_dft_1: &mut VecZnxDft,
|
||||
a: &VecZnx,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
b_cols: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
assert!(b_cols <= b.cols());
|
||||
module.vec_znx_dft(res_dft_1, a);
|
||||
module.vmp_apply_dft_to_dft(res_dft_0, res_dft_1, b.at(0), tmp_bytes);
|
||||
module.vmp_apply_dft_to_dft_inplace(res_dft_1, b.at(1), tmp_bytes);
|
||||
}
|
||||
|
||||
pub fn gadget_product_big_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
|
||||
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
|
||||
+ 2 * module.bytes_of_vec_znx_dft(1, min(c_cols, a_cols));
|
||||
}
|
||||
|
||||
/// Evaluates the gadget product: c.at(i) = IDFT(<DFT(a.at(i)), b.at(i)>)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `module`: backend support for operations mod (X^N + 1).
|
||||
/// * `c`: a [Ciphertext<VecZnxBig>] with cols_c cols.
|
||||
/// * `a`: a [Ciphertext<VecZnx>] with cols_a cols.
|
||||
/// * `b`: a [Ciphertext<VmpPMat>] with at least min(cols_c, cols_a) rows.
|
||||
pub fn gadget_product_big(
|
||||
module: &Module,
|
||||
c: &mut Ciphertext<VecZnxBig>,
|
||||
a: &Ciphertext<VecZnx>,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
let cols: usize = min(c.cols(), a.cols());
|
||||
|
||||
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
|
||||
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
|
||||
|
||||
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft);
|
||||
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft);
|
||||
|
||||
// a1_dft = DFT(a[1])
|
||||
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||
|
||||
// c[i] = IDFT(DFT(a[1]) * b[i])
|
||||
(0..2).for_each(|i| {
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(i), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(c.at_mut(i), &mut res_dft);
|
||||
})
|
||||
}
|
||||
|
||||
/// Evaluates the gadget product: c.at(i) = NORMALIZE(IDFT(<DFT(a.at(i)), b.at(i)>)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `module`: backend support for operations mod (X^N + 1).
|
||||
/// * `c`: a [Ciphertext<VecZnx>] with cols_c cols.
|
||||
/// * `a`: a [Ciphertext<VecZnx>] with cols_a cols.
|
||||
/// * `b`: a [Ciphertext<VmpPMat>] with at least min(cols_c, cols_a) rows.
|
||||
pub fn gadget_product(
|
||||
module: &Module,
|
||||
c: &mut Ciphertext<VecZnx>,
|
||||
a: &Ciphertext<VecZnx>,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
let cols: usize = min(c.cols(), a.cols());
|
||||
|
||||
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
|
||||
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
|
||||
|
||||
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft);
|
||||
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_res_dft);
|
||||
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
|
||||
|
||||
// a1_dft = DFT(a[1])
|
||||
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||
|
||||
// c[i] = IDFT(DFT(a[1]) * b[i])
|
||||
(0..2).for_each(|i| {
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(i), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(i), &mut res_big, tmp_bytes);
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::{
|
||||
ciphertext::{Ciphertext, new_gadget_ciphertext},
|
||||
decryptor::decrypt_rlwe,
|
||||
elem::{Elem, ElemCommon, ElemVecZnx},
|
||||
encryptor::encrypt_grlwe_sk,
|
||||
gadget_product::gadget_product_core,
|
||||
keys::SecretKey,
|
||||
parameters::{Parameters, ParametersLiteral},
|
||||
plaintext::Plaintext,
|
||||
};
|
||||
use base2k::{
|
||||
BACKEND, Infos, Sampling, SvpPPolOps, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat,
|
||||
alloc_aligned_u8,
|
||||
};
|
||||
use sampling::source::{Source, new_seed};
|
||||
|
||||
#[test]
|
||||
fn test_gadget_product_core() {
|
||||
let log_base2k: usize = 10;
|
||||
let q_cols: usize = 7;
|
||||
let p_cols: usize = 1;
|
||||
|
||||
// Basic parameters with enough limbs to test edge cases
|
||||
let params_lit: ParametersLiteral = ParametersLiteral {
|
||||
backend: BACKEND::FFT64,
|
||||
log_n: 12,
|
||||
log_q: q_cols * log_base2k,
|
||||
log_p: p_cols * log_base2k,
|
||||
log_base2k: log_base2k,
|
||||
log_scale: 20,
|
||||
xe: 3.2,
|
||||
xs: 1 << 11,
|
||||
};
|
||||
|
||||
let params: Parameters = Parameters::new(¶ms_lit);
|
||||
|
||||
// scratch space
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned_u8(
|
||||
params.decrypt_rlwe_tmp_byte(params.log_qp())
|
||||
| params.gadget_product_tmp_bytes(
|
||||
params.log_qp(),
|
||||
params.log_qp(),
|
||||
params.cols_qp(),
|
||||
params.log_qp(),
|
||||
)
|
||||
| params.encrypt_grlwe_sk_tmp_bytes(params.cols_qp(), params.log_qp()),
|
||||
);
|
||||
|
||||
// Samplers for public and private randomness
|
||||
let mut source_xe: Source = Source::new(new_seed());
|
||||
let mut source_xa: Source = Source::new(new_seed());
|
||||
let mut source_xs: Source = Source::new(new_seed());
|
||||
|
||||
// Two secret keys
|
||||
let mut sk0: SecretKey = SecretKey::new(params.module());
|
||||
sk0.fill_ternary_hw(params.xs(), &mut source_xs);
|
||||
let mut sk0_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
|
||||
params.module().svp_prepare(&mut sk0_svp_ppol, &sk0.0);
|
||||
|
||||
let mut sk1: SecretKey = SecretKey::new(params.module());
|
||||
sk1.fill_ternary_hw(params.xs(), &mut source_xs);
|
||||
let mut sk1_svp_ppol: base2k::SvpPPol = params.module().new_svp_ppol();
|
||||
params.module().svp_prepare(&mut sk1_svp_ppol, &sk1.0);
|
||||
|
||||
// The gadget ciphertext
|
||||
let mut gadget_ct: Ciphertext<VmpPMat> = new_gadget_ciphertext(
|
||||
params.module(),
|
||||
log_base2k,
|
||||
params.cols_qp(),
|
||||
params.log_qp(),
|
||||
);
|
||||
|
||||
// gct = [-b*sk1 + g(sk0) + e, b]
|
||||
encrypt_grlwe_sk(
|
||||
params.module(),
|
||||
&mut gadget_ct,
|
||||
&sk0.0,
|
||||
&sk1_svp_ppol,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
params.xe(),
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
// Intermediate buffers
|
||||
|
||||
// Input polynopmial, uniformly distributed
|
||||
let mut a: VecZnx = params.module().new_vec_znx(1, params.cols_q());
|
||||
params
|
||||
.module()
|
||||
.fill_uniform(log_base2k, &mut a, params.cols_q(), &mut source_xa);
|
||||
|
||||
// res = g^-1(a) * gct
|
||||
let mut elem_res: Elem<VecZnx> = Elem::<VecZnx>::new(params.module(), log_base2k, params.log_qp(), 2);
|
||||
|
||||
// Ideal output = a * s
|
||||
let mut a_dft: VecZnxDft = params.module().new_vec_znx_dft(1, a.cols());
|
||||
let mut a_big: VecZnxBig = a_dft.as_vec_znx_big();
|
||||
let mut a_times_s: VecZnx = params.module().new_vec_znx(1, a.cols());
|
||||
|
||||
// a * sk0
|
||||
params.module().svp_apply_dft(&mut a_dft, &sk0_svp_ppol, &a);
|
||||
params.module().vec_znx_idft_tmp_a(&mut a_big, &mut a_dft);
|
||||
params
|
||||
.module()
|
||||
.vec_znx_big_normalize(params.log_base2k(), &mut a_times_s, &a_big, &mut tmp_bytes);
|
||||
|
||||
// Plaintext for decrypted output of gadget product
|
||||
let mut pt: Plaintext = Plaintext::new(params.module(), params.log_base2k(), params.log_qp());
|
||||
|
||||
// Iterates over all possible cols values for input/output polynomials and gadget ciphertext.
|
||||
|
||||
(1..a.cols() + 1).for_each(|a_cols| {
|
||||
let mut a_trunc: VecZnx = params.module().new_vec_znx(1, a_cols);
|
||||
a_trunc.copy_from(&a);
|
||||
|
||||
(1..gadget_ct.cols() + 1).for_each(|b_cols| {
|
||||
let mut res_dft_0: VecZnxDft = params.module().new_vec_znx_dft(1, b_cols);
|
||||
let mut res_dft_1: VecZnxDft = params.module().new_vec_znx_dft(1, b_cols);
|
||||
let mut res_big_0: VecZnxBig = res_dft_0.as_vec_znx_big();
|
||||
let mut res_big_1: VecZnxBig = res_dft_1.as_vec_znx_big();
|
||||
|
||||
pt.elem_mut().zero();
|
||||
elem_res.zero();
|
||||
|
||||
// let b_cols: usize = min(a_cols+1, gadget_ct.cols());
|
||||
|
||||
println!("a_cols: {} b_cols: {}", a_cols, b_cols);
|
||||
|
||||
// res_dft_0 = DFT(gct_[0] * ct[1] = a * (-bs' + s + e) = -cs' + as + e')
|
||||
// res_dft_1 = DFT(gct_[1] * ct[1] = a * b = c)
|
||||
gadget_product_core(
|
||||
params.module(),
|
||||
&mut res_dft_0,
|
||||
&mut res_dft_1,
|
||||
&a_trunc,
|
||||
&gadget_ct,
|
||||
b_cols,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
// res_big_0 = IDFT(res_dft_0)
|
||||
params
|
||||
.module()
|
||||
.vec_znx_idft_tmp_a(&mut res_big_0, &mut res_dft_0);
|
||||
// res_big_1 = IDFT(res_dft_1);
|
||||
params
|
||||
.module()
|
||||
.vec_znx_idft_tmp_a(&mut res_big_1, &mut res_dft_1);
|
||||
|
||||
// res_big_0 = normalize(res_big_0)
|
||||
params
|
||||
.module()
|
||||
.vec_znx_big_normalize(log_base2k, elem_res.at_mut(0), &res_big_0, &mut tmp_bytes);
|
||||
|
||||
// res_big_1 = normalize(res_big_1)
|
||||
params
|
||||
.module()
|
||||
.vec_znx_big_normalize(log_base2k, elem_res.at_mut(1), &res_big_1, &mut tmp_bytes);
|
||||
|
||||
// <(-c*sk1 + a*sk0 + e, a), (1, sk1)> = a*sk0 + e
|
||||
decrypt_rlwe(
|
||||
params.module(),
|
||||
pt.elem_mut(),
|
||||
&elem_res,
|
||||
&sk1_svp_ppol,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
// a * sk0 + e - a*sk0 = e
|
||||
params
|
||||
.module()
|
||||
.vec_znx_sub_ab_inplace(pt.at_mut(0), &mut a_times_s);
|
||||
pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes);
|
||||
|
||||
// pt.at(0).print(pt.elem().cols(), 16);
|
||||
|
||||
let noise_have: f64 = pt.at(0).std(0, log_base2k).log2();
|
||||
|
||||
let var_a_err: f64;
|
||||
|
||||
if a_cols < a.cols() {
|
||||
var_a_err = 1f64 / 12f64;
|
||||
} else {
|
||||
var_a_err = 0f64;
|
||||
}
|
||||
|
||||
let a_logq: usize = a_cols * log_base2k;
|
||||
let b_logq: usize = b_cols * log_base2k;
|
||||
let var_msg: f64 = (params.xs() as f64) / params.n() as f64;
|
||||
|
||||
println!("{} {} {} {}", var_msg, var_a_err, a_logq, b_logq);
|
||||
|
||||
let noise_pred: f64 = params.noise_grlwe_product(var_msg, var_a_err, a_logq, b_logq);
|
||||
|
||||
println!("noise_pred: {}", noise_pred);
|
||||
println!("noise_have: {}", noise_have);
|
||||
|
||||
// assert!(noise_have <= noise_pred + 1.0);
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
pub fn noise_grlwe_product(&self, var_msg: f64, var_a_err: f64, a_logq: usize, b_logq: usize) -> f64 {
|
||||
let n: f64 = self.n() as f64;
|
||||
let var_xs: f64 = self.xs() as f64;
|
||||
|
||||
let var_gct_err_lhs: f64;
|
||||
let var_gct_err_rhs: f64;
|
||||
if b_logq < self.log_qp() {
|
||||
let var_round: f64 = 1f64 / 12f64;
|
||||
var_gct_err_lhs = var_round;
|
||||
var_gct_err_rhs = var_round;
|
||||
} else {
|
||||
var_gct_err_lhs = self.xe() * self.xe();
|
||||
var_gct_err_rhs = 0f64;
|
||||
}
|
||||
|
||||
noise_grlwe_product(
|
||||
n,
|
||||
self.log_base2k(),
|
||||
var_xs,
|
||||
var_msg,
|
||||
var_a_err,
|
||||
var_gct_err_lhs,
|
||||
var_gct_err_rhs,
|
||||
a_logq,
|
||||
b_logq,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn noise_grlwe_product(
|
||||
n: f64,
|
||||
log_base2k: usize,
|
||||
var_xs: f64,
|
||||
var_msg: f64,
|
||||
var_a_err: f64,
|
||||
var_gct_err_lhs: f64,
|
||||
var_gct_err_rhs: f64,
|
||||
a_logq: usize,
|
||||
b_logq: usize,
|
||||
) -> f64 {
|
||||
let a_logq: usize = min(a_logq, b_logq);
|
||||
let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k;
|
||||
|
||||
let b_scale = 2.0f64.powi(b_logq as i32);
|
||||
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
|
||||
|
||||
let base: f64 = (1 << (log_base2k)) as f64;
|
||||
let var_base: f64 = base * base / 12f64;
|
||||
|
||||
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
|
||||
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
|
||||
let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
|
||||
noise += var_msg * var_a_err * a_scale * a_scale * n;
|
||||
noise = noise.sqrt();
|
||||
noise /= b_scale;
|
||||
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
use crate::encryptor::{encrypt_grlwe_sk, encrypt_grlwe_sk_tmp_bytes};
|
||||
use crate::keys::{PublicKey, SecretKey, SwitchingKey};
|
||||
use crate::parameters::Parameters;
|
||||
use base2k::{Module, SvpPPol};
|
||||
use sampling::source::Source;
|
||||
|
||||
pub struct KeyGenerator {}
|
||||
|
||||
impl KeyGenerator {
|
||||
pub fn gen_secret_key_thread_safe(&self, params: &Parameters, source: &mut Source) -> SecretKey {
|
||||
let mut sk: SecretKey = SecretKey::new(params.module());
|
||||
sk.fill_ternary_hw(params.xs(), source);
|
||||
sk
|
||||
}
|
||||
|
||||
pub fn gen_public_key_thread_safe(
|
||||
&self,
|
||||
params: &Parameters,
|
||||
sk_ppol: &SvpPPol,
|
||||
source: &mut Source,
|
||||
tmp_bytes: &mut [u8],
|
||||
) -> PublicKey {
|
||||
let mut xa_source: Source = source.branch();
|
||||
let mut xe_source: Source = source.branch();
|
||||
let mut pk: PublicKey = PublicKey::new(params.module(), params.log_base2k(), params.log_qp());
|
||||
pk.gen_thread_safe(
|
||||
params.module(),
|
||||
sk_ppol,
|
||||
params.xe(),
|
||||
&mut xa_source,
|
||||
&mut xe_source,
|
||||
tmp_bytes,
|
||||
);
|
||||
pk
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_switching_key_tmp_bytes(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> usize {
|
||||
encrypt_grlwe_sk_tmp_bytes(module, log_base2k, rows, log_q)
|
||||
}
|
||||
|
||||
pub fn gen_switching_key(
|
||||
module: &Module,
|
||||
swk: &mut SwitchingKey,
|
||||
sk_in: &SecretKey,
|
||||
sk_out: &SvpPPol,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
encrypt_grlwe_sk(
|
||||
module, &mut swk.0, &sk_in.0, sk_out, source_xa, source_xe, sigma, tmp_bytes,
|
||||
);
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
use crate::ciphertext::Ciphertext;
|
||||
use crate::elem::ElemCommon;
|
||||
use base2k::{Module, VecZnx, VecZnxBigOps, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement};
|
||||
use std::cmp::min;
|
||||
|
||||
pub fn key_switch_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
|
||||
let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k;
|
||||
let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k;
|
||||
let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k;
|
||||
return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols)
|
||||
+ module.bytes_of_vec_znx_dft(1, std::cmp::min(res_cols, in_cols))
|
||||
+ module.bytes_of_vec_znx_dft(1, gct_cols);
|
||||
}
|
||||
|
||||
pub fn key_switch_rlwe(
|
||||
module: &Module,
|
||||
c: &mut Ciphertext<VecZnx>,
|
||||
a: &Ciphertext<VecZnx>,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
b_cols: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
key_switch_rlwe_core(module, c, a, b, b_cols, tmp_bytes);
|
||||
}
|
||||
|
||||
pub fn key_switch_rlwe_inplace(
|
||||
module: &Module,
|
||||
a: &mut Ciphertext<VecZnx>,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
b_cols: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
key_switch_rlwe_core(module, a, a, b, b_cols, tmp_bytes);
|
||||
}
|
||||
|
||||
fn key_switch_rlwe_core(
|
||||
module: &Module,
|
||||
c: *mut Ciphertext<VecZnx>,
|
||||
a: *const Ciphertext<VecZnx>,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
b_cols: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
// SAFETY WARNING: must ensure `c` and `a` are valid for read/write
|
||||
let c: &mut Ciphertext<VecZnx> = unsafe { &mut *c };
|
||||
let a: &Ciphertext<VecZnx> = unsafe { &*a };
|
||||
|
||||
let cols: usize = min(min(c.cols(), a.cols()), b.rows());
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(b_cols <= b.cols());
|
||||
assert!(tmp_bytes.len() >= key_switch_tmp_bytes(module, c.cols(), a.cols(), b.rows(), b.cols()));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let (tmp_bytes_a1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
|
||||
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
|
||||
|
||||
let mut a1_dft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_a1_dft);
|
||||
let mut res_dft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_res_dft);
|
||||
let mut res_big = res_dft.as_vec_znx_big();
|
||||
|
||||
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(0), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
|
||||
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(0), &mut res_big, tmp_bytes);
|
||||
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, b.at(1), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
|
||||
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut res_big, tmp_bytes);
|
||||
}
|
||||
|
||||
pub fn key_switch_grlwe(module: &Module, c: &mut Ciphertext<VecZnx>, a: &Ciphertext<VecZnx>, b: &Ciphertext<VmpPMat>) {}
|
||||
|
||||
pub fn key_switch_rgsw(module: &Module, c: &mut Ciphertext<VecZnx>, a: &Ciphertext<VecZnx>, b: &Ciphertext<VmpPMat>) {}
|
||||
@@ -1,82 +0,0 @@
|
||||
use crate::ciphertext::{Ciphertext, new_gadget_ciphertext};
|
||||
use crate::elem::{Elem, ElemCommon};
|
||||
use crate::encryptor::{encrypt_rlwe_sk, encrypt_rlwe_sk_tmp_bytes};
|
||||
use base2k::{Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VmpPMat};
|
||||
use sampling::source::Source;
|
||||
|
||||
pub struct SecretKey(pub Scalar);
|
||||
|
||||
impl SecretKey {
|
||||
pub fn new(module: &Module) -> Self {
|
||||
SecretKey(Scalar::new(module.n()))
|
||||
}
|
||||
|
||||
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
|
||||
self.0.fill_ternary_prob(prob, source);
|
||||
}
|
||||
|
||||
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
|
||||
self.0.fill_ternary_hw(hw, source);
|
||||
}
|
||||
|
||||
pub fn prepare(&self, module: &Module, sk_ppol: &mut SvpPPol) {
|
||||
module.svp_prepare(sk_ppol, &self.0)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PublicKey(pub Elem<VecZnx>);
|
||||
|
||||
impl PublicKey {
|
||||
pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> PublicKey {
|
||||
PublicKey(Elem::<VecZnx>::new(module, log_base2k, log_q, 2))
|
||||
}
|
||||
|
||||
pub fn gen_thread_safe(
|
||||
&mut self,
|
||||
module: &Module,
|
||||
sk: &SvpPPol,
|
||||
xe: f64,
|
||||
xa_source: &mut Source,
|
||||
xe_source: &mut Source,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
encrypt_rlwe_sk(
|
||||
module,
|
||||
&mut self.0,
|
||||
None,
|
||||
sk,
|
||||
xa_source,
|
||||
xe_source,
|
||||
xe,
|
||||
tmp_bytes,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn gen_thread_safe_tmp_bytes(module: &Module, log_base2k: usize, log_q: usize) -> usize {
|
||||
encrypt_rlwe_sk_tmp_bytes(module, log_base2k, log_q)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SwitchingKey(pub Ciphertext<VmpPMat>);
|
||||
|
||||
impl SwitchingKey {
|
||||
pub fn new(module: &Module, log_base2k: usize, rows: usize, log_q: usize) -> SwitchingKey {
|
||||
SwitchingKey(new_gadget_ciphertext(module, log_base2k, rows, log_q))
|
||||
}
|
||||
|
||||
pub fn n(&self) -> usize {
|
||||
self.0.n()
|
||||
}
|
||||
|
||||
pub fn rows(&self) -> usize {
|
||||
self.0.rows()
|
||||
}
|
||||
|
||||
pub fn cols(&self) -> usize {
|
||||
self.0.cols()
|
||||
}
|
||||
|
||||
pub fn log_base2k(&self) -> usize {
|
||||
self.0.log_base2k()
|
||||
}
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
pub mod automorphism;
|
||||
pub mod ciphertext;
|
||||
pub mod decryptor;
|
||||
pub mod elem;
|
||||
pub mod encryptor;
|
||||
pub mod gadget_product;
|
||||
pub mod key_generator;
|
||||
pub mod key_switching;
|
||||
pub mod keys;
|
||||
pub mod parameters;
|
||||
pub mod plaintext;
|
||||
pub mod rgsw_product;
|
||||
pub mod trace;
|
||||
@@ -1,88 +0,0 @@
|
||||
use base2k::module::{BACKEND, Module};
|
||||
|
||||
pub const DEFAULT_SIGMA: f64 = 3.2;
|
||||
|
||||
pub struct ParametersLiteral {
|
||||
pub backend: BACKEND,
|
||||
pub log_n: usize,
|
||||
pub log_q: usize,
|
||||
pub log_p: usize,
|
||||
pub log_base2k: usize,
|
||||
pub log_scale: usize,
|
||||
pub xe: f64,
|
||||
pub xs: usize,
|
||||
}
|
||||
|
||||
pub struct Parameters {
|
||||
log_n: usize,
|
||||
log_q: usize,
|
||||
log_p: usize,
|
||||
log_scale: usize,
|
||||
log_base2k: usize,
|
||||
xe: f64,
|
||||
xs: usize,
|
||||
module: Module,
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
pub fn new(p: &ParametersLiteral) -> Self {
|
||||
assert!(
|
||||
p.log_n + 2 * p.log_base2k <= 53,
|
||||
"invalid parameters: p.log_n + 2*p.log_base2k > 53"
|
||||
);
|
||||
Self {
|
||||
log_n: p.log_n,
|
||||
log_q: p.log_q,
|
||||
log_p: p.log_p,
|
||||
log_scale: p.log_scale,
|
||||
log_base2k: p.log_base2k,
|
||||
xe: p.xe,
|
||||
xs: p.xs,
|
||||
module: Module::new(1 << p.log_n, p.backend),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn n(&self) -> usize {
|
||||
1 << self.log_n
|
||||
}
|
||||
|
||||
pub fn log_scale(&self) -> usize {
|
||||
self.log_scale
|
||||
}
|
||||
|
||||
pub fn log_q(&self) -> usize {
|
||||
self.log_q
|
||||
}
|
||||
|
||||
pub fn log_p(&self) -> usize {
|
||||
self.log_p
|
||||
}
|
||||
|
||||
pub fn log_qp(&self) -> usize {
|
||||
self.log_q + self.log_p
|
||||
}
|
||||
|
||||
pub fn cols_q(&self) -> usize {
|
||||
(self.log_q + self.log_base2k - 1) / self.log_base2k
|
||||
}
|
||||
|
||||
pub fn cols_qp(&self) -> usize {
|
||||
(self.log_q + self.log_p + self.log_base2k - 1) / self.log_base2k
|
||||
}
|
||||
|
||||
pub fn log_base2k(&self) -> usize {
|
||||
self.log_base2k
|
||||
}
|
||||
|
||||
pub fn module(&self) -> &Module {
|
||||
&self.module
|
||||
}
|
||||
|
||||
pub fn xe(&self) -> f64 {
|
||||
self.xe
|
||||
}
|
||||
|
||||
pub fn xs(&self) -> usize {
|
||||
self.xs
|
||||
}
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
use crate::ciphertext::Ciphertext;
|
||||
use crate::elem::{Elem, ElemCommon, ElemVecZnx};
|
||||
use crate::parameters::Parameters;
|
||||
use base2k::{LAYOUT, Module, VecZnx};
|
||||
|
||||
pub struct Plaintext(pub Elem<VecZnx>);
|
||||
|
||||
impl Parameters {
|
||||
pub fn new_plaintext(&self, log_q: usize) -> Plaintext {
|
||||
Plaintext::new(self.module(), self.log_base2k(), log_q)
|
||||
}
|
||||
|
||||
pub fn bytes_of_plaintext(&self, log_q: usize) -> usize
|
||||
where {
|
||||
Elem::<VecZnx>::bytes_of(self.module(), self.log_base2k(), log_q, 1)
|
||||
}
|
||||
|
||||
pub fn plaintext_from_bytes(&self, log_q: usize, bytes: &mut [u8]) -> Plaintext {
|
||||
Plaintext(Elem::<VecZnx>::from_bytes(
|
||||
self.module(),
|
||||
self.log_base2k(),
|
||||
log_q,
|
||||
1,
|
||||
bytes,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl Plaintext {
|
||||
pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self {
|
||||
Self(Elem::<VecZnx>::new(module, log_base2k, log_q, 1))
|
||||
}
|
||||
}
|
||||
|
||||
impl Plaintext {
|
||||
pub fn bytes_of(module: &Module, log_base2k: usize, log_q: usize) -> usize {
|
||||
Elem::<VecZnx>::bytes_of(module, log_base2k, log_q, 1)
|
||||
}
|
||||
|
||||
pub fn from_bytes(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self {
|
||||
Self(Elem::<VecZnx>::from_bytes(
|
||||
module, log_base2k, log_q, 1, bytes,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn from_bytes_borrow(module: &Module, log_base2k: usize, log_q: usize, bytes: &mut [u8]) -> Self {
|
||||
Self(Elem::<VecZnx>::from_bytes_borrow(
|
||||
module, log_base2k, log_q, 1, bytes,
|
||||
))
|
||||
}
|
||||
|
||||
pub fn as_ciphertext(&self) -> Ciphertext<VecZnx> {
|
||||
unsafe { Ciphertext::<VecZnx>(std::ptr::read(&self.0)) }
|
||||
}
|
||||
}
|
||||
|
||||
impl ElemCommon<VecZnx> for Plaintext {
|
||||
fn n(&self) -> usize {
|
||||
self.0.n()
|
||||
}
|
||||
|
||||
fn log_n(&self) -> usize {
|
||||
self.elem().log_n()
|
||||
}
|
||||
|
||||
fn log_q(&self) -> usize {
|
||||
self.0.log_q
|
||||
}
|
||||
|
||||
fn elem(&self) -> &Elem<VecZnx> {
|
||||
&self.0
|
||||
}
|
||||
|
||||
fn elem_mut(&mut self) -> &mut Elem<VecZnx> {
|
||||
&mut self.0
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.elem().size()
|
||||
}
|
||||
|
||||
fn layout(&self) -> LAYOUT {
|
||||
self.elem().layout()
|
||||
}
|
||||
|
||||
fn rows(&self) -> usize {
|
||||
self.0.rows()
|
||||
}
|
||||
|
||||
fn cols(&self) -> usize {
|
||||
self.0.cols()
|
||||
}
|
||||
|
||||
fn at(&self, i: usize) -> &VecZnx {
|
||||
self.0.at(i)
|
||||
}
|
||||
|
||||
fn at_mut(&mut self, i: usize) -> &mut VecZnx {
|
||||
self.0.at_mut(i)
|
||||
}
|
||||
|
||||
fn log_base2k(&self) -> usize {
|
||||
self.0.log_base2k()
|
||||
}
|
||||
|
||||
fn log_scale(&self) -> usize {
|
||||
self.0.log_scale()
|
||||
}
|
||||
}
|
||||
@@ -1,300 +0,0 @@
|
||||
use crate::{ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters};
|
||||
use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMat, VmpPMatOps, assert_alignement};
|
||||
use std::cmp::min;
|
||||
|
||||
impl Parameters {
|
||||
pub fn rgsw_product_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
|
||||
rgsw_product_tmp_bytes(
|
||||
self.module(),
|
||||
self.log_base2k(),
|
||||
res_logq,
|
||||
in_logq,
|
||||
gct_logq,
|
||||
)
|
||||
}
|
||||
}
|
||||
pub fn rgsw_product_tmp_bytes(module: &Module, log_base2k: usize, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
|
||||
let gct_cols: usize = (gct_logq + log_base2k - 1) / log_base2k;
|
||||
let in_cols: usize = (in_logq + log_base2k - 1) / log_base2k;
|
||||
let res_cols: usize = (res_logq + log_base2k - 1) / log_base2k;
|
||||
return module.vmp_apply_dft_to_dft_tmp_bytes(res_cols, in_cols, in_cols, gct_cols)
|
||||
+ module.bytes_of_vec_znx_dft(1, std::cmp::min(res_cols, in_cols))
|
||||
+ 2 * module.bytes_of_vec_znx_dft(1, gct_cols);
|
||||
}
|
||||
|
||||
pub fn rgsw_product(
|
||||
module: &Module,
|
||||
c: &mut Ciphertext<VecZnx>,
|
||||
a: &Ciphertext<VecZnx>,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
b_cols: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(b_cols <= b.cols());
|
||||
assert_eq!(c.size(), 2);
|
||||
assert_eq!(a.size(), 2);
|
||||
assert_eq!(b.size(), 4);
|
||||
assert!(tmp_bytes.len() >= rgsw_product_tmp_bytes(module, c.cols(), a.cols(), min(b.rows(), a.cols()), b_cols));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, a.cols()));
|
||||
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
|
||||
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
|
||||
|
||||
let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, a.cols(), tmp_bytes_ai_dft);
|
||||
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c0_dft);
|
||||
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c1_dft);
|
||||
|
||||
let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big();
|
||||
let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big();
|
||||
|
||||
module.vec_znx_dft(&mut ai_dft, a.at(0));
|
||||
module.vmp_apply_dft_to_dft(&mut c0_dft, &ai_dft, b.at(0), tmp_bytes);
|
||||
module.vmp_apply_dft_to_dft(&mut c1_dft, &ai_dft, b.at(1), tmp_bytes);
|
||||
|
||||
module.vec_znx_dft(&mut ai_dft, a.at(1));
|
||||
module.vmp_apply_dft_to_dft_add(&mut c0_dft, &ai_dft, b.at(2), tmp_bytes);
|
||||
module.vmp_apply_dft_to_dft_add(&mut c1_dft, &ai_dft, b.at(3), tmp_bytes);
|
||||
|
||||
module.vec_znx_idft_tmp_a(&mut c0_big, &mut c0_dft);
|
||||
module.vec_znx_idft_tmp_a(&mut c1_big, &mut c1_dft);
|
||||
|
||||
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(0), &mut c0_big, tmp_bytes);
|
||||
module.vec_znx_big_normalize(c.log_base2k(), c.at_mut(1), &mut c1_big, tmp_bytes);
|
||||
}
|
||||
|
||||
pub fn rgsw_product_inplace(
|
||||
module: &Module,
|
||||
a: &mut Ciphertext<VecZnx>,
|
||||
b: &Ciphertext<VmpPMat>,
|
||||
b_cols: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(b_cols <= b.cols());
|
||||
assert_eq!(a.size(), 2);
|
||||
assert_eq!(b.size(), 4);
|
||||
assert!(tmp_bytes.len() >= rgsw_product_tmp_bytes(module, a.cols(), a.cols(), min(b.rows(), a.cols()), b_cols));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let (tmp_bytes_ai_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, a.cols()));
|
||||
let (tmp_bytes_c0_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
|
||||
let (tmp_bytes_c1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
|
||||
|
||||
let mut ai_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, a.cols(), tmp_bytes_ai_dft);
|
||||
let mut c0_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c0_dft);
|
||||
let mut c1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_c1_dft);
|
||||
|
||||
let mut c0_big: VecZnxBig = c0_dft.as_vec_znx_big();
|
||||
let mut c1_big: VecZnxBig = c1_dft.as_vec_znx_big();
|
||||
|
||||
module.vec_znx_dft(&mut ai_dft, a.at(0));
|
||||
module.vmp_apply_dft_to_dft(&mut c0_dft, &ai_dft, b.at(0), tmp_bytes);
|
||||
module.vmp_apply_dft_to_dft(&mut c1_dft, &ai_dft, b.at(1), tmp_bytes);
|
||||
|
||||
module.vec_znx_dft(&mut ai_dft, a.at(1));
|
||||
module.vmp_apply_dft_to_dft_add(&mut c0_dft, &ai_dft, b.at(2), tmp_bytes);
|
||||
module.vmp_apply_dft_to_dft_add(&mut c1_dft, &ai_dft, b.at(3), tmp_bytes);
|
||||
|
||||
module.vec_znx_idft_tmp_a(&mut c0_big, &mut c0_dft);
|
||||
module.vec_znx_idft_tmp_a(&mut c1_big, &mut c1_dft);
|
||||
|
||||
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut c0_big, tmp_bytes);
|
||||
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut c1_big, tmp_bytes);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::{
|
||||
ciphertext::{Ciphertext, new_rgsw_ciphertext},
|
||||
decryptor::decrypt_rlwe,
|
||||
elem::ElemCommon,
|
||||
encryptor::{encrypt_rgsw_sk, encrypt_rlwe_sk},
|
||||
keys::SecretKey,
|
||||
parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral},
|
||||
plaintext::Plaintext,
|
||||
rgsw_product::rgsw_product_inplace,
|
||||
};
|
||||
use base2k::{BACKEND, Encoding, Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxOps, VmpPMat, alloc_aligned};
|
||||
use sampling::source::{Source, new_seed};
|
||||
|
||||
#[test]
|
||||
fn test_rgsw_product() {
|
||||
let log_base2k: usize = 10;
|
||||
let log_q: usize = 50;
|
||||
let log_p: usize = 15;
|
||||
|
||||
// Basic parameters with enough limbs to test edge cases
|
||||
let params_lit: ParametersLiteral = ParametersLiteral {
|
||||
backend: BACKEND::FFT64,
|
||||
log_n: 12,
|
||||
log_q: log_q,
|
||||
log_p: log_p,
|
||||
log_base2k: log_base2k,
|
||||
log_scale: 20,
|
||||
xe: 3.2,
|
||||
xs: 1 << 11,
|
||||
};
|
||||
|
||||
let params: Parameters = Parameters::new(¶ms_lit);
|
||||
|
||||
let module: &Module = params.module();
|
||||
let log_q: usize = params.log_q();
|
||||
let log_qp: usize = params.log_qp();
|
||||
let gct_rows: usize = params.cols_q();
|
||||
let gct_cols: usize = params.cols_qp();
|
||||
|
||||
// scratch space
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned(
|
||||
params.decrypt_rlwe_tmp_byte(log_q)
|
||||
| params.encrypt_rlwe_sk_tmp_bytes(log_q)
|
||||
| params.rgsw_product_tmp_bytes(log_q, log_q, log_qp)
|
||||
| params.encrypt_rgsw_sk_tmp_bytes(gct_rows, log_qp),
|
||||
);
|
||||
|
||||
// Samplers for public and private randomness
|
||||
let mut source_xe: Source = Source::new(new_seed());
|
||||
let mut source_xa: Source = Source::new(new_seed());
|
||||
let mut source_xs: Source = Source::new(new_seed());
|
||||
|
||||
let mut sk: SecretKey = SecretKey::new(module);
|
||||
sk.fill_ternary_hw(params.xs(), &mut source_xs);
|
||||
let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol();
|
||||
module.svp_prepare(&mut sk_svp_ppol, &sk.0);
|
||||
|
||||
let mut ct_rgsw: Ciphertext<VmpPMat> = new_rgsw_ciphertext(module, log_base2k, gct_rows, log_qp);
|
||||
|
||||
let k: i64 = 3;
|
||||
|
||||
// X^k
|
||||
let m: Scalar = module.new_scalar();
|
||||
let data: &mut [i64] = m.raw_mut();
|
||||
data[k as usize] = 1;
|
||||
|
||||
encrypt_rgsw_sk(
|
||||
module,
|
||||
&mut ct_rgsw,
|
||||
&m,
|
||||
&sk_svp_ppol,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
DEFAULT_SIGMA,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
let log_k: usize = 2 * log_base2k;
|
||||
|
||||
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
|
||||
let mut pt: Plaintext = params.new_plaintext(log_q);
|
||||
let mut pt_rotate: Plaintext = params.new_plaintext(log_q);
|
||||
|
||||
pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32);
|
||||
|
||||
module.vec_znx_rotate(k, pt_rotate.at_mut(0), pt.at_mut(0));
|
||||
|
||||
encrypt_rlwe_sk(
|
||||
module,
|
||||
&mut ct.elem_mut(),
|
||||
Some(pt.at(0)),
|
||||
&sk_svp_ppol,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
params.xe(),
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
rgsw_product_inplace(module, &mut ct, &ct_rgsw, gct_cols, &mut tmp_bytes);
|
||||
|
||||
decrypt_rlwe(
|
||||
module,
|
||||
pt.elem_mut(),
|
||||
ct.elem(),
|
||||
&sk_svp_ppol,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
module.vec_znx_sub_ba_inplace(pt.at_mut(0), pt_rotate.at(0));
|
||||
|
||||
// pt.at(0).print(pt.cols(), 16);
|
||||
|
||||
let noise_have: f64 = pt.at(0).std(0, log_base2k).log2();
|
||||
|
||||
let var_msg: f64 = 1f64 / params.n() as f64; // X^{k}
|
||||
let var_a0_err: f64 = params.xe() * params.xe();
|
||||
let var_a1_err: f64 = 1f64 / 12f64;
|
||||
|
||||
let noise_pred: f64 = params.noise_rgsw_product(var_msg, var_a0_err, var_a1_err, ct.log_q(), ct_rgsw.log_q());
|
||||
|
||||
println!("noise_pred: {}", noise_pred);
|
||||
println!("noise_have: {}", noise_have);
|
||||
|
||||
assert!(noise_have <= noise_pred + 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
pub fn noise_rgsw_product(&self, var_msg: f64, var_a0_err: f64, var_a1_err: f64, a_logq: usize, b_logq: usize) -> f64 {
|
||||
let n: f64 = self.n() as f64;
|
||||
let var_xs: f64 = self.xs() as f64;
|
||||
|
||||
let var_gct_err_lhs: f64;
|
||||
let var_gct_err_rhs: f64;
|
||||
if b_logq < self.log_qp() {
|
||||
let var_round: f64 = 1f64 / 12f64;
|
||||
var_gct_err_lhs = var_round;
|
||||
var_gct_err_rhs = var_round;
|
||||
} else {
|
||||
var_gct_err_lhs = self.xe() * self.xe();
|
||||
var_gct_err_rhs = 0f64;
|
||||
}
|
||||
|
||||
noise_rgsw_product(
|
||||
n,
|
||||
self.log_base2k(),
|
||||
var_xs,
|
||||
var_msg,
|
||||
var_a0_err,
|
||||
var_a1_err,
|
||||
var_gct_err_lhs,
|
||||
var_gct_err_rhs,
|
||||
a_logq,
|
||||
b_logq,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn noise_rgsw_product(
|
||||
n: f64,
|
||||
log_base2k: usize,
|
||||
var_xs: f64,
|
||||
var_msg: f64,
|
||||
var_a0_err: f64,
|
||||
var_a1_err: f64,
|
||||
var_gct_err_lhs: f64,
|
||||
var_gct_err_rhs: f64,
|
||||
a_logq: usize,
|
||||
b_logq: usize,
|
||||
) -> f64 {
|
||||
let a_logq: usize = min(a_logq, b_logq);
|
||||
let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k;
|
||||
|
||||
let b_scale = 2.0f64.powi(b_logq as i32);
|
||||
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
|
||||
|
||||
let base: f64 = (1 << (log_base2k)) as f64;
|
||||
let var_base: f64 = base * base / 12f64;
|
||||
|
||||
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
|
||||
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
|
||||
let mut noise: f64 = 2.0 * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
|
||||
noise += var_msg * var_a0_err * a_scale * a_scale * n;
|
||||
noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs;
|
||||
noise = noise.sqrt();
|
||||
noise /= b_scale;
|
||||
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
|
||||
}
|
||||
113
rlwe/src/test.rs
113
rlwe/src/test.rs
@@ -1,113 +0,0 @@
|
||||
use base2k::{alloc_aligned, SvpPPol, SvpPPolOps, VecZnx, BACKEND};
|
||||
use sampling::source::{Source, new_seed};
|
||||
use crate::{ciphertext::Ciphertext, decryptor::decrypt_rlwe, elem::ElemCommon, encryptor::encrypt_rlwe_sk, keys::SecretKey, parameters::{Parameters, ParametersLiteral, DEFAULT_SIGMA}, plaintext::Plaintext};
|
||||
|
||||
|
||||
|
||||
pub struct Context{
|
||||
pub params: Parameters,
|
||||
pub sk0: SecretKey,
|
||||
pub sk0_ppol:SvpPPol,
|
||||
pub sk1: SecretKey,
|
||||
pub sk1_ppol: SvpPPol,
|
||||
pub tmp_bytes: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Context{
|
||||
pub fn new(log_n: usize, log_base2k: usize, log_q: usize, log_p: usize) -> Self{
|
||||
|
||||
let params_lit: ParametersLiteral = ParametersLiteral {
|
||||
backend: BACKEND::FFT64,
|
||||
log_n: log_n,
|
||||
log_q: log_q,
|
||||
log_p: log_p,
|
||||
log_base2k: log_base2k,
|
||||
log_scale: 20,
|
||||
xe: DEFAULT_SIGMA,
|
||||
xs: 1 << (log_n-1),
|
||||
};
|
||||
|
||||
let params: Parameters =Parameters::new(¶ms_lit);
|
||||
let module = params.module();
|
||||
|
||||
let log_q: usize = params.log_q();
|
||||
|
||||
let mut source_xs: Source = Source::new(new_seed());
|
||||
|
||||
let mut sk0: SecretKey = SecretKey::new(module);
|
||||
sk0.fill_ternary_hw(params.xs(), &mut source_xs);
|
||||
let mut sk0_ppol: base2k::SvpPPol = module.new_svp_ppol();
|
||||
module.svp_prepare(&mut sk0_ppol, &sk0.0);
|
||||
|
||||
let mut sk1: SecretKey = SecretKey::new(module);
|
||||
sk1.fill_ternary_hw(params.xs(), &mut source_xs);
|
||||
let mut sk1_ppol: base2k::SvpPPol = module.new_svp_ppol();
|
||||
module.svp_prepare(&mut sk1_ppol, &sk1.0);
|
||||
|
||||
let tmp_bytes: Vec<u8> = alloc_aligned(params.decrypt_rlwe_tmp_byte(log_q)| params.encrypt_rlwe_sk_tmp_bytes(log_q));
|
||||
|
||||
Context{
|
||||
params: params,
|
||||
sk0: sk0,
|
||||
sk0_ppol: sk0_ppol,
|
||||
sk1: sk1,
|
||||
sk1_ppol: sk1_ppol,
|
||||
tmp_bytes: tmp_bytes,
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encrypt_rlwe_sk0(&mut self, pt: &Plaintext, ct: &mut Ciphertext<VecZnx>){
|
||||
|
||||
let mut source_xe: Source = Source::new(new_seed());
|
||||
let mut source_xa: Source = Source::new(new_seed());
|
||||
|
||||
encrypt_rlwe_sk(
|
||||
self.params.module(),
|
||||
ct.elem_mut(),
|
||||
Some(pt.elem()),
|
||||
&self.sk0_ppol,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
self.params.xe(),
|
||||
&mut self.tmp_bytes,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn encrypt_rlwe_sk1(&mut self, ct: &mut Ciphertext<VecZnx>, pt: &Plaintext){
|
||||
|
||||
let mut source_xe: Source = Source::new(new_seed());
|
||||
let mut source_xa: Source = Source::new(new_seed());
|
||||
|
||||
encrypt_rlwe_sk(
|
||||
self.params.module(),
|
||||
ct.elem_mut(),
|
||||
Some(pt.elem()),
|
||||
&self.sk1_ppol,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
self.params.xe(),
|
||||
&mut self.tmp_bytes,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn decrypt_sk0(&mut self, pt: &mut Plaintext, ct: &Ciphertext<VecZnx>){
|
||||
decrypt_rlwe(
|
||||
self.params.module(),
|
||||
pt.elem_mut(),
|
||||
ct.elem(),
|
||||
&self.sk0_ppol,
|
||||
&mut self.tmp_bytes,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn decrypt_sk1(&mut self, pt: &mut Plaintext, ct: &Ciphertext<VecZnx>){
|
||||
decrypt_rlwe(
|
||||
self.params.module(),
|
||||
pt.elem_mut(),
|
||||
ct.elem(),
|
||||
&self.sk1_ppol,
|
||||
&mut self.tmp_bytes,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,236 +0,0 @@
|
||||
use crate::{automorphism::AutomorphismKey, ciphertext::Ciphertext, elem::ElemCommon, parameters::Parameters};
|
||||
use base2k::{Module, VecZnx, VecZnxBig, VecZnxBigOps, VecZnxDft, VecZnxDftOps, VmpPMatOps, assert_alignement};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub fn trace_galois_elements(module: &Module) -> Vec<i64> {
|
||||
let mut gal_els: Vec<i64> = Vec::new();
|
||||
(0..module.log_n()).for_each(|i| {
|
||||
if i == 0 {
|
||||
gal_els.push(-1);
|
||||
} else {
|
||||
gal_els.push(module.galois_element(1 << (i - 1)));
|
||||
}
|
||||
});
|
||||
gal_els
|
||||
}
|
||||
|
||||
impl Parameters {
|
||||
pub fn trace_tmp_bytes(&self, res_logq: usize, in_logq: usize, gct_logq: usize) -> usize {
|
||||
self.automorphism_tmp_bytes(res_logq, in_logq, gct_logq)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn trace_tmp_bytes(module: &Module, c_cols: usize, a_cols: usize, b_rows: usize, b_cols: usize) -> usize {
|
||||
return module.vmp_apply_dft_to_dft_tmp_bytes(c_cols, a_cols, b_rows, b_cols)
|
||||
+ 2 * module.bytes_of_vec_znx_dft(1, std::cmp::min(c_cols, a_cols));
|
||||
}
|
||||
|
||||
pub fn trace_inplace(
|
||||
module: &Module,
|
||||
a: &mut Ciphertext<VecZnx>,
|
||||
start: usize,
|
||||
end: usize,
|
||||
b: &HashMap<i64, AutomorphismKey>,
|
||||
b_cols: usize,
|
||||
tmp_bytes: &mut [u8],
|
||||
) {
|
||||
let cols: usize = a.cols();
|
||||
|
||||
let b_rows: usize;
|
||||
|
||||
if let Some((_, key)) = b.iter().next() {
|
||||
b_rows = key.value.rows();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
println!("{} {}", b_cols, key.value.cols());
|
||||
assert!(b_cols <= key.value.cols())
|
||||
}
|
||||
} else {
|
||||
panic!("b: HashMap<i64, AutomorphismKey>, is empty")
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(start <= end);
|
||||
assert!(end <= module.n());
|
||||
assert!(tmp_bytes.len() >= trace_tmp_bytes(module, cols, cols, b_rows, b_cols));
|
||||
assert_alignement(tmp_bytes.as_ptr());
|
||||
}
|
||||
|
||||
let cols: usize = std::cmp::min(b_cols, a.cols());
|
||||
|
||||
let (tmp_bytes_b1_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, cols));
|
||||
let (tmp_bytes_res_dft, tmp_bytes) = tmp_bytes.split_at_mut(module.bytes_of_vec_znx_dft(1, b_cols));
|
||||
|
||||
let mut a1_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, cols, tmp_bytes_b1_dft);
|
||||
let mut res_dft: VecZnxDft = module.new_vec_znx_dft_from_bytes_borrow(1, b_cols, tmp_bytes_res_dft);
|
||||
let mut res_big: VecZnxBig = res_dft.as_vec_znx_big();
|
||||
|
||||
let log_base2k: usize = a.log_base2k();
|
||||
|
||||
(start..end).for_each(|i| {
|
||||
a.at_mut(0).rsh(log_base2k, 1, tmp_bytes);
|
||||
a.at_mut(1).rsh(log_base2k, 1, tmp_bytes);
|
||||
|
||||
let p: i64;
|
||||
if i == 0 {
|
||||
p = -1;
|
||||
} else {
|
||||
p = module.galois_element(1 << (i - 1));
|
||||
}
|
||||
|
||||
if let Some(key) = b.get(&p) {
|
||||
module.vec_znx_dft(&mut a1_dft, a.at(1));
|
||||
|
||||
// a[0] = NORMALIZE(a[0] + AUTO(a[0] + IDFT(<DFT(a[1]), key[0]>)))
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, key.value.at(0), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
|
||||
module.vec_znx_big_automorphism_inplace(p, &mut res_big);
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(0));
|
||||
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(0), &mut res_big, tmp_bytes);
|
||||
|
||||
// a[1] = NORMALIZE(a[1] + AUTO(IDFT(<DFT(a[1]), key[1]>)))
|
||||
module.vmp_apply_dft_to_dft(&mut res_dft, &a1_dft, key.value.at(1), tmp_bytes);
|
||||
module.vec_znx_idft_tmp_a(&mut res_big, &mut res_dft);
|
||||
module.vec_znx_big_automorphism_inplace(p, &mut res_big);
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, a.at(1));
|
||||
module.vec_znx_big_normalize(a.log_base2k(), a.at_mut(1), &mut res_big, tmp_bytes);
|
||||
} else {
|
||||
panic!("b[{}] is empty", p)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::{trace_galois_elements, trace_inplace};
|
||||
use crate::{
|
||||
automorphism::AutomorphismKey,
|
||||
ciphertext::Ciphertext,
|
||||
decryptor::decrypt_rlwe,
|
||||
elem::ElemCommon,
|
||||
encryptor::encrypt_rlwe_sk,
|
||||
keys::SecretKey,
|
||||
parameters::{DEFAULT_SIGMA, Parameters, ParametersLiteral},
|
||||
plaintext::Plaintext,
|
||||
};
|
||||
use base2k::{BACKEND, Encoding, Module, SvpPPol, SvpPPolOps, VecZnx, alloc_aligned};
|
||||
use sampling::source::{Source, new_seed};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn test_trace_inplace() {
|
||||
let log_base2k: usize = 10;
|
||||
let log_q: usize = 50;
|
||||
let log_p: usize = 15;
|
||||
|
||||
// Basic parameters with enough limbs to test edge cases
|
||||
let params_lit: ParametersLiteral = ParametersLiteral {
|
||||
backend: BACKEND::FFT64,
|
||||
log_n: 12,
|
||||
log_q: log_q,
|
||||
log_p: log_p,
|
||||
log_base2k: log_base2k,
|
||||
log_scale: 20,
|
||||
xe: 3.2,
|
||||
xs: 1 << 11,
|
||||
};
|
||||
|
||||
let params: Parameters = Parameters::new(¶ms_lit);
|
||||
|
||||
let module: &Module = params.module();
|
||||
let log_q: usize = params.log_q();
|
||||
let log_qp: usize = params.log_qp();
|
||||
let gct_rows: usize = params.cols_q();
|
||||
let gct_cols: usize = params.cols_qp();
|
||||
|
||||
// scratch space
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned(
|
||||
params.decrypt_rlwe_tmp_byte(log_q)
|
||||
| params.encrypt_rlwe_sk_tmp_bytes(log_q)
|
||||
| params.automorphism_key_new_tmp_bytes(gct_rows, log_qp)
|
||||
| params.automorphism_tmp_bytes(log_q, log_q, log_qp),
|
||||
);
|
||||
|
||||
// Samplers for public and private randomness
|
||||
let mut source_xe: Source = Source::new(new_seed());
|
||||
let mut source_xa: Source = Source::new(new_seed());
|
||||
let mut source_xs: Source = Source::new(new_seed());
|
||||
|
||||
let mut sk: SecretKey = SecretKey::new(module);
|
||||
sk.fill_ternary_hw(params.xs(), &mut source_xs);
|
||||
let mut sk_svp_ppol: SvpPPol = module.new_svp_ppol();
|
||||
module.svp_prepare(&mut sk_svp_ppol, &sk.0);
|
||||
|
||||
let gal_els: Vec<i64> = trace_galois_elements(module);
|
||||
|
||||
let auto_keys: HashMap<i64, AutomorphismKey> = AutomorphismKey::new_many(
|
||||
module,
|
||||
&gal_els,
|
||||
&sk,
|
||||
log_base2k,
|
||||
gct_rows,
|
||||
log_qp,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
DEFAULT_SIGMA,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
let mut data: Vec<i64> = vec![0i64; params.n()];
|
||||
|
||||
data.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = 1 + i as i64);
|
||||
|
||||
let log_k: usize = 2 * log_base2k;
|
||||
|
||||
let mut ct: Ciphertext<VecZnx> = params.new_ciphertext(log_q);
|
||||
let mut pt: Plaintext = params.new_plaintext(log_q);
|
||||
|
||||
pt.at_mut(0).encode_vec_i64(0, log_base2k, log_k, &data, 32);
|
||||
pt.at_mut(0).normalize(log_base2k, &mut tmp_bytes);
|
||||
|
||||
pt.at(0).decode_vec_i64(0, log_base2k, log_k, &mut data);
|
||||
|
||||
pt.at(0).print(0, pt.cols(), 16);
|
||||
|
||||
encrypt_rlwe_sk(
|
||||
module,
|
||||
&mut ct.elem_mut(),
|
||||
Some(pt.at(0)),
|
||||
&sk_svp_ppol,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
params.xe(),
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
trace_inplace(module, &mut ct, 0, 4, &auto_keys, gct_cols, &mut tmp_bytes);
|
||||
trace_inplace(
|
||||
module,
|
||||
&mut ct,
|
||||
4,
|
||||
module.log_n(),
|
||||
&auto_keys,
|
||||
gct_cols,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
// pt = dec(auto(ct)) - auto(pt)
|
||||
decrypt_rlwe(
|
||||
module,
|
||||
pt.elem_mut(),
|
||||
ct.elem(),
|
||||
&sk_svp_ppol,
|
||||
&mut tmp_bytes,
|
||||
);
|
||||
|
||||
pt.at(0).print(0, pt.cols(), 16);
|
||||
|
||||
pt.at(0).decode_vec_i64(0, log_base2k, log_k, &mut data);
|
||||
|
||||
println!("trace: {:?}", &data[..16]);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user