mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
renamed crate & files
This commit is contained in:
12
core/Cargo.toml
Normal file
12
core/Cargo.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "rlwe"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
rug = {workspace = true}
|
||||
criterion = {workspace = true}
|
||||
base2k = {path="../base2k"}
|
||||
sampling = {path="../sampling"}
|
||||
rand_distr = {workspace = true}
|
||||
itertools = {workspace = true}
|
||||
213
core/src/elem.rs
Normal file
213
core/src/elem.rs
Normal file
@@ -0,0 +1,213 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftToMut, MatZnxDftToRef, Module, Scratch, VecZnx, VecZnxDft, VecZnxDftOps, VecZnxDftToMut,
|
||||
VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxZero,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
grlwe::GRLWECt,
|
||||
rlwe::{RLWECt, RLWECtDft},
|
||||
utils::derive_size,
|
||||
};
|
||||
|
||||
pub trait Infos {
|
||||
type Inner: ZnxInfos;
|
||||
|
||||
fn inner(&self) -> &Self::Inner;
|
||||
|
||||
/// Returns the ring degree of the polynomials.
|
||||
fn n(&self) -> usize {
|
||||
self.inner().n()
|
||||
}
|
||||
|
||||
/// Returns the base two logarithm of the ring dimension of the polynomials.
|
||||
fn log_n(&self) -> usize {
|
||||
self.inner().log_n()
|
||||
}
|
||||
|
||||
/// Returns the number of rows.
|
||||
fn rows(&self) -> usize {
|
||||
self.inner().rows()
|
||||
}
|
||||
|
||||
/// Returns the number of polynomials in each row.
|
||||
fn cols(&self) -> usize {
|
||||
self.inner().cols()
|
||||
}
|
||||
|
||||
/// Returns the number of size per polynomial.
|
||||
fn size(&self) -> usize {
|
||||
let size: usize = self.inner().size();
|
||||
debug_assert_eq!(size, derive_size(self.log_base2k(), self.log_k()));
|
||||
size
|
||||
}
|
||||
|
||||
/// Returns the total number of small polynomials.
|
||||
fn poly_count(&self) -> usize {
|
||||
self.rows() * self.cols() * self.size()
|
||||
}
|
||||
|
||||
/// Returns the base 2 logarithm of the ciphertext base.
|
||||
fn log_base2k(&self) -> usize;
|
||||
|
||||
/// Returns the bit precision of the ciphertext.
|
||||
fn log_k(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait GetRow<B: Backend> {
|
||||
fn get_row<R>(&self, module: &Module<B>, row_i: usize, col_j: usize, res: &mut RLWECtDft<R, B>)
|
||||
where
|
||||
VecZnxDft<R, B>: VecZnxDftToMut<B>;
|
||||
}
|
||||
|
||||
pub trait SetRow<B: Backend> {
|
||||
fn set_row<A>(&mut self, module: &Module<B>, row_i: usize, col_j: usize, a: &RLWECtDft<A, B>)
|
||||
where
|
||||
VecZnxDft<A, B>: VecZnxDftToRef<B>;
|
||||
}
|
||||
|
||||
pub(crate) trait MatZnxDftProducts<D, C>: Infos
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
fn mul_rlwe<R, A>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, a: &RLWECt<A>, scratch: &mut Scratch)
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
VecZnx<R>: VecZnxToMut,
|
||||
VecZnx<A>: VecZnxToRef;
|
||||
|
||||
fn mul_rlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, scratch: &mut Scratch)
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
VecZnx<R>: VecZnxToMut + VecZnxToRef,
|
||||
{
|
||||
unsafe {
|
||||
let res_ptr: *mut RLWECt<R> = res as *mut RLWECt<R>; // This is ok because [Self::mul_rlwe] only updates res at the end.
|
||||
self.mul_rlwe(&module, &mut *res_ptr, &*res_ptr, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
fn mul_rlwe_dft<R, A>(
|
||||
&self,
|
||||
module: &Module<FFT64>,
|
||||
res: &mut RLWECtDft<R, FFT64>,
|
||||
a: &RLWECtDft<A, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64> + ZnxInfos,
|
||||
VecZnxDft<A, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
let log_base2k: usize = self.log_base2k();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.log_base2k(), log_base2k);
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
|
||||
let (a_data, scratch_1) = scratch.tmp_vec_znx(module, 2, a.size());
|
||||
|
||||
let mut a_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> {
|
||||
data: a_data,
|
||||
log_base2k: a.log_base2k(),
|
||||
log_k: a.log_k(),
|
||||
};
|
||||
|
||||
a.idft(module, &mut a_idft, scratch_1);
|
||||
|
||||
let (res_data, scratch_2) = scratch_1.tmp_vec_znx(module, 2, res.size());
|
||||
|
||||
let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> {
|
||||
data: res_data,
|
||||
log_base2k: res.log_base2k(),
|
||||
log_k: res.log_k(),
|
||||
};
|
||||
|
||||
self.mul_rlwe(module, &mut res_idft, &a_idft, scratch_2);
|
||||
|
||||
module.vec_znx_dft(res, 0, &res_idft, 0);
|
||||
module.vec_znx_dft(res, 1, &res_idft, 1);
|
||||
}
|
||||
|
||||
fn mul_rlwe_dft_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECtDft<R, FFT64>, scratch: &mut Scratch)
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToRef<FFT64> + VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
let log_base2k: usize = self.log_base2k();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.log_base2k(), log_base2k);
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
}
|
||||
|
||||
let (res_data, scratch_1) = scratch.tmp_vec_znx(module, 2, res.size());
|
||||
|
||||
let mut res_idft: RLWECt<&mut [u8]> = RLWECt::<&mut [u8]> {
|
||||
data: res_data,
|
||||
log_base2k: res.log_base2k(),
|
||||
log_k: res.log_k(),
|
||||
};
|
||||
|
||||
res.idft(module, &mut res_idft, scratch_1);
|
||||
|
||||
self.mul_rlwe_inplace(module, &mut res_idft, scratch_1);
|
||||
|
||||
module.vec_znx_dft(res, 0, &res_idft, 0);
|
||||
module.vec_znx_dft(res, 1, &res_idft, 1);
|
||||
}
|
||||
|
||||
fn mul_grlwe<R, A>(&self, module: &Module<FFT64>, res: &mut GRLWECt<R, FFT64>, a: &GRLWECt<A, FFT64>, scratch: &mut Scratch)
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
MatZnxDft<R, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
MatZnxDft<A, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, a.size());
|
||||
|
||||
let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> {
|
||||
data: tmp_row_data,
|
||||
log_base2k: a.log_base2k(),
|
||||
log_k: a.log_k(),
|
||||
};
|
||||
|
||||
let min_rows: usize = res.rows().min(a.rows());
|
||||
|
||||
(0..min_rows).for_each(|row_i| {
|
||||
a.get_row(module, row_i, &mut tmp_row);
|
||||
self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
|
||||
res.set_row(module, row_i, &tmp_row);
|
||||
});
|
||||
|
||||
tmp_row.data.zero();
|
||||
|
||||
(min_rows..res.rows()).for_each(|row_i| {
|
||||
res.set_row(module, row_i, &tmp_row);
|
||||
})
|
||||
}
|
||||
|
||||
fn mul_grlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut R, scratch: &mut Scratch)
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
R: GetRow<FFT64> + SetRow<FFT64> + Infos,
|
||||
{
|
||||
let (tmp_row_data, scratch1) = scratch.tmp_vec_znx_dft(module, 2, res.size());
|
||||
|
||||
let mut tmp_row: RLWECtDft<&mut [u8], FFT64> = RLWECtDft::<&mut [u8], FFT64> {
|
||||
data: tmp_row_data,
|
||||
log_base2k: res.log_base2k(),
|
||||
log_k: res.log_k(),
|
||||
};
|
||||
|
||||
(0..self.cols()).for_each(|col_j| {
|
||||
(0..res.rows()).for_each(|row_i| {
|
||||
res.get_row(module, row_i, col_j, &mut tmp_row);
|
||||
self.mul_rlwe_dft_inplace(module, &mut tmp_row, scratch1);
|
||||
res.set_row(module, row_i, col_j, &tmp_row);
|
||||
});
|
||||
})
|
||||
}
|
||||
}
|
||||
343
core/src/grlwe.rs
Normal file
343
core/src/grlwe.rs
Normal file
@@ -0,0 +1,343 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
|
||||
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch,
|
||||
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
|
||||
ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::{GetRow, Infos, MatZnxDftProducts, SetRow},
|
||||
keys::SecretKeyDft,
|
||||
rlwe::{RLWECt, RLWECtDft, RLWEPt},
|
||||
utils::derive_size,
|
||||
};
|
||||
|
||||
pub struct GRLWECt<C, B: Backend> {
|
||||
pub data: MatZnxDft<C, B>,
|
||||
pub log_base2k: usize,
|
||||
pub log_k: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> GRLWECt<Vec<u8>, B> {
|
||||
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize, rows: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_mat_znx_dft(rows, 1, 2, derive_size(log_base2k, log_k)),
|
||||
log_base2k: log_base2k,
|
||||
log_k: log_k,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> GRLWECt<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
pub fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, res: &mut RLWECtDft<R, FFT64>)
|
||||
where
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
module.vmp_extract_row(res, self, row_i, 0);
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> GRLWECt<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
|
||||
{
|
||||
pub fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, a: &RLWECtDft<R, FFT64>)
|
||||
where
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
module.vmp_prepare_row(self, row_i, 0, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> Infos for GRLWECt<T, B> {
|
||||
type Inner = MatZnxDft<T, B>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
&self.data
|
||||
}
|
||||
|
||||
fn log_base2k(&self) -> usize {
|
||||
self.log_base2k
|
||||
}
|
||||
|
||||
fn log_k(&self) -> usize {
|
||||
self.log_k
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> MatZnxDftToMut<B> for GRLWECt<C, B>
|
||||
where
|
||||
MatZnxDft<C, B>: MatZnxDftToMut<B>,
|
||||
{
|
||||
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> MatZnxDftToRef<B> for GRLWECt<C, B>
|
||||
where
|
||||
MatZnxDft<C, B>: MatZnxDftToRef<B>,
|
||||
{
|
||||
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl GRLWECt<Vec<u8>, FFT64> {
|
||||
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, size: usize) -> usize {
|
||||
RLWECt::encrypt_sk_scratch_space(module, size)
|
||||
+ module.bytes_of_vec_znx(2, size)
|
||||
+ module.bytes_of_vec_znx(1, size)
|
||||
+ module.bytes_of_vec_znx_dft(2, size)
|
||||
}
|
||||
|
||||
pub fn mul_rlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
|
||||
module.bytes_of_vec_znx_dft(2, grlwe_size)
|
||||
+ (module.vec_znx_big_normalize_tmp_bytes()
|
||||
| (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size)
|
||||
+ module.bytes_of_vec_znx_dft(1, a_size)))
|
||||
}
|
||||
|
||||
pub fn mul_rlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, grlwe_size: usize) -> usize {
|
||||
Self::mul_rlwe_scratch_space(module, res_size, res_size, grlwe_size)
|
||||
}
|
||||
|
||||
pub fn mul_rlwe_dft_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
|
||||
(Self::mul_rlwe_scratch_space(module, res_size, a_size, grlwe_size) | module.vec_znx_idft_tmp_bytes())
|
||||
+ module.bytes_of_vec_znx(2, a_size)
|
||||
+ module.bytes_of_vec_znx(2, res_size)
|
||||
}
|
||||
|
||||
pub fn mul_rlwe_dft_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, grlwe_size: usize) -> usize {
|
||||
(Self::mul_rlwe_inplace_scratch_space(module, res_size, grlwe_size) | module.vec_znx_idft_tmp_bytes())
|
||||
+ module.bytes_of_vec_znx(2, res_size)
|
||||
}
|
||||
|
||||
pub fn mul_grlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
|
||||
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size)
|
||||
}
|
||||
|
||||
pub fn mul_grlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
|
||||
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encrypt_grlwe_sk<C, P, S>(
|
||||
module: &Module<FFT64>,
|
||||
ct: &mut GRLWECt<C, FFT64>,
|
||||
pt: &ScalarZnx<P>,
|
||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
|
||||
ScalarZnx<P>: ScalarZnxToRef,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let rows: usize = ct.rows();
|
||||
let size: usize = ct.size();
|
||||
let log_base2k: usize = ct.log_base2k();
|
||||
|
||||
let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size);
|
||||
let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size);
|
||||
let (mut vec_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size);
|
||||
|
||||
let mut vec_znx_pt: RLWEPt<&mut [u8]> = RLWEPt {
|
||||
data: tmp_znx_pt,
|
||||
log_base2k: log_base2k,
|
||||
log_k: ct.log_k(),
|
||||
};
|
||||
|
||||
let mut vec_znx_ct: RLWECt<&mut [u8]> = RLWECt {
|
||||
data: tmp_znx_ct,
|
||||
log_base2k: log_base2k,
|
||||
log_k: ct.log_k(),
|
||||
};
|
||||
|
||||
(0..rows).for_each(|row_i| {
|
||||
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
||||
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0);
|
||||
module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scratch_3);
|
||||
|
||||
// rlwe encrypt of vec_znx_pt into vec_znx_ct
|
||||
vec_znx_ct.encrypt_sk(
|
||||
module,
|
||||
Some(&vec_znx_pt),
|
||||
sk_dft,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch_3,
|
||||
);
|
||||
|
||||
vec_znx_pt.data.zero(); // zeroes for next iteration
|
||||
|
||||
// Switch vec_znx_ct into DFT domain
|
||||
module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0);
|
||||
module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1);
|
||||
|
||||
// Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft
|
||||
module.vmp_prepare_row(ct, row_i, 0, &vec_znx_dft_ct);
|
||||
});
|
||||
}
|
||||
|
||||
impl<C> GRLWECt<C, FFT64> {
|
||||
pub fn encrypt_sk<P, S>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
pt: &ScalarZnx<P>,
|
||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
|
||||
ScalarZnx<P>: ScalarZnxToRef,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
encrypt_grlwe_sk(
|
||||
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn mul_rlwe<R, A>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, a: &RLWECt<A>, scratch: &mut Scratch)
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
VecZnx<R>: VecZnxToMut,
|
||||
VecZnx<A>: VecZnxToRef,
|
||||
{
|
||||
MatZnxDftProducts::mul_rlwe(self, module, res, a, scratch);
|
||||
}
|
||||
|
||||
pub fn mul_rlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, scratch: &mut Scratch)
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
VecZnx<R>: VecZnxToMut + VecZnxToRef,
|
||||
{
|
||||
MatZnxDftProducts::mul_rlwe_inplace(self, module, res, scratch);
|
||||
}
|
||||
|
||||
pub fn mul_rlwe_dft<R, A>(
|
||||
&self,
|
||||
module: &Module<FFT64>,
|
||||
res: &mut RLWECtDft<R, FFT64>,
|
||||
a: &RLWECtDft<A, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64> + ZnxInfos,
|
||||
VecZnxDft<A, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
MatZnxDftProducts::mul_rlwe_dft(self, module, res, a, scratch);
|
||||
}
|
||||
|
||||
pub fn mul_rlwe_dft_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECtDft<R, FFT64>, scratch: &mut Scratch)
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToRef<FFT64> + VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
MatZnxDftProducts::mul_rlwe_dft_inplace(self, module, res, scratch);
|
||||
}
|
||||
|
||||
pub fn mul_grlwe<R, A>(
|
||||
&self,
|
||||
module: &Module<FFT64>,
|
||||
res: &mut GRLWECt<R, FFT64>,
|
||||
a: &GRLWECt<A, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
MatZnxDft<R, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
MatZnxDft<A, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
MatZnxDftProducts::mul_grlwe(self, module, res, a, scratch);
|
||||
}
|
||||
|
||||
pub fn mul_grlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut R, scratch: &mut Scratch)
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
R: GetRow<FFT64> + SetRow<FFT64> + Infos,
|
||||
{
|
||||
MatZnxDftProducts::mul_grlwe_inplace(self, module, res, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> GetRow<FFT64> for GRLWECt<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut RLWECtDft<R, FFT64>)
|
||||
where
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(col_j, 0);
|
||||
}
|
||||
module.vmp_extract_row(res, self, row_i, col_j);
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> SetRow<FFT64> for GRLWECt<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
|
||||
{
|
||||
fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, col_j: usize, a: &RLWECtDft<R, FFT64>)
|
||||
where
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(col_j, 0);
|
||||
}
|
||||
module.vmp_prepare_row(self, row_i, col_j, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> MatZnxDftProducts<GRLWECt<C, FFT64>, C> for GRLWECt<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
fn mul_rlwe<R, A>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, a: &RLWECt<A>, scratch: &mut Scratch)
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
VecZnx<R>: VecZnxToMut,
|
||||
VecZnx<A>: VecZnxToRef,
|
||||
{
|
||||
let log_base2k: usize = self.log_base2k();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.log_base2k(), log_base2k);
|
||||
assert_eq!(a.log_base2k(), log_base2k);
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), module.n());
|
||||
}
|
||||
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise
|
||||
|
||||
{
|
||||
let (mut a1_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, a.size());
|
||||
module.vec_znx_dft(&mut a1_dft, 0, a, 1);
|
||||
module.vmp_apply(&mut res_dft, &a1_dft, self, scratch2);
|
||||
}
|
||||
|
||||
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
|
||||
|
||||
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
|
||||
|
||||
module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1);
|
||||
module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1);
|
||||
}
|
||||
}
|
||||
204
core/src/keys.rs
Normal file
204
core/src/keys.rs
Normal file
@@ -0,0 +1,204 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut,
|
||||
ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, ScratchOwned, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos,
|
||||
ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{elem::Infos, rlwe::RLWECtDft};
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum SecretDistribution {
|
||||
TernaryFixed(usize), // Ternary with fixed Hamming weight
|
||||
TernaryProb(f64), // Ternary with probabilistic Hamming weight
|
||||
ZERO, // Debug mod
|
||||
NONE,
|
||||
}
|
||||
|
||||
pub struct SecretKey<T> {
|
||||
pub data: ScalarZnx<T>,
|
||||
pub dist: SecretDistribution,
|
||||
}
|
||||
|
||||
impl SecretKey<Vec<u8>> {
|
||||
pub fn new<B: Backend>(module: &Module<B>) -> Self {
|
||||
Self {
|
||||
data: module.new_scalar_znx(1),
|
||||
dist: SecretDistribution::NONE,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> SecretKey<S>
|
||||
where
|
||||
S: AsMut<[u8]> + AsRef<[u8]>,
|
||||
{
|
||||
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
|
||||
self.data.fill_ternary_prob(0, prob, source);
|
||||
self.dist = SecretDistribution::TernaryProb(prob);
|
||||
}
|
||||
|
||||
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
|
||||
self.data.fill_ternary_hw(0, hw, source);
|
||||
self.dist = SecretDistribution::TernaryFixed(hw);
|
||||
}
|
||||
|
||||
pub fn fill_zero(&mut self) {
|
||||
self.data.zero();
|
||||
self.dist = SecretDistribution::ZERO;
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> ScalarZnxToMut for SecretKey<C>
|
||||
where
|
||||
ScalarZnx<C>: ScalarZnxToMut,
|
||||
{
|
||||
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> ScalarZnxToRef for SecretKey<C>
|
||||
where
|
||||
ScalarZnx<C>: ScalarZnxToRef,
|
||||
{
|
||||
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SecretKeyDft<T, B: Backend> {
|
||||
pub data: ScalarZnxDft<T, B>,
|
||||
pub dist: SecretDistribution,
|
||||
}
|
||||
|
||||
impl<B: Backend> SecretKeyDft<Vec<u8>, B> {
|
||||
pub fn new(module: &Module<B>) -> Self {
|
||||
Self {
|
||||
data: module.new_scalar_znx_dft(1),
|
||||
dist: SecretDistribution::NONE,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dft<S>(&mut self, module: &Module<FFT64>, sk: &SecretKey<S>)
|
||||
where
|
||||
SecretKeyDft<Vec<u8>, B>: ScalarZnxDftToMut<base2k::FFT64>,
|
||||
SecretKey<S>: ScalarZnxToRef,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
match sk.dist {
|
||||
SecretDistribution::NONE => panic!("invalid sk: SecretDistribution::NONE"),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
module.svp_prepare(self, 0, sk, 0);
|
||||
self.dist = sk.dist;
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> ScalarZnxDftToMut<B> for SecretKeyDft<C, B>
|
||||
where
|
||||
ScalarZnxDft<C, B>: ScalarZnxDftToMut<B>,
|
||||
{
|
||||
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> ScalarZnxDftToRef<B> for SecretKeyDft<C, B>
|
||||
where
|
||||
ScalarZnxDft<C, B>: ScalarZnxDftToRef<B>,
|
||||
{
|
||||
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PublicKey<D, B: Backend> {
|
||||
pub data: RLWECtDft<D, B>,
|
||||
pub dist: SecretDistribution,
|
||||
}
|
||||
|
||||
impl<B: Backend> PublicKey<Vec<u8>, B> {
|
||||
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
|
||||
Self {
|
||||
data: RLWECtDft::new(module, log_base2k, log_k),
|
||||
dist: SecretDistribution::NONE,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> Infos for PublicKey<T, B> {
|
||||
type Inner = VecZnxDft<T, B>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
&self.data.data
|
||||
}
|
||||
|
||||
fn log_base2k(&self) -> usize {
|
||||
self.data.log_base2k
|
||||
}
|
||||
|
||||
fn log_k(&self) -> usize {
|
||||
self.data.log_k
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> VecZnxDftToMut<B> for PublicKey<C, B>
|
||||
where
|
||||
VecZnxDft<C, B>: VecZnxDftToMut<B>,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> VecZnxDftToRef<B> for PublicKey<C, B>
|
||||
where
|
||||
VecZnxDft<C, B>: VecZnxDftToRef<B>,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> PublicKey<C, FFT64> {
|
||||
pub fn generate<S>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
) where
|
||||
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
match sk_dft.dist {
|
||||
SecretDistribution::NONE => panic!("invalid sk_dft: SecretDistribution::NONE"),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Its ok to allocate scratch space here since pk is usually generated only once.
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(RLWECtDft::encrypt_zero_sk_scratch_space(
|
||||
module,
|
||||
self.size(),
|
||||
));
|
||||
self.data.encrypt_zero_sk(
|
||||
module,
|
||||
sk_dft,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
self.dist = sk_dft.dist;
|
||||
}
|
||||
}
|
||||
7
core/src/lib.rs
Normal file
7
core/src/lib.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
pub mod elem;
|
||||
pub mod grlwe;
|
||||
pub mod keys;
|
||||
pub mod rgsw;
|
||||
pub mod rlwe;
|
||||
mod test_fft64;
|
||||
mod utils;
|
||||
320
core/src/rgsw.rs
Normal file
320
core/src/rgsw.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
|
||||
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch,
|
||||
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
|
||||
ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::{GetRow, Infos, MatZnxDftProducts, SetRow},
|
||||
grlwe::GRLWECt,
|
||||
keys::SecretKeyDft,
|
||||
rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk},
|
||||
utils::derive_size,
|
||||
};
|
||||
|
||||
pub struct RGSWCt<C, B: Backend> {
|
||||
pub data: MatZnxDft<C, B>,
|
||||
pub log_base2k: usize,
|
||||
pub log_k: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> RGSWCt<Vec<u8>, B> {
|
||||
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize, rows: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_mat_znx_dft(rows, 2, 2, derive_size(log_base2k, log_k)),
|
||||
log_base2k: log_base2k,
|
||||
log_k: log_k,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> Infos for RGSWCt<T, B> {
|
||||
type Inner = MatZnxDft<T, B>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
&self.data
|
||||
}
|
||||
|
||||
fn log_base2k(&self) -> usize {
|
||||
self.log_base2k
|
||||
}
|
||||
|
||||
fn log_k(&self) -> usize {
|
||||
self.log_k
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> MatZnxDftToMut<B> for RGSWCt<C, B>
|
||||
where
|
||||
MatZnxDft<C, B>: MatZnxDftToMut<B>,
|
||||
{
|
||||
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> MatZnxDftToRef<B> for RGSWCt<C, B>
|
||||
where
|
||||
MatZnxDft<C, B>: MatZnxDftToRef<B>,
|
||||
{
|
||||
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl RGSWCt<Vec<u8>, FFT64> {
|
||||
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, size: usize) -> usize {
|
||||
RLWECt::encrypt_sk_scratch_space(module, size)
|
||||
+ module.bytes_of_vec_znx(2, size)
|
||||
+ module.bytes_of_vec_znx(1, size)
|
||||
+ module.bytes_of_vec_znx_dft(2, size)
|
||||
}
|
||||
|
||||
pub fn mul_rlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, rgsw_size: usize) -> usize {
|
||||
module.bytes_of_vec_znx_dft(2, rgsw_size)
|
||||
+ ((module.bytes_of_vec_znx_dft(2, a_size) + module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 2, 2, rgsw_size))
|
||||
| module.vec_znx_big_normalize_tmp_bytes())
|
||||
}
|
||||
|
||||
pub fn mul_rlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rgsw_size: usize) -> usize {
|
||||
Self::mul_rlwe_scratch_space(module, res_size, res_size, rgsw_size)
|
||||
}
|
||||
|
||||
pub fn mul_rlwe_dft_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
|
||||
(Self::mul_rlwe_scratch_space(module, res_size, a_size, grlwe_size) | module.vec_znx_idft_tmp_bytes())
|
||||
+ module.bytes_of_vec_znx(2, a_size)
|
||||
+ module.bytes_of_vec_znx(2, res_size)
|
||||
}
|
||||
|
||||
pub fn mul_rlwe_dft_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, grlwe_size: usize) -> usize {
|
||||
(Self::mul_rlwe_inplace_scratch_space(module, res_size, grlwe_size) | module.vec_znx_idft_tmp_bytes())
|
||||
+ module.bytes_of_vec_znx(2, res_size)
|
||||
}
|
||||
|
||||
pub fn mul_grlwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
|
||||
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size)
|
||||
}
|
||||
|
||||
pub fn mul_grlwe_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
|
||||
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size)
|
||||
}
|
||||
|
||||
pub fn mul_rgsw_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
|
||||
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size)
|
||||
}
|
||||
|
||||
pub fn mul_rgsw_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
|
||||
Self::mul_rlwe_dft_inplace_scratch_space(module, res_size, grlwe_size) + module.bytes_of_vec_znx_dft(2, a_size)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encrypt_rgsw_sk<C, P, S>(
|
||||
module: &Module<FFT64>,
|
||||
ct: &mut RGSWCt<C, FFT64>,
|
||||
pt: &ScalarZnx<P>,
|
||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
|
||||
ScalarZnx<P>: ScalarZnxToRef,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let size: usize = ct.size();
|
||||
let log_base2k: usize = ct.log_base2k();
|
||||
|
||||
let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size);
|
||||
let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, 2, size);
|
||||
|
||||
let mut vec_znx_pt: RLWEPt<&mut [u8]> = RLWEPt {
|
||||
data: tmp_znx_pt,
|
||||
log_base2k: log_base2k,
|
||||
log_k: ct.log_k(),
|
||||
};
|
||||
|
||||
let mut vec_znx_ct: RLWECt<&mut [u8]> = RLWECt {
|
||||
data: tmp_znx_ct,
|
||||
log_base2k: log_base2k,
|
||||
log_k: ct.log_k(),
|
||||
};
|
||||
|
||||
(0..ct.rows()).for_each(|row_j| {
|
||||
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
||||
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0);
|
||||
module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2);
|
||||
|
||||
(0..ct.cols()).for_each(|col_i| {
|
||||
// rlwe encrypt of vec_znx_pt into vec_znx_ct
|
||||
encrypt_rlwe_sk(
|
||||
module,
|
||||
&mut vec_znx_ct,
|
||||
Some((&vec_znx_pt, col_i)),
|
||||
sk_dft,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scrach_2,
|
||||
);
|
||||
|
||||
// Switch vec_znx_ct into DFT domain
|
||||
{
|
||||
let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, 2, size);
|
||||
module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0);
|
||||
module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1);
|
||||
module.vmp_prepare_row(ct, row_j, col_i, &vec_znx_dft_ct);
|
||||
}
|
||||
});
|
||||
|
||||
vec_znx_pt.data.zero(); // zeroes for next iteration
|
||||
});
|
||||
}
|
||||
|
||||
impl<C> RGSWCt<C, FFT64> {
|
||||
pub fn encrypt_sk<P, S>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
pt: &ScalarZnx<P>,
|
||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
|
||||
ScalarZnx<P>: ScalarZnxToRef,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
encrypt_rgsw_sk(
|
||||
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn mul_rlwe<R, A>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, a: &RLWECt<A>, scratch: &mut Scratch)
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
VecZnx<R>: VecZnxToMut,
|
||||
VecZnx<A>: VecZnxToRef,
|
||||
{
|
||||
MatZnxDftProducts::mul_rlwe(self, module, res, a, scratch);
|
||||
}
|
||||
|
||||
pub fn mul_rlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, scratch: &mut Scratch)
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
VecZnx<R>: VecZnxToMut + VecZnxToRef,
|
||||
{
|
||||
MatZnxDftProducts::mul_rlwe_inplace(self, module, res, scratch);
|
||||
}
|
||||
|
||||
pub fn mul_rlwe_dft<R, A>(
|
||||
&self,
|
||||
module: &Module<FFT64>,
|
||||
res: &mut RLWECtDft<R, FFT64>,
|
||||
a: &RLWECtDft<A, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64> + ZnxInfos,
|
||||
VecZnxDft<A, FFT64>: VecZnxDftToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
MatZnxDftProducts::mul_rlwe_dft(self, module, res, a, scratch);
|
||||
}
|
||||
|
||||
pub fn mul_rlwe_dft_inplace<R>(&self, module: &Module<FFT64>, res: &mut RLWECtDft<R, FFT64>, scratch: &mut Scratch)
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToRef<FFT64> + VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
MatZnxDftProducts::mul_rlwe_dft_inplace(self, module, res, scratch);
|
||||
}
|
||||
|
||||
pub fn mul_grlwe<R, A>(
|
||||
&self,
|
||||
module: &Module<FFT64>,
|
||||
res: &mut GRLWECt<R, FFT64>,
|
||||
a: &GRLWECt<A, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
MatZnxDft<R, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
MatZnxDft<A, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
MatZnxDftProducts::mul_grlwe(self, module, res, a, scratch);
|
||||
}
|
||||
|
||||
pub fn mul_grlwe_inplace<R>(&self, module: &Module<FFT64>, res: &mut R, scratch: &mut Scratch)
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
R: GetRow<FFT64> + SetRow<FFT64> + Infos,
|
||||
{
|
||||
MatZnxDftProducts::mul_grlwe_inplace(self, module, res, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> GetRow<FFT64> for RGSWCt<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
fn get_row<R>(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut RLWECtDft<R, FFT64>)
|
||||
where
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
module.vmp_extract_row(res, self, row_i, col_j);
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> SetRow<FFT64> for RGSWCt<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
|
||||
{
|
||||
fn set_row<R>(&mut self, module: &Module<FFT64>, row_i: usize, col_j: usize, a: &RLWECtDft<R, FFT64>)
|
||||
where
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
module.vmp_prepare_row(self, row_i, col_j, a);
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> MatZnxDftProducts<RGSWCt<C, FFT64>, C> for RGSWCt<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
|
||||
{
|
||||
fn mul_rlwe<R, A>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, a: &RLWECt<A>, scratch: &mut Scratch)
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
VecZnx<R>: VecZnxToMut,
|
||||
VecZnx<A>: VecZnxToRef,
|
||||
{
|
||||
let log_base2k: usize = self.log_base2k();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(res.log_base2k(), log_base2k);
|
||||
assert_eq!(a.log_base2k(), log_base2k);
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(res.n(), module.n());
|
||||
assert_eq!(a.n(), module.n());
|
||||
}
|
||||
|
||||
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise
|
||||
|
||||
{
|
||||
let (mut a_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 2, a.size());
|
||||
module.vec_znx_dft(&mut a_dft, 0, a, 0);
|
||||
module.vec_znx_dft(&mut a_dft, 1, a, 1);
|
||||
module.vmp_apply(&mut res_dft, &a_dft, self, scratch2);
|
||||
}
|
||||
|
||||
let res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
|
||||
|
||||
module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1);
|
||||
module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1);
|
||||
}
|
||||
}
|
||||
606
core/src/rlwe.rs
Normal file
606
core/src/rlwe.rs
Normal file
@@ -0,0 +1,606 @@
|
||||
use base2k::{
|
||||
AddNormal, Backend, FFT64, FillUniform, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc,
|
||||
ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft,
|
||||
VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::Infos,
|
||||
grlwe::GRLWECt,
|
||||
keys::{PublicKey, SecretDistribution, SecretKeyDft},
|
||||
utils::derive_size,
|
||||
};
|
||||
|
||||
pub struct RLWECt<C> {
|
||||
pub data: VecZnx<C>,
|
||||
pub log_base2k: usize,
|
||||
pub log_k: usize,
|
||||
}
|
||||
|
||||
impl RLWECt<Vec<u8>> {
|
||||
pub fn new<B: Backend>(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_vec_znx(2, derive_size(log_base2k, log_k)),
|
||||
log_base2k: log_base2k,
|
||||
log_k: log_k,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Infos for RLWECt<T> {
|
||||
type Inner = VecZnx<T>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
&self.data
|
||||
}
|
||||
|
||||
fn log_base2k(&self) -> usize {
|
||||
self.log_base2k
|
||||
}
|
||||
|
||||
fn log_k(&self) -> usize {
|
||||
self.log_k
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> VecZnxToMut for RLWECt<C>
|
||||
where
|
||||
VecZnx<C>: VecZnxToMut,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> VecZnxToRef for RLWECt<C>
|
||||
where
|
||||
VecZnx<C>: VecZnxToRef,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> RLWECt<C>
|
||||
where
|
||||
VecZnx<C>: VecZnxToRef,
|
||||
{
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn dft<R>(&self, module: &Module<FFT64>, res: &mut RLWECtDft<R, FFT64>)
|
||||
where
|
||||
VecZnxDft<R, FFT64>: VecZnxDftToMut<FFT64> + ZnxInfos,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.cols(), 2);
|
||||
assert_eq!(res.cols(), 2);
|
||||
assert_eq!(self.log_base2k(), res.log_base2k())
|
||||
}
|
||||
|
||||
module.vec_znx_dft(res, 0, self, 0);
|
||||
module.vec_znx_dft(res, 1, self, 1);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RLWEPt<C> {
|
||||
pub data: VecZnx<C>,
|
||||
pub log_base2k: usize,
|
||||
pub log_k: usize,
|
||||
}
|
||||
|
||||
impl<T> Infos for RLWEPt<T> {
|
||||
type Inner = VecZnx<T>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
&self.data
|
||||
}
|
||||
|
||||
fn log_base2k(&self) -> usize {
|
||||
self.log_base2k
|
||||
}
|
||||
|
||||
fn log_k(&self) -> usize {
|
||||
self.log_k
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> VecZnxToMut for RLWEPt<C>
|
||||
where
|
||||
VecZnx<C>: VecZnxToMut,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> VecZnxToRef for RLWEPt<C>
|
||||
where
|
||||
VecZnx<C>: VecZnxToRef,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl RLWEPt<Vec<u8>> {
|
||||
pub fn new<B: Backend>(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_vec_znx(1, derive_size(log_base2k, log_k)),
|
||||
log_base2k: log_base2k,
|
||||
log_k: log_k,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RLWECtDft<C, B: Backend> {
|
||||
pub data: VecZnxDft<C, B>,
|
||||
pub log_base2k: usize,
|
||||
pub log_k: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> RLWECtDft<Vec<u8>, B> {
|
||||
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_k)),
|
||||
log_base2k: log_base2k,
|
||||
log_k: log_k,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> Infos for RLWECtDft<T, B> {
|
||||
type Inner = VecZnxDft<T, B>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
&self.data
|
||||
}
|
||||
|
||||
fn log_base2k(&self) -> usize {
|
||||
self.log_base2k
|
||||
}
|
||||
|
||||
fn log_k(&self) -> usize {
|
||||
self.log_k
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> VecZnxDftToMut<B> for RLWECtDft<C, B>
|
||||
where
|
||||
VecZnxDft<C, B>: VecZnxDftToMut<B>,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> VecZnxDftToRef<B> for RLWECtDft<C, B>
|
||||
where
|
||||
VecZnxDft<C, B>: VecZnxDftToRef<B>,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C> RLWECtDft<C, FFT64>
|
||||
where
|
||||
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn idft_scratch_space(module: &Module<FFT64>, size: usize) -> usize {
|
||||
module.bytes_of_vec_znx(2, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())
|
||||
}
|
||||
|
||||
pub(crate) fn idft<R>(&self, module: &Module<FFT64>, res: &mut RLWECt<R>, scratch: &mut Scratch)
|
||||
where
|
||||
VecZnx<R>: VecZnxToMut,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.cols(), 2);
|
||||
assert_eq!(res.cols(), 2);
|
||||
assert_eq!(self.log_base2k(), res.log_base2k())
|
||||
}
|
||||
|
||||
let min_size: usize = self.size().min(res.size());
|
||||
|
||||
let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 2, min_size);
|
||||
|
||||
module.vec_znx_idft(&mut res_big, 0, &self.data, 0, scratch1);
|
||||
module.vec_znx_idft(&mut res_big, 1, &self.data, 1, scratch1);
|
||||
module.vec_znx_big_normalize(self.log_base2k(), res, 0, &res_big, 0, scratch1);
|
||||
module.vec_znx_big_normalize(self.log_base2k(), res, 1, &res_big, 1, scratch1);
|
||||
}
|
||||
}
|
||||
|
||||
impl RLWECt<Vec<u8>> {
|
||||
pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, size: usize) -> usize {
|
||||
(module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size)
|
||||
}
|
||||
|
||||
pub fn encrypt_pk_scratch_space<B: Backend>(module: &Module<B>, pk_size: usize) -> usize {
|
||||
((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1))
|
||||
+ module.bytes_of_scalar_znx_dft(1)
|
||||
+ module.vec_znx_big_normalize_tmp_bytes()
|
||||
}
|
||||
|
||||
pub fn decrypt_scratch_space<B: Backend>(module: &Module<B>, size: usize) -> usize {
|
||||
(module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encrypt_rlwe_sk<C, P, S>(
|
||||
module: &Module<FFT64>,
|
||||
ct: &mut RLWECt<C>,
|
||||
pt: Option<(&RLWEPt<P>, usize)>,
|
||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<C>: VecZnxToMut + VecZnxToRef,
|
||||
VecZnx<P>: VecZnxToRef,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let log_base2k: usize = ct.log_base2k();
|
||||
let log_k: usize = ct.log_k();
|
||||
let size: usize = ct.size();
|
||||
|
||||
// c1 = a
|
||||
ct.data.fill_uniform(log_base2k, 1, size, source_xa);
|
||||
|
||||
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size);
|
||||
|
||||
{
|
||||
let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size);
|
||||
module.vec_znx_dft(&mut c0_dft, 0, ct, 1);
|
||||
|
||||
// c0_dft = DFT(a) * DFT(s)
|
||||
module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0);
|
||||
|
||||
// c0_big = IDFT(c0_dft)
|
||||
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
|
||||
}
|
||||
|
||||
// c0_big = m - c0_big
|
||||
if let Some((pt, col)) = pt {
|
||||
match col {
|
||||
0 => module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0),
|
||||
1 => {
|
||||
module.vec_znx_big_negate_inplace(&mut c0_big, 0);
|
||||
module.vec_znx_add_inplace(ct, 1, pt, 0);
|
||||
module.vec_znx_normalize_inplace(log_base2k, ct, 1, scratch_1);
|
||||
}
|
||||
_ => panic!("invalid target column: {}", col),
|
||||
}
|
||||
} else {
|
||||
module.vec_znx_big_negate_inplace(&mut c0_big, 0);
|
||||
}
|
||||
// c0_big += e
|
||||
c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound);
|
||||
|
||||
// c0 = norm(c0_big = -as + m + e)
|
||||
module.vec_znx_big_normalize(log_base2k, ct, 0, &c0_big, 0, scratch_1);
|
||||
}
|
||||
|
||||
pub fn decrypt_rlwe<P, C, S>(
|
||||
module: &Module<FFT64>,
|
||||
pt: &mut RLWEPt<P>,
|
||||
ct: &RLWECt<C>,
|
||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<P>: VecZnxToMut + VecZnxToRef,
|
||||
VecZnx<C>: VecZnxToRef,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct
|
||||
|
||||
{
|
||||
let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct
|
||||
module.vec_znx_dft(&mut c0_dft, 0, ct, 1);
|
||||
|
||||
// c0_dft = DFT(a) * DFT(s)
|
||||
module.svp_apply_inplace(&mut c0_dft, 0, sk_dft, 0);
|
||||
|
||||
// c0_big = IDFT(c0_dft)
|
||||
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
|
||||
}
|
||||
|
||||
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
|
||||
module.vec_znx_big_add_small_inplace(&mut c0_big, 0, ct, 0);
|
||||
|
||||
// pt = norm(BIG(m + e))
|
||||
module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1);
|
||||
|
||||
pt.log_base2k = ct.log_base2k();
|
||||
pt.log_k = pt.log_k().min(ct.log_k());
|
||||
}
|
||||
|
||||
impl<C> RLWECt<C> {
|
||||
pub fn encrypt_sk<P, S>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
pt: Option<&RLWEPt<P>>,
|
||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<C>: VecZnxToMut + VecZnxToRef,
|
||||
VecZnx<P>: VecZnxToRef,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
if let Some(pt) = pt {
|
||||
encrypt_rlwe_sk(
|
||||
module,
|
||||
self,
|
||||
Some((pt, 0)),
|
||||
sk_dft,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch,
|
||||
)
|
||||
} else {
|
||||
encrypt_rlwe_sk::<C, P, S>(
|
||||
module, self, None, sk_dft, source_xa, source_xe, sigma, bound, scratch,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrypt<P, S>(
|
||||
&self,
|
||||
module: &Module<FFT64>,
|
||||
pt: &mut RLWEPt<P>,
|
||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<P>: VecZnxToMut + VecZnxToRef,
|
||||
VecZnx<C>: VecZnxToRef,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
decrypt_rlwe(module, pt, self, sk_dft, scratch);
|
||||
}
|
||||
|
||||
pub fn encrypt_pk<P, S>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
pt: Option<&RLWEPt<P>>,
|
||||
pk: &PublicKey<S, FFT64>,
|
||||
source_xu: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<C>: VecZnxToMut + VecZnxToRef,
|
||||
VecZnx<P>: VecZnxToRef,
|
||||
VecZnxDft<S, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
encrypt_rlwe_pk(
|
||||
module, self, pt, pk, source_xu, source_xe, sigma, bound, scratch,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn encrypt_zero_rlwe_dft_sk<C, S>(
|
||||
module: &Module<FFT64>,
|
||||
ct: &mut RLWECtDft<C, FFT64>,
|
||||
sk: &SecretKeyDft<S, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let log_base2k: usize = ct.log_base2k();
|
||||
let log_k: usize = ct.log_k();
|
||||
let size: usize = ct.size();
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
match sk.dist {
|
||||
SecretDistribution::NONE => panic!("invalid sk.dist = SecretDistribution::NONE"),
|
||||
_ => {}
|
||||
}
|
||||
assert_eq!(ct.cols(), 2);
|
||||
}
|
||||
|
||||
// ct[1] = DFT(a)
|
||||
{
|
||||
let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size);
|
||||
tmp_znx.fill_uniform(log_base2k, 0, size, source_xa);
|
||||
module.vec_znx_dft(ct, 1, &tmp_znx, 0);
|
||||
}
|
||||
|
||||
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size);
|
||||
|
||||
{
|
||||
let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size);
|
||||
// c0_dft = ct[1] * DFT(s)
|
||||
module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1);
|
||||
|
||||
// c0_big = IDFT(c0_dft)
|
||||
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0);
|
||||
}
|
||||
|
||||
// c0_big += e
|
||||
c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound);
|
||||
|
||||
// c0 = norm(c0_big = -as - e), NOTE: e is centered at 0.
|
||||
let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size);
|
||||
module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2);
|
||||
module.vec_znx_negate_inplace(&mut tmp_znx, 0);
|
||||
// ct[0] = DFT(-as + e)
|
||||
module.vec_znx_dft(ct, 0, &tmp_znx, 0);
|
||||
}
|
||||
|
||||
impl RLWECtDft<Vec<u8>, FFT64> {
|
||||
pub fn encrypt_zero_sk_scratch_space(module: &Module<FFT64>, size: usize) -> usize {
|
||||
(module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size))
|
||||
+ module.bytes_of_vec_znx_big(1, size)
|
||||
+ module.bytes_of_vec_znx(1, size)
|
||||
+ module.vec_znx_big_normalize_tmp_bytes()
|
||||
}
|
||||
|
||||
pub fn decrypt_scratch_space(module: &Module<FFT64>, size: usize) -> usize {
|
||||
(module.vec_znx_big_normalize_tmp_bytes()
|
||||
| module.bytes_of_vec_znx_dft(1, size)
|
||||
| (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes()))
|
||||
+ module.bytes_of_vec_znx_big(1, size)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrypt_rlwe_dft<P, C, S>(
|
||||
module: &Module<FFT64>,
|
||||
pt: &mut RLWEPt<P>,
|
||||
ct: &RLWECtDft<C, FFT64>,
|
||||
sk: &SecretKeyDft<S, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<P>: VecZnxToMut + VecZnxToRef,
|
||||
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct
|
||||
|
||||
{
|
||||
let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct
|
||||
// c0_dft = DFT(a) * DFT(s)
|
||||
module.svp_apply(&mut c0_dft, 0, sk, 0, ct, 1);
|
||||
// c0_big = IDFT(c0_dft)
|
||||
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
|
||||
}
|
||||
|
||||
{
|
||||
let (mut c1_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, ct.size());
|
||||
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
|
||||
module.vec_znx_idft(&mut c1_big, 0, ct, 0, scratch_2);
|
||||
module.vec_znx_big_add_inplace(&mut c0_big, 0, &c1_big, 0);
|
||||
}
|
||||
|
||||
// pt = norm(BIG(m + e))
|
||||
module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1);
|
||||
|
||||
pt.log_base2k = ct.log_base2k();
|
||||
pt.log_k = pt.log_k().min(ct.log_k());
|
||||
}
|
||||
|
||||
impl<C> RLWECtDft<C, FFT64> {
|
||||
pub(crate) fn encrypt_zero_sk<S>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
encrypt_zero_rlwe_dft_sk(
|
||||
module, self, sk_dft, source_xa, source_xe, sigma, bound, scratch,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn decrypt<P, S>(
|
||||
&self,
|
||||
module: &Module<FFT64>,
|
||||
pt: &mut RLWEPt<P>,
|
||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<P>: VecZnxToMut + VecZnxToRef,
|
||||
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
decrypt_rlwe_dft(module, pt, self, sk_dft, scratch);
|
||||
}
|
||||
|
||||
pub fn mul_grlwe_assign<A>(&mut self, module: &Module<FFT64>, a: &GRLWECt<A, FFT64>, scratch: &mut Scratch)
|
||||
where
|
||||
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
|
||||
MatZnxDft<A, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
a.mul_rlwe_dft_inplace(module, self, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn encrypt_rlwe_pk<C, P, S>(
|
||||
module: &Module<FFT64>,
|
||||
ct: &mut RLWECt<C>,
|
||||
pt: Option<&RLWEPt<P>>,
|
||||
pk: &PublicKey<S, FFT64>,
|
||||
source_xu: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<C>: VecZnxToMut + VecZnxToRef,
|
||||
VecZnx<P>: VecZnxToRef,
|
||||
VecZnxDft<S, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(ct.log_base2k(), pk.log_base2k());
|
||||
assert_eq!(ct.n(), module.n());
|
||||
assert_eq!(pk.n(), module.n());
|
||||
if let Some(pt) = pt {
|
||||
assert_eq!(pt.log_base2k(), pk.log_base2k());
|
||||
assert_eq!(pt.n(), module.n());
|
||||
}
|
||||
}
|
||||
|
||||
let log_base2k: usize = pk.log_base2k();
|
||||
let size_pk: usize = pk.size();
|
||||
|
||||
// Generates u according to the underlying secret distribution.
|
||||
let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1);
|
||||
|
||||
{
|
||||
let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1);
|
||||
match pk.dist {
|
||||
SecretDistribution::NONE => panic!(
|
||||
"invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate"
|
||||
),
|
||||
SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu),
|
||||
SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu),
|
||||
SecretDistribution::ZERO => {}
|
||||
}
|
||||
|
||||
module.svp_prepare(&mut u_dft, 0, &u, 0);
|
||||
}
|
||||
|
||||
let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity)
|
||||
let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity)
|
||||
|
||||
// ct[0] = pk[0] * u + m + e0
|
||||
module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0);
|
||||
module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0);
|
||||
tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound);
|
||||
|
||||
if let Some(pt) = pt {
|
||||
module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0);
|
||||
}
|
||||
|
||||
module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3);
|
||||
|
||||
// ct[1] = pk[1] * u + e1
|
||||
module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1);
|
||||
module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0);
|
||||
tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound);
|
||||
module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3);
|
||||
}
|
||||
722
core/src/test_fft64/grlwe.rs
Normal file
722
core/src/test_fft64/grlwe.rs
Normal file
@@ -0,0 +1,722 @@
|
||||
#[cfg(test)]
|
||||
|
||||
mod tests {
|
||||
use base2k::{FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::Infos,
|
||||
grlwe::GRLWECt,
|
||||
keys::{SecretKey, SecretKeyDft},
|
||||
rlwe::{RLWECt, RLWECtDft, RLWEPt},
|
||||
test_fft64::grlwe::noise_grlwe_rlwe_product,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn encrypt_sk() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
||||
let log_base2k: usize = 8;
|
||||
let log_k_ct: usize = 54;
|
||||
let rows: usize = 4;
|
||||
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = sigma * 6.0;
|
||||
|
||||
let mut ct: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_ct, rows);
|
||||
let mut pt: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
|
||||
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GRLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
ct.encrypt_sk(
|
||||
&module,
|
||||
&pt_scalar,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let mut ct_rlwe_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct);
|
||||
|
||||
(0..ct.rows()).for_each(|row_i| {
|
||||
ct.get_row(&module, row_i, &mut ct_rlwe_dft);
|
||||
ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0);
|
||||
let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2();
|
||||
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
|
||||
});
|
||||
|
||||
module.free();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mul_rlwe() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
||||
let log_base2k: usize = 12;
|
||||
let log_k_grlwe: usize = 60;
|
||||
let log_k_rlwe_in: usize = 45;
|
||||
let log_k_rlwe_out: usize = 60;
|
||||
let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k;
|
||||
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = sigma * 6.0;
|
||||
|
||||
let mut ct_grlwe: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
|
||||
let mut ct_rlwe_in: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_in);
|
||||
let mut ct_rlwe_out: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_out);
|
||||
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in);
|
||||
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
// Random input plaintext
|
||||
pt_want
|
||||
.data
|
||||
.fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size())
|
||||
| RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size())
|
||||
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size())
|
||||
| GRLWECt::mul_rlwe_scratch_space(
|
||||
&module,
|
||||
ct_rlwe_out.size(),
|
||||
ct_rlwe_in.size(),
|
||||
ct_grlwe.size(),
|
||||
),
|
||||
);
|
||||
|
||||
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk0.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk0_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk0_dft.dft(&module, &sk0);
|
||||
|
||||
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk1.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk1_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk1_dft.dft(&module, &sk1);
|
||||
|
||||
ct_grlwe.encrypt_sk(
|
||||
&module,
|
||||
&sk0.data,
|
||||
&sk1_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rlwe_in.encrypt_sk(
|
||||
&module,
|
||||
Some(&pt_want),
|
||||
&sk0_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_grlwe.mul_rlwe(&module, &mut ct_rlwe_out, &ct_rlwe_in, scratch.borrow());
|
||||
|
||||
ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
|
||||
let noise_want: f64 = noise_grlwe_rlwe_product(
|
||||
module.n() as f64,
|
||||
log_base2k,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
log_k_rlwe_in,
|
||||
log_k_grlwe,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
|
||||
module.free();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mul_rlwe_inplace() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
||||
let log_base2k: usize = 12;
|
||||
let log_k_grlwe: usize = 60;
|
||||
let log_k_rlwe: usize = 45;
|
||||
let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k;
|
||||
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = sigma * 6.0;
|
||||
|
||||
let mut ct_grlwe: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
|
||||
let mut ct_rlwe: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe);
|
||||
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe);
|
||||
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
// Random input plaintext
|
||||
pt_want
|
||||
.data
|
||||
.fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size())
|
||||
| RLWECt::decrypt_scratch_space(&module, ct_rlwe.size())
|
||||
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size())
|
||||
| GRLWECt::mul_rlwe_scratch_space(&module, ct_rlwe.size(), ct_rlwe.size(), ct_grlwe.size()),
|
||||
);
|
||||
|
||||
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk0.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk0_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk0_dft.dft(&module, &sk0);
|
||||
|
||||
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk1.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk1_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk1_dft.dft(&module, &sk1);
|
||||
|
||||
ct_grlwe.encrypt_sk(
|
||||
&module,
|
||||
&sk0.data,
|
||||
&sk1_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rlwe.encrypt_sk(
|
||||
&module,
|
||||
Some(&pt_want),
|
||||
&sk0_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_grlwe.mul_rlwe_inplace(&module, &mut ct_rlwe, scratch.borrow());
|
||||
|
||||
ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
|
||||
let noise_want: f64 = noise_grlwe_rlwe_product(
|
||||
module.n() as f64,
|
||||
log_base2k,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
log_k_rlwe,
|
||||
log_k_grlwe,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
|
||||
module.free();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mul_rlwe_dft() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
||||
let log_base2k: usize = 12;
|
||||
let log_k_grlwe: usize = 60;
|
||||
let log_k_rlwe_in: usize = 45;
|
||||
let log_k_rlwe_out: usize = 60;
|
||||
let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k;
|
||||
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = sigma * 6.0;
|
||||
|
||||
let mut ct_grlwe: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
|
||||
let mut ct_rlwe_in: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_in);
|
||||
let mut ct_rlwe_in_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_in);
|
||||
let mut ct_rlwe_out: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_out);
|
||||
let mut ct_rlwe_out_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe_out);
|
||||
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in);
|
||||
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
// Random input plaintext
|
||||
pt_want
|
||||
.data
|
||||
.fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size())
|
||||
| RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size())
|
||||
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size())
|
||||
| GRLWECt::mul_rlwe_scratch_space(
|
||||
&module,
|
||||
ct_rlwe_out.size(),
|
||||
ct_rlwe_in.size(),
|
||||
ct_grlwe.size(),
|
||||
),
|
||||
);
|
||||
|
||||
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk0.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk0_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk0_dft.dft(&module, &sk0);
|
||||
|
||||
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk1.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk1_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk1_dft.dft(&module, &sk1);
|
||||
|
||||
ct_grlwe.encrypt_sk(
|
||||
&module,
|
||||
&sk0.data,
|
||||
&sk1_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rlwe_in.encrypt_sk(
|
||||
&module,
|
||||
Some(&pt_want),
|
||||
&sk0_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rlwe_in.dft(&module, &mut ct_rlwe_in_dft);
|
||||
ct_grlwe.mul_rlwe_dft(
|
||||
&module,
|
||||
&mut ct_rlwe_out_dft,
|
||||
&ct_rlwe_in_dft,
|
||||
scratch.borrow(),
|
||||
);
|
||||
ct_rlwe_out_dft.idft(&module, &mut ct_rlwe_out, scratch.borrow());
|
||||
|
||||
ct_rlwe_out.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
|
||||
let noise_want: f64 = noise_grlwe_rlwe_product(
|
||||
module.n() as f64,
|
||||
log_base2k,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
log_k_rlwe_in,
|
||||
log_k_grlwe,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
|
||||
module.free();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mul_rlwe_dft_inplace() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
||||
let log_base2k: usize = 12;
|
||||
let log_k_grlwe: usize = 60;
|
||||
let log_k_rlwe: usize = 45;
|
||||
let rows: usize = (log_k_rlwe + log_base2k - 1) / log_base2k;
|
||||
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = sigma * 6.0;
|
||||
|
||||
let mut ct_grlwe: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
|
||||
let mut ct_rlwe: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe);
|
||||
let mut ct_rlwe_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_rlwe);
|
||||
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe);
|
||||
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
// Random input plaintext
|
||||
pt_want
|
||||
.data
|
||||
.fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe.size())
|
||||
| RLWECt::decrypt_scratch_space(&module, ct_rlwe.size())
|
||||
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe.size())
|
||||
| GRLWECt::mul_rlwe_scratch_space(&module, ct_rlwe.size(), ct_rlwe.size(), ct_grlwe.size()),
|
||||
);
|
||||
|
||||
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk0.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk0_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk0_dft.dft(&module, &sk0);
|
||||
|
||||
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk1.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk1_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk1_dft.dft(&module, &sk1);
|
||||
|
||||
ct_grlwe.encrypt_sk(
|
||||
&module,
|
||||
&sk0.data,
|
||||
&sk1_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rlwe.encrypt_sk(
|
||||
&module,
|
||||
Some(&pt_want),
|
||||
&sk0_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rlwe.dft(&module, &mut ct_rlwe_dft);
|
||||
ct_grlwe.mul_rlwe_dft_inplace(&module, &mut ct_rlwe_dft, scratch.borrow());
|
||||
ct_rlwe_dft.idft(&module, &mut ct_rlwe, scratch.borrow());
|
||||
|
||||
ct_rlwe.decrypt(&module, &mut pt_have, &sk1_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
|
||||
let noise_want: f64 = noise_grlwe_rlwe_product(
|
||||
module.n() as f64,
|
||||
log_base2k,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
log_k_rlwe,
|
||||
log_k_grlwe,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
|
||||
module.free();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mul_grlwe() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
||||
let log_base2k: usize = 12;
|
||||
let log_k_grlwe: usize = 60;
|
||||
let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k;
|
||||
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = sigma * 6.0;
|
||||
|
||||
let mut ct_grlwe_s0s1: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
|
||||
let mut ct_grlwe_s1s2: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
|
||||
let mut ct_grlwe_s0s2: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size())
|
||||
| RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s2.size())
|
||||
| GRLWECt::mul_grlwe_scratch_space(
|
||||
&module,
|
||||
ct_grlwe_s0s2.size(),
|
||||
ct_grlwe_s0s1.size(),
|
||||
ct_grlwe_s1s2.size(),
|
||||
),
|
||||
);
|
||||
|
||||
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk0.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk0_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk0_dft.dft(&module, &sk0);
|
||||
|
||||
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk1.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk1_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk1_dft.dft(&module, &sk1);
|
||||
|
||||
let mut sk2: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk2.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk2_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk2_dft.dft(&module, &sk2);
|
||||
|
||||
// GRLWE_{s1}(s0) = s0 -> s1
|
||||
ct_grlwe_s0s1.encrypt_sk(
|
||||
&module,
|
||||
&sk0.data,
|
||||
&sk1_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// GRLWE_{s2}(s1) -> s1 -> s2
|
||||
ct_grlwe_s1s2.encrypt_sk(
|
||||
&module,
|
||||
&sk1.data,
|
||||
&sk2_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0)
|
||||
ct_grlwe_s1s2.mul_grlwe(
|
||||
&module,
|
||||
&mut ct_grlwe_s0s2,
|
||||
&ct_grlwe_s0s1,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let mut ct_rlwe_dft_s0s2: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe);
|
||||
let mut pt: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_grlwe);
|
||||
|
||||
(0..ct_grlwe_s0s2.rows()).for_each(|row_i| {
|
||||
ct_grlwe_s0s2.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2);
|
||||
ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow());
|
||||
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0);
|
||||
|
||||
let noise_have: f64 = pt.data.std(0, log_base2k).log2();
|
||||
let noise_want: f64 = noise_grlwe_rlwe_product(
|
||||
module.n() as f64,
|
||||
log_base2k,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
log_k_grlwe,
|
||||
log_k_grlwe,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
});
|
||||
|
||||
module.free();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mul_grlwe_inplace() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
||||
let log_base2k: usize = 12;
|
||||
let log_k_grlwe: usize = 60;
|
||||
let rows: usize = (log_k_grlwe + log_base2k - 1) / log_base2k;
|
||||
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = sigma * 6.0;
|
||||
|
||||
let mut ct_grlwe_s0s1: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
|
||||
let mut ct_grlwe_s1s2: GRLWECt<Vec<u8>, FFT64> = GRLWECt::new(&module, log_base2k, log_k_grlwe, rows);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
GRLWECt::encrypt_sk_scratch_space(&module, ct_grlwe_s0s1.size())
|
||||
| RLWECtDft::decrypt_scratch_space(&module, ct_grlwe_s0s1.size())
|
||||
| GRLWECt::mul_grlwe_scratch_space(
|
||||
&module,
|
||||
ct_grlwe_s0s1.size(),
|
||||
ct_grlwe_s0s1.size(),
|
||||
ct_grlwe_s1s2.size(),
|
||||
),
|
||||
);
|
||||
|
||||
let mut sk0: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk0.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk0_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk0_dft.dft(&module, &sk0);
|
||||
|
||||
let mut sk1: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk1.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk1_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk1_dft.dft(&module, &sk1);
|
||||
|
||||
let mut sk2: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk2.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk2_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk2_dft.dft(&module, &sk2);
|
||||
|
||||
// GRLWE_{s1}(s0) = s0 -> s1
|
||||
ct_grlwe_s0s1.encrypt_sk(
|
||||
&module,
|
||||
&sk0.data,
|
||||
&sk1_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// GRLWE_{s2}(s1) -> s1 -> s2
|
||||
ct_grlwe_s1s2.encrypt_sk(
|
||||
&module,
|
||||
&sk1.data,
|
||||
&sk2_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
// GRLWE_{s1}(s0) (x) GRLWE_{s2}(s1) = GRLWE_{s2}(s0)
|
||||
ct_grlwe_s1s2.mul_grlwe_inplace(&module, &mut ct_grlwe_s0s1, scratch.borrow());
|
||||
|
||||
let ct_grlwe_s0s2: GRLWECt<Vec<u8>, FFT64> = ct_grlwe_s0s1;
|
||||
|
||||
let mut ct_rlwe_dft_s0s2: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_grlwe);
|
||||
let mut pt: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_grlwe);
|
||||
|
||||
(0..ct_grlwe_s0s2.rows()).for_each(|row_i| {
|
||||
ct_grlwe_s0s2.get_row(&module, row_i, &mut ct_rlwe_dft_s0s2);
|
||||
ct_rlwe_dft_s0s2.decrypt(&module, &mut pt, &sk2_dft, scratch.borrow());
|
||||
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk0, 0);
|
||||
|
||||
let noise_have: f64 = pt.data.std(0, log_base2k).log2();
|
||||
let noise_want: f64 = noise_grlwe_rlwe_product(
|
||||
module.n() as f64,
|
||||
log_base2k,
|
||||
0.5,
|
||||
0.5,
|
||||
0f64,
|
||||
sigma * sigma,
|
||||
0f64,
|
||||
log_k_grlwe,
|
||||
log_k_grlwe,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
});
|
||||
|
||||
module.free();
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn noise_grlwe_rlwe_product(
|
||||
n: f64,
|
||||
log_base2k: usize,
|
||||
var_xs: f64,
|
||||
var_msg: f64,
|
||||
var_a_err: f64,
|
||||
var_gct_err_lhs: f64,
|
||||
var_gct_err_rhs: f64,
|
||||
a_logq: usize,
|
||||
b_logq: usize,
|
||||
) -> f64 {
|
||||
let a_logq: usize = a_logq.min(b_logq);
|
||||
let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k;
|
||||
|
||||
let b_scale = 2.0f64.powi(b_logq as i32);
|
||||
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
|
||||
|
||||
let base: f64 = (1 << (log_base2k)) as f64;
|
||||
let var_base: f64 = base * base / 12f64;
|
||||
|
||||
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
|
||||
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
|
||||
let mut noise: f64 = (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
|
||||
noise += var_msg * var_a_err * a_scale * a_scale * n;
|
||||
noise = noise.sqrt();
|
||||
noise /= b_scale;
|
||||
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
|
||||
}
|
||||
3
core/src/test_fft64/mod.rs
Normal file
3
core/src/test_fft64/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod grlwe;
|
||||
mod rgsw;
|
||||
mod rlwe;
|
||||
235
core/src/test_fft64/rgsw.rs
Normal file
235
core/src/test_fft64/rgsw.rs
Normal file
@@ -0,0 +1,235 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use base2k::{
|
||||
FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps,
|
||||
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, VecZnxToMut, ZnxViewMut, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::{GetRow, Infos},
|
||||
keys::{SecretKey, SecretKeyDft},
|
||||
rgsw::RGSWCt,
|
||||
rlwe::{RLWECt, RLWECtDft, RLWEPt},
|
||||
test_fft64::rgsw::noise_rgsw_rlwe_product,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn encrypt_rgsw_sk() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
||||
let log_base2k: usize = 8;
|
||||
let log_k_ct: usize = 54;
|
||||
let rows: usize = 4;
|
||||
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = sigma * 6.0;
|
||||
|
||||
let mut ct: RGSWCt<Vec<u8>, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows);
|
||||
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
|
||||
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
|
||||
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
RGSWCt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECtDft::decrypt_scratch_space(&module, ct.size()),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
ct.encrypt_sk(
|
||||
&module,
|
||||
&pt_scalar,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let mut ct_rlwe_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct);
|
||||
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct.size());
|
||||
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct.size());
|
||||
|
||||
(0..ct.cols()).for_each(|col_j| {
|
||||
(0..ct.rows()).for_each(|row_i| {
|
||||
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
|
||||
|
||||
if col_j == 1 {
|
||||
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
||||
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0);
|
||||
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
||||
module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
||||
}
|
||||
|
||||
ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft);
|
||||
|
||||
ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2();
|
||||
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
|
||||
|
||||
pt_want.data.zero();
|
||||
});
|
||||
});
|
||||
|
||||
module.free();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mul_rlwe() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
||||
let log_base2k: usize = 12;
|
||||
let log_k_grlwe: usize = 60;
|
||||
let log_k_rlwe_in: usize = 45;
|
||||
let log_k_rlwe_out: usize = 60;
|
||||
let rows: usize = (log_k_rlwe_in + log_base2k - 1) / log_base2k;
|
||||
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = sigma * 6.0;
|
||||
|
||||
let mut ct_rgsw: RGSWCt<Vec<u8>, FFT64> = RGSWCt::new(&module, log_base2k, log_k_grlwe, rows);
|
||||
let mut ct_rlwe_in: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_in);
|
||||
let mut ct_rlwe_out: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_rlwe_out);
|
||||
let mut pt_rgsw: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_in);
|
||||
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_rlwe_out);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
// Random input plaintext
|
||||
// pt_want
|
||||
// .data
|
||||
// .fill_uniform(log_base2k, 0, pt_want.size(), &mut source_xa);
|
||||
|
||||
pt_want.to_mut().at_mut(0, 0)[1] = 1;
|
||||
|
||||
let k: usize = 1;
|
||||
|
||||
pt_rgsw.raw_mut()[k] = 1; // X^{k}
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
RGSWCt::encrypt_sk_scratch_space(&module, ct_rgsw.size())
|
||||
| RLWECt::decrypt_scratch_space(&module, ct_rlwe_out.size())
|
||||
| RLWECt::encrypt_sk_scratch_space(&module, ct_rlwe_in.size())
|
||||
| RGSWCt::mul_rlwe_scratch_space(
|
||||
&module,
|
||||
ct_rlwe_out.size(),
|
||||
ct_rlwe_in.size(),
|
||||
ct_rgsw.size(),
|
||||
),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
ct_rgsw.encrypt_sk(
|
||||
&module,
|
||||
&pt_rgsw,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rlwe_in.encrypt_sk(
|
||||
&module,
|
||||
Some(&pt_want),
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
ct_rgsw.mul_rlwe(&module, &mut ct_rlwe_out, &ct_rlwe_in, scratch.borrow());
|
||||
|
||||
ct_rlwe_out.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_rotate_inplace(k as i64, &mut pt_want, 0);
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let noise_have: f64 = pt_have.data.std(0, log_base2k).log2();
|
||||
|
||||
let var_gct_err_lhs: f64 = sigma * sigma;
|
||||
let var_gct_err_rhs: f64 = 0f64;
|
||||
|
||||
let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
|
||||
let var_a0_err: f64 = sigma * sigma;
|
||||
let var_a1_err: f64 = 1f64 / 12f64;
|
||||
|
||||
let noise_want: f64 = noise_rgsw_rlwe_product(
|
||||
module.n() as f64,
|
||||
log_base2k,
|
||||
0.5,
|
||||
var_msg,
|
||||
var_a0_err,
|
||||
var_a1_err,
|
||||
var_gct_err_lhs,
|
||||
var_gct_err_rhs,
|
||||
log_k_rlwe_in,
|
||||
log_k_grlwe,
|
||||
);
|
||||
|
||||
assert!(
|
||||
(noise_have - noise_want).abs() <= 0.1,
|
||||
"{} {}",
|
||||
noise_have,
|
||||
noise_want
|
||||
);
|
||||
|
||||
module.free();
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn noise_rgsw_rlwe_product(
|
||||
n: f64,
|
||||
log_base2k: usize,
|
||||
var_xs: f64,
|
||||
var_msg: f64,
|
||||
var_a0_err: f64,
|
||||
var_a1_err: f64,
|
||||
var_gct_err_lhs: f64,
|
||||
var_gct_err_rhs: f64,
|
||||
a_logq: usize,
|
||||
b_logq: usize,
|
||||
) -> f64 {
|
||||
let a_logq: usize = a_logq.min(b_logq);
|
||||
let a_cols: usize = (a_logq + log_base2k - 1) / log_base2k;
|
||||
|
||||
let b_scale = 2.0f64.powi(b_logq as i32);
|
||||
let a_scale: f64 = 2.0f64.powi((b_logq - a_logq) as i32);
|
||||
|
||||
let base: f64 = (1 << (log_base2k)) as f64;
|
||||
let var_base: f64 = base * base / 12f64;
|
||||
|
||||
// lhs = a_cols * n * (var_base * var_gct_err_lhs + var_e_a * var_msg * p^2)
|
||||
// rhs = a_cols * n * var_base * var_gct_err_rhs * var_xs
|
||||
let mut noise: f64 = 2.0 * (a_cols as f64) * n * var_base * (var_gct_err_lhs + var_xs * var_gct_err_rhs);
|
||||
noise += var_msg * var_a0_err * a_scale * a_scale * n;
|
||||
noise += var_msg * var_a1_err * a_scale * a_scale * n * var_xs;
|
||||
noise = noise.sqrt();
|
||||
noise /= b_scale;
|
||||
noise.log2().min(-1.0) // max noise is [-2^{-1}, 2^{-1}]
|
||||
}
|
||||
196
core/src/test_fft64/rlwe.rs
Normal file
196
core/src/test_fft64/rlwe.rs
Normal file
@@ -0,0 +1,196 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use base2k::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxZero};
|
||||
use itertools::izip;
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::Infos,
|
||||
keys::{PublicKey, SecretKey, SecretKeyDft},
|
||||
rlwe::{RLWECt, RLWECtDft, RLWEPt},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn encrypt_sk() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(32);
|
||||
let log_base2k: usize = 8;
|
||||
let log_k_ct: usize = 54;
|
||||
let log_k_pt: usize = 30;
|
||||
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = sigma * 6.0;
|
||||
|
||||
let mut ct: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_ct);
|
||||
let mut pt: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_pt);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
RLWECt::encrypt_sk_scratch_space(&module, ct.size()) | RLWECt::decrypt_scratch_space(&module, ct.size()),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
let mut data_want: Vec<i64> = vec![0i64; module.n()];
|
||||
|
||||
data_want
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = source_xa.next_i64() & 0xFF);
|
||||
|
||||
pt.data
|
||||
.encode_vec_i64(0, log_base2k, log_k_pt, &data_want, 10);
|
||||
|
||||
ct.encrypt_sk(
|
||||
&module,
|
||||
Some(&pt),
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
pt.data.zero();
|
||||
|
||||
ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||
|
||||
let mut data_have: Vec<i64> = vec![0i64; module.n()];
|
||||
|
||||
pt.data
|
||||
.decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have);
|
||||
|
||||
// TODO: properly assert the decryption noise through std(dec(ct) - pt)
|
||||
let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64;
|
||||
izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| {
|
||||
let b_scaled = (*b as f64) / scale;
|
||||
assert!(
|
||||
(*a as f64 - b_scaled).abs() < 0.1,
|
||||
"{} {}",
|
||||
*a as f64,
|
||||
b_scaled
|
||||
)
|
||||
});
|
||||
|
||||
module.free();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_zero_sk() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(1024);
|
||||
let log_base2k: usize = 8;
|
||||
let log_k_ct: usize = 55;
|
||||
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = sigma * 6.0;
|
||||
|
||||
let mut pt: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([1u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
let mut ct_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
RLWECtDft::decrypt_scratch_space(&module, ct_dft.size())
|
||||
| RLWECtDft::encrypt_zero_sk_scratch_space(&module, ct_dft.size()),
|
||||
);
|
||||
|
||||
ct_dft.encrypt_zero_sk(
|
||||
&module,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||
|
||||
assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2);
|
||||
module.free();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encrypt_pk() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(32);
|
||||
let log_base2k: usize = 8;
|
||||
let log_k_ct: usize = 54;
|
||||
let log_k_pk: usize = 64;
|
||||
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = sigma * 6.0;
|
||||
|
||||
let mut ct: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_ct);
|
||||
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
let mut source_xu: Source = Source::new([0u8; 32]);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
let mut pk: PublicKey<Vec<u8>, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk);
|
||||
pk.generate(
|
||||
&module,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
RLWECt::encrypt_sk_scratch_space(&module, ct.size())
|
||||
| RLWECt::decrypt_scratch_space(&module, ct.size())
|
||||
| RLWECt::encrypt_pk_scratch_space(&module, pk.size()),
|
||||
);
|
||||
|
||||
let mut data_want: Vec<i64> = vec![0i64; module.n()];
|
||||
|
||||
data_want
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = source_xa.next_i64() & 0);
|
||||
|
||||
pt_want
|
||||
.data
|
||||
.encode_vec_i64(0, log_base2k, log_k_ct, &data_want, 10);
|
||||
|
||||
ct.encrypt_pk(
|
||||
&module,
|
||||
Some(&pt_want),
|
||||
&pk,
|
||||
&mut source_xu,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
|
||||
|
||||
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0);
|
||||
|
||||
assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2);
|
||||
|
||||
module.free();
|
||||
}
|
||||
}
|
||||
3
core/src/utils.rs
Normal file
3
core/src/utils.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub(crate) fn derive_size(log_base2k: usize, log_k: usize) -> usize {
|
||||
(log_k + log_base2k - 1) / log_base2k
|
||||
}
|
||||
Reference in New Issue
Block a user