mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
fixed all tests
This commit is contained in:
@@ -19,8 +19,24 @@ pub struct GGSWCiphertext<C, B: Backend> {
|
||||
|
||||
impl<B: Backend> GGSWCiphertext<Vec<u8>, B> {
|
||||
pub fn alloc(module: &Module<B>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self {
|
||||
let size: usize = div_ceil(k, basek);
|
||||
debug_assert!(
|
||||
size > digits,
|
||||
"invalid ggsw: ceil(k/basek): {} <= digits: {}",
|
||||
size,
|
||||
digits
|
||||
);
|
||||
|
||||
assert!(
|
||||
rows * digits <= size,
|
||||
"invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}",
|
||||
rows,
|
||||
digits,
|
||||
size
|
||||
);
|
||||
|
||||
Self {
|
||||
data: module.new_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k)),
|
||||
data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, div_ceil(k, basek)),
|
||||
basek,
|
||||
k: k,
|
||||
digits,
|
||||
@@ -28,7 +44,23 @@ impl<B: Backend> GGSWCiphertext<Vec<u8>, B> {
|
||||
}
|
||||
|
||||
pub fn bytes_of(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize {
|
||||
module.bytes_of_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k))
|
||||
let size: usize = div_ceil(k, basek);
|
||||
debug_assert!(
|
||||
size > digits,
|
||||
"invalid ggsw: ceil(k/basek): {} <= digits: {}",
|
||||
size,
|
||||
digits
|
||||
);
|
||||
|
||||
assert!(
|
||||
rows * digits <= size,
|
||||
"invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}",
|
||||
rows,
|
||||
digits,
|
||||
size
|
||||
);
|
||||
|
||||
module.bytes_of_mat_znx_dft(rows, rank + 1, rank + 1, size)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,7 +92,7 @@ impl<T, B: Backend> GGSWCiphertext<T, B> {
|
||||
|
||||
impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
|
||||
let size = div_ceil(basek, k);
|
||||
let size = div_ceil(k, basek);
|
||||
GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
|
||||
+ module.bytes_of_vec_znx(rank + 1, size)
|
||||
+ module.bytes_of_vec_znx(1, size)
|
||||
@@ -71,46 +103,59 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
self_k: usize,
|
||||
tsk_k: usize,
|
||||
k_tsk: usize,
|
||||
digits: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tsk_size: usize = div_ceil(basek, tsk_k);
|
||||
let self_size: usize = div_ceil(basek, self_k);
|
||||
let tsk_size: usize = div_ceil(k_tsk, basek);
|
||||
let self_size_out: usize = div_ceil(self_k, basek);
|
||||
let self_size_in: usize = div_ceil(self_size_out, digits);
|
||||
let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size);
|
||||
let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size);
|
||||
let vmp: usize = tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tsk_size);
|
||||
let tmp_a: usize = module.bytes_of_vec_znx_dft(1, self_size_in);
|
||||
let vmp: usize = module.vmp_apply_tmp_bytes(
|
||||
self_size_out,
|
||||
self_size_in,
|
||||
self_size_in,
|
||||
rank,
|
||||
rank,
|
||||
tsk_size,
|
||||
);
|
||||
let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size);
|
||||
let norm: usize = module.vec_znx_big_normalize_tmp_bytes();
|
||||
tmp_dft_i + ((tmp_dft_col_data + vmp) | (tmp_idft + norm))
|
||||
tmp_dft_i + ((tmp_a + vmp) | (tmp_idft + norm))
|
||||
}
|
||||
|
||||
pub(crate) fn keyswitch_internal_col0_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
in_k: usize,
|
||||
ksk_k: usize,
|
||||
k_out: usize,
|
||||
k_in: usize,
|
||||
k_ksk: usize,
|
||||
digits: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, out_k, rank, in_k, rank, ksk_k)
|
||||
+ module.bytes_of_vec_znx_dft(rank + 1, div_ceil(basek, in_k))
|
||||
GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank)
|
||||
+ module.bytes_of_vec_znx_dft(rank + 1, div_ceil(k_in, basek))
|
||||
}
|
||||
|
||||
pub fn keyswitch_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
in_k: usize,
|
||||
ksk_k: usize,
|
||||
tsk_k: usize,
|
||||
k_out: usize,
|
||||
k_in: usize,
|
||||
k_ksk: usize,
|
||||
digits_ksk: usize,
|
||||
k_tsk: usize,
|
||||
digits_tsk: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let out_size: usize = div_ceil(basek, out_k);
|
||||
let out_size: usize = div_ceil(k_out, basek);
|
||||
|
||||
let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size);
|
||||
let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let ks: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, out_k, in_k, ksk_k, rank);
|
||||
let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, out_k, tsk_k, rank);
|
||||
let ks: usize =
|
||||
GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank);
|
||||
let expand_rows: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank);
|
||||
let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
res_znx + ci_dft + (ks | expand_rows | res_dft)
|
||||
}
|
||||
@@ -118,67 +163,81 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
pub fn keyswitch_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
ksk_k: usize,
|
||||
tsk_k: usize,
|
||||
k_out: usize,
|
||||
k_ksk: usize,
|
||||
digits_ksk: usize,
|
||||
k_tsk: usize,
|
||||
digits_tsk: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
GGSWCiphertext::keyswitch_scratch_space(module, basek, out_k, out_k, ksk_k, tsk_k, rank)
|
||||
GGSWCiphertext::keyswitch_scratch_space(
|
||||
module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn automorphism_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
in_k: usize,
|
||||
atk_k: usize,
|
||||
tsk_k: usize,
|
||||
k_out: usize,
|
||||
k_in: usize,
|
||||
k_ksk: usize,
|
||||
digits_ksk: usize,
|
||||
k_tsk: usize,
|
||||
digits_tsk: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let cols: usize = rank + 1;
|
||||
let out_size: usize = div_ceil(basek, out_k);
|
||||
let out_size: usize = div_ceil(k_out, basek);
|
||||
let res: usize = module.bytes_of_vec_znx(cols, out_size);
|
||||
let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
|
||||
let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
|
||||
let ks_internal: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, out_k, in_k, atk_k, rank);
|
||||
let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, out_k, tsk_k, rank);
|
||||
let ks_internal: usize =
|
||||
GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank);
|
||||
let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank);
|
||||
res + ci_dft + (ks_internal | expand | res_dft)
|
||||
}
|
||||
|
||||
pub fn automorphism_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
atk_k: usize,
|
||||
tsk_k: usize,
|
||||
k_out: usize,
|
||||
k_ksk: usize,
|
||||
digits_ksk: usize,
|
||||
k_tsk: usize,
|
||||
digits_tsk: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
GGSWCiphertext::automorphism_scratch_space(module, basek, out_k, out_k, atk_k, tsk_k, rank)
|
||||
GGSWCiphertext::automorphism_scratch_space(
|
||||
module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn external_product_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
in_k: usize,
|
||||
ggsw_k: usize,
|
||||
k_out: usize,
|
||||
k_in: usize,
|
||||
k_ggsw: usize,
|
||||
digits: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, in_k, rank);
|
||||
let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank);
|
||||
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, basek, out_k, in_k, ggsw_k, rank);
|
||||
let tmp_in: usize = GLWECiphertextFourier::bytes_of(module, basek, k_in, rank);
|
||||
let tmp_out: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank);
|
||||
let ggsw: usize = GLWECiphertextFourier::external_product_scratch_space(module, basek, k_out, k_in, k_ggsw, digits, rank);
|
||||
tmp_in + tmp_out + ggsw
|
||||
}
|
||||
|
||||
pub fn external_product_inplace_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
basek: usize,
|
||||
out_k: usize,
|
||||
ggsw_k: usize,
|
||||
k_out: usize,
|
||||
k_ggsw: usize,
|
||||
digits: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank);
|
||||
let ggsw: usize = GLWECiphertextFourier::external_product_inplace_scratch_space(module, basek, out_k, ggsw_k, rank);
|
||||
let tmp: usize = GLWECiphertextFourier::bytes_of(module, basek, k_out, rank);
|
||||
let ggsw: usize =
|
||||
GLWECiphertextFourier::external_product_inplace_scratch_space(module, basek, k_out, k_ggsw, digits, rank);
|
||||
tmp + ggsw
|
||||
}
|
||||
}
|
||||
@@ -214,7 +273,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
tmp_pt.data.zero();
|
||||
|
||||
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
||||
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i * digits, pt, 0);
|
||||
module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0);
|
||||
module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch2);
|
||||
|
||||
(0..rank + 1).for_each(|col_j| {
|
||||
@@ -254,7 +313,15 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
assert!(
|
||||
scratch.available() >= GGSWCiphertext::expand_row_scratch_space(module, self.basek(), self.k(), tsk.k(), self.rank())
|
||||
scratch.available()
|
||||
>= GGSWCiphertext::expand_row_scratch_space(
|
||||
module,
|
||||
self.basek(),
|
||||
self.k(),
|
||||
tsk.k(),
|
||||
tsk.digits(),
|
||||
tsk.rank()
|
||||
)
|
||||
);
|
||||
|
||||
// Example for rank 3:
|
||||
@@ -279,8 +346,6 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
|
||||
let (mut tmp_dft_i, scratch1) = scratch.tmp_vec_znx_dft(module, cols, tsk.size());
|
||||
{
|
||||
let (mut tmp_dft_col_data, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, self.size());
|
||||
|
||||
// Performs a key-switch for each combination of s[i]*s[j], i.e. for a0, a1, a2
|
||||
//
|
||||
// # Example for col=1
|
||||
@@ -293,23 +358,27 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
// =
|
||||
// (-(x0s0 + x1s1 + x2s2) + s0(a0s0 + a1s1 + a2s2), x0, x1, x2)
|
||||
(1..cols).for_each(|col_i| {
|
||||
// Extracts a[i] and multipies with Enc(s[i]s[j])
|
||||
tmp_dft_col_data.extract_column(0, ci_dft, col_i);
|
||||
let digits: usize = tsk.digits();
|
||||
|
||||
let pmat: &MatZnxDft<DataTsk, FFT64> = &tsk.at(col_i - 1, col_j - 1).0.data; // Selects Enc(s[i]s[j])
|
||||
|
||||
// Extracts a[i] and multipies with Enc(s[i]s[j])
|
||||
if col_i == 1 {
|
||||
module.vmp_apply(
|
||||
&mut tmp_dft_i,
|
||||
&tmp_dft_col_data,
|
||||
&tsk.at(col_i - 1, col_j - 1).0.data, // Selects Enc(s[i]s[j])
|
||||
scratch2,
|
||||
);
|
||||
(0..digits).for_each(|di| {
|
||||
let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + di) / digits);
|
||||
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i);
|
||||
if di == 0 {
|
||||
module.vmp_apply(&mut tmp_dft_i, &tmp_a, pmat, scratch2);
|
||||
} else {
|
||||
module.vmp_apply_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch2);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
module.vmp_apply_add(
|
||||
&mut tmp_dft_i,
|
||||
&tmp_dft_col_data,
|
||||
&tsk.at(col_i - 1, col_j - 1).0.data, // Selects Enc(s[i]s[j])
|
||||
scratch2,
|
||||
);
|
||||
(0..digits).for_each(|di| {
|
||||
let (mut tmp_a, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, (ci_dft.size() + di) / digits);
|
||||
module.vec_znx_dft_copy(digits, digits - 1 - di, &mut tmp_a, 0, ci_dft, col_i);
|
||||
module.vmp_apply_add(&mut tmp_dft_i, &tmp_a, pmat, di, scratch2);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -344,7 +413,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
let basek: usize = self.basek();
|
||||
|
||||
let (mut tmp_res, scratch1) = scratch.tmp_glwe_ct(module, basek, self.k(), rank);
|
||||
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
|
||||
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size());
|
||||
|
||||
// Keyswitch the j-th row of the col 0
|
||||
(0..lhs.rows()).for_each(|row_i| {
|
||||
@@ -354,7 +423,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
|
||||
// Isolates DFT(a[i])
|
||||
(0..cols).for_each(|col_i| {
|
||||
module.vec_znx_dft(&mut ci_dft, col_i, &tmp_res.data, col_i);
|
||||
module.vec_znx_dft(1, 0, &mut ci_dft, col_i, &tmp_res.data, col_i);
|
||||
});
|
||||
|
||||
module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft);
|
||||
@@ -425,8 +494,10 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
self.k(),
|
||||
lhs.k(),
|
||||
auto_key.k(),
|
||||
auto_key.digits(),
|
||||
tensor_key.k(),
|
||||
self.rank()
|
||||
tensor_key.digits(),
|
||||
self.rank(),
|
||||
)
|
||||
)
|
||||
};
|
||||
@@ -436,7 +507,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
let basek: usize = self.basek();
|
||||
|
||||
let (mut tmp_res, scratch1) = scratch.tmp_glwe_ct(module, basek, self.k(), rank);
|
||||
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
|
||||
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size());
|
||||
|
||||
// Keyswitch the j-th row of the col 0
|
||||
(0..lhs.rows()).for_each(|row_i| {
|
||||
@@ -448,7 +519,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
(0..cols).for_each(|col_i| {
|
||||
// (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) -> (-(a0s0 + a1s1 + a2s2) + pi(M[i]), a0, a1, a2)
|
||||
module.vec_znx_automorphism_inplace(auto_key.p(), &mut tmp_res.data, col_i);
|
||||
module.vec_znx_dft(&mut ci_dft, col_i, &tmp_res.data, col_i);
|
||||
module.vec_znx_dft(1, 0, &mut ci_dft, col_i, &tmp_res.data, col_i);
|
||||
});
|
||||
|
||||
module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft);
|
||||
@@ -510,6 +581,19 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
self.rank(),
|
||||
rhs.rank()
|
||||
);
|
||||
|
||||
assert!(
|
||||
scratch.available()
|
||||
>= GGSWCiphertext::external_product_scratch_space(
|
||||
module,
|
||||
self.basek(),
|
||||
self.k(),
|
||||
lhs.k(),
|
||||
rhs.k(),
|
||||
rhs.digits(),
|
||||
rhs.rank()
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
let (mut tmp_ct_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank());
|
||||
@@ -582,6 +666,7 @@ impl<DataSelf: AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
|
||||
res.k(),
|
||||
self.k(),
|
||||
ksk.k(),
|
||||
ksk.digits(),
|
||||
ksk.rank()
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user