Added API in poulpy for updated vmp_add (+tests)

This commit is contained in:
Jean-Philippe Bossuat
2025-06-04 11:39:11 +02:00
parent fcdc8f53d3
commit 159cd8025f
14 changed files with 216 additions and 82 deletions

View File

@@ -47,6 +47,7 @@ unsafe extern "C" {
pmat: *const VMP_PMAT, pmat: *const VMP_PMAT,
nrows: u64, nrows: u64,
ncols: u64, ncols: u64,
pmat_scale: u64,
tmp_space: *mut u8, tmp_space: *mut u8,
); );
} }
@@ -79,6 +80,7 @@ unsafe extern "C" {
pmat: *const VMP_PMAT, pmat: *const VMP_PMAT,
nrows: u64, nrows: u64,
ncols: u64, ncols: u64,
pmat_scale: u64,
tmp_space: *mut u8, tmp_space: *mut u8,
); );
} }

View File

@@ -101,7 +101,7 @@ pub trait MatZnxDftOps<BACKEND: Backend> {
B: MatZnxToRef<FFT64>; B: MatZnxToRef<FFT64>;
// Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R. // Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R.
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scale: usize, scratch: &mut Scratch)
where where
R: VecZnxDftToMut<FFT64>, R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>, A: VecZnxDftToRef<FFT64>,
@@ -309,7 +309,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
} }
} }
fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) fn vmp_apply_add<R, A, B>(&self, res: &mut R, a: &A, b: &B, scale: usize, scratch: &mut Scratch)
where where
R: VecZnxDftToMut<FFT64>, R: VecZnxDftToMut<FFT64>,
A: VecZnxDftToRef<FFT64>, A: VecZnxDftToRef<FFT64>,
@@ -358,6 +358,7 @@ impl MatZnxDftOps<FFT64> for Module<FFT64> {
b.as_ptr() as *const vmp::vmp_pmat_t, b.as_ptr() as *const vmp::vmp_pmat_t,
(b.rows() * b.cols_in()) as u64, (b.rows() * b.cols_in()) as u64,
(b.size() * b.cols_out()) as u64, (b.size() * b.cols_out()) as u64,
scale as u64,
tmp_bytes.as_mut_ptr(), tmp_bytes.as_mut_ptr(),
) )
} }
@@ -368,6 +369,7 @@ mod tests {
use crate::{ use crate::{
Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig,
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut,
ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -409,7 +411,7 @@ mod tests {
let basek: usize = 15; let basek: usize = 15;
let a_size: usize = 5; let a_size: usize = 5;
let mat_size: usize = 6; let mat_size: usize = 6;
let res_size: usize = 5; let res_size: usize = a_size;
[1, 2].iter().for_each(|in_cols| { [1, 2].iter().for_each(|in_cols| {
[1, 2].iter().for_each(|out_cols| { [1, 2].iter().for_each(|out_cols| {
@@ -419,7 +421,6 @@ mod tests {
let mat_rows: usize = a_size; let mat_rows: usize = a_size;
let mat_cols_in: usize = a_cols; let mat_cols_in: usize = a_cols;
let mat_cols_out: usize = res_cols; let mat_cols_out: usize = res_cols;
let res_cols: usize = mat_cols_out;
let mut scratch: ScratchOwned = ScratchOwned::new( let mut scratch: ScratchOwned = ScratchOwned::new(
module.vmp_apply_tmp_bytes( module.vmp_apply_tmp_bytes(
@@ -435,7 +436,7 @@ mod tests {
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size); let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
(0..a_cols).for_each(|i| { (0..a_cols).for_each(|i| {
a.at_mut(i, 2)[i + 1] = 1; a.at_mut(i, a_size - 1)[i + 1] = 1;
}); });
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> = let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
@@ -479,7 +480,100 @@ mod tests {
(0..a_cols).for_each(|i| { (0..a_cols).for_each(|i| {
res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1; res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1;
}); });
res_have.decode_vec_i64(col_i, basek, basek * 3, &mut res_have_vi64); res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64);
assert_eq!(res_have_vi64, res_want_vi64);
});
});
});
}
#[test]
fn vmp_apply_add() {
let log_n: i32 = 5;
let n: usize = 1 << log_n;
let module: Module<FFT64> = Module::<FFT64>::new(n);
let basek: usize = 15;
let a_size: usize = 5;
let mat_size: usize = 6;
let res_size: usize = a_size;
[1, 2].iter().for_each(|in_cols| {
[1, 2].iter().for_each(|out_cols| {
let a_cols: usize = *in_cols;
let res_cols: usize = *out_cols;
let mat_rows: usize = a_size;
let mat_cols_in: usize = a_cols;
let mat_cols_out: usize = res_cols;
let mut scratch: ScratchOwned = ScratchOwned::new(
module.vmp_apply_tmp_bytes(
res_size,
a_size,
mat_rows,
mat_cols_in,
mat_cols_out,
mat_size,
) | module.vec_znx_big_normalize_tmp_bytes(),
);
let mut a: VecZnx<Vec<u8>> = module.new_vec_znx(a_cols, a_size);
(0..a_cols).for_each(|i| {
a.at_mut(i, a_size - 1)[i + 1] = 1;
});
let mut mat_znx_dft: MatZnxDft<Vec<u8>, FFT64> =
module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size);
let mut c_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size);
let mut c_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size);
let mut tmp: VecZnx<Vec<u8>> = module.new_vec_znx(mat_cols_out, mat_size);
// Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix.
(0..a.size()).for_each(|row_i| {
(0..mat_cols_in).for_each(|col_in_i| {
(0..mat_cols_out).for_each(|col_out_i| {
let idx = 1 + col_in_i * mat_cols_out + col_out_i;
tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx}
module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i);
tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64;
});
module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft);
});
});
let mut a_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(a_cols, a_size);
(0..a_cols).for_each(|i| {
module.vec_znx_dft(&mut a_dft, i, &a, i);
});
c_dft.zero();
(0..c_dft.cols()).for_each(|i| {
module.vec_znx_dft(&mut c_dft, i, &a, 0);
});
module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, 0, scratch.borrow());
let mut res_have_vi64: Vec<i64> = vec![i64::default(); n];
let mut res_have: VecZnx<Vec<u8>> = module.new_vec_znx(res_cols, res_size);
(0..mat_cols_out).for_each(|i| {
module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i);
module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow());
});
(0..mat_cols_out).for_each(|col_i| {
let mut res_want_vi64: Vec<i64> = vec![i64::default(); n];
(0..a_cols).for_each(|i| {
res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1;
});
res_want_vi64[1] += 1;
res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64);
assert_eq!(res_have_vi64, res_want_vi64); assert_eq!(res_have_vi64, res_want_vi64);
}); });
}); });

View File

@@ -12,15 +12,15 @@ pub struct AutomorphismKey<Data, B: Backend> {
} }
impl AutomorphismKey<Vec<u8>, FFT64> { impl AutomorphismKey<Vec<u8>, FFT64> {
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> Self { pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self {
AutomorphismKey { AutomorphismKey {
key: GLWESwitchingKey::alloc(module, basek, k, rows, rank, rank), key: GLWESwitchingKey::alloc(module, basek, k, rows, digits, rank, rank),
p: 0, p: 0,
} }
} }
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> usize { pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize {
GLWESwitchingKey::<Vec<u8>, FFT64>::bytes_of(module, basek, k, rows, rank, rank) GLWESwitchingKey::<Vec<u8>, FFT64>::bytes_of(module, basek, k, rows, digits,rank, rank)
} }
} }
@@ -45,6 +45,10 @@ impl<T, B: Backend> AutomorphismKey<T, B> {
self.p self.p
} }
pub fn digits(&self) -> usize {
self.key.digits()
}
pub fn rank(&self) -> usize { pub fn rank(&self) -> usize {
self.key.rank() self.key.rank()
} }

View File

@@ -1,6 +1,6 @@
use backend::{Backend, Module, ZnxInfos}; use backend::{Backend, Module, ZnxInfos};
use crate::{GLWECiphertextFourier, derive_size}; use crate::{GLWECiphertextFourier, div_ceil};
pub trait Infos { pub trait Infos {
type Inner: ZnxInfos; type Inner: ZnxInfos;
@@ -34,7 +34,7 @@ pub trait Infos {
/// Returns the number of size per polynomial. /// Returns the number of size per polynomial.
fn size(&self) -> usize { fn size(&self) -> usize {
let size: usize = self.inner().size(); let size: usize = self.inner().size();
debug_assert_eq!(size, derive_size(self.basek(), self.k())); debug_assert_eq!(size, div_ceil(self.basek(), self.k()));
size size
} }

View File

@@ -4,25 +4,27 @@ use backend::{
}; };
use sampling::source::Source; use sampling::source::Source;
use crate::{GLWECiphertext, GLWECiphertextFourier, GLWESecret, GetRow, Infos, ScratchCore, SetRow, derive_size}; use crate::{GLWECiphertext, GLWECiphertextFourier, GLWESecret, GetRow, Infos, ScratchCore, SetRow, div_ceil};
pub struct GGLWECiphertext<C, B: Backend> { pub struct GGLWECiphertext<C, B: Backend> {
pub(crate) data: MatZnxDft<C, B>, pub(crate) data: MatZnxDft<C, B>,
pub(crate) basek: usize, pub(crate) basek: usize,
pub(crate) k: usize, pub(crate) k: usize,
pub(crate) digits: usize,
} }
impl<B: Backend> GGLWECiphertext<Vec<u8>, B> { impl<B: Backend> GGLWECiphertext<Vec<u8>, B> {
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { pub fn alloc(module: &Module<B>, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self {
Self { Self {
data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, derive_size(basek, k)), data: module.new_mat_znx_dft(div_ceil(rows, digits), rank_in, rank_out + 1, div_ceil(basek, k)),
basek: basek, basek: basek,
k, k,
digits,
} }
} }
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> usize { pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> usize {
module.bytes_of_mat_znx_dft(rows, rank_in, rank_out + 1, derive_size(basek, k)) module.bytes_of_mat_znx_dft(div_ceil(rows, digits), rank_in, rank_out + 1, div_ceil(basek, k))
} }
} }
@@ -47,6 +49,10 @@ impl<T, B: Backend> GGLWECiphertext<T, B> {
self.data.cols_out() - 1 self.data.cols_out() - 1
} }
pub fn digits(&self) -> usize{
self.digits
}
pub fn rank_in(&self) -> usize { pub fn rank_in(&self) -> usize {
self.data.cols_in() self.data.cols_in()
} }
@@ -58,7 +64,7 @@ impl<T, B: Backend> GGLWECiphertext<T, B> {
impl GGLWECiphertext<Vec<u8>, FFT64> { impl GGLWECiphertext<Vec<u8>, FFT64> {
pub fn generate_from_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize { pub fn generate_from_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
let size = derive_size(basek, k); let size = div_ceil(basek, k);
GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
+ module.bytes_of_vec_znx(rank + 1, size) + module.bytes_of_vec_znx(rank + 1, size)
+ module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx(1, size)
@@ -101,6 +107,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGLWECiphertext<DataSelf, FFT64> {
} }
let rows: usize = self.rows(); let rows: usize = self.rows();
let digits: usize = self.digits();
let basek: usize = self.basek(); let basek: usize = self.basek();
let k: usize = self.k(); let k: usize = self.k();
let rank_in: usize = self.rank_in(); let rank_in: usize = self.rank_in();
@@ -125,7 +132,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGLWECiphertext<DataSelf, FFT64> {
(0..rows).for_each(|row_i| { (0..rows).for_each(|row_i| {
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
tmp_pt.data.zero(); // zeroes for next iteration tmp_pt.data.zero(); // zeroes for next iteration
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i, pt, col_i); // Selects the i-th module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i * digits, pt, col_i);
module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_3); module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_3);
// rlwe encrypt of vec_znx_pt into vec_znx_ct // rlwe encrypt of vec_znx_pt into vec_znx_ct

View File

@@ -7,26 +7,28 @@ use sampling::source::Source;
use crate::{ use crate::{
AutomorphismKey, GLWECiphertext, GLWECiphertextFourier, GLWESecret, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow, AutomorphismKey, GLWECiphertext, GLWECiphertextFourier, GLWESecret, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow,
TensorKey, derive_size, TensorKey, div_ceil,
}; };
pub struct GGSWCiphertext<C, B: Backend> { pub struct GGSWCiphertext<C, B: Backend> {
pub data: MatZnxDft<C, B>, pub(crate) data: MatZnxDft<C, B>,
pub basek: usize, pub(crate) basek: usize,
pub k: usize, pub(crate) k: usize,
pub(crate) digits: usize,
} }
impl<B: Backend> GGSWCiphertext<Vec<u8>, B> { impl<B: Backend> GGSWCiphertext<Vec<u8>, B> {
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rows: usize, rank: usize) -> Self { pub fn alloc(module: &Module<B>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self {
Self { Self {
data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k)), data: module.new_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k)),
basek: basek, basek,
k: k, k: k,
digits,
} }
} }
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> usize { pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize {
module.bytes_of_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k)) module.bytes_of_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k))
} }
} }
@@ -50,11 +52,15 @@ impl<T, B: Backend> GGSWCiphertext<T, B> {
pub fn rank(&self) -> usize { pub fn rank(&self) -> usize {
self.data.cols_out() - 1 self.data.cols_out() - 1
} }
pub fn digits(&self) -> usize {
self.digits
}
} }
impl GGSWCiphertext<Vec<u8>, FFT64> { impl GGSWCiphertext<Vec<u8>, FFT64> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize { pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
let size = derive_size(basek, k); let size = div_ceil(basek, k);
GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
+ module.bytes_of_vec_znx(rank + 1, size) + module.bytes_of_vec_znx(rank + 1, size)
+ module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx(1, size)
@@ -68,8 +74,8 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
tsk_k: usize, tsk_k: usize,
rank: usize, rank: usize,
) -> usize { ) -> usize {
let tsk_size: usize = derive_size(basek, tsk_k); let tsk_size: usize = div_ceil(basek, tsk_k);
let self_size: usize = derive_size(basek, self_k); let self_size: usize = div_ceil(basek, self_k);
let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size); let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size);
let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size); let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size);
let vmp: usize = tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tsk_size); let vmp: usize = tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tsk_size);
@@ -87,7 +93,7 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
rank: usize, rank: usize,
) -> usize { ) -> usize {
GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, out_k, rank, in_k, rank, ksk_k) GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, out_k, rank, in_k, rank, ksk_k)
+ module.bytes_of_vec_znx_dft(rank + 1, derive_size(basek, in_k)) + module.bytes_of_vec_znx_dft(rank + 1, div_ceil(basek, in_k))
} }
pub fn keyswitch_scratch_space( pub fn keyswitch_scratch_space(
@@ -99,7 +105,7 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
tsk_k: usize, tsk_k: usize,
rank: usize, rank: usize,
) -> usize { ) -> usize {
let out_size: usize = derive_size(basek, out_k); let out_size: usize = div_ceil(basek, out_k);
let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size); let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size);
let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
@@ -130,7 +136,7 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
rank: usize, rank: usize,
) -> usize { ) -> usize {
let cols: usize = rank + 1; let cols: usize = rank + 1;
let out_size: usize = derive_size(basek, out_k); let out_size: usize = div_ceil(basek, out_k);
let res: usize = module.bytes_of_vec_znx(cols, out_size); let res: usize = module.bytes_of_vec_znx(cols, out_size);
let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size); let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size); let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
@@ -199,6 +205,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
let basek: usize = self.basek(); let basek: usize = self.basek();
let k: usize = self.k(); let k: usize = self.k();
let rank: usize = self.rank(); let rank: usize = self.rank();
let digits: usize = self.digits();
let (mut tmp_pt, scratch1) = scratch.tmp_glwe_pt(module, basek, k); let (mut tmp_pt, scratch1) = scratch.tmp_glwe_pt(module, basek, k);
let (mut tmp_ct, scratch2) = scratch1.tmp_glwe_ct(module, basek, k, rank); let (mut tmp_ct, scratch2) = scratch1.tmp_glwe_ct(module, basek, k, rank);
@@ -207,7 +214,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
tmp_pt.data.zero(); tmp_pt.data.zero();
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i, pt, 0); module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i * digits, pt, 0);
module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch2); module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch2);
(0..rank + 1).for_each(|col_j| { (0..rank + 1).for_each(|col_j| {

View File

@@ -7,7 +7,7 @@ use sampling::source::Source;
use crate::{ use crate::{
AutomorphismKey, GGSWCiphertext, GLWECiphertextFourier, GLWEOps, GLWEPlaintext, GLWEPublicKey, GLWESecret, GLWESwitchingKey, AutomorphismKey, GGSWCiphertext, GLWECiphertextFourier, GLWEOps, GLWEPlaintext, GLWEPublicKey, GLWESecret, GLWESwitchingKey,
Infos, SIX_SIGMA, SecretDistribution, SetMetaData, derive_size, Infos, SIX_SIGMA, SecretDistribution, SetMetaData, div_ceil,
}; };
pub struct GLWECiphertext<C> { pub struct GLWECiphertext<C> {
@@ -19,14 +19,14 @@ pub struct GLWECiphertext<C> {
impl GLWECiphertext<Vec<u8>> { impl GLWECiphertext<Vec<u8>> {
pub fn alloc<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self { pub fn alloc<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self {
Self { Self {
data: module.new_vec_znx(rank + 1, derive_size(basek, k)), data: module.new_vec_znx(rank + 1, div_ceil(basek, k)),
basek, basek,
k, k,
} }
} }
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize { pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
module.bytes_of_vec_znx(rank + 1, derive_size(basek, k)) module.bytes_of_vec_znx(rank + 1, div_ceil(basek, k))
} }
} }
@@ -69,18 +69,18 @@ impl<C: AsRef<[u8]>> GLWECiphertext<C> {
impl GLWECiphertext<Vec<u8>> { impl GLWECiphertext<Vec<u8>> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize { pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
let size: usize = derive_size(basek, k); let size: usize = div_ceil(basek, k);
module.vec_znx_big_normalize_tmp_bytes() + module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx(1, size) module.vec_znx_big_normalize_tmp_bytes() + module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx(1, size)
} }
pub fn encrypt_pk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize { pub fn encrypt_pk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
let size: usize = derive_size(basek, k); let size: usize = div_ceil(basek, k);
((module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx_big(1, size)) | module.bytes_of_scalar_znx(1)) ((module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx_big(1, size)) | module.bytes_of_scalar_znx(1))
+ module.bytes_of_scalar_znx_dft(1) + module.bytes_of_scalar_znx_dft(1)
+ module.vec_znx_big_normalize_tmp_bytes() + module.vec_znx_big_normalize_tmp_bytes()
} }
pub fn decrypt_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize { pub fn decrypt_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
let size: usize = derive_size(basek, k); let size: usize = div_ceil(basek, k);
(module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size)
} }
@@ -94,9 +94,9 @@ impl GLWECiphertext<Vec<u8>> {
ksk_k: usize, ksk_k: usize,
) -> usize { ) -> usize {
let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, out_rank); let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, out_rank);
let in_size: usize = derive_size(basek, in_k); let in_size: usize = div_ceil(basek, in_k);
let out_size: usize = derive_size(basek, out_k); let out_size: usize = div_ceil(basek, out_k);
let ksk_size: usize = derive_size(basek, ksk_k); let ksk_size: usize = div_ceil(basek, ksk_k);
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size) let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size)
+ module.bytes_of_vec_znx_dft(in_rank, in_size); + module.bytes_of_vec_znx_dft(in_rank, in_size);
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
@@ -155,9 +155,9 @@ impl GLWECiphertext<Vec<u8>> {
rank: usize, rank: usize,
) -> usize { ) -> usize {
let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank); let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank);
let in_size: usize = derive_size(basek, in_k); let in_size: usize = div_ceil(basek, in_k);
let out_size: usize = derive_size(basek, out_k); let out_size: usize = div_ceil(basek, out_k);
let ggsw_size: usize = derive_size(basek, ggsw_k); let ggsw_size: usize = div_ceil(basek, ggsw_k);
let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size)
+ module.vmp_apply_tmp_bytes( + module.vmp_apply_tmp_bytes(
out_size, out_size,

View File

@@ -4,7 +4,7 @@ use backend::{
}; };
use sampling::source::Source; use sampling::source::Source;
use crate::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, ScratchCore, derive_size}; use crate::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, ScratchCore, div_ceil};
pub struct GLWECiphertextFourier<C, B: Backend> { pub struct GLWECiphertextFourier<C, B: Backend> {
pub data: VecZnxDft<C, B>, pub data: VecZnxDft<C, B>,
@@ -15,14 +15,14 @@ pub struct GLWECiphertextFourier<C, B: Backend> {
impl<B: Backend> GLWECiphertextFourier<Vec<u8>, B> { impl<B: Backend> GLWECiphertextFourier<Vec<u8>, B> {
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self { pub fn alloc(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self {
Self { Self {
data: module.new_vec_znx_dft(rank + 1, derive_size(basek, k)), data: module.new_vec_znx_dft(rank + 1, div_ceil(basek, k)),
basek: basek, basek: basek,
k: k, k: k,
} }
} }
pub fn bytes_of(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize { pub fn bytes_of(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize {
module.bytes_of_vec_znx_dft(rank + 1, derive_size(basek, k)) module.bytes_of_vec_znx_dft(rank + 1, div_ceil(basek, k))
} }
} }
@@ -51,16 +51,16 @@ impl<T, B: Backend> GLWECiphertextFourier<T, B> {
impl GLWECiphertextFourier<Vec<u8>, FFT64> { impl GLWECiphertextFourier<Vec<u8>, FFT64> {
#[allow(dead_code)] #[allow(dead_code)]
pub(crate) fn idft_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize { pub(crate) fn idft_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
module.bytes_of_vec_znx(1, derive_size(basek, k)) module.bytes_of_vec_znx(1, div_ceil(basek, k))
+ (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())
} }
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize { pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
module.bytes_of_vec_znx(rank + 1, derive_size(basek, k)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) module.bytes_of_vec_znx(rank + 1, div_ceil(basek, k)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
} }
pub fn decrypt_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize { pub fn decrypt_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
let size: usize = derive_size(basek, k); let size: usize = div_ceil(basek, k);
(module.vec_znx_big_normalize_tmp_bytes() (module.vec_znx_big_normalize_tmp_bytes()
| module.bytes_of_vec_znx_dft(1, size) | module.bytes_of_vec_znx_dft(1, size)
| (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes()))
@@ -99,9 +99,9 @@ impl GLWECiphertextFourier<Vec<u8>, FFT64> {
rank: usize, rank: usize,
) -> usize { ) -> usize {
let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank); let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank);
let out_size: usize = derive_size(basek, out_k); let out_size: usize = div_ceil(basek, out_k);
let in_size: usize = derive_size(basek, in_k); let in_size: usize = div_ceil(basek, in_k);
let ggsw_size: usize = derive_size(basek, ggsw_k); let ggsw_size: usize = div_ceil(basek, ggsw_k);
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size); let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size);
let res_small: usize = GLWECiphertext::bytes_of(module, basek, out_k, rank); let res_small: usize = GLWECiphertext::bytes_of(module, basek, out_k, rank);
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();

View File

@@ -1,6 +1,6 @@
use backend::{Backend, FFT64, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef}; use backend::{Backend, FFT64, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef};
use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData, derive_size}; use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData, div_ceil};
pub struct GLWEPlaintext<C> { pub struct GLWEPlaintext<C> {
pub data: VecZnx<C>, pub data: VecZnx<C>,
@@ -37,14 +37,14 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> SetMetaData for GLWEPlaintext<DataSelf
impl GLWEPlaintext<Vec<u8>> { impl GLWEPlaintext<Vec<u8>> {
pub fn alloc<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> Self { pub fn alloc<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> Self {
Self { Self {
data: module.new_vec_znx(1, derive_size(basek, k)), data: module.new_vec_znx(1, div_ceil(basek, k)),
basek: basek, basek: basek,
k, k,
} }
} }
pub fn byte_of(module: &Module<FFT64>, basek: usize, k: usize) -> usize { pub fn byte_of(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
module.bytes_of_vec_znx(1, derive_size(basek, k)) module.bytes_of_vec_znx(1, div_ceil(basek, k))
} }
} }

View File

@@ -6,14 +6,14 @@ use crate::{GGLWECiphertext, GGSWCiphertext, GLWECiphertextFourier, GLWESecret,
pub struct GLWESwitchingKey<Data, B: Backend>(pub(crate) GGLWECiphertext<Data, B>); pub struct GLWESwitchingKey<Data, B: Backend>(pub(crate) GGLWECiphertext<Data, B>);
impl GLWESwitchingKey<Vec<u8>, FFT64> { impl GLWESwitchingKey<Vec<u8>, FFT64> {
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self {
GLWESwitchingKey(GGLWECiphertext::alloc( GLWESwitchingKey(GGLWECiphertext::alloc(
module, basek, k, rows, rank_in, rank_out, module, basek, k, rows, digits, rank_in, rank_out,
)) ))
} }
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> usize { pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> usize {
GGLWECiphertext::<Vec<u8>, FFT64>::bytes_of(module, basek, k, rows, rank_in, rank_out) GGLWECiphertext::<Vec<u8>, FFT64>::bytes_of(module, basek, k, rows, digits, rank_in, rank_out)
} }
} }
@@ -45,6 +45,10 @@ impl<T, B: Backend> GLWESwitchingKey<T, B> {
pub fn rank_out(&self) -> usize { pub fn rank_out(&self) -> usize {
self.0.data.cols_out() - 1 self.0.data.cols_out() - 1
} }
pub fn digits(&self) -> usize {
self.0.digits()
}
} }
impl<C: AsRef<[u8]>> GetRow<FFT64> for GLWESwitchingKey<C, FFT64> { impl<C: AsRef<[u8]>> GetRow<FFT64> for GLWESwitchingKey<C, FFT64> {

View File

@@ -33,7 +33,7 @@ pub use tensor_key::*;
pub use backend::Scratch; pub use backend::Scratch;
pub use backend::ScratchOwned; pub use backend::ScratchOwned;
use utils::derive_size; use utils::div_ceil;
pub(crate) const SIX_SIGMA: f64 = 6.0; pub(crate) const SIX_SIGMA: f64 = 6.0;
@@ -46,6 +46,7 @@ pub trait ScratchCore<B: Backend> {
basek: usize, basek: usize,
k: usize, k: usize,
rows: usize, rows: usize,
digits: usize,
rank_in: usize, rank_in: usize,
rank_out: usize, rank_out: usize,
) -> (GGLWECiphertext<&mut [u8], B>, &mut Self); ) -> (GGLWECiphertext<&mut [u8], B>, &mut Self);
@@ -55,6 +56,7 @@ pub trait ScratchCore<B: Backend> {
basek: usize, basek: usize,
k: usize, k: usize,
rows: usize, rows: usize,
digits: usize,
rank: usize, rank: usize,
) -> (GGSWCiphertext<&mut [u8], B>, &mut Self); ) -> (GGSWCiphertext<&mut [u8], B>, &mut Self);
fn tmp_glwe_fourier( fn tmp_glwe_fourier(
@@ -78,6 +80,7 @@ pub trait ScratchCore<B: Backend> {
basek: usize, basek: usize,
k: usize, k: usize,
rows: usize, rows: usize,
digits: usize,
rank_in: usize, rank_in: usize,
rank_out: usize, rank_out: usize,
) -> (GLWESwitchingKey<&mut [u8], B>, &mut Self); ) -> (GLWESwitchingKey<&mut [u8], B>, &mut Self);
@@ -87,6 +90,7 @@ pub trait ScratchCore<B: Backend> {
basek: usize, basek: usize,
k: usize, k: usize,
rows: usize, rows: usize,
digits: usize,
rank: usize, rank: usize,
) -> (TensorKey<&mut [u8], B>, &mut Self); ) -> (TensorKey<&mut [u8], B>, &mut Self);
fn tmp_autokey( fn tmp_autokey(
@@ -95,6 +99,7 @@ pub trait ScratchCore<B: Backend> {
basek: usize, basek: usize,
k: usize, k: usize,
rows: usize, rows: usize,
digits: usize,
rank: usize, rank: usize,
) -> (AutomorphismKey<&mut [u8], B>, &mut Self); ) -> (AutomorphismKey<&mut [u8], B>, &mut Self);
} }
@@ -107,12 +112,12 @@ impl ScratchCore<FFT64> for Scratch {
k: usize, k: usize,
rank: usize, rank: usize,
) -> (GLWECiphertext<&mut [u8]>, &mut Self) { ) -> (GLWECiphertext<&mut [u8]>, &mut Self) {
let (data, scratch) = self.tmp_vec_znx(module, rank + 1, derive_size(basek, k)); let (data, scratch) = self.tmp_vec_znx(module, rank + 1, div_ceil(basek, k));
(GLWECiphertext { data, basek, k }, scratch) (GLWECiphertext { data, basek, k }, scratch)
} }
fn tmp_glwe_pt(&mut self, module: &Module<FFT64>, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self) { fn tmp_glwe_pt(&mut self, module: &Module<FFT64>, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self) {
let (data, scratch) = self.tmp_vec_znx(module, 1, derive_size(basek, k)); let (data, scratch) = self.tmp_vec_znx(module, 1, div_ceil(basek, k));
(GLWEPlaintext { data, basek, k }, scratch) (GLWEPlaintext { data, basek, k }, scratch)
} }
@@ -122,15 +127,17 @@ impl ScratchCore<FFT64> for Scratch {
basek: usize, basek: usize,
k: usize, k: usize,
rows: usize, rows: usize,
digits: usize,
rank_in: usize, rank_in: usize,
rank_out: usize, rank_out: usize,
) -> (GGLWECiphertext<&mut [u8], FFT64>, &mut Self) { ) -> (GGLWECiphertext<&mut [u8], FFT64>, &mut Self) {
let (data, scratch) = self.tmp_mat_znx_dft(module, rows, rank_in, rank_out + 1, derive_size(basek, k)); let (data, scratch) = self.tmp_mat_znx_dft(module, div_ceil(rows, digits), rank_in, rank_out + 1, div_ceil(basek, k));
( (
GGLWECiphertext { GGLWECiphertext {
data: data, data: data,
basek: basek, basek: basek,
k, k,
digits,
}, },
scratch, scratch,
) )
@@ -142,14 +149,16 @@ impl ScratchCore<FFT64> for Scratch {
basek: usize, basek: usize,
k: usize, k: usize,
rows: usize, rows: usize,
digits: usize,
rank: usize, rank: usize,
) -> (GGSWCiphertext<&mut [u8], FFT64>, &mut Self) { ) -> (GGSWCiphertext<&mut [u8], FFT64>, &mut Self) {
let (data, scratch) = self.tmp_mat_znx_dft(module, rows, rank + 1, rank + 1, derive_size(basek, k)); let (data, scratch) = self.tmp_mat_znx_dft(module, div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k));
( (
GGSWCiphertext { GGSWCiphertext {
data: data, data,
basek: basek, basek,
k, k,
digits,
}, },
scratch, scratch,
) )
@@ -162,7 +171,7 @@ impl ScratchCore<FFT64> for Scratch {
k: usize, k: usize,
rank: usize, rank: usize,
) -> (GLWECiphertextFourier<&mut [u8], FFT64>, &mut Self) { ) -> (GLWECiphertextFourier<&mut [u8], FFT64>, &mut Self) {
let (data, scratch) = self.tmp_vec_znx_dft(module, rank + 1, derive_size(basek, k)); let (data, scratch) = self.tmp_vec_znx_dft(module, rank + 1, div_ceil(basek, k));
(GLWECiphertextFourier { data, basek, k }, scratch) (GLWECiphertextFourier { data, basek, k }, scratch)
} }
@@ -202,10 +211,11 @@ impl ScratchCore<FFT64> for Scratch {
basek: usize, basek: usize,
k: usize, k: usize,
rows: usize, rows: usize,
digits: usize,
rank_in: usize, rank_in: usize,
rank_out: usize, rank_out: usize,
) -> (GLWESwitchingKey<&mut [u8], FFT64>, &mut Self) { ) -> (GLWESwitchingKey<&mut [u8], FFT64>, &mut Self) {
let (data, scratch) = self.tmp_gglwe(module, basek, k, rows, rank_in, rank_out); let (data, scratch) = self.tmp_gglwe(module, basek, k, rows, digits, rank_in, rank_out);
(GLWESwitchingKey(data), scratch) (GLWESwitchingKey(data), scratch)
} }
@@ -215,9 +225,10 @@ impl ScratchCore<FFT64> for Scratch {
basek: usize, basek: usize,
k: usize, k: usize,
rows: usize, rows: usize,
digits: usize,
rank: usize, rank: usize,
) -> (AutomorphismKey<&mut [u8], FFT64>, &mut Self) { ) -> (AutomorphismKey<&mut [u8], FFT64>, &mut Self) {
let (data, scratch) = self.tmp_glwe_ksk(module, basek, k, rows, rank, rank); let (data, scratch) = self.tmp_glwe_ksk(module, basek, k, rows, digits, rank, rank);
(AutomorphismKey { key: data, p: 0 }, scratch) (AutomorphismKey { key: data, p: 0 }, scratch)
} }
@@ -227,6 +238,7 @@ impl ScratchCore<FFT64> for Scratch {
basek: usize, basek: usize,
k: usize, k: usize,
rows: usize, rows: usize,
digits: usize,
rank: usize, rank: usize,
) -> (TensorKey<&mut [u8], FFT64>, &mut Self) { ) -> (TensorKey<&mut [u8], FFT64>, &mut Self) {
let mut keys: Vec<GLWESwitchingKey<&mut [u8], FFT64>> = Vec::new(); let mut keys: Vec<GLWESwitchingKey<&mut [u8], FFT64>> = Vec::new();
@@ -235,12 +247,12 @@ impl ScratchCore<FFT64> for Scratch {
let mut scratch: &mut Scratch = self; let mut scratch: &mut Scratch = self;
if pairs != 0 { if pairs != 0 {
let (gglwe, s) = scratch.tmp_glwe_ksk(module, basek, k, rows, 1, rank); let (gglwe, s) = scratch.tmp_glwe_ksk(module, basek, k, rows, digits, 1, rank);
scratch = s; scratch = s;
keys.push(gglwe); keys.push(gglwe);
} }
for _ in 1..pairs { for _ in 1..pairs {
let (gglwe, s) = scratch.tmp_glwe_ksk(module, basek, k, rows, 1, rank); let (gglwe, s) = scratch.tmp_glwe_ksk(module, basek, k, rows, digits, 1, rank);
scratch = s; scratch = s;
keys.push(gglwe); keys.push(gglwe);
} }

View File

@@ -8,18 +8,18 @@ pub struct TensorKey<C, B: Backend> {
} }
impl TensorKey<Vec<u8>, FFT64> { impl TensorKey<Vec<u8>, FFT64> {
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> Self { pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self {
let mut keys: Vec<GLWESwitchingKey<Vec<u8>, FFT64>> = Vec::new(); let mut keys: Vec<GLWESwitchingKey<Vec<u8>, FFT64>> = Vec::new();
let pairs: usize = (((rank + 1) * rank) >> 1).max(1); let pairs: usize = (((rank + 1) * rank) >> 1).max(1);
(0..pairs).for_each(|_| { (0..pairs).for_each(|_| {
keys.push(GLWESwitchingKey::alloc(module, basek, k, rows, 1, rank)); keys.push(GLWESwitchingKey::alloc(module, basek, k, rows, digits,1, rank));
}); });
Self { keys: keys } Self { keys: keys }
} }
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> usize { pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize {
let pairs: usize = (((rank + 1) * rank) >> 1).max(1); let pairs: usize = (((rank + 1) * rank) >> 1).max(1);
pairs * GLWESwitchingKey::<Vec<u8>, FFT64>::bytes_of(module, basek, k, rows, 1, rank) pairs * GLWESwitchingKey::<Vec<u8>, FFT64>::bytes_of(module, basek, k, rows, digits,1, rank)
} }
} }
@@ -51,6 +51,10 @@ impl<T, B: Backend> TensorKey<T, B> {
pub fn rank_out(&self) -> usize { pub fn rank_out(&self) -> usize {
self.keys[0].rank_out() self.keys[0].rank_out()
} }
pub fn digits(&self) -> usize {
self.keys[0].digits()
}
} }
impl TensorKey<Vec<u8>, FFT64> { impl TensorKey<Vec<u8>, FFT64> {

View File

@@ -1,3 +1,3 @@
pub(crate) fn derive_size(basek: usize, k: usize) -> usize { pub(crate) fn div_ceil(a: usize, b: usize) -> usize {
(k + basek - 1) / basek (a + b - 1) / b
} }