fixed all tests

This commit is contained in:
Jean-Philippe Bossuat
2025-06-06 14:06:36 +02:00
parent 33795df6c2
commit 113231da55
28 changed files with 1817 additions and 959 deletions

View File

@@ -15,14 +15,14 @@ pub struct GLWECiphertextFourier<C, B: Backend> {
impl<B: Backend> GLWECiphertextFourier<Vec<u8>, B> {
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rank: usize) -> Self {
Self {
data: module.new_vec_znx_dft(rank + 1, div_ceil(basek, k)),
data: module.new_vec_znx_dft(rank + 1, div_ceil(k, basek)),
basek: basek,
k: k,
}
}
pub fn bytes_of(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize {
module.bytes_of_vec_znx_dft(rank + 1, div_ceil(basek, k))
module.bytes_of_vec_znx_dft(rank + 1, div_ceil(k, basek))
}
}
@@ -51,16 +51,16 @@ impl<T, B: Backend> GLWECiphertextFourier<T, B> {
impl GLWECiphertextFourier<Vec<u8>, FFT64> {
#[allow(dead_code)]
pub(crate) fn idft_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
module.bytes_of_vec_znx(1, div_ceil(basek, k))
module.bytes_of_vec_znx(1, div_ceil(k, basek))
+ (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())
}
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
module.bytes_of_vec_znx(rank + 1, div_ceil(basek, k)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
module.bytes_of_vec_znx(rank + 1, div_ceil(k, basek)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
}
pub fn decrypt_scratch_space(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
let size: usize = div_ceil(basek, k);
let size: usize = div_ceil(k, basek);
(module.vec_znx_big_normalize_tmp_bytes()
| module.bytes_of_vec_znx_dft(1, size)
| (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes()))
@@ -70,40 +70,45 @@ impl GLWECiphertextFourier<Vec<u8>, FFT64> {
pub fn keyswitch_scratch_space(
module: &Module<FFT64>,
basek: usize,
out_k: usize,
out_rank: usize,
in_k: usize,
in_rank: usize,
ksk_k: usize,
k_out: usize,
k_in: usize,
k_ksk: usize,
digits: usize,
rank_in: usize,
rank_out: usize,
) -> usize {
GLWECiphertext::bytes_of(module, basek, out_k, out_rank)
+ GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, out_k, out_rank, in_k, in_rank, ksk_k)
GLWECiphertext::bytes_of(module, basek, k_out, rank_out)
+ GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out)
}
pub fn keyswitch_inplace_scratch_space(
module: &Module<FFT64>,
basek: usize,
out_k: usize,
out_rank: usize,
ksk_k: usize,
k_out: usize,
k_ksk: usize,
digits: usize,
rank: usize,
) -> usize {
Self::keyswitch_scratch_space(module, basek, out_k, out_rank, out_k, out_rank, ksk_k)
Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank)
}
// WARNING TODO: UPDATE
pub fn external_product_scratch_space(
module: &Module<FFT64>,
basek: usize,
out_k: usize,
in_k: usize,
ggsw_k: usize,
k_out: usize,
k_in: usize,
k_ggsw: usize,
digits: usize,
rank: usize,
) -> usize {
let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank);
let out_size: usize = div_ceil(basek, out_k);
let in_size: usize = div_ceil(basek, in_k);
let ggsw_size: usize = div_ceil(basek, ggsw_k);
let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size);
let res_small: usize = GLWECiphertext::bytes_of(module, basek, out_k, rank);
let ggsw_size: usize = div_ceil(k_ggsw, basek);
let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, ggsw_size);
let in_size: usize = div_ceil(div_ceil(k_in, basek), digits);
let ggsw_size: usize = div_ceil(k_ggsw, basek);
let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size)
+ module.vmp_apply_tmp_bytes(ggsw_size, in_size, in_size, rank + 1, rank + 1, ggsw_size);
let res_small: usize = module.bytes_of_vec_znx(rank + 1, ggsw_size);
let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
res_dft + (vmp | (res_small + normalize))
}
@@ -111,11 +116,12 @@ impl GLWECiphertextFourier<Vec<u8>, FFT64> {
pub fn external_product_inplace_scratch_space(
module: &Module<FFT64>,
basek: usize,
out_k: usize,
ggsw_k: usize,
k_out: usize,
k_ggsw: usize,
digits: usize,
rank: usize,
) -> usize {
Self::external_product_scratch_space(module, basek, out_k, out_k, ggsw_k, rank)
Self::external_product_scratch_space(module, basek, k_out, k_out, k_ggsw, digits, rank)
}
}
@@ -176,6 +182,18 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWECiphertextFourier<DataSelf, FFT64>
assert_eq!(rhs.n(), module.n());
assert_eq!(self.n(), module.n());
assert_eq!(lhs.n(), module.n());
assert!(
scratch.available()
>= GLWECiphertextFourier::external_product_scratch_space(
module,
self.basek(),
self.k(),
lhs.k(),
rhs.k(),
rhs.digits(),
rhs.rank(),
)
);
}
let cols: usize = rhs.rank() + 1;
@@ -184,7 +202,22 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWECiphertextFourier<DataSelf, FFT64>
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size());
{
module.vmp_apply(&mut res_dft, &lhs.data, &rhs.data, scratch1);
let digits = rhs.digits();
(0..digits).for_each(|di| {
// (lhs.size() + di) / digits = (a - (digit - di - 1) + digit - 1) / digits
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, (lhs.size() + di) / digits);
(0..cols).for_each(|col_i| {
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut a_dft, col_i, &lhs.data, col_i);
});
if di == 0 {
module.vmp_apply(&mut res_dft, &a_dft, &rhs.data, scratch2);
} else {
module.vmp_apply_add(&mut res_dft, &a_dft, &rhs.data, di, scratch2);
}
});
}
// VMP result in high precision
@@ -194,7 +227,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWECiphertextFourier<DataSelf, FFT64>
let (mut res_small, scratch2) = scratch1.tmp_vec_znx(module, cols, rhs.size());
(0..cols).for_each(|i| {
module.vec_znx_big_normalize(basek, &mut res_small, i, &res_big, i, scratch2);
module.vec_znx_dft(&mut self.data, i, &res_small, i);
module.vec_znx_dft(1, 0, &mut self.data, i, &res_small, i);
});
}