mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Added LWE-GLWE conversion & LWE Keyswitch, improved LUT generation
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use backend::{
|
||||
MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftOps, VecZnxOps, ZnxInfos, ZnxView, ZnxViewMut, ZnxZero, FFT64
|
||||
FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps,
|
||||
Scratch, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxView,
|
||||
ZnxViewMut, ZnxZero,
|
||||
};
|
||||
use itertools::izip;
|
||||
|
||||
use crate::{
|
||||
FourierGLWECiphertext, FourierGLWESecret, GLWECiphertext, GLWECiphertextToMut, GLWEPlaintext, Infos, LWECiphertext,
|
||||
ScratchCore,
|
||||
GLWECiphertext, GLWECiphertextToMut, Infos, LWECiphertext,
|
||||
blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable},
|
||||
lwe::ciphertext::LWECiphertextToRef,
|
||||
};
|
||||
@@ -21,25 +20,31 @@ pub fn cggi_blind_rotate_scratch_space(
|
||||
rows: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
let lut_size: usize = k_lut.div_ceil(basek);
|
||||
let cols: usize = rank + 1;
|
||||
let brk_size: usize = k_brk.div_ceil(basek);
|
||||
|
||||
let acc_dft: usize = FourierGLWECiphertext::bytes_of(module, basek, k_brk, rank) * extension_factor;
|
||||
let acc_big: usize = module.bytes_of_vec_znx_big(rank + 1, brk_size);
|
||||
let acc_dft_add: usize = FourierGLWECiphertext::bytes_of(module, basek, k_brk, rank) * extension_factor;
|
||||
let vmp_res: usize = FourierGLWECiphertext::bytes_of(module, basek, k_brk, rank) * extension_factor;
|
||||
let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, rows) * extension_factor;
|
||||
let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size);
|
||||
let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor;
|
||||
let acc_dft_add: usize = vmp_res;
|
||||
let xai_plus_y: usize = module.bytes_of_scalar_znx(1);
|
||||
let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1);
|
||||
let vmp: usize = module.vmp_apply_tmp_bytes(lut_size, lut_size, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
|
||||
let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
|
||||
|
||||
let acc: usize;
|
||||
if extension_factor > 1 {
|
||||
acc = GLWECiphertext::bytes_of(module, basek, k_lut, rank) * extension_factor;
|
||||
acc = module.bytes_of_vec_znx(cols, k_lut.div_ceil(basek)) * extension_factor;
|
||||
} else {
|
||||
acc = 0;
|
||||
}
|
||||
|
||||
return acc + acc_big + acc_dft + acc_dft_add + vmp_res + xai_plus_y + xai_plus_y_dft + (vmp | module.vec_znx_big_normalize_tmp_bytes());
|
||||
return acc
|
||||
+ acc_dft
|
||||
+ acc_dft_add
|
||||
+ vmp_res
|
||||
+ xai_plus_y
|
||||
+ xai_plus_y_dft
|
||||
+ (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())));
|
||||
}
|
||||
|
||||
pub fn cggi_blind_rotate<DataRes, DataIn>(
|
||||
@@ -62,6 +67,7 @@ pub fn cggi_blind_rotate<DataRes, DataIn>(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: ENSURE DOMAIN EXTENSION AS
|
||||
pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
|
||||
module: &Module<FFT64>,
|
||||
res: &mut GLWECiphertext<DataRes>,
|
||||
@@ -75,27 +81,25 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
|
||||
{
|
||||
let extension_factor: usize = lut.extension_factor();
|
||||
let basek: usize = res.basek();
|
||||
let rows: usize = brk.rows();
|
||||
let cols: usize = res.rank() + 1;
|
||||
|
||||
let (mut acc, scratch1) = scratch.tmp_vec_glwe_ct(extension_factor, module, basek, res.k(), res.rank());
|
||||
let (mut acc_dft, scratch2) = scratch1.tmp_vec_fourier_glwe_ct(extension_factor, module, basek, brk.k(), res.rank());
|
||||
let (mut vmp_res, scratch3) = scratch2.tmp_vec_fourier_glwe_ct(extension_factor, module, basek, brk.k(), res.rank());
|
||||
let (mut acc_add_dft, scratch4) = scratch3.tmp_vec_fourier_glwe_ct(extension_factor, module, basek, brk.k(), res.rank());
|
||||
|
||||
(0..extension_factor).for_each(|i| {
|
||||
acc[i].data.zero();
|
||||
});
|
||||
|
||||
let (mut acc, scratch1) = scratch.tmp_slice_vec_znx(extension_factor, module, cols, res.size());
|
||||
let (mut acc_dft, scratch2) = scratch1.tmp_slice_vec_znx_dft(extension_factor, module, cols, rows);
|
||||
let (mut vmp_res, scratch3) = scratch2.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size());
|
||||
let (mut acc_add_dft, scratch4) = scratch3.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size());
|
||||
let (mut xai_plus_y, scratch5) = scratch4.tmp_scalar_znx(module, 1);
|
||||
let (mut xai_plus_y_dft, scratch6) = scratch5.tmp_scalar_znx_dft(module, 1);
|
||||
let (mut acc_add_big, scratch7) = scratch6.tmp_vec_znx_big(module, 1, brk.size());
|
||||
|
||||
(0..extension_factor).for_each(|i| {
|
||||
acc[i].zero();
|
||||
});
|
||||
|
||||
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
||||
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
|
||||
|
||||
let two_n_ext: usize = 2 * lut.domain_size();
|
||||
|
||||
let cols: usize = res.rank() + 1;
|
||||
|
||||
negate_and_mod_switch_2n(two_n_ext, &mut lwe_2n, &lwe_ref);
|
||||
|
||||
let a: &[i64] = &lwe_2n[1..];
|
||||
@@ -105,10 +109,10 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
|
||||
let b_lo: usize = b_pos % extension_factor;
|
||||
|
||||
for (i, j) in (0..b_lo).zip(extension_factor - b_lo..extension_factor) {
|
||||
module.vec_znx_rotate(b_hi as i64 + 1, &mut acc[i].data, 0, &lut.data[j], 0);
|
||||
module.vec_znx_rotate(b_hi as i64 + 1, &mut acc[i], 0, &lut.data[j], 0);
|
||||
}
|
||||
for (i, j) in (b_lo..extension_factor).zip(0..extension_factor - b_lo) {
|
||||
module.vec_znx_rotate(b_hi as i64, &mut acc[i].data, 0, &lut.data[j], 0);
|
||||
module.vec_znx_rotate(b_hi as i64, &mut acc[i], 0, &lut.data[j], 0);
|
||||
}
|
||||
|
||||
let block_size: usize = brk.block_size();
|
||||
@@ -121,9 +125,9 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
|
||||
.for_each(|(i, (ai, ski))| {
|
||||
(0..extension_factor).for_each(|i| {
|
||||
(0..cols).for_each(|j| {
|
||||
module.vec_znx_dft(1, 0, &mut acc_dft[i].data, j, &acc[i].data, j);
|
||||
module.vec_znx_dft(1, 0, &mut acc_dft[i], j, &acc[i], j);
|
||||
});
|
||||
acc_add_dft[i].data.zero();
|
||||
acc_add_dft[i].zero();
|
||||
});
|
||||
|
||||
// TODO: first & last iterations can be optimized
|
||||
@@ -134,25 +138,19 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
|
||||
|
||||
// vmp_res = DFT(acc) * BRK[i]
|
||||
(0..extension_factor).for_each(|i| {
|
||||
module.vmp_apply(&mut vmp_res[i].data, &acc_dft[i].data, &skii.data, scratch7);
|
||||
module.vmp_apply(&mut vmp_res[i], &acc_dft[i], &skii.data, scratch6);
|
||||
});
|
||||
|
||||
// Trivial case: no rotation between polynomials, we can directly multiply with (X^{-ai} - 1)
|
||||
if ai_lo == 0 {
|
||||
// DFT X^{-ai}
|
||||
set_xai_plus_y(
|
||||
module,
|
||||
ai_hi as i64,
|
||||
-1,
|
||||
&mut xai_plus_y_dft,
|
||||
&mut xai_plus_y,
|
||||
);
|
||||
set_xai_plus_y(module, ai_hi, -1, &mut xai_plus_y_dft, &mut xai_plus_y);
|
||||
|
||||
// Sets acc_add_dft[i] = (acc[i] * sk) * (X^{-ai} - 1)
|
||||
(0..extension_factor).for_each(|j| {
|
||||
(0..cols).for_each(|i| {
|
||||
module.svp_apply_inplace(&mut vmp_res[j].data, i, &xai_plus_y_dft, 0);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[j].data, i, &vmp_res[j].data, i);
|
||||
module.svp_apply_inplace(&mut vmp_res[j], i, &xai_plus_y_dft, 0);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i);
|
||||
});
|
||||
});
|
||||
// Non trivial case: rotation between polynomials
|
||||
@@ -163,74 +161,83 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
|
||||
// Sets acc_add_dft[i] = acc[i] * sk
|
||||
(0..extension_factor).for_each(|i| {
|
||||
(0..cols).for_each(|k| {
|
||||
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i].data, k, &vmp_res[i].data, k);
|
||||
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
|
||||
})
|
||||
});
|
||||
|
||||
// DFT X^{-ai+1}
|
||||
set_xai_plus_y(
|
||||
module,
|
||||
ai_hi as i64 + 1,
|
||||
0,
|
||||
&mut xai_plus_y_dft,
|
||||
&mut xai_plus_y,
|
||||
);
|
||||
// DFT X^{-ai}
|
||||
set_xai_plus_y(module, ai_hi + 1, 0, &mut xai_plus_y_dft, &mut xai_plus_y);
|
||||
|
||||
// Sets acc_add_dft[0..ai_lo] += (acc[extension_factor - ai_lo..extension_factor] * sk) * X^{-ai+1}
|
||||
for (i, j) in (0..ai_lo).zip(extension_factor - ai_lo..extension_factor) {
|
||||
module.vec_znx_rotate(b_hi as i64 + 1, &mut acc[i].data, 0, &lut.data[j], 0);
|
||||
(0..cols).for_each(|k| {
|
||||
module.svp_apply_inplace(&mut vmp_res[j].data, k, &xai_plus_y_dft, 0);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i].data, k, &vmp_res[j].data, k);
|
||||
module.svp_apply_inplace(&mut vmp_res[j], k, &xai_plus_y_dft, 0);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k);
|
||||
});
|
||||
}
|
||||
|
||||
// DFT X^{-ai}
|
||||
set_xai_plus_y(
|
||||
module,
|
||||
ai_hi as i64,
|
||||
0,
|
||||
&mut xai_plus_y_dft,
|
||||
&mut xai_plus_y,
|
||||
);
|
||||
set_xai_plus_y(module, ai_hi, 0, &mut xai_plus_y_dft, &mut xai_plus_y);
|
||||
|
||||
// Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai}
|
||||
for (i, j) in (ai_lo..extension_factor).zip(0..extension_factor - ai_lo) {
|
||||
module.vec_znx_rotate(b_hi as i64, &mut acc[i].data, 0, &lut.data[j], 0);
|
||||
(0..cols).for_each(|k| {
|
||||
module.svp_apply_inplace(&mut vmp_res[j].data, k, &xai_plus_y_dft, 0);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i].data, k, &vmp_res[j].data, k);
|
||||
module.svp_apply_inplace(&mut vmp_res[j], k, &xai_plus_y_dft, 0);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k);
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
(0..extension_factor).for_each(|j| {
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_dft_add_inplace(&mut acc_dft[j].data, i, &acc_add_dft[j].data, i);
|
||||
module.vec_znx_idft(&mut acc_add_big, 0, &acc_dft[j].data, i, scratch7);
|
||||
module.vec_znx_big_normalize(basek, &mut acc[j].data, i, &acc_add_big, 0, scratch7);
|
||||
{
|
||||
let (mut acc_add_big, scratch7) = scratch6.tmp_vec_znx_big(module, 1, brk.size());
|
||||
|
||||
(0..extension_factor).for_each(|j| {
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_idft(&mut acc_add_big, 0, &acc_add_dft[j], i, scratch7);
|
||||
module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &acc[j], i);
|
||||
module.vec_znx_big_normalize(basek, &mut acc[j], i, &acc_add_big, 0, scratch7);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_copy(&mut res.data, i, &acc[0].data, i);
|
||||
module.vec_znx_copy(&mut res.data, i, &acc[0], i);
|
||||
});
|
||||
}
|
||||
|
||||
fn set_xai_plus_y(
|
||||
module: &Module<FFT64>,
|
||||
k: i64,
|
||||
ai: usize,
|
||||
y: i64,
|
||||
res: &mut ScalarZnxDft<&mut [u8], FFT64>,
|
||||
buf: &mut ScalarZnx<&mut [u8]>,
|
||||
) {
|
||||
buf.zero();
|
||||
buf.at_mut(0, 0)[0] = 1;
|
||||
module.vec_znx_rotate_inplace(k, buf, 0);
|
||||
buf.at_mut(0, 0)[0] += y;
|
||||
let n: usize = module.n();
|
||||
|
||||
{
|
||||
let raw: &mut [i64] = buf.at_mut(0, 0);
|
||||
if ai < n {
|
||||
raw[ai] = 1;
|
||||
} else {
|
||||
raw[(ai - n) & (n - 1)] = -1;
|
||||
}
|
||||
raw[0] += y;
|
||||
}
|
||||
|
||||
module.svp_prepare(res, 0, buf, 0);
|
||||
|
||||
{
|
||||
let raw: &mut [i64] = buf.at_mut(0, 0);
|
||||
|
||||
if ai < n {
|
||||
raw[ai] = 0;
|
||||
} else {
|
||||
raw[(ai - n) & (n - 1)] = 0;
|
||||
}
|
||||
raw[0] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
|
||||
@@ -244,11 +251,12 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
|
||||
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
||||
DataIn: AsRef<[u8]>,
|
||||
{
|
||||
let basek: usize = res.basek();
|
||||
|
||||
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
||||
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
|
||||
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
|
||||
let two_n: usize = module.n() << 1;
|
||||
let basek: usize = brk.basek();
|
||||
let rows = brk.rows();
|
||||
|
||||
let cols: usize = out_mut.rank() + 1;
|
||||
|
||||
@@ -260,48 +268,59 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
|
||||
out_mut.data.zero();
|
||||
|
||||
// Initialize out to X^{b} * LUT(X)
|
||||
module.vec_znx_rotate(-b, &mut out_mut.data, 0, &lut.data[0], 0);
|
||||
module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0);
|
||||
|
||||
let block_size: usize = brk.block_size();
|
||||
|
||||
// ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)]
|
||||
|
||||
let (mut acc_dft, scratch1) = scratch.tmp_fourier_glwe_ct(module, brk.basek(), out_mut.k(), out_mut.rank());
|
||||
let (mut acc_add_dft, scratch2) = scratch1.tmp_fourier_glwe_ct(module, brk.basek(), out_mut.k(), out_mut.rank());
|
||||
let (mut vmp_res, scratch3) = scratch2.tmp_fourier_glwe_ct(module, basek, out_mut.k(), out_mut.rank());
|
||||
let (mut acc_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rows);
|
||||
let (mut vmp_res, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, brk.size());
|
||||
let (mut acc_add_dft, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, brk.size());
|
||||
let (mut xai_plus_y, scratch4) = scratch3.tmp_scalar_znx(module, 1);
|
||||
let (mut xai_plus_y_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1);
|
||||
|
||||
let start: Instant = Instant::now();
|
||||
izip!(
|
||||
a.chunks_exact(block_size),
|
||||
brk.data.chunks_exact(block_size)
|
||||
)
|
||||
.for_each(|(ai, ski)| {
|
||||
out_mut.dft(module, &mut acc_dft);
|
||||
acc_add_dft.data.zero();
|
||||
(0..cols).for_each(|j| {
|
||||
module.vec_znx_dft(1, 0, &mut acc_dft, j, &out_mut.data, j);
|
||||
});
|
||||
|
||||
acc_add_dft.zero();
|
||||
|
||||
izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| {
|
||||
let ai_pos: usize = ((aii + two_n as i64) % two_n as i64) as usize;
|
||||
|
||||
// vmp_res = DFT(acc) * BRK[i]
|
||||
module.vmp_apply(&mut vmp_res.data, &acc_dft.data, &skii.data, scratch5);
|
||||
module.vmp_apply(&mut vmp_res, &acc_dft, &skii.data, scratch5);
|
||||
|
||||
// DFT(X^ai -1)
|
||||
set_xai_plus_y(module, *aii, -1, &mut xai_plus_y_dft, &mut xai_plus_y);
|
||||
set_xai_plus_y(module, ai_pos, -1, &mut xai_plus_y_dft, &mut xai_plus_y);
|
||||
|
||||
// DFT(X^ai -1) * (DFT(acc) * BRK[i])
|
||||
(0..cols).for_each(|i| {
|
||||
module.svp_apply_inplace(&mut vmp_res.data, i, &xai_plus_y_dft, 0);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft.data, i, &vmp_res.data, i);
|
||||
module.svp_apply_inplace(&mut vmp_res, i, &xai_plus_y_dft, 0);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft, i, &vmp_res, i);
|
||||
});
|
||||
});
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_dft_add_inplace(&mut acc_dft.data, i, &acc_add_dft.data, i);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_dft, i, &acc_add_dft, i);
|
||||
});
|
||||
|
||||
acc_dft.idft(module, &mut out_mut, scratch5);
|
||||
{
|
||||
let (mut acc_add_big, scratch6) = scratch5.tmp_vec_znx_big(module, 1, brk.size());
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_idft(&mut acc_add_big, 0, &acc_add_dft, i, scratch6);
|
||||
module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &out_mut.data, i);
|
||||
module.vec_znx_big_normalize(basek, &mut out_mut.data, i, &acc_add_big, 0, scratch6);
|
||||
});
|
||||
}
|
||||
});
|
||||
let duration: std::time::Duration = start.elapsed();
|
||||
}
|
||||
|
||||
pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) {
|
||||
@@ -315,7 +334,7 @@ pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphe
|
||||
if basek > log2n {
|
||||
let diff: usize = basek - log2n;
|
||||
res.iter_mut().for_each(|x| {
|
||||
*x = div_signed_by_pow2(x, diff);
|
||||
*x = div_ceil_signed_by_pow2(x, diff);
|
||||
})
|
||||
} else {
|
||||
let rem: usize = basek - (log2n % basek);
|
||||
@@ -336,7 +355,22 @@ pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphe
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn div_signed_by_pow2(x: &i64, k: usize) -> i64 {
|
||||
fn div_round_by_pow2(x: &i64, k: usize) -> i64 {
|
||||
if x >= &0 {
|
||||
(x + (1 << (k - 1))) >> k
|
||||
} else {
|
||||
(x + (-1 << (k - 1))) >> k
|
||||
}
|
||||
}
|
||||
|
||||
// #[inline(always)]
|
||||
// fn div_floor_signed_by_pow2(x: &i64, k: usize) -> i64{
|
||||
// let bias: i64 = (1 << k) - 1;
|
||||
// (x + ((x >> 63) & bias)) >> k
|
||||
// }
|
||||
|
||||
#[inline(always)]
|
||||
fn div_ceil_signed_by_pow2(x: &i64, k: usize) -> i64 {
|
||||
let bias: i64 = (1 << k) - 1;
|
||||
(x + ((x >> 63) & bias)) >> k
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use backend::{FFT64, Module, ScalarZnx, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxView, ZnxViewMut, alloc_aligned};
|
||||
use backend::{FFT64, Module, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned};
|
||||
|
||||
pub struct LookUpTable {
|
||||
pub(crate) data: Vec<VecZnx<Vec<u8>>>,
|
||||
@@ -24,17 +24,19 @@ impl LookUpTable {
|
||||
self.data.len() * self.data[0].n()
|
||||
}
|
||||
|
||||
pub fn set(&mut self, module: &Module<FFT64>, f: fn(i64) -> i64, message_modulus: usize) {
|
||||
pub fn set(&mut self, module: &Module<FFT64>, f: &Vec<i64>, k: usize) {
|
||||
assert!(f.len() <= module.n());
|
||||
|
||||
let basek: usize = self.basek;
|
||||
|
||||
// Get the number minimum limb to store the message modulus
|
||||
let limbs: usize = message_modulus.div_ceil(1 << basek);
|
||||
let limbs: usize = k.div_ceil(1 << basek);
|
||||
|
||||
// Scaling factor
|
||||
let scale: i64 = (1 << (basek * limbs - 1)).div_round(message_modulus) as i64;
|
||||
let scale: i64 = (1 << (basek * limbs - 1)).div_round(k) as i64;
|
||||
|
||||
// Updates function
|
||||
let f_scaled = |x: i64| (f(x) % message_modulus as i64) * scale;
|
||||
// #elements in lookup table
|
||||
let f_len: usize = f.len();
|
||||
|
||||
// If LUT size > module.n()
|
||||
let domain_size: usize = self.domain_size();
|
||||
@@ -43,29 +45,17 @@ impl LookUpTable {
|
||||
|
||||
// Equivalent to AUTO([f(0), -f(n-1), -f(n-2), ..., -f(1)], -1)
|
||||
let mut lut_full: VecZnx<Vec<u8>> = VecZnx::new::<i64>(domain_size, 1, size);
|
||||
{
|
||||
let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1);
|
||||
|
||||
let start: usize = 0;
|
||||
let end: usize = (domain_size).div_round(message_modulus);
|
||||
let lut_at: &mut [i64] = lut_full.at_mut(0, limbs - 1);
|
||||
|
||||
let y: i64 = f_scaled(0);
|
||||
(start..end).for_each(|i| {
|
||||
lut_at[i] = y;
|
||||
});
|
||||
|
||||
(1..message_modulus).for_each(|x| {
|
||||
let start: usize = (x * domain_size).div_round(message_modulus);
|
||||
let end: usize = ((x + 1) * domain_size).div_round(message_modulus);
|
||||
let y: i64 = f_scaled(x as i64);
|
||||
(start..end).for_each(|i| {
|
||||
lut_at[i] = y;
|
||||
})
|
||||
});
|
||||
}
|
||||
f.iter().enumerate().for_each(|(i, fi)| {
|
||||
let start: usize = (i * domain_size).div_round(f_len);
|
||||
let end: usize = ((i + 1) * domain_size).div_round(f_len);
|
||||
lut_at[start..end].fill(fi * scale);
|
||||
});
|
||||
|
||||
// Rotates half the step to the left
|
||||
let half_step: usize = domain_size.div_round(message_modulus << 1);
|
||||
let half_step: usize = domain_size.div_round(f_len << 1);
|
||||
|
||||
lut_full.rotate(-(half_step as i64));
|
||||
|
||||
@@ -84,30 +74,6 @@ impl LookUpTable {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_raw<D>(&mut self, module: &Module<FFT64>, lut: &ScalarZnx<D>)
|
||||
where
|
||||
D: AsRef<[u8]>,
|
||||
{
|
||||
let domain_size: usize = self.domain_size();
|
||||
|
||||
let size: usize = self.k.div_ceil(self.basek);
|
||||
|
||||
let mut lut_full: VecZnx<Vec<u8>> = VecZnx::new::<i64>(domain_size, 1, size);
|
||||
|
||||
lut_full.at_mut(0, 0).copy_from_slice(lut.raw());
|
||||
|
||||
if self.extension_factor() > 1 {
|
||||
(0..self.extension_factor()).for_each(|i| {
|
||||
module.switch_degree(&mut self.data[i], 0, &lut_full, 0);
|
||||
if i < self.extension_factor() {
|
||||
lut_full.rotate(-1);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn rotate(&mut self, k: i64) {
|
||||
let extension_factor: usize = self.extension_factor();
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use backend::{Encoding, FFT64, Module, ScalarZnx, ScratchOwned, Stats, VecZnxOps, ZnxView, ZnxViewMut};
|
||||
use backend::{Encoding, FFT64, Module, ScratchOwned, Stats, VecZnxOps, ZnxView};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
@@ -16,13 +16,13 @@ use crate::{
|
||||
#[test]
|
||||
fn blind_rotation() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(2048);
|
||||
let basek: usize = 18;
|
||||
let basek: usize = 19;
|
||||
|
||||
let n_lwe: usize = 1071;
|
||||
|
||||
let k_lwe: usize = 24;
|
||||
let k_brk: usize = 3 * basek;
|
||||
let rows_brk: usize = 2;
|
||||
let rows_brk: usize = 1;
|
||||
let k_lut: usize = 2 * basek;
|
||||
let rank: usize = 1;
|
||||
let block_size: usize = 7;
|
||||
@@ -42,22 +42,19 @@ fn blind_rotation() {
|
||||
let mut sk_lwe: LWESecret<Vec<u8>> = LWESecret::alloc(n_lwe);
|
||||
sk_lwe.fill_binary_block(block_size, &mut source_xs);
|
||||
|
||||
sk_lwe.data.raw_mut()[0] = 0;
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(BlindRotationKeyCGGI::generate_from_sk_scratch_space(
|
||||
&module, basek, k_brk, rank,
|
||||
));
|
||||
|
||||
println!("sk_lwe: {:?}", sk_lwe.data.raw());
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(
|
||||
BlindRotationKeyCGGI::generate_from_sk_scratch_space(&module, basek, k_brk, rank)
|
||||
| cggi_blind_rotate_scratch_space(
|
||||
&module,
|
||||
extension_factor,
|
||||
basek,
|
||||
k_lut,
|
||||
k_brk,
|
||||
rows_brk,
|
||||
rank,
|
||||
),
|
||||
);
|
||||
let mut scratch_br: ScratchOwned = ScratchOwned::new(cggi_blind_rotate_scratch_space(
|
||||
&module,
|
||||
extension_factor,
|
||||
basek,
|
||||
k_lut,
|
||||
k_brk,
|
||||
rows_brk,
|
||||
rank,
|
||||
));
|
||||
|
||||
let start: Instant = Instant::now();
|
||||
let mut brk: BlindRotationKeyCGGI<FFT64> = BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank);
|
||||
@@ -79,7 +76,7 @@ fn blind_rotation() {
|
||||
|
||||
let mut pt_lwe: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc(basek, k_lwe);
|
||||
|
||||
let x: i64 = 1;
|
||||
let x: i64 = 2;
|
||||
let bits: usize = 8;
|
||||
|
||||
pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits);
|
||||
@@ -92,18 +89,19 @@ fn blind_rotation() {
|
||||
|
||||
println!("{}", pt_lwe.data);
|
||||
|
||||
fn lut_fn(x: i64) -> i64 {
|
||||
2 * x + 1
|
||||
}
|
||||
let mut f: Vec<i64> = vec![0i64; message_modulus];
|
||||
f.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = 2 * (i as i64) + 1);
|
||||
|
||||
let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor);
|
||||
lut.set(&module, lut_fn, message_modulus);
|
||||
lut.set(&module, &f, message_modulus);
|
||||
|
||||
let mut res: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_lut, rank);
|
||||
|
||||
let start: Instant = Instant::now();
|
||||
(0..1).for_each(|_| {
|
||||
cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch.borrow());
|
||||
(0..32).for_each(|_| {
|
||||
cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow());
|
||||
});
|
||||
|
||||
let duration: std::time::Duration = start.elapsed();
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::vec;
|
||||
|
||||
use backend::{FFT64, Module, ZnxView};
|
||||
|
||||
use crate::blind_rotation::lut::{DivRound, LookUpTable};
|
||||
@@ -12,12 +14,13 @@ fn standard() {
|
||||
|
||||
let scale: usize = (1 << (basek - 1)) / message_modulus;
|
||||
|
||||
fn lut_fn(x: i64) -> i64 {
|
||||
x - 8
|
||||
}
|
||||
let mut f: Vec<i64> = vec![0i64; message_modulus];
|
||||
f.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i as i64) - 8);
|
||||
|
||||
let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor);
|
||||
lut.set(&module, lut_fn, message_modulus);
|
||||
lut.set(&module, &f, message_modulus);
|
||||
|
||||
let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64;
|
||||
lut.rotate(half_step);
|
||||
@@ -27,7 +30,7 @@ fn standard() {
|
||||
(0..lut.domain_size()).step_by(step).for_each(|i| {
|
||||
(0..step).for_each(|_| {
|
||||
assert_eq!(
|
||||
lut_fn((i / step) as i64) % message_modulus as i64,
|
||||
f[i / step] % message_modulus as i64,
|
||||
lut.data[0].raw()[0] / scale as i64
|
||||
);
|
||||
lut.rotate(-1);
|
||||
@@ -45,12 +48,13 @@ fn extended() {
|
||||
|
||||
let scale: usize = (1 << (basek - 1)) / message_modulus;
|
||||
|
||||
fn lut_fn(x: i64) -> i64 {
|
||||
x - 8
|
||||
}
|
||||
let mut f: Vec<i64> = vec![0i64; message_modulus];
|
||||
f.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i as i64) - 8);
|
||||
|
||||
let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor);
|
||||
lut.set(&module, lut_fn, message_modulus);
|
||||
lut.set(&module, &f, message_modulus);
|
||||
|
||||
let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64;
|
||||
lut.rotate(half_step);
|
||||
@@ -60,7 +64,7 @@ fn extended() {
|
||||
(0..lut.domain_size()).step_by(step).for_each(|i| {
|
||||
(0..step).for_each(|_| {
|
||||
assert_eq!(
|
||||
lut_fn((i / step) as i64) % message_modulus as i64,
|
||||
f[i / step] % message_modulus as i64,
|
||||
lut.data[0].raw()[0] / scale as i64
|
||||
);
|
||||
lut.rotate(-1);
|
||||
|
||||
Reference in New Issue
Block a user