diff --git a/backend/spqlios-arithmetic b/backend/spqlios-arithmetic index b919282..d604503 160000 --- a/backend/spqlios-arithmetic +++ b/backend/spqlios-arithmetic @@ -1 +1 @@ -Subproject commit b919282c9b913e8b11418df6afdb0baa02debc9b +Subproject commit d6045033e510315437c46ed2ddb80cfbb454950b diff --git a/backend/src/ffi/vmp.rs b/backend/src/ffi/vmp.rs index fc8a7ae..4f58e9b 100644 --- a/backend/src/ffi/vmp.rs +++ b/backend/src/ffi/vmp.rs @@ -47,6 +47,7 @@ unsafe extern "C" { pmat: *const VMP_PMAT, nrows: u64, ncols: u64, + pmat_scale: u64, tmp_space: *mut u8, ); } @@ -79,6 +80,7 @@ unsafe extern "C" { pmat: *const VMP_PMAT, nrows: u64, ncols: u64, + pmat_scale: u64, tmp_space: *mut u8, ); } diff --git a/backend/src/mat_znx_dft_ops.rs b/backend/src/mat_znx_dft_ops.rs index 5ad724f..d0316a9 100644 --- a/backend/src/mat_znx_dft_ops.rs +++ b/backend/src/mat_znx_dft_ops.rs @@ -101,7 +101,7 @@ pub trait MatZnxDftOps { B: MatZnxToRef; // Same as [MatZnxDftOps::vmp_apply] except result is added on R instead of overwritting R. - fn vmp_apply_add(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + fn vmp_apply_add(&self, res: &mut R, a: &A, b: &B, scale: usize, scratch: &mut Scratch) where R: VecZnxDftToMut, A: VecZnxDftToRef, @@ -309,7 +309,7 @@ impl MatZnxDftOps for Module { } } - fn vmp_apply_add(&self, res: &mut R, a: &A, b: &B, scratch: &mut Scratch) + fn vmp_apply_add(&self, res: &mut R, a: &A, b: &B, scale: usize, scratch: &mut Scratch) where R: VecZnxDftToMut, A: VecZnxDftToRef, @@ -358,6 +358,7 @@ impl MatZnxDftOps for Module { b.as_ptr() as *const vmp::vmp_pmat_t, (b.rows() * b.cols_in()) as u64, (b.size() * b.cols_out()) as u64, + scale as u64, tmp_bytes.as_mut_ptr(), ) } @@ -368,6 +369,7 @@ mod tests { use crate::{ Decoding, FFT64, FillUniform, MatZnxDft, MatZnxDftOps, Module, ScratchOwned, VecZnx, VecZnxAlloc, VecZnxBig, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDft, VecZnxDftAlloc, VecZnxDftOps, ZnxInfos, ZnxView, ZnxViewMut, + ZnxZero, }; use sampling::source::Source; @@ -409,7 +411,7 @@ mod tests { let basek: usize = 15; let a_size: usize = 5; let mat_size: usize = 6; - let res_size: usize = 5; + let res_size: usize = a_size; [1, 2].iter().for_each(|in_cols| { [1, 2].iter().for_each(|out_cols| { @@ -419,7 +421,6 @@ mod tests { let mat_rows: usize = a_size; let mat_cols_in: usize = a_cols; let mat_cols_out: usize = res_cols; - let res_cols: usize = mat_cols_out; let mut scratch: ScratchOwned = ScratchOwned::new( module.vmp_apply_tmp_bytes( @@ -435,7 +436,7 @@ mod tests { let mut a: VecZnx> = module.new_vec_znx(a_cols, a_size); (0..a_cols).for_each(|i| { - a.at_mut(i, 2)[i + 1] = 1; + a.at_mut(i, a_size - 1)[i + 1] = 1; }); let mut mat_znx_dft: MatZnxDft, FFT64> = @@ -479,7 +480,100 @@ mod tests { (0..a_cols).for_each(|i| { res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1; }); - res_have.decode_vec_i64(col_i, basek, basek * 3, &mut res_have_vi64); + res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64); + assert_eq!(res_have_vi64, res_want_vi64); + }); + }); + }); + } + + #[test] + fn vmp_apply_add() { + let log_n: i32 = 5; + let n: usize = 1 << log_n; + + let module: Module = Module::::new(n); + let basek: usize = 15; + let a_size: usize = 5; + let mat_size: usize = 6; + let res_size: usize = a_size; + + [1, 2].iter().for_each(|in_cols| { + [1, 2].iter().for_each(|out_cols| { + let a_cols: usize = *in_cols; + let res_cols: usize = *out_cols; + + let mat_rows: usize = a_size; + let mat_cols_in: usize = a_cols; + let mat_cols_out: usize = res_cols; + + let mut scratch: ScratchOwned = ScratchOwned::new( + module.vmp_apply_tmp_bytes( + res_size, + a_size, + mat_rows, + mat_cols_in, + mat_cols_out, + mat_size, + ) | module.vec_znx_big_normalize_tmp_bytes(), + ); + + let mut a: VecZnx> = module.new_vec_znx(a_cols, a_size); + + (0..a_cols).for_each(|i| { + a.at_mut(i, a_size - 1)[i + 1] = 1; + }); + + let mut mat_znx_dft: MatZnxDft, FFT64> = + module.new_mat_znx_dft(mat_rows, mat_cols_in, mat_cols_out, mat_size); + + let mut c_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(mat_cols_out, mat_size); + let mut c_big: VecZnxBig, FFT64> = module.new_vec_znx_big(mat_cols_out, mat_size); + + let mut tmp: VecZnx> = module.new_vec_znx(mat_cols_out, mat_size); + + // Construts a [VecZnxMatDft] that performs cyclic rotations on each submatrix. + (0..a.size()).for_each(|row_i| { + (0..mat_cols_in).for_each(|col_in_i| { + (0..mat_cols_out).for_each(|col_out_i| { + let idx = 1 + col_in_i * mat_cols_out + col_out_i; + tmp.at_mut(col_out_i, row_i)[idx] = 1 as i64; // X^{idx} + module.vec_znx_dft(&mut c_dft, col_out_i, &tmp, col_out_i); + tmp.at_mut(col_out_i, row_i)[idx] = 0 as i64; + }); + module.vmp_prepare_row(&mut mat_znx_dft, row_i, col_in_i, &c_dft); + }); + }); + + let mut a_dft: VecZnxDft, FFT64> = module.new_vec_znx_dft(a_cols, a_size); + (0..a_cols).for_each(|i| { + module.vec_znx_dft(&mut a_dft, i, &a, i); + }); + + c_dft.zero(); + (0..c_dft.cols()).for_each(|i| { + module.vec_znx_dft(&mut c_dft, i, &a, 0); + }); + + module.vmp_apply_add(&mut c_dft, &a_dft, &mat_znx_dft, 0, scratch.borrow()); + + let mut res_have_vi64: Vec = vec![i64::default(); n]; + + let mut res_have: VecZnx> = module.new_vec_znx(res_cols, res_size); + (0..mat_cols_out).for_each(|i| { + module.vec_znx_idft_tmp_a(&mut c_big, i, &mut c_dft, i); + module.vec_znx_big_normalize(basek, &mut res_have, i, &c_big, i, scratch.borrow()); + }); + + (0..mat_cols_out).for_each(|col_i| { + let mut res_want_vi64: Vec = vec![i64::default(); n]; + (0..a_cols).for_each(|i| { + res_want_vi64[(i + 1) + (1 + i * mat_cols_out + col_i)] = 1; + }); + + res_want_vi64[1] += 1; + + res_have.decode_vec_i64(col_i, basek, basek * a_size, &mut res_have_vi64); assert_eq!(res_have_vi64, res_want_vi64); }); }); diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index 339535c..a4165e6 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -12,15 +12,15 @@ pub struct AutomorphismKey { } impl AutomorphismKey, FFT64> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { AutomorphismKey { - key: GLWESwitchingKey::alloc(module, basek, k, rows, rank, rank), + key: GLWESwitchingKey::alloc(module, basek, k, rows, digits, rank, rank), p: 0, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> usize { - GLWESwitchingKey::, FFT64>::bytes_of(module, basek, k, rows, rank, rank) + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { + GLWESwitchingKey::, FFT64>::bytes_of(module, basek, k, rows, digits,rank, rank) } } @@ -45,6 +45,10 @@ impl AutomorphismKey { self.p } + pub fn digits(&self) -> usize { + self.key.digits() + } + pub fn rank(&self) -> usize { self.key.rank() } diff --git a/core/src/elem.rs b/core/src/elem.rs index fdfd5bd..9a3a285 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -1,6 +1,6 @@ use backend::{Backend, Module, ZnxInfos}; -use crate::{GLWECiphertextFourier, derive_size}; +use crate::{GLWECiphertextFourier, div_ceil}; pub trait Infos { type Inner: ZnxInfos; @@ -34,7 +34,7 @@ pub trait Infos { /// Returns the number of size per polynomial. fn size(&self) -> usize { let size: usize = self.inner().size(); - debug_assert_eq!(size, derive_size(self.basek(), self.k())); + debug_assert_eq!(size, div_ceil(self.basek(), self.k())); size } diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs index 0ab99ed..1650f59 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe_ciphertext.rs @@ -4,25 +4,27 @@ use backend::{ }; use sampling::source::Source; -use crate::{GLWECiphertext, GLWECiphertextFourier, GLWESecret, GetRow, Infos, ScratchCore, SetRow, derive_size}; +use crate::{GLWECiphertext, GLWECiphertextFourier, GLWESecret, GetRow, Infos, ScratchCore, SetRow, div_ceil}; pub struct GGLWECiphertext { pub(crate) data: MatZnxDft, pub(crate) basek: usize, pub(crate) k: usize, + pub(crate) digits: usize, } impl GGLWECiphertext, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self { Self { - data: module.new_mat_znx_dft(rows, rank_in, rank_out + 1, derive_size(basek, k)), + data: module.new_mat_znx_dft(div_ceil(rows, digits), rank_in, rank_out + 1, div_ceil(basek, k)), basek: basek, k, + digits, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> usize { - module.bytes_of_mat_znx_dft(rows, rank_in, rank_out + 1, derive_size(basek, k)) + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> usize { + module.bytes_of_mat_znx_dft(div_ceil(rows, digits), rank_in, rank_out + 1, div_ceil(basek, k)) } } @@ -47,6 +49,10 @@ impl GGLWECiphertext { self.data.cols_out() - 1 } + pub fn digits(&self) -> usize{ + self.digits + } + pub fn rank_in(&self) -> usize { self.data.cols_in() } @@ -58,7 +64,7 @@ impl GGLWECiphertext { impl GGLWECiphertext, FFT64> { pub fn generate_from_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - let size = derive_size(basek, k); + let size = div_ceil(basek, k); GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + module.bytes_of_vec_znx(rank + 1, size) + module.bytes_of_vec_znx(1, size) @@ -101,6 +107,7 @@ impl + AsRef<[u8]>> GGLWECiphertext { } let rows: usize = self.rows(); + let digits: usize = self.digits(); let basek: usize = self.basek(); let k: usize = self.k(); let rank_in: usize = self.rank_in(); @@ -125,7 +132,7 @@ impl + AsRef<[u8]>> GGLWECiphertext { (0..rows).for_each(|row_i| { // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt tmp_pt.data.zero(); // zeroes for next iteration - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i, pt, col_i); // Selects the i-th + module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i * digits, pt, col_i); module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch_3); // rlwe encrypt of vec_znx_pt into vec_znx_ct diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 062954c..98edf89 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -7,26 +7,28 @@ use sampling::source::Source; use crate::{ AutomorphismKey, GLWECiphertext, GLWECiphertextFourier, GLWESecret, GLWESwitchingKey, GetRow, Infos, ScratchCore, SetRow, - TensorKey, derive_size, + TensorKey, div_ceil, }; pub struct GGSWCiphertext { - pub data: MatZnxDft, - pub basek: usize, - pub k: usize, + pub(crate) data: MatZnxDft, + pub(crate) basek: usize, + pub(crate) k: usize, + pub(crate) digits: usize, } impl GGSWCiphertext, B> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { Self { - data: module.new_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k)), - basek: basek, + data: module.new_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k)), + basek, k: k, + digits, } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> usize { - module.bytes_of_mat_znx_dft(rows, rank + 1, rank + 1, derive_size(basek, k)) + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { + module.bytes_of_mat_znx_dft(div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k)) } } @@ -50,11 +52,15 @@ impl GGSWCiphertext { pub fn rank(&self) -> usize { self.data.cols_out() - 1 } + + pub fn digits(&self) -> usize { + self.digits + } } impl GGSWCiphertext, FFT64> { pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - let size = derive_size(basek, k); + let size = div_ceil(basek, k); GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + module.bytes_of_vec_znx(rank + 1, size) + module.bytes_of_vec_znx(1, size) @@ -68,8 +74,8 @@ impl GGSWCiphertext, FFT64> { tsk_k: usize, rank: usize, ) -> usize { - let tsk_size: usize = derive_size(basek, tsk_k); - let self_size: usize = derive_size(basek, self_k); + let tsk_size: usize = div_ceil(basek, tsk_k); + let self_size: usize = div_ceil(basek, self_k); let tmp_dft_i: usize = module.bytes_of_vec_znx_dft(rank + 1, tsk_size); let tmp_dft_col_data: usize = module.bytes_of_vec_znx_dft(1, self_size); let vmp: usize = tmp_dft_col_data + module.vmp_apply_tmp_bytes(self_size, self_size, self_size, rank, rank, tsk_size); @@ -87,7 +93,7 @@ impl GGSWCiphertext, FFT64> { rank: usize, ) -> usize { GLWECiphertext::keyswitch_from_fourier_scratch_space(module, basek, out_k, rank, in_k, rank, ksk_k) - + module.bytes_of_vec_znx_dft(rank + 1, derive_size(basek, in_k)) + + module.bytes_of_vec_znx_dft(rank + 1, div_ceil(basek, in_k)) } pub fn keyswitch_scratch_space( @@ -99,7 +105,7 @@ impl GGSWCiphertext, FFT64> { tsk_k: usize, rank: usize, ) -> usize { - let out_size: usize = derive_size(basek, out_k); + let out_size: usize = div_ceil(basek, out_k); let res_znx: usize = module.bytes_of_vec_znx(rank + 1, out_size); let ci_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, out_size); @@ -130,7 +136,7 @@ impl GGSWCiphertext, FFT64> { rank: usize, ) -> usize { let cols: usize = rank + 1; - let out_size: usize = derive_size(basek, out_k); + let out_size: usize = div_ceil(basek, out_k); let res: usize = module.bytes_of_vec_znx(cols, out_size); let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size); let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size); @@ -199,6 +205,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { let basek: usize = self.basek(); let k: usize = self.k(); let rank: usize = self.rank(); + let digits: usize = self.digits(); let (mut tmp_pt, scratch1) = scratch.tmp_glwe_pt(module, basek, k); let (mut tmp_ct, scratch2) = scratch1.tmp_glwe_ct(module, basek, k, rank); @@ -207,7 +214,7 @@ impl + AsRef<[u8]>> GGSWCiphertext { tmp_pt.data.zero(); // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i, pt, 0); + module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i * digits, pt, 0); module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch2); (0..rank + 1).for_each(|col_j| { diff --git a/core/src/glwe_ciphertext.rs b/core/src/glwe_ciphertext.rs index d5f8c47..a084b91 100644 --- a/core/src/glwe_ciphertext.rs +++ b/core/src/glwe_ciphertext.rs @@ -7,7 +7,7 @@ use sampling::source::Source; use crate::{ AutomorphismKey, GGSWCiphertext, GLWECiphertextFourier, GLWEOps, GLWEPlaintext, GLWEPublicKey, GLWESecret, GLWESwitchingKey, - Infos, SIX_SIGMA, SecretDistribution, SetMetaData, derive_size, + Infos, SIX_SIGMA, SecretDistribution, SetMetaData, div_ceil, }; pub struct GLWECiphertext { @@ -19,14 +19,14 @@ pub struct GLWECiphertext { impl GLWECiphertext> { pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { Self { - data: module.new_vec_znx(rank + 1, derive_size(basek, k)), + data: module.new_vec_znx(rank + 1, div_ceil(basek, k)), basek, k, } } pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx(rank + 1, derive_size(basek, k)) + module.bytes_of_vec_znx(rank + 1, div_ceil(basek, k)) } } @@ -69,18 +69,18 @@ impl> GLWECiphertext { impl GLWECiphertext> { pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = derive_size(basek, k); + let size: usize = div_ceil(basek, k); module.vec_znx_big_normalize_tmp_bytes() + module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx(1, size) } pub fn encrypt_pk_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = derive_size(basek, k); + let size: usize = div_ceil(basek, k); ((module.bytes_of_vec_znx_dft(1, size) + module.bytes_of_vec_znx_big(1, size)) | module.bytes_of_scalar_znx(1)) + module.bytes_of_scalar_znx_dft(1) + module.vec_znx_big_normalize_tmp_bytes() } pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = derive_size(basek, k); + let size: usize = div_ceil(basek, k); (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size)) + module.bytes_of_vec_znx_big(1, size) } @@ -94,9 +94,9 @@ impl GLWECiphertext> { ksk_k: usize, ) -> usize { let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, out_rank); - let in_size: usize = derive_size(basek, in_k); - let out_size: usize = derive_size(basek, out_k); - let ksk_size: usize = derive_size(basek, ksk_k); + let in_size: usize = div_ceil(basek, in_k); + let out_size: usize = div_ceil(basek, out_k); + let ksk_size: usize = div_ceil(basek, ksk_k); let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, in_rank, out_rank + 1, ksk_size) + module.bytes_of_vec_znx_dft(in_rank, in_size); let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); @@ -155,9 +155,9 @@ impl GLWECiphertext> { rank: usize, ) -> usize { let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank); - let in_size: usize = derive_size(basek, in_k); - let out_size: usize = derive_size(basek, out_k); - let ggsw_size: usize = derive_size(basek, ggsw_k); + let in_size: usize = div_ceil(basek, in_k); + let out_size: usize = div_ceil(basek, out_k); + let ggsw_size: usize = div_ceil(basek, ggsw_k); let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size) + module.vmp_apply_tmp_bytes( out_size, diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index 92864d3..4fe7e1e 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -4,7 +4,7 @@ use backend::{ }; use sampling::source::Source; -use crate::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, ScratchCore, derive_size}; +use crate::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, GLWESecret, GLWESwitchingKey, Infos, ScratchCore, div_ceil}; pub struct GLWECiphertextFourier { pub data: VecZnxDft, @@ -15,14 +15,14 @@ pub struct GLWECiphertextFourier { impl GLWECiphertextFourier, B> { pub fn alloc(module: &Module, basek: usize, k: usize, rank: usize) -> Self { Self { - data: module.new_vec_znx_dft(rank + 1, derive_size(basek, k)), + data: module.new_vec_znx_dft(rank + 1, div_ceil(basek, k)), basek: basek, k: k, } } pub fn bytes_of(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx_dft(rank + 1, derive_size(basek, k)) + module.bytes_of_vec_znx_dft(rank + 1, div_ceil(basek, k)) } } @@ -51,16 +51,16 @@ impl GLWECiphertextFourier { impl GLWECiphertextFourier, FFT64> { #[allow(dead_code)] pub(crate) fn idft_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - module.bytes_of_vec_znx(1, derive_size(basek, k)) + module.bytes_of_vec_znx(1, div_ceil(basek, k)) + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes()) } pub fn encrypt_sk_scratch_space(module: &Module, basek: usize, k: usize, rank: usize) -> usize { - module.bytes_of_vec_znx(rank + 1, derive_size(basek, k)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) + module.bytes_of_vec_znx(rank + 1, div_ceil(basek, k)) + GLWECiphertext::encrypt_sk_scratch_space(module, basek, k) } pub fn decrypt_scratch_space(module: &Module, basek: usize, k: usize) -> usize { - let size: usize = derive_size(basek, k); + let size: usize = div_ceil(basek, k); (module.vec_znx_big_normalize_tmp_bytes() | module.bytes_of_vec_znx_dft(1, size) | (module.bytes_of_vec_znx_big(1, size) + module.vec_znx_idft_tmp_bytes())) @@ -99,9 +99,9 @@ impl GLWECiphertextFourier, FFT64> { rank: usize, ) -> usize { let res_dft: usize = GLWECiphertextFourier::bytes_of(module, basek, out_k, rank); - let out_size: usize = derive_size(basek, out_k); - let in_size: usize = derive_size(basek, in_k); - let ggsw_size: usize = derive_size(basek, ggsw_k); + let out_size: usize = div_ceil(basek, out_k); + let in_size: usize = div_ceil(basek, in_k); + let ggsw_size: usize = div_ceil(basek, ggsw_k); let vmp: usize = module.vmp_apply_tmp_bytes(out_size, in_size, in_size, rank + 1, rank + 1, ggsw_size); let res_small: usize = GLWECiphertext::bytes_of(module, basek, out_k, rank); let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(); diff --git a/core/src/glwe_plaintext.rs b/core/src/glwe_plaintext.rs index 0aa846f..058c9a5 100644 --- a/core/src/glwe_plaintext.rs +++ b/core/src/glwe_plaintext.rs @@ -1,6 +1,6 @@ use backend::{Backend, FFT64, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef}; -use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData, derive_size}; +use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEOps, Infos, SetMetaData, div_ceil}; pub struct GLWEPlaintext { pub data: VecZnx, @@ -37,14 +37,14 @@ impl + AsRef<[u8]>> SetMetaData for GLWEPlaintext> { pub fn alloc(module: &Module, basek: usize, k: usize) -> Self { Self { - data: module.new_vec_znx(1, derive_size(basek, k)), + data: module.new_vec_znx(1, div_ceil(basek, k)), basek: basek, k, } } pub fn byte_of(module: &Module, basek: usize, k: usize) -> usize { - module.bytes_of_vec_znx(1, derive_size(basek, k)) + module.bytes_of_vec_znx(1, div_ceil(basek, k)) } } diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index 0362cd8..bc9baae 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -6,14 +6,14 @@ use crate::{GGLWECiphertext, GGSWCiphertext, GLWECiphertextFourier, GLWESecret, pub struct GLWESwitchingKey(pub(crate) GGLWECiphertext); impl GLWESwitchingKey, FFT64> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> Self { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> Self { GLWESwitchingKey(GGLWECiphertext::alloc( - module, basek, k, rows, rank_in, rank_out, + module, basek, k, rows, digits, rank_in, rank_out, )) } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, rank_in: usize, rank_out: usize) -> usize { - GGLWECiphertext::, FFT64>::bytes_of(module, basek, k, rows, rank_in, rank_out) + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank_in: usize, rank_out: usize) -> usize { + GGLWECiphertext::, FFT64>::bytes_of(module, basek, k, rows, digits, rank_in, rank_out) } } @@ -45,6 +45,10 @@ impl GLWESwitchingKey { pub fn rank_out(&self) -> usize { self.0.data.cols_out() - 1 } + + pub fn digits(&self) -> usize { + self.0.digits() + } } impl> GetRow for GLWESwitchingKey { diff --git a/core/src/lib.rs b/core/src/lib.rs index 3854e21..b82f82a 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -33,7 +33,7 @@ pub use tensor_key::*; pub use backend::Scratch; pub use backend::ScratchOwned; -use utils::derive_size; +use utils::div_ceil; pub(crate) const SIX_SIGMA: f64 = 6.0; @@ -46,6 +46,7 @@ pub trait ScratchCore { basek: usize, k: usize, rows: usize, + digits: usize, rank_in: usize, rank_out: usize, ) -> (GGLWECiphertext<&mut [u8], B>, &mut Self); @@ -55,6 +56,7 @@ pub trait ScratchCore { basek: usize, k: usize, rows: usize, + digits: usize, rank: usize, ) -> (GGSWCiphertext<&mut [u8], B>, &mut Self); fn tmp_glwe_fourier( @@ -78,6 +80,7 @@ pub trait ScratchCore { basek: usize, k: usize, rows: usize, + digits: usize, rank_in: usize, rank_out: usize, ) -> (GLWESwitchingKey<&mut [u8], B>, &mut Self); @@ -87,6 +90,7 @@ pub trait ScratchCore { basek: usize, k: usize, rows: usize, + digits: usize, rank: usize, ) -> (TensorKey<&mut [u8], B>, &mut Self); fn tmp_autokey( @@ -95,6 +99,7 @@ pub trait ScratchCore { basek: usize, k: usize, rows: usize, + digits: usize, rank: usize, ) -> (AutomorphismKey<&mut [u8], B>, &mut Self); } @@ -107,12 +112,12 @@ impl ScratchCore for Scratch { k: usize, rank: usize, ) -> (GLWECiphertext<&mut [u8]>, &mut Self) { - let (data, scratch) = self.tmp_vec_znx(module, rank + 1, derive_size(basek, k)); + let (data, scratch) = self.tmp_vec_znx(module, rank + 1, div_ceil(basek, k)); (GLWECiphertext { data, basek, k }, scratch) } fn tmp_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self) { - let (data, scratch) = self.tmp_vec_znx(module, 1, derive_size(basek, k)); + let (data, scratch) = self.tmp_vec_znx(module, 1, div_ceil(basek, k)); (GLWEPlaintext { data, basek, k }, scratch) } @@ -122,15 +127,17 @@ impl ScratchCore for Scratch { basek: usize, k: usize, rows: usize, + digits: usize, rank_in: usize, rank_out: usize, ) -> (GGLWECiphertext<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_mat_znx_dft(module, rows, rank_in, rank_out + 1, derive_size(basek, k)); + let (data, scratch) = self.tmp_mat_znx_dft(module, div_ceil(rows, digits), rank_in, rank_out + 1, div_ceil(basek, k)); ( GGLWECiphertext { data: data, basek: basek, k, + digits, }, scratch, ) @@ -142,14 +149,16 @@ impl ScratchCore for Scratch { basek: usize, k: usize, rows: usize, + digits: usize, rank: usize, ) -> (GGSWCiphertext<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_mat_znx_dft(module, rows, rank + 1, rank + 1, derive_size(basek, k)); + let (data, scratch) = self.tmp_mat_znx_dft(module, div_ceil(rows, digits), rank + 1, rank + 1, div_ceil(basek, k)); ( GGSWCiphertext { - data: data, - basek: basek, + data, + basek, k, + digits, }, scratch, ) @@ -162,7 +171,7 @@ impl ScratchCore for Scratch { k: usize, rank: usize, ) -> (GLWECiphertextFourier<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_vec_znx_dft(module, rank + 1, derive_size(basek, k)); + let (data, scratch) = self.tmp_vec_znx_dft(module, rank + 1, div_ceil(basek, k)); (GLWECiphertextFourier { data, basek, k }, scratch) } @@ -202,10 +211,11 @@ impl ScratchCore for Scratch { basek: usize, k: usize, rows: usize, + digits: usize, rank_in: usize, rank_out: usize, ) -> (GLWESwitchingKey<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_gglwe(module, basek, k, rows, rank_in, rank_out); + let (data, scratch) = self.tmp_gglwe(module, basek, k, rows, digits, rank_in, rank_out); (GLWESwitchingKey(data), scratch) } @@ -215,9 +225,10 @@ impl ScratchCore for Scratch { basek: usize, k: usize, rows: usize, + digits: usize, rank: usize, ) -> (AutomorphismKey<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_glwe_ksk(module, basek, k, rows, rank, rank); + let (data, scratch) = self.tmp_glwe_ksk(module, basek, k, rows, digits, rank, rank); (AutomorphismKey { key: data, p: 0 }, scratch) } @@ -227,6 +238,7 @@ impl ScratchCore for Scratch { basek: usize, k: usize, rows: usize, + digits: usize, rank: usize, ) -> (TensorKey<&mut [u8], FFT64>, &mut Self) { let mut keys: Vec> = Vec::new(); @@ -235,12 +247,12 @@ impl ScratchCore for Scratch { let mut scratch: &mut Scratch = self; if pairs != 0 { - let (gglwe, s) = scratch.tmp_glwe_ksk(module, basek, k, rows, 1, rank); + let (gglwe, s) = scratch.tmp_glwe_ksk(module, basek, k, rows, digits, 1, rank); scratch = s; keys.push(gglwe); } for _ in 1..pairs { - let (gglwe, s) = scratch.tmp_glwe_ksk(module, basek, k, rows, 1, rank); + let (gglwe, s) = scratch.tmp_glwe_ksk(module, basek, k, rows, digits, 1, rank); scratch = s; keys.push(gglwe); } diff --git a/core/src/tensor_key.rs b/core/src/tensor_key.rs index fca6b2e..c7e9bf4 100644 --- a/core/src/tensor_key.rs +++ b/core/src/tensor_key.rs @@ -8,18 +8,18 @@ pub struct TensorKey { } impl TensorKey, FFT64> { - pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> Self { + pub fn alloc(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self { let mut keys: Vec, FFT64>> = Vec::new(); let pairs: usize = (((rank + 1) * rank) >> 1).max(1); (0..pairs).for_each(|_| { - keys.push(GLWESwitchingKey::alloc(module, basek, k, rows, 1, rank)); + keys.push(GLWESwitchingKey::alloc(module, basek, k, rows, digits,1, rank)); }); Self { keys: keys } } - pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, rank: usize) -> usize { + pub fn bytes_of(module: &Module, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize { let pairs: usize = (((rank + 1) * rank) >> 1).max(1); - pairs * GLWESwitchingKey::, FFT64>::bytes_of(module, basek, k, rows, 1, rank) + pairs * GLWESwitchingKey::, FFT64>::bytes_of(module, basek, k, rows, digits,1, rank) } } @@ -51,6 +51,10 @@ impl TensorKey { pub fn rank_out(&self) -> usize { self.keys[0].rank_out() } + + pub fn digits(&self) -> usize { + self.keys[0].digits() + } } impl TensorKey, FFT64> { diff --git a/core/src/utils.rs b/core/src/utils.rs index c3bc5d5..62164c3 100644 --- a/core/src/utils.rs +++ b/core/src/utils.rs @@ -1,3 +1,3 @@ -pub(crate) fn derive_size(basek: usize, k: usize) -> usize { - (k + basek - 1) / basek +pub(crate) fn div_ceil(a: usize, b: usize) -> usize { + (a + b - 1) / b }