mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
working on adding rank to glwe (all test passing)
This commit is contained in:
261
core/src/glwe_ciphertext_fourier.rs
Normal file
261
core/src/glwe_ciphertext_fourier.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
use base2k::{
|
||||
Backend, FFT64, MatZnxDft, MatZnxDftToRef, Module, ScalarZnxDft, ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnx,
|
||||
VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps,
|
||||
VecZnxDftToMut, VecZnxDftToRef, VecZnxToMut, VecZnxToRef, ZnxZero,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
elem::Infos,
|
||||
gglwe_ciphertext::GGLWECiphertext,
|
||||
ggsw_ciphertext::GGSWCiphertext,
|
||||
glwe_ciphertext::GLWECiphertext,
|
||||
glwe_plaintext::GLWEPlaintext,
|
||||
keys::SecretKeyFourier,
|
||||
keyswitch_key::GLWESwitchingKey,
|
||||
utils::derive_size,
|
||||
vec_glwe_product::{VecGLWEProduct, VecGLWEProductScratchSpace},
|
||||
};
|
||||
|
||||
pub struct GLWECiphertextFourier<C, B: Backend> {
|
||||
pub data: VecZnxDft<C, B>,
|
||||
pub basek: usize,
|
||||
pub k: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> GLWECiphertextFourier<Vec<u8>, B> {
|
||||
pub fn new(module: &Module<B>, log_base2k: usize, log_k: usize, rank: usize) -> Self {
|
||||
Self {
|
||||
data: module.new_vec_znx_dft(rank + 1, derive_size(log_base2k, log_k)),
|
||||
basek: log_base2k,
|
||||
k: log_k,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> Infos for GLWECiphertextFourier<T, B> {
|
||||
type Inner = VecZnxDft<T, B>;
|
||||
|
||||
fn inner(&self) -> &Self::Inner {
|
||||
&self.data
|
||||
}
|
||||
|
||||
fn basek(&self) -> usize {
|
||||
self.basek
|
||||
}
|
||||
|
||||
fn k(&self) -> usize {
|
||||
self.k
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B: Backend> GLWECiphertextFourier<T, B> {
|
||||
pub fn rank(&self) -> usize {
|
||||
self.cols() - 1
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> VecZnxDftToMut<B> for GLWECiphertextFourier<C, B>
|
||||
where
|
||||
VecZnxDft<C, B>: VecZnxDftToMut<B>,
|
||||
{
|
||||
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||
self.data.to_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, B: Backend> VecZnxDftToRef<B> for GLWECiphertextFourier<C, B>
|
||||
where
|
||||
VecZnxDft<C, B>: VecZnxDftToRef<B>,
|
||||
{
|
||||
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||
self.data.to_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl GLWECiphertextFourier<Vec<u8>, FFT64> {
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn idft_scratch_space(module: &Module<FFT64>, size: usize) -> usize {
|
||||
module.bytes_of_vec_znx(1, size) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())
|
||||
}
|
||||
|
||||
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, ct_size: usize) -> usize {
|
||||
module.bytes_of_vec_znx(rank + 1, ct_size) + GLWECiphertext::encrypt_sk_scratch_space(module, rank, ct_size)
|
||||
}
|
||||
|
||||
pub fn decrypt_scratch_space(module: &Module<FFT64>, ct_size: usize) -> usize {
|
||||
(module.vec_znx_big_normalize_tmp_bytes()
|
||||
| module.bytes_of_vec_znx_dft(1, ct_size)
|
||||
| (module.bytes_of_vec_znx_big(1, ct_size) + module.vec_znx_idft_tmp_bytes()))
|
||||
+ module.bytes_of_vec_znx_big(1, ct_size)
|
||||
}
|
||||
|
||||
pub fn keyswitch_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
|
||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize) -> usize {
|
||||
<GGLWECiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
|
||||
module, res_size, rhs,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn external_product_scratch_space(module: &Module<FFT64>, res_size: usize, lhs: usize, rhs: usize) -> usize {
|
||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_scratch_space(module, res_size, lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn external_product_inplace_scratch_space(module: &Module<FFT64>, res_size: usize, rhs: usize) -> usize {
|
||||
<GGSWCiphertext<Vec<u8>, FFT64> as VecGLWEProductScratchSpace>::prod_with_glwe_inplace_scratch_space(
|
||||
module, res_size, rhs,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> GLWECiphertextFourier<DataSelf, FFT64>
|
||||
where
|
||||
VecZnxDft<DataSelf, FFT64>: VecZnxDftToMut<FFT64> + VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
pub fn encrypt_zero_sk<DataSk>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
bound: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
let (vec_znx_tmp, scratch_1) = scratch.tmp_vec_znx(module, self.rank() + 1, self.size());
|
||||
let mut ct_idft = GLWECiphertext {
|
||||
data: vec_znx_tmp,
|
||||
basek: self.basek,
|
||||
k: self.k,
|
||||
};
|
||||
ct_idft.encrypt_zero_sk(
|
||||
module, sk_dft, source_xa, source_xe, sigma, bound, scratch_1,
|
||||
);
|
||||
|
||||
ct_idft.dft(module, self);
|
||||
}
|
||||
|
||||
pub fn keyswitch<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.0.prod_with_glwe_fourier(module, self, lhs, scratch);
|
||||
}
|
||||
|
||||
pub fn keyswitch_inplace<DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &GLWESwitchingKey<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.0.prod_with_glwe_fourier_inplace(module, self, scratch);
|
||||
}
|
||||
|
||||
pub fn external_product<DataLhs, DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
lhs: &GLWECiphertextFourier<DataLhs, FFT64>,
|
||||
rhs: &GGSWCiphertext<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnxDft<DataLhs, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.prod_with_glwe_fourier(module, self, lhs, scratch);
|
||||
}
|
||||
|
||||
pub fn external_product_inplace<DataRhs>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
rhs: &GGSWCiphertext<DataRhs, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||
{
|
||||
rhs.prod_with_glwe_fourier_inplace(module, self, scratch);
|
||||
}
|
||||
}
|
||||
|
||||
impl<DataSelf> GLWECiphertextFourier<DataSelf, FFT64>
|
||||
where
|
||||
VecZnxDft<DataSelf, FFT64>: VecZnxDftToRef<FFT64>,
|
||||
{
|
||||
pub fn decrypt<DataPt, DataSk>(
|
||||
&self,
|
||||
module: &Module<FFT64>,
|
||||
pt: &mut GLWEPlaintext<DataPt>,
|
||||
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
VecZnx<DataPt>: VecZnxToMut + VecZnxToRef,
|
||||
ScalarZnxDft<DataSk, FFT64>: ScalarZnxDftToRef<FFT64>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank(), sk_dft.rank());
|
||||
assert_eq!(self.n(), module.n());
|
||||
assert_eq!(pt.n(), module.n());
|
||||
assert_eq!(sk_dft.n(), module.n());
|
||||
}
|
||||
|
||||
let cols = self.rank() + 1;
|
||||
|
||||
let (mut pt_big, scratch_1) = scratch.tmp_vec_znx_big(module, 1, self.size()); // TODO optimize size when pt << ct
|
||||
pt_big.zero();
|
||||
|
||||
{
|
||||
(1..cols).for_each(|i| {
|
||||
let (mut ci_dft, _) = scratch_1.tmp_vec_znx_dft(module, 1, self.size()); // TODO optimize size when pt << ct
|
||||
module.svp_apply(&mut ci_dft, 0, sk_dft, i - 1, self, i);
|
||||
let ci_big: VecZnxBig<&mut [u8], FFT64> = module.vec_znx_idft_consume(ci_dft);
|
||||
module.vec_znx_big_add_inplace(&mut pt_big, 0, &ci_big, 0);
|
||||
});
|
||||
}
|
||||
|
||||
{
|
||||
let (mut c0_big, scratch_2) = scratch_1.tmp_vec_znx_big(module, 1, self.size());
|
||||
// c0_big = (a * s) + (-a * s + m + e) = BIG(m + e)
|
||||
module.vec_znx_idft(&mut c0_big, 0, self, 0, scratch_2);
|
||||
module.vec_znx_big_add_inplace(&mut pt_big, 0, &c0_big, 0);
|
||||
}
|
||||
|
||||
// pt = norm(BIG(m + e))
|
||||
module.vec_znx_big_normalize(self.basek(), pt, 0, &mut pt_big, 0, scratch_1);
|
||||
|
||||
pt.basek = self.basek();
|
||||
pt.k = pt.k().min(self.k());
|
||||
}
|
||||
|
||||
pub(crate) fn idft<DataRes>(&self, module: &Module<FFT64>, res: &mut GLWECiphertext<DataRes>, scratch: &mut Scratch)
|
||||
where
|
||||
GLWECiphertext<DataRes>: VecZnxToMut,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.rank(), res.rank());
|
||||
assert_eq!(self.basek(), res.basek())
|
||||
}
|
||||
|
||||
let min_size: usize = self.size().min(res.size());
|
||||
|
||||
let (mut res_big, scratch1) = scratch.tmp_vec_znx_big(module, 1, min_size);
|
||||
|
||||
(0..self.rank() + 1).for_each(|i| {
|
||||
module.vec_znx_idft(&mut res_big, 0, self, i, scratch1);
|
||||
module.vec_znx_big_normalize(self.basek(), res, i, &res_big, 0, scratch1);
|
||||
});
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user