From 8209fb4e4058df2ab322df3f6f109e360221eaf1 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Wed, 28 May 2025 15:59:49 +0200 Subject: [PATCH] Replaced manual core structs scratch allocation by new API on Scratch --- core/src/automorphism.rs | 24 +--- core/src/elem.rs | 2 +- core/src/gglwe_ciphertext.rs | 9 +- core/src/ggsw_ciphertext.rs | 165 ++++++++++------------------ core/src/glwe_ciphertext_fourier.rs | 33 ++---- core/src/glwe_packing.rs | 27 +---- core/src/glwe_plaintext.rs | 6 +- core/src/keyswitch_key.rs | 51 ++------- core/src/lib.rs | 12 +- 9 files changed, 95 insertions(+), 234 deletions(-) diff --git a/core/src/automorphism.rs b/core/src/automorphism.rs index 3e309f9..1df916e 100644 --- a/core/src/automorphism.rs +++ b/core/src/automorphism.rs @@ -5,13 +5,8 @@ use backend::{ use sampling::source::Source; use crate::{ - elem::{GetRow, Infos, SetRow}, - gglwe_ciphertext::GGLWECiphertext, - ggsw_ciphertext::GGSWCiphertext, - glwe_ciphertext::GLWECiphertext, - glwe_ciphertext_fourier::GLWECiphertextFourier, - keys::{SecretKey, SecretKeyFourier}, - keyswitch_key::GLWESwitchingKey, + GGLWECiphertext, GGSWCiphertext, GLWECiphertext, GLWECiphertextFourier, GLWESwitchingKey, GetRow, Infos, ScratchCore, + SecretKey, SetRow, }; pub struct AutomorphismKey { @@ -179,12 +174,7 @@ impl + AsRef<[u8]>> AutomorphismKey { ) } - let (sk_out_dft_data, scratch_1) = scratch.tmp_scalar_znx_dft(module, sk.rank()); - - let mut sk_out_dft: SecretKeyFourier<&mut [u8], FFT64> = SecretKeyFourier { - data: sk_out_dft_data, - dist: sk.dist, - }; + let (mut sk_out_dft, scratch_1) = scratch.tmp_sk_fourier(module, sk.rank()); { (0..self.rank()).for_each(|i| { @@ -249,13 +239,7 @@ impl + AsRef<[u8]>> AutomorphismKey { let cols_out: usize = rhs.rank_out() + 1; - let (tmp_dft_data, scratch1) = scratch.tmp_vec_znx_dft(module, cols_out, lhs.size()); - - let mut tmp_dft: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_dft_data, - basek: lhs.basek(), - k: lhs.k(), - }; + let (mut tmp_dft, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { diff --git a/core/src/elem.rs b/core/src/elem.rs index 131e9d7..fdfd5bd 100644 --- a/core/src/elem.rs +++ b/core/src/elem.rs @@ -1,6 +1,6 @@ use backend::{Backend, Module, ZnxInfos}; -use crate::{glwe_ciphertext_fourier::GLWECiphertextFourier, utils::derive_size}; +use crate::{GLWECiphertextFourier, derive_size}; pub trait Infos { type Inner: ZnxInfos; diff --git a/core/src/gglwe_ciphertext.rs b/core/src/gglwe_ciphertext.rs index 0417707..dd4a6ff 100644 --- a/core/src/gglwe_ciphertext.rs +++ b/core/src/gglwe_ciphertext.rs @@ -4,14 +4,7 @@ use backend::{ }; use sampling::source::Source; -use crate::{ - elem::{GetRow, Infos, SetRow}, - glwe_ciphertext::GLWECiphertext, - glwe_ciphertext_fourier::GLWECiphertextFourier, - glwe_plaintext::GLWEPlaintext, - keys::SecretKeyFourier, - utils::derive_size, -}; +use crate::{GLWECiphertext, GLWECiphertextFourier, GLWEPlaintext, GetRow, Infos, SecretKeyFourier, SetRow, derive_size}; pub struct GGLWECiphertext { pub(crate) data: MatZnxDft, diff --git a/core/src/ggsw_ciphertext.rs b/core/src/ggsw_ciphertext.rs index 4144cf4..43e9f34 100644 --- a/core/src/ggsw_ciphertext.rs +++ b/core/src/ggsw_ciphertext.rs @@ -6,11 +6,11 @@ use backend::{ use sampling::source::Source; use crate::{ + ScratchCore, automorphism::AutomorphismKey, elem::{GetRow, Infos, SetRow}, glwe_ciphertext::GLWECiphertext, glwe_ciphertext_fourier::GLWECiphertextFourier, - glwe_plaintext::GLWEPlaintext, keys::SecretKeyFourier, keyswitch_key::GLWESwitchingKey, tensor_key::TensorKey, @@ -198,55 +198,38 @@ impl + AsRef<[u8]>> GGSWCiphertext { assert_eq!(sk_dft.n(), module.n()); } - let size: usize = self.size(); let basek: usize = self.basek(); let k: usize = self.k(); - let cols: usize = self.rank() + 1; + let rank: usize = self.rank(); - let (tmp_znx_pt, scratch_1) = scratch.tmp_vec_znx(module, 1, size); - let (tmp_znx_ct, scrach_2) = scratch_1.tmp_vec_znx(module, cols, size); - - let mut vec_znx_pt: GLWEPlaintext<&mut [u8]> = GLWEPlaintext { - data: tmp_znx_pt, - basek: basek, - k: k, - }; - - let mut vec_znx_ct: GLWECiphertext<&mut [u8]> = GLWECiphertext { - data: tmp_znx_ct, - basek: basek, - k, - }; + let (mut tmp_pt, scratch1) = scratch.tmp_glwe_pt(module, basek, k); + let (mut tmp_ct, scratch2) = scratch1.tmp_glwe_ct(module, basek, k, rank); (0..self.rows()).for_each(|row_i| { - vec_znx_pt.data.zero(); + tmp_pt.data.zero(); // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt - module.vec_znx_add_scalar_inplace(&mut vec_znx_pt.data, 0, row_i, pt, 0); - module.vec_znx_normalize_inplace(basek, &mut vec_znx_pt.data, 0, scrach_2); + module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, row_i, pt, 0); + module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch2); - (0..cols).for_each(|col_j| { + (0..rank + 1).for_each(|col_j| { // rlwe encrypt of vec_znx_pt into vec_znx_ct - vec_znx_ct.encrypt_sk_private( + tmp_ct.encrypt_sk_private( module, - Some((&vec_znx_pt, col_j)), + Some((&tmp_pt, col_j)), sk_dft, source_xa, source_xe, sigma, - scrach_2, + scratch2, ); // Switch vec_znx_ct into DFT domain { - let (mut vec_znx_dft_ct, _) = scrach_2.tmp_vec_znx_dft(module, cols, size); - - (0..cols).for_each(|i| { - module.vec_znx_dft(&mut vec_znx_dft_ct, i, &vec_znx_ct.data, i); - }); - - module.vmp_prepare_row(&mut self.data, row_i, col_j, &vec_znx_dft_ct); + let (mut tmp_ct_dft, _) = scratch2.tmp_glwe_fourier(module, basek, k, rank); + tmp_ct.dft(module, &mut tmp_ct_dft); + self.set_row(module, row_i, col_j, &tmp_ct_dft); } }); }); @@ -349,26 +332,22 @@ impl + AsRef<[u8]>> GGSWCiphertext { tsk: &TensorKey, scratch: &mut Scratch, ) { - let cols: usize = self.rank() + 1; - - let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size()); - let mut res: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { - data: res_data, - basek: self.basek(), - k: self.k(), - }; + let rank: usize = self.rank(); + let cols: usize = rank + 1; + let basek: usize = self.basek(); + let (mut tmp_res, scratch1) = scratch.tmp_glwe_ct(module, basek, self.k(), rank); let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size()); // Keyswitch the j-th row of the col 0 (0..lhs.rows()).for_each(|row_i| { // Key-switch column 0, i.e. // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0s0' + a1s1' + a2s2') + M[i], a0, a1, a2) - lhs.keyswitch_internal_col0(module, row_i, &mut res, ksk, scratch2); + lhs.keyswitch_internal_col0(module, row_i, &mut tmp_res, ksk, scratch2); // Isolates DFT(a[i]) (0..cols).for_each(|col_i| { - module.vec_znx_dft(&mut ci_dft, col_i, &res.data, col_i); + module.vec_znx_dft(&mut ci_dft, col_i, &tmp_res.data, col_i); }); module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft); @@ -379,14 +358,10 @@ impl + AsRef<[u8]>> GGSWCiphertext { // col 2: (-(c0s0' + c1s1' + c2s2') , c0 , c1 + M[i], c2 ) // col 3: (-(d0s0' + d1s1' + d2s2') , d0 , d1 , d2 + M[i]) (1..cols).for_each(|col_j| { - self.expand_row(module, col_j, &mut res.data, &ci_dft, tsk, scratch2); - - let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size()); - (0..cols).for_each(|i| { - module.vec_znx_dft(&mut res_dft, i, &res.data, i); - }); - - module.vmp_prepare_row(&mut self.data, row_i, col_j, &res_dft); + self.expand_row(module, col_j, &mut tmp_res.data, &ci_dft, tsk, scratch2); + let (mut tmp_res_dft, _) = scratch2.tmp_glwe_fourier(module, basek, self.k(), rank); + tmp_res.dft(module, &mut tmp_res_dft); + self.set_row(module, row_i, col_j, &tmp_res_dft); }); }) } @@ -448,28 +423,24 @@ impl + AsRef<[u8]>> GGSWCiphertext { ) }; - let cols: usize = self.rank() + 1; + let rank: usize = self.rank(); + let cols: usize = rank + 1; + let basek: usize = self.basek(); - let (res_data, scratch1) = scratch.tmp_vec_znx(&module, cols, self.size()); - let mut res: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { - data: res_data, - basek: self.basek(), - k: self.k(), - }; - - let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size()); + let (mut tmp_res, scratch1) = scratch.tmp_glwe_ct(module, basek, self.k(), rank); + let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size()); // Keyswitch the j-th row of the col 0 (0..lhs.rows()).for_each(|row_i| { // Key-switch column 0, i.e. // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) - lhs.keyswitch_internal_col0(module, row_i, &mut res, &auto_key.key, scratch2); + lhs.keyswitch_internal_col0(module, row_i, &mut tmp_res, &auto_key.key, scratch2); // Isolates DFT(AUTO(a[i])) (0..cols).for_each(|col_i| { // (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2) -> (-(a0s0 + a1s1 + a2s2) + pi(M[i]), a0, a1, a2) - module.vec_znx_automorphism_inplace(auto_key.p(), &mut res.data, col_i); - module.vec_znx_dft(&mut ci_dft, col_i, &res.data, col_i); + module.vec_znx_automorphism_inplace(auto_key.p(), &mut tmp_res.data, col_i); + module.vec_znx_dft(&mut ci_dft, col_i, &tmp_res.data, col_i); }); module.vmp_prepare_row(&mut self.data, row_i, 0, &ci_dft); @@ -480,14 +451,17 @@ impl + AsRef<[u8]>> GGSWCiphertext { // col 2: (-(c0s0 + c1s1 + c2s2) , c0 , c1 + pi(M[i]), c2 ) // col 3: (-(d0s0 + d1s1 + d2s2) , d0 , d1 , d2 + pi(M[i])) (1..cols).for_each(|col_j| { - self.expand_row(module, col_j, &mut res.data, &ci_dft, tensor_key, scratch2); - - let (mut res_dft, _) = scratch2.tmp_vec_znx_dft(module, cols, self.size()); - (0..cols).for_each(|i| { - module.vec_znx_dft(&mut res_dft, i, &res.data, i); - }); - - module.vmp_prepare_row(&mut self.data, row_i, col_j, &res_dft); + self.expand_row( + module, + col_j, + &mut tmp_res.data, + &ci_dft, + tensor_key, + scratch2, + ); + let (mut tmp_res_dft, _) = scratch2.tmp_glwe_fourier(module, basek, self.k(), rank); + tmp_res.dft(module, &mut tmp_res_dft); + self.set_row(module, row_i, col_j, &tmp_res_dft); }); }) } @@ -530,35 +504,22 @@ impl + AsRef<[u8]>> GGSWCiphertext { ); } - let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank() + 1, lhs.size()); - - let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_in_data, - basek: lhs.basek(), - k: lhs.k(), - }; - - let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); - - let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_out_data, - basek: self.basek(), - k: self.k(), - }; + let (mut tmp_ct_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_ct_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); (0..self.rank() + 1).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { - lhs.get_row(module, row_j, col_i, &mut tmp_in); - tmp_out.external_product(module, &tmp_in, rhs, scratch2); - self.set_row(module, row_j, col_i, &tmp_out); + lhs.get_row(module, row_j, col_i, &mut tmp_ct_in); + tmp_ct_out.external_product(module, &tmp_ct_in, rhs, scratch2); + self.set_row(module, row_j, col_i, &tmp_ct_out); }); }); - tmp_out.data.zero(); + tmp_ct_out.data.zero(); (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| { (0..self.rank() + 1).for_each(|col_j| { - self.set_row(module, row_i, col_j, &tmp_out); + self.set_row(module, row_i, col_j, &tmp_ct_out); }); }); } @@ -580,19 +541,13 @@ impl + AsRef<[u8]>> GGSWCiphertext { ); } - let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); - - let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_data, - basek: self.basek(), - k: self.k(), - }; + let (mut tmp_ct, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); (0..self.rank() + 1).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { - self.get_row(module, row_j, col_i, &mut tmp); - tmp.external_product_inplace(module, rhs, scratch1); - self.set_row(module, row_j, col_i, &tmp); + self.get_row(module, row_j, col_i, &mut tmp_ct); + tmp_ct.external_product_inplace(module, rhs, scratch1); + self.set_row(module, row_j, col_i, &tmp_ct); }); }); } @@ -622,15 +577,9 @@ impl> GGSWCiphertext { ) ) } - - let (tmp_dft_in_data, scratch2) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); - let mut tmp_dft_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_dft_in_data, - basek: self.basek(), - k: self.k(), - }; - self.get_row(module, row_i, 0, &mut tmp_dft_in); - res.keyswitch_from_fourier(module, &tmp_dft_in, ksk, scratch2); + let (mut tmp_dft_dft, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); + self.get_row(module, row_i, 0, &mut tmp_dft_dft); + res.keyswitch_from_fourier(module, &tmp_dft_dft, ksk, scratch1); } } diff --git a/core/src/glwe_ciphertext_fourier.rs b/core/src/glwe_ciphertext_fourier.rs index cb2ec21..19f37f1 100644 --- a/core/src/glwe_ciphertext_fourier.rs +++ b/core/src/glwe_ciphertext_fourier.rs @@ -5,7 +5,7 @@ use backend::{ use sampling::source::Source; use crate::{ - elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, glwe_plaintext::GLWEPlaintext, + ScratchCore, elem::Infos, ggsw_ciphertext::GGSWCiphertext, glwe_ciphertext::GLWECiphertext, glwe_plaintext::GLWEPlaintext, keys::SecretKeyFourier, keyswitch_key::GLWESwitchingKey, utils::derive_size, }; @@ -119,15 +119,9 @@ impl + AsRef<[u8]>> GLWECiphertextFourier sigma: f64, scratch: &mut Scratch, ) { - let (vec_znx_tmp, scratch_1) = scratch.tmp_vec_znx(module, self.rank() + 1, self.size()); - let mut ct_idft = GLWECiphertext { - data: vec_znx_tmp, - basek: self.basek, - k: self.k, - }; - ct_idft.encrypt_zero_sk(module, sk_dft, source_xa, source_xe, sigma, scratch_1); - - ct_idft.dft(module, self); + let (mut tmp_ct, scratch1) = scratch.tmp_glwe_ct(module, self.basek(), self.k(), self.rank()); + tmp_ct.encrypt_zero_sk(module, sk_dft, source_xa, source_xe, sigma, scratch1); + tmp_ct.dft(module, self); } pub fn keyswitch, DataRhs: AsRef<[u8]>>( @@ -137,22 +131,9 @@ impl + AsRef<[u8]>> GLWECiphertextFourier rhs: &GLWESwitchingKey, scratch: &mut Scratch, ) { - let cols_out: usize = rhs.rank_out() + 1; - - // Space fr normalized VMP result outside of DFT domain - let (res_idft_data, scratch1) = scratch.tmp_vec_znx(module, cols_out, lhs.size()); - - let mut res_idft: GLWECiphertext<&mut [u8]> = GLWECiphertext::<&mut [u8]> { - data: res_idft_data, - basek: lhs.basek, - k: lhs.k, - }; - - res_idft.keyswitch_from_fourier(module, lhs, rhs, scratch1); - - (0..cols_out).for_each(|i| { - module.vec_znx_dft(&mut self.data, i, &res_idft.data, i); - }); + let (mut tmp_ct, scratch1) = scratch.tmp_glwe_ct(module, self.basek(), self.k(), self.rank()); + tmp_ct.keyswitch_from_fourier(module, lhs, rhs, scratch1); + tmp_ct.dft(module, self); } pub fn keyswitch_inplace>( diff --git a/core/src/glwe_packing.rs b/core/src/glwe_packing.rs index b391aeb..54583ad 100644 --- a/core/src/glwe_packing.rs +++ b/core/src/glwe_packing.rs @@ -1,4 +1,4 @@ -use crate::{automorphism::AutomorphismKey, elem::Infos, glwe_ciphertext::GLWECiphertext, glwe_ops::GLWEOps}; +use crate::{ScratchCore, automorphism::AutomorphismKey, elem::Infos, glwe_ciphertext::GLWECiphertext, glwe_ops::GLWEOps}; use std::collections::HashMap; use backend::{FFT64, Module, Scratch, VecZnxAlloc}; @@ -223,8 +223,6 @@ fn combine, DataAK: AsRef<[u8]>>( let basek: usize = a.basek(); let k: usize = a.k(); let rank: usize = a.rank(); - let cols: usize = rank + 1; - let size: usize = a.size(); let gal_el: i64; @@ -245,20 +243,9 @@ fn combine, DataAK: AsRef<[u8]>>( a.rsh(1, scratch); if let Some(b) = b { - let (tmp_b_data, scratch_1) = scratch.tmp_vec_znx(module, cols, size); - let mut tmp_b: GLWECiphertext<&mut [u8]> = GLWECiphertext { - data: tmp_b_data, - k: k, - basek: basek, - }; - + let (mut tmp_b, scratch_1) = scratch.tmp_glwe_ct(module, basek, k, rank); { - let (tmp_a_data, scratch_2) = scratch_1.tmp_vec_znx(module, cols, size); - let mut tmp_a: GLWECiphertext<&mut [u8]> = GLWECiphertext { - data: tmp_a_data, - k: k, - basek: basek, - }; + let (mut tmp_a, scratch_2) = scratch_1.tmp_glwe_ct(module, basek, k, rank); //TODO can we skip tmp_a by reordering X^k ? // tmp_a = b * X^t tmp_a.rotate(module, 1 << (log_n - i - 1), b); @@ -294,13 +281,7 @@ fn combine, DataAK: AsRef<[u8]>>( } } else { if let Some(b) = b { - let (tmp_b_data, scratch_1) = scratch.tmp_vec_znx(module, cols, size); - let mut tmp_b: GLWECiphertext<&mut [u8]> = GLWECiphertext { - data: tmp_b_data, - k: k, - basek: basek, - }; - + let (mut tmp_b, scratch_1) = scratch.tmp_glwe_ct(module, basek, k, rank); tmp_b.rotate(module, 1 << (log_n - i - 1), b); tmp_b.rsh(1, scratch_1); diff --git a/core/src/glwe_plaintext.rs b/core/src/glwe_plaintext.rs index 3bf0060..3ffe32c 100644 --- a/core/src/glwe_plaintext.rs +++ b/core/src/glwe_plaintext.rs @@ -1,4 +1,4 @@ -use backend::{Backend, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef}; +use backend::{Backend, FFT64, Module, VecZnx, VecZnxAlloc, VecZnxToMut, VecZnxToRef}; use crate::{ elem::{Infos, SetMetaData}, @@ -47,6 +47,10 @@ impl GLWEPlaintext> { k, } } + + pub fn byte_of(module: &Module, basek: usize, k: usize) -> usize { + module.bytes_of_vec_znx(1, derive_size(basek, k)) + } } impl> GLWECiphertextToRef for GLWEPlaintext { diff --git a/core/src/keyswitch_key.rs b/core/src/keyswitch_key.rs index 645b24b..2503176 100644 --- a/core/src/keyswitch_key.rs +++ b/core/src/keyswitch_key.rs @@ -2,6 +2,7 @@ use backend::{Backend, FFT64, MatZnxDft, MatZnxDftOps, Module, Scratch, VecZnxDf use sampling::source::Source; use crate::{ + ScratchCore, elem::{GetRow, Infos, SetRow}, gglwe_ciphertext::GGLWECiphertext, ggsw_ciphertext::GGSWCiphertext, @@ -184,21 +185,8 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size()); - - let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_in_data, - basek: lhs.basek(), - k: lhs.k(), - }; - - let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); - - let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_out_data, - basek: self.basek(), - k: self.k(), - }; + let (mut tmp_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { @@ -234,13 +222,7 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); - - let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_data, - basek: self.basek(), - k: self.k(), - }; + let (mut tmp, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { @@ -283,21 +265,8 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (tmp_in_data, scratch1) = scratch.tmp_vec_znx_dft(module, lhs.rank_out() + 1, lhs.size()); - - let mut tmp_in: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_in_data, - basek: lhs.basek(), - k: lhs.k(), - }; - - let (tmp_out_data, scratch2) = scratch1.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); - - let mut tmp_out: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_out_data, - basek: self.basek(), - k: self.k(), - }; + let (mut tmp_in, scratch1) = scratch.tmp_glwe_fourier(module, lhs.basek(), lhs.k(), lhs.rank()); + let (mut tmp_out, scratch2) = scratch1.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { @@ -333,13 +302,7 @@ impl + AsRef<[u8]>> GLWESwitchingKey { ); } - let (tmp_data, scratch1) = scratch.tmp_vec_znx_dft(module, self.rank_out() + 1, self.size()); - - let mut tmp: GLWECiphertextFourier<&mut [u8], FFT64> = GLWECiphertextFourier::<&mut [u8], FFT64> { - data: tmp_data, - basek: self.basek(), - k: self.k(), - }; + let (mut tmp, scratch1) = scratch.tmp_glwe_fourier(module, self.basek(), self.k(), self.rank()); (0..self.rank_in()).for_each(|col_i| { (0..self.rows()).for_each(|row_j| { diff --git a/core/src/lib.rs b/core/src/lib.rs index b13ab8b..ab27539 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -38,7 +38,8 @@ use utils::derive_size; pub(crate) const SIX_SIGMA: f64 = 6.0; pub trait ScratchCore { - fn tmp_glwe(&mut self, module: &Module, basek: usize, k: usize, rank: usize) -> (GLWECiphertext<&mut [u8]>, &mut Self); + fn tmp_glwe_ct(&mut self, module: &Module, basek: usize, k: usize, rank: usize) -> (GLWECiphertext<&mut [u8]>, &mut Self); + fn tmp_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self); fn tmp_gglwe( &mut self, module: &Module, @@ -100,7 +101,7 @@ pub trait ScratchCore { } impl ScratchCore for Scratch { - fn tmp_glwe( + fn tmp_glwe_ct( &mut self, module: &Module, basek: usize, @@ -111,6 +112,11 @@ impl ScratchCore for Scratch { (GLWECiphertext { data, basek, k }, scratch) } + fn tmp_glwe_pt(&mut self, module: &Module, basek: usize, k: usize) -> (GLWEPlaintext<&mut [u8]>, &mut Self) { + let (data, scratch) = self.tmp_vec_znx(module, 1, derive_size(basek, k)); + (GLWEPlaintext { data, basek, k }, scratch) + } + fn tmp_gglwe( &mut self, module: &Module, @@ -190,7 +196,7 @@ impl ScratchCore for Scratch { } fn tmp_sk_fourier(&mut self, module: &Module, rank: usize) -> (SecretKeyFourier<&mut [u8], FFT64>, &mut Self) { - let (data, scratch) = self.tmp_scalar_znx_dft(module, rank + 1); + let (data, scratch) = self.tmp_scalar_znx_dft(module, rank); ( SecretKeyFourier { data,