fixed all tests

This commit is contained in:
Jean-Philippe Bossuat
2025-06-06 14:06:36 +02:00
parent 33795df6c2
commit 113231da55
28 changed files with 1817 additions and 959 deletions

View File

@@ -19,8 +19,24 @@ pub struct GGSWCiphertext<C, B: Backend> {
impl<B: Backend> GGSWCiphertext<Vec<u8>, B> {
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self {
let size: usize = div_ceil(k, basek);
debug_assert!(
size > digits,
"invalid ggsw: ceil(k/basek): {} <= digits: {}",
size,
digits
);
assert!(
rows * digits <= size,
"invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}",
rows,
digits,
size
);
Self {
data: module.new_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k)),
data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, div_ceil(k, basek)),
basek,
k: k,
digits,
@@ -28,7 +44,23 @@ impl<B: Backend> GGSWCiphertext<Vec<u8>, B> {
}
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize {
module.bytes_of_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k))
let size: usize = div_ceil(k, basek);
debug_assert!(
size > digits,
"invalid ggsw: ceil(k/basek): {} <= digits: {}",
size,
digits
);
assert!(
rows * digits <= size,
"invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}",
rows,
digits,
size
);
module.bytes_of_mat_znx_dft(rows, rank + 1, rank + 1, size)
}
}
@@ -60,7 +92,7 @@ impl<T, B: Backend> GGSWCiphertext<T, B> {
impl GGSWCiphertext<Vec<u8>, FFT64> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
let size = div_ceil(basek, k);
let size = div_ceil(k, basek);
GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
+ module.bytes_of_vec_znx(rank + 1, size)
+ module.bytes_of_vec_znx(1, size)
@@ -71,46 +103,59 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
module: &Module<FFT64>,
basek: usize,
self_k: usize,
tsk_k: usize,
k_tsk: usize,
digits: usize,
rank: usize,
) -> usize {
let tsk_size: usize = div_ceil(basek, tsk_k);
let self_size: usize = div_ceil(basek, self_k);
let tsk_size: usize = div_ceil(k_tsk, basek);
let self_size_out: usize = div_ceil(self_k, basek);
let self_size_in: usize = div_ceil(self_size_out, digits);
let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size);
let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size);
let vmp: usize = tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tsk_size);
let tmp_a: usize = module.bytes_of_vec_znx_dft(1, self_size_in);
let vmp: usize = module.vmp_apply_tmp_bytes(
self_size_out,
self_size_in,
self_size_in,
rank,
rank,
tsk_size,
);
let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size);
let norm: usize = module.vec_znx_big_normalize_tmp_bytes();
tmp_dft_i + ((tmp_dft_col_data + vmp) | (tmp_idft + norm))
tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm))
}
pub(crate) fn keyswitch_internal_col0_scratch_space(
module: &Module<FFT64>,
basek: usize,
out_k: usize,
in_k: usize,
ksk_k: usize,
k_out: usize,
k_in: usize,
k_ksk: usize,
digits: usize,
rank: usize,
) -> usize {
GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, out_k, rank, in_k, rank, ksk_k)
+ module.bytes_of_vec_znx_dft(rank + 1, div_ceil(basek, in_k))
GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank)
+ module.bytes_of_vec_znx_dft(rank + 1, div_ceil(k_in, basek))
}
pub fn keyswitch_scratch_space(
module: &Module<FFT64>,
basek: usize,
out_k: usize,
in_k: usize,
ksk_k: usize,
tsk_k: usize,
k_out: usize,
k_in: usize,
k_ksk: usize,
digits_ksk: usize,
k_tsk: usize,
digits_tsk: usize,
rank: usize,
) -> usize {
let out_size: usize = div_ceil(basek, out_k);
let out_size: usize = div_ceil(k_out, basek);
let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size);
let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
let ks: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, out_k, in_k, ksk_k, rank);
let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, out_k, tsk_k, rank);
let ks: usize =
GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank);
let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank);
let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
res_znx + ci_dft + (ks | expand_rows | res_dft)
}
@@ -118,67 +163,81 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
pub fn keyswitch_inplace_scratch_space(
module: &Module<FFT64>,
basek: usize,
out_k: usize,
ksk_k: usize,
tsk_k: usize,
k_out: usize,
k_ksk: usize,
digits_ksk: usize,
k_tsk: usize,
digits_tsk: usize,
rank: usize,
) -> usize {
GGSWCiphertext::keyswitch_scratch_space(module, basek, out_k, out_k, ksk_k, tsk_k, rank)
GGSWCiphertext::keyswitch_scratch_space(
module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank,
)
}
pub fn automorphism_scratch_space(
module: &Module<FFT64>,
basek: usize,
out_k: usize,
in_k: usize,
atk_k: usize,
tsk_k: usize,
k_out: usize,
k_in: usize,
k_ksk: usize,
digits_ksk: usize,
k_tsk: usize,
digits_tsk: usize,
rank: usize,
) -> usize {
let cols: usize = rank + 1;
let out_size: usize = div_ceil(basek, out_k);
let out_size: usize = div_ceil(k_out, basek);
let res: usize = module.bytes_of_vec_znx(cols, out_size);
let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
let ks_internal: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, out_k, in_k, atk_k, rank);
let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, out_k, tsk_k, rank);
let ks_internal: usize =
GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank);
let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank);
res + ci_dft + (ks_internal | expand | res_dft)
}
pub fn automorphism_inplace_scratch_space(
module: &Module<FFT64>,
basek: usize,
out_k: usize,
atk_k: usize,
tsk_k: usize,
k_out: usize,
k_ksk: usize,
digits_ksk: usize,
k_tsk: usize,
digits_tsk: usize,
rank: usize,
) -> usize {
GGSWCiphertext::automorphism_scratch_space(module, basek, out_k, out_k, atk_k, tsk_k, rank)
GGSWCiphertext::automorphism_scratch_space(
module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank,
)
}
pub fn external_product_scratch_space(
module: &Module<FFT64>,
basek: usize,
out_k: usize,
in_k: usize,
ggsw_k: usize,
k_out: usize,
k_in: usize,
k_ggsw: usize,
digits: usize,
rank: usize,
) -> usize {
let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, in_k, rank);
let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank);
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, basek, out_k, in_k, ggsw_k, rank);
let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, k_in, rank);
let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank);
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank);
tmp_in + tmp_out + ggsw
}
pub fn external_product_inplace_scratch_space(
module: &Module<FFT64>,
basek: usize,
out_k: usize,
ggsw_k: usize,
k_out: usize,
k_ggsw: usize,
digits: usize,
rank: usize,
) -> usize {
let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank);
let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, basek, out_k, ggsw_k, rank);
let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank);
let ggsw: usize =
GLWECiphertextFourier::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank);
tmp + ggsw
}
}
@@ -214,7 +273,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
tmp_pt.data.zero();
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i * digits, pt, 0);
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0);
module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch2);
(0..rank + 1).for_each(|col_j| {
@@ -254,7 +313,15 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
let cols: usize = self.rank() + 1;
assert!(
scratch.available() >= GGSWCiphertext::expand_row_scratch_space(module, self.basek(), self.k(), tsk.k(), self.rank())
scratch.available()
>= GGSWCiphertext::expand_row_scratch_space(
module,
self.basek(),
self.k(),
tsk.k(),
tsk.digits(),
tsk.rank()
)
);
// Example for rank 3:
@@ -279,8 +346,6 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size());
{
let (mut tmp_dft_col_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, self.size());
// Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
//
// # Example for col=1
@@ -293,23 +358,27 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
// =
// (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2)
(1..cols).for_each(|col_i| {
// Extracts a[i] and multipies with Enc(s[i]s[j])
tmp_dft_col_data.extract_column(0, ci_dft, col_i);
let digits: usize = tsk.digits();
let pmat: &MatZnxDft<DataTsk, FFT64> = &tsk.at(col_i - 1, col_j - 1).0.data; // Selects Enc(s[i]s[j])
// Extracts a[i] and multipies with Enc(s[i]s[j])
if col_i == 1 {
module.vmp_apply(
&mut tmp_dft_i,
&tmp_dft_col_data,
&tsk.at(col_i - 1, col_j - 1).0.data, // Selects Enc(s[i]s[j])
scratch2,
);
(0..digits).for_each(|di| {
let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + di) / digits);
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i);
if di == 0 {
module.vmp_apply(&mut tmp_dft_i, &tmp_a, pmat, scratch2);
} else {
module.vmp_apply_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch2);
}
});
} else {
module.vmp_apply_add(
&mut tmp_dft_i,
&tmp_dft_col_data,
&tsk.at(col_i - 1, col_j - 1).0.data, // Selects Enc(s[i]s[j])
scratch2,
);
(0..digits).for_each(|di| {
let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + di) / digits);
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i);
module.vmp_apply_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch2);
});
}
});
}
@@ -344,7 +413,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
let basek: usize = self.basek();
let (mut tmp_res, scratch1) = scratch.tmp_glwe_ct(module, basek, self.k(), rank);
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size());
// Keyswitch the j-th row of the col 0
(0..lhs.rows()).for_each(|row_i| {
@@ -354,7 +423,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
// Isolates DFT(a[i])
(0..cols).for_each(|col_i| {
module.vec_znx_dft(&mut ci_dft, col_i, &tmp_res.data, col_i);
module.vec_znx_dft(1, 0, &mut ci_dft, col_i, &tmp_res.data, col_i);
});
module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft);
@@ -425,8 +494,10 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
self.k(),
lhs.k(),
auto_key.k(),
auto_key.digits(),
tensor_key.k(),
self.rank()
tensor_key.digits(),
self.rank(),
)
)
};
@@ -436,7 +507,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
let basek: usize = self.basek();
let (mut tmp_res, scratch1) = scratch.tmp_glwe_ct(module, basek, self.k(), rank);
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size());
// Keyswitch the j-th row of the col 0
(0..lhs.rows()).for_each(|row_i| {
@@ -448,7 +519,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
(0..cols).for_each(|col_i| {
// (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) -> (-(a0s0 + a1s1 + a2s2) + pi(M[i]), a0, a1, a2)
module.vec_znx_automorphism_inplace(auto_key.p(), &mut tmp_res.data, col_i);
module.vec_znx_dft(&mut ci_dft, col_i, &tmp_res.data, col_i);
module.vec_znx_dft(1, 0, &mut ci_dft, col_i, &tmp_res.data, col_i);
});
module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft);
@@ -510,6 +581,19 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
self.rank(),
rhs.rank()
);
assert!(
scratch.available()
>= GGSWCiphertext::external_product_scratch_space(
module,
self.basek(),
self.k(),
lhs.k(),
rhs.k(),
rhs.digits(),
rhs.rank()
)
)
}
let (mut tmp_ct_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank());
@@ -582,6 +666,7 @@ impl<DataSelf: AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
res.k(),
self.k(),
ksk.k(),
ksk.digits(),
ksk.rank()
)
)