added sk encryption

This commit is contained in:
Jean-Philippe Bossuat
2025-05-07 12:05:12 +02:00
parent 240884db8d
commit 6ce525e5a1
6 changed files with 333 additions and 299 deletions

View File

@@ -1,5 +1,5 @@
use base2k::{ use base2k::{
AddNormal, Encoding, FFT64, FillUniform, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, AddNormal, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps,
ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc,
VecZnxDftOps, VecZnxOps, ZnxInfos, VecZnxDftOps, VecZnxOps, ZnxInfos,
}; };
@@ -20,7 +20,7 @@ fn main() {
let mut source: Source = Source::new(seed); let mut source: Source = Source::new(seed);
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1) // s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
let mut s: Scalar<Vec<u8>> = module.new_scalar(1); let mut s: ScalarZnx<Vec<u8>> = module.new_scalar(1);
s.fill_ternary_prob(0, 0.5, &mut source); s.fill_ternary_prob(0, 0.5, &mut source);
// Buffer to store s in the DFT domain // Buffer to store s in the DFT domain

View File

@@ -2,8 +2,8 @@ use crate::ffi::svp;
use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{ use crate::{
Backend, FFT64, Module, ScalarZnxToRef, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft,
VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, VecZnxDftToMut, VecZnxDftToRef,
}; };
pub trait ScalarZnxDftAlloc<B: Backend> { pub trait ScalarZnxDftAlloc<B: Backend> {

View File

@@ -107,21 +107,13 @@ where
{ {
fn zero(&mut self) { fn zero(&mut self) {
unsafe { unsafe {
std::ptr::write_bytes( std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * self.poly_count());
self.as_mut_ptr(),
0,
self.n() * self.poly_count(),
);
} }
} }
fn zero_at(&mut self, i: usize, j: usize) { fn zero_at(&mut self, i: usize, j: usize) {
unsafe { unsafe {
std::ptr::write_bytes( std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.n());
self.at_mut_ptr(i, j),
0,
self.n(),
);
} }
} }
} }

View File

@@ -1,6 +1,6 @@
use base2k::{ use base2k::{
Backend, DataView, DataViewMut, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnxDftToRef, VecZnx, Backend, Module, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef,
VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxInfos,
}; };
pub trait Infos { pub trait Infos {
@@ -31,7 +31,7 @@ pub trait Infos {
/// Returns the number of size per polynomial. /// Returns the number of size per polynomial.
fn size(&self) -> usize { fn size(&self) -> usize {
let size: usize = self.inner().size(); let size: usize = self.inner().size();
debug_assert_eq!(size, derive_size(self.log_base2k(), self.log_q())); debug_assert_eq!(size, derive_size(self.log_base2k(), self.log_k()));
size size
} }
@@ -43,18 +43,18 @@ pub trait Infos {
/// Returns the base 2 logarithm of the ciphertext base. /// Returns the base 2 logarithm of the ciphertext base.
fn log_base2k(&self) -> usize; fn log_base2k(&self) -> usize;
/// Returns the base 2 logarithm of the ciphertext modulus. /// Returns the bit precision of the ciphertext.
fn log_q(&self) -> usize; fn log_k(&self) -> usize;
} }
pub struct RLWECt<C>{ pub struct RLWECt<C> {
data: VecZnx<C>, pub data: VecZnx<C>,
log_base2k: usize, pub log_base2k: usize,
log_q: usize, pub log_k: usize,
} }
impl<T: ZnxInfos> Infos for RLWECt<T> { impl<T> Infos for RLWECt<T> {
type Inner = T; type Inner = VecZnx<T>;
fn inner(&self) -> &Self::Inner { fn inner(&self) -> &Self::Inner {
&self.data &self.data
@@ -64,32 +64,37 @@ impl<T: ZnxInfos> Infos for RLWECt<T> {
self.log_base2k self.log_base2k
} }
fn log_q(&self) -> usize { fn log_k(&self) -> usize {
self.log_q self.log_k
} }
} }
impl<D> DataView for Ciphertext<D> { impl<C> VecZnxToMut for RLWECt<C>
type D = D; where
fn data(&self) -> &Self::D { VecZnx<C>: VecZnxToMut,
&self.data {
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
self.data.to_mut()
} }
} }
impl<D> DataViewMut for Ciphertext<D> { impl<C> VecZnxToRef for RLWECt<C>
fn data_mut(&mut self) -> &mut Self::D { where
&mut self.data VecZnx<C>: VecZnxToRef,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
self.data.to_ref()
} }
} }
pub struct Plaintext<T> { pub struct RLWEPt<C> {
data: T, pub data: VecZnx<C>,
log_base2k: usize, pub log_base2k: usize,
log_q: usize, pub log_k: usize,
} }
impl<T: ZnxInfos> Infos for Plaintext<T> { impl<T> Infos for RLWEPt<T> {
type Inner = T; type Inner = VecZnx<T>;
fn inner(&self) -> &Self::Inner { fn inner(&self) -> &Self::Inner {
&self.data &self.data
@@ -99,140 +104,99 @@ impl<T: ZnxInfos> Infos for Plaintext<T> {
self.log_base2k self.log_base2k
} }
fn log_q(&self) -> usize { fn log_k(&self) -> usize {
self.log_q self.log_k
} }
} }
impl<T> Plaintext<T> { impl<C> VecZnxToMut for RLWEPt<C>
pub fn data(&self) -> &T { where
VecZnx<C>: VecZnxToMut,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
self.data.to_mut()
}
}
impl<C> VecZnxToRef for RLWEPt<C>
where
VecZnx<C>: VecZnxToRef,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
self.data.to_ref()
}
}
impl RLWECt<Vec<u8>> {
pub fn new<B: Backend>(module: &Module<B>, log_base2k: usize, log_k: usize, cols: usize) -> Self {
Self {
data: module.new_vec_znx(cols, derive_size(log_base2k, log_k)),
log_base2k: log_base2k,
log_k: log_k,
}
}
}
impl RLWEPt<Vec<u8>> {
pub fn new<B: Backend>(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
Self {
data: module.new_vec_znx(1, derive_size(log_base2k, log_k)),
log_base2k: log_base2k,
log_k: log_k,
}
}
}
pub struct RLWECtDft<C, B: Backend> {
pub data: VecZnxDft<C, B>,
pub log_base2k: usize,
pub log_k: usize,
}
impl<B: Backend> RLWECtDft<Vec<u8>, B> {
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
Self {
data: module.new_vec_znx_dft(1, derive_size(log_base2k, log_k)),
log_base2k: log_base2k,
log_k: log_k,
}
}
}
impl<T, B: Backend> Infos for RLWECtDft<T, B> {
type Inner = VecZnxDft<T, B>;
fn inner(&self) -> &Self::Inner {
&self.data &self.data
} }
pub fn data_mut(&mut self) -> &mut T { fn log_base2k(&self) -> usize {
&mut self.data self.log_base2k
}
fn log_k(&self) -> usize {
self.log_k
} }
} }
pub(crate) type CtVecZnx<C> = Ciphertext<VecZnx<C>>; impl<C, B: Backend> VecZnxDftToMut<B> for RLWECtDft<C, B>
pub(crate) type CtVecZnxDft<C, B: Backend> = Ciphertext<VecZnxDft<C, B>>;
pub(crate) type CtMatZnxDft<C, B: Backend> = Ciphertext<MatZnxDft<C, B>>;
pub(crate) type PtVecZnx<C> = Plaintext<VecZnx<C>>;
pub(crate) type PtVecZnxDft<C, B: Backend> = Plaintext<VecZnxDft<C, B>>;
pub(crate) type PtMatZnxDft<C, B: Backend> = Plaintext<MatZnxDft<C, B>>;
impl<D> VecZnxToMut for Ciphertext<D>
where where
D: VecZnxToMut, VecZnxDft<C, B>: VecZnxDftToMut<B>,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
self.data_mut().to_mut()
}
}
impl<D> VecZnxToRef for Ciphertext<D>
where
D: VecZnxToRef,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
self.data().to_ref()
}
}
impl Ciphertext<VecZnx<Vec<u8>>> {
pub fn new<B: Backend>(module: &Module<B>, log_base2k: usize, log_q: usize, cols: usize) -> Self {
Self {
data: module.new_vec_znx(cols, derive_size(log_base2k, log_q)),
log_base2k: log_base2k,
log_q: log_q,
}
}
}
impl<D> VecZnxToMut for Plaintext<D>
where
D: VecZnxToMut,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
self.data_mut().to_mut()
}
}
impl<D> VecZnxToRef for Plaintext<D>
where
D: VecZnxToRef,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
self.data().to_ref()
}
}
impl Plaintext<VecZnx<Vec<u8>>> {
pub fn new<B: Backend>(module: &Module<B>, log_base2k: usize, log_q: usize) -> Self {
Self {
data: module.new_vec_znx(1, derive_size(log_base2k, log_q)),
log_base2k: log_base2k,
log_q: log_q,
}
}
}
impl<D, B: Backend> VecZnxDftToMut<B> for Ciphertext<D>
where
D: VecZnxDftToMut<B>,
{ {
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
self.data_mut().to_mut() self.data.to_mut()
} }
} }
impl<D, B: Backend> VecZnxDftToRef<B> for Ciphertext<D> impl<C, B: Backend> VecZnxDftToRef<B> for RLWECtDft<C, B>
where where
D: VecZnxDftToRef<B>, VecZnxDft<C, B>: VecZnxDftToRef<B>,
{ {
fn to_ref(&self) -> VecZnxDft<&[u8], B> { fn to_ref(&self) -> VecZnxDft<&[u8], B> {
self.data().to_ref() self.data.to_ref()
} }
} }
impl<B: Backend> Ciphertext<VecZnxDft<Vec<u8>, B>> { pub(crate) fn derive_size(log_base2k: usize, log_k: usize) -> usize {
pub fn new(module: &Module<B>, log_base2k: usize, log_q: usize, cols: usize) -> Self { (log_k + log_base2k - 1) / log_base2k
Self {
data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_q)),
log_base2k: log_base2k,
log_q: log_q,
}
}
}
impl<D, B: Backend> MatZnxDftToMut<B> for Ciphertext<D>
where
D: MatZnxDftToMut<B>,
{
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
self.data_mut().to_mut()
}
}
impl<D, B: Backend> MatZnxDftToRef<B> for Ciphertext<D>
where
D: MatZnxDftToRef<B>,
{
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
self.data().to_ref()
}
}
impl<B: Backend> Ciphertext<MatZnxDft<Vec<u8>, B>> {
pub fn new(module: &Module<B>, log_base2k: usize, rows: usize, cols_in: usize, cols_out: usize, log_q: usize) -> Self {
Self {
data: module.new_mat_znx_dft(rows, cols_in, cols_out, derive_size(log_base2k, log_q)),
log_base2k: log_base2k,
log_q: log_q,
}
}
}
pub(crate) fn derive_size(log_base2k: usize, log_q: usize) -> usize {
(log_q + log_base2k - 1) / log_base2k
} }

View File

@@ -1,161 +1,166 @@
use std::cmp::min;
use base2k::{ use base2k::{
AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx,
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxToMut, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut,
VecZnxToRef, ZnxInfos, VecZnxDftToRef, VecZnxToMut, VecZnxToRef,
}; };
use sampling::source::Source; use sampling::source::Source;
use crate::{ use crate::{
elem::{Ciphertext, Infos, Plaintext}, elem::{Infos, RLWECt, RLWECtDft, RLWEPt},
keys::SecretKey, keys::SecretKeyDft,
}; };
pub trait EncryptSk<B: Backend, C, P> { pub fn encrypt_rlwe_sk_scratch_bytes<B: Backend>(module: &Module<B>, size: usize) -> usize {
fn encrypt<S>( (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size)
module: &Module<B>,
res: &mut Ciphertext<C>,
pt: Option<&Plaintext<P>>,
sk: &SecretKey<S>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch,
sigma: f64,
bound: f64,
) where
S: ScalarZnxDftToRef<B>;
fn encrypt_scratch_bytes(module: &Module<B>, size: usize) -> usize;
} }
impl<C, P> EncryptSk<FFT64, C, P> for Ciphertext<C> pub fn encrypt_rlwe_sk<C, P, S>(
where module: &Module<FFT64>,
C: VecZnxToMut + ZnxInfos, ct: &mut RLWECt<C>,
P: VecZnxToRef + ZnxInfos, pt: Option<&RLWEPt<P>>,
sk: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch,
sigma: f64,
bound: f64,
) where
VecZnx<C>: VecZnxToMut + VecZnxToRef,
VecZnx<P>: VecZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{ {
fn encrypt<S>( let log_base2k: usize = ct.log_base2k();
module: &Module<FFT64>, let log_k: usize = ct.log_k();
ct: &mut Ciphertext<C>, let size: usize = ct.size();
pt: Option<&Plaintext<P>>,
sk: &SecretKey<S>, // c1 = a
source_xa: &mut Source, ct.data.fill_uniform(log_base2k, 1, size, source_xa);
source_xe: &mut Source,
scratch: &mut Scratch, let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size);
sigma: f64,
bound: f64,
) where
S: ScalarZnxDftToRef<FFT64>,
{ {
let log_base2k: usize = ct.log_base2k(); let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size);
let log_q: usize = ct.log_q(); module.vec_znx_dft(&mut c0_dft, 0, ct, 1);
let mut ct_mut: VecZnx<&mut [u8]> = ct.to_mut();
let size: usize = ct_mut.size();
// c1 = a // c0_dft = DFT(a) * DFT(s)
ct_mut.fill_uniform(log_base2k, 1, size, source_xa); module.svp_apply_inplace(&mut c0_dft, 0, sk, 0);
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); // c0_big = IDFT(c0_dft)
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
{
let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size);
module.vec_znx_dft(&mut c0_dft, 0, &ct_mut, 1);
// c0_dft = DFT(a) * DFT(s)
module.svp_apply_inplace(&mut c0_dft, 0, &sk.data().to_ref(), 0);
// c0_big = IDFT(c0_dft)
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
}
// c0_big = m - c0_big
if let Some(pt) = pt {
module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0);
}
// c0_big += e
c0_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound);
// c0 = norm(c0_big = -as + m + e)
module.vec_znx_big_normalize(log_base2k, &mut ct_mut, 0, &c0_big, 0, scratch_1);
} }
fn encrypt_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize { // c0_big = m - c0_big
(module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) if let Some(pt) = pt {
module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0);
} }
// c0_big += e
c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound);
// c0 = norm(c0_big = -as + m + e)
module.vec_znx_big_normalize(log_base2k, ct, 0, &c0_big, 0, scratch_1);
} }
impl<C> Ciphertext<C> pub fn decrypt_rlwe<P, C, S>(
where module: &Module<FFT64>,
C: VecZnxToMut + ZnxInfos, pt: &mut RLWEPt<P>,
ct: &RLWECt<C>,
sk: &SecretKeyDft<S, FFT64>,
scratch: &mut Scratch,
) where
VecZnx<P>: VecZnxToMut + VecZnxToRef,
VecZnx<C>: VecZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{ {
let size: usize = min(pt.size(), ct.size());
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size);
{
let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size);
module.vec_znx_dft(&mut c0_dft, 0, ct, 1);
// c0_dft = DFT(a) * DFT(s)
module.svp_apply_inplace(&mut c0_dft, 0, sk, 0);
// c0_big = IDFT(c0_dft)
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
}
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
module.vec_znx_big_add_small_inplace(&mut c0_big, 0, ct, 0);
// pt = norm(BIG(m + e))
module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1);
pt.log_base2k = ct.log_base2k();
pt.log_k = min(pt.log_k(), ct.log_k());
}
pub fn decrypt_rlwe_scratch_bytes<B: Backend>(module: &Module<B>, size: usize) -> usize {
(module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size)
}
impl<C> RLWECt<C> {
pub fn encrypt_sk<P, S>( pub fn encrypt_sk<P, S>(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
pt: Option<&Plaintext<P>>, pt: Option<&RLWEPt<P>>,
sk: &SecretKey<S>, sk: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source, source_xa: &mut Source,
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch, scratch: &mut Scratch,
sigma: f64, sigma: f64,
bound: f64, bound: f64,
) where ) where
P: VecZnxToRef + ZnxInfos, VecZnx<C>: VecZnxToMut + VecZnxToRef,
S: ScalarZnxDftToRef<FFT64>, VecZnx<P>: VecZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{ {
<Self as EncryptSk<FFT64, _, _>>::encrypt( encrypt_rlwe_sk(
module, self, pt, sk, source_xa, source_xe, scratch, sigma, bound, module, self, pt, sk, source_xa, source_xe, scratch, sigma, bound,
); )
} }
pub fn encrypt_sk_scratch_bytes<P>(module: &Module<FFT64>, size: usize) -> usize pub fn decrypt<P, S>(&self, module: &Module<FFT64>, pt: &mut RLWEPt<P>, sk: &SecretKeyDft<S, FFT64>, scratch: &mut Scratch)
where where
Self: EncryptSk<FFT64, C, P>, VecZnx<P>: VecZnxToMut + VecZnxToRef,
VecZnx<C>: VecZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{ {
<Self as EncryptSk<FFT64, C, P>>::encrypt_scratch_bytes(module, size) decrypt_rlwe(module, pt, self, sk, scratch);
} }
} }
pub trait EncryptZeroSk<B: Backend, D> { pub(crate) fn encrypt_rlwe_zero_dft_scratch_bytes<B: Backend>(module: &Module<FFT64>, size: usize) -> usize {
fn encrypt_zero<S>( (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size)
module: &Module<B>,
res: &mut D,
sk: &SecretKey<S>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch,
sigma: f64,
bound: f64,
) where
S: ScalarZnxDftToRef<B>;
fn encrypt_zero_scratch_bytes(module: &Module<B>, size: usize) -> usize;
} }
impl<C> EncryptZeroSk<FFT64, C> for C impl<C> RLWECtDft<C, FFT64> {
where
C: VecZnxDftToMut<FFT64> + ZnxInfos + Infos,
{
fn encrypt_zero<S>( fn encrypt_zero<S>(
module: &Module<FFT64>, module: &Module<FFT64>,
ct: &mut C, ct: &mut RLWECtDft<C, FFT64>,
sk: &SecretKey<S>, sk: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source, source_xa: &mut Source,
source_xe: &mut Source, source_xe: &mut Source,
scratch: &mut Scratch, scratch: &mut Scratch,
sigma: f64, sigma: f64,
bound: f64, bound: f64,
) where ) where
S: ScalarZnxDftToRef<FFT64>, VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{ {
let log_base2k: usize = ct.log_base2k(); let log_base2k: usize = ct.log_base2k();
let log_q: usize = ct.log_q(); let log_k: usize = ct.log_k();
let mut ct_mut: VecZnxDft<&mut [u8], FFT64> = ct.to_mut(); let size: usize = ct.size();
let size: usize = ct_mut.size();
// ct[1] = DFT(a) // ct[1] = DFT(a)
{ {
let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size); let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size);
tmp_znx.fill_uniform(log_base2k, 1, size, source_xa); tmp_znx.fill_uniform(log_base2k, 1, size, source_xa);
module.vec_znx_dft(&mut ct_mut, 1, &tmp_znx, 0); module.vec_znx_dft(ct, 1, &tmp_znx, 0);
} }
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size);
@@ -163,22 +168,22 @@ where
{ {
let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size);
// c0_dft = DFT(a) * DFT(s) // c0_dft = DFT(a) * DFT(s)
module.svp_apply(&mut tmp_dft, 0, &sk.data().to_ref(), 0, &ct_mut, 1); module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1);
// c0_big = IDFT(c0_dft) // c0_big = IDFT(c0_dft)
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0); module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0);
} }
// c0_big += e // c0_big += e
c0_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound); c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound);
// c0 = norm(c0_big = -as + e) // c0 = norm(c0_big = -as + e)
let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size); let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size);
module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2); module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2);
// ct[0] = DFT(-as + e) // ct[0] = DFT(-as + e)
module.vec_znx_dft(&mut ct_mut, 0, &tmp_znx, 0); module.vec_znx_dft(ct, 0, &tmp_znx, 0);
} }
fn encrypt_zero_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize{ fn encrypt_zero_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
(module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size))
+ module.bytes_of_vec_znx_big(1, size) + module.bytes_of_vec_znx_big(1, size)
+ module.bytes_of_vec_znx(1, size) + module.bytes_of_vec_znx(1, size)
@@ -188,42 +193,80 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use base2k::{FFT64, Module, ScratchOwned, VecZnx, Scalar}; use base2k::{Encoding, FFT64, Module, ScratchOwned, ZnxZero};
use itertools::izip;
use sampling::source::Source; use sampling::source::Source;
use crate::{elem::{Ciphertext, Infos, Plaintext}, keys::SecretKey}; use crate::{
elem::{Infos, RLWECt, RLWEPt},
keys::{SecretKey, SecretKeyDft},
};
use super::{decrypt_rlwe_scratch_bytes, encrypt_rlwe_sk_scratch_bytes};
#[test] #[test]
fn encrypt_sk_vec_znx_fft64() { fn encrypt_sk_vec_znx_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(32); let module: Module<FFT64> = Module::<FFT64>::new(32);
let log_base2k: usize = 8; let log_base2k: usize = 8;
let log_q: usize = 54; let log_k_ct: usize = 54;
let log_k_pt: usize = 40;
let sigma: f64 = 3.2; let sigma: f64 = 3.2;
let bound: f64 = sigma * 6; let bound: f64 = sigma * 6.0;
let mut ct: Ciphertext<VecZnx<Vec<u8>>> = Ciphertext::<VecZnx<Vec<u8>>>::new(&module, log_base2k, log_q, 2); let mut ct: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_ct, 2);
let mut pt: Plaintext<VecZnx<Vec<u8>>> = Plaintext::<VecZnx<Vec<u8>>>::new(&module, log_base2k, log_q); let mut pt: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_pt);
let mut source_xe = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned =
ScratchOwned::new(encrypt_rlwe_sk_scratch_bytes(&module, ct.size()) | decrypt_rlwe_scratch_bytes(&module, ct.size()));
let mut scratch: ScratchOwned = ScratchOwned::new(ct.encrypt_encsk_scratch_bytes(&module, ct.size())); let sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk_dft.dft(&module, &sk);
let mut sk: SecretKey<Scalar<Vec<u8>>> = SecretKey::new(&module); let mut data_want: Vec<i64> = vec![0i64; module.n()];
let mut sk_prep
sk.svp_prepare(&module, &mut sk_prep); data_want
.iter_mut()
.for_each(|x| *x = source_xa.next_i64() & 0xFF);
pt.data
.encode_vec_i64(0, log_base2k, log_k_pt, &data_want, 10);
ct.encrypt_sk( ct.encrypt_sk(
&module, &module,
Some(&pt), Some(&pt),
&sk_prep, &sk_dft,
&mut source_xa, &mut source_xa,
&mut source_xe, &mut source_xe,
scratch.borrow(), scratch.borrow(),
sigma, sigma,
bound, bound,
); );
pt.data.zero();
ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
let mut data_have: Vec<i64> = vec![0i64; module.n()];
pt.data
.decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have);
let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64;
izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| {
let b_scaled = (*b as f64) / scale;
assert!(
(*a as f64 - b_scaled).abs() < 0.1,
"{} {}",
*a as f64,
b_scaled
)
});
module.free();
} }
} }

View File

@@ -1,31 +1,27 @@
use base2k::{ use base2k::{
Backend, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, ZnxInfos, FFT64 Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut,
ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut,
}; };
use sampling::source::Source; use sampling::source::Source;
use crate::elem::derive_size; use crate::elem::derive_size;
pub struct SecretKey<T> { pub struct SecretKey<T> {
data: T, pub data: ScalarZnx<T>,
} }
impl<T> SecretKey<T> { impl SecretKey<Vec<u8>> {
pub fn data(&self) -> &T {
&self.data
}
pub fn data_mut(&mut self) -> &mut T {
&mut self.data
}
}
impl SecretKey<Scalar<Vec<u8>>> {
pub fn new<B: Backend>(module: &Module<B>) -> Self { pub fn new<B: Backend>(module: &Module<B>) -> Self {
Self { Self {
data: module.new_scalar(1), data: module.new_scalar(1),
} }
} }
}
impl<S> SecretKey<S>
where
S: AsMut<[u8]> + AsRef<[u8]>,
{
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
self.data.fill_ternary_prob(0, prob, source); self.data.fill_ternary_prob(0, prob, source);
} }
@@ -33,27 +29,66 @@ impl SecretKey<Scalar<Vec<u8>>> {
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
self.data.fill_ternary_hw(0, hw, source); self.data.fill_ternary_hw(0, hw, source);
} }
}
pub fn svp_prepare<D>(&self, module: &Module<FFT64>, sk_prep: &mut SecretKey<ScalarZnxDft<D, FFT64>>) impl<C> ScalarZnxToMut for SecretKey<C>
where where
ScalarZnxDft<D, base2k::FFT64>: ScalarZnxDftToMut<base2k::FFT64>, ScalarZnx<C>: ScalarZnxToMut,
{ {
module.svp_prepare(&mut sk_prep.data, 0, &self.data, 0) fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
self.data.to_mut()
} }
} }
type SecretKeyPrep<C, B> = SecretKey<ScalarZnxDft<C, B>>; impl<C> ScalarZnxToRef for SecretKey<C>
where
ScalarZnx<C>: ScalarZnxToRef,
{
fn to_ref(&self) -> ScalarZnx<&[u8]> {
self.data.to_ref()
}
}
impl<B: Backend> SecretKey<ScalarZnxDft<Vec<u8>, B>> { pub struct SecretKeyDft<T, B: Backend> {
pub fn new(module: &Module<B>) -> Self{ pub data: ScalarZnxDft<T, B>,
Self{ }
data: module.new_scalar_znx_dft(1)
impl<B: Backend> SecretKeyDft<Vec<u8>, B> {
pub fn new(module: &Module<B>) -> Self {
Self {
data: module.new_scalar_znx_dft(1),
} }
} }
pub fn dft<S>(&mut self, module: &Module<FFT64>, sk: &SecretKey<S>)
where
SecretKeyDft<Vec<u8>, B>: ScalarZnxDftToMut<base2k::FFT64>,
SecretKey<S>: ScalarZnxToRef,
{
module.svp_prepare(self, 0, sk, 0)
}
}
impl<C, B: Backend> ScalarZnxDftToMut<B> for SecretKeyDft<C, B>
where
ScalarZnxDft<C, B>: ScalarZnxDftToMut<B>,
{
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> ScalarZnxDftToRef<B> for SecretKeyDft<C, B>
where
ScalarZnxDft<C, B>: ScalarZnxDftToRef<B>,
{
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
self.data.to_ref()
}
} }
pub struct PublicKey<D, B: Backend> { pub struct PublicKey<D, B: Backend> {
data: VecZnxDft<D, B>, pub data: VecZnxDft<D, B>,
} }
impl<B: Backend> PublicKey<Vec<u8>, B> { impl<B: Backend> PublicKey<Vec<u8>, B> {