mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
working GGSW key-switch + added test (missing noise formula)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
|
||||
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxBigOps, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut,
|
||||
VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero,
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
|
||||
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch,
|
||||
VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxInfos, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
@@ -81,6 +81,27 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
+ module.bytes_of_vec_znx_dft(rank + 1, size)
|
||||
}
|
||||
|
||||
pub fn keyswitch_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
in_size: usize,
|
||||
ksk_size: usize,
|
||||
tsk_size: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let tmp_dft_out: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size);
|
||||
let vmp_ksk: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size)
|
||||
+ GLWECiphertextFourier::keyswitch_scratch_space(module, out_size, rank, in_size, rank, ksk_size);
|
||||
let tmp_c0: usize = module.bytes_of_vec_znx_big(1, out_size);
|
||||
let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size);
|
||||
let vmp_tsk: usize = module.bytes_of_vec_znx_dft(1, out_size)
|
||||
+ module.vmp_apply_tmp_bytes(out_size, out_size, rank + 1, rank + 1, rank + 1, tsk_size);
|
||||
let tmp_idft: usize = module.bytes_of_vec_znx_big(1, tsk_size);
|
||||
let tmp_znx_small: usize = module.bytes_of_vec_znx(1, out_size);
|
||||
let norm: usize = module.vec_znx_big_normalize_tmp_bytes();
|
||||
tmp_dft_out + (vmp_ksk | (tmp_c0 + tmp_dft_i + (vmp_tsk | (tmp_idft + tmp_znx_small + norm))))
|
||||
}
|
||||
|
||||
pub fn automorphism_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
out_size: usize,
|
||||
@@ -90,7 +111,8 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
|
||||
) -> usize {
|
||||
let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, auto_key_size);
|
||||
let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, out_size);
|
||||
let vmp: usize = GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, auto_key_size);
|
||||
let vmp: usize =
|
||||
GLWECiphertext::keyswitch_from_fourier_scratch_space(module, out_size, rank, in_size, rank, auto_key_size);
|
||||
tmp_dft + tmp_idft + vmp
|
||||
}
|
||||
|
||||
@@ -198,7 +220,6 @@ where
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
pub fn keyswitch<DataLhs, DataRhs0, DataRhs1>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
@@ -239,7 +260,7 @@ where
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
// Example for rank 3:
|
||||
//
|
||||
//
|
||||
// Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is
|
||||
// actually composed of that many rows.
|
||||
//
|
||||
@@ -250,14 +271,13 @@ where
|
||||
// col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + M, c2 )
|
||||
// col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + M)
|
||||
//
|
||||
// # Output
|
||||
// # Output
|
||||
//
|
||||
// col 0: (-(a0s0' + a1s1' + a2s2') + M, a0 , a1 , a2 )
|
||||
// col 1: (-(b0s0' + b1s1' + b2s2') , b0 + M, b1 , b2 )
|
||||
// col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M, c2 )
|
||||
// col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M)
|
||||
(0..self.rows()).for_each(|row_j| {
|
||||
|
||||
let (tmp_dft_out_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
|
||||
|
||||
let mut tmp_dft_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
@@ -291,7 +311,6 @@ where
|
||||
// col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) = KS_{s1's0', s1's1', s1's2'}(a1) + (0, 0, -(a0s0' + a1s1' + a2s2') + M[i], 0)
|
||||
// col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) = KS_{s2's0', s2's1', s2's2'}(a2) + (0, 0, 0, -(a0s0' + a1s1' + a2s2') + M[i])
|
||||
(1..cols).for_each(|col_i| {
|
||||
|
||||
let (tmp_dft_i_data, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, tsk.size());
|
||||
let mut tmp_dft_i: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||
data: tmp_dft_i_data,
|
||||
@@ -311,7 +330,6 @@ where
|
||||
// =
|
||||
// (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0, x1, x2)
|
||||
(1..cols).for_each(|col_j| {
|
||||
|
||||
// Extracts a[i] and multipies with Enc(s'[i]s'[j])
|
||||
let (mut tmp_dft_col_data, scratch4) = scratch3.tmp_vec_znx_dft(module, 1, self.size());
|
||||
tmp_dft_col_data.extract_column(0, &tmp_dft_out.data, col_j);
|
||||
@@ -336,7 +354,7 @@ where
|
||||
// Adds -(sum a[i] * s[i]) + m) on the i-th column of tmp_idft_i
|
||||
//
|
||||
// (-(x0s0' + x1s1' + x2s2') + a0s0's0' + a1s0's1' + a2s0's2', x0, x1, x2)
|
||||
// +
|
||||
// +
|
||||
// (0, -(a0s0' + a1s1' + a2s2') + M[i], 0, 0)
|
||||
// =
|
||||
// (-(x0s0' + x1s1' + x2s2') + s0'(a0s0' + a1s1' + a2s2'), x0 -(a0s0' + a1s1' + a2s2') + M[i], x1, x2)
|
||||
@@ -347,7 +365,9 @@ where
|
||||
let (mut tmp_znx_small, scratch5) = scratch3.tmp_vec_znx(module, 1, self.size());
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_idft_tmp_a(&mut tmp_idft, 0, &mut tmp_dft_i, i);
|
||||
module.vec_znx_big_add_inplace(&mut tmp_idft, col_i, &tmp_c0_data, 0);
|
||||
if i == col_i {
|
||||
module.vec_znx_big_add_inplace(&mut tmp_idft, 0, &tmp_c0_data, 0);
|
||||
}
|
||||
module.vec_znx_big_normalize(self.basek(), &mut tmp_znx_small, 0, &tmp_idft, 0, scratch5);
|
||||
module.vec_znx_dft(&mut tmp_dft_i, i, &tmp_znx_small, 0);
|
||||
});
|
||||
@@ -385,8 +405,7 @@ where
|
||||
self.rank(),
|
||||
rhs.rank()
|
||||
);
|
||||
}
|
||||
;
|
||||
};
|
||||
let cols: usize = self.rank() + 1;
|
||||
|
||||
let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rhs.size()); //TODO optimize
|
||||
|
||||
Reference in New Issue
Block a user