Added API in poulpy for updated vmp_add (+tests)

This commit is contained in:
Jean-Philippe Bossuat
2025-06-04 11:39:11 +02:00
parent fcdc8f53d3
commit 159cd8025f
14 changed files with 216 additions and 82 deletions

View File

@@ -12,15 +12,15 @@ pub struct AutomorphismKey<Data, B: Backend> {
}
impl AutomorphismKey<Vec<u8>, FFT64> {
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self {
AutomorphismKey {
key: GLWESwitchingKey::alloc(module, basek, k, rows, rank, rank),
key: GLWESwitchingKey::alloc(module, basek, k, rows, digits, rank, rank),
p: 0,
}
}
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> usize {
GLWESwitchingKey::<Vec<u8>, FFT64>::bytes_of(module, basek, k, rows, rank, rank)
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize {
GLWESwitchingKey::<Vec<u8>, FFT64>::bytes_of(module, basek, k, rows, digits,rank, rank)
}
}
@@ -45,6 +45,10 @@ impl<T, B: Backend> AutomorphismKey<T, B> {
self.p
}
pub fn digits(&self) -> usize {
self.key.digits()
}
pub fn rank(&self) -> usize {
self.key.rank()
}

View File

@@ -1,6 +1,6 @@
use backend::{Backend, Module, ZnxInfos};
use crate::{GLWECiphertextFourier, derive_size};
use crate::{GLWECiphertextFourier, div_ceil};
pub trait Infos {
type Inner: ZnxInfos;
@@ -34,7 +34,7 @@ pub trait Infos {
/// Returns the number of size per polynomial.
fn size(&self) -> usize {
let size: usize = self.inner().size();
debug_assert_eq!(size, derive_size(self.basek(), self.k()));
debug_assert_eq!(size, div_ceil(self.basek(), self.k()));
size
}

View File

@@ -4,25 +4,27 @@ use backend::{
};
use sampling::source::Source;
use crate::{GLWECiphertext, GLWECiphertextFourier, GLWESecret, GetRow, Infos, ScratchCore, SetRow, derive_size};
use crate::{GLWECiphertext, GLWECiphertextFourier, GLWESecret, GetRow, Infos, ScratchCore, SetRow, div_ceil};
pub struct GGLWECiphertext<C, B: Backend> {
pub(crate) data: MatZnxDft<C, B>,
pub(crate) basek: usize,
pub(crate) k: usize,
pub(crate) digits: usize,
}
impl<B: Backend> GGLWECiphertext<Vec<u8>, B> {
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self {
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self {
Self {
data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, derive_size(basek, k)),
data: module.new_mat_znx_dft(div_ceil(rows, digits), rank_in, rank_out + 1, div_ceil(basek, k)),
basek: basek,
k,
digits,
}
}
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> usize {
module.bytes_of_mat_znx_dft(rows, rank_in, rank_out + 1, derive_size(basek, k))
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> usize {
module.bytes_of_mat_znx_dft(div_ceil(rows, digits), rank_in, rank_out + 1, div_ceil(basek, k))
}
}
@@ -47,6 +49,10 @@ impl<T, B: Backend> GGLWECiphertext<T, B> {
self.data.cols_out() - 1
}
pub fn digits(&self) -> usize{
self.digits
}
pub fn rank_in(&self) -> usize {
self.data.cols_in()
}
@@ -58,7 +64,7 @@ impl<T, B: Backend> GGLWECiphertext<T, B> {
impl GGLWECiphertext<Vec<u8>, FFT64> {
pub fn generate_from_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
let size = derive_size(basek, k);
let size = div_ceil(basek, k);
GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
+ module.bytes_of_vec_znx(rank + 1, size)
+ module.bytes_of_vec_znx(1, size)
@@ -101,6 +107,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGLWECiphertext<DataSelf, FFT64> {
}
let rows: usize = self.rows();
let digits: usize = self.digits();
let basek: usize = self.basek();
let k: usize = self.k();
let rank_in: usize = self.rank_in();
@@ -125,7 +132,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGLWECiphertext<DataSelf, FFT64> {
(0..rows).for_each(|row_i| {
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
tmp_pt.data.zero(); // zeroes for next iteration
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i, pt, col_i); // Selects the i-th
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i * digits, pt, col_i);
module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_3);
// rlwe encrypt of vec_znx_pt into vec_znx_ct

View File

@@ -7,26 +7,28 @@ use sampling::source::Source;
use crate::{
AutomorphismKey, GLWECiphertext, GLWECiphertextFourier, GLWESecret, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow,
TensorKey, derive_size,
TensorKey, div_ceil,
};
pub struct GGSWCiphertext<C, B: Backend> {
pub data: MatZnxDft<C, B>,
pub basek: usize,
pub k: usize,
pub(crate) data: MatZnxDft<C, B>,
pub(crate) basek: usize,
pub(crate) k: usize,
pub(crate) digits: usize,
}
impl<B: Backend> GGSWCiphertext<Vec<u8>, B> {
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self {
Self {
data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k)),
basek: basek,
data: module.new_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k)),
basek,
k: k,
digits,
}
}
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> usize {
module.bytes_of_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k))
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize {
module.bytes_of_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k))
}
}
@@ -50,11 +52,15 @@ impl<T, B: Backend> GGSWCiphertext<T, B> {
pub fn rank(&self) -> usize {
self.data.cols_out() - 1
}
pub fn digits(&self) -> usize {
self.digits
}
}
impl GGSWCiphertext<Vec<u8>, FFT64> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
let size = derive_size(basek, k);
let size = div_ceil(basek, k);
GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
+ module.bytes_of_vec_znx(rank + 1, size)
+ module.bytes_of_vec_znx(1, size)
@@ -68,8 +74,8 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
tsk_k: usize,
rank: usize,
) -> usize {
let tsk_size: usize = derive_size(basek, tsk_k);
let self_size: usize = derive_size(basek, self_k);
let tsk_size: usize = div_ceil(basek, tsk_k);
let self_size: usize = div_ceil(basek, self_k);
let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size);
let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size);
let vmp: usize = tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tsk_size);
@@ -87,7 +93,7 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
rank: usize,
) -> usize {
GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, out_k, rank, in_k, rank, ksk_k)
+ module.bytes_of_vec_znx_dft(rank + 1, derive_size(basek, in_k))
+ module.bytes_of_vec_znx_dft(rank + 1, div_ceil(basek, in_k))
}
pub fn keyswitch_scratch_space(
@@ -99,7 +105,7 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
tsk_k: usize,
rank: usize,
) -> usize {
let out_size: usize = derive_size(basek, out_k);
let out_size: usize = div_ceil(basek, out_k);
let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size);
let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
@@ -130,7 +136,7 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
rank: usize,
) -> usize {
let cols: usize = rank + 1;
let out_size: usize = derive_size(basek, out_k);
let out_size: usize = div_ceil(basek, out_k);
let res: usize = module.bytes_of_vec_znx(cols, out_size);
let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
@@ -199,6 +205,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
let basek: usize = self.basek();
let k: usize = self.k();
let rank: usize = self.rank();
let digits: usize = self.digits();
let (mut tmp_pt, scratch1) = scratch.tmp_glwe_pt(module, basek, k);
let (mut tmp_ct, scratch2) = scratch1.tmp_glwe_ct(module, basek, k, rank);
@@ -207,7 +214,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
tmp_pt.data.zero();
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i, pt, 0);
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i * digits, pt, 0);
module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch2);
(0..rank + 1).for_each(|col_j| {

View File

@@ -7,7 +7,7 @@ use sampling::source::Source;
use crate::{
AutomorphismKey, GGSWCiphertext, GLWECiphertextFourier, GLWEOps, GLWEPlaintext, GLWEPublicKey, GLWESecret, GLWESwitchingKey,
Infos, SIX_SIGMA, SecretDistribution, SetMetaData, derive_size,
Infos, SIX_SIGMA, SecretDistribution, SetMetaData, div_ceil,
};
pub struct GLWECiphertext<C> {
@@ -19,14 +19,14 @@ pub struct GLWECiphertext<C> {
impl GLWECiphertext<Vec<u8>> {
pub fn alloc<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self {
Self {
data: module.new_vec_znx(rank + 1, derive_size(basek, k)),
data: module.new_vec_znx(rank + 1, div_ceil(basek, k)),
basek,
k,
}
}
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
module.bytes_of_vec_znx(rank + 1, derive_size(basek, k))
module.bytes_of_vec_znx(rank + 1, div_ceil(basek, k))
}
}
@@ -69,18 +69,18 @@ impl<C: AsRef<[u8]>> GLWECiphertext<C> {
impl GLWECiphertext<Vec<u8>> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
let size: usize = derive_size(basek, k);
let size: usize = div_ceil(basek, k);
module.vec_znx_big_normalize_tmp_bytes() + module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx(1, size)
}
pub fn encrypt_pk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
let size: usize = derive_size(basek, k);
let size: usize = div_ceil(basek, k);
((module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx_big(1, size)) | module.bytes_of_scalar_znx(1))
+ module.bytes_of_scalar_znx_dft(1)
+ module.vec_znx_big_normalize_tmp_bytes()
}
pub fn decrypt_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
let size: usize = derive_size(basek, k);
let size: usize = div_ceil(basek, k);
(module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size)
}
@@ -94,9 +94,9 @@ impl GLWECiphertext<Vec<u8>> {
ksk_k: usize,
) -> usize {
let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, out_rank);
let in_size: usize = derive_size(basek, in_k);
let out_size: usize = derive_size(basek, out_k);
let ksk_size: usize = derive_size(basek, ksk_k);
let in_size: usize = div_ceil(basek, in_k);
let out_size: usize = div_ceil(basek, out_k);
let ksk_size: usize = div_ceil(basek, ksk_k);
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size)
+ module.bytes_of_vec_znx_dft(in_rank, in_size);
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
@@ -155,9 +155,9 @@ impl GLWECiphertext<Vec<u8>> {
rank: usize,
) -> usize {
let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank);
let in_size: usize = derive_size(basek, in_k);
let out_size: usize = derive_size(basek, out_k);
let ggsw_size: usize = derive_size(basek, ggsw_k);
let in_size: usize = div_ceil(basek, in_k);
let out_size: usize = div_ceil(basek, out_k);
let ggsw_size: usize = div_ceil(basek, ggsw_k);
let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size)
+ module.vmp_apply_tmp_bytes(
out_size,

View File

@@ -4,7 +4,7 @@ use backend::{
};
use sampling::source::Source;
use crate::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, ScratchCore, derive_size};
use crate::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, ScratchCore, div_ceil};
pub struct GLWECiphertextFourier<C, B: Backend> {
pub data: VecZnxDft<C, B>,
@@ -15,14 +15,14 @@ pub struct GLWECiphertextFourier<C, B: Backend> {
impl<B: Backend> GLWECiphertextFourier<Vec<u8>, B> {
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self {
Self {
data: module.new_vec_znx_dft(rank + 1, derive_size(basek, k)),
data: module.new_vec_znx_dft(rank + 1, div_ceil(basek, k)),
basek: basek,
k: k,
}
}
pub fn bytes_of(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize {
module.bytes_of_vec_znx_dft(rank + 1, derive_size(basek, k))
module.bytes_of_vec_znx_dft(rank + 1, div_ceil(basek, k))
}
}
@@ -51,16 +51,16 @@ impl<T, B: Backend> GLWECiphertextFourier<T, B> {
impl GLWECiphertextFourier<Vec<u8>, FFT64> {
#[allow(dead_code)]
pub(crate) fn idft_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
module.bytes_of_vec_znx(1, derive_size(basek, k))
module.bytes_of_vec_znx(1, div_ceil(basek, k))
+ (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())
}
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
module.bytes_of_vec_znx(rank + 1, derive_size(basek, k)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
module.bytes_of_vec_znx(rank + 1, div_ceil(basek, k)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
}
pub fn decrypt_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
let size: usize = derive_size(basek, k);
let size: usize = div_ceil(basek, k);
(module.vec_znx_big_normalize_tmp_bytes()
| module.bytes_of_vec_znx_dft(1, size)
| (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes()))
@@ -99,9 +99,9 @@ impl GLWECiphertextFourier<Vec<u8>, FFT64> {
rank: usize,
) -> usize {
let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank);
let out_size: usize = derive_size(basek, out_k);
let in_size: usize = derive_size(basek, in_k);
let ggsw_size: usize = derive_size(basek, ggsw_k);
let out_size: usize = div_ceil(basek, out_k);
let in_size: usize = div_ceil(basek, in_k);
let ggsw_size: usize = div_ceil(basek, ggsw_k);
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size);
let res_small: usize = GLWECiphertext::bytes_of(module, basek, out_k, rank);
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();

View File

@@ -1,6 +1,6 @@
use backend::{Backend, FFT64, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef};
use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData, derive_size};
use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData, div_ceil};
pub struct GLWEPlaintext<C> {
pub data: VecZnx<C>,
@@ -37,14 +37,14 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> SetMetaData for GLWEPlaintext<DataSelf
impl GLWEPlaintext<Vec<u8>> {
pub fn alloc<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> Self {
Self {
data: module.new_vec_znx(1, derive_size(basek, k)),
data: module.new_vec_znx(1, div_ceil(basek, k)),
basek: basek,
k,
}
}
pub fn byte_of(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
module.bytes_of_vec_znx(1, derive_size(basek, k))
module.bytes_of_vec_znx(1, div_ceil(basek, k))
}
}

View File

@@ -6,14 +6,14 @@ use crate::{GGLWECiphertext, GGSWCiphertext, GLWECiphertextFourier, GLWESecret,
pub struct GLWESwitchingKey<Data, B: Backend>(pub(crate) GGLWECiphertext<Data, B>);
impl GLWESwitchingKey<Vec<u8>, FFT64> {
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self {
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self {
GLWESwitchingKey(GGLWECiphertext::alloc(
module, basek, k, rows, rank_in, rank_out,
module, basek, k, rows, digits, rank_in, rank_out,
))
}
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> usize {
GGLWECiphertext::<Vec<u8>, FFT64>::bytes_of(module, basek, k, rows, rank_in, rank_out)
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> usize {
GGLWECiphertext::<Vec<u8>, FFT64>::bytes_of(module, basek, k, rows, digits, rank_in, rank_out)
}
}
@@ -45,6 +45,10 @@ impl<T, B: Backend> GLWESwitchingKey<T, B> {
pub fn rank_out(&self) -> usize {
self.0.data.cols_out() - 1
}
pub fn digits(&self) -> usize {
self.0.digits()
}
}
impl<C: AsRef<[u8]>> GetRow<FFT64> for GLWESwitchingKey<C, FFT64> {

View File

@@ -33,7 +33,7 @@ pub use tensor_key::*;
pub use backend::Scratch;
pub use backend::ScratchOwned;
use utils::derive_size;
use utils::div_ceil;
pub(crate) const SIX_SIGMA: f64 = 6.0;
@@ -46,6 +46,7 @@ pub trait ScratchCore<B: Backend> {
basek: usize,
k: usize,
rows: usize,
digits: usize,
rank_in: usize,
rank_out: usize,
) -> (GGLWECiphertext<&mut [u8], B>, &mut Self);
@@ -55,6 +56,7 @@ pub trait ScratchCore<B: Backend> {
basek: usize,
k: usize,
rows: usize,
digits: usize,
rank: usize,
) -> (GGSWCiphertext<&mut [u8], B>, &mut Self);
fn tmp_glwe_fourier(
@@ -78,6 +80,7 @@ pub trait ScratchCore<B: Backend> {
basek: usize,
k: usize,
rows: usize,
digits: usize,
rank_in: usize,
rank_out: usize,
) -> (GLWESwitchingKey<&mut [u8], B>, &mut Self);
@@ -87,6 +90,7 @@ pub trait ScratchCore<B: Backend> {
basek: usize,
k: usize,
rows: usize,
digits: usize,
rank: usize,
) -> (TensorKey<&mut [u8], B>, &mut Self);
fn tmp_autokey(
@@ -95,6 +99,7 @@ pub trait ScratchCore<B: Backend> {
basek: usize,
k: usize,
rows: usize,
digits: usize,
rank: usize,
) -> (AutomorphismKey<&mut [u8], B>, &mut Self);
}
@@ -107,12 +112,12 @@ impl ScratchCore<FFT64> for Scratch {
k: usize,
rank: usize,
) -> (GLWECiphertext<&mut [u8]>, &mut Self) {
let (data, scratch) = self.tmp_vec_znx(module, rank + 1, derive_size(basek, k));
let (data, scratch) = self.tmp_vec_znx(module, rank + 1, div_ceil(basek, k));
(GLWECiphertext { data, basek, k }, scratch)
}
fn tmp_glwe_pt(&mut self, module: &Module<FFT64>, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self) {
let (data, scratch) = self.tmp_vec_znx(module, 1, derive_size(basek, k));
let (data, scratch) = self.tmp_vec_znx(module, 1, div_ceil(basek, k));
(GLWEPlaintext { data, basek, k }, scratch)
}
@@ -122,15 +127,17 @@ impl ScratchCore<FFT64> for Scratch {
basek: usize,
k: usize,
rows: usize,
digits: usize,
rank_in: usize,
rank_out: usize,
) -> (GGLWECiphertext<&mut [u8], FFT64>, &mut Self) {
let (data, scratch) = self.tmp_mat_znx_dft(module, rows, rank_in, rank_out + 1, derive_size(basek, k));
let (data, scratch) = self.tmp_mat_znx_dft(module, div_ceil(rows, digits), rank_in, rank_out + 1, div_ceil(basek, k));
(
GGLWECiphertext {
data: data,
basek: basek,
k,
digits,
},
scratch,
)
@@ -142,14 +149,16 @@ impl ScratchCore<FFT64> for Scratch {
basek: usize,
k: usize,
rows: usize,
digits: usize,
rank: usize,
) -> (GGSWCiphertext<&mut [u8], FFT64>, &mut Self) {
let (data, scratch) = self.tmp_mat_znx_dft(module, rows, rank + 1, rank + 1, derive_size(basek, k));
let (data, scratch) = self.tmp_mat_znx_dft(module, div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k));
(
GGSWCiphertext {
data: data,
basek: basek,
data,
basek,
k,
digits,
},
scratch,
)
@@ -162,7 +171,7 @@ impl ScratchCore<FFT64> for Scratch {
k: usize,
rank: usize,
) -> (GLWECiphertextFourier<&mut [u8], FFT64>, &mut Self) {
let (data, scratch) = self.tmp_vec_znx_dft(module, rank + 1, derive_size(basek, k));
let (data, scratch) = self.tmp_vec_znx_dft(module, rank + 1, div_ceil(basek, k));
(GLWECiphertextFourier { data, basek, k }, scratch)
}
@@ -202,10 +211,11 @@ impl ScratchCore<FFT64> for Scratch {
basek: usize,
k: usize,
rows: usize,
digits: usize,
rank_in: usize,
rank_out: usize,
) -> (GLWESwitchingKey<&mut [u8], FFT64>, &mut Self) {
let (data, scratch) = self.tmp_gglwe(module, basek, k, rows, rank_in, rank_out);
let (data, scratch) = self.tmp_gglwe(module, basek, k, rows, digits, rank_in, rank_out);
(GLWESwitchingKey(data), scratch)
}
@@ -215,9 +225,10 @@ impl ScratchCore<FFT64> for Scratch {
basek: usize,
k: usize,
rows: usize,
digits: usize,
rank: usize,
) -> (AutomorphismKey<&mut [u8], FFT64>, &mut Self) {
let (data, scratch) = self.tmp_glwe_ksk(module, basek, k, rows, rank, rank);
let (data, scratch) = self.tmp_glwe_ksk(module, basek, k, rows, digits, rank, rank);
(AutomorphismKey { key: data, p: 0 }, scratch)
}
@@ -227,6 +238,7 @@ impl ScratchCore<FFT64> for Scratch {
basek: usize,
k: usize,
rows: usize,
digits: usize,
rank: usize,
) -> (TensorKey<&mut [u8], FFT64>, &mut Self) {
let mut keys: Vec<GLWESwitchingKey<&mut [u8], FFT64>> = Vec::new();
@@ -235,12 +247,12 @@ impl ScratchCore<FFT64> for Scratch {
let mut scratch: &mut Scratch = self;
if pairs != 0 {
let (gglwe, s) = scratch.tmp_glwe_ksk(module, basek, k, rows, 1, rank);
let (gglwe, s) = scratch.tmp_glwe_ksk(module, basek, k, rows, digits, 1, rank);
scratch = s;
keys.push(gglwe);
}
for _ in 1..pairs {
let (gglwe, s) = scratch.tmp_glwe_ksk(module, basek, k, rows, 1, rank);
let (gglwe, s) = scratch.tmp_glwe_ksk(module, basek, k, rows, digits, 1, rank);
scratch = s;
keys.push(gglwe);
}

View File

@@ -8,18 +8,18 @@ pub struct TensorKey<C, B: Backend> {
}
impl TensorKey<Vec<u8>, FFT64> {
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self {
let mut keys: Vec<GLWESwitchingKey<Vec<u8>, FFT64>> = Vec::new();
let pairs: usize = (((rank + 1) * rank) >> 1).max(1);
(0..pairs).for_each(|_| {
keys.push(GLWESwitchingKey::alloc(module, basek, k, rows, 1, rank));
keys.push(GLWESwitchingKey::alloc(module, basek, k, rows, digits,1, rank));
});
Self { keys: keys }
}
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> usize {
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize {
let pairs: usize = (((rank + 1) * rank) >> 1).max(1);
pairs * GLWESwitchingKey::<Vec<u8>, FFT64>::bytes_of(module, basek, k, rows, 1, rank)
pairs * GLWESwitchingKey::<Vec<u8>, FFT64>::bytes_of(module, basek, k, rows, digits,1, rank)
}
}
@@ -51,6 +51,10 @@ impl<T, B: Backend> TensorKey<T, B> {
pub fn rank_out(&self) -> usize {
self.keys[0].rank_out()
}
pub fn digits(&self) -> usize {
self.keys[0].digits()
}
}
impl TensorKey<Vec<u8>, FFT64> {

View File

@@ -1,3 +1,3 @@
pub(crate) fn derive_size(basek: usize, k: usize) -> usize {
(k + basek - 1) / basek
pub(crate) fn div_ceil(a: usize, b: usize) -> usize {
(a + b - 1) / b
}