From 6ce525e5a1516d6156b239862dd2a03b8374d237 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 7 May 2025 12:05:12 +0200 Subject: [PATCH] added sk encryption --- base2k/examples/rlwe_encrypt.rs | 4 +- base2k/src/scalar_znx_dft_ops.rs | 4 +- base2k/src/znx_base.rs | 12 +- rlwe/src/elem.rs | 246 ++++++++++++--------------- rlwe/src/encryption.rs | 283 ++++++++++++++++++------------- rlwe/src/keys.rs | 83 ++++++--- 6 files changed, 333 insertions(+), 299 deletions(-) diff --git a/base2k/examples/rlwe_encrypt.rs b/base2k/examples/rlwe_encrypt.rs index 4d2961c..16b7d3a 100644 --- a/base2k/examples/rlwe_encrypt.rs +++ b/base2k/examples/rlwe_encrypt.rs @@ -1,5 +1,5 @@ use base2k::{ - AddNormal, Encoding, FFT64, FillUniform, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, + AddNormal, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxInfos, }; @@ -20,7 +20,7 @@ fn main() { let mut source: Source = Source::new(seed); // s <- Z_{-1, 0, 1}[X]/(X^{N}+1) - let mut s: Scalar> = module.new_scalar(1); + let mut s: ScalarZnx> = module.new_scalar(1); s.fill_ternary_prob(0, 0.5, &mut source); // Buffer to store s in the DFT domain diff --git a/base2k/src/scalar_znx_dft_ops.rs b/base2k/src/scalar_znx_dft_ops.rs index 888b2a9..f5f8f7f 100644 --- a/base2k/src/scalar_znx_dft_ops.rs +++ b/base2k/src/scalar_znx_dft_ops.rs @@ -2,8 +2,8 @@ use crate::ffi::svp; use crate::ffi::vec_znx_dft::vec_znx_dft_t; use crate::znx_base::{ZnxInfos, ZnxView, ZnxViewMut}; use crate::{ - Backend, FFT64, Module, ScalarZnxToRef, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, - VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, + Backend, FFT64, Module, ScalarZnxDft, ScalarZnxDftOwned, ScalarZnxDftToMut, ScalarZnxDftToRef, ScalarZnxToRef, VecZnxDft, + VecZnxDftToMut, VecZnxDftToRef, }; pub trait ScalarZnxDftAlloc { diff --git a/base2k/src/znx_base.rs b/base2k/src/znx_base.rs index e8dcab2..5230dfd 100644 --- a/base2k/src/znx_base.rs +++ b/base2k/src/znx_base.rs @@ -107,21 +107,13 @@ where { fn zero(&mut self) { unsafe { - std::ptr::write_bytes( - self.as_mut_ptr(), - 0, - self.n() * self.poly_count(), - ); + std::ptr::write_bytes(self.as_mut_ptr(), 0, self.n() * self.poly_count()); } } fn zero_at(&mut self, i: usize, j: usize) { unsafe { - std::ptr::write_bytes( - self.at_mut_ptr(i, j), - 0, - self.n(), - ); + std::ptr::write_bytes(self.at_mut_ptr(i, j), 0, self.n()); } } } diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index 1126ed4..fe1b3b4 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,6 +1,6 @@ use base2k::{ - Backend, DataView, DataViewMut, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnxDftToRef, VecZnx, - VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, + Backend, Module, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, + ZnxInfos, }; pub trait Infos { @@ -31,7 +31,7 @@ pub trait Infos { /// Returns the number of size per polynomial. fn size(&self) -> usize { let size: usize = self.inner().size(); - debug_assert_eq!(size, derive_size(self.log_base2k(), self.log_q())); + debug_assert_eq!(size, derive_size(self.log_base2k(), self.log_k())); size } @@ -43,18 +43,18 @@ pub trait Infos { /// Returns the base 2 logarithm of the ciphertext base. fn log_base2k(&self) -> usize; - /// Returns the base 2 logarithm of the ciphertext modulus. - fn log_q(&self) -> usize; + /// Returns the bit precision of the ciphertext. + fn log_k(&self) -> usize; } -pub struct RLWECt{ - data: VecZnx, - log_base2k: usize, - log_q: usize, +pub struct RLWECt { + pub data: VecZnx, + pub log_base2k: usize, + pub log_k: usize, } -impl Infos for RLWECt { - type Inner = T; +impl Infos for RLWECt { + type Inner = VecZnx; fn inner(&self) -> &Self::Inner { &self.data @@ -64,32 +64,37 @@ impl Infos for RLWECt { self.log_base2k } - fn log_q(&self) -> usize { - self.log_q + fn log_k(&self) -> usize { + self.log_k } } -impl DataView for Ciphertext { - type D = D; - fn data(&self) -> &Self::D { - &self.data +impl VecZnxToMut for RLWECt +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() } } -impl DataViewMut for Ciphertext { - fn data_mut(&mut self) -> &mut Self::D { - &mut self.data +impl VecZnxToRef for RLWECt +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() } } -pub struct Plaintext { - data: T, - log_base2k: usize, - log_q: usize, +pub struct RLWEPt { + pub data: VecZnx, + pub log_base2k: usize, + pub log_k: usize, } -impl Infos for Plaintext { - type Inner = T; +impl Infos for RLWEPt { + type Inner = VecZnx; fn inner(&self) -> &Self::Inner { &self.data @@ -99,140 +104,99 @@ impl Infos for Plaintext { self.log_base2k } - fn log_q(&self) -> usize { - self.log_q + fn log_k(&self) -> usize { + self.log_k } } -impl Plaintext { - pub fn data(&self) -> &T { +impl VecZnxToMut for RLWEPt +where + VecZnx: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data.to_mut() + } +} + +impl VecZnxToRef for RLWEPt +where + VecZnx: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data.to_ref() + } +} + +impl RLWECt> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize, cols: usize) -> Self { + Self { + data: module.new_vec_znx(cols, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl RLWEPt> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx(1, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +pub struct RLWECtDft { + pub data: VecZnxDft, + pub log_base2k: usize, + pub log_k: usize, +} + +impl RLWECtDft, B> { + pub fn new(module: &Module, log_base2k: usize, log_k: usize) -> Self { + Self { + data: module.new_vec_znx_dft(1, derive_size(log_base2k, log_k)), + log_base2k: log_base2k, + log_k: log_k, + } + } +} + +impl Infos for RLWECtDft { + type Inner = VecZnxDft; + + fn inner(&self) -> &Self::Inner { &self.data } - pub fn data_mut(&mut self) -> &mut T { - &mut self.data + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_k(&self) -> usize { + self.log_k } } -pub(crate) type CtVecZnx = Ciphertext>; -pub(crate) type CtVecZnxDft = Ciphertext>; -pub(crate) type CtMatZnxDft = Ciphertext>; -pub(crate) type PtVecZnx = Plaintext>; -pub(crate) type PtVecZnxDft = Plaintext>; -pub(crate) type PtMatZnxDft = Plaintext>; - -impl VecZnxToMut for Ciphertext +impl VecZnxDftToMut for RLWECtDft where - D: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data_mut().to_mut() - } -} - -impl VecZnxToRef for Ciphertext -where - D: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data().to_ref() - } -} - -impl Ciphertext>> { - pub fn new(module: &Module, log_base2k: usize, log_q: usize, cols: usize) -> Self { - Self { - data: module.new_vec_znx(cols, derive_size(log_base2k, log_q)), - log_base2k: log_base2k, - log_q: log_q, - } - } -} - -impl VecZnxToMut for Plaintext -where - D: VecZnxToMut, -{ - fn to_mut(&mut self) -> VecZnx<&mut [u8]> { - self.data_mut().to_mut() - } -} - -impl VecZnxToRef for Plaintext -where - D: VecZnxToRef, -{ - fn to_ref(&self) -> VecZnx<&[u8]> { - self.data().to_ref() - } -} - -impl Plaintext>> { - pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self { - Self { - data: module.new_vec_znx(1, derive_size(log_base2k, log_q)), - log_base2k: log_base2k, - log_q: log_q, - } - } -} - -impl VecZnxDftToMut for Ciphertext -where - D: VecZnxDftToMut, + VecZnxDft: VecZnxDftToMut, { fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { - self.data_mut().to_mut() + self.data.to_mut() } } -impl VecZnxDftToRef for Ciphertext +impl VecZnxDftToRef for RLWECtDft where - D: VecZnxDftToRef, + VecZnxDft: VecZnxDftToRef, { fn to_ref(&self) -> VecZnxDft<&[u8], B> { - self.data().to_ref() + self.data.to_ref() } } -impl Ciphertext, B>> { - pub fn new(module: &Module, log_base2k: usize, log_q: usize, cols: usize) -> Self { - Self { - data: module.new_vec_znx_dft(cols, derive_size(log_base2k, log_q)), - log_base2k: log_base2k, - log_q: log_q, - } - } -} - -impl MatZnxDftToMut for Ciphertext -where - D: MatZnxDftToMut, -{ - fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { - self.data_mut().to_mut() - } -} - -impl MatZnxDftToRef for Ciphertext -where - D: MatZnxDftToRef, -{ - fn to_ref(&self) -> MatZnxDft<&[u8], B> { - self.data().to_ref() - } -} - -impl Ciphertext, B>> { - pub fn new(module: &Module, log_base2k: usize, rows: usize, cols_in: usize, cols_out: usize, log_q: usize) -> Self { - Self { - data: module.new_mat_znx_dft(rows, cols_in, cols_out, derive_size(log_base2k, log_q)), - log_base2k: log_base2k, - log_q: log_q, - } - } -} - -pub(crate) fn derive_size(log_base2k: usize, log_q: usize) -> usize { - (log_q + log_base2k - 1) / log_base2k +pub(crate) fn derive_size(log_base2k: usize, log_k: usize) -> usize { + (log_k + log_base2k - 1) / log_base2k } diff --git a/rlwe/src/encryption.rs b/rlwe/src/encryption.rs index e0f9e1f..148ded4 100644 --- a/rlwe/src/encryption.rs +++ b/rlwe/src/encryption.rs @@ -1,161 +1,166 @@ +use std::cmp::min; + use base2k::{ - AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, - VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxToMut, - VecZnxToRef, ZnxInfos, + AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, + VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, + VecZnxDftToRef, VecZnxToMut, VecZnxToRef, }; use sampling::source::Source; use crate::{ - elem::{Ciphertext, Infos, Plaintext}, - keys::SecretKey, + elem::{Infos, RLWECt, RLWECtDft, RLWEPt}, + keys::SecretKeyDft, }; -pub trait EncryptSk { - fn encrypt( - module: &Module, - res: &mut Ciphertext, - pt: Option<&Plaintext

>, - sk: &SecretKey, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - sigma: f64, - bound: f64, - ) where - S: ScalarZnxDftToRef; - - fn encrypt_scratch_bytes(module: &Module, size: usize) -> usize; +pub fn encrypt_rlwe_sk_scratch_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) } -impl EncryptSk for Ciphertext -where - C: VecZnxToMut + ZnxInfos, - P: VecZnxToRef + ZnxInfos, +pub fn encrypt_rlwe_sk( + module: &Module, + ct: &mut RLWECt, + pt: Option<&RLWEPt

>, + sk: &SecretKeyDft, + source_xa: &mut Source, + source_xe: &mut Source, + scratch: &mut Scratch, + sigma: f64, + bound: f64, +) where + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx

: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, { - fn encrypt( - module: &Module, - ct: &mut Ciphertext, - pt: Option<&Plaintext

>, - sk: &SecretKey, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - sigma: f64, - bound: f64, - ) where - S: ScalarZnxDftToRef, + let log_base2k: usize = ct.log_base2k(); + let log_k: usize = ct.log_k(); + let size: usize = ct.size(); + + // c1 = a + ct.data.fill_uniform(log_base2k, 1, size, source_xa); + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); + { - let log_base2k: usize = ct.log_base2k(); - let log_q: usize = ct.log_q(); - let mut ct_mut: VecZnx<&mut [u8]> = ct.to_mut(); - let size: usize = ct_mut.size(); + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + module.vec_znx_dft(&mut c0_dft, 0, ct, 1); - // c1 = a - ct_mut.fill_uniform(log_base2k, 1, size, source_xa); + // c0_dft = DFT(a) * DFT(s) + module.svp_apply_inplace(&mut c0_dft, 0, sk, 0); - let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); - - { - let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); - module.vec_znx_dft(&mut c0_dft, 0, &ct_mut, 1); - - // c0_dft = DFT(a) * DFT(s) - module.svp_apply_inplace(&mut c0_dft, 0, &sk.data().to_ref(), 0); - - // c0_big = IDFT(c0_dft) - module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); - } - - // c0_big = m - c0_big - if let Some(pt) = pt { - module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0); - } - // c0_big += e - c0_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound); - - // c0 = norm(c0_big = -as + m + e) - module.vec_znx_big_normalize(log_base2k, &mut ct_mut, 0, &c0_big, 0, scratch_1); + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); } - fn encrypt_scratch_bytes(module: &Module, size: usize) -> usize { - (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + // c0_big = m - c0_big + if let Some(pt) = pt { + module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0); } + // c0_big += e + c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); + + // c0 = norm(c0_big = -as + m + e) + module.vec_znx_big_normalize(log_base2k, ct, 0, &c0_big, 0, scratch_1); } -impl Ciphertext -where - C: VecZnxToMut + ZnxInfos, +pub fn decrypt_rlwe( + module: &Module, + pt: &mut RLWEPt

, + ct: &RLWECt, + sk: &SecretKeyDft, + scratch: &mut Scratch, +) where + VecZnx

: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, { + let size: usize = min(pt.size(), ct.size()); + + let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); + + { + let (mut c0_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); + module.vec_znx_dft(&mut c0_dft, 0, ct, 1); + + // c0_dft = DFT(a) * DFT(s) + module.svp_apply_inplace(&mut c0_dft, 0, sk, 0); + + // c0_big = IDFT(c0_dft) + module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut c0_dft, 0); + } + + // c0_big = (a * s) + (-a * s + m + e) = BIG(m + e) + module.vec_znx_big_add_small_inplace(&mut c0_big, 0, ct, 0); + + // pt = norm(BIG(m + e)) + module.vec_znx_big_normalize(ct.log_base2k(), pt, 0, &mut c0_big, 0, scratch_1); + + pt.log_base2k = ct.log_base2k(); + pt.log_k = min(pt.log_k(), ct.log_k()); +} + +pub fn decrypt_rlwe_scratch_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) +} + +impl RLWECt { pub fn encrypt_sk( &mut self, module: &Module, - pt: Option<&Plaintext

>, - sk: &SecretKey, + pt: Option<&RLWEPt

>, + sk: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, sigma: f64, bound: f64, ) where - P: VecZnxToRef + ZnxInfos, - S: ScalarZnxDftToRef, + VecZnx: VecZnxToMut + VecZnxToRef, + VecZnx

: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, { - >::encrypt( + encrypt_rlwe_sk( module, self, pt, sk, source_xa, source_xe, scratch, sigma, bound, - ); + ) } - pub fn encrypt_sk_scratch_bytes

(module: &Module, size: usize) -> usize + pub fn decrypt(&self, module: &Module, pt: &mut RLWEPt

, sk: &SecretKeyDft, scratch: &mut Scratch) where - Self: EncryptSk, + VecZnx

: VecZnxToMut + VecZnxToRef, + VecZnx: VecZnxToRef, + ScalarZnxDft: ScalarZnxDftToRef, { - >::encrypt_scratch_bytes(module, size) + decrypt_rlwe(module, pt, self, sk, scratch); } } -pub trait EncryptZeroSk { - fn encrypt_zero( - module: &Module, - res: &mut D, - sk: &SecretKey, - source_xa: &mut Source, - source_xe: &mut Source, - scratch: &mut Scratch, - sigma: f64, - bound: f64, - ) where - S: ScalarZnxDftToRef; - - fn encrypt_zero_scratch_bytes(module: &Module, size: usize) -> usize; +pub(crate) fn encrypt_rlwe_zero_dft_scratch_bytes(module: &Module, size: usize) -> usize { + (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) } -impl EncryptZeroSk for C -where - C: VecZnxDftToMut + ZnxInfos + Infos, -{ +impl RLWECtDft { fn encrypt_zero( module: &Module, - ct: &mut C, - sk: &SecretKey, + ct: &mut RLWECtDft, + sk: &SecretKeyDft, source_xa: &mut Source, source_xe: &mut Source, scratch: &mut Scratch, sigma: f64, bound: f64, ) where - S: ScalarZnxDftToRef, + VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + ScalarZnxDft: ScalarZnxDftToRef, { let log_base2k: usize = ct.log_base2k(); - let log_q: usize = ct.log_q(); - let mut ct_mut: VecZnxDft<&mut [u8], FFT64> = ct.to_mut(); - let size: usize = ct_mut.size(); + let log_k: usize = ct.log_k(); + let size: usize = ct.size(); // ct[1] = DFT(a) { let (mut tmp_znx, _) = scratch.tmp_vec_znx(module, 1, size); tmp_znx.fill_uniform(log_base2k, 1, size, source_xa); - module.vec_znx_dft(&mut ct_mut, 1, &tmp_znx, 0); + module.vec_znx_dft(ct, 1, &tmp_znx, 0); } let (mut c0_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, size); @@ -163,22 +168,22 @@ where { let (mut tmp_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, size); // c0_dft = DFT(a) * DFT(s) - module.svp_apply(&mut tmp_dft, 0, &sk.data().to_ref(), 0, &ct_mut, 1); + module.svp_apply(&mut tmp_dft, 0, sk, 0, ct, 1); // c0_big = IDFT(c0_dft) module.vec_znx_idft_tmp_a(&mut c0_big, 0, &mut tmp_dft, 0); } // c0_big += e - c0_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound); + c0_big.add_normal(log_base2k, 0, log_k, source_xe, sigma, bound); // c0 = norm(c0_big = -as + e) let (mut tmp_znx, scratch_2) = scratch_1.tmp_vec_znx(module, 1, size); module.vec_znx_big_normalize(log_base2k, &mut tmp_znx, 0, &c0_big, 0, scratch_2); // ct[0] = DFT(-as + e) - module.vec_znx_dft(&mut ct_mut, 0, &tmp_znx, 0); + module.vec_znx_dft(ct, 0, &tmp_znx, 0); } - fn encrypt_zero_scratch_bytes(module: &Module, size: usize) -> usize{ + fn encrypt_zero_scratch_bytes(module: &Module, size: usize) -> usize { (module.bytes_of_vec_znx(1, size) | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) + module.bytes_of_vec_znx(1, size) @@ -188,42 +193,80 @@ where #[cfg(test)] mod tests { - use base2k::{FFT64, Module, ScratchOwned, VecZnx, Scalar}; + use base2k::{Encoding, FFT64, Module, ScratchOwned, ZnxZero}; + use itertools::izip; use sampling::source::Source; - use crate::{elem::{Ciphertext, Infos, Plaintext}, keys::SecretKey}; + use crate::{ + elem::{Infos, RLWECt, RLWEPt}, + keys::{SecretKey, SecretKeyDft}, + }; + + use super::{decrypt_rlwe_scratch_bytes, encrypt_rlwe_sk_scratch_bytes}; #[test] fn encrypt_sk_vec_znx_fft64() { let module: Module = Module::::new(32); let log_base2k: usize = 8; - let log_q: usize = 54; + let log_k_ct: usize = 54; + let log_k_pt: usize = 40; let sigma: f64 = 3.2; - let bound: f64 = sigma * 6; + let bound: f64 = sigma * 6.0; - let mut ct: Ciphertext>> = Ciphertext::>>::new(&module, log_base2k, log_q, 2); - let mut pt: Plaintext>> = Plaintext::>>::new(&module, log_base2k, log_q); + let mut ct: RLWECt> = RLWECt::new(&module, log_base2k, log_k_ct, 2); + let mut pt: RLWEPt> = RLWEPt::new(&module, log_base2k, log_k_pt); - let mut source_xe = Source::new([0u8; 32]); + let mut source_xe: Source = Source::new([0u8; 32]); let mut source_xa: Source = Source::new([0u8; 32]); - - let mut scratch: ScratchOwned = ScratchOwned::new(ct.encrypt_encsk_scratch_bytes(&module, ct.size())); + let mut scratch: ScratchOwned = + ScratchOwned::new(encrypt_rlwe_sk_scratch_bytes(&module, ct.size()) | decrypt_rlwe_scratch_bytes(&module, ct.size())); - let mut sk: SecretKey>> = SecretKey::new(&module); - let mut sk_prep - sk.svp_prepare(&module, &mut sk_prep); + let sk: SecretKey> = SecretKey::new(&module); + let mut sk_dft: SecretKeyDft, FFT64> = SecretKeyDft::new(&module); + sk_dft.dft(&module, &sk); + + let mut data_want: Vec = vec![0i64; module.n()]; + + data_want + .iter_mut() + .for_each(|x| *x = source_xa.next_i64() & 0xFF); + + pt.data + .encode_vec_i64(0, log_base2k, log_k_pt, &data_want, 10); ct.encrypt_sk( &module, Some(&pt), - &sk_prep, + &sk_dft, &mut source_xa, &mut source_xe, scratch.borrow(), sigma, bound, ); + + pt.data.zero(); + + ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow()); + + let mut data_have: Vec = vec![0i64; module.n()]; + + pt.data + .decode_vec_i64(0, log_base2k, pt.size() * log_base2k, &mut data_have); + + let scale: f64 = (1 << (pt.size() * log_base2k - log_k_pt)) as f64; + izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| { + let b_scaled = (*b as f64) / scale; + assert!( + (*a as f64 - b_scaled).abs() < 0.1, + "{} {}", + *a as f64, + b_scaled + ) + }); + + module.free(); } } diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index ee8bb94..767d1eb 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -1,31 +1,27 @@ use base2k::{ - Backend, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, ZnxInfos, FFT64 + Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxDftToMut, + ScalarZnxDftToRef, ScalarZnxToMut, ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, }; use sampling::source::Source; use crate::elem::derive_size; pub struct SecretKey { - data: T, + pub data: ScalarZnx, } -impl SecretKey { - pub fn data(&self) -> &T { - &self.data - } - - pub fn data_mut(&mut self) -> &mut T { - &mut self.data - } -} - -impl SecretKey>> { +impl SecretKey> { pub fn new(module: &Module) -> Self { Self { data: module.new_scalar(1), } } +} +impl SecretKey +where + S: AsMut<[u8]> + AsRef<[u8]>, +{ pub fn fill_ternary_prob(&mut self, prob: f64, source: &mut Source) { self.data.fill_ternary_prob(0, prob, source); } @@ -33,27 +29,66 @@ impl SecretKey>> { pub fn fill_ternary_hw(&mut self, hw: usize, source: &mut Source) { self.data.fill_ternary_hw(0, hw, source); } +} - pub fn svp_prepare(&self, module: &Module, sk_prep: &mut SecretKey>) - where - ScalarZnxDft: ScalarZnxDftToMut, - { - module.svp_prepare(&mut sk_prep.data, 0, &self.data, 0) +impl ScalarZnxToMut for SecretKey +where + ScalarZnx: ScalarZnxToMut, +{ + fn to_mut(&mut self) -> ScalarZnx<&mut [u8]> { + self.data.to_mut() } } -type SecretKeyPrep = SecretKey>; +impl ScalarZnxToRef for SecretKey +where + ScalarZnx: ScalarZnxToRef, +{ + fn to_ref(&self) -> ScalarZnx<&[u8]> { + self.data.to_ref() + } +} -impl SecretKey, B>> { - pub fn new(module: &Module) -> Self{ - Self{ - data: module.new_scalar_znx_dft(1) +pub struct SecretKeyDft { + pub data: ScalarZnxDft, +} + +impl SecretKeyDft, B> { + pub fn new(module: &Module) -> Self { + Self { + data: module.new_scalar_znx_dft(1), } } + + pub fn dft(&mut self, module: &Module, sk: &SecretKey) + where + SecretKeyDft, B>: ScalarZnxDftToMut, + SecretKey: ScalarZnxToRef, + { + module.svp_prepare(self, 0, sk, 0) + } +} + +impl ScalarZnxDftToMut for SecretKeyDft +where + ScalarZnxDft: ScalarZnxDftToMut, +{ + fn to_mut(&mut self) -> ScalarZnxDft<&mut [u8], B> { + self.data.to_mut() + } +} + +impl ScalarZnxDftToRef for SecretKeyDft +where + ScalarZnxDft: ScalarZnxDftToRef, +{ + fn to_ref(&self) -> ScalarZnxDft<&[u8], B> { + self.data.to_ref() + } } pub struct PublicKey { - data: VecZnxDft, + pub data: VecZnxDft, } impl PublicKey, B> {