Updated arguments to get scratch space size for ops

This commit is contained in:
Jean-Philippe Bossuat
2025-05-28 18:46:24 +02:00
parent 8209fb4e40
commit f9440c5407
20 changed files with 599 additions and 529 deletions

View File

@@ -53,59 +53,72 @@ impl<T, B: Backend> GLWECiphertextFourier<T, B> {
impl GLWECiphertextFourier<Vec<u8>, FFT64> {
#[allow(dead_code)]
pub(crate) fn idft_scratch_space(module: &Module<FFT64>, size: usize) -> usize {
module.bytes_of_vec_znx(1, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())
pub(crate) fn idft_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
module.bytes_of_vec_znx(1, derive_size(basek, k))
+ (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())
}
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, ct_size: usize) -> usize {
module.bytes_of_vec_znx(rank + 1, ct_size) + GLWECiphertext::encrypt_sk_scratch_space(module, ct_size)
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
module.bytes_of_vec_znx(rank + 1, derive_size(basek, k)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
}
pub fn decrypt_scratch_space(module: &Module<FFT64>, ct_size: usize) -> usize {
pub fn decrypt_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
let size: usize = derive_size(basek, k);
(module.vec_znx_big_normalize_tmp_bytes()
| module.bytes_of_vec_znx_dft(1, ct_size)
| (module.bytes_of_vec_znx_big(1, ct_size) + module.vec_znx_idft_tmp_bytes()))
+ module.bytes_of_vec_znx_big(1, ct_size)
| module.bytes_of_vec_znx_dft(1, size)
| (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes()))
+ module.bytes_of_vec_znx_big(1, size)
}
pub fn keyswitch_scratch_space(
module: &Module<FFT64>,
out_size: usize,
basek: usize,
out_k: usize,
out_rank: usize,
in_size: usize,
in_k: usize,
in_rank: usize,
ksk_size: usize,
ksk_k: usize,
) -> usize {
module.bytes_of_vec_znx(out_rank + 1, out_size)
+ GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, out_rank, in_size, in_rank, ksk_size)
GLWECiphertext::bytes_of(module, basek, out_k, out_rank)
+ GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, out_k, out_rank, in_k, in_rank, ksk_k)
}
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, out_size: usize, out_rank: usize, ksk_size: usize) -> usize {
Self::keyswitch_scratch_space(module, out_size, out_rank, out_size, out_rank, ksk_size)
pub fn keyswitch_inplace_scratch_space(
module: &Module<FFT64>,
basek: usize,
out_k: usize,
out_rank: usize,
ksk_k: usize,
) -> usize {
Self::keyswitch_scratch_space(module, basek, out_k, out_rank, out_k, out_rank, ksk_k)
}
pub fn external_product_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
ggsw_size: usize,
basek: usize,
out_k: usize,
in_k: usize,
ggsw_k: usize,
rank: usize,
) -> usize {
let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank);
let out_size: usize = derive_size(basek, out_k);
let in_size: usize = derive_size(basek, in_k);
let ggsw_size: usize = derive_size(basek, ggsw_k);
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size);
let res_small: usize = module.bytes_of_vec_znx(rank + 1, out_size);
let res_small: usize = GLWECiphertext::bytes_of(module, basek, out_k, rank);
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
res_dft + (vmp | (res_small + normalize))
}
pub fn external_product_inplace_scratch_space(
module: &Module<FFT64>,
out_size: usize,
ggsw_size: usize,
basek: usize,
out_k: usize,
ggsw_k: usize,
rank: usize,
) -> usize {
Self::external_product_scratch_space(module, out_size, out_size, ggsw_size, rank)
Self::external_product_scratch_space(module, basek, out_k, out_k, ggsw_k, rank)
}
}