mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Added tensor key & associated test
This commit is contained in:
@@ -9,9 +9,9 @@ use rand_distr::{Distribution, weighted::WeightedIndex};
|
|||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
pub struct ScalarZnx<D> {
|
pub struct ScalarZnx<D> {
|
||||||
data: D,
|
pub(crate) data: D,
|
||||||
n: usize,
|
pub(crate) n: usize,
|
||||||
cols: usize,
|
pub(crate) cols: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D> ZnxInfos for ScalarZnx<D> {
|
impl<D> ZnxInfos for ScalarZnx<D> {
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ use std::marker::PhantomData;
|
|||||||
|
|
||||||
use crate::ffi::svp;
|
use crate::ffi::svp;
|
||||||
use crate::znx_base::ZnxInfos;
|
use crate::znx_base::ZnxInfos;
|
||||||
use crate::{Backend, DataView, DataViewMut, FFT64, Module, ZnxSliceSize, ZnxView, alloc_aligned};
|
use crate::{alloc_aligned, Backend, DataView, DataViewMut, Module, VecZnxDft, VecZnxDftToMut, VecZnxDftToRef, ZnxSliceSize, ZnxView, FFT64};
|
||||||
|
|
||||||
pub struct ScalarZnxDft<D, B: Backend> {
|
pub struct ScalarZnxDft<D, B: Backend> {
|
||||||
data: D,
|
data: D,
|
||||||
@@ -92,6 +92,16 @@ impl<D, B: Backend> ScalarZnxDft<D, B> {
|
|||||||
_phantom: PhantomData,
|
_phantom: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn as_vec_znx_dft(self) -> VecZnxDft<D, B>{
|
||||||
|
VecZnxDft{
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: 1,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type ScalarZnxDftOwned<B> = ScalarZnxDft<Vec<u8>, B>;
|
pub type ScalarZnxDftOwned<B> = ScalarZnxDft<Vec<u8>, B>;
|
||||||
@@ -158,3 +168,63 @@ impl<B: Backend> ScalarZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<Vec<u8>, B> {
|
||||||
|
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||||
|
VecZnxDft {
|
||||||
|
data: self.data.as_mut_slice(),
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: 1,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<Vec<u8>, B> {
|
||||||
|
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||||
|
VecZnxDft {
|
||||||
|
data: self.data.as_slice(),
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: 1,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> VecZnxDftToMut<B> for ScalarZnxDft<&mut [u8], B> {
|
||||||
|
fn to_mut(&mut self) -> VecZnxDft<&mut [u8], B> {
|
||||||
|
VecZnxDft {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: 1,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&mut [u8], B> {
|
||||||
|
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||||
|
VecZnxDft {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: 1,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> VecZnxDftToRef<B> for ScalarZnxDft<&[u8], B> {
|
||||||
|
fn to_ref(&self) -> VecZnxDft<&[u8], B> {
|
||||||
|
VecZnxDft {
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
size: 1,
|
||||||
|
_phantom: PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
use crate::DataView;
|
use crate::DataView;
|
||||||
use crate::DataViewMut;
|
use crate::DataViewMut;
|
||||||
|
use crate::ScalarZnx;
|
||||||
use crate::ZnxSliceSize;
|
use crate::ZnxSliceSize;
|
||||||
use crate::ZnxZero;
|
use crate::ZnxZero;
|
||||||
use crate::alloc_aligned;
|
use crate::alloc_aligned;
|
||||||
@@ -128,6 +129,15 @@ impl<D> VecZnx<D> {
|
|||||||
size,
|
size,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn to_scalar_znx(self) -> ScalarZnx<D>{
|
||||||
|
debug_assert_eq!(self.size, 1, "cannot convert VecZnx to ScalarZnx if cols: {} != 1", self.cols);
|
||||||
|
ScalarZnx{
|
||||||
|
data: self.data,
|
||||||
|
n: self.n,
|
||||||
|
cols: self.cols,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Copies the coefficients of `a` on the receiver.
|
/// Copies the coefficients of `a` on the receiver.
|
||||||
|
|||||||
@@ -8,11 +8,11 @@ use crate::{
|
|||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
pub struct VecZnxDft<D, B: Backend> {
|
pub struct VecZnxDft<D, B: Backend> {
|
||||||
data: D,
|
pub(crate) data: D,
|
||||||
n: usize,
|
pub(crate) n: usize,
|
||||||
cols: usize,
|
pub(crate) cols: usize,
|
||||||
size: usize,
|
pub(crate) size: usize,
|
||||||
_phantom: PhantomData<B>,
|
pub(crate) _phantom: PhantomData<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<D, B: Backend> VecZnxDft<D, B> {
|
impl<D, B: Backend> VecZnxDft<D, B> {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ use base2k::{
|
|||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
automorphism::AutomorphismKey,
|
||||||
elem::{GetRow, Infos, SetRow},
|
elem::{GetRow, Infos, SetRow},
|
||||||
glwe_ciphertext::GLWECiphertext,
|
glwe_ciphertext::GLWECiphertext,
|
||||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||||
@@ -78,6 +79,20 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
|
|||||||
+ module.bytes_of_vec_znx_dft(rank + 1, size)
|
+ module.bytes_of_vec_znx_dft(rank + 1, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn automorphism_scratch_space(
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
out_size: usize,
|
||||||
|
in_size: usize,
|
||||||
|
auto_key_size: usize,
|
||||||
|
rank: usize,
|
||||||
|
) -> usize {
|
||||||
|
let size: usize = in_size.min(out_size);
|
||||||
|
let tmp_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, size);
|
||||||
|
let tmp_idft: usize = module.bytes_of_vec_znx(rank + 1, size);
|
||||||
|
let vmp: usize = GLWECiphertext::keyswitch_from_fourier_scratch_space(module, size, rank, size, rank, auto_key_size);
|
||||||
|
tmp_dft + tmp_idft + vmp
|
||||||
|
}
|
||||||
|
|
||||||
pub fn external_product_scratch_space(
|
pub fn external_product_scratch_space(
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
out_size: usize,
|
out_size: usize,
|
||||||
@@ -182,6 +197,73 @@ where
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn automorphism<DataLhs, DataRhs>(
|
||||||
|
&mut self,
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
lhs: &GGSWCiphertext<DataLhs, FFT64>,
|
||||||
|
rhs: &AutomorphismKey<DataRhs, FFT64>,
|
||||||
|
scratch: &mut Scratch,
|
||||||
|
) where
|
||||||
|
MatZnxDft<DataLhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
|
MatZnxDft<DataRhs, FFT64>: MatZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(
|
||||||
|
self.rank(),
|
||||||
|
lhs.rank(),
|
||||||
|
"ggsw_out rank: {} != ggsw_in rank: {}",
|
||||||
|
self.rank(),
|
||||||
|
lhs.rank()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
self.rank(),
|
||||||
|
rhs.rank(),
|
||||||
|
"ggsw_in rank: {} != auto_key rank: {}",
|
||||||
|
self.rank(),
|
||||||
|
rhs.rank()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let size: usize = self.size().min(lhs.size());
|
||||||
|
let cols: usize = self.rank() + 1;
|
||||||
|
|
||||||
|
let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols, size);
|
||||||
|
|
||||||
|
let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> {
|
||||||
|
data: tmp_dft_data,
|
||||||
|
basek: lhs.basek(),
|
||||||
|
k: lhs.k(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let (tmp_idft_data, scratch2) = scratch1.tmp_vec_znx(module, cols, size);
|
||||||
|
|
||||||
|
let mut tmp_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> {
|
||||||
|
data: tmp_idft_data,
|
||||||
|
basek: self.basek(),
|
||||||
|
k: self.k(),
|
||||||
|
};
|
||||||
|
|
||||||
|
(0..cols).for_each(|col_i| {
|
||||||
|
(0..self.rows()).for_each(|row_j| {
|
||||||
|
lhs.get_row(module, row_j, col_i, &mut tmp_dft);
|
||||||
|
tmp_idft.keyswitch_from_fourier(module, &tmp_dft, &rhs.key, scratch2);
|
||||||
|
(0..cols).for_each(|i| {
|
||||||
|
module.vec_znx_automorphism_inplace(rhs.p(), &mut tmp_idft, i);
|
||||||
|
});
|
||||||
|
self.set_row(module, row_j, col_i, &tmp_dft);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
tmp_dft.data.zero();
|
||||||
|
|
||||||
|
(self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
|
||||||
|
(0..self.rank() + 1).for_each(|col_j| {
|
||||||
|
self.set_row(module, row_i, col_j, &tmp_dft);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
pub fn external_product<DataLhs, DataRhs>(
|
pub fn external_product<DataLhs, DataRhs>(
|
||||||
&mut self,
|
&mut self,
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ pub mod glwe_ciphertext_fourier;
|
|||||||
pub mod glwe_plaintext;
|
pub mod glwe_plaintext;
|
||||||
pub mod keys;
|
pub mod keys;
|
||||||
pub mod keyswitch_key;
|
pub mod keyswitch_key;
|
||||||
|
pub mod tensor_key;
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test_fft64;
|
mod test_fft64;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|||||||
125
core/src/tensor_key.rs
Normal file
125
core/src/tensor_key.rs
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
use base2k::{
|
||||||
|
Backend, FFT64, MatZnxDft, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, ScalarZnxDftAlloc,
|
||||||
|
ScalarZnxDftOps, ScalarZnxDftToRef, Scratch, VecZnxDftOps, VecZnxDftToRef,
|
||||||
|
};
|
||||||
|
use sampling::source::Source;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
elem::Infos,
|
||||||
|
keys::{SecretKey, SecretKeyFourier},
|
||||||
|
keyswitch_key::GLWESwitchingKey,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub struct TensorKey<C, B: Backend> {
|
||||||
|
pub(crate) keys: Vec<GLWESwitchingKey<C, B>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TensorKey<Vec<u8>, FFT64> {
|
||||||
|
pub fn new(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
|
||||||
|
let mut keys: Vec<GLWESwitchingKey<Vec<u8>, FFT64>> = Vec::new();
|
||||||
|
let pairs: usize = ((rank + 1) * rank) >> 1;
|
||||||
|
(0..pairs).for_each(|_| {
|
||||||
|
keys.push(GLWESwitchingKey::new(module, basek, k, rows, 1, rank));
|
||||||
|
});
|
||||||
|
Self { keys: keys }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, B: Backend> Infos for TensorKey<T, B> {
|
||||||
|
type Inner = MatZnxDft<T, B>;
|
||||||
|
|
||||||
|
fn inner(&self) -> &Self::Inner {
|
||||||
|
&self.keys[0].inner()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn basek(&self) -> usize {
|
||||||
|
self.keys[0].basek()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn k(&self) -> usize {
|
||||||
|
self.keys[0].k()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, B: Backend> TensorKey<T, B> {
|
||||||
|
pub fn rank(&self) -> usize {
|
||||||
|
self.keys[0].rank()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rank_in(&self) -> usize {
|
||||||
|
self.keys[0].rank_in()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rank_out(&self) -> usize {
|
||||||
|
self.keys[0].rank_out()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TensorKey<Vec<u8>, FFT64> {
|
||||||
|
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
|
||||||
|
module.bytes_of_scalar_znx_dft(1) + GLWESwitchingKey::encrypt_sk_scratch_space(module, rank, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<DataSelf> TensorKey<DataSelf, FFT64>
|
||||||
|
where
|
||||||
|
MatZnxDft<DataSelf, FFT64>: MatZnxDftToMut<FFT64> + MatZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
|
pub fn encrypt_sk<DataSk>(
|
||||||
|
&mut self,
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
sk_dft: &SecretKeyFourier<DataSk, FFT64>,
|
||||||
|
source_xa: &mut Source,
|
||||||
|
source_xe: &mut Source,
|
||||||
|
sigma: f64,
|
||||||
|
scratch: &mut Scratch,
|
||||||
|
) where
|
||||||
|
ScalarZnxDft<DataSk, FFT64>: VecZnxDftToRef<FFT64> + ScalarZnxDftToRef<FFT64>,
|
||||||
|
{
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(self.rank(), sk_dft.rank());
|
||||||
|
assert_eq!(self.n(), module.n());
|
||||||
|
assert_eq!(sk_dft.n(), module.n());
|
||||||
|
}
|
||||||
|
|
||||||
|
let rank: usize = self.rank();
|
||||||
|
|
||||||
|
(0..rank).for_each(|i| {
|
||||||
|
(i..rank).for_each(|j| {
|
||||||
|
let (mut sk_ij_dft, scratch1) = scratch.tmp_scalar_znx_dft(module, 1);
|
||||||
|
module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j);
|
||||||
|
let sk_ij: ScalarZnx<&mut [u8]> = module
|
||||||
|
.vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft())
|
||||||
|
.to_vec_znx_small()
|
||||||
|
.to_scalar_znx();
|
||||||
|
let sk_ij: SecretKey<&mut [u8]> = SecretKey {
|
||||||
|
data: sk_ij,
|
||||||
|
dist: sk_dft.dist,
|
||||||
|
};
|
||||||
|
|
||||||
|
self.at_mut(i, j).encrypt_sk(
|
||||||
|
module, &sk_ij, sk_dft, source_xa, source_xe, sigma, scratch1,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a reference to GLWESwitchingKey_{s}(s[i] * s[j])
|
||||||
|
pub fn at(&self, mut i: usize, mut j: usize) -> &GLWESwitchingKey<DataSelf, FFT64> {
|
||||||
|
if i > j {
|
||||||
|
std::mem::swap(&mut i, &mut j);
|
||||||
|
};
|
||||||
|
let rank: usize = self.rank();
|
||||||
|
&self.keys[i * rank + j - (i * (i + 1) / 2)]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a mutable reference to GLWESwitchingKey_{s}(s[i] * s[j])
|
||||||
|
pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> &mut GLWESwitchingKey<DataSelf, FFT64> {
|
||||||
|
if i > j {
|
||||||
|
std::mem::swap(&mut i, &mut j);
|
||||||
|
};
|
||||||
|
let rank: usize = self.rank();
|
||||||
|
&mut self.keys[i * rank + j - (i * (i + 1) / 2)]
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,6 +18,14 @@ fn automorphism() {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn automorphism_inplace() {
|
||||||
|
(1..4).for_each(|rank| {
|
||||||
|
println!("test automorphism_inplace rank: {}", rank);
|
||||||
|
test_automorphism_inplace(-1, 5, 12, 12, 60, 3.2, rank);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) {
|
fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) {
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||||
let rows = (k_ksk + basek - 1) / basek;
|
let rows = (k_ksk + basek - 1) / basek;
|
||||||
@@ -115,3 +123,94 @@ fn test_automorphism(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize,
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn test_automorphism_inplace(p0: i64, p1: i64, log_n: usize, basek: usize, k_ksk: usize, sigma: f64, rank: usize) {
|
||||||
|
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||||
|
let rows = (k_ksk + basek - 1) / basek;
|
||||||
|
|
||||||
|
let mut auto_key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
|
||||||
|
let mut auto_key_apply: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k_ksk, rows, rank);
|
||||||
|
|
||||||
|
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||||
|
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||||
|
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||||
|
|
||||||
|
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||||
|
AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size())
|
||||||
|
| GLWECiphertextFourier::decrypt_scratch_space(&module, auto_key.size())
|
||||||
|
| AutomorphismKey::automorphism_inplace_scratch_space(&module, auto_key.size(), auto_key_apply.size(), rank),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||||
|
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||||
|
|
||||||
|
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||||
|
sk_dft.dft(&module, &sk);
|
||||||
|
|
||||||
|
// gglwe_{s1}(s0) = s0 -> s1
|
||||||
|
auto_key.encrypt_sk(
|
||||||
|
&module,
|
||||||
|
p0,
|
||||||
|
&sk,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
sigma,
|
||||||
|
scratch.borrow(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// gglwe_{s2}(s1) -> s1 -> s2
|
||||||
|
auto_key_apply.encrypt_sk(
|
||||||
|
&module,
|
||||||
|
p1,
|
||||||
|
&sk,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
sigma,
|
||||||
|
scratch.borrow(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// gglwe_{s1}(s0) (x) gglwe_{s2}(s1) = gglwe_{s2}(s0)
|
||||||
|
auto_key.automorphism_inplace(&module, &auto_key_apply, scratch.borrow());
|
||||||
|
|
||||||
|
let mut ct_glwe_dft: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ksk, rank);
|
||||||
|
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ksk);
|
||||||
|
|
||||||
|
let mut sk_auto: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||||
|
sk_auto.fill_zero(); // Necessary to avoid panic of unfilled sk
|
||||||
|
(0..rank).for_each(|i| {
|
||||||
|
module.scalar_znx_automorphism(module.galois_element_inv(p0 * p1), &mut sk_auto, i, &sk, i);
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut sk_auto_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||||
|
sk_auto_dft.dft(&module, &sk_auto);
|
||||||
|
|
||||||
|
(0..auto_key.rank_in()).for_each(|col_i| {
|
||||||
|
(0..auto_key.rows()).for_each(|row_i| {
|
||||||
|
auto_key.get_row(&module, row_i, col_i, &mut ct_glwe_dft);
|
||||||
|
|
||||||
|
ct_glwe_dft.decrypt(&module, &mut pt, &sk_auto_dft, scratch.borrow());
|
||||||
|
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk, col_i);
|
||||||
|
|
||||||
|
let noise_have: f64 = pt.data.std(0, basek).log2();
|
||||||
|
let noise_want: f64 = noise_gglwe_product(
|
||||||
|
module.n() as f64,
|
||||||
|
basek,
|
||||||
|
0.5,
|
||||||
|
0.5,
|
||||||
|
0f64,
|
||||||
|
sigma * sigma,
|
||||||
|
0f64,
|
||||||
|
rank as f64,
|
||||||
|
k_ksk,
|
||||||
|
k_ksk,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
(noise_have - noise_want).abs() <= 0.1,
|
||||||
|
"{} {}",
|
||||||
|
noise_have,
|
||||||
|
noise_want
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ use base2k::{
|
|||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
automorphism::AutomorphismKey,
|
||||||
elem::{GetRow, Infos},
|
elem::{GetRow, Infos},
|
||||||
ggsw_ciphertext::GGSWCiphertext,
|
ggsw_ciphertext::GGSWCiphertext,
|
||||||
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||||
@@ -104,6 +105,123 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ggsw: usize, sigma: f64, rank:
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fn test_automorphism(p: i64, log_n: usize, basek: usize, k: usize, rank: usize, sigma: f64) {
|
||||||
|
// let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||||
|
// let rows: usize = (k_ggsw + basek - 1) / basek;
|
||||||
|
//
|
||||||
|
// let mut ct_ggsw_in: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
|
||||||
|
// let mut ct_ggsw_out: GGSWCiphertext<Vec<u8>, FFT64> = GGSWCiphertext::new(&module, basek, k, rows, rank);
|
||||||
|
// let mut auto_key: AutomorphismKey<Vec<u8>, FFT64> = AutomorphismKey::new(&module, basek, k, rows, rank);
|
||||||
|
//
|
||||||
|
// let mut pt_ggsw_in: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||||
|
// let mut pt_ggsw_out: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||||
|
//
|
||||||
|
// let mut source_xs: Source = Source::new([0u8; 32]);
|
||||||
|
// let mut source_xe: Source = Source::new([0u8; 32]);
|
||||||
|
// let mut source_xa: Source = Source::new([0u8; 32]);
|
||||||
|
//
|
||||||
|
// pt_ggsw_in.fill_ternary_prob(0, 0.5, &mut source_xs);
|
||||||
|
//
|
||||||
|
// let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||||
|
// AutomorphismKey::encrypt_sk_scratch_space(&module, rank, auto_key.size())
|
||||||
|
// | GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_out.size())
|
||||||
|
// | GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_in.size())
|
||||||
|
// | GGSWCiphertext::automorphism_scratch_space(
|
||||||
|
// &module,
|
||||||
|
// ct_ggsw_out.size(),
|
||||||
|
// ct_ggsw_in.size(),
|
||||||
|
// auto_key.size(),
|
||||||
|
// rank,
|
||||||
|
// ),
|
||||||
|
// );
|
||||||
|
//
|
||||||
|
// let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||||
|
// sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||||
|
//
|
||||||
|
// let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||||
|
// sk_dft.dft(&module, &sk);
|
||||||
|
//
|
||||||
|
// ct_ggsw_in.encrypt_sk(
|
||||||
|
// &module,
|
||||||
|
// &pt_ggsw_in,
|
||||||
|
// &sk_dft,
|
||||||
|
// &mut source_xa,
|
||||||
|
// &mut source_xe,
|
||||||
|
// sigma,
|
||||||
|
// scratch.borrow(),
|
||||||
|
// );
|
||||||
|
//
|
||||||
|
// auto_key.encrypt_sk(
|
||||||
|
// &module,
|
||||||
|
// p,
|
||||||
|
// &sk,
|
||||||
|
// &mut source_xa,
|
||||||
|
// &mut source_xe,
|
||||||
|
// sigma,
|
||||||
|
// scratch.borrow(),
|
||||||
|
// );
|
||||||
|
//
|
||||||
|
// ct_ggsw_out.automorphism(&module, &ct_ggsw_in, &auto_key, scratch.borrow());
|
||||||
|
//
|
||||||
|
// let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k_ggsw, rank);
|
||||||
|
// let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||||
|
// let mut pt_dft: VecZnxDft<Vec<u8>, FFT64> = module.new_vec_znx_dft(1, ct_ggsw_lhs_out.size());
|
||||||
|
// let mut pt_big: VecZnxBig<Vec<u8>, FFT64> = module.new_vec_znx_big(1, ct_ggsw_lhs_out.size());
|
||||||
|
// let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k_ggsw);
|
||||||
|
//
|
||||||
|
// module.vec_znx_rotate_inplace(k as i64, &mut pt_ggsw_lhs, 0);
|
||||||
|
//
|
||||||
|
// (0..ct_ggsw_lhs_out.rank() + 1).for_each(|col_j| {
|
||||||
|
// (0..ct_ggsw_lhs_out.rows()).for_each(|row_i| {
|
||||||
|
// module.vec_znx_add_scalar_inplace(&mut pt_want, 0, row_i, &pt_ggsw_lhs, 0);
|
||||||
|
//
|
||||||
|
// if col_j > 0 {
|
||||||
|
// module.vec_znx_dft(&mut pt_dft, 0, &pt_want, 0);
|
||||||
|
// module.svp_apply_inplace(&mut pt_dft, 0, &sk_dft, col_j - 1);
|
||||||
|
// module.vec_znx_idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
|
||||||
|
// module.vec_znx_big_normalize(basek, &mut pt_want, 0, &pt_big, 0, scratch.borrow());
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// ct_ggsw_lhs_out.get_row(&module, row_i, col_j, &mut ct_glwe_fourier);
|
||||||
|
// ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||||
|
//
|
||||||
|
// module.vec_znx_sub_ab_inplace(&mut pt, 0, &pt_want, 0);
|
||||||
|
//
|
||||||
|
// let noise_have: f64 = pt.data.std(0, basek).log2();
|
||||||
|
//
|
||||||
|
// let var_gct_err_lhs: f64 = sigma * sigma;
|
||||||
|
// let var_gct_err_rhs: f64 = 0f64;
|
||||||
|
//
|
||||||
|
// let var_msg: f64 = 1f64 / module.n() as f64; // X^{k}
|
||||||
|
// let var_a0_err: f64 = sigma * sigma;
|
||||||
|
// let var_a1_err: f64 = 1f64 / 12f64;
|
||||||
|
//
|
||||||
|
// let noise_want: f64 = noise_ggsw_product(
|
||||||
|
// module.n() as f64,
|
||||||
|
// basek,
|
||||||
|
// 0.5,
|
||||||
|
// var_msg,
|
||||||
|
// var_a0_err,
|
||||||
|
// var_a1_err,
|
||||||
|
// var_gct_err_lhs,
|
||||||
|
// var_gct_err_rhs,
|
||||||
|
// rank as f64,
|
||||||
|
// k_ggsw,
|
||||||
|
// k_ggsw,
|
||||||
|
// );
|
||||||
|
//
|
||||||
|
// assert!(
|
||||||
|
// (noise_have - noise_want).abs() <= 0.1,
|
||||||
|
// "have: {} want: {}",
|
||||||
|
// noise_have,
|
||||||
|
// noise_want
|
||||||
|
// );
|
||||||
|
//
|
||||||
|
// pt_want.data.zero();
|
||||||
|
// });
|
||||||
|
// });
|
||||||
|
// }
|
||||||
|
|
||||||
fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) {
|
fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize, sigma: f64) {
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||||
|
|
||||||
@@ -126,8 +244,7 @@ fn test_external_product(log_n: usize, basek: usize, k_ggsw: usize, rank: usize,
|
|||||||
pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
|
pt_ggsw_rhs.to_mut().raw_mut()[k] = 1; //X^{k}
|
||||||
|
|
||||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||||
GLWESwitchingKey::encrypt_sk_scratch_space(&module, rank, ct_ggsw_rhs.size())
|
GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size())
|
||||||
| GLWECiphertextFourier::decrypt_scratch_space(&module, ct_ggsw_lhs_out.size())
|
|
||||||
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs_in.size())
|
| GGSWCiphertext::encrypt_sk_scratch_space(&module, rank, ct_ggsw_lhs_in.size())
|
||||||
| GGSWCiphertext::external_product_scratch_space(
|
| GGSWCiphertext::external_product_scratch_space(
|
||||||
&module,
|
&module,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use base2k::{
|
use base2k::{
|
||||||
Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut,
|
Decoding, Encoding, FFT64, FillUniform, Module, ScalarZnx, ScalarZnxAlloc, ScratchOwned, Stats, VecZnxOps, VecZnxToMut,
|
||||||
ZnxView, ZnxViewMut, ZnxZero,
|
ZnxViewMut, ZnxZero,
|
||||||
};
|
};
|
||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
@@ -75,6 +75,22 @@ fn external_product_inplace() {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn automorphism_inplace() {
|
||||||
|
(1..4).for_each(|rank| {
|
||||||
|
println!("test automorphism_inplace rank: {}", rank);
|
||||||
|
test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn automorphism() {
|
||||||
|
(1..4).for_each(|rank| {
|
||||||
|
println!("test automorphism rank: {}", rank);
|
||||||
|
test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) {
|
fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma: f64, rank: usize) {
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||||
|
|
||||||
@@ -416,14 +432,6 @@ fn test_keyswitch_inplace(log_n: usize, basek: usize, k_ksk: usize, k_ct: usize,
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn automorphism() {
|
|
||||||
(1..4).for_each(|rank| {
|
|
||||||
println!("test automorphism rank: {}", rank);
|
|
||||||
test_automorphism(12, 12, -5, 60, 45, 60, rank, 3.2);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
fn test_automorphism(
|
fn test_automorphism(
|
||||||
log_n: usize,
|
log_n: usize,
|
||||||
basek: usize,
|
basek: usize,
|
||||||
@@ -515,14 +523,6 @@ fn test_automorphism(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn automorphism_inplace() {
|
|
||||||
(1..4).for_each(|rank| {
|
|
||||||
println!("test automorphism_inplace rank: {}", rank);
|
|
||||||
test_automorphism_inplace(12, 12, -5, 60, 60, rank, 3.2);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usize, k_ct: usize, rank: usize, sigma: f64) {
|
fn test_automorphism_inplace(log_n: usize, basek: usize, p: i64, k_autokey: usize, k_ct: usize, rank: usize, sigma: f64) {
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||||
let rows: usize = (k_ct + basek - 1) / basek;
|
let rows: usize = (k_ct + basek - 1) / basek;
|
||||||
|
|||||||
@@ -3,3 +3,4 @@ mod gglwe;
|
|||||||
mod ggsw;
|
mod ggsw;
|
||||||
mod glwe;
|
mod glwe;
|
||||||
mod glwe_fourier;
|
mod glwe_fourier;
|
||||||
|
mod tensor_key;
|
||||||
|
|||||||
77
core/src/test_fft64/tensor_key.rs
Normal file
77
core/src/test_fft64/tensor_key.rs
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
use base2k::{FFT64, Module, ScalarZnx, ScalarZnxDftAlloc, ScalarZnxDftOps, ScratchOwned, Stats, VecZnxDftOps, VecZnxOps};
|
||||||
|
use sampling::source::Source;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
elem::{GetRow, Infos},
|
||||||
|
glwe_ciphertext_fourier::GLWECiphertextFourier,
|
||||||
|
glwe_plaintext::GLWEPlaintext,
|
||||||
|
keys::{SecretKey, SecretKeyFourier},
|
||||||
|
tensor_key::TensorKey,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn encrypt_sk() {
|
||||||
|
(1..4).for_each(|rank| {
|
||||||
|
println!("test encrypt_sk rank: {}", rank);
|
||||||
|
test_encrypt_sk(12, 16, 54, 3.2, rank);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_encrypt_sk(log_n: usize, basek: usize, k: usize, sigma: f64, rank: usize) {
|
||||||
|
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
|
||||||
|
|
||||||
|
let rows: usize = (k + basek - 1) / basek;
|
||||||
|
|
||||||
|
let mut tensor_key: TensorKey<Vec<u8>, FFT64> = TensorKey::new(&module, basek, k, rows, rank);
|
||||||
|
|
||||||
|
let mut source_xs: Source = Source::new([0u8; 32]);
|
||||||
|
let mut source_xe: Source = Source::new([0u8; 32]);
|
||||||
|
let mut source_xa: Source = Source::new([0u8; 32]);
|
||||||
|
|
||||||
|
let mut scratch: ScratchOwned = ScratchOwned::new(TensorKey::encrypt_sk_scratch_space(
|
||||||
|
&module,
|
||||||
|
rank,
|
||||||
|
tensor_key.size(),
|
||||||
|
));
|
||||||
|
|
||||||
|
let mut sk: SecretKey<Vec<u8>> = SecretKey::new(&module, rank);
|
||||||
|
sk.fill_ternary_prob(0.5, &mut source_xs);
|
||||||
|
|
||||||
|
let mut sk_dft: SecretKeyFourier<Vec<u8>, FFT64> = SecretKeyFourier::new(&module, rank);
|
||||||
|
sk_dft.dft(&module, &sk);
|
||||||
|
|
||||||
|
tensor_key.encrypt_sk(
|
||||||
|
&module,
|
||||||
|
&sk_dft,
|
||||||
|
&mut source_xa,
|
||||||
|
&mut source_xe,
|
||||||
|
sigma,
|
||||||
|
scratch.borrow(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut ct_glwe_fourier: GLWECiphertextFourier<Vec<u8>, FFT64> = GLWECiphertextFourier::new(&module, basek, k, rank);
|
||||||
|
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::new(&module, basek, k);
|
||||||
|
|
||||||
|
(0..rank).for_each(|i| {
|
||||||
|
(0..rank).for_each(|j| {
|
||||||
|
let mut sk_ij_dft: base2k::ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(1);
|
||||||
|
module.svp_apply(&mut sk_ij_dft, 0, &sk_dft.data, i, &sk_dft.data, j);
|
||||||
|
let sk_ij: ScalarZnx<Vec<u8>> = module
|
||||||
|
.vec_znx_idft_consume(sk_ij_dft.as_vec_znx_dft())
|
||||||
|
.to_vec_znx_small()
|
||||||
|
.to_scalar_znx();
|
||||||
|
|
||||||
|
(0..tensor_key.rank_in()).for_each(|col_i| {
|
||||||
|
(0..tensor_key.rows()).for_each(|row_i| {
|
||||||
|
tensor_key
|
||||||
|
.at(i, j)
|
||||||
|
.get_row(&module, row_i, col_i, &mut ct_glwe_fourier);
|
||||||
|
ct_glwe_fourier.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
|
||||||
|
module.vec_znx_sub_scalar_inplace(&mut pt, 0, row_i, &sk_ij, col_i);
|
||||||
|
let std_pt: f64 = pt.data.std(0, basek) * (k as f64).exp2();
|
||||||
|
assert!((sigma - std_pt).abs() <= 0.2, "{} {}", sigma, std_pt);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user