Added sk/pk encryption for rlwe/rlwedft with tests

This commit is contained in:
Jean-Philippe Bossuat
2025-05-07 17:04:42 +02:00
parent 6cbd2a6a93
commit 48ac28c4ce
3 changed files with 451 additions and 72 deletions

View File

@@ -154,9 +154,9 @@ pub struct RLWECtDft<C, B: Backend> {
}
impl<B: Backend> RLWECtDft<Vec<u8>, B> {
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize, cols: usize) -> Self {
Self {
data: module.new_vec_znx_dft(1, derive_size(log_base2k, log_k)),
data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_k)),
log_base2k: log_base2k,
log_k: log_k,
}

View File

@@ -1,16 +1,16 @@
use std::cmp::min;
use base2k::{
AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx,
VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut,
VecZnxDftToRef, VecZnxToMut, VecZnxToRef,
AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps,
ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc,
VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef,
};
use sampling::source::Source;
use crate::{
elem::{Infos, RLWECt, RLWECtDft, RLWEPt},
keys::SecretKeyDft,
keys::{PublicKey, SecretDistribution, SecretKeyDft},
};
pub fn encrypt_rlwe_sk_scratch_bytes<B: Backend>(module: &Module<B>, size: usize) -> usize {
@@ -24,9 +24,9 @@ pub fn encrypt_rlwe_sk<C, P, S>(
sk: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch,
sigma: f64,
bound: f64,
scratch: &mut Scratch,
) where
VecZnx<C>: VecZnxToMut + VecZnxToRef,
VecZnx<P>: VecZnxToRef,
@@ -74,12 +74,10 @@ pub fn decrypt_rlwe<P, C, S>(
VecZnx<C>: VecZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
let size: usize = min(pt.size(), ct.size());
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size);
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct
{
let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size);
let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct
module.vec_znx_dft(&mut c0_dft, 0, ct, 1);
// c0_dft = DFT(a) * DFT(s)
@@ -111,16 +109,16 @@ impl<C> RLWECt<C> {
sk: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch,
sigma: f64,
bound: f64,
scratch: &mut Scratch,
) where
VecZnx<C>: VecZnxToMut + VecZnxToRef,
VecZnx<P>: VecZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
encrypt_rlwe_sk(
module, self, pt, sk, source_xa, source_xe, scratch, sigma, bound,
module, self, pt, sk, source_xa, source_xe, sigma, bound, scratch,
)
}
@@ -132,34 +130,58 @@ impl<C> RLWECt<C> {
{
decrypt_rlwe(module, pt, self, sk, scratch);
}
pub fn encrypt_pk<P, S>(
&mut self,
module: &Module<FFT64>,
pt: Option<&RLWEPt<P>>,
pk: &PublicKey<S, FFT64>,
source_xu: &mut Source,
source_xe: &mut Source,
sigma: f64,
bound: f64,
scratch: &mut Scratch,
) where
VecZnx<C>: VecZnxToMut + VecZnxToRef,
VecZnx<P>: VecZnxToRef,
VecZnxDft<S, FFT64>: VecZnxDftToRef<FFT64>,
{
encrypt_rlwe_pk(
module, self, pt, pk, source_xu, source_xe, sigma, bound, scratch,
)
}
}
pub(crate) fn encrypt_rlwe_zero_dft_scratch_bytes<B: Backend>(module: &Module<FFT64>, size: usize) -> usize {
(module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size)
}
impl<C> RLWECtDft<C, FFT64> {
fn encrypt_zero<S>(
pub(crate) fn encrypt_zero_rlwe_dft_sk<C, S>(
module: &Module<FFT64>,
ct: &mut RLWECtDft<C, FFT64>,
sk: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
scratch: &mut Scratch,
sigma: f64,
bound: f64,
) where
scratch: &mut Scratch,
) where
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
{
let log_base2k: usize = ct.log_base2k();
let log_k: usize = ct.log_k();
let size: usize = ct.size();
#[cfg(debug_assertions)]
{
match sk.dist {
SecretDistribution::NONE => panic!("invalid sk.dist = SecretDistribution::NONE"),
_ => {}
}
assert_eq!(ct.cols(), 2);
}
// ct[1] = DFT(a)
{
let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size);
tmp_znx.fill_uniform(log_base2k, 1, size, source_xa);
tmp_znx.fill_uniform(log_base2k, 0, size, source_xa);
module.vec_znx_dft(ct, 1, &tmp_znx, 0);
}
@@ -167,8 +189,9 @@ impl<C> RLWECtDft<C, FFT64> {
{
let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size);
// c0_dft = DFT(a) * DFT(s)
// c0_dft = ct[1] * DFT(s)
module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1);
// c0_big = IDFT(c0_dft)
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0);
}
@@ -176,40 +199,189 @@ impl<C> RLWECtDft<C, FFT64> {
// c0_big += e
c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound);
// c0 = norm(c0_big = -as + e)
// c0 = norm(c0_big = -as - e), NOTE: e is centered at 0.
let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size);
module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2);
module.vec_znx_negate_inplace(&mut tmp_znx, 0);
// ct[0] = DFT(-as + e)
module.vec_znx_dft(ct, 0, &tmp_znx, 0);
}
}
fn encrypt_zero_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
pub(crate) fn encrypt_zero_rlwe_dft_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
(module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size))
+ module.bytes_of_vec_znx_big(1, size)
+ module.bytes_of_vec_znx(1, size)
+ module.vec_znx_big_normalize_tmp_bytes()
}
pub fn decrypt_rlwe_dft<P, C, S>(
module: &Module<FFT64>,
pt: &mut RLWEPt<P>,
ct: &RLWECtDft<C, FFT64>,
sk: &SecretKeyDft<S, FFT64>,
scratch: &mut Scratch,
) where
VecZnx<P>: VecZnxToMut + VecZnxToRef,
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64>,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct
{
let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct
// c0_dft = DFT(a) * DFT(s)
module.svp_apply(&mut c0_dft, 0, sk, 0, ct, 1);
// c0_big = IDFT(c0_dft)
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
}
{
let (mut c1_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, ct.size());
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
module.vec_znx_idft(&mut c1_big, 0, ct, 0, scratch_2);
module.vec_znx_big_add_inplace(&mut c0_big, 0, &c1_big, 0);
}
// pt = norm(BIG(m + e))
module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1);
pt.log_base2k = ct.log_base2k();
pt.log_k = min(pt.log_k(), ct.log_k());
}
pub fn decrypt_rlwe_dft_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
(module.vec_znx_big_normalize_tmp_bytes()
| module.bytes_of_vec_znx_dft(1, size)
| (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes()))
+ module.bytes_of_vec_znx_big(1, size)
}
impl<C> RLWECtDft<C, FFT64> {
pub(crate) fn encrypt_zero_sk<S>(
&mut self,
module: &Module<FFT64>,
sk_dft: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
bound: f64,
scratch: &mut Scratch,
) where
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
encrypt_zero_rlwe_dft_sk(
module, self, sk_dft, source_xa, source_xe, sigma, bound, scratch,
)
}
pub fn decrypt<P, S>(
&self,
module: &Module<FFT64>,
pt: &mut RLWEPt<P>,
sk_dft: &SecretKeyDft<S, FFT64>,
scratch: &mut Scratch,
) where
VecZnx<P>: VecZnxToMut + VecZnxToRef,
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64>,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
decrypt_rlwe_dft(module, pt, self, sk_dft, scratch);
}
}
pub fn encrypt_rlwe_pk_scratch_bytes<B: Backend>(module: &Module<B>, pk_size: usize) -> usize {
((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1))
+ module.bytes_of_scalar_znx_dft(1)
+ module.vec_znx_big_normalize_tmp_bytes()
}
pub(crate) fn encrypt_rlwe_pk<C, P, S>(
module: &Module<FFT64>,
ct: &mut RLWECt<C>,
pt: Option<&RLWEPt<P>>,
pk: &PublicKey<S, FFT64>,
source_xu: &mut Source,
source_xe: &mut Source,
sigma: f64,
bound: f64,
scratch: &mut Scratch,
) where
VecZnx<C>: VecZnxToMut + VecZnxToRef,
VecZnx<P>: VecZnxToRef,
VecZnxDft<S, FFT64>: VecZnxDftToRef<FFT64>,
{
#[cfg(debug_assertions)]
{
assert_eq!(ct.log_base2k(), pk.log_base2k());
assert_eq!(ct.n(), module.n());
assert_eq!(pk.n(), module.n());
if let Some(pt) = pt {
assert_eq!(pt.log_base2k(), pk.log_base2k());
assert_eq!(pt.n(), module.n());
}
}
let log_base2k: usize = pk.log_base2k();
let size_pk: usize = pk.size();
// Generates u according to the underlying secret distribution.
let (mut u_dft, scratch_1) = scratch.tmp_scalar_dft(module, 1);
{
let (mut u, _) = scratch_1.tmp_scalar(module, 1);
match pk.dist {
SecretDistribution::NONE => panic!(
"invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate"
),
SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu),
SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu),
}
module.svp_prepare(&mut u_dft, 0, &u, 0);
}
let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity)
let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity)
// ct[0] = pk[0] * u + m + e0
module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0);
module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0);
tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound);
if let Some(pt) = pt {
module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0);
}
module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3);
// ct[1] = pk[1] * u + e1
module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1);
module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0);
tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound);
module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3);
}
#[cfg(test)]
mod tests {
use base2k::{Encoding, FFT64, Module, ScratchOwned, ZnxZero};
use base2k::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxZero};
use itertools::izip;
use sampling::source::Source;
use crate::{
elem::{Infos, RLWECt, RLWEPt},
keys::{SecretKey, SecretKeyDft},
elem::{Infos, RLWECt, RLWECtDft, RLWEPt},
encryption::{decrypt_rlwe_dft_scratch_bytes, encrypt_zero_rlwe_dft_scratch_bytes},
keys::{PublicKey, SecretKey, SecretKeyDft},
};
use super::{decrypt_rlwe_scratch_bytes, encrypt_rlwe_sk_scratch_bytes};
use super::{decrypt_rlwe_scratch_bytes, encrypt_rlwe_pk_scratch_bytes, encrypt_rlwe_sk_scratch_bytes};
#[test]
fn encrypt_sk_vec_znx_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(32);
let log_base2k: usize = 8;
let log_k_ct: usize = 54;
let log_k_pt: usize = 40;
let log_k_pt: usize = 30;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
@@ -217,13 +389,16 @@ mod tests {
let mut ct: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_ct, 2);
let mut pt: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_pt);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut scratch: ScratchOwned =
ScratchOwned::new(encrypt_rlwe_sk_scratch_bytes(&module, ct.size()) | decrypt_rlwe_scratch_bytes(&module, ct.size()));
let sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk_dft.dft(&module, &sk);
@@ -242,9 +417,9 @@ mod tests {
&sk_dft,
&mut source_xa,
&mut source_xe,
scratch.borrow(),
sigma,
bound,
scratch.borrow(),
);
pt.data.zero();
@@ -256,6 +431,7 @@ mod tests {
pt.data
.decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have);
// TODO: properly assert the decryption noise through std(dec(ct) - pt)
let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64;
izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| {
let b_scaled = (*b as f64) / scale;
@@ -269,4 +445,118 @@ mod tests {
module.free();
}
#[test]
fn encrypt_zero_rlwe_dft_sk_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(1024);
let log_base2k: usize = 8;
let log_k_ct: usize = 55;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
let mut pt: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([1u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk_dft.dft(&module, &sk);
let mut ct_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2);
let mut scratch: ScratchOwned = ScratchOwned::new(
encrypt_rlwe_sk_scratch_bytes(&module, ct_dft.size())
| decrypt_rlwe_dft_scratch_bytes(&module, ct_dft.size())
| encrypt_zero_rlwe_dft_scratch_bytes(&module, ct_dft.size()),
);
ct_dft.encrypt_zero_sk(
&module,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2);
module.free();
}
#[test]
fn encrypt_pk_vec_znx_fft64() {
let module: Module<FFT64> = Module::<FFT64>::new(32);
let log_base2k: usize = 8;
let log_k_ct: usize = 54;
let log_k_pk: usize = 64;
let sigma: f64 = 3.2;
let bound: f64 = sigma * 6.0;
let mut ct: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_ct, 2);
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
let mut source_xa: Source = Source::new([0u8; 32]);
let mut source_xu: Source = Source::new([0u8; 32]);
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
sk.fill_ternary_prob(0.5, &mut source_xs);
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
sk_dft.dft(&module, &sk);
let mut pk: PublicKey<Vec<u8>, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk);
pk.generate(
&module,
&sk_dft,
&mut source_xa,
&mut source_xe,
sigma,
bound,
);
let mut scratch: ScratchOwned = ScratchOwned::new(
encrypt_rlwe_sk_scratch_bytes(&module, ct.size())
| decrypt_rlwe_scratch_bytes(&module, ct.size())
| encrypt_rlwe_pk_scratch_bytes(&module, pk.size()),
);
let mut data_want: Vec<i64> = vec![0i64; module.n()];
data_want
.iter_mut()
.for_each(|x| *x = source_xa.next_i64() & 0);
pt_want
.data
.encode_vec_i64(0, log_base2k, log_k_ct, &data_want, 10);
ct.encrypt_pk(
&module,
Some(&pt_want),
&pk,
&mut source_xu,
&mut source_xe,
sigma,
bound,
scratch.borrow(),
);
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0);
assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2);
module.free();
}
}

View File

@@ -1,19 +1,31 @@
use base2k::{
Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut,
ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut,
ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, ScratchOwned, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos,
};
use sampling::source::Source;
use crate::elem::derive_size;
use crate::{
elem::{Infos, RLWECtDft},
encryption::encrypt_zero_rlwe_dft_scratch_bytes,
};
#[derive(Clone, Copy, Debug)]
pub enum SecretDistribution {
TernaryFixed(usize), // Ternary with fixed Hamming weight
TernaryProb(f64), // Ternary with probabilistic Hamming weight
NONE,
}
pub struct SecretKey<T> {
pub data: ScalarZnx<T>,
pub dist: SecretDistribution,
}
impl SecretKey<Vec<u8>> {
pub fn new<B: Backend>(module: &Module<B>) -> Self {
Self {
data: module.new_scalar(1),
data: module.new_scalar_znx(1),
dist: SecretDistribution::NONE,
}
}
}
@@ -24,10 +36,12 @@ where
{
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
self.data.fill_ternary_prob(0, prob, source);
self.dist = SecretDistribution::TernaryProb(prob);
}
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
self.data.fill_ternary_hw(0, hw, source);
self.dist = SecretDistribution::TernaryFixed(hw);
}
}
@@ -51,12 +65,14 @@ where
pub struct SecretKeyDft<T, B: Backend> {
pub data: ScalarZnxDft<T, B>,
pub dist: SecretDistribution,
}
impl<B: Backend> SecretKeyDft<Vec<u8>, B> {
pub fn new(module: &Module<B>) -> Self {
Self {
data: module.new_scalar_znx_dft(1),
dist: SecretDistribution::NONE,
}
}
@@ -65,7 +81,16 @@ impl<B: Backend> SecretKeyDft<Vec<u8>, B> {
SecretKeyDft<Vec<u8>, B>: ScalarZnxDftToMut<base2k::FFT64>,
SecretKey<S>: ScalarZnxToRef,
{
module.svp_prepare(self, 0, sk, 0)
#[cfg(debug_assertions)]
{
match sk.dist {
SecretDistribution::NONE => panic!("invalid sk: SecretDistribution::NONE"),
_ => {}
}
}
module.svp_prepare(self, 0, sk, 0);
self.dist = sk.dist;
}
}
@@ -88,21 +113,85 @@ where
}
pub struct PublicKey<D, B: Backend> {
pub data: VecZnxDft<D, B>,
pub data: RLWECtDft<D, B>,
pub dist: SecretDistribution,
}
impl<B: Backend> PublicKey<Vec<u8>, B> {
pub fn new(module: &Module<B>, log_base2k: usize, log_q: usize) -> Self {
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
Self {
data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_q)),
data: RLWECtDft::new(module, log_base2k, log_k, 2),
dist: SecretDistribution::NONE,
}
}
}
impl<B: Backend, D: VecZnxDftToMut<B>> PublicKey<D, B> {
pub fn generate<S>(&mut self, module: &Module<B>, sk: &SecretKey<ScalarZnxDft<S, B>>, scratch: &mut Scratch)
where
ScalarZnxDft<S, B>: ScalarZnxDftToMut<B>,
{
impl<T, B: Backend> Infos for PublicKey<T, B> {
type Inner = VecZnxDft<T, B>;
fn inner(&self) -> &Self::Inner {
&self.data.data
}
fn log_base2k(&self) -> usize {
self.data.log_base2k
}
fn log_k(&self) -> usize {
self.data.log_k
}
}
impl<C, B: Backend> VecZnxDftToMut<B> for PublicKey<C, B>
where
VecZnxDft<C, B>: VecZnxDftToMut<B>,
{
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
self.data.to_mut()
}
}
impl<C, B: Backend> VecZnxDftToRef<B> for PublicKey<C, B>
where
VecZnxDft<C, B>: VecZnxDftToRef<B>,
{
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
self.data.to_ref()
}
}
impl<C> PublicKey<C, FFT64> {
pub fn generate<S>(
&mut self,
module: &Module<FFT64>,
sk_dft: &SecretKeyDft<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
bound: f64,
) where
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64> + ZnxInfos,
{
#[cfg(debug_assertions)]
{
match sk_dft.dist {
SecretDistribution::NONE => panic!("invalid sk_dft: SecretDistribution::NONE"),
_ => {}
}
}
// Its ok to allocate scratch space here since pk is usually generated only once.
let mut scratch: ScratchOwned = ScratchOwned::new(encrypt_zero_rlwe_dft_scratch_bytes(module, self.size()));
self.data.encrypt_zero_sk(
module,
sk_dft,
source_xa,
source_xe,
sigma,
bound,
scratch.borrow(),
);
self.dist = sk_dft.dist;
}
}