mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
added rgsw encrypt + test
This commit is contained in:
@@ -196,7 +196,7 @@ impl Scratch {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tmp_scalar<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
|
pub fn tmp_scalar_znx<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
|
||||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols));
|
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols));
|
||||||
|
|
||||||
(
|
(
|
||||||
@@ -205,7 +205,7 @@ impl Scratch {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn tmp_scalar_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) {
|
pub fn tmp_scalar_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) {
|
||||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols));
|
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols));
|
||||||
|
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use crate::znx_base::ZnxInfos;
|
use crate::znx_base::ZnxInfos;
|
||||||
use crate::{Backend, DataView, DataViewMut, Module, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned};
|
use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut};
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
use rand_core::RngCore;
|
use rand_core::RngCore;
|
||||||
use rand_distr::{Distribution, weighted::WeightedIndex};
|
use rand_distr::{Distribution, weighted::WeightedIndex};
|
||||||
@@ -144,6 +144,17 @@ impl ScalarZnxToMut for ScalarZnx<Vec<u8>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl VecZnxToMut for ScalarZnx<Vec<u8>>{
|
||||||
|
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||||
|
VecZnx {
|
||||||
|
data: self.data.as_mut_slice(),
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ScalarZnxToRef for ScalarZnx<Vec<u8>> {
|
impl ScalarZnxToRef for ScalarZnx<Vec<u8>> {
|
||||||
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
||||||
ScalarZnx {
|
ScalarZnx {
|
||||||
@@ -154,6 +165,17 @@ impl ScalarZnxToRef for ScalarZnx<Vec<u8>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl VecZnxToRef for ScalarZnx<Vec<u8>>{
|
||||||
|
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||||
|
VecZnx {
|
||||||
|
data: self.data.as_slice(),
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ScalarZnxToMut for ScalarZnx<&mut [u8]> {
|
impl ScalarZnxToMut for ScalarZnx<&mut [u8]> {
|
||||||
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
|
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
|
||||||
ScalarZnx {
|
ScalarZnx {
|
||||||
@@ -164,6 +186,17 @@ impl ScalarZnxToMut for ScalarZnx<&mut [u8]> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl VecZnxToMut for ScalarZnx<&mut [u8]> {
|
||||||
|
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||||
|
VecZnx {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ScalarZnxToRef for ScalarZnx<&mut [u8]> {
|
impl ScalarZnxToRef for ScalarZnx<&mut [u8]> {
|
||||||
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
||||||
ScalarZnx {
|
ScalarZnx {
|
||||||
@@ -174,6 +207,17 @@ impl ScalarZnxToRef for ScalarZnx<&mut [u8]> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl VecZnxToRef for ScalarZnx<&mut [u8]> {
|
||||||
|
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||||
|
VecZnx {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl ScalarZnxToRef for ScalarZnx<&[u8]> {
|
impl ScalarZnxToRef for ScalarZnx<&[u8]> {
|
||||||
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
||||||
ScalarZnx {
|
ScalarZnx {
|
||||||
@@ -183,3 +227,14 @@ impl ScalarZnxToRef for ScalarZnx<&[u8]> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl VecZnxToRef for ScalarZnx<&[u8]> {
|
||||||
|
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||||
|
VecZnx {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ use std::{cmp::min, fmt};
|
|||||||
/// are small polynomials of Zn\[X\].
|
/// are small polynomials of Zn\[X\].
|
||||||
pub struct VecZnx<D> {
|
pub struct VecZnx<D> {
|
||||||
pub data: D,
|
pub data: D,
|
||||||
n: usize,
|
pub n: usize,
|
||||||
cols: usize,
|
pub cols: usize,
|
||||||
size: usize,
|
pub size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D> ZnxInfos for VecZnx<D> {
|
impl<D> ZnxInfos for VecZnx<D> {
|
||||||
|
|||||||
@@ -114,6 +114,9 @@ pub trait VecZnxBigOps<BACKEND: Backend> {
|
|||||||
R: VecZnxBigToMut<BACKEND>,
|
R: VecZnxBigToMut<BACKEND>,
|
||||||
A: VecZnxToRef;
|
A: VecZnxToRef;
|
||||||
|
|
||||||
|
/// Negates `a` inplace.
|
||||||
|
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize) where A: VecZnxBigToMut<BACKEND>;
|
||||||
|
|
||||||
/// Normalizes `a` and stores the result on `b`.
|
/// Normalizes `a` and stores the result on `b`.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
@@ -503,6 +506,25 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, res_col: usize) where A: VecZnxBigToMut<FFT64> {
|
||||||
|
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(a.n(), self.n());
|
||||||
|
}
|
||||||
|
unsafe {
|
||||||
|
vec_znx::vec_znx_negate(
|
||||||
|
self.ptr,
|
||||||
|
a.at_mut_ptr(res_col, 0),
|
||||||
|
a.size() as u64,
|
||||||
|
a.sl() as u64,
|
||||||
|
a.at_ptr(res_col, 0),
|
||||||
|
a.size() as u64,
|
||||||
|
a.sl() as u64,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn vec_znx_big_normalize<R, A>(
|
fn vec_znx_big_normalize<R, A>(
|
||||||
&self,
|
&self,
|
||||||
log_base2k: usize,
|
log_base2k: usize,
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ pub fn encrypt_grlwe_sk<C, P, S>(
|
|||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
ct: &mut GRLWECt<C, FFT64>,
|
ct: &mut GRLWECt<C, FFT64>,
|
||||||
pt: &ScalarZnx<P>,
|
pt: &ScalarZnx<P>,
|
||||||
sk: &SecretKeyDft<S, FFT64>,
|
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||||
source_xa: &mut Source,
|
source_xa: &mut Source,
|
||||||
source_xe: &mut Source,
|
source_xe: &mut Source,
|
||||||
sigma: f64,
|
sigma: f64,
|
||||||
@@ -131,7 +131,7 @@ pub fn encrypt_grlwe_sk<C, P, S>(
|
|||||||
vec_znx_ct.encrypt_sk(
|
vec_znx_ct.encrypt_sk(
|
||||||
module,
|
module,
|
||||||
Some(&vec_znx_pt),
|
Some(&vec_znx_pt),
|
||||||
sk,
|
sk_dft,
|
||||||
source_xa,
|
source_xa,
|
||||||
source_xe,
|
source_xe,
|
||||||
sigma,
|
sigma,
|
||||||
@@ -186,7 +186,7 @@ mod tests {
|
|||||||
use super::GRLWECt;
|
use super::GRLWECt;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn encrypt_sk_vec_znx_fft64() {
|
fn encrypt_sk_fft64() {
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
||||||
let log_base2k: usize = 8;
|
let log_base2k: usize = 8;
|
||||||
let log_k_ct: usize = 54;
|
let log_k_ct: usize = 54;
|
||||||
@@ -233,7 +233,7 @@ mod tests {
|
|||||||
ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||||
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0);
|
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0);
|
||||||
let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2();
|
let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2();
|
||||||
assert!((sigma - std_pt) <= 0.2, "{} {}", sigma, std_pt);
|
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
|
||||||
});
|
});
|
||||||
|
|
||||||
module.free();
|
module.free();
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
use base2k::{
|
use base2k::{
|
||||||
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
|
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
|
||||||
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxDftOps, ZnxView, ZnxViewMut,
|
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxOps,
|
||||||
|
ZnxZero,
|
||||||
};
|
};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
elem::Infos,
|
elem::Infos,
|
||||||
elem_grlwe::GRLWECt,
|
elem_rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk},
|
||||||
elem_rlwe::{RLWECt, RLWECtDft, RLWEPt},
|
|
||||||
keys::SecretKeyDft,
|
keys::SecretKeyDft,
|
||||||
utils::derive_size,
|
utils::derive_size,
|
||||||
};
|
};
|
||||||
@@ -62,28 +62,32 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GRLWECt<Vec<u8>, FFT64> {
|
impl RGSWCt<Vec<u8>, FFT64> {
|
||||||
pub fn encrypt_sk_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
|
pub fn encrypt_sk_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
|
||||||
RLWECt::encrypt_sk_scratch_bytes(module, size)
|
RLWECt::encrypt_sk_scratch_bytes(module, size)
|
||||||
+ module.bytes_of_vec_znx(2, size)
|
+ module.bytes_of_vec_znx(2, size)
|
||||||
+ module.bytes_of_vec_znx(1, size)
|
+ module.bytes_of_vec_znx(1, size)
|
||||||
+ module.bytes_of_vec_znx_dft(2, size)
|
+ module.bytes_of_vec_znx_dft(2, size)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn encrypt_pk_scratch_bytes(module: &Module<FFT64>, pk_size: usize) -> usize {
|
impl<C> RGSWCt<C, FFT64>
|
||||||
RLWECt::encrypt_pk_scratch_bytes(module, pk_size)
|
where
|
||||||
}
|
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
pub fn decrypt_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
|
pub fn get_row(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut RLWECtDft<C, FFT64>)
|
||||||
RLWECtDft::decrypt_scratch_bytes(module, size)
|
where
|
||||||
|
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64>,
|
||||||
|
{
|
||||||
|
module.vmp_extract_row(res, self, row_i, col_j);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn encrypt_grlwe_sk<C, P, S>(
|
pub fn encrypt_rgsw_sk<C, P, S>(
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
ct: &mut GRLWECt<C, FFT64>,
|
ct: &mut RGSWCt<C, FFT64>,
|
||||||
pt: &ScalarZnx<P>,
|
pt: &ScalarZnx<P>,
|
||||||
sk: &SecretKeyDft<S, FFT64>,
|
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||||
source_xa: &mut Source,
|
source_xa: &mut Source,
|
||||||
source_xe: &mut Source,
|
source_xe: &mut Source,
|
||||||
sigma: f64,
|
sigma: f64,
|
||||||
@@ -94,47 +98,164 @@ pub fn encrypt_grlwe_sk<C, P, S>(
|
|||||||
ScalarZnx<P>: ScalarZnxToRef,
|
ScalarZnx<P>: ScalarZnxToRef,
|
||||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
let rows: usize = ct.rows();
|
|
||||||
let size: usize = ct.size();
|
let size: usize = ct.size();
|
||||||
|
let log_base2k: usize = ct.log_base2k();
|
||||||
|
|
||||||
let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size);
|
let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size);
|
||||||
let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size);
|
let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, 2, size);
|
||||||
let (mut tmp_dft, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size);
|
|
||||||
|
|
||||||
let mut tmp_pt: RLWEPt<&mut [u8]> = RLWEPt {
|
let mut vec_znx_pt: RLWEPt<&mut [u8]> = RLWEPt {
|
||||||
data: tmp_znx_pt,
|
data: tmp_znx_pt,
|
||||||
log_base2k: ct.log_base2k(),
|
log_base2k: log_base2k,
|
||||||
log_k: ct.log_k(),
|
log_k: ct.log_k(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut tmp_ct: RLWECt<&mut [u8]> = RLWECt {
|
let mut vec_znx_ct: RLWECt<&mut [u8]> = RLWECt {
|
||||||
data: tmp_znx_ct,
|
data: tmp_znx_ct,
|
||||||
log_base2k: ct.log_base2k(),
|
log_base2k: log_base2k,
|
||||||
log_k: ct.log_k(),
|
log_k: ct.log_k(),
|
||||||
};
|
};
|
||||||
|
|
||||||
(0..rows).for_each(|row_i| {
|
(0..ct.rows()).for_each(|row_j| {
|
||||||
tmp_pt
|
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
||||||
.data
|
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0);
|
||||||
.at_mut(0, row_i)
|
module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2);
|
||||||
.copy_from_slice(&pt.to_ref().raw());
|
|
||||||
|
|
||||||
tmp_ct.encrypt_sk(
|
(0..ct.cols()).for_each(|col_i| {
|
||||||
|
// rlwe encrypt of vec_znx_pt into vec_znx_ct
|
||||||
|
encrypt_rlwe_sk(
|
||||||
module,
|
module,
|
||||||
Some(&tmp_pt),
|
&mut vec_znx_ct,
|
||||||
sk,
|
Some((&vec_znx_pt, col_i)),
|
||||||
|
sk_dft,
|
||||||
source_xa,
|
source_xa,
|
||||||
source_xe,
|
source_xe,
|
||||||
sigma,
|
sigma,
|
||||||
bound,
|
bound,
|
||||||
scratch_3,
|
scrach_2,
|
||||||
);
|
);
|
||||||
|
|
||||||
tmp_pt.data.at_mut(0, row_i).fill(0);
|
// Switch vec_znx_ct into DFT domain
|
||||||
|
{
|
||||||
|
let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, 2, size);
|
||||||
|
module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0);
|
||||||
|
module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1);
|
||||||
|
module.vmp_prepare_row(ct, row_j, col_i, &vec_znx_dft_ct);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
module.vec_znx_dft(&mut tmp_dft, 0, &tmp_ct, 0);
|
vec_znx_pt.data.zero(); // zeroes for next iteration
|
||||||
module.vec_znx_dft(&mut tmp_dft, 1, &tmp_ct, 1);
|
|
||||||
|
|
||||||
module.vmp_prepare_row(ct, row_i, 0, &tmp_dft);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<C> RGSWCt<C, FFT64> {
|
||||||
|
pub fn encrypt_sk<P, S>(
|
||||||
|
&mut self,
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
pt: &ScalarZnx<P>,
|
||||||
|
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||||
|
source_xa: &mut Source,
|
||||||
|
source_xe: &mut Source,
|
||||||
|
sigma: f64,
|
||||||
|
bound: f64,
|
||||||
|
scratch: &mut Scratch,
|
||||||
|
) where
|
||||||
|
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
|
||||||
|
ScalarZnx<P>: ScalarZnxToRef,
|
||||||
|
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
|
encrypt_rgsw_sk(
|
||||||
|
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use base2k::{
|
||||||
|
FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps,
|
||||||
|
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero,
|
||||||
|
};
|
||||||
|
use sampling::source::Source;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
elem::Infos,
|
||||||
|
elem_rlwe::{RLWECtDft, RLWEPt},
|
||||||
|
keys::{SecretKey, SecretKeyDft},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::RGSWCt;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encrypt_rgsw_sk_fft64() {
|
||||||
|
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
||||||
|
let log_base2k: usize = 8;
|
||||||
|
let log_k_ct: usize = 54;
|
||||||
|
let rows: usize = 4;
|
||||||
|
|
||||||
|
let sigma: f64 = 3.2;
|
||||||
|
let bound: f64 = sigma * 6.0;
|
||||||
|
|
||||||
|
let mut ct: RGSWCt<Vec<u8>, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows);
|
||||||
|
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
|
||||||
|
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
|
||||||
|
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||||
|
|
||||||
|
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||||
|
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||||
|
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||||
|
|
||||||
|
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
|
||||||
|
|
||||||
|
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||||
|
RGSWCt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECtDft::decrypt_scratch_bytes(&module, ct.size()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||||
|
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||||
|
|
||||||
|
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||||
|
sk_dft.dft(&module, &sk);
|
||||||
|
|
||||||
|
ct.encrypt_sk(
|
||||||
|
&module,
|
||||||
|
&pt_scalar,
|
||||||
|
&sk_dft,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
sigma,
|
||||||
|
bound,
|
||||||
|
scratch.borrow(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut ct_rlwe_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2);
|
||||||
|
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct.size());
|
||||||
|
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct.size());
|
||||||
|
|
||||||
|
(0..ct.cols()).for_each(|col_j| {
|
||||||
|
(0..ct.rows()).for_each(|row_i| {
|
||||||
|
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
|
||||||
|
|
||||||
|
if col_j == 1 {
|
||||||
|
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
||||||
|
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0);
|
||||||
|
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
||||||
|
module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
||||||
|
}
|
||||||
|
|
||||||
|
ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft);
|
||||||
|
|
||||||
|
ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||||
|
|
||||||
|
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||||
|
|
||||||
|
let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2();
|
||||||
|
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
|
||||||
|
|
||||||
|
pt_want.data.zero();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
module.free();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ impl RLWECt<Vec<u8>> {
|
|||||||
pub fn encrypt_rlwe_sk<C, P, S>(
|
pub fn encrypt_rlwe_sk<C, P, S>(
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
ct: &mut RLWECt<C>,
|
ct: &mut RLWECt<C>,
|
||||||
pt: Option<&RLWEPt<P>>,
|
pt: Option<(&RLWEPt<P>, usize)>,
|
||||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||||
source_xa: &mut Source,
|
source_xa: &mut Source,
|
||||||
source_xe: &mut Source,
|
source_xe: &mut Source,
|
||||||
@@ -213,8 +213,18 @@ pub fn encrypt_rlwe_sk<C, P, S>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// c0_big = m - c0_big
|
// c0_big = m - c0_big
|
||||||
if let Some(pt) = pt {
|
if let Some((pt, col)) = pt {
|
||||||
module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0);
|
match col {
|
||||||
|
0 => module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0),
|
||||||
|
1 => {
|
||||||
|
module.vec_znx_big_negate_inplace(&mut c0_big, 0);
|
||||||
|
module.vec_znx_add_inplace(ct, 1, pt, 0);
|
||||||
|
module.vec_znx_normalize_inplace(log_base2k, ct, 1, scratch_1);
|
||||||
|
}
|
||||||
|
_ => panic!("invalid target column: {}", col),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
module.vec_znx_big_negate_inplace(&mut c0_big, 0);
|
||||||
}
|
}
|
||||||
// c0_big += e
|
// c0_big += e
|
||||||
c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound);
|
c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound);
|
||||||
@@ -273,9 +283,23 @@ impl<C> RLWECt<C> {
|
|||||||
VecZnx<P>: VecZnxToRef,
|
VecZnx<P>: VecZnxToRef,
|
||||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
|
if let Some(pt) = pt {
|
||||||
encrypt_rlwe_sk(
|
encrypt_rlwe_sk(
|
||||||
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
|
module,
|
||||||
|
self,
|
||||||
|
Some((pt, 0)),
|
||||||
|
sk_dft,
|
||||||
|
source_xa,
|
||||||
|
source_xe,
|
||||||
|
sigma,
|
||||||
|
bound,
|
||||||
|
scratch,
|
||||||
)
|
)
|
||||||
|
} else {
|
||||||
|
encrypt_rlwe_sk::<C, P, S>(
|
||||||
|
module, self, None, sk_dft, source_xa, source_xe, sigma, bound, scratch,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decrypt<P, S>(
|
pub fn decrypt<P, S>(
|
||||||
@@ -483,10 +507,10 @@ pub(crate) fn encrypt_rlwe_pk<C, P, S>(
|
|||||||
let size_pk: usize = pk.size();
|
let size_pk: usize = pk.size();
|
||||||
|
|
||||||
// Generates u according to the underlying secret distribution.
|
// Generates u according to the underlying secret distribution.
|
||||||
let (mut u_dft, scratch_1) = scratch.tmp_scalar_dft(module, 1);
|
let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1);
|
||||||
|
|
||||||
{
|
{
|
||||||
let (mut u, _) = scratch_1.tmp_scalar(module, 1);
|
let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1);
|
||||||
match pk.dist {
|
match pk.dist {
|
||||||
SecretDistribution::NONE => panic!(
|
SecretDistribution::NONE => panic!(
|
||||||
"invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate"
|
"invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate"
|
||||||
|
|||||||
Reference in New Issue
Block a user