mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Added sk/pk encryption for rlwe/rlwedft with tests
This commit is contained in:
@@ -154,9 +154,9 @@ pub struct RLWECtDft<C, B: Backend> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> RLWECtDft<Vec<u8>, B> {
|
impl<B: Backend> RLWECtDft<Vec<u8>, B> {
|
||||||
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
|
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize, cols: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
data: module.new_vec_znx_dft(1, derive_size(log_base2k, log_k)),
|
data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_k)),
|
||||||
log_base2k: log_base2k,
|
log_base2k: log_base2k,
|
||||||
log_k: log_k,
|
log_k: log_k,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
use std::cmp::min;
|
use std::cmp::min;
|
||||||
|
|
||||||
use base2k::{
|
use base2k::{
|
||||||
AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx,
|
AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps,
|
||||||
VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut,
|
ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc,
|
||||||
VecZnxDftToRef, VecZnxToMut, VecZnxToRef,
|
VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef,
|
||||||
};
|
};
|
||||||
|
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
elem::{Infos, RLWECt, RLWECtDft, RLWEPt},
|
elem::{Infos, RLWECt, RLWECtDft, RLWEPt},
|
||||||
keys::SecretKeyDft,
|
keys::{PublicKey, SecretDistribution, SecretKeyDft},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn encrypt_rlwe_sk_scratch_bytes<B: Backend>(module: &Module<B>, size: usize) -> usize {
|
pub fn encrypt_rlwe_sk_scratch_bytes<B: Backend>(module: &Module<B>, size: usize) -> usize {
|
||||||
@@ -24,9 +24,9 @@ pub fn encrypt_rlwe_sk<C, P, S>(
|
|||||||
sk: &SecretKeyDft<S, FFT64>,
|
sk: &SecretKeyDft<S, FFT64>,
|
||||||
source_xa: &mut Source,
|
source_xa: &mut Source,
|
||||||
source_xe: &mut Source,
|
source_xe: &mut Source,
|
||||||
scratch: &mut Scratch,
|
|
||||||
sigma: f64,
|
sigma: f64,
|
||||||
bound: f64,
|
bound: f64,
|
||||||
|
scratch: &mut Scratch,
|
||||||
) where
|
) where
|
||||||
VecZnx<C>: VecZnxToMut + VecZnxToRef,
|
VecZnx<C>: VecZnxToMut + VecZnxToRef,
|
||||||
VecZnx<P>: VecZnxToRef,
|
VecZnx<P>: VecZnxToRef,
|
||||||
@@ -74,12 +74,10 @@ pub fn decrypt_rlwe<P, C, S>(
|
|||||||
VecZnx<C>: VecZnxToRef,
|
VecZnx<C>: VecZnxToRef,
|
||||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
let size: usize = min(pt.size(), ct.size());
|
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct
|
||||||
|
|
||||||
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size);
|
|
||||||
|
|
||||||
{
|
{
|
||||||
let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size);
|
let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct
|
||||||
module.vec_znx_dft(&mut c0_dft, 0, ct, 1);
|
module.vec_znx_dft(&mut c0_dft, 0, ct, 1);
|
||||||
|
|
||||||
// c0_dft = DFT(a) * DFT(s)
|
// c0_dft = DFT(a) * DFT(s)
|
||||||
@@ -111,16 +109,16 @@ impl<C> RLWECt<C> {
|
|||||||
sk: &SecretKeyDft<S, FFT64>,
|
sk: &SecretKeyDft<S, FFT64>,
|
||||||
source_xa: &mut Source,
|
source_xa: &mut Source,
|
||||||
source_xe: &mut Source,
|
source_xe: &mut Source,
|
||||||
scratch: &mut Scratch,
|
|
||||||
sigma: f64,
|
sigma: f64,
|
||||||
bound: f64,
|
bound: f64,
|
||||||
|
scratch: &mut Scratch,
|
||||||
) where
|
) where
|
||||||
VecZnx<C>: VecZnxToMut + VecZnxToRef,
|
VecZnx<C>: VecZnxToMut + VecZnxToRef,
|
||||||
VecZnx<P>: VecZnxToRef,
|
VecZnx<P>: VecZnxToRef,
|
||||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||||
{
|
{
|
||||||
encrypt_rlwe_sk(
|
encrypt_rlwe_sk(
|
||||||
module, self, pt, sk, source_xa, source_xe, scratch, sigma, bound,
|
module, self, pt, sk, source_xa, source_xe, sigma, bound, scratch,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,22 +130,37 @@ impl<C> RLWECt<C> {
|
|||||||
{
|
{
|
||||||
decrypt_rlwe(module, pt, self, sk, scratch);
|
decrypt_rlwe(module, pt, self, sk, scratch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn encrypt_pk<P, S>(
|
||||||
|
&mut self,
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
pt: Option<&RLWEPt<P>>,
|
||||||
|
pk: &PublicKey<S, FFT64>,
|
||||||
|
source_xu: &mut Source,
|
||||||
|
source_xe: &mut Source,
|
||||||
|
sigma: f64,
|
||||||
|
bound: f64,
|
||||||
|
scratch: &mut Scratch,
|
||||||
|
) where
|
||||||
|
VecZnx<C>: VecZnxToMut + VecZnxToRef,
|
||||||
|
VecZnx<P>: VecZnxToRef,
|
||||||
|
VecZnxDft<S, FFT64>: VecZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
|
encrypt_rlwe_pk(
|
||||||
|
module, self, pt, pk, source_xu, source_xe, sigma, bound, scratch,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn encrypt_rlwe_zero_dft_scratch_bytes<B: Backend>(module: &Module<FFT64>, size: usize) -> usize {
|
pub(crate) fn encrypt_zero_rlwe_dft_sk<C, S>(
|
||||||
(module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<C> RLWECtDft<C, FFT64> {
|
|
||||||
fn encrypt_zero<S>(
|
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
ct: &mut RLWECtDft<C, FFT64>,
|
ct: &mut RLWECtDft<C, FFT64>,
|
||||||
sk: &SecretKeyDft<S, FFT64>,
|
sk: &SecretKeyDft<S, FFT64>,
|
||||||
source_xa: &mut Source,
|
source_xa: &mut Source,
|
||||||
source_xe: &mut Source,
|
source_xe: &mut Source,
|
||||||
scratch: &mut Scratch,
|
|
||||||
sigma: f64,
|
sigma: f64,
|
||||||
bound: f64,
|
bound: f64,
|
||||||
|
scratch: &mut Scratch,
|
||||||
) where
|
) where
|
||||||
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
|
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
|
||||||
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||||
@@ -156,10 +169,19 @@ impl<C> RLWECtDft<C, FFT64> {
|
|||||||
let log_k: usize = ct.log_k();
|
let log_k: usize = ct.log_k();
|
||||||
let size: usize = ct.size();
|
let size: usize = ct.size();
|
||||||
|
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
match sk.dist {
|
||||||
|
SecretDistribution::NONE => panic!("invalid sk.dist = SecretDistribution::NONE"),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
assert_eq!(ct.cols(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
// ct[1] = DFT(a)
|
// ct[1] = DFT(a)
|
||||||
{
|
{
|
||||||
let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size);
|
let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size);
|
||||||
tmp_znx.fill_uniform(log_base2k, 1, size, source_xa);
|
tmp_znx.fill_uniform(log_base2k, 0, size, source_xa);
|
||||||
module.vec_znx_dft(ct, 1, &tmp_znx, 0);
|
module.vec_znx_dft(ct, 1, &tmp_znx, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,8 +189,9 @@ impl<C> RLWECtDft<C, FFT64> {
|
|||||||
|
|
||||||
{
|
{
|
||||||
let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size);
|
let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size);
|
||||||
// c0_dft = DFT(a) * DFT(s)
|
// c0_dft = ct[1] * DFT(s)
|
||||||
module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1);
|
module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1);
|
||||||
|
|
||||||
// c0_big = IDFT(c0_dft)
|
// c0_big = IDFT(c0_dft)
|
||||||
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0);
|
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0);
|
||||||
}
|
}
|
||||||
@@ -176,40 +199,189 @@ impl<C> RLWECtDft<C, FFT64> {
|
|||||||
// c0_big += e
|
// c0_big += e
|
||||||
c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound);
|
c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound);
|
||||||
|
|
||||||
// c0 = norm(c0_big = -as + e)
|
// c0 = norm(c0_big = -as - e), NOTE: e is centered at 0.
|
||||||
let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size);
|
let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size);
|
||||||
module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2);
|
module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2);
|
||||||
|
module.vec_znx_negate_inplace(&mut tmp_znx, 0);
|
||||||
// ct[0] = DFT(-as + e)
|
// ct[0] = DFT(-as + e)
|
||||||
module.vec_znx_dft(ct, 0, &tmp_znx, 0);
|
module.vec_znx_dft(ct, 0, &tmp_znx, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn encrypt_zero_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
|
pub(crate) fn encrypt_zero_rlwe_dft_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
|
||||||
(module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size))
|
(module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size))
|
||||||
+ module.bytes_of_vec_znx_big(1, size)
|
+ module.bytes_of_vec_znx_big(1, size)
|
||||||
+ module.bytes_of_vec_znx(1, size)
|
+ module.bytes_of_vec_znx(1, size)
|
||||||
+ module.vec_znx_big_normalize_tmp_bytes()
|
+ module.vec_znx_big_normalize_tmp_bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn decrypt_rlwe_dft<P, C, S>(
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
pt: &mut RLWEPt<P>,
|
||||||
|
ct: &RLWECtDft<C, FFT64>,
|
||||||
|
sk: &SecretKeyDft<S, FFT64>,
|
||||||
|
scratch: &mut Scratch,
|
||||||
|
) where
|
||||||
|
VecZnx<P>: VecZnxToMut + VecZnxToRef,
|
||||||
|
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64>,
|
||||||
|
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
|
let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, ct.size()); // TODO optimize size when pt << ct
|
||||||
|
|
||||||
|
{
|
||||||
|
let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, ct.size()); // TODO optimize size when pt << ct
|
||||||
|
// c0_dft = DFT(a) * DFT(s)
|
||||||
|
module.svp_apply(&mut c0_dft, 0, sk, 0, ct, 1);
|
||||||
|
// c0_big = IDFT(c0_dft)
|
||||||
|
module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let (mut c1_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, ct.size());
|
||||||
|
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
|
||||||
|
module.vec_znx_idft(&mut c1_big, 0, ct, 0, scratch_2);
|
||||||
|
module.vec_znx_big_add_inplace(&mut c0_big, 0, &c1_big, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// pt = norm(BIG(m + e))
|
||||||
|
module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1);
|
||||||
|
|
||||||
|
pt.log_base2k = ct.log_base2k();
|
||||||
|
pt.log_k = min(pt.log_k(), ct.log_k());
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decrypt_rlwe_dft_scratch_bytes(module: &Module<FFT64>, size: usize) -> usize {
|
||||||
|
(module.vec_znx_big_normalize_tmp_bytes()
|
||||||
|
| module.bytes_of_vec_znx_dft(1, size)
|
||||||
|
| (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes()))
|
||||||
|
+ module.bytes_of_vec_znx_big(1, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C> RLWECtDft<C, FFT64> {
|
||||||
|
pub(crate) fn encrypt_zero_sk<S>(
|
||||||
|
&mut self,
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||||
|
source_xa: &mut Source,
|
||||||
|
source_xe: &mut Source,
|
||||||
|
sigma: f64,
|
||||||
|
bound: f64,
|
||||||
|
scratch: &mut Scratch,
|
||||||
|
) where
|
||||||
|
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
|
||||||
|
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
|
encrypt_zero_rlwe_dft_sk(
|
||||||
|
module, self, sk_dft, source_xa, source_xe, sigma, bound, scratch,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decrypt<P, S>(
|
||||||
|
&self,
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
pt: &mut RLWEPt<P>,
|
||||||
|
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||||
|
scratch: &mut Scratch,
|
||||||
|
) where
|
||||||
|
VecZnx<P>: VecZnxToMut + VecZnxToRef,
|
||||||
|
VecZnxDft<C, FFT64>: VecZnxDftToRef<FFT64>,
|
||||||
|
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
|
decrypt_rlwe_dft(module, pt, self, sk_dft, scratch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encrypt_rlwe_pk_scratch_bytes<B: Backend>(module: &Module<B>, pk_size: usize) -> usize {
|
||||||
|
((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1))
|
||||||
|
+ module.bytes_of_scalar_znx_dft(1)
|
||||||
|
+ module.vec_znx_big_normalize_tmp_bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn encrypt_rlwe_pk<C, P, S>(
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
ct: &mut RLWECt<C>,
|
||||||
|
pt: Option<&RLWEPt<P>>,
|
||||||
|
pk: &PublicKey<S, FFT64>,
|
||||||
|
source_xu: &mut Source,
|
||||||
|
source_xe: &mut Source,
|
||||||
|
sigma: f64,
|
||||||
|
bound: f64,
|
||||||
|
scratch: &mut Scratch,
|
||||||
|
) where
|
||||||
|
VecZnx<C>: VecZnxToMut + VecZnxToRef,
|
||||||
|
VecZnx<P>: VecZnxToRef,
|
||||||
|
VecZnxDft<S, FFT64>: VecZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(ct.log_base2k(), pk.log_base2k());
|
||||||
|
assert_eq!(ct.n(), module.n());
|
||||||
|
assert_eq!(pk.n(), module.n());
|
||||||
|
if let Some(pt) = pt {
|
||||||
|
assert_eq!(pt.log_base2k(), pk.log_base2k());
|
||||||
|
assert_eq!(pt.n(), module.n());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let log_base2k: usize = pk.log_base2k();
|
||||||
|
let size_pk: usize = pk.size();
|
||||||
|
|
||||||
|
// Generates u according to the underlying secret distribution.
|
||||||
|
let (mut u_dft, scratch_1) = scratch.tmp_scalar_dft(module, 1);
|
||||||
|
|
||||||
|
{
|
||||||
|
let (mut u, _) = scratch_1.tmp_scalar(module, 1);
|
||||||
|
match pk.dist {
|
||||||
|
SecretDistribution::NONE => panic!(
|
||||||
|
"invalid public key: SecretDistribution::NONE, ensure it has been correctly intialized through Self::generate"
|
||||||
|
),
|
||||||
|
SecretDistribution::TernaryFixed(hw) => u.fill_ternary_hw(0, hw, source_xu),
|
||||||
|
SecretDistribution::TernaryProb(prob) => u.fill_ternary_prob(0, prob, source_xu),
|
||||||
|
}
|
||||||
|
|
||||||
|
module.svp_prepare(&mut u_dft, 0, &u, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (mut tmp_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity)
|
||||||
|
let (mut tmp_dft, scratch_3) = scratch_2.tmp_vec_znx_dft(module, 1, size_pk); // TODO optimize size (e.g. when encrypting at low homomorphic capacity)
|
||||||
|
|
||||||
|
// ct[0] = pk[0] * u + m + e0
|
||||||
|
module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 0);
|
||||||
|
module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0);
|
||||||
|
tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound);
|
||||||
|
|
||||||
|
if let Some(pt) = pt {
|
||||||
|
module.vec_znx_big_add_small_inplace(&mut tmp_big, 0, pt, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
module.vec_znx_big_normalize(log_base2k, ct, 0, &tmp_big, 0, scratch_3);
|
||||||
|
|
||||||
|
// ct[1] = pk[1] * u + e1
|
||||||
|
module.svp_apply(&mut tmp_dft, 0, &u_dft, 0, pk, 1);
|
||||||
|
module.vec_znx_idft_tmp_a(&mut tmp_big, 0, &mut tmp_dft, 0);
|
||||||
|
tmp_big.add_normal(log_base2k, 0, pk.log_k(), source_xe, sigma, bound);
|
||||||
|
module.vec_znx_big_normalize(log_base2k, ct, 1, &tmp_big, 0, scratch_3);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use base2k::{Encoding, FFT64, Module, ScratchOwned, ZnxZero};
|
use base2k::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxZero};
|
||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
elem::{Infos, RLWECt, RLWEPt},
|
elem::{Infos, RLWECt, RLWECtDft, RLWEPt},
|
||||||
keys::{SecretKey, SecretKeyDft},
|
encryption::{decrypt_rlwe_dft_scratch_bytes, encrypt_zero_rlwe_dft_scratch_bytes},
|
||||||
|
keys::{PublicKey, SecretKey, SecretKeyDft},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{decrypt_rlwe_scratch_bytes, encrypt_rlwe_sk_scratch_bytes};
|
use super::{decrypt_rlwe_scratch_bytes, encrypt_rlwe_pk_scratch_bytes, encrypt_rlwe_sk_scratch_bytes};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn encrypt_sk_vec_znx_fft64() {
|
fn encrypt_sk_vec_znx_fft64() {
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(32);
|
let module: Module<FFT64> = Module::<FFT64>::new(32);
|
||||||
let log_base2k: usize = 8;
|
let log_base2k: usize = 8;
|
||||||
let log_k_ct: usize = 54;
|
let log_k_ct: usize = 54;
|
||||||
let log_k_pt: usize = 40;
|
let log_k_pt: usize = 30;
|
||||||
|
|
||||||
let sigma: f64 = 3.2;
|
let sigma: f64 = 3.2;
|
||||||
let bound: f64 = sigma * 6.0;
|
let bound: f64 = sigma * 6.0;
|
||||||
@@ -217,13 +389,16 @@ mod tests {
|
|||||||
let mut ct: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_ct, 2);
|
let mut ct: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_ct, 2);
|
||||||
let mut pt: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_pt);
|
let mut pt: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_pt);
|
||||||
|
|
||||||
|
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||||
let mut source_xe: Source = Source::new([0u8; 32]);
|
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||||
let mut source_xa: Source = Source::new([0u8; 32]);
|
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||||
|
|
||||||
let mut scratch: ScratchOwned =
|
let mut scratch: ScratchOwned =
|
||||||
ScratchOwned::new(encrypt_rlwe_sk_scratch_bytes(&module, ct.size()) | decrypt_rlwe_scratch_bytes(&module, ct.size()));
|
ScratchOwned::new(encrypt_rlwe_sk_scratch_bytes(&module, ct.size()) | decrypt_rlwe_scratch_bytes(&module, ct.size()));
|
||||||
|
|
||||||
let sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||||
|
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||||
|
|
||||||
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||||
sk_dft.dft(&module, &sk);
|
sk_dft.dft(&module, &sk);
|
||||||
|
|
||||||
@@ -242,9 +417,9 @@ mod tests {
|
|||||||
&sk_dft,
|
&sk_dft,
|
||||||
&mut source_xa,
|
&mut source_xa,
|
||||||
&mut source_xe,
|
&mut source_xe,
|
||||||
scratch.borrow(),
|
|
||||||
sigma,
|
sigma,
|
||||||
bound,
|
bound,
|
||||||
|
scratch.borrow(),
|
||||||
);
|
);
|
||||||
|
|
||||||
pt.data.zero();
|
pt.data.zero();
|
||||||
@@ -256,6 +431,7 @@ mod tests {
|
|||||||
pt.data
|
pt.data
|
||||||
.decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have);
|
.decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have);
|
||||||
|
|
||||||
|
// TODO: properly assert the decryption noise through std(dec(ct) - pt)
|
||||||
let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64;
|
let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64;
|
||||||
izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| {
|
izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| {
|
||||||
let b_scaled = (*b as f64) / scale;
|
let b_scaled = (*b as f64) / scale;
|
||||||
@@ -269,4 +445,118 @@ mod tests {
|
|||||||
|
|
||||||
module.free();
|
module.free();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encrypt_zero_rlwe_dft_sk_fft64() {
|
||||||
|
let module: Module<FFT64> = Module::<FFT64>::new(1024);
|
||||||
|
let log_base2k: usize = 8;
|
||||||
|
let log_k_ct: usize = 55;
|
||||||
|
|
||||||
|
let sigma: f64 = 3.2;
|
||||||
|
let bound: f64 = sigma * 6.0;
|
||||||
|
|
||||||
|
let mut pt: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
|
||||||
|
|
||||||
|
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||||
|
let mut source_xe: Source = Source::new([1u8; 32]);
|
||||||
|
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||||
|
|
||||||
|
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||||
|
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||||
|
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||||
|
sk_dft.dft(&module, &sk);
|
||||||
|
|
||||||
|
let mut ct_dft: RLWECtDft<Vec<u8>, FFT64> = RLWECtDft::new(&module, log_base2k, log_k_ct, 2);
|
||||||
|
|
||||||
|
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||||
|
encrypt_rlwe_sk_scratch_bytes(&module, ct_dft.size())
|
||||||
|
| decrypt_rlwe_dft_scratch_bytes(&module, ct_dft.size())
|
||||||
|
| encrypt_zero_rlwe_dft_scratch_bytes(&module, ct_dft.size()),
|
||||||
|
);
|
||||||
|
|
||||||
|
ct_dft.encrypt_zero_sk(
|
||||||
|
&module,
|
||||||
|
&sk_dft,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
sigma,
|
||||||
|
bound,
|
||||||
|
scratch.borrow(),
|
||||||
|
);
|
||||||
|
ct_dft.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||||
|
|
||||||
|
assert!((sigma - pt.data.std(0, log_base2k) * (log_k_ct as f64).exp2()) <= 0.2);
|
||||||
|
module.free();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encrypt_pk_vec_znx_fft64() {
|
||||||
|
let module: Module<FFT64> = Module::<FFT64>::new(32);
|
||||||
|
let log_base2k: usize = 8;
|
||||||
|
let log_k_ct: usize = 54;
|
||||||
|
let log_k_pk: usize = 64;
|
||||||
|
|
||||||
|
let sigma: f64 = 3.2;
|
||||||
|
let bound: f64 = sigma * 6.0;
|
||||||
|
|
||||||
|
let mut ct: RLWECt<Vec<u8>> = RLWECt::new(&module, log_base2k, log_k_ct, 2);
|
||||||
|
let mut pt_want: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
|
||||||
|
|
||||||
|
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||||
|
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||||
|
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||||
|
let mut source_xu: Source = Source::new([0u8; 32]);
|
||||||
|
|
||||||
|
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module);
|
||||||
|
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||||
|
let mut sk_dft: SecretKeyDft<Vec<u8>, FFT64> = SecretKeyDft::new(&module);
|
||||||
|
sk_dft.dft(&module, &sk);
|
||||||
|
|
||||||
|
let mut pk: PublicKey<Vec<u8>, FFT64> = PublicKey::new(&module, log_base2k, log_k_pk);
|
||||||
|
pk.generate(
|
||||||
|
&module,
|
||||||
|
&sk_dft,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
sigma,
|
||||||
|
bound,
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||||
|
encrypt_rlwe_sk_scratch_bytes(&module, ct.size())
|
||||||
|
| decrypt_rlwe_scratch_bytes(&module, ct.size())
|
||||||
|
| encrypt_rlwe_pk_scratch_bytes(&module, pk.size()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut data_want: Vec<i64> = vec![0i64; module.n()];
|
||||||
|
|
||||||
|
data_want
|
||||||
|
.iter_mut()
|
||||||
|
.for_each(|x| *x = source_xa.next_i64() & 0);
|
||||||
|
|
||||||
|
pt_want
|
||||||
|
.data
|
||||||
|
.encode_vec_i64(0, log_base2k, log_k_ct, &data_want, 10);
|
||||||
|
|
||||||
|
ct.encrypt_pk(
|
||||||
|
&module,
|
||||||
|
Some(&pt_want),
|
||||||
|
&pk,
|
||||||
|
&mut source_xu,
|
||||||
|
&mut source_xe,
|
||||||
|
sigma,
|
||||||
|
bound,
|
||||||
|
scratch.borrow(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut pt_have: RLWEPt<Vec<u8>> = RLWEPt::new(&module, log_base2k, log_k_ct);
|
||||||
|
|
||||||
|
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
|
||||||
|
|
||||||
|
module.vec_znx_sub_ab_inplace(&mut pt_want, 0, &pt_have, 0);
|
||||||
|
|
||||||
|
assert!(((1.0f64 / 12.0).sqrt() - pt_want.data.std(0, log_base2k) * (log_k_ct as f64).exp2()).abs() < 0.2);
|
||||||
|
|
||||||
|
module.free();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
109
rlwe/src/keys.rs
109
rlwe/src/keys.rs
@@ -1,19 +1,31 @@
|
|||||||
use base2k::{
|
use base2k::{
|
||||||
Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut,
|
Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut,
|
||||||
ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut,
|
ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, ScratchOwned, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxInfos,
|
||||||
};
|
};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
use crate::elem::derive_size;
|
use crate::{
|
||||||
|
elem::{Infos, RLWECtDft},
|
||||||
|
encryption::encrypt_zero_rlwe_dft_scratch_bytes,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub enum SecretDistribution {
|
||||||
|
TernaryFixed(usize), // Ternary with fixed Hamming weight
|
||||||
|
TernaryProb(f64), // Ternary with probabilistic Hamming weight
|
||||||
|
NONE,
|
||||||
|
}
|
||||||
|
|
||||||
pub struct SecretKey<T> {
|
pub struct SecretKey<T> {
|
||||||
pub data: ScalarZnx<T>,
|
pub data: ScalarZnx<T>,
|
||||||
|
pub dist: SecretDistribution,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SecretKey<Vec<u8>> {
|
impl SecretKey<Vec<u8>> {
|
||||||
pub fn new<B: Backend>(module: &Module<B>) -> Self {
|
pub fn new<B: Backend>(module: &Module<B>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
data: module.new_scalar(1),
|
data: module.new_scalar_znx(1),
|
||||||
|
dist: SecretDistribution::NONE,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -24,10 +36,12 @@ where
|
|||||||
{
|
{
|
||||||
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
|
pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) {
|
||||||
self.data.fill_ternary_prob(0, prob, source);
|
self.data.fill_ternary_prob(0, prob, source);
|
||||||
|
self.dist = SecretDistribution::TernaryProb(prob);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
|
pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) {
|
||||||
self.data.fill_ternary_hw(0, hw, source);
|
self.data.fill_ternary_hw(0, hw, source);
|
||||||
|
self.dist = SecretDistribution::TernaryFixed(hw);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,12 +65,14 @@ where
|
|||||||
|
|
||||||
pub struct SecretKeyDft<T, B: Backend> {
|
pub struct SecretKeyDft<T, B: Backend> {
|
||||||
pub data: ScalarZnxDft<T, B>,
|
pub data: ScalarZnxDft<T, B>,
|
||||||
|
pub dist: SecretDistribution,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> SecretKeyDft<Vec<u8>, B> {
|
impl<B: Backend> SecretKeyDft<Vec<u8>, B> {
|
||||||
pub fn new(module: &Module<B>) -> Self {
|
pub fn new(module: &Module<B>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
data: module.new_scalar_znx_dft(1),
|
data: module.new_scalar_znx_dft(1),
|
||||||
|
dist: SecretDistribution::NONE,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,7 +81,16 @@ impl<B: Backend> SecretKeyDft<Vec<u8>, B> {
|
|||||||
SecretKeyDft<Vec<u8>, B>: ScalarZnxDftToMut<base2k::FFT64>,
|
SecretKeyDft<Vec<u8>, B>: ScalarZnxDftToMut<base2k::FFT64>,
|
||||||
SecretKey<S>: ScalarZnxToRef,
|
SecretKey<S>: ScalarZnxToRef,
|
||||||
{
|
{
|
||||||
module.svp_prepare(self, 0, sk, 0)
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
match sk.dist {
|
||||||
|
SecretDistribution::NONE => panic!("invalid sk: SecretDistribution::NONE"),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module.svp_prepare(self, 0, sk, 0);
|
||||||
|
self.dist = sk.dist;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,21 +113,85 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub struct PublicKey<D, B: Backend> {
|
pub struct PublicKey<D, B: Backend> {
|
||||||
pub data: VecZnxDft<D, B>,
|
pub data: RLWECtDft<D, B>,
|
||||||
|
pub dist: SecretDistribution,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> PublicKey<Vec<u8>, B> {
|
impl<B: Backend> PublicKey<Vec<u8>, B> {
|
||||||
pub fn new(module: &Module<B>, log_base2k: usize, log_q: usize) -> Self {
|
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize) -> Self {
|
||||||
Self {
|
Self {
|
||||||
data: module.new_vec_znx_dft(2, derive_size(log_base2k, log_q)),
|
data: RLWECtDft::new(module, log_base2k, log_k, 2),
|
||||||
|
dist: SecretDistribution::NONE,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend, D: VecZnxDftToMut<B>> PublicKey<D, B> {
|
impl<T, B: Backend> Infos for PublicKey<T, B> {
|
||||||
pub fn generate<S>(&mut self, module: &Module<B>, sk: &SecretKey<ScalarZnxDft<S, B>>, scratch: &mut Scratch)
|
type Inner = VecZnxDft<T, B>;
|
||||||
|
|
||||||
|
fn inner(&self) -> &Self::Inner {
|
||||||
|
&self.data.data
|
||||||
|
}
|
||||||
|
|
||||||
|
fn log_base2k(&self) -> usize {
|
||||||
|
self.data.log_base2k
|
||||||
|
}
|
||||||
|
|
||||||
|
fn log_k(&self) -> usize {
|
||||||
|
self.data.log_k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C, B: Backend> VecZnxDftToMut<B> for PublicKey<C, B>
|
||||||
where
|
where
|
||||||
ScalarZnxDft<S, B>: ScalarZnxDftToMut<B>,
|
VecZnxDft<C, B>: VecZnxDftToMut<B>,
|
||||||
{
|
{
|
||||||
|
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||||
|
self.data.to_mut()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C, B: Backend> VecZnxDftToRef<B> for PublicKey<C, B>
|
||||||
|
where
|
||||||
|
VecZnxDft<C, B>: VecZnxDftToRef<B>,
|
||||||
|
{
|
||||||
|
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||||
|
self.data.to_ref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C> PublicKey<C, FFT64> {
|
||||||
|
pub fn generate<S>(
|
||||||
|
&mut self,
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
sk_dft: &SecretKeyDft<S, FFT64>,
|
||||||
|
source_xa: &mut Source,
|
||||||
|
source_xe: &mut Source,
|
||||||
|
sigma: f64,
|
||||||
|
bound: f64,
|
||||||
|
) where
|
||||||
|
VecZnxDft<C, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
|
||||||
|
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64> + ZnxInfos,
|
||||||
|
{
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
match sk_dft.dist {
|
||||||
|
SecretDistribution::NONE => panic!("invalid sk_dft: SecretDistribution::NONE"),
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Its ok to allocate scratch space here since pk is usually generated only once.
|
||||||
|
let mut scratch: ScratchOwned = ScratchOwned::new(encrypt_zero_rlwe_dft_scratch_bytes(module, self.size()));
|
||||||
|
self.data.encrypt_zero_sk(
|
||||||
|
module,
|
||||||
|
sk_dft,
|
||||||
|
source_xa,
|
||||||
|
source_xe,
|
||||||
|
sigma,
|
||||||
|
bound,
|
||||||
|
scratch.borrow(),
|
||||||
|
);
|
||||||
|
self.dist = sk_dft.dist;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user