Replaced manual core structs scratch allocation by new API on Scratch

This commit is contained in:
Jean-Philippe Bossuat
2025-05-28 15:59:49 +02:00
parent f2b671329d
commit 8209fb4e40
9 changed files with 95 additions and 234 deletions

View File

@@ -5,13 +5,8 @@ use backend::{
use sampling::source::Source; use sampling::source::Source;
use crate::{ use crate::{
elem::{GetRow, Infos, SetRow}, GGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWECiphertextFourier, GLWESwitchingKey, GetRow, Infos, ScratchCore,
gglwe_ciphertext::GGLWECiphertext, SecretKey, SetRow,
ggsw_ciphertext::GGSWCiphertext,
glwe_ciphertext::GLWECiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
keys::{SecretKey, SecretKeyFourier},
keyswitch_key::GLWESwitchingKey,
}; };
pub struct AutomorphismKey<Data, B: Backend> { pub struct AutomorphismKey<Data, B: Backend> {
@@ -179,12 +174,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> AutomorphismKey<DataSelf, FFT64> {
) )
} }
let (sk_out_dft_data, scratch_1) = scratch.tmp_scalar_znx_dft(module, sk.rank()); let (mut sk_out_dft, scratch_1) = scratch.tmp_sk_fourier(module, sk.rank());
let mut sk_out_dft: SecretKeyFourier<&mut [u8], FFT64> = SecretKeyFourier {
data: sk_out_dft_data,
dist: sk.dist,
};
{ {
(0..self.rank()).for_each(|i| { (0..self.rank()).for_each(|i| {
@@ -249,13 +239,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> AutomorphismKey<DataSelf, FFT64> {
let cols_out: usize = rhs.rank_out() + 1; let cols_out: usize = rhs.rank_out() + 1;
let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, lhs.size()); let (mut tmp_dft, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank());
let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_dft_data,
basek: lhs.basek(),
k: lhs.k(),
};
(0..self.rank_in()).for_each(|col_i| { (0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| { (0..self.rows()).for_each(|row_j| {

View File

@@ -1,6 +1,6 @@
use backend::{Backend, Module, ZnxInfos}; use backend::{Backend, Module, ZnxInfos};
use crate::{glwe_ciphertext_fourier::GLWECiphertextFourier, utils::derive_size}; use crate::{GLWECiphertextFourier, derive_size};
pub trait Infos { pub trait Infos {
type Inner: ZnxInfos; type Inner: ZnxInfos;

View File

@@ -4,14 +4,7 @@ use backend::{
}; };
use sampling::source::Source; use sampling::source::Source;
use crate::{ use crate::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext, GetRow, Infos, SecretKeyFourier, SetRow, derive_size};
elem::{GetRow, Infos, SetRow},
glwe_ciphertext::GLWECiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::SecretKeyFourier,
utils::derive_size,
};
pub struct GGLWECiphertext<C, B: Backend> { pub struct GGLWECiphertext<C, B: Backend> {
pub(crate) data: MatZnxDft<C, B>, pub(crate) data: MatZnxDft<C, B>,

View File

@@ -6,11 +6,11 @@ use backend::{
use sampling::source::Source; use sampling::source::Source;
use crate::{ use crate::{
ScratchCore,
automorphism::AutomorphismKey, automorphism::AutomorphismKey,
elem::{GetRow, Infos, SetRow}, elem::{GetRow, Infos, SetRow},
glwe_ciphertext::GLWECiphertext, glwe_ciphertext::GLWECiphertext,
glwe_ciphertext_fourier::GLWECiphertextFourier, glwe_ciphertext_fourier::GLWECiphertextFourier,
glwe_plaintext::GLWEPlaintext,
keys::SecretKeyFourier, keys::SecretKeyFourier,
keyswitch_key::GLWESwitchingKey, keyswitch_key::GLWESwitchingKey,
tensor_key::TensorKey, tensor_key::TensorKey,
@@ -198,55 +198,38 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
assert_eq!(sk_dft.n(), module.n()); assert_eq!(sk_dft.n(), module.n());
} }
let size: usize = self.size();
let basek: usize = self.basek(); let basek: usize = self.basek();
let k: usize = self.k(); let k: usize = self.k();
let cols: usize = self.rank() + 1; let rank: usize = self.rank();
let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size); let (mut tmp_pt, scratch1) = scratch.tmp_glwe_pt(module, basek, k);
let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, cols, size); let (mut tmp_ct, scratch2) = scratch1.tmp_glwe_ct(module, basek, k, rank);
let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext {
data: tmp_znx_pt,
basek: basek,
k: k,
};
let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext {
data: tmp_znx_ct,
basek: basek,
k,
};
(0..self.rows()).for_each(|row_i| { (0..self.rows()).for_each(|row_i| {
vec_znx_pt.data.zero(); tmp_pt.data.zero();
// Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
module.vec_znx_add_scalar_inplace(&mut vec_znx_pt.data, 0, row_i, pt, 0); module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i, pt, 0);
module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt.data, 0, scrach_2); module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch2);
(0..cols).for_each(|col_j| { (0..rank + 1).for_each(|col_j| {
// rlwe encrypt of vec_znx_pt into vec_znx_ct // rlwe encrypt of vec_znx_pt into vec_znx_ct
vec_znx_ct.encrypt_sk_private( tmp_ct.encrypt_sk_private(
module, module,
Some((&vec_znx_pt, col_j)), Some((&tmp_pt, col_j)),
sk_dft, sk_dft,
source_xa, source_xa,
source_xe, source_xe,
sigma, sigma,
scrach_2, scratch2,
); );
// Switch vec_znx_ct into DFT domain // Switch vec_znx_ct into DFT domain
{ {
let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, cols, size); let (mut tmp_ct_dft, _) = scratch2.tmp_glwe_fourier(module, basek, k, rank);
tmp_ct.dft(module, &mut tmp_ct_dft);
(0..cols).for_each(|i| { self.set_row(module, row_i, col_j, &tmp_ct_dft);
module.vec_znx_dft(&mut vec_znx_dft_ct, i, &vec_znx_ct.data, i);
});
module.vmp_prepare_row(&mut self.data, row_i, col_j, &vec_znx_dft_ct);
} }
}); });
}); });
@@ -349,26 +332,22 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
tsk: &TensorKey<DataTsk, FFT64>, tsk: &TensorKey<DataTsk, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) { ) {
let cols: usize = self.rank() + 1; let rank: usize = self.rank();
let cols: usize = rank + 1;
let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size()); let basek: usize = self.basek();
let mut res: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
data: res_data,
basek: self.basek(),
k: self.k(),
};
let (mut tmp_res, scratch1) = scratch.tmp_glwe_ct(module, basek, self.k(), rank);
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size()); let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
// Keyswitch the j-th row of the col 0 // Keyswitch the j-th row of the col 0
(0..lhs.rows()).for_each(|row_i| { (0..lhs.rows()).for_each(|row_i| {
// Key-switch column 0, i.e. // Key-switch column 0, i.e.
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2)
lhs.keyswitch_internal_col0(module, row_i, &mut res, ksk, scratch2); lhs.keyswitch_internal_col0(module, row_i, &mut tmp_res, ksk, scratch2);
// Isolates DFT(a[i]) // Isolates DFT(a[i])
(0..cols).for_each(|col_i| { (0..cols).for_each(|col_i| {
module.vec_znx_dft(&mut ci_dft, col_i, &res.data, col_i); module.vec_znx_dft(&mut ci_dft, col_i, &tmp_res.data, col_i);
}); });
module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft); module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft);
@@ -379,14 +358,10 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
// col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 )
// col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i])
(1..cols).for_each(|col_j| { (1..cols).for_each(|col_j| {
self.expand_row(module, col_j, &mut res.data, &ci_dft, tsk, scratch2); self.expand_row(module, col_j, &mut tmp_res.data, &ci_dft, tsk, scratch2);
let (mut tmp_res_dft, _) = scratch2.tmp_glwe_fourier(module, basek, self.k(), rank);
let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size()); tmp_res.dft(module, &mut tmp_res_dft);
(0..cols).for_each(|i| { self.set_row(module, row_i, col_j, &tmp_res_dft);
module.vec_znx_dft(&mut res_dft, i, &res.data, i);
});
module.vmp_prepare_row(&mut self.data, row_i, col_j, &res_dft);
}); });
}) })
} }
@@ -448,28 +423,24 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
) )
}; };
let cols: usize = self.rank() + 1; let rank: usize = self.rank();
let cols: usize = rank + 1;
let basek: usize = self.basek();
let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size()); let (mut tmp_res, scratch1) = scratch.tmp_glwe_ct(module, basek, self.k(), rank);
let mut res: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
data: res_data,
basek: self.basek(),
k: self.k(),
};
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size());
// Keyswitch the j-th row of the col 0 // Keyswitch the j-th row of the col 0
(0..lhs.rows()).for_each(|row_i| { (0..lhs.rows()).for_each(|row_i| {
// Key-switch column 0, i.e. // Key-switch column 0, i.e.
// col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2)
lhs.keyswitch_internal_col0(module, row_i, &mut res, &auto_key.key, scratch2); lhs.keyswitch_internal_col0(module, row_i, &mut tmp_res, &auto_key.key, scratch2);
// Isolates DFT(AUTO(a[i])) // Isolates DFT(AUTO(a[i]))
(0..cols).for_each(|col_i| { (0..cols).for_each(|col_i| {
// (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) -> (-(a0s0 + a1s1 + a2s2) + pi(M[i]), a0, a1, a2) // (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) -> (-(a0s0 + a1s1 + a2s2) + pi(M[i]), a0, a1, a2)
module.vec_znx_automorphism_inplace(auto_key.p(), &mut res.data, col_i); module.vec_znx_automorphism_inplace(auto_key.p(), &mut tmp_res.data, col_i);
module.vec_znx_dft(&mut ci_dft, col_i, &res.data, col_i); module.vec_znx_dft(&mut ci_dft, col_i, &tmp_res.data, col_i);
}); });
module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft); module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft);
@@ -480,14 +451,17 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
// col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + pi(M[i]), c2 ) // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + pi(M[i]), c2 )
// col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + pi(M[i])) // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + pi(M[i]))
(1..cols).for_each(|col_j| { (1..cols).for_each(|col_j| {
self.expand_row(module, col_j, &mut res.data, &ci_dft, tensor_key, scratch2); self.expand_row(
module,
let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size()); col_j,
(0..cols).for_each(|i| { &mut tmp_res.data,
module.vec_znx_dft(&mut res_dft, i, &res.data, i); &ci_dft,
}); tensor_key,
scratch2,
module.vmp_prepare_row(&mut self.data, row_i, col_j, &res_dft); );
let (mut tmp_res_dft, _) = scratch2.tmp_glwe_fourier(module, basek, self.k(), rank);
tmp_res.dft(module, &mut tmp_res_dft);
self.set_row(module, row_i, col_j, &tmp_res_dft);
}); });
}) })
} }
@@ -530,35 +504,22 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
); );
} }
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank() + 1, lhs.size()); let (mut tmp_ct_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank());
let (mut tmp_ct_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank());
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_in_data,
basek: lhs.basek(),
k: lhs.k(),
};
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank() + 1, self.size());
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_out_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank() + 1).for_each(|col_i| { (0..self.rank() + 1).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| { (0..self.rows()).for_each(|row_j| {
lhs.get_row(module, row_j, col_i, &mut tmp_in); lhs.get_row(module, row_j, col_i, &mut tmp_ct_in);
tmp_out.external_product(module, &tmp_in, rhs, scratch2); tmp_ct_out.external_product(module, &tmp_ct_in, rhs, scratch2);
self.set_row(module, row_j, col_i, &tmp_out); self.set_row(module, row_j, col_i, &tmp_ct_out);
}); });
}); });
tmp_out.data.zero(); tmp_ct_out.data.zero();
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
(0..self.rank() + 1).for_each(|col_j| { (0..self.rank() + 1).for_each(|col_j| {
self.set_row(module, row_i, col_j, &tmp_out); self.set_row(module, row_i, col_j, &tmp_ct_out);
}); });
}); });
} }
@@ -580,19 +541,13 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
); );
} }
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); let (mut tmp_ct, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank());
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank() + 1).for_each(|col_i| { (0..self.rank() + 1).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| { (0..self.rows()).for_each(|row_j| {
self.get_row(module, row_j, col_i, &mut tmp); self.get_row(module, row_j, col_i, &mut tmp_ct);
tmp.external_product_inplace(module, rhs, scratch1); tmp_ct.external_product_inplace(module, rhs, scratch1);
self.set_row(module, row_j, col_i, &tmp); self.set_row(module, row_j, col_i, &tmp_ct);
}); });
}); });
} }
@@ -622,15 +577,9 @@ impl<DataSelf: AsRef<[u8]>> GGSWCiphertext<DataSelf, FFT64> {
) )
) )
} }
let (mut tmp_dft_dft, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank());
let (tmp_dft_in_data, scratch2) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); self.get_row(module, row_i, 0, &mut tmp_dft_dft);
let mut tmp_dft_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { res.keyswitch_from_fourier(module, &tmp_dft_dft, ksk, scratch1);
data: tmp_dft_in_data,
basek: self.basek(),
k: self.k(),
};
self.get_row(module, row_i, 0, &mut tmp_dft_in);
res.keyswitch_from_fourier(module, &tmp_dft_in, ksk, scratch2);
} }
} }

View File

@@ -5,7 +5,7 @@ use backend::{
use sampling::source::Source; use sampling::source::Source;
use crate::{ use crate::{
elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, glwe_plaintext::GLWEPlaintext, ScratchCore, elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, glwe_plaintext::GLWEPlaintext,
keys::SecretKeyFourier, keyswitch_key::GLWESwitchingKey, utils::derive_size, keys::SecretKeyFourier, keyswitch_key::GLWESwitchingKey, utils::derive_size,
}; };
@@ -119,15 +119,9 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWECiphertextFourier<DataSelf, FFT64>
sigma: f64, sigma: f64,
scratch: &mut Scratch, scratch: &mut Scratch,
) { ) {
let (vec_znx_tmp, scratch_1) = scratch.tmp_vec_znx(module, self.rank() + 1, self.size()); let (mut tmp_ct, scratch1) = scratch.tmp_glwe_ct(module, self.basek(), self.k(), self.rank());
let mut ct_idft = GLWECiphertext { tmp_ct.encrypt_zero_sk(module, sk_dft, source_xa, source_xe, sigma, scratch1);
data: vec_znx_tmp, tmp_ct.dft(module, self);
basek: self.basek,
k: self.k,
};
ct_idft.encrypt_zero_sk(module, sk_dft, source_xa, source_xe, sigma, scratch_1);
ct_idft.dft(module, self);
} }
pub fn keyswitch<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>( pub fn keyswitch<DataLhs: AsRef<[u8]>, DataRhs: AsRef<[u8]>>(
@@ -137,22 +131,9 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWECiphertextFourier<DataSelf, FFT64>
rhs: &GLWESwitchingKey<DataRhs, FFT64>, rhs: &GLWESwitchingKey<DataRhs, FFT64>,
scratch: &mut Scratch, scratch: &mut Scratch,
) { ) {
let cols_out: usize = rhs.rank_out() + 1; let (mut tmp_ct, scratch1) = scratch.tmp_glwe_ct(module, self.basek(), self.k(), self.rank());
tmp_ct.keyswitch_from_fourier(module, lhs, rhs, scratch1);
// Space fr normalized VMP result outside of DFT domain tmp_ct.dft(module, self);
let (res_idft_data, scratch1) = scratch.tmp_vec_znx(module, cols_out, lhs.size());
let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
data: res_idft_data,
basek: lhs.basek,
k: lhs.k,
};
res_idft.keyswitch_from_fourier(module, lhs, rhs, scratch1);
(0..cols_out).for_each(|i| {
module.vec_znx_dft(&mut self.data, i, &res_idft.data, i);
});
} }
pub fn keyswitch_inplace<DataRhs: AsRef<[u8]>>( pub fn keyswitch_inplace<DataRhs: AsRef<[u8]>>(

View File

@@ -1,4 +1,4 @@
use crate::{automorphism::AutomorphismKey, elem::Infos, glwe_ciphertext::GLWECiphertext, glwe_ops::GLWEOps}; use crate::{ScratchCore, automorphism::AutomorphismKey, elem::Infos, glwe_ciphertext::GLWECiphertext, glwe_ops::GLWEOps};
use std::collections::HashMap; use std::collections::HashMap;
use backend::{FFT64, Module, Scratch, VecZnxAlloc}; use backend::{FFT64, Module, Scratch, VecZnxAlloc};
@@ -223,8 +223,6 @@ fn combine<D: AsRef<[u8]>, DataAK: AsRef<[u8]>>(
let basek: usize = a.basek(); let basek: usize = a.basek();
let k: usize = a.k(); let k: usize = a.k();
let rank: usize = a.rank(); let rank: usize = a.rank();
let cols: usize = rank + 1;
let size: usize = a.size();
let gal_el: i64; let gal_el: i64;
@@ -245,20 +243,9 @@ fn combine<D: AsRef<[u8]>, DataAK: AsRef<[u8]>>(
a.rsh(1, scratch); a.rsh(1, scratch);
if let Some(b) = b { if let Some(b) = b {
let (tmp_b_data, scratch_1) = scratch.tmp_vec_znx(module, cols, size); let (mut tmp_b, scratch_1) = scratch.tmp_glwe_ct(module, basek, k, rank);
let mut tmp_b: GLWECiphertext<&mut [u8]> = GLWECiphertext {
data: tmp_b_data,
k: k,
basek: basek,
};
{ {
let (tmp_a_data, scratch_2) = scratch_1.tmp_vec_znx(module, cols, size); let (mut tmp_a, scratch_2) = scratch_1.tmp_glwe_ct(module, basek, k, rank); //TODO can we skip tmp_a by reordering X^k ?
let mut tmp_a: GLWECiphertext<&mut [u8]> = GLWECiphertext {
data: tmp_a_data,
k: k,
basek: basek,
};
// tmp_a = b * X^t // tmp_a = b * X^t
tmp_a.rotate(module, 1 << (log_n - i - 1), b); tmp_a.rotate(module, 1 << (log_n - i - 1), b);
@@ -294,13 +281,7 @@ fn combine<D: AsRef<[u8]>, DataAK: AsRef<[u8]>>(
} }
} else { } else {
if let Some(b) = b { if let Some(b) = b {
let (tmp_b_data, scratch_1) = scratch.tmp_vec_znx(module, cols, size); let (mut tmp_b, scratch_1) = scratch.tmp_glwe_ct(module, basek, k, rank);
let mut tmp_b: GLWECiphertext<&mut [u8]> = GLWECiphertext {
data: tmp_b_data,
k: k,
basek: basek,
};
tmp_b.rotate(module, 1 << (log_n - i - 1), b); tmp_b.rotate(module, 1 << (log_n - i - 1), b);
tmp_b.rsh(1, scratch_1); tmp_b.rsh(1, scratch_1);

View File

@@ -1,4 +1,4 @@
use backend::{Backend, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef}; use backend::{Backend, FFT64, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef};
use crate::{ use crate::{
elem::{Infos, SetMetaData}, elem::{Infos, SetMetaData},
@@ -47,6 +47,10 @@ impl GLWEPlaintext<Vec<u8>> {
k, k,
} }
} }
pub fn byte_of(module: &Module<FFT64>, basek: usize, k: usize) -> usize {
module.bytes_of_vec_znx(1, derive_size(basek, k))
}
} }
impl<D: AsRef<[u8]>> GLWECiphertextToRef for GLWEPlaintext<D> { impl<D: AsRef<[u8]>> GLWECiphertextToRef for GLWEPlaintext<D> {

View File

@@ -2,6 +2,7 @@ use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module, Scratch, VecZnxDf
use sampling::source::Source; use sampling::source::Source;
use crate::{ use crate::{
ScratchCore,
elem::{GetRow, Infos, SetRow}, elem::{GetRow, Infos, SetRow},
gglwe_ciphertext::GGLWECiphertext, gglwe_ciphertext::GGLWECiphertext,
ggsw_ciphertext::GGSWCiphertext, ggsw_ciphertext::GGSWCiphertext,
@@ -184,21 +185,8 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWESwitchingKey<DataSelf, FFT64> {
); );
} }
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size()); let (mut tmp_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank());
let (mut tmp_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank());
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_in_data,
basek: lhs.basek(),
k: lhs.k(),
};
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_out_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank_in()).for_each(|col_i| { (0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| { (0..self.rows()).for_each(|row_j| {
@@ -234,13 +222,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWESwitchingKey<DataSelf, FFT64> {
); );
} }
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); let (mut tmp, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank());
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank_in()).for_each(|col_i| { (0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| { (0..self.rows()).for_each(|row_j| {
@@ -283,21 +265,8 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWESwitchingKey<DataSelf, FFT64> {
); );
} }
let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size()); let (mut tmp_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank());
let (mut tmp_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank());
let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_in_data,
basek: lhs.basek(),
k: lhs.k(),
};
let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size());
let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_out_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank_in()).for_each(|col_i| { (0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| { (0..self.rows()).for_each(|row_j| {
@@ -333,13 +302,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> GLWESwitchingKey<DataSelf, FFT64> {
); );
} }
let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); let (mut tmp, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank());
let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
data: tmp_data,
basek: self.basek(),
k: self.k(),
};
(0..self.rank_in()).for_each(|col_i| { (0..self.rank_in()).for_each(|col_i| {
(0..self.rows()).for_each(|row_j| { (0..self.rows()).for_each(|row_j| {

View File

@@ -38,7 +38,8 @@ use utils::derive_size;
pub(crate) const SIX_SIGMA: f64 = 6.0; pub(crate) const SIX_SIGMA: f64 = 6.0;
pub trait ScratchCore<B: Backend> { pub trait ScratchCore<B: Backend> {
fn tmp_glwe(&mut self, module: &Module<B>, basek: usize, k: usize, rank: usize) -> (GLWECiphertext<&mut [u8]>, &mut Self); fn tmp_glwe_ct(&mut self, module: &Module<B>, basek: usize, k: usize, rank: usize) -> (GLWECiphertext<&mut [u8]>, &mut Self);
fn tmp_glwe_pt(&mut self, module: &Module<B>, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self);
fn tmp_gglwe( fn tmp_gglwe(
&mut self, &mut self,
module: &Module<B>, module: &Module<B>,
@@ -100,7 +101,7 @@ pub trait ScratchCore<B: Backend> {
} }
impl ScratchCore<FFT64> for Scratch { impl ScratchCore<FFT64> for Scratch {
fn tmp_glwe( fn tmp_glwe_ct(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
basek: usize, basek: usize,
@@ -111,6 +112,11 @@ impl ScratchCore<FFT64> for Scratch {
(GLWECiphertext { data, basek, k }, scratch) (GLWECiphertext { data, basek, k }, scratch)
} }
fn tmp_glwe_pt(&mut self, module: &Module<FFT64>, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self) {
let (data, scratch) = self.tmp_vec_znx(module, 1, derive_size(basek, k));
(GLWEPlaintext { data, basek, k }, scratch)
}
fn tmp_gglwe( fn tmp_gglwe(
&mut self, &mut self,
module: &Module<FFT64>, module: &Module<FFT64>,
@@ -190,7 +196,7 @@ impl ScratchCore<FFT64> for Scratch {
} }
fn tmp_sk_fourier(&mut self, module: &Module<FFT64>, rank: usize) -> (SecretKeyFourier<&mut [u8], FFT64>, &mut Self) { fn tmp_sk_fourier(&mut self, module: &Module<FFT64>, rank: usize) -> (SecretKeyFourier<&mut [u8], FFT64>, &mut Self) {
let (data, scratch) = self.tmp_scalar_znx_dft(module, rank + 1); let (data, scratch) = self.tmp_scalar_znx_dft(module, rank);
( (
SecretKeyFourier { SecretKeyFourier {
data, data,