added rgsw encrypt + test

This commit is contained in:
Jean-Philippe Bossuat
2025-05-08 18:32:19 +02:00
parent 107e83c65c
commit de3b34477d
8 changed files with 384 additions and 162 deletions

View File

@@ -196,7 +196,7 @@ impl Scratch {
}
}
pub fn tmp_scalar<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
pub fn tmp_scalar_znx<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnx<&mut [u8]>, &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx(module, cols));
(
@@ -205,7 +205,7 @@ impl Scratch {
)
}
pub fn tmp_scalar_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) {
pub fn tmp_scalar_znx_dft<B: Backend>(&mut self, module: &Module<B>, cols: usize) -> (ScalarZnxDft<&mut [u8], B>, &mut Self) {
let (take_slice, rem_slice) = Self::take_slice_aligned(&mut self.data, bytes_of_scalar_znx_dft(module, cols));
(

View File

@@ -1,5 +1,5 @@
use crate::znx_base::ZnxInfos;
use crate::{Backend, DataView, DataViewMut, Module, ZnxSliceSize, ZnxView, ZnxViewMut, alloc_aligned};
use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnx, VecZnxToMut, VecZnxToRef, ZnxSliceSize, ZnxView, ZnxViewMut};
use rand::seq::SliceRandom;
use rand_core::RngCore;
use rand_distr::{Distribution, weighted::WeightedIndex};
@@ -144,6 +144,17 @@ impl ScalarZnxToMut for ScalarZnx<Vec<u8>> {
}
}
impl VecZnxToMut for ScalarZnx<Vec<u8>>{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data.as_mut_slice(),
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToRef for ScalarZnx<Vec<u8>> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
@@ -154,6 +165,17 @@ impl ScalarZnxToRef for ScalarZnx<Vec<u8>> {
}
}
impl VecZnxToRef for ScalarZnx<Vec<u8>>{
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data.as_slice(),
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToMut for ScalarZnx<&mut [u8]> {
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
ScalarZnx {
@@ -164,6 +186,17 @@ impl ScalarZnxToMut for ScalarZnx<&mut [u8]> {
}
}
impl VecZnxToMut for ScalarZnx<&mut [u8]> {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToRef for ScalarZnx<&mut [u8]> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
@@ -174,6 +207,17 @@ impl ScalarZnxToRef for ScalarZnx<&mut [u8]> {
}
}
impl VecZnxToRef for ScalarZnx<&mut [u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
}
}
}
impl ScalarZnxToRef for ScalarZnx<&[u8]> {
fn to_ref(&self) -> ScalarZnx<&[u8]> {
ScalarZnx {
@@ -183,3 +227,14 @@ impl ScalarZnxToRef for ScalarZnx<&[u8]> {
}
}
}
impl VecZnxToRef for ScalarZnx<&[u8]> {
fn to_ref(&self) -> VecZnx<&[u8]> {
VecZnx {
data: self.data,
n: self.n,
cols: self.cols,
size: 1,
}
}
}

View File

@@ -1,103 +1,103 @@
use crate::ffi::svp;
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft,
VecZnxDftToMut, VecZnxDftToRef,
};
pub trait ScalarZnxDftAlloc<B: Backend> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>;
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>;
}
pub trait ScalarZnxDftOps<BACKEND: Backend> {
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxDftToMut<BACKEND>,
A: ScalarZnxToRef;
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: ScalarZnxDftToRef<BACKEND>,
B: VecZnxDftToRef<FFT64>;
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: ScalarZnxDftToRef<BACKEND>;
}
impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B> {
ScalarZnxDftOwned::new(self, cols)
}
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize {
ScalarZnxDftOwned::bytes_of(self, cols)
}
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B> {
ScalarZnxDftOwned::new_from_bytes(self, cols, bytes)
}
}
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxDftToMut<FFT64>,
A: ScalarZnxToRef,
{
unsafe {
svp::svp_prepare(
self.ptr,
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
a.to_ref().at_ptr(a_col, 0),
)
}
}
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: ScalarZnxDftToRef<FFT64>,
B: VecZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
let b: VecZnxDft<&[u8], FFT64> = b.to_ref();
unsafe {
svp::svp_apply_dft_to_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
b.size() as u64,
b.cols() as u64,
)
}
}
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: ScalarZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
unsafe {
svp::svp_apply_dft_to_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
)
}
}
}
use crate::ffi::svp;
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft,
VecZnxDftToMut, VecZnxDftToRef,
};
pub trait ScalarZnxDftAlloc<B: Backend> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B>;
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize;
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B>;
}
pub trait ScalarZnxDftOps<BACKEND: Backend> {
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxDftToMut<BACKEND>,
A: ScalarZnxToRef;
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: ScalarZnxDftToRef<BACKEND>,
B: VecZnxDftToRef<FFT64>;
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<BACKEND>,
A: ScalarZnxDftToRef<BACKEND>;
}
impl<B: Backend> ScalarZnxDftAlloc<B> for Module<B> {
fn new_scalar_znx_dft(&self, cols: usize) -> ScalarZnxDftOwned<B> {
ScalarZnxDftOwned::new(self, cols)
}
fn bytes_of_scalar_znx_dft(&self, cols: usize) -> usize {
ScalarZnxDftOwned::bytes_of(self, cols)
}
fn new_scalar_znx_dft_from_bytes(&self, cols: usize, bytes: Vec<u8>) -> ScalarZnxDftOwned<B> {
ScalarZnxDftOwned::new_from_bytes(self, cols, bytes)
}
}
impl ScalarZnxDftOps<FFT64> for Module<FFT64> {
fn svp_prepare<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: ScalarZnxDftToMut<FFT64>,
A: ScalarZnxToRef,
{
unsafe {
svp::svp_prepare(
self.ptr,
res.to_mut().at_mut_ptr(res_col, 0) as *mut svp::svp_ppol_t,
a.to_ref().at_ptr(a_col, 0),
)
}
}
fn svp_apply<R, A, B>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize, b: &B, b_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: ScalarZnxDftToRef<FFT64>,
B: VecZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
let b: VecZnxDft<&[u8], FFT64> = b.to_ref();
unsafe {
svp::svp_apply_dft_to_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
b.at_ptr(b_col, 0) as *const vec_znx_dft_t,
b.size() as u64,
b.cols() as u64,
)
}
}
fn svp_apply_inplace<R, A>(&self, res: &mut R, res_col: usize, a: &A, a_col: usize)
where
R: VecZnxDftToMut<FFT64>,
A: ScalarZnxDftToRef<FFT64>,
{
let mut res: VecZnxDft<&mut [u8], FFT64> = res.to_mut();
let a: ScalarZnxDft<&[u8], FFT64> = a.to_ref();
unsafe {
svp::svp_apply_dft_to_dft(
self.ptr,
res.at_mut_ptr(res_col, 0) as *mut vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
a.at_ptr(a_col, 0) as *const svp::svp_ppol_t,
res.at_ptr(res_col, 0) as *const vec_znx_dft_t,
res.size() as u64,
res.cols() as u64,
)
}
}
}

View File

@@ -20,9 +20,9 @@ use std::{cmp::min, fmt};
/// are small polynomials of Zn\[X\].
pub struct VecZnx<D> {
pub data: D,
n: usize,
cols: usize,
size: usize,
pub n: usize,
pub cols: usize,
pub size: usize,
}
impl<D> ZnxInfos for VecZnx<D> {

View File

@@ -114,6 +114,9 @@ pub trait VecZnxBigOps<BACKEND: Backend> {
R: VecZnxBigToMut<BACKEND>,
A: VecZnxToRef;
/// Negates `a` inplace.
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, a_col: usize) where A: VecZnxBigToMut<BACKEND>;
/// Normalizes `a` and stores the result on `b`.
///
/// # Arguments
@@ -503,6 +506,25 @@ impl VecZnxBigOps<FFT64> for Module<FFT64> {
}
}
fn vec_znx_big_negate_inplace<A>(&self, a: &mut A, res_col: usize) where A: VecZnxBigToMut<FFT64> {
let mut a: VecZnxBig<&mut [u8], FFT64> = a.to_mut();
#[cfg(debug_assertions)]
{
assert_eq!(a.n(), self.n());
}
unsafe {
vec_znx::vec_znx_negate(
self.ptr,
a.at_mut_ptr(res_col, 0),
a.size() as u64,
a.sl() as u64,
a.at_ptr(res_col, 0),
a.size() as u64,
a.sl() as u64,
)
}
}
fn vec_znx_big_normalize<R, A>(
&self,
log_base2k: usize,

View File

@@ -91,7 +91,7 @@ pub fn encrypt_grlwe_sk<C, P, S>(
module: &Module<FFT64>,
ct: &mut GRLWECt<C, FFT64>,
pt: &ScalarZnx<P>,
sk: &SecretKeyDft<S, FFT64>,
sk_dft: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
@@ -131,7 +131,7 @@ pub fn encrypt_grlwe_sk<C, P, S>(
vec_znx_ct.encrypt_sk(
module,
Some(&vec_znx_pt),
sk,
sk_dft,
source_xa,
source_xe,
sigma,
@@ -186,7 +186,7 @@ mod tests {
use super::GRLWECt;
#[test]
fn encrypt_sk_vec_znx_fft64() {
fn encrypt_sk_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 8;
let log_k_ct: usize = 54;
@@ -233,7 +233,7 @@ mod tests {
ct_rlwe_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &pt_scalar, 0);
let std_pt: f64 = pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2();
assert!((sigma - std_pt) <= 0.2, "{} {}", sigma, std_pt);
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
});
module.free();

View File

@@ -1,13 +1,13 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxDftOps, ZnxView, ZnxViewMut,
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxOps,
ZnxZero,
};
use sampling::source::Source;
use crate::{
elem::Infos,
elem_grlwe::GRLWECt,
elem_rlwe::{RLWECt, RLWECtDft, RLWEPt},
elem_rlwe::{RLWECt, RLWECtDft, RLWEPt, encrypt_rlwe_sk},
keys::SecretKeyDft,
utils::derive_size,
};
@@ -62,28 +62,32 @@ where
}
}
impl GRLWECt<Vec<u8>, FFT64> {
impl RGSWCt<Vec<u8>, FFT64> {
pub fn encrypt_sk_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
RLWECt::encrypt_sk_scratch_bytes(module, size)
+ module.bytes_of_vec_znx(2, size)
+ module.bytes_of_vec_znx(1, size)
+ module.bytes_of_vec_znx_dft(2, size)
}
}
pub fn encrypt_pk_scratch_bytes(module: &Module<FFT64>, pk_size: usize) -> usize {
RLWECt::encrypt_pk_scratch_bytes(module, pk_size)
}
pub fn decrypt_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
RLWECtDft::decrypt_scratch_bytes(module, size)
impl<C> RGSWCt<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
{
pub fn get_row(&self, module: &Module<FFT64>, row_i: usize, col_j: usize, res: &mut RLWECtDft<C, FFT64>)
where
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64>,
{
module.vmp_extract_row(res, self, row_i, col_j);
}
}
pub fn encrypt_grlwe_sk<C, P, S>(
pub fn encrypt_rgsw_sk<C, P, S>(
module: &Module<FFT64>,
ct: &mut GRLWECt<C, FFT64>,
ct: &mut RGSWCt<C, FFT64>,
pt: &ScalarZnx<P>,
sk: &SecretKeyDft<S, FFT64>,
sk_dft: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
@@ -94,47 +98,164 @@ pub fn encrypt_grlwe_sk<C, P, S>(
ScalarZnx<P>: ScalarZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
let rows: usize = ct.rows();
let size: usize = ct.size();
let log_base2k: usize = ct.log_base2k();
let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size);
let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size);
let (mut tmp_dft, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size);
let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size);
let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, 2, size);
let mut tmp_pt: RLWEPt<&mut [u8]> = RLWEPt {
let mut vec_znx_pt: RLWEPt<&mut [u8]> = RLWEPt {
data: tmp_znx_pt,
log_base2k: ct.log_base2k(),
log_base2k: log_base2k,
log_k: ct.log_k(),
};
let mut tmp_ct: RLWECt<&mut [u8]> = RLWECt {
let mut vec_znx_ct: RLWECt<&mut [u8]> = RLWECt {
data: tmp_znx_ct,
log_base2k: ct.log_base2k(),
log_base2k: log_base2k,
log_k: ct.log_k(),
};
(0..rows).for_each(|row_i| {
tmp_pt
.data
.at_mut(0, row_i)
.copy_from_slice(&pt.to_ref().raw());
(0..ct.rows()).for_each(|row_j| {
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_j, pt, 0);
module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scrach_2);
tmp_ct.encrypt_sk(
module,
Some(&tmp_pt),
sk,
source_xa,
source_xe,
sigma,
bound,
scratch_3,
);
(0..ct.cols()).for_each(|col_i| {
// rlwe encrypt of vec_znx_pt into vec_znx_ct
encrypt_rlwe_sk(
module,
&mut vec_znx_ct,
Some((&vec_znx_pt, col_i)),
sk_dft,
source_xa,
source_xe,
sigma,
bound,
scrach_2,
);
tmp_pt.data.at_mut(0, row_i).fill(0);
// Switch vec_znx_ct into DFT domain
{
let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, 2, size);
module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0);
module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1);
module.vmp_prepare_row(ct, row_j, col_i, &vec_znx_dft_ct);
}
});
module.vec_znx_dft(&mut tmp_dft, 0, &tmp_ct, 0);
module.vec_znx_dft(&mut tmp_dft, 1, &tmp_ct, 1);
module.vmp_prepare_row(ct, row_i, 0, &tmp_dft);
vec_znx_pt.data.zero(); // zeroes for next iteration
});
}
impl<C> RGSWCt<C, FFT64> {
pub fn encrypt_sk<P, S>(
&mut self,
module: &Module<FFT64>,
pt: &ScalarZnx<P>,
sk_dft: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
bound: f64,
scratch: &mut Scratch,
) where
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
ScalarZnx<P>: ScalarZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
encrypt_rgsw_sk(
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
)
}
}
#[cfg(test)]
mod tests {
use base2k::{
FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps,
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxZero,
};
use sampling::source::Source;
use crate::{
elem::Infos,
elem_rlwe::{RLWECtDft, RLWEPt},
keys::{SecretKey, SecretKeyDft},
};
use super::RGSWCt;
#[test]
fn encrypt_rgsw_sk_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(2048);
let log_base2k: usize = 8;
let log_k_ct: usize = 54;
let rows: usize = 4;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
let mut ct: RGSWCt<Vec<u8>, FFT64> = RGSWCt::new(&module, log_base2k, log_k_ct, rows);
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
let mut pt_scalar: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
pt_scalar.fill_ternary_hw(0, module.n(), &mut source_xs);
let mut scratch: ScratchOwned = ScratchOwned::new(
RGSWCt::encrypt_sk_scratch_bytes(&module, ct.size()) | RLWECtDft::decrypt_scratch_bytes(&module, ct.size()),
);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk_dft.dft(&module, &sk);
ct.encrypt_sk(
&module,
&pt_scalar,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
let mut ct_rlwe_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2);
let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct.size());
let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct.size());
(0..ct.cols()).for_each(|col_j| {
(0..ct.rows()).for_each(|row_i| {
module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_scalar, 0);
if col_j == 1 {
module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, 0);
module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
module.vec_znx_big_normalize(log_base2k, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
}
ct.get_row(&module, row_i, col_j, &mut ct_rlwe_dft);
ct_rlwe_dft.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_have, 0, &pt_want, 0);
let std_pt: f64 = pt_have.data.std(0, log_base2k) * (log_k_ct as f64).exp2();
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
pt_want.data.zero();
});
});
module.free();
}
}

View File

@@ -180,7 +180,7 @@ impl RLWECt<Vec<u8>> {
pub fn encrypt_rlwe_sk<C, P, S>(
module: &Module<FFT64>,
ct: &mut RLWECt<C>,
pt: Option<&RLWEPt<P>>,
pt: Option<(&RLWEPt<P>, usize)>,
sk_dft: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
@@ -213,8 +213,18 @@ pub fn encrypt_rlwe_sk<C, P, S>(
}
// c0_big = m - c0_big
if let Some(pt) = pt {
module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0);
if let Some((pt, col)) = pt {
match col {
0 => module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0),
1 => {
module.vec_znx_big_negate_inplace(&mut c0_big, 0);
module.vec_znx_add_inplace(ct, 1, pt, 0);
module.vec_znx_normalize_inplace(log_base2k, ct, 1, scratch_1);
}
_ => panic!("invalid target column: {}", col),
}
} else {
module.vec_znx_big_negate_inplace(&mut c0_big, 0);
}
// c0_big += e
c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound);
@@ -273,9 +283,23 @@ impl<C> RLWECt<C> {
VecZnx<P>: VecZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
encrypt_rlwe_sk(
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
)
if let Some(pt) = pt {
encrypt_rlwe_sk(
module,
self,
Some((pt, 0)),
sk_dft,
source_xa,
source_xe,
sigma,
bound,
scratch,
)
} else {
encrypt_rlwe_sk::<C, P, S>(
module, self, None, sk_dft, source_xa, source_xe, sigma, bound, scratch,
)
}
}
pub fn decrypt<P, S>(
@@ -483,10 +507,10 @@ pub(crate) fn encrypt_rlwe_pk<C, P, S>(
let size_pk: usize = pk.size();
// Generates u according to the underlying secret distribution.
let (mut u_dft, scratch_1) = scratch.tmp_scalar_dft(module, 1);
let (mut u_dft, scratch_1) = scratch.tmp_scalar_znx_dft(module, 1);
{
let (mut u, _) = scratch_1.tmp_scalar(module, 1);
let (mut u, _) = scratch_1.tmp_scalar_znx(module, 1);
match pk.dist {
SecretDistribution::NONE => panic!(
"invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate"