From 9afe9372bd843458544ef49031880023a0739581 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Tue, 6 May 2025 18:02:00 +0200 Subject: [PATCH] wip, playing with base2k traits in rlwe crate to ensure inherent compatibility --- rlwe/src/elem.rs | 171 +++++++++++++++++++++++++++++++++++++---- rlwe/src/encryption.rs | 31 ++++---- rlwe/src/keys.rs | 4 +- 3 files changed, 172 insertions(+), 34 deletions(-) diff --git a/rlwe/src/elem.rs b/rlwe/src/elem.rs index 5749208..3cb1360 100644 --- a/rlwe/src/elem.rs +++ b/rlwe/src/elem.rs @@ -1,4 +1,52 @@ -use base2k::{Backend, FFT64, MatZnxDft, MatZnxDftAlloc, Module, VecZnx, VecZnxAlloc, VecZnxDft, VecZnxDftAlloc}; +use base2k::{ + Backend, DataView, DataViewMut, MatZnxDft, MatZnxDftAlloc, MatZnxDftToMut, MatZnxDftToRef, Module, VecZnx, VecZnxAlloc, + VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxInfos, +}; + +pub trait Infos +where + T: ZnxInfos, +{ + fn inner(&self) -> &T; + + /// Returns the ring degree of the polynomials. + fn n(&self) -> usize { + self.inner().n() + } + + /// Returns the base two logarithm of the ring dimension of the polynomials. + fn log_n(&self) -> usize { + self.inner().log_n() + } + + /// Returns the number of rows. + fn rows(&self) -> usize { + self.inner().rows() + } + + /// Returns the number of polynomials in each row. + fn cols(&self) -> usize { + self.inner().cols() + } + + /// Returns the number of size per polynomial. + fn size(&self) -> usize { + let size: usize = self.inner().size(); + debug_assert_eq!(size, derive_size(self.log_base2k(), self.log_q())); + size + } + + /// Returns the total number of small polynomials. + fn poly_count(&self) -> usize { + self.rows() * self.cols() * self.size() + } + + /// Returns the base 2 logarithm of the ciphertext base. + fn log_base2k(&self) -> usize; + + /// Returns the base 2 logarithm of the ciphertext modulus. + fn log_q(&self) -> usize; +} pub struct Ciphertext { data: T, @@ -6,20 +54,32 @@ pub struct Ciphertext { log_q: usize, } -impl Ciphertext { - pub fn log_base2k(&self) -> usize { - self.log_base2k - } - - pub fn log_q(&self) -> usize { - self.log_q - } - - pub fn data(&self) -> &T { +impl Infos for Ciphertext +where + T: ZnxInfos, +{ + fn inner(&self) -> &T { &self.data } - pub fn data_mut(&mut self) -> &mut T { + fn log_base2k(&self) -> usize { + self.log_base2k + } + + fn log_q(&self) -> usize { + self.log_q + } +} + +impl DataView for Ciphertext { + type D = D; + fn data(&self) -> &Self::D { + &self.data + } +} + +impl DataViewMut for Ciphertext { + fn data_mut(&mut self) -> &mut Self::D { &mut self.data } } @@ -30,15 +90,24 @@ pub struct Plaintext { log_q: usize, } -impl Plaintext { - pub fn log_base2k(&self) -> usize { +impl Infos for Plaintext +where + T: ZnxInfos, +{ + fn inner(&self) -> &T { + &self.data + } + + fn log_base2k(&self) -> usize { self.log_base2k } - pub fn log_q(&self) -> usize { + fn log_q(&self) -> usize { self.log_q } +} +impl Plaintext { pub fn data(&self) -> &T { &self.data } @@ -55,6 +124,24 @@ pub(crate) type PtVecZnx = Plaintext>; pub(crate) type PtVecZnxDft = Plaintext>; pub(crate) type PtMatZnxDft = Plaintext>; +impl VecZnxToMut for Ciphertext +where + D: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data_mut().to_mut() + } +} + +impl VecZnxToRef for Ciphertext +where + D: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data().to_ref() + } +} + impl Ciphertext>> { pub fn new(module: &Module, log_base2k: usize, log_q: usize, cols: usize) -> Self { Self { @@ -65,6 +152,24 @@ impl Ciphertext>> { } } +impl VecZnxToMut for Plaintext +where + D: VecZnxToMut, +{ + fn to_mut(&mut self) -> VecZnx<&mut [u8]> { + self.data_mut().to_mut() + } +} + +impl VecZnxToRef for Plaintext +where + D: VecZnxToRef, +{ + fn to_ref(&self) -> VecZnx<&[u8]> { + self.data().to_ref() + } +} + impl Plaintext>> { pub fn new(module: &Module, log_base2k: usize, log_q: usize) -> Self { Self { @@ -75,6 +180,24 @@ impl Plaintext>> { } } +impl VecZnxDftToMut for Ciphertext +where + D: VecZnxDftToMut, +{ + fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> { + self.data_mut().to_mut() + } +} + +impl VecZnxDftToRef for Ciphertext +where + D: VecZnxDftToRef, +{ + fn to_ref(&self) -> VecZnxDft<&[u8], B> { + self.data().to_ref() + } +} + impl Ciphertext, B>> { pub fn new(module: &Module, log_base2k: usize, log_q: usize, cols: usize) -> Self { Self { @@ -85,6 +208,24 @@ impl Ciphertext, B>> { } } +impl MatZnxDftToMut for Ciphertext +where + D: MatZnxDftToMut, +{ + fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> { + self.data_mut().to_mut() + } +} + +impl MatZnxDftToRef for Ciphertext +where + D: MatZnxDftToRef, +{ + fn to_ref(&self) -> MatZnxDft<&[u8], B> { + self.data().to_ref() + } +} + impl Ciphertext, B>> { pub fn new(module: &Module, log_base2k: usize, rows: usize, cols_in: usize, cols_out: usize, log_q: usize) -> Self { Self { diff --git a/rlwe/src/encryption.rs b/rlwe/src/encryption.rs index 3d62bfe..de3146f 100644 --- a/rlwe/src/encryption.rs +++ b/rlwe/src/encryption.rs @@ -1,15 +1,12 @@ use base2k::{ AddNormal, Backend, FFT64, FillUniform, Module, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx, VecZnxAlloc, - VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, - VecZnxToMut, VecZnxToRef, ZnxInfos, + VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxToMut, + VecZnxToRef, ZnxInfos, }; use sampling::source::Source; -use crate::{ - elem::{CtVecZnx, CtVecZnxDft, PtVecZnx}, - keys::SecretKey, -}; +use crate::{elem::Infos, keys::SecretKey}; pub trait EncryptSk { fn encrypt( @@ -30,15 +27,15 @@ pub trait EncryptSk { } } -impl EncryptSk, PtVecZnx

> for CtVecZnx +impl EncryptSk for C where - VecZnx: VecZnxToMut + VecZnxToRef, - VecZnx

: VecZnxToRef, + C: VecZnxToMut + ZnxInfos + Infos, + P: VecZnxToRef, { fn encrypt( module: &Module, - ct: &mut CtVecZnx, - pt: Option<&PtVecZnx

>, + ct: &mut C, + pt: Option<&P>, sk: &SecretKey, source_xa: &mut Source, source_xe: &mut Source, @@ -50,7 +47,7 @@ where { let log_base2k: usize = ct.log_base2k(); let log_q: usize = ct.log_q(); - let mut ct_mut: VecZnx<&mut [u8]> = ct.data_mut().to_mut(); + let mut ct_mut: VecZnx<&mut [u8]> = ct.to_mut(); let size: usize = ct_mut.size(); // c1 = a @@ -71,7 +68,7 @@ where // c0_big = m - c0_big if let Some(pt) = pt { - module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, &pt.data().to_ref(), 0); + module.vec_znx_big_sub_small_b_inplace(&mut c0_big, 0, pt, 0); } // c0_big += e c0_big.add_normal(log_base2k, 0, log_q, source_xe, sigma, bound); @@ -102,13 +99,13 @@ pub trait EncryptZeroSk { } } -impl EncryptZeroSk> for CtVecZnxDft +impl EncryptZeroSk for C where - VecZnxDft: VecZnxDftToMut + VecZnxDftToRef, + C: VecZnxDftToMut + ZnxInfos + Infos, { fn encrypt_zero( module: &Module, - ct: &mut CtVecZnxDft, + ct: &mut C, sk: &SecretKey, source_xa: &mut Source, source_xe: &mut Source, @@ -120,7 +117,7 @@ where { let log_base2k: usize = ct.log_base2k(); let log_q: usize = ct.log_q(); - let mut ct_mut: VecZnxDft<&mut [u8], FFT64> = ct.data_mut().to_mut(); + let mut ct_mut: VecZnxDft<&mut [u8], FFT64> = ct.to_mut(); let size: usize = ct_mut.size(); // ct[1] = DFT(a) diff --git a/rlwe/src/keys.rs b/rlwe/src/keys.rs index d84abc0..77f1d9a 100644 --- a/rlwe/src/keys.rs +++ b/rlwe/src/keys.rs @@ -1,5 +1,5 @@ use base2k::{ - Backend, FFT64, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToMut, Scratch, VecZnx, VecZnxDft, + Backend, FFT64, Module, Scalar, ScalarAlloc, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToMut, Scratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftToMut, }; use sampling::source::Source; @@ -56,7 +56,7 @@ impl PublicKey, B> { } impl> PublicKey { - pub fn generate(&mut self, module: &Module, sk: &SecretKey>) + pub fn generate(&mut self, module: &Module, sk: &SecretKey>, scratch: &mut Scratch) where ScalarZnxDft: ScalarZnxDftToMut, {