From 48ac28c4ce403f3f0c36ed84545c8b3750844a00 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 7 May 2025 17:04:42 +0200 Subject: [PATCH] Added sk/pk encryption for rlwe/rlwedft with tests --- rlwe/src/elem.rs | 4 +- rlwe/src/encryption.rs | 406 +++++++++++++++++++++++++++++++++++------ rlwe/src/keys.rs | 113 ++++++++++-- 3 files changed, 451 insertions(+), 72 deletions(-) diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index fe1b3b4..d1ddb74 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -154,9 +154,9 @@ pub struct RLWECtDft { } impl RLWECtDft, B> { - pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { Self { - data: module.new_vec_znx_dft(1, derive_size(log_base2k, log_k)), + data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_k)), log_base2k: log_base2k, log_k: log_k, } diff --git a/rlwe/src/encryption.rs b/rlwe/src/encryption.rs index 148ded4..0bdae33 100644 --- a/rlwe/src/encryption.rs +++ b/rlwe/src/encryption.rs @@ -1,16 +1,16 @@ use std::cmp::min; use base2k::{ - AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, - VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, - VecZnxDftToRef, VecZnxToMut, VecZnxToRef, + AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, + ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, + VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, }; use sampling::source::Source; use crate::{ elem::{Infos, RLWECt, RLWECtDft, RLWEPt}, - keys::SecretKeyDft, + keys::{PublicKey, SecretDistribution, SecretKeyDft}, }; pub fn encrypt_rlwe_sk_scratch_bytes(module: &Module, size: usize) -> usize { @@ -24,9 +24,9 @@ pub fn encrypt_rlwe_sk( sk: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, sigma: f64, bound: f64, + scratch: &mut Scratch, ) where VecZnx: VecZnxToMut + VecZnxToRef, VecZnx

: VecZnxToRef, @@ -74,12 +74,10 @@ pub fn decrypt_rlwe( VecZnx: VecZnxToRef, ScalarZnxDft: ScalarZnxDftToRef, { - let size: usize = min(pt.size(), ct.size()); - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct module.vec_znx_dft(&mut c0_dft, 0, ct, 1); // c0_dft = DFT(a) * DFT(s) @@ -111,16 +109,16 @@ impl RLWECt { sk: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, sigma: f64, bound: f64, + scratch: &mut Scratch, ) where VecZnx: VecZnxToMut + VecZnxToRef, VecZnx

: VecZnxToRef, ScalarZnxDft: ScalarZnxDftToRef, { encrypt_rlwe_sk( - module, self, pt, sk, source_xa, source_xe, scratch, sigma, bound, + module, self, pt, sk, source_xa, source_xe, sigma, bound, scratch, ) } @@ -132,84 +130,258 @@ impl RLWECt { { decrypt_rlwe(module, pt, self, sk, scratch); } + + pub fn encrypt_pk( + &mut self, + module: &Module, + pt: Option<&RLWEPt

>, + pk: &PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, + ) where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx

: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + { + encrypt_rlwe_pk( + module, self, pt, pk, source_xu, source_xe, sigma, bound, scratch, + ) + } } -pub(crate) fn encrypt_rlwe_zero_dft_scratch_bytes(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) +pub(crate) fn encrypt_zero_rlwe_dft_sk( + module: &Module, + ct: &mut RLWECtDft, + sk: &SecretKeyDft, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let log_base2k: usize = ct.log_base2k(); + let log_k: usize = ct.log_k(); + let size: usize = ct.size(); + + #[cfg(debug_assertions)] + { + match sk.dist { + SecretDistribution::NONE => panic!("invalid sk.dist = SecretDistribution::NONE"), + _ => {} + } + assert_eq!(ct.cols(), 2); + } + + // ct[1] = DFT(a) + { + let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size); + tmp_znx.fill_uniform(log_base2k, 0, size, source_xa); + module.vec_znx_dft(ct, 1, &tmp_znx, 0); + } + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); + + { + let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + // c0_dft = ct[1] * DFT(s) + module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1); + + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0); + } + + // c0_big += e + c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); + + // c0 = norm(c0_big = -as - e), NOTE: e is centered at 0. + let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size); + module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2); + module.vec_znx_negate_inplace(&mut tmp_znx, 0); + // ct[0] = DFT(-as + e) + module.vec_znx_dft(ct, 0, &tmp_znx, 0); +} + +pub(crate) fn encrypt_zero_rlwe_dft_scratch_bytes(module: &Module, size: usize) -> usize { + (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) + + module.bytes_of_vec_znx_big(1, size) + + module.bytes_of_vec_znx(1, size) + + module.vec_znx_big_normalize_tmp_bytes() +} + +pub fn decrypt_rlwe_dft( + module: &Module, + pt: &mut RLWEPt

, + ct: &RLWECtDft, + sk: &SecretKeyDft, + scratch: &mut Scratch, +) where + VecZnx

: VecZnxToMut + VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + ScalarZnxDft: ScalarZnxDftToRef, +{ + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct + + { + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct + // c0_dft = DFT(a) * DFT(s) + module.svp_apply(&mut c0_dft, 0, sk, 0, ct, 1); + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); + } + + { + let (mut c1_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, ct.size()); + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_idft(&mut c1_big, 0, ct, 0, scratch_2); + module.vec_znx_big_add_inplace(&mut c0_big, 0, &c1_big, 0); + } + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1); + + pt.log_base2k = ct.log_base2k(); + pt.log_k = min(pt.log_k(), ct.log_k()); +} + +pub fn decrypt_rlwe_dft_scratch_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() + | module.bytes_of_vec_znx_dft(1, size) + | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) + + module.bytes_of_vec_znx_big(1, size) } impl RLWECtDft { - fn encrypt_zero( + pub(crate) fn encrypt_zero_sk( + &mut self, module: &Module, - ct: &mut RLWECtDft, - sk: &SecretKeyDft, + sk_dft: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, - scratch: &mut Scratch, sigma: f64, bound: f64, + scratch: &mut Scratch, ) where VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, ScalarZnxDft: ScalarZnxDftToRef, { - let log_base2k: usize = ct.log_base2k(); - let log_k: usize = ct.log_k(); - let size: usize = ct.size(); - - // ct[1] = DFT(a) - { - let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size); - tmp_znx.fill_uniform(log_base2k, 1, size, source_xa); - module.vec_znx_dft(ct, 1, &tmp_znx, 0); - } - - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); - - { - let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); - // c0_dft = DFT(a) * DFT(s) - module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1); - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0); - } - - // c0_big += e - c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); - - // c0 = norm(c0_big = -as + e) - let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size); - module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2); - // ct[0] = DFT(-as + e) - module.vec_znx_dft(ct, 0, &tmp_znx, 0); + encrypt_zero_rlwe_dft_sk( + module, self, sk_dft, source_xa, source_xe, sigma, bound, scratch, + ) } - fn encrypt_zero_scratch_bytes(module: &Module, size: usize) -> usize { - (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) - + module.bytes_of_vec_znx_big(1, size) - + module.bytes_of_vec_znx(1, size) - + module.vec_znx_big_normalize_tmp_bytes() + pub fn decrypt( + &self, + module: &Module, + pt: &mut RLWEPt

, + sk_dft: &SecretKeyDft, + scratch: &mut Scratch, + ) where + VecZnx

: VecZnxToMut + VecZnxToRef, + VecZnxDft: VecZnxDftToRef, + ScalarZnxDft: ScalarZnxDftToRef, + { + decrypt_rlwe_dft(module, pt, self, sk_dft, scratch); } } +pub fn encrypt_rlwe_pk_scratch_bytes(module: &Module, pk_size: usize) -> usize { + ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) + + module.bytes_of_scalar_znx_dft(1) + + module.vec_znx_big_normalize_tmp_bytes() +} + +pub(crate) fn encrypt_rlwe_pk( + module: &Module, + ct: &mut RLWECt, + pt: Option<&RLWEPt

>, + pk: &PublicKey, + source_xu: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + scratch: &mut Scratch, +) where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx

: VecZnxToRef, + VecZnxDft: VecZnxDftToRef, +{ + #[cfg(debug_assertions)] + { + assert_eq!(ct.log_base2k(), pk.log_base2k()); + assert_eq!(ct.n(), module.n()); + assert_eq!(pk.n(), module.n()); + if let Some(pt) = pt { + assert_eq!(pt.log_base2k(), pk.log_base2k()); + assert_eq!(pt.n(), module.n()); + } + } + + let log_base2k: usize = pk.log_base2k(); + let size_pk: usize = pk.size(); + + // Generates u according to the underlying secret distribution. + let (mut u_dft, scratch_1) = scratch.tmp_scalar_dft(module, 1); + + { + let (mut u, _) = scratch_1.tmp_scalar(module, 1); + match pk.dist { + SecretDistribution::NONE => panic!( + "invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate" + ), + SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu), + SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu), + } + + module.svp_prepare(&mut u_dft, 0, &u, 0); + } + + let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) + let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity) + + // ct[0] = pk[0] * u + m + e0 + module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0); + module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); + tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); + + if let Some(pt) = pt { + module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0); + } + + module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3); + + // ct[1] = pk[1] * u + e1 + module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1); + module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0); + tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound); + module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3); +} + #[cfg(test)] mod tests { - use base2k::{Encoding, FFT64, Module, ScratchOwned, ZnxZero}; + use base2k::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxZero}; use itertools::izip; use sampling::source::Source; use crate::{ - elem::{Infos, RLWECt, RLWEPt}, - keys::{SecretKey, SecretKeyDft}, + elem::{Infos, RLWECt, RLWECtDft, RLWEPt}, + encryption::{decrypt_rlwe_dft_scratch_bytes, encrypt_zero_rlwe_dft_scratch_bytes}, + keys::{PublicKey, SecretKey, SecretKeyDft}, }; - use super::{decrypt_rlwe_scratch_bytes, encrypt_rlwe_sk_scratch_bytes}; + use super::{decrypt_rlwe_scratch_bytes, encrypt_rlwe_pk_scratch_bytes, encrypt_rlwe_sk_scratch_bytes}; #[test] fn encrypt_sk_vec_znx_fft64() { let module: Module = Module::::new(32); let log_base2k: usize = 8; let log_k_ct: usize = 54; - let log_k_pt: usize = 40; + let log_k_pt: usize = 30; let sigma: f64 = 3.2; let bound: f64 = sigma * 6.0; @@ -217,13 +389,16 @@ mod tests { let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct, 2); let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_pt); + let mut source_xs: Source = Source::new([0u8; 32]); let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); let mut scratch: ScratchOwned = ScratchOwned::new(encrypt_rlwe_sk_scratch_bytes(&module, ct.size()) | decrypt_rlwe_scratch_bytes(&module, ct.size())); - let sk: SecretKey> = SecretKey::new(&module); + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); sk_dft.dft(&module, &sk); @@ -242,9 +417,9 @@ mod tests { &sk_dft, &mut source_xa, &mut source_xe, - scratch.borrow(), sigma, bound, + scratch.borrow(), ); pt.data.zero(); @@ -256,6 +431,7 @@ mod tests { pt.data .decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have); + // TODO: properly assert the decryption noise through std(dec(ct) - pt) let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64; izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { let b_scaled = (*b as f64) / scale; @@ -269,4 +445,118 @@ mod tests { module.free(); } + + #[test] + fn encrypt_zero_rlwe_dft_sk_fft64() { + let module: Module = Module::::new(1024); + let log_base2k: usize = 8; + let log_k_ct: usize = 55; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([1u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut ct_dft: RLWECtDft, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2); + + let mut scratch: ScratchOwned = ScratchOwned::new( + encrypt_rlwe_sk_scratch_bytes(&module, ct_dft.size()) + | decrypt_rlwe_dft_scratch_bytes(&module, ct_dft.size()) + | encrypt_zero_rlwe_dft_scratch_bytes(&module, ct_dft.size()), + ); + + ct_dft.encrypt_zero_sk( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2); + module.free(); + } + + #[test] + fn encrypt_pk_vec_znx_fft64() { + let module: Module = Module::::new(32); + let log_base2k: usize = 8; + let log_k_ct: usize = 54; + let log_k_pk: usize = 64; + + let sigma: f64 = 3.2; + let bound: f64 = sigma * 6.0; + + let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct, 2); + let mut pt_want: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + let mut source_xs: Source = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); + let mut source_xa: Source = Source::new([0u8; 32]); + let mut source_xu: Source = Source::new([0u8; 32]); + + let mut sk: SecretKey> = SecretKey::new(&module); + sk.fill_ternary_prob(0.5, &mut source_xs); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut pk: PublicKey, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk); + pk.generate( + &module, + &sk_dft, + &mut source_xa, + &mut source_xe, + sigma, + bound, + ); + + let mut scratch: ScratchOwned = ScratchOwned::new( + encrypt_rlwe_sk_scratch_bytes(&module, ct.size()) + | decrypt_rlwe_scratch_bytes(&module, ct.size()) + | encrypt_rlwe_pk_scratch_bytes(&module, pk.size()), + ); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0); + + pt_want + .data + .encode_vec_i64(0, log_base2k, log_k_ct, &data_want, 10); + + ct.encrypt_pk( + &module, + Some(&pt_want), + &pk, + &mut source_xu, + &mut source_xe, + sigma, + bound, + scratch.borrow(), + ); + + let mut pt_have: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_ct); + + ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow()); + + module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0); + + assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2); + + module.free(); + } } diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index 767d1eb..89c33e3 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -1,19 +1,31 @@ use base2k::{ Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut, - ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, + ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, ScratchOwned, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos, }; use sampling::source::Source; -use crate::elem::derive_size; +use crate::{ + elem::{Infos, RLWECtDft}, + encryption::encrypt_zero_rlwe_dft_scratch_bytes, +}; + +#[derive(Clone, Copy, Debug)] +pub enum SecretDistribution { + TernaryFixed(usize), // Ternary with fixed Hamming weight + TernaryProb(f64), // Ternary with probabilistic Hamming weight + NONE, +} pub struct SecretKey { pub data: ScalarZnx, + pub dist: SecretDistribution, } impl SecretKey> { pub fn new(module: &Module) -> Self { Self { - data: module.new_scalar(1), + data: module.new_scalar_znx(1), + dist: SecretDistribution::NONE, } } } @@ -24,10 +36,12 @@ where { pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { self.data.fill_ternary_prob(0, prob, source); + self.dist = SecretDistribution::TernaryProb(prob); } pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { self.data.fill_ternary_hw(0, hw, source); + self.dist = SecretDistribution::TernaryFixed(hw); } } @@ -51,12 +65,14 @@ where pub struct SecretKeyDft { pub data: ScalarZnxDft, + pub dist: SecretDistribution, } impl SecretKeyDft, B> { pub fn new(module: &Module) -> Self { Self { data: module.new_scalar_znx_dft(1), + dist: SecretDistribution::NONE, } } @@ -65,7 +81,16 @@ impl SecretKeyDft, B> { SecretKeyDft, B>: ScalarZnxDftToMut, SecretKey: ScalarZnxToRef, { - module.svp_prepare(self, 0, sk, 0) + #[cfg(debug_assertions)] + { + match sk.dist { + SecretDistribution::NONE => panic!("invalid sk: SecretDistribution::NONE"), + _ => {} + } + } + + module.svp_prepare(self, 0, sk, 0); + self.dist = sk.dist; } } @@ -88,21 +113,85 @@ where } pub struct PublicKey { - pub data: VecZnxDft, + pub data: RLWECtDft, + pub dist: SecretDistribution, } impl PublicKey, B> { - pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { Self { - data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_q)), + data: RLWECtDft::new(module, log_base2k, log_k, 2), + dist: SecretDistribution::NONE, } } } -impl> PublicKey { - pub fn generate(&mut self, module: &Module, sk: &SecretKey>, scratch: &mut Scratch) - where - ScalarZnxDft: ScalarZnxDftToMut, - { +impl Infos for PublicKey { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { + &self.data.data + } + + fn log_base2k(&self) -> usize { + self.data.log_base2k + } + + fn log_k(&self) -> usize { + self.data.log_k + } +} + +impl VecZnxDftToMut for PublicKey +where + VecZnxDft: VecZnxDftToMut, +{ + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl VecZnxDftToRef for PublicKey +where + VecZnxDft: VecZnxDftToRef, +{ + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + self.data.to_ref() + } +} + +impl PublicKey { + pub fn generate( + &mut self, + module: &Module, + sk_dft: &SecretKeyDft, + source_xa: &mut Source, + source_xe: &mut Source, + sigma: f64, + bound: f64, + ) where + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + ScalarZnxDft: ScalarZnxDftToRef + ZnxInfos, + { + #[cfg(debug_assertions)] + { + match sk_dft.dist { + SecretDistribution::NONE => panic!("invalid sk_dft: SecretDistribution::NONE"), + _ => {} + } + } + + // Its ok to allocate scratch space here since pk is usually generated only once. + let mut scratch: ScratchOwned = ScratchOwned::new(encrypt_zero_rlwe_dft_scratch_bytes(module, self.size())); + self.data.encrypt_zero_sk( + module, + sk_dft, + source_xa, + source_xe, + sigma, + bound, + scratch.borrow(), + ); + self.dist = sk_dft.dist; } }