Merge pull request #54 from phantomzone-org/jay/fhe-vm-fixes

Jay/fhe vm fixes
This commit is contained in:
Jean-Philippe Bossuat
2025-07-15 20:19:22 +02:00
committed by GitHub
12 changed files with 131 additions and 154 deletions

View File

@@ -269,12 +269,12 @@ fn decode_coeff_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, k
let size: usize = (k + basek - 1) / basek;
let data: &[i64] = a.raw();
let mut res: i64 = data[i];
let mut res: i64 = 0;
let rem: usize = basek - (k % basek);
let slice_size: usize = a.n() * a.cols();
(0..size).for_each(|i| {
let x: i64 = data[i * slice_size];
if i == size - 1 && rem != basek {
(0..size).for_each(|j| {
let x: i64 = data[j * slice_size + i];
if j == size - 1 && rem != basek {
let k_rem: usize = basek - rem;
res = (res << k_rem) + (x >> rem);
} else {
@@ -320,7 +320,7 @@ mod tests {
let module: Module<FFT64> = Module::<FFT64>::new(n);
let basek: usize = 17;
let size: usize = 5;
for k in [size * basek - 5] {
for k in [1, basek / 2, size * basek - 5] {
let mut a: VecZnx<_> = module.new_vec_znx(2, size);
let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut();

View File

@@ -1,7 +1,6 @@
use backend::{
FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps,
Scratch, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxView,
ZnxViewMut, ZnxZero,
FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxAlloc,
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxView, ZnxViewMut, ZnxZero,
};
use itertools::izip;

View File

@@ -34,7 +34,7 @@ pub trait Infos {
/// Returns the number of size per polynomial.
fn size(&self) -> usize {
let size: usize = self.inner().size();
debug_assert_eq!(size, self.k().div_ceil(self.basek()));
debug_assert!(size >= self.k().div_ceil(self.basek()));
size
}

View File

@@ -20,6 +20,8 @@ pub struct GGSWCiphertext<C, B: Backend> {
impl GGSWCiphertext<Vec<u8>, FFT64> {
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self {
let size: usize = k.div_ceil(basek);
debug_assert!(digits > 0, "invalid ggsw: `digits` == 0");
debug_assert!(
size > digits,
"invalid ggsw: ceil(k/basek): {} <= digits: {}",

View File

@@ -79,7 +79,7 @@ impl<DataSelf: AsMut<[u8]> + AsRef<[u8]>> SetMetaData for GLWECiphertext<DataSel
}
}
pub trait GLWECiphertextToRef {
pub trait GLWECiphertextToRef: Infos {
fn to_ref(&self) -> GLWECiphertext<&[u8]>;
}
@@ -93,7 +93,7 @@ impl<D: AsRef<[u8]>> GLWECiphertextToRef for GLWECiphertext<D> {
}
}
pub trait GLWECiphertextToMut {
pub trait GLWECiphertextToMut: Infos {
fn to_mut(&mut self) -> GLWECiphertext<&mut [u8]>;
}

View File

@@ -2,7 +2,7 @@ use backend::{
FFT64, MatZnxDftOps, MatZnxDftScratch, Module, Scratch, VecZnxBig, VecZnxBigOps, VecZnxDftAlloc, VecZnxDftOps, VecZnxScratch,
};
use crate::{FourierGLWECiphertext, GGSWCiphertext, GLWECiphertext, Infos};
use crate::{GGSWCiphertext, GLWECiphertext, Infos};
impl GLWECiphertext<Vec<u8>> {
pub fn external_product_scratch_space(
@@ -14,21 +14,21 @@ impl GLWECiphertext<Vec<u8>> {
digits: usize,
rank: usize,
) -> usize {
let res_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_ggsw, rank);
let in_size: usize = k_in.div_ceil(basek).div_ceil(digits);
let out_size: usize = k_out.div_ceil(basek);
let ggsw_size: usize = k_ggsw.div_ceil(basek);
let vmp: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size)
+ module.vmp_apply_tmp_bytes(
out_size,
in_size,
in_size, // rows
rank + 1, // cols in
rank + 1, // cols out
ggsw_size,
);
let res_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, ggsw_size);
let a_dft: usize = module.bytes_of_vec_znx_dft(rank + 1, in_size);
let vmp: usize = module.vmp_apply_tmp_bytes(
out_size,
in_size,
in_size, // rows
rank + 1, // cols in
rank + 1, // cols out
ggsw_size,
);
let normalize: usize = module.vec_znx_normalize_tmp_bytes();
res_dft + (vmp | normalize)
res_dft + a_dft + (vmp | normalize)
}
pub fn external_product_inplace_scratch_space(

View File

@@ -2,11 +2,11 @@ use backend::{FFT64, Module, Scratch, VecZnx, VecZnxOps, ZnxZero};
use crate::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, Infos, SetMetaData};
pub trait GLWEOps: GLWECiphertextToMut + Infos + SetMetaData {
pub trait GLWEOps: GLWECiphertextToMut + SetMetaData + Sized {
fn add<A, B>(&mut self, module: &Module<FFT64>, a: &A, b: &B)
where
A: GLWECiphertextToRef + Infos,
B: GLWECiphertextToRef + Infos,
A: GLWECiphertextToRef,
B: GLWECiphertextToRef,
{
#[cfg(debug_assertions)]
{
@@ -47,7 +47,7 @@ pub trait GLWEOps: GLWECiphertextToMut + Infos + SetMetaData {
});
self.set_basek(a.basek());
self.set_k(a.k().max(b.k()));
self.set_k(set_k_binary(self, a, b));
}
fn add_inplace<A>(&mut self, module: &Module<FFT64>, a: &A)
@@ -69,13 +69,13 @@ pub trait GLWEOps: GLWECiphertextToMut + Infos + SetMetaData {
module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i);
});
self.set_k(a.k().max(self.k()));
self.set_k(set_k_unary(self, a))
}
fn sub<A, B>(&mut self, module: &Module<FFT64>, a: &A, b: &B)
where
A: GLWECiphertextToRef + Infos,
B: GLWECiphertextToRef + Infos,
A: GLWECiphertextToRef,
B: GLWECiphertextToRef,
{
#[cfg(debug_assertions)]
{
@@ -117,7 +117,7 @@ pub trait GLWEOps: GLWECiphertextToMut + Infos + SetMetaData {
});
self.set_basek(a.basek());
self.set_k(a.k().max(b.k()));
self.set_k(set_k_binary(self, a, b));
}
fn sub_inplace_ab<A>(&mut self, module: &Module<FFT64>, a: &A)
@@ -139,7 +139,7 @@ pub trait GLWEOps: GLWECiphertextToMut + Infos + SetMetaData {
module.vec_znx_sub_ab_inplace(&mut self_mut.data, i, &a_ref.data, i);
});
self.set_k(a.k().max(self.k()));
self.set_k(set_k_unary(self, a))
}
fn sub_inplace_ba<A>(&mut self, module: &Module<FFT64>, a: &A)
@@ -161,7 +161,7 @@ pub trait GLWEOps: GLWECiphertextToMut + Infos + SetMetaData {
module.vec_znx_sub_ba_inplace(&mut self_mut.data, i, &a_ref.data, i);
});
self.set_k(a.k().max(self.k()));
self.set_k(set_k_unary(self, a))
}
fn rotate<A>(&mut self, module: &Module<FFT64>, k: i64, a: &A)
@@ -183,7 +183,7 @@ pub trait GLWEOps: GLWECiphertextToMut + Infos + SetMetaData {
});
self.set_basek(a.basek());
self.set_k(a.k());
self.set_k(set_k_unary(self, a))
}
fn rotate_inplace(&mut self, module: &Module<FFT64>, k: i64) {
@@ -217,7 +217,7 @@ pub trait GLWEOps: GLWECiphertextToMut + Infos + SetMetaData {
module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
});
self.set_k(a.k());
self.set_k(a.k().min(self.size() * self.basek()));
self.set_basek(a.basek());
}
@@ -229,7 +229,7 @@ pub trait GLWEOps: GLWECiphertextToMut + Infos + SetMetaData {
fn normalize<A>(&mut self, module: &Module<FFT64>, a: &A, scratch: &mut Scratch)
where
A: GLWECiphertextToRef + Infos,
A: GLWECiphertextToRef,
{
#[cfg(debug_assertions)]
{
@@ -245,7 +245,7 @@ pub trait GLWEOps: GLWECiphertextToMut + Infos + SetMetaData {
module.vec_znx_normalize(a.basek(), &mut self_mut.data, i, &a_ref.data, i, scratch);
});
self.set_basek(a.basek());
self.set_k(a.k());
self.set_k(a.k().min(self.k()));
}
fn normalize_inplace(&mut self, module: &Module<FFT64>, scratch: &mut Scratch) {
@@ -265,3 +265,33 @@ impl GLWECiphertext<Vec<u8>> {
VecZnx::rsh_scratch_space(module.n())
}
}
// c = op(a, b)
fn set_k_binary(c: &impl Infos, a: &impl Infos, b: &impl Infos) -> usize {
// If either operands is a ciphertext
if a.rank() != 0 || b.rank() != 0 {
// If a is a plaintext (but b ciphertext)
let k = if a.rank() == 0 {
b.k()
// If b is a plaintext (but a ciphertext)
} else if b.rank() == 0 {
a.k()
// If a & b are both ciphertexts
} else {
a.k().min(b.k())
};
k.min(c.k())
// If a & b are both plaintexts
} else {
c.k()
}
}
// a = op(a, b)
fn set_k_unary(a: &impl Infos, b: &impl Infos) -> usize {
if a.rank() != 0 || b.rank() != 0 {
a.k().min(b.k())
} else {
a.k()
}
}

View File

@@ -65,7 +65,7 @@ impl GLWEPacker {
}
/// Implicit reset of the internal state (to be called before a new packing procedure).
pub fn reset(&mut self) {
fn reset(&mut self) {
for i in 0..self.accumulators.len() {
self.accumulators[i].value = false;
self.accumulators[i].control = false;
@@ -82,9 +82,7 @@ impl GLWEPacker {
GLWECiphertext::trace_galois_elements(module)
}
/// Adds a GLWE ciphertext to the [StreamPacker]. And propagates
/// intermediate results among the [Accumulator]s.
///
/// Adds a GLWE ciphertext to the [StreamPacker].
/// #Arguments
///
/// * `module`: static backend FFT tables.
@@ -96,11 +94,16 @@ impl GLWEPacker {
pub fn add<DataA: AsRef<[u8]>, DataAK: AsRef<[u8]>>(
&mut self,
module: &Module<FFT64>,
res: &mut Vec<GLWECiphertext<Vec<u8>>>,
a: Option<&GLWECiphertext<DataA>>,
auto_keys: &HashMap<i64, GLWEAutomorphismKey<DataAK, FFT64>>,
scratch: &mut Scratch,
) {
assert!(
self.counter < module.n(),
"Packing limit of {} reached",
module.n() >> self.log_batch
);
pack_core(
module,
a,
@@ -110,35 +113,18 @@ impl GLWEPacker {
scratch,
);
self.counter += 1 << self.log_batch;
if self.counter == module.n() {
res.push(
self.accumulators[module.log_n() - self.log_batch - 1]
.data
.clone(),
);
self.reset();
}
}
/// Flushes all accumlators and appends the result to `res`.
pub fn flush<DataAK: AsRef<[u8]>>(
&mut self,
module: &Module<FFT64>,
res: &mut Vec<GLWECiphertext<Vec<u8>>>,
auto_keys: &HashMap<i64, GLWEAutomorphismKey<DataAK, FFT64>>,
scratch: &mut Scratch,
) {
if self.counter != 0 {
while self.counter != 0 {
self.add(
module,
res,
None::<&GLWECiphertext<Vec<u8>>>,
auto_keys,
scratch,
);
}
}
/// Flush result to`res`.
pub fn flush<Data: AsMut<[u8]> + AsRef<[u8]>>(&mut self, module: &Module<FFT64>, res: &mut GLWECiphertext<Data>) {
assert!(self.counter == module.n());
// Copy result GLWE into res GLWE
res.copy(
module,
&self.accumulators[module.log_n() - self.log_batch - 1].data,
);
self.reset();
}
}

View File

@@ -1,8 +1,7 @@
use backend::{Decoding, Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxZero};
use itertools::izip;
use backend::{FFT64, FillUniform, Module, ScratchOwned, Stats};
use sampling::source::Source;
use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos};
use crate::{FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWEOps, GLWEPlaintext, GLWEPublicKey, GLWESecret, Infos};
#[test]
fn encrypt_sk() {
@@ -35,7 +34,8 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma:
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_ct, rank);
let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_pt);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_pt);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_pt);
let mut source_xs: Source = Source::new([0u8; 32]);
let mut source_xe: Source = Source::new([0u8; 32]);
@@ -50,17 +50,13 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma:
sk.fill_ternary_prob(0.5, &mut source_xs);
let sk_dft: FourierGLWESecret<Vec<u8>, FFT64> = FourierGLWESecret::from(&module, &sk);
let mut data_want: Vec<i64> = vec![0i64; module.n()];
data_want
.iter_mut()
.for_each(|x| *x = source_xa.next_i64() & 0xFF);
pt.data.encode_vec_i64(0, basek, k_pt, &data_want, 10);
pt_want
.data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
ct.encrypt_sk(
&module,
&pt,
&pt_want,
&sk_dft,
&mut source_xa,
&mut source_xe,
@@ -68,26 +64,14 @@ fn test_encrypt_sk(log_n: usize, basek: usize, k_ct: usize, k_pt: usize, sigma:
scratch.borrow(),
);
pt.data.zero();
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
ct.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
pt_want.sub_inplace_ab(&module, &pt_have);
let mut data_have: Vec<i64> = vec![0i64; module.n()];
let noise_have: f64 = pt_want.data.std(0, basek) * (ct.k() as f64).exp2();
let noise_want: f64 = sigma;
pt.data
.decode_vec_i64(0, basek, pt.size() * basek, &mut data_have);
// TODO: properly assert the decryption noise through std(dec(ct) - pt)
let scale: f64 = (1 << (pt.size() * basek - k_pt)) as f64;
izip!(data_want.iter(), data_have.iter()).for_each(|(a, b)| {
let b_scaled = (*b as f64) / scale;
assert!(
(*a as f64 - b_scaled).abs() < 0.1,
"{} {}",
*a as f64,
b_scaled
)
});
assert!(noise_have <= noise_want + 0.2);
}
fn test_encrypt_zero_sk(log_n: usize, basek: usize, k_ct: usize, sigma: f64, rank: usize) {
@@ -127,6 +111,7 @@ fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma:
let module: Module<FFT64> = Module::<FFT64>::new(1 << log_n);
let mut ct: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_ct, rank);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_ct);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_ct);
let mut source_xs: Source = Source::new([0u8; 32]);
@@ -147,13 +132,9 @@ fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma:
| GLWECiphertext::encrypt_pk_scratch_space(&module, basek, pk.k()),
);
let mut data_want: Vec<i64> = vec![0i64; module.n()];
data_want
.iter_mut()
.for_each(|x| *x = source_xa.next_i64() & 0);
pt_want.data.encode_vec_i64(0, basek, k_ct, &data_want, 10);
pt_want
.data
.fill_uniform(basek, 0, pt_want.size(), &mut source_xa);
ct.encrypt_pk(
&module,
@@ -165,11 +146,9 @@ fn test_encrypt_pk(log_n: usize, basek: usize, k_ct: usize, k_pk: usize, sigma:
scratch.borrow(),
);
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_ct);
ct.decrypt(&module, &mut pt_have, &sk_dft, scratch.borrow());
module.vec_znx_sub_ab_inplace(&mut pt_want.data, 0, &pt_have.data, 0);
pt_want.sub_inplace_ab(&module, &pt_have);
let noise_have: f64 = pt_want.data.std(0, basek).log2();
let noise_want: f64 = ((((rank as f64) + 1.0) * module.n() as f64 * 0.5 * sigma * sigma).sqrt()).log2() - (k_ct as f64);

View File

@@ -74,8 +74,6 @@ fn apply() {
scratch.borrow(),
);
let mut res: Vec<GLWECiphertext<Vec<u8>>> = Vec::new();
(0..module.n() >> log_batch).for_each(|i| {
ct.encrypt_sk(
&module,
@@ -90,11 +88,10 @@ fn apply() {
pt.rotate_inplace(&module, -(1 << log_batch)); // X^-batch * pt
if reverse_bits_msb(i, log_n as u32) % 5 == 0 {
packer.add(&module, &mut res, Some(&ct), &auto_keys, scratch.borrow());
packer.add(&module, Some(&ct), &auto_keys, scratch.borrow());
} else {
packer.add(
&module,
&mut res,
None::<&GLWECiphertext<Vec<u8>>>,
&auto_keys,
scratch.borrow(),
@@ -102,36 +99,29 @@ fn apply() {
}
});
packer.flush(&module, &mut res, &auto_keys, scratch.borrow());
packer.reset();
let mut res = GLWECiphertext::alloc(&module, basek, k_ct, rank);
packer.flush(&module, &mut res);
let mut pt_want: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_ct);
res.iter().enumerate().for_each(|(i, res_i)| {
let mut data: Vec<i64> = vec![0i64; module.n()];
data.iter_mut().enumerate().for_each(|(i, x)| {
if i % 5 == 0 {
*x = reverse_bits_msb(i, log_n as u32) as i64;
}
});
pt_want.data.encode_vec_i64(0, basek, pt_k, &data, 32);
res_i.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
if i & 1 == 0 {
pt.sub_inplace_ab(&module, &pt_want);
} else {
pt.add_inplace(&module, &pt_want);
let mut data: Vec<i64> = vec![0i64; module.n()];
data.iter_mut().enumerate().for_each(|(i, x)| {
if i % 5 == 0 {
*x = reverse_bits_msb(i, log_n as u32) as i64;
}
let noise_have = pt.data.std(0, basek).log2();
// println!("noise_have: {}", noise_have);
assert!(
noise_have < -((k_ct - basek) as f64),
"noise: {}",
noise_have
);
});
pt_want.data.encode_vec_i64(0, basek, pt_k, &data, 32);
res.decrypt(&module, &mut pt, &sk_dft, scratch.borrow());
pt.sub_inplace_ab(&module, &pt_want);
let noise_have = pt.data.std(0, basek).log2();
// println!("noise_have: {}", noise_have);
assert!(
noise_have < -((k_ct - basek) as f64),
"noise: {}",
noise_have
);
}
#[inline(always)]

View File

@@ -19,10 +19,10 @@ pub use ggsw::GGSWCiphertext;
pub use glwe::{GLWECiphertext, GLWEOps, GLWEPacker, GLWEPlaintext, GLWEPublicKey, GLWESecret};
pub use lwe::{LWECiphertext, LWESecret};
pub(crate) use glwe::{GLWECiphertextToMut, GLWECiphertextToRef};
pub use backend;
pub use backend::Scratch;
pub use backend::ScratchOwned;
pub(crate) use glwe::{GLWECiphertextToMut, GLWECiphertextToRef};
use crate::dist::Distribution;

View File

@@ -1,6 +1,6 @@
use rand_chacha::ChaCha8Rng;
use rand_chacha::rand_core::SeedableRng;
use rand_core::{OsRng, RngCore, TryRngCore};
use rand_core::RngCore;
const MAXF64: f64 = 9007199254740992.0;
@@ -8,12 +8,6 @@ pub struct Source {
source: ChaCha8Rng,
}
pub fn new_seed() -> [u8; 32] {
let mut seed = [0u8; 32];
OsRng.try_fill_bytes(&mut seed).unwrap();
seed
}
impl Source {
pub fn new(seed: [u8; 32]) -> Source {
Source {
@@ -21,14 +15,11 @@ impl Source {
}
}
pub fn new_seed(&mut self) -> [u8; 32] {
let mut seed: [u8; 32] = [0u8; 32];
self.source.fill_bytes(&mut seed);
seed
}
pub fn branch(&mut self) -> Self {
Source::new(self.new_seed())
let mut seed = [0; 32];
self.source.fill_bytes(&mut seed);
Source::new(seed)
}
#[inline(always)]