bug fixes

This commit is contained in:
Jean-Philippe Bossuat
2025-05-26 19:16:43 +02:00
parent e5d6a6f828
commit dec3481a6f
3 changed files with 14 additions and 11 deletions

View File

@@ -267,7 +267,7 @@ fn decode_coeff_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, k
let data: &[i64] = a.raw(); let data: &[i64] = a.raw();
let mut res: i64 = data[i]; let mut res: i64 = data[i];
let rem: usize = basek - (k % basek); let rem: usize = basek - (k % basek);
let slice_size: usize = a.n() * a.size(); let slice_size: usize = a.n() * a.cols();
(1..size).for_each(|i| { (1..size).for_each(|i| {
let x: i64 = data[i * slice_size]; let x: i64 = data[i * slice_size];
if i == size - 1 && rem != basek { if i == size - 1 && rem != basek {

View File

@@ -142,14 +142,13 @@ impl GGSWCiphertext<Vec<u8>, FFT64> {
tensor_key_size: usize, tensor_key_size: usize,
rank: usize, rank: usize,
) -> usize { ) -> usize {
GGSWCiphertext::keyswitch_scratch_space( let cols: usize = rank + 1;
module, let res: usize = module.bytes_of_vec_znx(cols, out_size);
out_size, let res_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
in_size, let ci_dft: usize = module.bytes_of_vec_znx_dft(cols, out_size);
auto_key_size, let ks_internal: usize = GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, out_size, in_size, auto_key_size, rank);
tensor_key_size, let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, out_size, tensor_key_size, rank);
rank, res + ci_dft + (ks_internal | expand | res_dft)
)
} }
pub fn automorphism_inplace_scratch_space( pub fn automorphism_inplace_scratch_space(
@@ -288,6 +287,8 @@ where
{ {
let cols: usize = self.rank() + 1; let cols: usize = self.rank() + 1;
assert!(scratch.available() >= GGSWCiphertext::expand_row_scratch_space(module, self.size(), tsk.size(), tsk.rank()));
// Example for rank 3: // Example for rank 3:
// //
// Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is // Note: M is a vector (m, Bm, B^2m, B^3m, ...), so each column is
@@ -383,7 +384,7 @@ where
k: self.k(), k: self.k(),
}; };
let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, self.size()); let (mut ci_dft, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, lhs.size());
// Keyswitch the j-th row of the col 0 // Keyswitch the j-th row of the col 0
(0..lhs.rows()).for_each(|row_i| { (0..lhs.rows()).for_each(|row_i| {
@@ -467,6 +468,7 @@ where
self.rank(), self.rank(),
tensor_key.rank() tensor_key.rank()
); );
assert!(scratch.available() >= GGSWCiphertext::automorphism_scratch_space(module, self.size(), lhs.size(), auto_key.size(), tensor_key.size(), self.rank()))
}; };
let cols: usize = self.rank() + 1; let cols: usize = self.rank() + 1;
@@ -646,6 +648,7 @@ where
{ {
assert_eq!(self.rank(), ksk.rank()); assert_eq!(self.rank(), ksk.rank());
assert_eq!(res.rank(), ksk.rank()); assert_eq!(res.rank(), ksk.rank());
assert!(scratch.available() >= GGSWCiphertext::keyswitch_internal_col0_scratch_space(module, res.size(), self.size(), ksk.size(), ksk.rank()))
} }
let (tmp_dft_in_data, scratch2) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size()); let (tmp_dft_in_data, scratch2) = scratch.tmp_vec_znx_dft(module, self.rank() + 1, self.size());

View File

@@ -17,7 +17,7 @@ pub struct TensorKey<C, B: Backend> {
impl TensorKey<Vec<u8>, FFT64> { impl TensorKey<Vec<u8>, FFT64> {
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> Self { pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
let mut keys: Vec<GLWESwitchingKey<Vec<u8>, FFT64>> = Vec::new(); let mut keys: Vec<GLWESwitchingKey<Vec<u8>, FFT64>> = Vec::new();
let pairs: usize = ((rank + 1) * rank) >> 1; let pairs: usize = (((rank + 1) * rank) >> 1).max(1);
(0..pairs).for_each(|_| { (0..pairs).for_each(|_| {
keys.push(GLWESwitchingKey::alloc(module, basek, k, rows, 1, rank)); keys.push(GLWESwitchingKey::alloc(module, basek, k, rows, 1, rank));
}); });