mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Updated arguments to get scratch space size for ops
This commit is contained in:
@@ -60,8 +60,9 @@ impl<T, B: Backend> GGSWCiphertext<T, B> {
|
||||
}
|
||||
|
||||
impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
|
||||
GLWECiphertext::encrypt_sk_scratch_space(module, size)
|
||||
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
|
||||
let size = derive_size(basek, k);
|
||||
GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
|
||||
+ module.bytes_of_vec_znx(rank + 1, size)
|
||||
+ module.bytes_of_vec_znx(1, size)
|
||||
+ module.bytes_of_vec_znx_dft(rank + 1, size)
|
||||
@@ -69,112 +70,116 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
|
||||
pub(crate) fn expand_row_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
self_size: usize,
|
||||
tensor_key_size: usize,
|
||||
basek: usize,
|
||||
self_k: usize,
|
||||
tsk_k: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tensor_key_size);
|
||||
let tsk_size: usize = derive_size(basek, tsk_k);
|
||||
let self_size: usize = derive_size(basek, self_k);
|
||||
let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size);
|
||||
let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size);
|
||||
let vmp: usize =
|
||||
tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tensor_key_size);
|
||||
let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tensor_key_size);
|
||||
let vmp: usize = tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tsk_size);
|
||||
let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size);
|
||||
let norm: usize = module.vec_znx_big_normalize_tmp_bytes();
|
||||
tmp_dft_i + ((tmp_dft_col_data + vmp) | (tmp_idft + norm))
|
||||
}
|
||||
|
||||
pub(crate) fn keyswitch_internal_col0_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ksk_size: usize,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
in_k: usize,
|
||||
ksk_k: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, ksk_size)
|
||||
+ module.bytes_of_vec_znx_dft(rank + 1, in_size)
|
||||
GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, out_k, rank, in_k, rank, ksk_k)
|
||||
+ module.bytes_of_vec_znx_dft(rank + 1, derive_size(basek, in_k))
|
||||
}
|
||||
|
||||
pub fn keyswitch_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ksk_size: usize,
|
||||
tensor_key_size: usize,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
in_k: usize,
|
||||
ksk_k: usize,
|
||||
tsk_k: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let out_size: usize = derive_size(basek, out_k);
|
||||
|
||||
let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size);
|
||||
let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let ks: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, out_size, in_size, ksk_size, rank);
|
||||
let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tensor_key_size, rank);
|
||||
let ks: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, out_k, in_k, ksk_k, rank);
|
||||
let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, out_k, tsk_k, rank);
|
||||
let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
res_znx + ci_dft + (ks | expand_rows | res_dft)
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
ksk_size: usize,
|
||||
tensor_key_size: usize,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
ksk_k: usize,
|
||||
tsk_k: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
GGSWCiphertext::keyswitch_scratch_space(module, out_size, out_size, ksk_size, tensor_key_size, rank)
|
||||
GGSWCiphertext::keyswitch_scratch_space(module, basek, out_k, out_k, ksk_k, tsk_k, rank)
|
||||
}
|
||||
|
||||
pub fn automorphism_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
auto_key_size: usize,
|
||||
tensor_key_size: usize,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
in_k: usize,
|
||||
atk_k: usize,
|
||||
tsk_k: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let cols: usize = rank + 1;
|
||||
let out_size: usize = derive_size(basek, out_k);
|
||||
let res: usize = module.bytes_of_vec_znx(cols, out_size);
|
||||
let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
|
||||
let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
|
||||
let ks_internal: usize =
|
||||
GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, out_size, in_size, auto_key_size, rank);
|
||||
let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tensor_key_size, rank);
|
||||
let ks_internal: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, out_k, in_k, atk_k, rank);
|
||||
let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, out_k, tsk_k, rank);
|
||||
res + ci_dft + (ks_internal | expand | res_dft)
|
||||
}
|
||||
|
||||
pub fn automorphism_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
auto_key_size: usize,
|
||||
tensor_key_size: usize,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
atk_k: usize,
|
||||
tsk_k: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
GGSWCiphertext::automorphism_scratch_space(
|
||||
module,
|
||||
out_size,
|
||||
out_size,
|
||||
auto_key_size,
|
||||
tensor_key_size,
|
||||
rank,
|
||||
)
|
||||
GGSWCiphertext::automorphism_scratch_space(module, basek, out_k, out_k, atk_k, tsk_k, rank)
|
||||
}
|
||||
|
||||
pub fn external_product_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ggsw_size: usize,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
in_k: usize,
|
||||
ggsw_k: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tmp_in: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size);
|
||||
let tmp_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, out_size, in_size, ggsw_size, rank);
|
||||
let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, in_k, rank);
|
||||
let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank);
|
||||
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, basek, out_k, in_k, ggsw_k, rank);
|
||||
tmp_in + tmp_out + ggsw
|
||||
}
|
||||
|
||||
pub fn external_product_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
ggsw_size: usize,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
ggsw_k: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tmp: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, out_size, ggsw_size, rank);
|
||||
let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank);
|
||||
let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, basek, out_k, ggsw_k, rank);
|
||||
tmp + ggsw
|
||||
}
|
||||
}
|
||||
@@ -248,7 +253,9 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
{
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
assert!(scratch.available() >= GGSWCiphertext::expand_row_scratch_space(module, self.size(), tsk.size(), tsk.rank()));
|
||||
assert!(
|
||||
scratch.available() >= GGSWCiphertext::expand_row_scratch_space(module, self.basek(), self.k(), tsk.k(), self.rank())
|
||||
);
|
||||
|
||||
// Example for rank 3:
|
||||
//
|
||||
@@ -414,10 +421,11 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
scratch.available()
|
||||
>= GGSWCiphertext::automorphism_scratch_space(
|
||||
module,
|
||||
self.size(),
|
||||
lhs.size(),
|
||||
auto_key.size(),
|
||||
tensor_key.size(),
|
||||
self.basek(),
|
||||
self.k(),
|
||||
lhs.k(),
|
||||
auto_key.k(),
|
||||
tensor_key.k(),
|
||||
self.rank()
|
||||
)
|
||||
)
|
||||
@@ -570,9 +578,10 @@ impl<DataSelf: AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
scratch.available()
|
||||
>= GGSWCiphertext::keyswitch_internal_col0_scratch_space(
|
||||
module,
|
||||
res.size(),
|
||||
self.size(),
|
||||
ksk.size(),
|
||||
self.basek(),
|
||||
res.k(),
|
||||
self.k(),
|
||||
ksk.k(),
|
||||
ksk.rank()
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user