mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
added rgsw encrypt + test
This commit is contained in:
@@ -196,7 +196,7 @@ impl Scratch {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tmp_scalar<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
|
||||
pub fn tmp_scalar_znx<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols));
|
||||
|
||||
(
|
||||
@@ -205,7 +205,7 @@ impl Scratch {
|
||||
)
|
||||
}
|
||||
|
||||
pub fn tmp_scalar_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) {
|
||||
pub fn tmp_scalar_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) {
|
||||
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols));
|
||||
|
||||
(
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::znx_base::ZnxInfos;
|
||||
use crate::{Backend, DataView, DataViewMut, Module, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned};
|
||||
use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut};
|
||||
use rand::seq::SliceRandom;
|
||||
use rand_core::RngCore;
|
||||
use rand_distr::{Distribution, weighted::WeightedIndex};
|
||||
@@ -144,6 +144,17 @@ impl ScalarZnxToMut for ScalarZnx<Vec<u8>> {
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxToMut for ScalarZnx<Vec<u8>>{
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_mut_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarZnxToRef for ScalarZnx<Vec<u8>> {
|
||||
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
||||
ScalarZnx {
|
||||
@@ -154,6 +165,17 @@ impl ScalarZnxToRef for ScalarZnx<Vec<u8>> {
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxToRef for ScalarZnx<Vec<u8>>{
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
VecZnx {
|
||||
data: self.data.as_slice(),
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarZnxToMut for ScalarZnx<&mut [u8]> {
|
||||
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
|
||||
ScalarZnx {
|
||||
@@ -164,6 +186,17 @@ impl ScalarZnxToMut for ScalarZnx<&mut [u8]> {
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxToMut for ScalarZnx<&mut [u8]> {
|
||||
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
|
||||
VecZnx {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarZnxToRef for ScalarZnx<&mut [u8]> {
|
||||
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
||||
ScalarZnx {
|
||||
@@ -174,6 +207,17 @@ impl ScalarZnxToRef for ScalarZnx<&mut [u8]> {
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxToRef for ScalarZnx<&mut [u8]> {
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
VecZnx {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarZnxToRef for ScalarZnx<&[u8]> {
|
||||
fn to_ref(&self) -> ScalarZnx<&[u8]> {
|
||||
ScalarZnx {
|
||||
@@ -183,3 +227,14 @@ impl ScalarZnxToRef for ScalarZnx<&[u8]> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VecZnxToRef for ScalarZnx<&[u8]> {
|
||||
fn to_ref(&self) -> VecZnx<&[u8]> {
|
||||
VecZnx {
|
||||
data: self.data,
|
||||
n: self.n,
|
||||
cols: self.cols,
|
||||
size: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,103 +1,103 @@
|
||||
use crate::ffi::svp;
|
||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use crate::{
|
||||
Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft,
|
||||
VecZnxDftToMut, VecZnxDftToRef,
|
||||
};
|
||||
|
||||
pub trait ScalarZnxDftAlloc<B: Backend> {
|
||||
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>;
|
||||
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
|
||||
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait ScalarZnxDftOps<BACKEND: Backend> {
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxDftToMut<BACKEND>,
|
||||
A: ScalarZnxToRef;
|
||||
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<BACKEND>,
|
||||
A: ScalarZnxDftToRef<BACKEND>,
|
||||
B: VecZnxDftToRef<FFT64>;
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<BACKEND>,
|
||||
A: ScalarZnxDftToRef<BACKEND>;
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
|
||||
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B> {
|
||||
ScalarZnxDftOwned::new(self, cols)
|
||||
}
|
||||
|
||||
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize {
|
||||
ScalarZnxDftOwned::bytes_of(self, cols)
|
||||
}
|
||||
|
||||
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B> {
|
||||
ScalarZnxDftOwned::new_from_bytes(self, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxDftToMut<FFT64>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
unsafe {
|
||||
svp::svp_prepare(
|
||||
self.ptr,
|
||||
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
|
||||
a.to_ref().at_ptr(a_col, 0),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: ScalarZnxDftToRef<FFT64>,
|
||||
B: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
|
||||
b.size() as u64,
|
||||
b.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
use crate::ffi::svp;
|
||||
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
|
||||
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
|
||||
use crate::{
|
||||
Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft,
|
||||
VecZnxDftToMut, VecZnxDftToRef,
|
||||
};
|
||||
|
||||
pub trait ScalarZnxDftAlloc<B: Backend> {
|
||||
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>;
|
||||
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
|
||||
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>;
|
||||
}
|
||||
|
||||
pub trait ScalarZnxDftOps<BACKEND: Backend> {
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxDftToMut<BACKEND>,
|
||||
A: ScalarZnxToRef;
|
||||
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<BACKEND>,
|
||||
A: ScalarZnxDftToRef<BACKEND>,
|
||||
B: VecZnxDftToRef<FFT64>;
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<BACKEND>,
|
||||
A: ScalarZnxDftToRef<BACKEND>;
|
||||
}
|
||||
|
||||
impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
|
||||
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B> {
|
||||
ScalarZnxDftOwned::new(self, cols)
|
||||
}
|
||||
|
||||
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize {
|
||||
ScalarZnxDftOwned::bytes_of(self, cols)
|
||||
}
|
||||
|
||||
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B> {
|
||||
ScalarZnxDftOwned::new_from_bytes(self, cols, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
|
||||
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: ScalarZnxDftToMut<FFT64>,
|
||||
A: ScalarZnxToRef,
|
||||
{
|
||||
unsafe {
|
||||
svp::svp_prepare(
|
||||
self.ptr,
|
||||
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
|
||||
a.to_ref().at_ptr(a_col, 0),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: ScalarZnxDftToRef<FFT64>,
|
||||
B: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
let b: VecZnxDft<&[u8], FFT64> = b.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
|
||||
b.size() as u64,
|
||||
b.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
|
||||
where
|
||||
R: VecZnxDftToMut<FFT64>,
|
||||
A: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
|
||||
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
|
||||
unsafe {
|
||||
svp::svp_apply_dft_to_dft(
|
||||
self.ptr,
|
||||
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
|
||||
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
|
||||
res.size() as u64,
|
||||
res.cols() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ use std::{cmp::min, fmt};
|
||||
/// are small polynomials of Zn\[X\].
|
||||
pub struct VecZnx<D> {
|
||||
pub data: D,
|
||||
n: usize,
|
||||
cols: usize,
|
||||
size: usize,
|
||||
pub n: usize,
|
||||
pub cols: usize,
|
||||
pub size: usize,
|
||||
}
|
||||
|
||||
impl<D> ZnxInfos for VecZnx<D> {
|
||||
|
||||
@@ -114,6 +114,9 @@ pub trait VecZnxBigOps<BACKEND: Backend> {
|
||||
R: VecZnxBigToMut<BACKEND>,
|
||||
A: VecZnxToRef;
|
||||
|
||||
/// Negates `a` inplace.
|
||||
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize) where A: VecZnxBigToMut<BACKEND>;
|
||||
|
||||
/// Normalizes `a` and stores the result on `b`.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -503,6 +506,25 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, res_col: usize) where A: VecZnxBigToMut<FFT64> {
|
||||
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(a.n(), self.n());
|
||||
}
|
||||
unsafe {
|
||||
vec_znx::vec_znx_negate(
|
||||
self.ptr,
|
||||
a.at_mut_ptr(res_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
a.at_ptr(res_col, 0),
|
||||
a.size() as u64,
|
||||
a.sl() as u64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn vec_znx_big_normalize<R, A>(
|
||||
&self,
|
||||
log_base2k: usize,
|
||||
|
||||
@@ -91,7 +91,7 @@ pub fn encrypt_grlwe_sk<C, P, S>(
|
||||
module: &Module<FFT64>,
|
||||
ct: &mut GRLWECt<C, FFT64>,
|
||||
pt: &ScalarZnx<P>,
|
||||
sk: &SecretKeyDft<S, FFT64>,
|
||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
@@ -131,7 +131,7 @@ pub fn encrypt_grlwe_sk<C, P, S>(
|
||||
vec_znx_ct.encrypt_sk(
|
||||
module,
|
||||
Some(&vec_znx_pt),
|
||||
sk,
|
||||
sk_dft,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
@@ -186,7 +186,7 @@ mod tests {
|
||||
use super::GRLWECt;
|
||||
|
||||
#[test]
|
||||
fn encrypt_sk_vec_znx_fft64() {
|
||||
fn encrypt_sk_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
||||
let log_base2k: usize = 8;
|
||||
let log_k_ct: usize = 54;
|
||||
@@ -233,7 +233,7 @@ mod tests {
|
||||
ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0);
|
||||
let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2();
|
||||
assert!((sigma - std_pt) <= 0.2, "{} {}", sigma, std_pt);
|
||||
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
|
||||
});
|
||||
|
||||
module.free();
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
|
||||
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxDftOps, ZnxView, ZnxViewMut,
|
||||
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxOps,
|
||||
ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::Infos,
|
||||
elem_grlwe::GRLWECt,
|
||||
elem_rlwe::{RLWECt, RLWECtDft, RLWEPt},
|
||||
elem_rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk},
|
||||
keys::SecretKeyDft,
|
||||
utils::derive_size,
|
||||
};
|
||||
@@ -62,28 +62,32 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl GRLWECt<Vec<u8>, FFT64> {
|
||||
impl RGSWCt<Vec<u8>, FFT64> {
|
||||
pub fn encrypt_sk_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
|
||||
RLWECt::encrypt_sk_scratch_bytes(module, size)
|
||||
+ module.bytes_of_vec_znx(2, size)
|
||||
+ module.bytes_of_vec_znx(1, size)
|
||||
+ module.bytes_of_vec_znx_dft(2, size)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encrypt_pk_scratch_bytes(module: &Module<FFT64>, pk_size: usize) -> usize {
|
||||
RLWECt::encrypt_pk_scratch_bytes(module, pk_size)
|
||||
}
|
||||
|
||||
pub fn decrypt_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
|
||||
RLWECtDft::decrypt_scratch_bytes(module, size)
|
||||
impl<C> RGSWCt<C, FFT64>
|
||||
where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
pub fn get_row(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut RLWECtDft<C, FFT64>)
|
||||
where
|
||||
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64>,
|
||||
{
|
||||
module.vmp_extract_row(res, self, row_i, col_j);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn encrypt_grlwe_sk<C, P, S>(
|
||||
pub fn encrypt_rgsw_sk<C, P, S>(
|
||||
module: &Module<FFT64>,
|
||||
ct: &mut GRLWECt<C, FFT64>,
|
||||
ct: &mut RGSWCt<C, FFT64>,
|
||||
pt: &ScalarZnx<P>,
|
||||
sk: &SecretKeyDft<S, FFT64>,
|
||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
@@ -94,47 +98,164 @@ pub fn encrypt_grlwe_sk<C, P, S>(
|
||||
ScalarZnx<P>: ScalarZnxToRef,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let rows: usize = ct.rows();
|
||||
let size: usize = ct.size();
|
||||
let log_base2k: usize = ct.log_base2k();
|
||||
|
||||
let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size);
|
||||
let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size);
|
||||
let (mut tmp_dft, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size);
|
||||
let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size);
|
||||
let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, 2, size);
|
||||
|
||||
let mut tmp_pt: RLWEPt<&mut [u8]> = RLWEPt {
|
||||
let mut vec_znx_pt: RLWEPt<&mut [u8]> = RLWEPt {
|
||||
data: tmp_znx_pt,
|
||||
log_base2k: ct.log_base2k(),
|
||||
log_base2k: log_base2k,
|
||||
log_k: ct.log_k(),
|
||||
};
|
||||
|
||||
let mut tmp_ct: RLWECt<&mut [u8]> = RLWECt {
|
||||
let mut vec_znx_ct: RLWECt<&mut [u8]> = RLWECt {
|
||||
data: tmp_znx_ct,
|
||||
log_base2k: ct.log_base2k(),
|
||||
log_base2k: log_base2k,
|
||||
log_k: ct.log_k(),
|
||||
};
|
||||
|
||||
(0..rows).for_each(|row_i| {
|
||||
tmp_pt
|
||||
.data
|
||||
.at_mut(0, row_i)
|
||||
.copy_from_slice(&pt.to_ref().raw());
|
||||
(0..ct.rows()).for_each(|row_j| {
|
||||
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
|
||||
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0);
|
||||
module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2);
|
||||
|
||||
tmp_ct.encrypt_sk(
|
||||
module,
|
||||
Some(&tmp_pt),
|
||||
sk,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch_3,
|
||||
);
|
||||
(0..ct.cols()).for_each(|col_i| {
|
||||
// rlwe encrypt of vec_znx_pt into vec_znx_ct
|
||||
encrypt_rlwe_sk(
|
||||
module,
|
||||
&mut vec_znx_ct,
|
||||
Some((&vec_znx_pt, col_i)),
|
||||
sk_dft,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scrach_2,
|
||||
);
|
||||
|
||||
tmp_pt.data.at_mut(0, row_i).fill(0);
|
||||
// Switch vec_znx_ct into DFT domain
|
||||
{
|
||||
let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, 2, size);
|
||||
module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0);
|
||||
module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1);
|
||||
module.vmp_prepare_row(ct, row_j, col_i, &vec_znx_dft_ct);
|
||||
}
|
||||
});
|
||||
|
||||
module.vec_znx_dft(&mut tmp_dft, 0, &tmp_ct, 0);
|
||||
module.vec_znx_dft(&mut tmp_dft, 1, &tmp_ct, 1);
|
||||
|
||||
module.vmp_prepare_row(ct, row_i, 0, &tmp_dft);
|
||||
vec_znx_pt.data.zero(); // zeroes for next iteration
|
||||
});
|
||||
}
|
||||
|
||||
impl<C> RGSWCt<C, FFT64> {
|
||||
pub fn encrypt_sk<P, S>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
pt: &ScalarZnx<P>,
|
||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
|
||||
ScalarZnx<P>: ScalarZnxToRef,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
encrypt_rgsw_sk(
|
||||
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use base2k::{
|
||||
FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps,
|
||||
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::Infos,
|
||||
elem_rlwe::{RLWECtDft, RLWEPt},
|
||||
keys::{SecretKey, SecretKeyDft},
|
||||
};
|
||||
|
||||
use super::RGSWCt;
|
||||
|
||||
#[test]
|
||||
fn encrypt_rgsw_sk_fft64() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
||||
let log_base2k: usize = 8;
|
||||
let log_k_ct: usize = 54;
|
||||
let rows: usize = 4;
|
||||
|
||||
let sigma: f64 = 3.2;
|
||||
let bound: f64 = sigma * 6.0;
|
||||
|
||||
let mut ct: RGSWCt<Vec<u8>, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows);
|
||||
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
|
||||
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
|
||||
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
|
||||
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||
|
||||
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
RGSWCt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECtDft::decrypt_scratch_bytes(&module, ct.size()),
|
||||
);
|
||||
|
||||
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||
|
||||
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||
sk_dft.dft(&module, &sk);
|
||||
|
||||
ct.encrypt_sk(
|
||||
&module,
|
||||
&pt_scalar,
|
||||
&sk_dft,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let mut ct_rlwe_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2);
|
||||
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct.size());
|
||||
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct.size());
|
||||
|
||||
(0..ct.cols()).for_each(|col_j| {
|
||||
(0..ct.rows()).for_each(|row_i| {
|
||||
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
|
||||
|
||||
if col_j == 1 {
|
||||
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
||||
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0);
|
||||
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
||||
module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
||||
}
|
||||
|
||||
ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft);
|
||||
|
||||
ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||
|
||||
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
|
||||
|
||||
let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2();
|
||||
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
|
||||
|
||||
pt_want.data.zero();
|
||||
});
|
||||
});
|
||||
|
||||
module.free();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,7 +180,7 @@ impl RLWECt<Vec<u8>> {
|
||||
pub fn encrypt_rlwe_sk<C, P, S>(
|
||||
module: &Module<FFT64>,
|
||||
ct: &mut RLWECt<C>,
|
||||
pt: Option<&RLWEPt<P>>,
|
||||
pt: Option<(&RLWEPt<P>, usize)>,
|
||||
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
@@ -213,8 +213,18 @@ pub fn encrypt_rlwe_sk<C, P, S>(
|
||||
}
|
||||
|
||||
// c0_big = m - c0_big
|
||||
if let Some(pt) = pt {
|
||||
module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0);
|
||||
if let Some((pt, col)) = pt {
|
||||
match col {
|
||||
0 => module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0),
|
||||
1 => {
|
||||
module.vec_znx_big_negate_inplace(&mut c0_big, 0);
|
||||
module.vec_znx_add_inplace(ct, 1, pt, 0);
|
||||
module.vec_znx_normalize_inplace(log_base2k, ct, 1, scratch_1);
|
||||
}
|
||||
_ => panic!("invalid target column: {}", col),
|
||||
}
|
||||
} else {
|
||||
module.vec_znx_big_negate_inplace(&mut c0_big, 0);
|
||||
}
|
||||
// c0_big += e
|
||||
c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound);
|
||||
@@ -273,9 +283,23 @@ impl<C> RLWECt<C> {
|
||||
VecZnx<P>: VecZnxToRef,
|
||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
encrypt_rlwe_sk(
|
||||
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
|
||||
)
|
||||
if let Some(pt) = pt {
|
||||
encrypt_rlwe_sk(
|
||||
module,
|
||||
self,
|
||||
Some((pt, 0)),
|
||||
sk_dft,
|
||||
source_xa,
|
||||
source_xe,
|
||||
sigma,
|
||||
bound,
|
||||
scratch,
|
||||
)
|
||||
} else {
|
||||
encrypt_rlwe_sk::<C, P, S>(
|
||||
module, self, None, sk_dft, source_xa, source_xe, sigma, bound, scratch,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrypt<P, S>(
|
||||
@@ -483,10 +507,10 @@ pub(crate) fn encrypt_rlwe_pk<C, P, S>(
|
||||
let size_pk: usize = pk.size();
|
||||
|
||||
// Generates u according to the underlying secret distribution.
|
||||
let (mut u_dft, scratch_1) = scratch.tmp_scalar_dft(module, 1);
|
||||
let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1);
|
||||
|
||||
{
|
||||
let (mut u, _) = scratch_1.tmp_scalar(module, 1);
|
||||
let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1);
|
||||
match pk.dist {
|
||||
SecretDistribution::NONE => panic!(
|
||||
"invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate"
|
||||
|
||||
Reference in New Issue
Block a user