mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Added grlwe encrypt + test
This commit is contained in:
@@ -1,6 +1,16 @@
|
|||||||
use base2k::{Backend, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module};
|
use base2k::{
|
||||||
|
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
|
||||||
|
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxOps,
|
||||||
|
ZnxZero,
|
||||||
|
};
|
||||||
|
use sampling::source::Source;
|
||||||
|
|
||||||
use crate::{elem::Infos, utils::derive_size};
|
use crate::{
|
||||||
|
elem::Infos,
|
||||||
|
elem_rlwe::{RLWECt, RLWECtDft, RLWEPt},
|
||||||
|
keys::SecretKeyDft,
|
||||||
|
utils::derive_size,
|
||||||
|
};
|
||||||
|
|
||||||
pub struct GRLWECt<C, B: Backend> {
|
pub struct GRLWECt<C, B: Backend> {
|
||||||
pub data: MatZnxDft<C, B>,
|
pub data: MatZnxDft<C, B>,
|
||||||
@@ -18,6 +28,18 @@ impl<B: Backend> GRLWECt<Vec<u8>, B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<C> GRLWECt<C, FFT64>
|
||||||
|
where
|
||||||
|
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
|
pub fn get_row(&self, module: &Module<FFT64>, i: usize, res: &mut RLWECtDft<C, FFT64>)
|
||||||
|
where
|
||||||
|
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64>,
|
||||||
|
{
|
||||||
|
module.vmp_extract_row(res, self, i, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<T, B: Backend> Infos for GRLWECt<T, B> {
|
impl<T, B: Backend> Infos for GRLWECt<T, B> {
|
||||||
type Inner = MatZnxDft<T, B>;
|
type Inner = MatZnxDft<T, B>;
|
||||||
|
|
||||||
@@ -51,3 +73,169 @@ where
|
|||||||
self.data.to_ref()
|
self.data.to_ref()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl GRLWECt<Vec<u8>, FFT64> {
|
||||||
|
pub fn encrypt_sk_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
|
||||||
|
RLWECt::encrypt_sk_scratch_bytes(module, size)
|
||||||
|
+ module.bytes_of_vec_znx(2, size)
|
||||||
|
+ module.bytes_of_vec_znx(1, size)
|
||||||
|
+ module.bytes_of_vec_znx_dft(2, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pub fn encrypt_pk_scratch_bytes(module: &Module<FFT64>, pk_size: usize) -> usize {
|
||||||
|
// RLWECt::encrypt_pk_scratch_bytes(module, pk_size)
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encrypt_grlwe_sk<C, P, S>(
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
ct: &mut GRLWECt<C, FFT64>,
|
||||||
|
pt: &ScalarZnx<P>,
|
||||||
|
sk: &SecretKeyDft<S, FFT64>,
|
||||||
|
source_xa: &mut Source,
|
||||||
|
source_xe: &mut Source,
|
||||||
|
sigma: f64,
|
||||||
|
bound: f64,
|
||||||
|
scratch: &mut Scratch,
|
||||||
|
) where
|
||||||
|
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
|
||||||
|
ScalarZnx<P>: ScalarZnxToRef,
|
||||||
|
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
|
let rows: usize = ct.rows();
|
||||||
|
let size: usize = ct.size();
|
||||||
|
let log_base2k: usize = ct.log_base2k();
|
||||||
|
|
||||||
|
let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size);
|
||||||
|
let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size);
|
||||||
|
let (mut vec_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size);
|
||||||
|
|
||||||
|
let mut vec_znx_pt: RLWEPt<&mut [u8]> = RLWEPt {
|
||||||
|
data: tmp_znx_pt,
|
||||||
|
log_base2k: log_base2k,
|
||||||
|
log_k: ct.log_k(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut vec_znx_ct: RLWECt<&mut [u8]> = RLWECt {
|
||||||
|
data: tmp_znx_ct,
|
||||||
|
log_base2k: log_base2k,
|
||||||
|
log_k: ct.log_k(),
|
||||||
|
};
|
||||||
|
|
||||||
|
(0..rows).for_each(|row_i| {
|
||||||
|
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
||||||
|
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0);
|
||||||
|
module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scratch_3);
|
||||||
|
|
||||||
|
// rlwe encrypt of vec_znx_pt into vec_znx_ct
|
||||||
|
vec_znx_ct.encrypt_sk(
|
||||||
|
module,
|
||||||
|
Some(&vec_znx_pt),
|
||||||
|
sk,
|
||||||
|
source_xa,
|
||||||
|
source_xe,
|
||||||
|
sigma,
|
||||||
|
bound,
|
||||||
|
scratch_3,
|
||||||
|
);
|
||||||
|
|
||||||
|
vec_znx_pt.data.zero(); // zeroes for next iteration
|
||||||
|
|
||||||
|
// Switch vec_znx_ct into DFT domain
|
||||||
|
module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0);
|
||||||
|
module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1);
|
||||||
|
|
||||||
|
// Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft
|
||||||
|
module.vmp_prepare_row(ct, row_i, 0, &vec_znx_dft_ct);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C> GRLWECt<C, FFT64> {
|
||||||
|
pub fn encrypt_sk<P, S>(
|
||||||
|
&mut self,
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
pt: &ScalarZnx<P>,
|
||||||
|
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||||
|
source_xa: &mut Source,
|
||||||
|
source_xe: &mut Source,
|
||||||
|
sigma: f64,
|
||||||
|
bound: f64,
|
||||||
|
scratch: &mut Scratch,
|
||||||
|
) where
|
||||||
|
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
|
||||||
|
ScalarZnx<P>: ScalarZnxToRef,
|
||||||
|
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
|
encrypt_grlwe_sk(
|
||||||
|
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use base2k::{FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps};
|
||||||
|
use sampling::source::Source;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
elem::Infos,
|
||||||
|
elem_rlwe::{RLWECtDft, RLWEPt},
|
||||||
|
keys::{SecretKey, SecretKeyDft},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::GRLWECt;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encrypt_sk_vec_znx_fft64() {
|
||||||
|
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
||||||
|
let log_base2k: usize = 8;
|
||||||
|
let log_k_ct: usize = 54;
|
||||||
|
let rows: usize = 4;
|
||||||
|
|
||||||
|
let sigma: f64 = 3.2;
|
||||||
|
let bound: f64 = sigma * 6.0;
|
||||||
|
|
||||||
|
let mut ct: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_ct, rows);
|
||||||
|
let mut pt: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
|
||||||
|
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||||
|
|
||||||
|
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||||
|
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||||
|
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||||
|
|
||||||
|
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
|
||||||
|
|
||||||
|
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||||
|
GRLWECt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECtDft::decrypt_scratch_bytes(&module, ct.size()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||||
|
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||||
|
|
||||||
|
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||||
|
sk_dft.dft(&module, &sk);
|
||||||
|
|
||||||
|
ct.encrypt_sk(
|
||||||
|
&module,
|
||||||
|
&pt_scalar,
|
||||||
|
&sk_dft,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
sigma,
|
||||||
|
bound,
|
||||||
|
scratch.borrow(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut ct_rlwe_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2);
|
||||||
|
|
||||||
|
(0..ct.rows()).for_each(|row_i| {
|
||||||
|
ct.get_row(&module, row_i, &mut ct_rlwe_dft);
|
||||||
|
ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||||
|
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0);
|
||||||
|
let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2();
|
||||||
|
assert!((sigma - std_pt) <= 0.2, "{} {}", sigma, std_pt);
|
||||||
|
});
|
||||||
|
|
||||||
|
module.free();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -181,7 +181,7 @@ pub fn encrypt_rlwe_sk<C, P, S>(
|
|||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
ct: &mut RLWECt<C>,
|
ct: &mut RLWECt<C>,
|
||||||
pt: Option<&RLWEPt<P>>,
|
pt: Option<&RLWEPt<P>>,
|
||||||
sk: &SecretKeyDft<S, FFT64>,
|
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||||
source_xa: &mut Source,
|
source_xa: &mut Source,
|
||||||
source_xe: &mut Source,
|
source_xe: &mut Source,
|
||||||
sigma: f64,
|
sigma: f64,
|
||||||
@@ -206,7 +206,7 @@ pub fn encrypt_rlwe_sk<C, P, S>(
|
|||||||
module.vec_znx_dft(&mut c0_dft, 0, ct, 1);
|
module.vec_znx_dft(&mut c0_dft, 0, ct, 1);
|
||||||
|
|
||||||
// c0_dft = DFT(a) * DFT(s)
|
// c0_dft = DFT(a) * DFT(s)
|
||||||
module.svp_apply_inplace(&mut c0_dft, 0, sk, 0);
|
module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0);
|
||||||
|
|
||||||
// c0_big = IDFT(c0_dft)
|
// c0_big = IDFT(c0_dft)
|
||||||
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
|
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
|
||||||
@@ -227,7 +227,7 @@ pub fn decrypt_rlwe<P, C, S>(
|
|||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
pt: &mut RLWEPt<P>,
|
pt: &mut RLWEPt<P>,
|
||||||
ct: &RLWECt<C>,
|
ct: &RLWECt<C>,
|
||||||
sk: &SecretKeyDft<S, FFT64>,
|
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||||
scratch: &mut Scratch,
|
scratch: &mut Scratch,
|
||||||
) where
|
) where
|
||||||
VecZnx<P>: VecZnxToMut + VecZnxToRef,
|
VecZnx<P>: VecZnxToMut + VecZnxToRef,
|
||||||
@@ -241,7 +241,7 @@ pub fn decrypt_rlwe<P, C, S>(
|
|||||||
module.vec_znx_dft(&mut c0_dft, 0, ct, 1);
|
module.vec_znx_dft(&mut c0_dft, 0, ct, 1);
|
||||||
|
|
||||||
// c0_dft = DFT(a) * DFT(s)
|
// c0_dft = DFT(a) * DFT(s)
|
||||||
module.svp_apply_inplace(&mut c0_dft, 0, sk, 0);
|
module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0);
|
||||||
|
|
||||||
// c0_big = IDFT(c0_dft)
|
// c0_big = IDFT(c0_dft)
|
||||||
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
|
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
|
||||||
@@ -262,7 +262,7 @@ impl<C> RLWECt<C> {
|
|||||||
&mut self,
|
&mut self,
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
pt: Option<&RLWEPt<P>>,
|
pt: Option<&RLWEPt<P>>,
|
||||||
sk: &SecretKeyDft<S, FFT64>,
|
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||||
source_xa: &mut Source,
|
source_xa: &mut Source,
|
||||||
source_xe: &mut Source,
|
source_xe: &mut Source,
|
||||||
sigma: f64,
|
sigma: f64,
|
||||||
@@ -274,17 +274,22 @@ impl<C> RLWECt<C> {
|
|||||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
encrypt_rlwe_sk(
|
encrypt_rlwe_sk(
|
||||||
module, self, pt, sk, source_xa, source_xe, sigma, bound, scratch,
|
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decrypt<P, S>(&self, module: &Module<FFT64>, pt: &mut RLWEPt<P>, sk: &SecretKeyDft<S, FFT64>, scratch: &mut Scratch)
|
pub fn decrypt<P, S>(
|
||||||
where
|
&self,
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
pt: &mut RLWEPt<P>,
|
||||||
|
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||||
|
scratch: &mut Scratch,
|
||||||
|
) where
|
||||||
VecZnx<P>: VecZnxToMut + VecZnxToRef,
|
VecZnx<P>: VecZnxToMut + VecZnxToRef,
|
||||||
VecZnx<C>: VecZnxToRef,
|
VecZnx<C>: VecZnxToRef,
|
||||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
decrypt_rlwe(module, pt, self, sk, scratch);
|
decrypt_rlwe(module, pt, self, sk_dft, scratch);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn encrypt_pk<P, S>(
|
pub fn encrypt_pk<P, S>(
|
||||||
@@ -526,7 +531,7 @@ mod tests {
|
|||||||
};
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn encrypt_sk_vec_znx_fft64() {
|
fn encrypt_sk_fft64() {
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(32);
|
let module: Module<FFT64> = Module::<FFT64>::new(32);
|
||||||
let log_base2k: usize = 8;
|
let log_base2k: usize = 8;
|
||||||
let log_k_ct: usize = 54;
|
let log_k_ct: usize = 54;
|
||||||
@@ -597,7 +602,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn encrypt_zero_rlwe_dft_sk_fft64() {
|
fn encrypt_zero_sk_fft64() {
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(1024);
|
let module: Module<FFT64> = Module::<FFT64>::new(1024);
|
||||||
let log_base2k: usize = 8;
|
let log_base2k: usize = 8;
|
||||||
let log_k_ct: usize = 55;
|
let log_k_ct: usize = 55;
|
||||||
@@ -639,7 +644,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn encrypt_pk_vec_znx_fft64() {
|
fn encrypt_pk_fft64() {
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(32);
|
let module: Module<FFT64> = Module::<FFT64>::new(32);
|
||||||
let log_base2k: usize = 8;
|
let log_base2k: usize = 8;
|
||||||
let log_k_ct: usize = 54;
|
let log_k_ct: usize = 54;
|
||||||
|
|||||||
Reference in New Issue
Block a user