bug fixes

This commit is contained in:
Jean-Philippe Bossuat
2025-05-26 13:55:21 +02:00
parent cb284a4c4c
commit 4c3a568108
7 changed files with 57 additions and 46 deletions

35
Cargo.lock generated
View File

@@ -145,25 +145,22 @@ dependencies = [
[[package]] [[package]]
name = "criterion" name = "criterion"
version = "0.5.1" version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" checksum = "3bf7af66b0989381bd0be551bd7cc91912a655a58c6918420c9527b1fd8b4679"
dependencies = [ dependencies = [
"anes", "anes",
"cast", "cast",
"ciborium", "ciborium",
"clap", "clap",
"criterion-plot", "criterion-plot",
"is-terminal", "itertools 0.13.0",
"itertools 0.10.5",
"num-traits", "num-traits",
"once_cell",
"oorandom", "oorandom",
"plotters", "plotters",
"rayon", "rayon",
"regex", "regex",
"serde", "serde",
"serde_derive",
"serde_json", "serde_json",
"tinytemplate", "tinytemplate",
"walkdir", "walkdir",
@@ -254,23 +251,6 @@ dependencies = [
"crunchy", "crunchy",
] ]
[[package]]
name = "hermit-abi"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc"
[[package]]
name = "is-terminal"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b"
dependencies = [
"hermit-abi",
"libc",
"windows-sys 0.52.0",
]
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.10.5" version = "0.10.5"
@@ -280,6 +260,15 @@ dependencies = [
"either", "either",
] ]
[[package]]
name = "itertools"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
dependencies = [
"either",
]
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.14.0" version = "0.14.0"

View File

@@ -9,4 +9,4 @@ rand_chacha = "0.9.0"
rand_core = "0.9.3" rand_core = "0.9.3"
rand_distr = "0.5.1" rand_distr = "0.5.1"
itertools = "0.14.0" itertools = "0.14.0"
criterion = "0.5.1" criterion = "0.6.0"

View File

@@ -263,14 +263,14 @@ fn decode_coeff_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, k
assert!(col_i < a.cols()) assert!(col_i < a.cols())
} }
let cols: usize = (k + basek - 1) / basek; let size: usize = (k + basek - 1) / basek;
let data: &[i64] = a.raw(); let data: &[i64] = a.raw();
let mut res: i64 = data[i]; let mut res: i64 = data[i];
let rem: usize = basek - (k % basek); let rem: usize = basek - (k % basek);
let slice_size: usize = a.n() * a.size(); let slice_size: usize = a.n() * a.size();
(1..cols).for_each(|i| { (1..size).for_each(|i| {
let x = data[i * slice_size]; let x: i64 = data[i * slice_size];
if i == cols - 1 && rem != basek { if i == size - 1 && rem != basek {
let k_rem: usize = basek - rem; let k_rem: usize = basek - rem;
res = (res << k_rem) + (x >> rem); res = (res << k_rem) + (x >> rem);
} else { } else {

View File

@@ -1,7 +1,7 @@
use backend::{ use backend::{
Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftOps, ScalarZnxOps, Backend, FFT64, MatZnxDft, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDftAlloc,
ScalarZnxToRef, Scratch, VecZnx, VecZnxBigAlloc, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ScalarZnxDftOps, ScalarZnxOps, ScalarZnxToRef, Scratch, VecZnx, VecZnxBigAlloc, VecZnxDftAlloc, VecZnxDftOps, VecZnxDftToMut,
ZnxZero, VecZnxDftToRef, VecZnxOps, ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -107,7 +107,7 @@ where
impl AutomorphismKey<Vec<u8>, FFT64> { impl AutomorphismKey<Vec<u8>, FFT64> {
pub fn generate_from_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize { pub fn generate_from_sk_scratch_space(module: &Module<FFT64>, rank: usize, size: usize) -> usize {
GGLWECiphertext::generate_from_sk_scratch_space(module, rank, size) GGLWECiphertext::generate_from_sk_scratch_space(module, rank, size) + module.bytes_of_scalar_znx_dft(rank)
} }
pub fn generate_from_pk_scratch_space(module: &Module<FFT64>, rank: usize, pk_size: usize) -> usize { pub fn generate_from_pk_scratch_space(module: &Module<FFT64>, rank: usize, pk_size: usize) -> usize {
@@ -188,6 +188,15 @@ where
assert_eq!(sk.n(), module.n()); assert_eq!(sk.n(), module.n());
assert_eq!(self.rank_out(), self.rank_in()); assert_eq!(self.rank_out(), self.rank_in());
assert_eq!(sk.rank(), self.rank()); assert_eq!(sk.rank(), self.rank());
assert!(
scratch.available() >= AutomorphismKey::generate_from_sk_scratch_space(module, self.rank(), self.size()),
"scratch.available(): {} < AutomorphismKey::generate_from_sk_scratch_space(module, self.rank()={}, \
self.size()={}): {}",
scratch.available(),
self.rank(),
self.size(),
AutomorphismKey::generate_from_sk_scratch_space(module, self.rank(), self.size())
)
} }
let (sk_out_dft_data, scratch_1) = scratch.tmp_scalar_znx_dft(module, sk.rank()); let (sk_out_dft_data, scratch_1) = scratch.tmp_scalar_znx_dft(module, sk.rank());

View File

@@ -1,7 +1,7 @@
use backend::{ use backend::{
Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft, Backend, FFT64, MatZnxDft, MatZnxDftAlloc, MatZnxDftOps, MatZnxDftToMut, MatZnxDftToRef, Module, ScalarZnx, ScalarZnxDft,
ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef, VecZnxOps, ZnxInfos, ScalarZnxDftToRef, ScalarZnxToRef, Scratch, VecZnxAlloc, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftToMut, VecZnxDftToRef,
ZnxZero, VecZnxOps, ZnxInfos, ZnxZero,
}; };
use sampling::source::Source; use sampling::source::Source;
@@ -115,6 +115,15 @@ where
assert_eq!(self.n(), module.n()); assert_eq!(self.n(), module.n());
assert_eq!(sk_dft.n(), module.n()); assert_eq!(sk_dft.n(), module.n());
assert_eq!(pt.n(), module.n()); assert_eq!(pt.n(), module.n());
assert!(
scratch.available() >= GGLWECiphertext::generate_from_sk_scratch_space(module, self.rank(), self.size()),
"scratch.available: {} < GGLWECiphertext::generate_from_sk_scratch_space(module, self.rank()={}, \
self.size()={}): {}",
scratch.available(),
self.rank(),
self.size(),
GGLWECiphertext::generate_from_sk_scratch_space(module, self.rank(), self.size())
)
} }
let rows: usize = self.rows(); let rows: usize = self.rows();

View File

@@ -97,9 +97,7 @@ where
impl GLWECiphertext<Vec<u8>> { impl GLWECiphertext<Vec<u8>> {
pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, ct_size: usize) -> usize { pub fn encrypt_sk_scratch_space(module: &Module<FFT64>, ct_size: usize) -> usize {
module.vec_znx_big_normalize_tmp_bytes() module.vec_znx_big_normalize_tmp_bytes() + module.bytes_of_vec_znx_dft(1, ct_size) + module.bytes_of_vec_znx(1, ct_size)
+ module.bytes_of_vec_znx_dft(1, ct_size)
+ module.bytes_of_vec_znx_big(1, ct_size)
} }
pub fn encrypt_pk_scratch_space(module: &Module<FFT64>, pk_size: usize) -> usize { pub fn encrypt_pk_scratch_space(module: &Module<FFT64>, pk_size: usize) -> usize {
((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1)) ((module.bytes_of_vec_znx_dft(1, pk_size) + module.bytes_of_vec_znx_big(1, pk_size)) | module.bytes_of_scalar_znx(1))
@@ -538,7 +536,7 @@ where
1 => module.vec_znx_big_add_small_inplace(&mut res_big, i, lhs, i), 1 => module.vec_znx_big_add_small_inplace(&mut res_big, i, lhs, i),
2 => module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, lhs, i), 2 => module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, lhs, i),
3 => module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, lhs, i), 3 => module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, lhs, i),
_=>{}, _ => {}
} }
module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1); module.vec_znx_big_normalize(basek, self, i, &res_big, i, scratch1);
}); });
@@ -636,6 +634,12 @@ where
assert_eq!(pt.n(), module.n()); assert_eq!(pt.n(), module.n());
assert!(col < self.rank() + 1); assert!(col < self.rank() + 1);
} }
assert!(
scratch.available() >= GLWECiphertext::encrypt_sk_scratch_space(module, self.size()),
"scratch.available(): {} < GLWECiphertext::encrypt_sk_scratch_space: {}",
scratch.available(),
GLWECiphertext::encrypt_sk_scratch_space(module, self.size())
)
} }
let basek: usize = self.basek(); let basek: usize = self.basek();