mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Added missing tests for CGGI & added standard blind rotation
This commit is contained in:
@@ -6,7 +6,7 @@ use backend::{
|
|||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext,
|
GLWECiphertext, GLWECiphertextToMut, GLWEOps, Infos, LWECiphertext, ScratchCore,
|
||||||
blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable},
|
blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable},
|
||||||
lwe::ciphertext::LWECiphertextToRef,
|
lwe::ciphertext::LWECiphertextToRef,
|
||||||
};
|
};
|
||||||
@@ -63,7 +63,7 @@ pub fn cggi_blind_rotate<DataRes, DataIn>(
|
|||||||
} else if brk.block_size() > 1 {
|
} else if brk.block_size() > 1 {
|
||||||
cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch);
|
cggi_blind_rotate_block_binary(module, res, lwe, lut, brk, scratch);
|
||||||
} else {
|
} else {
|
||||||
todo!("implement this case")
|
cggi_blind_rotate_standard(module, res, lwe, lut, brk, scratch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,8 +121,7 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
|
|||||||
a.chunks_exact(block_size),
|
a.chunks_exact(block_size),
|
||||||
brk.data.chunks_exact(block_size)
|
brk.data.chunks_exact(block_size)
|
||||||
)
|
)
|
||||||
.enumerate()
|
.for_each(|(ai, ski)| {
|
||||||
.for_each(|(i, (ai, ski))| {
|
|
||||||
(0..extension_factor).for_each(|i| {
|
(0..extension_factor).for_each(|i| {
|
||||||
(0..cols).for_each(|j| {
|
(0..cols).for_each(|j| {
|
||||||
module.vec_znx_dft(1, 0, &mut acc_dft[i], j, &acc[i], j);
|
module.vec_znx_dft(1, 0, &mut acc_dft[i], j, &acc[i], j);
|
||||||
@@ -323,6 +322,96 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn cggi_blind_rotate_standard<DataRes, DataIn>(
|
||||||
|
module: &Module<FFT64>,
|
||||||
|
res: &mut GLWECiphertext<DataRes>,
|
||||||
|
lwe: &LWECiphertext<DataIn>,
|
||||||
|
lut: &LookUpTable,
|
||||||
|
brk: &BlindRotationKeyCGGI<FFT64>,
|
||||||
|
scratch: &mut Scratch,
|
||||||
|
) where
|
||||||
|
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
||||||
|
DataIn: AsRef<[u8]>,
|
||||||
|
{
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
{
|
||||||
|
assert_eq!(
|
||||||
|
res.n(),
|
||||||
|
module.n(),
|
||||||
|
"res.n(): {} != brk.n(): {}",
|
||||||
|
res.n(),
|
||||||
|
module.n()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
lut.domain_size(),
|
||||||
|
module.n(),
|
||||||
|
"lut.n(): {} != brk.n(): {}",
|
||||||
|
lut.domain_size(),
|
||||||
|
module.n()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
brk.n(),
|
||||||
|
module.n(),
|
||||||
|
"brk.n(): {} != brk.n(): {}",
|
||||||
|
brk.n(),
|
||||||
|
module.n()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
res.rank(),
|
||||||
|
brk.rank(),
|
||||||
|
"res.rank(): {} != brk.rank(): {}",
|
||||||
|
res.rank(),
|
||||||
|
brk.rank()
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
lwe.n(),
|
||||||
|
brk.data.len(),
|
||||||
|
"lwe.n(): {} != brk.data.len(): {}",
|
||||||
|
lwe.n(),
|
||||||
|
brk.data.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
||||||
|
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
|
||||||
|
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
|
||||||
|
let basek: usize = brk.basek();
|
||||||
|
|
||||||
|
negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe_ref);
|
||||||
|
|
||||||
|
let a: &[i64] = &lwe_2n[1..];
|
||||||
|
let b: i64 = lwe_2n[0];
|
||||||
|
|
||||||
|
out_mut.data.zero();
|
||||||
|
|
||||||
|
// Initialize out to X^{b} * LUT(X)
|
||||||
|
module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0);
|
||||||
|
|
||||||
|
// ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)]
|
||||||
|
let (mut acc_tmp, scratch1) = scratch.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank());
|
||||||
|
let (mut acc_tmp_rot, scratch2) = scratch1.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank());
|
||||||
|
|
||||||
|
// TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs
|
||||||
|
// TODO: first iteration can be optimized to be a gglwe product
|
||||||
|
izip!(a.iter(), brk.data.iter()).for_each(|(ai, ski)| {
|
||||||
|
// acc_tmp = sk[i] * acc
|
||||||
|
acc_tmp.external_product(module, &out_mut, ski, scratch2);
|
||||||
|
|
||||||
|
// acc_tmp = (sk[i] * acc) * X^{ai}
|
||||||
|
acc_tmp_rot.rotate(module, *ai, &acc_tmp);
|
||||||
|
|
||||||
|
// acc = acc + (sk[i] * acc) * X^{ai}
|
||||||
|
out_mut.add_inplace(module, &acc_tmp_rot);
|
||||||
|
|
||||||
|
// acc = acc + (sk[i] * acc) * X^{ai} - (sk[i] * acc) = acc + (sk[i] * acc) * (X^{ai} - 1)
|
||||||
|
out_mut.sub_inplace_ab(module, &acc_tmp);
|
||||||
|
});
|
||||||
|
|
||||||
|
// We can normalize only at the end because we add normalized values in [-2^{basek-1}, 2^{basek-1}]
|
||||||
|
// on top of each others, thus ~ 2^{63-basek} additions are supported before overflow.
|
||||||
|
out_mut.normalize_inplace(module, scratch2);
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) {
|
pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) {
|
||||||
let basek: usize = lwe.basek();
|
let basek: usize = lwe.basek();
|
||||||
|
|
||||||
|
|||||||
@@ -74,6 +74,11 @@ impl BlindRotationKeyCGGI<FFT64> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn n(&self) -> usize {
|
||||||
|
self.data[0].n()
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub(crate) fn rows(&self) -> usize {
|
pub(crate) fn rows(&self) -> usize {
|
||||||
self.data[0].rows()
|
self.data[0].rows()
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
use std::time::Instant;
|
use backend::{Encoding, FFT64, Module, ScratchOwned, ZnxView};
|
||||||
|
|
||||||
use backend::{Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxView};
|
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -14,22 +12,31 @@ use crate::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn blind_rotation() {
|
fn standard() {
|
||||||
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
blind_rotatio_test(224, 1, 1);
|
||||||
let basek: usize = 19;
|
}
|
||||||
|
|
||||||
let n_lwe: usize = 1071;
|
#[test]
|
||||||
|
fn block_binary() {
|
||||||
|
blind_rotatio_test(224, 7, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn block_binary_extended() {
|
||||||
|
blind_rotatio_test(224, 7, 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) {
|
||||||
|
let module: Module<FFT64> = Module::<FFT64>::new(512);
|
||||||
|
let basek: usize = 19;
|
||||||
|
|
||||||
let k_lwe: usize = 24;
|
let k_lwe: usize = 24;
|
||||||
let k_brk: usize = 3 * basek;
|
let k_brk: usize = 3 * basek;
|
||||||
let rows_brk: usize = 1;
|
let rows_brk: usize = 2; // Ensures first limb is noise-free.
|
||||||
let k_lut: usize = 2 * basek;
|
let k_lut: usize = 2 * basek;
|
||||||
let rank: usize = 1;
|
let rank: usize = 1;
|
||||||
let block_size: usize = 7;
|
|
||||||
|
|
||||||
let extension_factor: usize = 2;
|
let message_modulus: usize = 1 << 4;
|
||||||
|
|
||||||
let message_modulus: usize = 1 << 6;
|
|
||||||
|
|
||||||
let mut source_xs: Source = Source::new([1u8; 32]);
|
let mut source_xs: Source = Source::new([1u8; 32]);
|
||||||
let mut source_xe: Source = Source::new([1u8; 32]);
|
let mut source_xe: Source = Source::new([1u8; 32]);
|
||||||
@@ -56,7 +63,6 @@ fn blind_rotation() {
|
|||||||
rank,
|
rank,
|
||||||
));
|
));
|
||||||
|
|
||||||
let start: Instant = Instant::now();
|
|
||||||
let mut brk: BlindRotationKeyCGGI<FFT64> = BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank);
|
let mut brk: BlindRotationKeyCGGI<FFT64> = BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank);
|
||||||
|
|
||||||
brk.generate_from_sk(
|
brk.generate_from_sk(
|
||||||
@@ -69,9 +75,6 @@ fn blind_rotation() {
|
|||||||
scratch.borrow(),
|
scratch.borrow(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let duration: std::time::Duration = start.elapsed();
|
|
||||||
println!("brk-gen: {} ms", duration.as_millis());
|
|
||||||
|
|
||||||
let mut lwe: LWECiphertext<Vec<u8>> = LWECiphertext::alloc(n_lwe, basek, k_lwe);
|
let mut lwe: LWECiphertext<Vec<u8>> = LWECiphertext::alloc(n_lwe, basek, k_lwe);
|
||||||
|
|
||||||
let mut pt_lwe: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc(basek, k_lwe);
|
let mut pt_lwe: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc(basek, k_lwe);
|
||||||
@@ -81,13 +84,13 @@ fn blind_rotation() {
|
|||||||
|
|
||||||
pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits);
|
pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits);
|
||||||
|
|
||||||
println!("{}", pt_lwe.data);
|
// println!("{}", pt_lwe.data);
|
||||||
|
|
||||||
lwe.encrypt_sk(&pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, 3.2);
|
lwe.encrypt_sk(&pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, 3.2);
|
||||||
|
|
||||||
lwe.decrypt(&mut pt_lwe, &sk_lwe);
|
lwe.decrypt(&mut pt_lwe, &sk_lwe);
|
||||||
|
|
||||||
println!("{}", pt_lwe.data);
|
// println!("{}", pt_lwe.data);
|
||||||
|
|
||||||
let mut f: Vec<i64> = vec![0i64; message_modulus];
|
let mut f: Vec<i64> = vec![0i64; message_modulus];
|
||||||
f.iter_mut()
|
f.iter_mut()
|
||||||
@@ -99,13 +102,9 @@ fn blind_rotation() {
|
|||||||
|
|
||||||
let mut res: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_lut, rank);
|
let mut res: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_lut, rank);
|
||||||
|
|
||||||
let start: Instant = Instant::now();
|
cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow());
|
||||||
(0..32).for_each(|_| {
|
|
||||||
cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow());
|
|
||||||
});
|
|
||||||
|
|
||||||
let duration: std::time::Duration = start.elapsed();
|
println!("out_mut.data: {}", res.data);
|
||||||
println!("blind-rotate: {} ms", duration.as_millis());
|
|
||||||
|
|
||||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_lut);
|
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_lut);
|
||||||
|
|
||||||
@@ -125,20 +124,21 @@ fn blind_rotation() {
|
|||||||
.sum::<i64>())
|
.sum::<i64>())
|
||||||
% (2 * lut.domain_size()) as i64;
|
% (2 * lut.domain_size()) as i64;
|
||||||
|
|
||||||
println!("pt_want: {}", pt_want);
|
// println!("pt_want: {}", pt_want);
|
||||||
|
|
||||||
lut.rotate(pt_want);
|
lut.rotate(pt_want);
|
||||||
|
|
||||||
lut.data.iter().for_each(|d| {
|
// lut.data.iter().for_each(|d| {
|
||||||
println!("{}", d);
|
// println!("{}", d);
|
||||||
});
|
// });
|
||||||
|
|
||||||
// First limb should be exactly equal (test are parameterized such that the noise does not reach
|
// First limb should be exactly equal (test are parameterized such that the noise does not reach
|
||||||
// the first limb)
|
// the first limb)
|
||||||
// assert_eq!(pt_have.data.at_mut(0, 0), lut.data[0].at_mut(0, 0));
|
assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0));
|
||||||
|
|
||||||
// Then checks the noise
|
// Then checks the noise
|
||||||
module.vec_znx_sub_ab_inplace(&mut lut.data[0], 0, &pt_have.data, 0);
|
// module.vec_znx_sub_ab_inplace(&mut lut.data[0], 0, &pt_have.data, 0);
|
||||||
let noise: f64 = lut.data[0].std(0, basek);
|
// let noise: f64 = lut.data[0].std(0, basek);
|
||||||
println!("noise: {}", noise);
|
// println!("noise: {}", noise);
|
||||||
|
// assert!(noise < 1e-3);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user