working on adding rank to glwe (all test passing)

This commit is contained in:
Jean-Philippe Bossuat
2025-05-13 17:21:41 +02:00
parent 31b14ee585
commit dee889dc0c
22 changed files with 2020 additions and 2030 deletions

View File

@@ -1,170 +1,63 @@
use base2k::{
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftScratch, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx,
ScalarZnxDft, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigOps, VecZnxBigScratch,
VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, VecZnxToMut, VecZnxToRef, ZnxInfos,
ZnxZero,
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftToRef,
ScalarZnxToRef, Scratch, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef,
};
use sampling::source::Source;
use crate::{
elem::{GetRow, Infos, SetRow},
encryption::EncryptSkScratchSpace,
external_product::{
ExternalProduct, ExternalProductInplace, ExternalProductInplaceScratchSpace, ExternalProductScratchSpace,
},
ggsw::GGSWCiphertext,
glwe::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext},
gglwe_ciphertext::GGLWECiphertext,
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
keys::SecretKeyFourier,
keyswitch::{KeySwitch, KeySwitchInplace, KeySwitchInplaceScratchSpace, KeySwitchScratchSpace},
utils::derive_size,
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
};
pub struct GLWEKeySwitchKey<C, B: Backend> {
pub data: MatZnxDft<C, B>,
pub log_base2k: usize,
pub log_k: usize,
}
pub struct GLWESwitchingKey<Data, B: Backend>(pub(crate) GGLWECiphertext<Data, B>);
impl<B: Backend> GLWEKeySwitchKey<Vec<u8>, B> {
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize, rows: usize) -> Self {
Self {
data: module.new_mat_znx_dft(rows, 1, 2, derive_size(log_base2k, log_k)),
log_base2k: log_base2k,
log_k: log_k,
}
impl GLWESwitchingKey<Vec<u8>, FFT64> {
pub fn new(module: &Module<FFT64>, base2k: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self {
GLWESwitchingKey(GGLWECiphertext::new(
module, base2k, k, rows, rank_in, rank_out,
))
}
}
impl<T, B: Backend> Infos for GLWEKeySwitchKey<T, B> {
impl<T, B: Backend> Infos for GLWESwitchingKey<T, B> {
type Inner = MatZnxDft<T, B>;
fn inner(&self) -> &Self::Inner {
&self.data
&self.0.inner()
}
fn basek(&self) -> usize {
self.log_base2k
self.0.basek()
}
fn k(&self) -> usize {
self.log_k
self.0.k()
}
}
impl<C, B: Backend> MatZnxDftToMut<B> for GLWEKeySwitchKey<C, B>
impl<DataSelf, B: Backend> MatZnxDftToMut<B> for GLWESwitchingKey<DataSelf, B>
where
MatZnxDft<C, B>: MatZnxDftToMut<B>,
MatZnxDft<DataSelf, B>: MatZnxDftToMut<B>,
{
fn to_mut(&mut self) -> MatZnxDft<&mut [u8], B> {
self.data.to_mut()
self.0.data.to_mut()
}
}
impl<C, B: Backend> MatZnxDftToRef<B> for GLWEKeySwitchKey<C, B>
impl<DataSelf, B: Backend> MatZnxDftToRef<B> for GLWESwitchingKey<DataSelf, B>
where
MatZnxDft<C, B>: MatZnxDftToRef<B>,
MatZnxDft<DataSelf, B>: MatZnxDftToRef<B>,
{
fn to_ref(&self) -> MatZnxDft<&[u8], B> {
self.data.to_ref()
self.0.data.to_ref()
}
}
impl GLWEKeySwitchKey<Vec<u8>, FFT64> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, size: usize) -> usize {
GLWECiphertext::encrypt_sk_scratch_space(module, size)
+ module.bytes_of_vec_znx(2, size)
+ module.bytes_of_vec_znx(1, size)
+ module.bytes_of_vec_znx_dft(2, size)
}
}
pub fn encrypt_glwe_key_switch_key_sk<C, P, S>(
module: &Module<FFT64>,
ct: &mut GLWEKeySwitchKey<C, FFT64>,
pt: &ScalarZnx<P>,
sk_dft: &SecretKeyFourier<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
bound: f64,
scratch: &mut Scratch,
) where
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
ScalarZnx<P>: ScalarZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
let rows: usize = ct.rows();
let size: usize = ct.size();
let log_base2k: usize = ct.basek();
let (tmp_znx_pt, scrach_1) = scratch.tmp_vec_znx(module, 1, size);
let (tmp_znx_ct, scrach_2) = scrach_1.tmp_vec_znx(module, 2, size);
let (mut vec_znx_dft_ct, scratch_3) = scrach_2.tmp_vec_znx_dft(module, 2, size);
let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext {
data: tmp_znx_pt,
log_base2k: log_base2k,
log_k: ct.k(),
};
let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext {
data: tmp_znx_ct,
log_base2k: log_base2k,
log_k: ct.k(),
};
(0..rows).for_each(|row_i| {
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt, 0, row_i, pt, 0);
module.vec_znx_normalize_inplace(log_base2k, &mut vec_znx_pt, 0, scratch_3);
// rlwe encrypt of vec_znx_pt into vec_znx_ct
vec_znx_ct.encrypt_sk(
module,
&vec_znx_pt,
sk_dft,
source_xa,
source_xe,
sigma,
bound,
scratch_3,
);
vec_znx_pt.data.zero(); // zeroes for next iteration
// Switch vec_znx_ct into DFT domain
module.vec_znx_dft(&mut vec_znx_dft_ct, 0, &vec_znx_ct, 0);
module.vec_znx_dft(&mut vec_znx_dft_ct, 1, &vec_znx_ct, 1);
// Stores vec_znx_dft_ct into thw i-th row of the MatZnxDft
module.vmp_prepare_row(ct, row_i, 0, &vec_znx_dft_ct);
});
}
impl<C> GLWEKeySwitchKey<C, FFT64> {
pub fn encrypt_sk<P, S>(
&mut self,
module: &Module<FFT64>,
pt: &ScalarZnx<P>,
sk_dft: &SecretKeyFourier<S, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
bound: f64,
scratch: &mut Scratch,
) where
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
ScalarZnx<P>: ScalarZnxToRef,
ScalarZnxDft<S, FFT64>: ScalarZnxDftToRef<FFT64>,
{
encrypt_glwe_key_switch_key_sk(
module, self, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
)
}
}
impl<C> GetRow<FFT64> for GLWEKeySwitchKey<C, FFT64>
impl<C> GetRow<FFT64> for GLWESwitchingKey<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
{
@@ -180,7 +73,7 @@ where
}
}
impl<C> SetRow<FFT64> for GLWEKeySwitchKey<C, FFT64>
impl<C> SetRow<FFT64> for GLWESwitchingKey<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToMut<FFT64>,
{
@@ -196,138 +89,117 @@ where
}
}
impl KeySwitchScratchSpace for GLWEKeySwitchKey<Vec<u8>, FFT64> {
fn keyswitch_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
<GLWEKeySwitchKey<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
impl GLWESwitchingKey<Vec<u8>, FFT64> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
GGLWECiphertext::encrypt_sk_scratch_space(module, rank, size)
}
pub fn encrypt_pk_scratch_space(module: &Module<FFT64>, rank: usize, pk_size: usize) -> usize {
GGLWECiphertext::encrypt_pk_scratch_space(module, rank, pk_size)
}
}
impl<DataSelf> GLWESwitchingKey<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
{
pub fn encrypt_sk<DataPt, DataSk>(
&mut self,
module: &Module<FFT64>,
pt: &ScalarZnx<DataPt>,
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
bound: f64,
scratch: &mut Scratch,
) where
ScalarZnx<DataPt>: ScalarZnxToRef,
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
{
self.0.encrypt_sk(
module, pt, sk_dft, source_xa, source_xe, sigma, bound, scratch,
);
}
}
impl GLWESwitchingKey<Vec<u8>, FFT64> {
pub fn keyswitch_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
module, res_size, lhs, rhs,
)
}
}
impl<DataSelf, DataLhs, DataRhs> KeySwitch<DataLhs, DataRhs> for GLWEKeySwitchKey<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
type Lhs = GLWEKeySwitchKey<DataLhs, FFT64>;
type Rhs = GLWEKeySwitchKey<DataRhs, FFT64>;
fn keyswitch(&mut self, module: &Module<FFT64>, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) {
rhs.prod_with_vec_glwe(module, self, lhs, scratch);
}
}
impl KeySwitchInplaceScratchSpace for GLWEKeySwitchKey<Vec<u8>, FFT64> {
fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize) -> usize {
<GLWEKeySwitchKey<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space(
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize) -> usize {
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_inplace_scratch_space(
module, res_size, rhs,
)
}
}
impl<DataSelf, DataRhs> KeySwitchInplace<DataRhs> for GLWEKeySwitchKey<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
type Rhs = GLWEKeySwitchKey<DataRhs, FFT64>;
fn keyswitch_inplace(&mut self, module: &Module<FFT64>, rhs: &Self::Rhs, scratch: &mut Scratch) {
rhs.prod_with_vec_glwe(module, self, rhs, scratch);
}
}
impl ExternalProductScratchSpace for GLWEKeySwitchKey<Vec<u8>, FFT64> {
fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_vec_glwe_scratch_space(
module, res_size, lhs, rhs,
)
}
}
impl<DataSelf, DataLhs, DataRhs> ExternalProduct<DataLhs, DataRhs> for GLWEKeySwitchKey<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
type Lhs = GLWEKeySwitchKey<DataLhs, FFT64>;
type Rhs = GGSWCiphertext<DataRhs, FFT64>;
fn external_product(&mut self, module: &Module<FFT64>, lhs: &Self::Lhs, rhs: &Self::Rhs, scratch: &mut Scratch) {
rhs.prod_with_vec_glwe(module, self, lhs, scratch);
}
}
impl ExternalProductInplaceScratchSpace for GLWEKeySwitchKey<Vec<u8>, FFT64> {
fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize) -> usize {
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize) -> usize {
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
module, res_size, rhs,
)
}
}
impl<DataSelf, DataRhs> ExternalProductInplace<DataRhs> for GLWEKeySwitchKey<DataSelf, FFT64>
impl<DataSelf> GLWESwitchingKey<DataSelf, FFT64>
where
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
type Rhs = GGSWCiphertext<DataRhs, FFT64>;
fn external_product_inplace(&mut self, module: &Module<FFT64>, rhs: &Self::Rhs, scratch: &mut Scratch) {
rhs.prod_with_vec_glwe_inplace(module, self, scratch);
}
}
impl VecGLWEProductScratchSpace for GLWEKeySwitchKey<Vec<u8>, FFT64> {
fn prod_with_glwe_scratch_space(module: &Module<FFT64>, res_size: usize, a_size: usize, grlwe_size: usize) -> usize {
module.bytes_of_vec_znx_dft(2, grlwe_size)
+ (module.vec_znx_big_normalize_tmp_bytes()
| (module.vmp_apply_tmp_bytes(res_size, a_size, a_size, 1, 2, grlwe_size)
+ module.bytes_of_vec_znx_dft(1, a_size)))
}
}
impl<C> VecGLWEProduct for GLWEKeySwitchKey<C, FFT64>
where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64> + ZnxInfos,
{
fn prod_with_glwe<R, A>(
&self,
pub fn keyswitch<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
res: &mut GLWECiphertext<R>,
a: &GLWECiphertext<A>,
lhs: &GLWESwitchingKey<DataLhs, FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut base2k::Scratch,
) where
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.0
.prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch);
}
pub fn keyswitch_inplace<DataRhs>(
&mut self,
module: &Module<FFT64>,
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut base2k::Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.0
.prod_with_vec_glwe_inplace(module, &mut self.0, scratch);
}
pub fn external_product<DataLhs, DataRhs>(
&mut self,
module: &Module<FFT64>,
lhs: &GLWESwitchingKey<DataLhs, FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<C, FFT64>: MatZnxDftToRef<FFT64>,
VecZnx<R>: VecZnxToMut,
VecZnx<A>: VecZnxToRef,
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
let log_base2k: usize = self.basek();
rhs.prod_with_vec_glwe(module, &mut self.0, &lhs.0, scratch);
}
#[cfg(debug_assertions)]
{
assert_eq!(res.basek(), log_base2k);
assert_eq!(a.basek(), log_base2k);
assert_eq!(self.n(), module.n());
assert_eq!(res.n(), module.n());
assert_eq!(a.n(), module.n());
}
let (mut res_dft, scratch1) = scratch.tmp_vec_znx_dft(module, 2, self.size()); // Todo optimise
{
let (mut a1_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, 1, a.size());
module.vec_znx_dft(&mut a1_dft, 0, a, 1);
module.vmp_apply(&mut res_dft, &a1_dft, self, scratch2);
}
let mut res_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(res_dft);
module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
module.vec_znx_big_normalize(log_base2k, res, 0, &res_big, 0, scratch1);
module.vec_znx_big_normalize(log_base2k, res, 1, &res_big, 1, scratch1);
pub fn external_product_inplace<DataRhs>(
&mut self,
module: &Module<FFT64>,
rhs: &GGSWCiphertext<DataRhs, FFT64>,
scratch: &mut Scratch,
) where
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
{
rhs.prod_with_vec_glwe_inplace(module, &mut self.0, scratch);
}
}