added sk encryption

This commit is contained in:
Jean-Philippe Bossuat
2025-05-07 12:05:12 +02:00
parent 240884db8d
commit 6ce525e5a1
6 changed files with 333 additions and 299 deletions

View File

@@ -1,5 +1,5 @@
use base2k::{
AddNormal, Encoding, FFT64, FillUniform, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps,
AddNormal, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps,
ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc,
VecZnxDftOps, VecZnxOps, ZnxInfos,
};
@@ -20,7 +20,7 @@ fn main() {
let mut source: Source = Source::new(seed);
// s <- Z_{-1, 0, 1}[X]/(X^{N}+1)
let mut s: Scalar<Vec<u8>> = module.new_scalar(1);
let mut s: ScalarZnx<Vec<u8>> = module.new_scalar(1);
s.fill_ternary_prob(0, 0.5, &mut source);
// Buffer to store s in the DFT domain

View File

@@ -2,8 +2,8 @@ use crate::ffi::svp;
use crate::ffi::vec_znx_dft::vec_znx_dft_t;
use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut};
use crate::{
Backend, FFT64, Module, ScalarZnxToRef, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef,
VecZnxDft, VecZnxDftToMut, VecZnxDftToRef,
Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft,
VecZnxDftToMut, VecZnxDftToRef,
};
pub trait ScalarZnxDftAlloc<B: Backend> {

View File

@@ -107,21 +107,13 @@ where
{
fn zero(&mut self) {
unsafe {
std::ptr::write_bytes(
self.as_mut_ptr(),
0,
self.n() * self.poly_count(),
);
std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * self.poly_count());
}
}
fn zero_at(&mut self, i: usize, j: usize) {
unsafe {
std::ptr::write_bytes(
self.at_mut_ptr(i, j),
0,
self.n(),
);
std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.n());
}
}
}

View File

@@ -1,6 +1,6 @@
use base2k::{
Backend, DataView, DataViewMut, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnxDftToRef, VecZnx,
VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos,
Backend, Module, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef,
ZnxInfos,
};
pub trait Infos {
@@ -31,7 +31,7 @@ pub trait Infos {
/// Returns the number of size per polynomial.
fn size(&self) -> usize {
let size: usize = self.inner().size();
debug_assert_eq!(size, derive_size(self.log_base2k(), self.log_q()));
debug_assert_eq!(size, derive_size(self.log_base2k(), self.log_k()));
size
}
@@ -43,18 +43,18 @@ pub trait Infos {
/// Returns the base 2 logarithm of the ciphertext base.
fn log_base2k(&self) -> usize;
/// Returns the base 2 logarithm of the ciphertext modulus.
fn log_q(&self) -> usize;
/// Returns the bit precision of the ciphertext.
fn log_k(&self) -> usize;
}
pub struct RLWECt<C>{
data: VecZnx<C>,
log_base2k: usize,
log_q: usize,
pub struct RLWECt<C> {
pub data: VecZnx<C>,
pub log_base2k: usize,
pub log_k: usize,
}
impl<T: ZnxInfos> Infos for RLWECt<T> {
type Inner = T;
impl<T> Infos for RLWECt<T> {
type Inner = VecZnx<T>;
fn inner(&self) -> &Self::Inner {
&self.data
@@ -64,32 +64,37 @@ impl<T: ZnxInfos> Infos for RLWECt<T> {
self.log_base2k
}
fn log_q(&self) -> usize {
self.log_q
fn log_k(&self) -> usize {
self.log_k
}
}
impl<D> DataView for Ciphertext<D> {
type D = D;
fn data(&self) -> &Self::D {
&self.data
impl<C> VecZnxToMut for RLWECt<C>
where
VecZnx<C>: VecZnxToMut,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
self.data.to_mut()
}
}
impl<D> DataViewMut for Ciphertext<D> {
fn data_mut(&mut self) -> &mut Self::D {
&mut self.data
impl<C> VecZnxToRef for RLWECt<C>
where
VecZnx<C>: VecZnxToRef,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
self.data.to_ref()
}
}
pub struct Plaintext<T> {
data: T,
log_base2k: usize,
log_q: usize,
pub struct RLWEPt<C> {
pub data: VecZnx<C>,
pub log_base2k: usize,
pub log_k: usize,
}
impl<T: ZnxInfos> Infos for Plaintext<T> {
type Inner = T;
impl<T> Infos for RLWEPt<T> {
type Inner = VecZnx<T>;
fn inner(&self) -> &Self::Inner {
&self.data
@@ -99,140 +104,99 @@ impl<T: ZnxInfos> Infos for Plaintext<T> {
self.log_base2k
}
fn log_q(&self) -> usize {
self.log_q
fn log_k(&self) -> usize {
self.log_k
}
}
impl<T> Plaintext<T> {
pub fn data(&self) -> &T {
impl<C> VecZnxToMut for RLWEPt<C>
where
VecZnx<C>: VecZnxToMut,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
self.data.to_mut()
}
}
impl<C> VecZnxToRef for RLWEPt<C>
where
VecZnx<C>: VecZnxToRef,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
self.data.to_ref()
}
}
impl RLWECt<Vec<u8>> {
pub fn new<B: Backend>(module: &Module<B>, log_base2k: usize, log_k: usize, cols: usize) -> Self {
Self {
data: module.new_vec_znx(cols, derive_size(log_base2k, log_k)),
log_base2k: log_base2k,
log_k: log_k,
}
}
}
impl RLWEPt<Vec<u8>> {
pub fn new<B: Backend>(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
Self {
data: module.new_vec_znx(1, derive_size(log_base2k, log_k)),
log_base2k: log_base2k,
log_k: log_k,
}
}
}
pub struct RLWECtDft<C, B: Backend> {
pub data: VecZnxDft<C, B>,
pub log_base2k: usize,
pub log_k: usize,
}
impl<B: Backend> RLWECtDft<Vec<u8>, B> {
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
Self {
data: module.new_vec_znx_dft(1, derive_size(log_base2k, log_k)),
log_base2k: log_base2k,
log_k: log_k,
}
}
}
impl<T, B: Backend> Infos for RLWECtDft<T, B> {
type Inner = VecZnxDft<T, B>;
fn inner(&self) -> &Self::Inner {
&self.data
}
pub fn data_mut(&mut self) -> &mut T {
&mut self.data
fn log_base2k(&self) -> usize {
self.log_base2k
}
fn log_k(&self) -> usize {
self.log_k
}
}
pub(crate) type CtVecZnx<C> = Ciphertext<VecZnx<C>>;
pub(crate) type CtVecZnxDft<C, B: Backend> = Ciphertext<VecZnxDft<C, B>>;
pub(crate) type CtMatZnxDft<C, B: Backend> = Ciphertext<MatZnxDft<C, B>>;
pub(crate) type PtVecZnx<C> = Plaintext<VecZnx<C>>;
pub(crate) type PtVecZnxDft<C, B: Backend> = Plaintext<VecZnxDft<C, B>>;
pub(crate) type PtMatZnxDft<C, B: Backend> = Plaintext<MatZnxDft<C, B>>;
impl<D> VecZnxToMut for Ciphertext<D>
impl<C, B: Backend> VecZnxDftToMut<B> for RLWECtDft<C, B>
where
D: VecZnxToMut,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
self.data_mut().to_mut()
}
}
impl<D> VecZnxToRef for Ciphertext<D>
where
D: VecZnxToRef,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
self.data().to_ref()
}
}
impl Ciphertext<VecZnx<Vec<u8>>> {
pub fn new<B: Backend>(module: &Module<B>, log_base2k: usize, log_q: usize, cols: usize) -> Self {
Self {
data: module.new_vec_znx(cols, derive_size(log_base2k, log_q)),
log_base2k: log_base2k,
log_q: log_q,
}
}
}
impl<D> VecZnxToMut for Plaintext<D>
where
D: VecZnxToMut,
{
fn to_mut(&mut self) -> VecZnx<&mut [u8]> {
self.data_mut().to_mut()
}
}
impl<D> VecZnxToRef for Plaintext<D>
where
D: VecZnxToRef,
{
fn to_ref(&self) -> VecZnx<&[u8]> {
self.data().to_ref()
}
}
impl Plaintext<VecZnx<Vec<u8>>> {
pub fn new<B: Backend>(module: &Module<B>, log_base2k: usize, log_q: usize) -> Self {
Self {
data: module.new_vec_znx(1, derive_size(log_base2k, log_q)),
log_base2k: log_base2k,
log_q: log_q,
}
}
}
impl<D, B: Backend> VecZnxDftToMut<B> for Ciphertext<D>
where
D: VecZnxDftToMut<B>,
VecZnxDft<C, B>: VecZnxDftToMut<B>,
{
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
self.data_mut().to_mut()
self.data.to_mut()
}
}
impl<D, B: Backend> VecZnxDftToRef<B> for Ciphertext<D>
impl<C, B: Backend> VecZnxDftToRef<B> for RLWECtDft<C, B>
where
D: VecZnxDftToRef<B>,
VecZnxDft<C, B>: VecZnxDftToRef<B>,
{
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
self.data().to_ref()
self.data.to_ref()
}
}
impl<B: Backend> Ciphertext<VecZnxDft<Vec<u8>, B>> {
pub fn new(module: &Module<B>, log_base2k: usize, log_q: usize, cols: usize) -> Self {
Self {
data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_q)),
log_base2k: log_base2k,
log_q: log_q,
}
}
}
impl<D, B: Backend> MatZnxDftToMut<B> for Ciphertext<D>
where
D: MatZnxDftToMut<B>,
{
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
self.data_mut().to_mut()
}
}
impl<D, B: Backend> MatZnxDftToRef<B> for Ciphertext<D>
where
D: MatZnxDftToRef<B>,
{
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
self.data().to_ref()
}
}
impl<B: Backend> Ciphertext<MatZnxDft<Vec<u8>, B>> {
pub fn new(module: &Module<B>, log_base2k: usize, rows: usize, cols_in: usize, cols_out: usize, log_q: usize) -> Self {
Self {
data: module.new_mat_znx_dft(rows, cols_in, cols_out, derive_size(log_base2k, log_q)),
log_base2k: log_base2k,
log_q: log_q,
}
}
}
pub(crate) fn derive_size(log_base2k: usize, log_q: usize) -> usize {
(log_q + log_base2k - 1) / log_base2k
pub(crate) fn derive_size(log_base2k: usize, log_k: usize) -> usize {
(log_k + log_base2k - 1) / log_base2k
}

View File

@@ -1,161 +1,166 @@
use std::cmp::min;
use base2k::{
AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc,
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxToMut,
VecZnxToRef, ZnxInfos,
AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx,
VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut,
VecZnxDftToRef, VecZnxToMut, VecZnxToRef,
};
use sampling::source::Source;
use crate::{
elem::{Ciphertext, Infos, Plaintext},
keys::SecretKey,
elem::{Infos, RLWECt, RLWECtDft, RLWEPt},
keys::SecretKeyDft,
};
pub trait EncryptSk<B: Backend, C, P> {
fn encrypt<S>(
module: &Module<B>,
res: &mut Ciphertext<C>,
pt: Option<&Plaintext<P>>,
sk: &SecretKey<S>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch,
sigma: f64,
bound: f64,
) where
S: ScalarZnxDftToRef<B>;
fn encrypt_scratch_bytes(module: &Module<B>, size: usize) -> usize;
pub fn encrypt_rlwe_sk_scratch_bytes<B: Backend>(module: &Module<B>, size: usize) -> usize {
(module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size)
}
impl<C, P> EncryptSk<FFT64, C, P> for Ciphertext<C>
where
C: VecZnxToMut + ZnxInfos,
P: VecZnxToRef + ZnxInfos,
pub fn encrypt_rlwe_sk<C, P, S>(
module: &Module<FFT64>,
ct: &mut RLWECt<C>,
pt: Option<&RLWEPt<P>>,
sk: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch,
sigma: f64,
bound: f64,
) where
VecZnx<C>: VecZnxToMut + VecZnxToRef,
VecZnx<P>: VecZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
fn encrypt<S>(
module: &Module<FFT64>,
ct: &mut Ciphertext<C>,
pt: Option<&Plaintext<P>>,
sk: &SecretKey<S>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch,
sigma: f64,
bound: f64,
) where
S: ScalarZnxDftToRef<FFT64>,
let log_base2k: usize = ct.log_base2k();
let log_k: usize = ct.log_k();
let size: usize = ct.size();
// c1 = a
ct.data.fill_uniform(log_base2k, 1, size, source_xa);
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size);
{
let log_base2k: usize = ct.log_base2k();
let log_q: usize = ct.log_q();
let mut ct_mut: VecZnx<&mut [u8]> = ct.to_mut();
let size: usize = ct_mut.size();
let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size);
module.vec_znx_dft(&mut c0_dft, 0, ct, 1);
// c1 = a
ct_mut.fill_uniform(log_base2k, 1, size, source_xa);
// c0_dft = DFT(a) * DFT(s)
module.svp_apply_inplace(&mut c0_dft, 0, sk, 0);
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size);
{
let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size);
module.vec_znx_dft(&mut c0_dft, 0, &ct_mut, 1);
// c0_dft = DFT(a) * DFT(s)
module.svp_apply_inplace(&mut c0_dft, 0, &sk.data().to_ref(), 0);
// c0_big = IDFT(c0_dft)
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
}
// c0_big = m - c0_big
if let Some(pt) = pt {
module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0);
}
// c0_big += e
c0_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound);
// c0 = norm(c0_big = -as + m + e)
module.vec_znx_big_normalize(log_base2k, &mut ct_mut, 0, &c0_big, 0, scratch_1);
// c0_big = IDFT(c0_dft)
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
}
fn encrypt_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
(module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size)
// c0_big = m - c0_big
if let Some(pt) = pt {
module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0);
}
// c0_big += e
c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound);
// c0 = norm(c0_big = -as + m + e)
module.vec_znx_big_normalize(log_base2k, ct, 0, &c0_big, 0, scratch_1);
}
impl<C> Ciphertext<C>
where
C: VecZnxToMut + ZnxInfos,
pub fn decrypt_rlwe<P, C, S>(
module: &Module<FFT64>,
pt: &mut RLWEPt<P>,
ct: &RLWECt<C>,
sk: &SecretKeyDft<S, FFT64>,
scratch: &mut Scratch,
) where
VecZnx<P>: VecZnxToMut + VecZnxToRef,
VecZnx<C>: VecZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
let size: usize = min(pt.size(), ct.size());
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size);
{
let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size);
module.vec_znx_dft(&mut c0_dft, 0, ct, 1);
// c0_dft = DFT(a) * DFT(s)
module.svp_apply_inplace(&mut c0_dft, 0, sk, 0);
// c0_big = IDFT(c0_dft)
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
}
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
module.vec_znx_big_add_small_inplace(&mut c0_big, 0, ct, 0);
// pt = norm(BIG(m + e))
module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1);
pt.log_base2k = ct.log_base2k();
pt.log_k = min(pt.log_k(), ct.log_k());
}
pub fn decrypt_rlwe_scratch_bytes<B: Backend>(module: &Module<B>, size: usize) -> usize {
(module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size)
}
impl<C> RLWECt<C> {
pub fn encrypt_sk<P, S>(
&mut self,
module: &Module<FFT64>,
pt: Option<&Plaintext<P>>,
sk: &SecretKey<S>,
pt: Option<&RLWEPt<P>>,
sk: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch,
sigma: f64,
bound: f64,
) where
P: VecZnxToRef + ZnxInfos,
S: ScalarZnxDftToRef<FFT64>,
VecZnx<C>: VecZnxToMut + VecZnxToRef,
VecZnx<P>: VecZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
<Self as EncryptSk<FFT64, _, _>>::encrypt(
encrypt_rlwe_sk(
module, self, pt, sk, source_xa, source_xe, scratch, sigma, bound,
);
)
}
pub fn encrypt_sk_scratch_bytes<P>(module: &Module<FFT64>, size: usize) -> usize
pub fn decrypt<P, S>(&self, module: &Module<FFT64>, pt: &mut RLWEPt<P>, sk: &SecretKeyDft<S, FFT64>, scratch: &mut Scratch)
where
Self: EncryptSk<FFT64, C, P>,
VecZnx<P>: VecZnxToMut + VecZnxToRef,
VecZnx<C>: VecZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
<Self as EncryptSk<FFT64, C, P>>::encrypt_scratch_bytes(module, size)
decrypt_rlwe(module, pt, self, sk, scratch);
}
}
pub trait EncryptZeroSk<B: Backend, D> {
fn encrypt_zero<S>(
module: &Module<B>,
res: &mut D,
sk: &SecretKey<S>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch,
sigma: f64,
bound: f64,
) where
S: ScalarZnxDftToRef<B>;
fn encrypt_zero_scratch_bytes(module: &Module<B>, size: usize) -> usize;
pub(crate) fn encrypt_rlwe_zero_dft_scratch_bytes<B: Backend>(module: &Module<FFT64>, size: usize) -> usize {
(module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size)
}
impl<C> EncryptZeroSk<FFT64, C> for C
where
C: VecZnxDftToMut<FFT64> + ZnxInfos + Infos,
{
impl<C> RLWECtDft<C, FFT64> {
fn encrypt_zero<S>(
module: &Module<FFT64>,
ct: &mut C,
sk: &SecretKey<S>,
ct: &mut RLWECtDft<C, FFT64>,
sk: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch,
sigma: f64,
bound: f64,
) where
S: ScalarZnxDftToRef<FFT64>,
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
let log_base2k: usize = ct.log_base2k();
let log_q: usize = ct.log_q();
let mut ct_mut: VecZnxDft<&mut [u8], FFT64> = ct.to_mut();
let size: usize = ct_mut.size();
let log_k: usize = ct.log_k();
let size: usize = ct.size();
// ct[1] = DFT(a)
{
let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size);
tmp_znx.fill_uniform(log_base2k, 1, size, source_xa);
module.vec_znx_dft(&mut ct_mut, 1, &tmp_znx, 0);
module.vec_znx_dft(ct, 1, &tmp_znx, 0);
}
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size);
@@ -163,22 +168,22 @@ where
{
let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size);
// c0_dft = DFT(a) * DFT(s)
module.svp_apply(&mut tmp_dft, 0, &sk.data().to_ref(), 0, &ct_mut, 1);
module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1);
// c0_big = IDFT(c0_dft)
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0);
}
// c0_big += e
c0_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound);
c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound);
// c0 = norm(c0_big = -as + e)
let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size);
module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2);
// ct[0] = DFT(-as + e)
module.vec_znx_dft(&mut ct_mut, 0, &tmp_znx, 0);
module.vec_znx_dft(ct, 0, &tmp_znx, 0);
}
fn encrypt_zero_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize{
fn encrypt_zero_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
(module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size))
+ module.bytes_of_vec_znx_big(1, size)
+ module.bytes_of_vec_znx(1, size)
@@ -188,42 +193,80 @@ where
#[cfg(test)]
mod tests {
use base2k::{FFT64, Module, ScratchOwned, VecZnx, Scalar};
use base2k::{Encoding, FFT64, Module, ScratchOwned, ZnxZero};
use itertools::izip;
use sampling::source::Source;
use crate::{elem::{Ciphertext, Infos, Plaintext}, keys::SecretKey};
use crate::{
elem::{Infos, RLWECt, RLWEPt},
keys::{SecretKey, SecretKeyDft},
};
use super::{decrypt_rlwe_scratch_bytes, encrypt_rlwe_sk_scratch_bytes};
#[test]
fn encrypt_sk_vec_znx_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(32);
let log_base2k: usize = 8;
let log_q: usize = 54;
let log_k_ct: usize = 54;
let log_k_pt: usize = 40;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6;
let bound: f64 = sigma * 6.0;
let mut ct: Ciphertext<VecZnx<Vec<u8>>> = Ciphertext::<VecZnx<Vec<u8>>>::new(&module, log_base2k, log_q, 2);
let mut pt: Plaintext<VecZnx<Vec<u8>>> = Plaintext::<VecZnx<Vec<u8>>>::new(&module, log_base2k, log_q);
let mut ct: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_ct, 2);
let mut pt: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_pt);
let mut source_xe = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned =
ScratchOwned::new(encrypt_rlwe_sk_scratch_bytes(&module, ct.size()) | decrypt_rlwe_scratch_bytes(&module, ct.size()));
let mut scratch: ScratchOwned = ScratchOwned::new(ct.encrypt_encsk_scratch_bytes(&module, ct.size()));
let sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk_dft.dft(&module, &sk);
let mut sk: SecretKey<Scalar<Vec<u8>>> = SecretKey::new(&module);
let mut sk_prep
sk.svp_prepare(&module, &mut sk_prep);
let mut data_want: Vec<i64> = vec![0i64; module.n()];
data_want
.iter_mut()
.for_each(|x| *x = source_xa.next_i64() & 0xFF);
pt.data
.encode_vec_i64(0, log_base2k, log_k_pt, &data_want, 10);
ct.encrypt_sk(
&module,
Some(&pt),
&sk_prep,
&sk_dft,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
sigma,
bound,
);
pt.data.zero();
ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
let mut data_have: Vec<i64> = vec![0i64; module.n()];
pt.data
.decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have);
let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64;
izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| {
let b_scaled = (*b as f64) / scale;
assert!(
(*a as f64 - b_scaled).abs() < 0.1,
"{} {}",
*a as f64,
b_scaled
)
});
module.free();
}
}

View File

@@ -1,31 +1,27 @@
use base2k::{
Backend, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, ZnxInfos, FFT64
Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut,
ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut,
};
use sampling::source::Source;
use crate::elem::derive_size;
pub struct SecretKey<T> {
data: T,
pub data: ScalarZnx<T>,
}
impl<T> SecretKey<T> {
pub fn data(&self) -> &T {
&self.data
}
pub fn data_mut(&mut self) -> &mut T {
&mut self.data
}
}
impl SecretKey<Scalar<Vec<u8>>> {
impl SecretKey<Vec<u8>> {
pub fn new<B: Backend>(module: &Module<B>) -> Self {
Self {
data: module.new_scalar(1),
}
}
}
impl<S> SecretKey<S>
where
S: AsMut<[u8]> + AsRef<[u8]>,
{
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
self.data.fill_ternary_prob(0, prob, source);
}
@@ -33,27 +29,66 @@ impl SecretKey<Scalar<Vec<u8>>> {
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
self.data.fill_ternary_hw(0, hw, source);
}
}
pub fn svp_prepare<D>(&self, module: &Module<FFT64>, sk_prep: &mut SecretKey<ScalarZnxDft<D, FFT64>>)
where
ScalarZnxDft<D, base2k::FFT64>: ScalarZnxDftToMut<base2k::FFT64>,
{
module.svp_prepare(&mut sk_prep.data, 0, &self.data, 0)
impl<C> ScalarZnxToMut for SecretKey<C>
where
ScalarZnx<C>: ScalarZnxToMut,
{
fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> {
self.data.to_mut()
}
}
type SecretKeyPrep<C, B> = SecretKey<ScalarZnxDft<C, B>>;
impl<C> ScalarZnxToRef for SecretKey<C>
where
ScalarZnx<C>: ScalarZnxToRef,
{
fn to_ref(&self) -> ScalarZnx<&[u8]> {
self.data.to_ref()
}
}
impl<B: Backend> SecretKey<ScalarZnxDft<Vec<u8>, B>> {
pub fn new(module: &Module<B>) -> Self{
Self{
data: module.new_scalar_znx_dft(1)
pub struct SecretKeyDft<T, B: Backend> {
pub data: ScalarZnxDft<T, B>,
}
impl<B: Backend> SecretKeyDft<Vec<u8>, B> {
pub fn new(module: &Module<B>) -> Self {
Self {
data: module.new_scalar_znx_dft(1),
}
}
pub fn dft<S>(&mut self, module: &Module<FFT64>, sk: &SecretKey<S>)
where
SecretKeyDft<Vec<u8>, B>: ScalarZnxDftToMut<base2k::FFT64>,
SecretKey<S>: ScalarZnxToRef,
{
module.svp_prepare(self, 0, sk, 0)
}
}
impl<C, B: Backend> ScalarZnxDftToMut<B> for SecretKeyDft<C, B>
where
ScalarZnxDft<C, B>: ScalarZnxDftToMut<B>,
{
fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> ScalarZnxDftToRef<B> for SecretKeyDft<C, B>
where
ScalarZnxDft<C, B>: ScalarZnxDftToRef<B>,
{
fn to_ref(&self) -> ScalarZnxDft<&[u8], B> {
self.data.to_ref()
}
}
pub struct PublicKey<D, B: Backend> {
data: VecZnxDft<D, B>,
pub data: VecZnxDft<D, B>,
}
impl<B: Backend> PublicKey<Vec<u8>, B> {