working GGSW key-switch + added test (missing noise formula)

This commit is contained in:
Jean-Philippe Bossuat
2025-05-20 13:51:13 +02:00
parent 06b3cccbff
commit 7d84477e64
2 changed files with 168 additions and 14 deletions

View File

@@ -1,7 +1,7 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxBigOps, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut,
VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero,
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch,
VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero,
};
use sampling::source::Source;
@@ -81,6 +81,27 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
+ module.bytes_of_vec_znx_dft(rank + 1, size)
}
pub fn keyswitch_scratch_space(
module: &Module<FFT64>,
out_size: usize,
in_size: usize,
ksk_size: usize,
tsk_size: usize,
rank: usize,
) -> usize {
let tmp_dft_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
let vmp_ksk: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size)
+ GLWECiphertextFourier::keyswitch_scratch_space(module, out_size, rank, in_size, rank, ksk_size);
let tmp_c0: usize = module.bytes_of_vec_znx_big(1, out_size);
let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size);
let vmp_tsk: usize = module.bytes_of_vec_znx_dft(1, out_size)
+ module.vmp_apply_tmp_bytes(out_size, out_size, rank + 1, rank + 1, rank + 1, tsk_size);
let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size);
let tmp_znx_small: usize = module.bytes_of_vec_znx(1, out_size);
let norm: usize = module.vec_znx_big_normalize_tmp_bytes();
tmp_dft_out + (vmp_ksk | (tmp_c0 + tmp_dft_i + (vmp_tsk | (tmp_idft + tmp_znx_small + norm))))
}
pub fn automorphism_scratch_space(
module: &Module<FFT64>,
out_size: usize,
@@ -90,7 +111,8 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
) -> usize {
let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, auto_key_size);
let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, out_size);
let vmp: usize = GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, auto_key_size);
let vmp: usize =
GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, auto_key_size);
tmp_dft + tmp_idft + vmp
}
@@ -198,7 +220,6 @@ where
});
}
pub fn keyswitch<DataLhs, DataRhs0, DataRhs1>(
&mut self,
module: &Module<FFT64>,
@@ -239,7 +260,7 @@ where
let cols: usize = self.rank() + 1;
// Example for rank 3:
//
//
// Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is
// actually composed of that many rows.
//
@@ -250,14 +271,13 @@ where
// col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M, c2 )
// col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M)
//
// # Output
// # Output
//
// col 0: (-(a0s0' + a1s1' + a2s2') + M, a0 , a1 , a2 )
// col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M, b1 , b2 )
// col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M, c2 )
// col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M)
(0..self.rows()).for_each(|row_j| {
let (tmp_dft_out_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
let mut tmp_dft_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
@@ -291,7 +311,6 @@ where
// col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) = KS_{s1's0', s1's1', s1's2'}(a1) + (0, 0, -(a0s0' + a1s1' + a2s2') + M[i], 0)
// col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) = KS_{s2's0', s2's1', s2's2'}(a2) + (0, 0, 0, -(a0s0' + a1s1' + a2s2') + M[i])
(1..cols).for_each(|col_i| {
let (tmp_dft_i_data, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, tsk.size());
let mut tmp_dft_i: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_dft_i_data,
@@ -311,7 +330,6 @@ where
// =
// (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0, x1, x2)
(1..cols).for_each(|col_j| {
// Extracts a[i] and multipies with Enc(s'[i]s'[j])
let (mut tmp_dft_col_data, scratch4) = scratch3.tmp_vec_znx_dft(module, 1, self.size());
tmp_dft_col_data.extract_column(0, &tmp_dft_out.data, col_j);
@@ -336,7 +354,7 @@ where
// Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i
//
// (-(x0s0' + x1s1' + x2s2') + a0s0's0' + a1s0's1' + a2s0's2', x0, x1, x2)
// +
// +
// (0, -(a0s0' + a1s1' + a2s2') + M[i], 0, 0)
// =
// (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0 -(a0s0' + a1s1' + a2s2') + M[i], x1, x2)
@@ -347,7 +365,9 @@ where
let (mut tmp_znx_small, scratch5) = scratch3.tmp_vec_znx(module, 1, self.size());
(0..cols).for_each(|i| {
module.vec_znx_idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i);
module.vec_znx_big_add_inplace(&mut tmp_idft, col_i, &tmp_c0_data, 0);
if i == col_i {
module.vec_znx_big_add_inplace(&mut tmp_idft, 0, &tmp_c0_data, 0);
}
module.vec_znx_big_normalize(self.basek(), &mut tmp_znx_small, 0, &tmp_idft, 0, scratch5);
module.vec_znx_dft(&mut tmp_dft_i, i, &tmp_znx_small, 0);
});
@@ -385,8 +405,7 @@ where
self.rank(),
rhs.rank()
);
}
;
};
let cols: usize = self.rank() + 1;
let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); //TODO optimize