mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Merge remote-tracking branch 'origin/main' into jay/fhe-vm-fixes
This commit is contained in:
@@ -240,11 +240,11 @@ impl Scratch {
|
|||||||
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self) {
|
) -> (Vec<VecZnxDft<&mut [u8], B>>, &mut Self) {
|
||||||
let mut scratch: &mut Scratch = self;
|
let mut scratch: &mut Scratch = self;
|
||||||
let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(slice_size);
|
let mut slice: Vec<VecZnxDft<&mut [u8], B>> = Vec::with_capacity(slice_size);
|
||||||
for _ in 0..slice_size{
|
for _ in 0..slice_size {
|
||||||
let (znx, new_scratch) = scratch.tmp_vec_znx_dft(module, cols, size);
|
let (znx, new_scratch) = scratch.tmp_vec_znx_dft(module, cols, size);
|
||||||
scratch = new_scratch;
|
scratch = new_scratch;
|
||||||
slice.push(znx);
|
slice.push(znx);
|
||||||
};
|
}
|
||||||
(slice, scratch)
|
(slice, scratch)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,11 +279,11 @@ impl Scratch {
|
|||||||
) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
|
) -> (Vec<VecZnx<&mut [u8]>>, &mut Self) {
|
||||||
let mut scratch: &mut Scratch = self;
|
let mut scratch: &mut Scratch = self;
|
||||||
let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(slice_size);
|
let mut slice: Vec<VecZnx<&mut [u8]>> = Vec::with_capacity(slice_size);
|
||||||
for _ in 0..slice_size{
|
for _ in 0..slice_size {
|
||||||
let (znx, new_scratch) = scratch.tmp_vec_znx(module, cols, size);
|
let (znx, new_scratch) = scratch.tmp_vec_znx(module, cols, size);
|
||||||
scratch = new_scratch;
|
scratch = new_scratch;
|
||||||
slice.push(znx);
|
slice.push(znx);
|
||||||
};
|
}
|
||||||
(slice, scratch)
|
(slice, scratch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use backend::{
|
use backend::{
|
||||||
FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps,
|
FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps,
|
||||||
Scratch, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxView,
|
Scratch, VecZnxAlloc, VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxView,
|
||||||
ZnxViewMut, ZnxZero,
|
ZnxViewMut, ZnxZero,
|
||||||
};
|
};
|
||||||
@@ -30,7 +30,7 @@ pub fn cggi_blind_rotate_scratch_space(
|
|||||||
let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size);
|
let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size);
|
||||||
let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor;
|
let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor;
|
||||||
let acc_dft_add: usize = vmp_res;
|
let acc_dft_add: usize = vmp_res;
|
||||||
let xai_plus_y: usize = module.bytes_of_scalar_znx(1);
|
let xai_plus_y: usize = module.bytes_of_scalar_znx_dft(1);
|
||||||
let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1);
|
let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1);
|
||||||
let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
|
let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
|
||||||
|
|
||||||
@@ -54,16 +54,17 @@ pub fn cggi_blind_rotate_scratch_space(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cggi_blind_rotate<DataRes, DataIn>(
|
pub fn cggi_blind_rotate<DataRes, DataIn, DataBrk>(
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
res: &mut GLWECiphertext<DataRes>,
|
res: &mut GLWECiphertext<DataRes>,
|
||||||
lwe: &LWECiphertext<DataIn>,
|
lwe: &LWECiphertext<DataIn>,
|
||||||
lut: &LookUpTable,
|
lut: &LookUpTable,
|
||||||
brk: &BlindRotationKeyCGGI<FFT64>,
|
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
|
||||||
scratch: &mut Scratch,
|
scratch: &mut Scratch,
|
||||||
) where
|
) where
|
||||||
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
||||||
DataIn: AsRef<[u8]>,
|
DataIn: AsRef<[u8]>,
|
||||||
|
DataBrk: AsRef<[u8]>,
|
||||||
{
|
{
|
||||||
match brk.dist {
|
match brk.dist {
|
||||||
Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => {
|
Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => {
|
||||||
@@ -82,16 +83,17 @@ pub fn cggi_blind_rotate<DataRes, DataIn>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
|
pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
res: &mut GLWECiphertext<DataRes>,
|
res: &mut GLWECiphertext<DataRes>,
|
||||||
lwe: &LWECiphertext<DataIn>,
|
lwe: &LWECiphertext<DataIn>,
|
||||||
lut: &LookUpTable,
|
lut: &LookUpTable,
|
||||||
brk: &BlindRotationKeyCGGI<FFT64>,
|
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
|
||||||
scratch: &mut Scratch,
|
scratch: &mut Scratch,
|
||||||
) where
|
) where
|
||||||
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
||||||
DataIn: AsRef<[u8]>,
|
DataIn: AsRef<[u8]>,
|
||||||
|
DataBrk: AsRef<[u8]>,
|
||||||
{
|
{
|
||||||
let extension_factor: usize = lut.extension_factor();
|
let extension_factor: usize = lut.extension_factor();
|
||||||
let basek: usize = res.basek();
|
let basek: usize = res.basek();
|
||||||
@@ -102,25 +104,35 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
|
|||||||
let (mut acc_dft, scratch2) = scratch1.tmp_slice_vec_znx_dft(extension_factor, module, cols, rows);
|
let (mut acc_dft, scratch2) = scratch1.tmp_slice_vec_znx_dft(extension_factor, module, cols, rows);
|
||||||
let (mut vmp_res, scratch3) = scratch2.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size());
|
let (mut vmp_res, scratch3) = scratch2.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size());
|
||||||
let (mut acc_add_dft, scratch4) = scratch3.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size());
|
let (mut acc_add_dft, scratch4) = scratch3.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size());
|
||||||
let (mut xai_plus_y, scratch5) = scratch4.tmp_scalar_znx(module, 1);
|
let (mut minus_one, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1);
|
||||||
let (mut xai_plus_y_dft, scratch6) = scratch5.tmp_scalar_znx_dft(module, 1);
|
let (mut xai_plus_y_dft, scratch6) = scratch5.tmp_scalar_znx_dft(module, 1);
|
||||||
|
|
||||||
|
minus_one.raw_mut()[..module.n() >> 1].fill(-1.0);
|
||||||
|
|
||||||
(0..extension_factor).for_each(|i| {
|
(0..extension_factor).for_each(|i| {
|
||||||
acc[i].zero();
|
acc[i].zero();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let x_pow_a: &Vec<ScalarZnxDft<Vec<u8>, FFT64>>;
|
||||||
|
if let Some(b) = &brk.x_pow_a {
|
||||||
|
x_pow_a = b
|
||||||
|
} else {
|
||||||
|
panic!("invalid key: x_pow_a has not been initialized")
|
||||||
|
}
|
||||||
|
|
||||||
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
||||||
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
|
let lwe_ref: LWECiphertext<&[u8]> = lwe.to_ref();
|
||||||
|
|
||||||
|
let two_n: usize = 2 * module.n();
|
||||||
let two_n_ext: usize = 2 * lut.domain_size();
|
let two_n_ext: usize = 2 * lut.domain_size();
|
||||||
|
|
||||||
negate_and_mod_switch_2n(two_n_ext, &mut lwe_2n, &lwe_ref);
|
negate_and_mod_switch_2n(two_n_ext, &mut lwe_2n, &lwe_ref);
|
||||||
|
|
||||||
let a: &[i64] = &lwe_2n[1..];
|
let a: &[i64] = &lwe_2n[1..];
|
||||||
let b_pos: usize = ((lwe_2n[0] + two_n_ext as i64) % two_n_ext as i64) as usize;
|
let b_pos: usize = ((lwe_2n[0] + two_n_ext as i64) & (two_n_ext - 1) as i64) as usize;
|
||||||
|
|
||||||
let b_hi: usize = b_pos / extension_factor;
|
let b_hi: usize = b_pos / extension_factor;
|
||||||
let b_lo: usize = b_pos % extension_factor;
|
let b_lo: usize = b_pos & (extension_factor - 1);
|
||||||
|
|
||||||
for (i, j) in (0..b_lo).zip(extension_factor - b_lo..extension_factor) {
|
for (i, j) in (0..b_lo).zip(extension_factor - b_lo..extension_factor) {
|
||||||
module.vec_znx_rotate(b_hi as i64 + 1, &mut acc[i], 0, &lut.data[j], 0);
|
module.vec_znx_rotate(b_hi as i64 + 1, &mut acc[i], 0, &lut.data[j], 0);
|
||||||
@@ -145,9 +157,9 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
|
|||||||
|
|
||||||
// TODO: first & last iterations can be optimized
|
// TODO: first & last iterations can be optimized
|
||||||
izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| {
|
izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| {
|
||||||
let ai_pos: usize = ((aii + two_n_ext as i64) % two_n_ext as i64) as usize;
|
let ai_pos: usize = ((aii + two_n_ext as i64) & (two_n_ext - 1) as i64) as usize;
|
||||||
let ai_hi: usize = ai_pos / extension_factor;
|
let ai_hi: usize = ai_pos / extension_factor;
|
||||||
let ai_lo: usize = ai_pos % extension_factor;
|
let ai_lo: usize = ai_pos & (extension_factor - 1);
|
||||||
|
|
||||||
// vmp_res = DFT(acc) * BRK[i]
|
// vmp_res = DFT(acc) * BRK[i]
|
||||||
(0..extension_factor).for_each(|i| {
|
(0..extension_factor).for_each(|i| {
|
||||||
@@ -156,48 +168,62 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
|
|||||||
|
|
||||||
// Trivial case: no rotation between polynomials, we can directly multiply with (X^{-ai} - 1)
|
// Trivial case: no rotation between polynomials, we can directly multiply with (X^{-ai} - 1)
|
||||||
if ai_lo == 0 {
|
if ai_lo == 0 {
|
||||||
// DFT X^{-ai}
|
|
||||||
set_xai_plus_y(module, ai_hi, -1, &mut xai_plus_y_dft, &mut xai_plus_y);
|
|
||||||
|
|
||||||
// Sets acc_add_dft[i] = (acc[i] * sk) * (X^{-ai} - 1)
|
// Sets acc_add_dft[i] = (acc[i] * sk) * (X^{-ai} - 1)
|
||||||
(0..extension_factor).for_each(|j| {
|
if ai_hi != 0 {
|
||||||
(0..cols).for_each(|i| {
|
// DFT X^{-ai}
|
||||||
module.svp_apply_inplace(&mut vmp_res[j], i, &xai_plus_y_dft, 0);
|
module.vec_znx_dft_add(&mut xai_plus_y_dft, 0, &x_pow_a[ai_hi], 0, &minus_one, 0);
|
||||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i);
|
(0..extension_factor).for_each(|j| {
|
||||||
|
(0..cols).for_each(|i| {
|
||||||
|
module.svp_apply_inplace(&mut vmp_res[j], i, &xai_plus_y_dft, 0);
|
||||||
|
module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
});
|
}
|
||||||
|
|
||||||
// Non trivial case: rotation between polynomials
|
// Non trivial case: rotation between polynomials
|
||||||
// In this case we can't directly multiply with (X^{-ai} - 1) because of the
|
// In this case we can't directly multiply with (X^{-ai} - 1) because of the
|
||||||
// ring homomorphism R^{N} -> prod R^{N/extension_factor}, so we split the
|
// ring homomorphism R^{N} -> prod R^{N/extension_factor}, so we split the
|
||||||
// computation in two steps: acc_add_dft = (acc * sk) * (-1) + (acc * sk) * X^{-ai}
|
// computation in two steps: acc_add_dft = (acc * sk) * (-1) + (acc * sk) * X^{-ai}
|
||||||
} else {
|
} else {
|
||||||
// Sets acc_add_dft[i] = acc[i] * sk
|
// Sets acc_add_dft[i] = acc[i] * sk
|
||||||
(0..extension_factor).for_each(|i| {
|
|
||||||
(0..cols).for_each(|k| {
|
|
||||||
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
// DFT X^{-ai}
|
// Sets acc_add_dft[0..ai_lo] -= acc[..ai_lo] * sk
|
||||||
set_xai_plus_y(module, ai_hi + 1, 0, &mut xai_plus_y_dft, &mut xai_plus_y);
|
if (ai_hi + 1) & (two_n - 1) != 0 {
|
||||||
|
for i in 0..ai_lo {
|
||||||
// Sets acc_add_dft[0..ai_lo] += (acc[extension_factor - ai_lo..extension_factor] * sk) * X^{-ai+1}
|
(0..cols).for_each(|k| {
|
||||||
for (i, j) in (0..ai_lo).zip(extension_factor - ai_lo..extension_factor) {
|
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
|
||||||
(0..cols).for_each(|k| {
|
});
|
||||||
module.svp_apply_inplace(&mut vmp_res[j], k, &xai_plus_y_dft, 0);
|
}
|
||||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k);
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DFT X^{-ai}
|
// Sets acc_add_dft[ai_lo..extension_factor] -= acc[ai_lo..extension_factor] * sk
|
||||||
set_xai_plus_y(module, ai_hi, 0, &mut xai_plus_y_dft, &mut xai_plus_y);
|
if ai_hi != 0 {
|
||||||
|
for i in ai_lo..extension_factor {
|
||||||
|
(0..cols).for_each(|k: usize| {
|
||||||
|
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets acc_add_dft[0..ai_lo] += (acc[extension_factor - ai_lo..extension_factor] * sk) * X^{-ai+1}
|
||||||
|
if (ai_hi + 1) & (two_n - 1) != 0 {
|
||||||
|
for (i, j) in (0..ai_lo).zip(extension_factor - ai_lo..extension_factor) {
|
||||||
|
(0..cols).for_each(|k| {
|
||||||
|
module.svp_apply_inplace(&mut vmp_res[j], k, &x_pow_a[ai_hi + 1], 0);
|
||||||
|
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai}
|
// Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai}
|
||||||
for (i, j) in (ai_lo..extension_factor).zip(0..extension_factor - ai_lo) {
|
if ai_hi != 0 {
|
||||||
(0..cols).for_each(|k| {
|
// Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai}
|
||||||
module.svp_apply_inplace(&mut vmp_res[j], k, &xai_plus_y_dft, 0);
|
for (i, j) in (ai_lo..extension_factor).zip(0..extension_factor - ai_lo) {
|
||||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k);
|
(0..cols).for_each(|k| {
|
||||||
});
|
module.svp_apply_inplace(&mut vmp_res[j], k, &x_pow_a[ai_hi], 0);
|
||||||
|
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k);
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -220,49 +246,17 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn>(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_xai_plus_y(
|
pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn, DataBrk>(
|
||||||
module: &Module<FFT64>,
|
|
||||||
ai: usize,
|
|
||||||
y: i64,
|
|
||||||
res: &mut ScalarZnxDft<&mut [u8], FFT64>,
|
|
||||||
buf: &mut ScalarZnx<&mut [u8]>,
|
|
||||||
) {
|
|
||||||
let n: usize = module.n();
|
|
||||||
|
|
||||||
{
|
|
||||||
let raw: &mut [i64] = buf.at_mut(0, 0);
|
|
||||||
if ai < n {
|
|
||||||
raw[ai] = 1;
|
|
||||||
} else {
|
|
||||||
raw[(ai - n) & (n - 1)] = -1;
|
|
||||||
}
|
|
||||||
raw[0] += y;
|
|
||||||
}
|
|
||||||
|
|
||||||
module.svp_prepare(res, 0, buf, 0);
|
|
||||||
|
|
||||||
{
|
|
||||||
let raw: &mut [i64] = buf.at_mut(0, 0);
|
|
||||||
|
|
||||||
if ai < n {
|
|
||||||
raw[ai] = 0;
|
|
||||||
} else {
|
|
||||||
raw[(ai - n) & (n - 1)] = 0;
|
|
||||||
}
|
|
||||||
raw[0] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
|
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
res: &mut GLWECiphertext<DataRes>,
|
res: &mut GLWECiphertext<DataRes>,
|
||||||
lwe: &LWECiphertext<DataIn>,
|
lwe: &LWECiphertext<DataIn>,
|
||||||
lut: &LookUpTable,
|
lut: &LookUpTable,
|
||||||
brk: &BlindRotationKeyCGGI<FFT64>,
|
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
|
||||||
scratch: &mut Scratch,
|
scratch: &mut Scratch,
|
||||||
) where
|
) where
|
||||||
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
||||||
DataIn: AsRef<[u8]>,
|
DataIn: AsRef<[u8]>,
|
||||||
|
DataBrk: AsRef<[u8]>,
|
||||||
{
|
{
|
||||||
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
||||||
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
|
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
|
||||||
@@ -290,9 +284,18 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
|
|||||||
let (mut acc_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rows);
|
let (mut acc_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rows);
|
||||||
let (mut vmp_res, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, brk.size());
|
let (mut vmp_res, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, brk.size());
|
||||||
let (mut acc_add_dft, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, brk.size());
|
let (mut acc_add_dft, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, brk.size());
|
||||||
let (mut xai_plus_y, scratch4) = scratch3.tmp_scalar_znx(module, 1);
|
let (mut minus_one, scratch4) = scratch3.tmp_scalar_znx_dft(module, 1);
|
||||||
let (mut xai_plus_y_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1);
|
let (mut xai_plus_y_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1);
|
||||||
|
|
||||||
|
minus_one.raw_mut()[..module.n() >> 1].fill(-1.0);
|
||||||
|
|
||||||
|
let x_pow_a: &Vec<ScalarZnxDft<Vec<u8>, FFT64>>;
|
||||||
|
if let Some(b) = &brk.x_pow_a {
|
||||||
|
x_pow_a = b
|
||||||
|
} else {
|
||||||
|
panic!("invalid key: x_pow_a has not been initialized")
|
||||||
|
}
|
||||||
|
|
||||||
izip!(
|
izip!(
|
||||||
a.chunks_exact(block_size),
|
a.chunks_exact(block_size),
|
||||||
brk.data.chunks_exact(block_size)
|
brk.data.chunks_exact(block_size)
|
||||||
@@ -305,13 +308,13 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
|
|||||||
acc_add_dft.zero();
|
acc_add_dft.zero();
|
||||||
|
|
||||||
izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| {
|
izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| {
|
||||||
let ai_pos: usize = ((aii + two_n as i64) % two_n as i64) as usize;
|
let ai_pos: usize = ((aii + two_n as i64) & (two_n - 1) as i64) as usize;
|
||||||
|
|
||||||
// vmp_res = DFT(acc) * BRK[i]
|
// vmp_res = DFT(acc) * BRK[i]
|
||||||
module.vmp_apply(&mut vmp_res, &acc_dft, &skii.data, scratch5);
|
module.vmp_apply(&mut vmp_res, &acc_dft, &skii.data, scratch5);
|
||||||
|
|
||||||
// DFT(X^ai -1)
|
// DFT(X^ai -1)
|
||||||
set_xai_plus_y(module, ai_pos, -1, &mut xai_plus_y_dft, &mut xai_plus_y);
|
module.vec_znx_dft_add(&mut xai_plus_y_dft, 0, &x_pow_a[ai_pos], 0, &minus_one, 0);
|
||||||
|
|
||||||
// DFT(X^ai -1) * (DFT(acc) * BRK[i])
|
// DFT(X^ai -1) * (DFT(acc) * BRK[i])
|
||||||
(0..cols).for_each(|i| {
|
(0..cols).for_each(|i| {
|
||||||
@@ -320,10 +323,6 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
(0..cols).for_each(|i| {
|
|
||||||
module.vec_znx_dft_add_inplace(&mut acc_dft, i, &acc_add_dft, i);
|
|
||||||
});
|
|
||||||
|
|
||||||
{
|
{
|
||||||
let (mut acc_add_big, scratch6) = scratch5.tmp_vec_znx_big(module, 1, brk.size());
|
let (mut acc_add_big, scratch6) = scratch5.tmp_vec_znx_big(module, 1, brk.size());
|
||||||
|
|
||||||
@@ -336,16 +335,17 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn>(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn cggi_blind_rotate_binary_standard<DataRes, DataIn>(
|
pub(crate) fn cggi_blind_rotate_binary_standard<DataRes, DataIn, DataBrk>(
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
res: &mut GLWECiphertext<DataRes>,
|
res: &mut GLWECiphertext<DataRes>,
|
||||||
lwe: &LWECiphertext<DataIn>,
|
lwe: &LWECiphertext<DataIn>,
|
||||||
lut: &LookUpTable,
|
lut: &LookUpTable,
|
||||||
brk: &BlindRotationKeyCGGI<FFT64>,
|
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
|
||||||
scratch: &mut Scratch,
|
scratch: &mut Scratch,
|
||||||
) where
|
) where
|
||||||
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
||||||
DataIn: AsRef<[u8]>,
|
DataIn: AsRef<[u8]>,
|
||||||
|
DataBrk: AsRef<[u8]>,
|
||||||
{
|
{
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
use backend::{Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxToRef, Scratch, ZnxView, ZnxViewMut};
|
use backend::{
|
||||||
|
Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxToRef, Scratch,
|
||||||
|
ZnxView, ZnxViewMut,
|
||||||
|
};
|
||||||
use sampling::source::Source;
|
use sampling::source::Source;
|
||||||
|
|
||||||
use crate::{Distribution, FourierGLWESecret, GGSWCiphertext, Infos, LWESecret};
|
use crate::{Distribution, FourierGLWESecret, GGSWCiphertext, Infos, LWESecret};
|
||||||
|
|
||||||
pub struct BlindRotationKeyCGGI<B: Backend> {
|
pub struct BlindRotationKeyCGGI<D, B: Backend> {
|
||||||
pub(crate) data: Vec<GGSWCiphertext<Vec<u8>, B>>,
|
pub(crate) data: Vec<GGSWCiphertext<D, B>>,
|
||||||
pub(crate) dist: Distribution,
|
pub(crate) dist: Distribution,
|
||||||
|
pub(crate) x_pow_a: Option<Vec<ScalarZnxDft<Vec<u8>, B>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub struct BlindRotationKeyFHEW<B: Backend> {
|
// pub struct BlindRotationKeyFHEW<B: Backend> {
|
||||||
@@ -13,20 +17,61 @@ pub struct BlindRotationKeyCGGI<B: Backend> {
|
|||||||
// pub(crate) auto: Vec<GLWEAutomorphismKey<Vec<u8>, B>>,
|
// pub(crate) auto: Vec<GLWEAutomorphismKey<Vec<u8>, B>>,
|
||||||
//}
|
//}
|
||||||
|
|
||||||
impl BlindRotationKeyCGGI<FFT64> {
|
impl BlindRotationKeyCGGI<Vec<u8>, FFT64> {
|
||||||
pub fn allocate(module: &Module<FFT64>, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
|
pub fn allocate(module: &Module<FFT64>, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
|
||||||
let mut data: Vec<GGSWCiphertext<Vec<u8>, FFT64>> = Vec::with_capacity(n_lwe);
|
let mut data: Vec<GGSWCiphertext<Vec<u8>, FFT64>> = Vec::with_capacity(n_lwe);
|
||||||
(0..n_lwe).for_each(|_| data.push(GGSWCiphertext::alloc(module, basek, k, rows, 1, rank)));
|
(0..n_lwe).for_each(|_| data.push(GGSWCiphertext::alloc(module, basek, k, rows, 1, rank)));
|
||||||
Self {
|
Self {
|
||||||
data,
|
data,
|
||||||
dist: Distribution::NONE,
|
dist: Distribution::NONE,
|
||||||
|
x_pow_a: None::<Vec<ScalarZnxDft<Vec<u8>, FFT64>>>,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn generate_from_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
|
pub fn generate_from_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
|
||||||
GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank)
|
GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D: AsRef<[u8]>> BlindRotationKeyCGGI<D, FFT64> {
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn n(&self) -> usize {
|
||||||
|
self.data[0].n()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn rows(&self) -> usize {
|
||||||
|
self.data[0].rows()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn k(&self) -> usize {
|
||||||
|
self.data[0].k()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn size(&self) -> usize {
|
||||||
|
self.data[0].size()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub(crate) fn rank(&self) -> usize {
|
||||||
|
self.data[0].rank()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn basek(&self) -> usize {
|
||||||
|
self.data[0].basek()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn block_size(&self) -> usize {
|
||||||
|
match self.dist {
|
||||||
|
Distribution::BinaryBlock(value) => value,
|
||||||
|
_ => 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D: AsRef<[u8]> + AsMut<[u8]>> BlindRotationKeyCGGI<D, FFT64> {
|
||||||
pub fn generate_from_sk<DataSkGLWE, DataSkLWE>(
|
pub fn generate_from_sk<DataSkGLWE, DataSkLWE>(
|
||||||
&mut self,
|
&mut self,
|
||||||
module: &Module<FFT64>,
|
module: &Module<FFT64>,
|
||||||
@@ -64,42 +109,51 @@ impl BlindRotationKeyCGGI<FFT64> {
|
|||||||
self.data.iter_mut().enumerate().for_each(|(i, ggsw)| {
|
self.data.iter_mut().enumerate().for_each(|(i, ggsw)| {
|
||||||
pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i];
|
pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i];
|
||||||
ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, sigma, scratch);
|
ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, sigma, scratch);
|
||||||
})
|
});
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn block_size(&self) -> usize {
|
match sk_lwe.dist {
|
||||||
match self.dist {
|
Distribution::BinaryBlock(_) => {
|
||||||
Distribution::BinaryBlock(value) => value,
|
let mut x_pow_a: Vec<ScalarZnxDft<Vec<u8>, FFT64>> = Vec::with_capacity(module.n() << 1);
|
||||||
_ => 1,
|
let mut buf: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||||
|
(0..module.n() << 1).for_each(|i| {
|
||||||
|
let mut res: ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(1);
|
||||||
|
set_xai_plus_y(module, i, 0, &mut res, &mut buf);
|
||||||
|
x_pow_a.push(res);
|
||||||
|
});
|
||||||
|
self.x_pow_a = Some(x_pow_a);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
pub fn set_xai_plus_y<A, B>(module: &Module<FFT64>, ai: usize, y: i64, res: &mut ScalarZnxDft<A, FFT64>, buf: &mut ScalarZnx<B>)
|
||||||
pub(crate) fn n(&self) -> usize {
|
where
|
||||||
self.data[0].n()
|
A: AsRef<[u8]> + AsMut<[u8]>,
|
||||||
|
B: AsRef<[u8]> + AsMut<[u8]>,
|
||||||
|
{
|
||||||
|
let n: usize = module.n();
|
||||||
|
|
||||||
|
{
|
||||||
|
let raw: &mut [i64] = buf.at_mut(0, 0);
|
||||||
|
if ai < n {
|
||||||
|
raw[ai] = 1;
|
||||||
|
} else {
|
||||||
|
raw[(ai - n) & (n - 1)] = -1;
|
||||||
|
}
|
||||||
|
raw[0] += y;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(dead_code)]
|
module.svp_prepare(res, 0, buf, 0);
|
||||||
pub(crate) fn rows(&self) -> usize {
|
|
||||||
self.data[0].rows()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
{
|
||||||
pub(crate) fn k(&self) -> usize {
|
let raw: &mut [i64] = buf.at_mut(0, 0);
|
||||||
self.data[0].k()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
if ai < n {
|
||||||
pub(crate) fn size(&self) -> usize {
|
raw[ai] = 0;
|
||||||
self.data[0].size()
|
} else {
|
||||||
}
|
raw[(ai - n) & (n - 1)] = 0;
|
||||||
|
}
|
||||||
#[allow(dead_code)]
|
raw[0] = 0;
|
||||||
pub(crate) fn rank(&self) -> usize {
|
|
||||||
self.data[0].rank()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn basek(&self) -> usize {
|
|
||||||
self.data[0].basek()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,10 @@ impl LookUpTable {
|
|||||||
Self { data, basek, k }
|
Self { data, basek, k }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn log_extension_factor(&self) -> usize {
|
||||||
|
(usize::BITS - (self.extension_factor() - 1).leading_zeros()) as _
|
||||||
|
}
|
||||||
|
|
||||||
pub fn extension_factor(&self) -> usize {
|
pub fn extension_factor(&self) -> usize {
|
||||||
self.data.len()
|
self.data.len()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,8 +39,8 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize)
|
|||||||
|
|
||||||
let message_modulus: usize = 1 << 4;
|
let message_modulus: usize = 1 << 4;
|
||||||
|
|
||||||
let mut source_xs: Source = Source::new([1u8; 32]);
|
let mut source_xs: Source = Source::new([2u8; 32]);
|
||||||
let mut source_xe: Source = Source::new([1u8; 32]);
|
let mut source_xe: Source = Source::new([2u8; 32]);
|
||||||
let mut source_xa: Source = Source::new([1u8; 32]);
|
let mut source_xa: Source = Source::new([1u8; 32]);
|
||||||
|
|
||||||
let mut sk_glwe: GLWESecret<Vec<u8>> = GLWESecret::alloc(&module, rank);
|
let mut sk_glwe: GLWESecret<Vec<u8>> = GLWESecret::alloc(&module, rank);
|
||||||
@@ -65,7 +65,8 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize)
|
|||||||
rank,
|
rank,
|
||||||
));
|
));
|
||||||
|
|
||||||
let mut brk: BlindRotationKeyCGGI<FFT64> = BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank);
|
let mut brk: BlindRotationKeyCGGI<Vec<u8>, FFT64> =
|
||||||
|
BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank);
|
||||||
|
|
||||||
brk.generate_from_sk(
|
brk.generate_from_sk(
|
||||||
&module,
|
&module,
|
||||||
@@ -86,14 +87,8 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize)
|
|||||||
|
|
||||||
pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits);
|
pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits);
|
||||||
|
|
||||||
// println!("{}", pt_lwe.data);
|
|
||||||
|
|
||||||
lwe.encrypt_sk(&pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, 3.2);
|
lwe.encrypt_sk(&pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, 3.2);
|
||||||
|
|
||||||
lwe.decrypt(&mut pt_lwe, &sk_lwe);
|
|
||||||
|
|
||||||
// println!("{}", pt_lwe.data);
|
|
||||||
|
|
||||||
let mut f: Vec<i64> = vec![0i64; message_modulus];
|
let mut f: Vec<i64> = vec![0i64; message_modulus];
|
||||||
f.iter_mut()
|
f.iter_mut()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
@@ -106,14 +101,10 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize)
|
|||||||
|
|
||||||
cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow());
|
cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow());
|
||||||
|
|
||||||
println!("out_mut.data: {}", res.data);
|
|
||||||
|
|
||||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_res);
|
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_res);
|
||||||
|
|
||||||
res.decrypt(&module, &mut pt_have, &sk_glwe_dft, scratch.borrow());
|
res.decrypt(&module, &mut pt_have, &sk_glwe_dft, scratch.borrow());
|
||||||
|
|
||||||
println!("pt_have: {}", pt_have.data);
|
|
||||||
|
|
||||||
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
||||||
|
|
||||||
negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe.to_ref());
|
negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe.to_ref());
|
||||||
@@ -124,23 +115,11 @@ fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize)
|
|||||||
.zip(sk_lwe.data.at(0, 0))
|
.zip(sk_lwe.data.at(0, 0))
|
||||||
.map(|(x, y)| x * y)
|
.map(|(x, y)| x * y)
|
||||||
.sum::<i64>())
|
.sum::<i64>())
|
||||||
% (2 * lut.domain_size()) as i64;
|
& (2 * lut.domain_size() - 1) as i64;
|
||||||
|
|
||||||
// println!("pt_want: {}", pt_want);
|
|
||||||
|
|
||||||
lut.rotate(pt_want);
|
lut.rotate(pt_want);
|
||||||
|
|
||||||
// lut.data.iter().for_each(|d| {
|
|
||||||
// println!("{}", d);
|
|
||||||
// });
|
|
||||||
|
|
||||||
// First limb should be exactly equal (test are parameterized such that the noise does not reach
|
// First limb should be exactly equal (test are parameterized such that the noise does not reach
|
||||||
// the first limb)
|
// the first limb)
|
||||||
assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0));
|
assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0));
|
||||||
|
|
||||||
// Then checks the noise
|
|
||||||
// module.vec_znx_sub_ab_inplace(&mut lut.data[0], 0, &pt_have.data, 0);
|
|
||||||
// let noise: f64 = lut.data[0].std(0, basek);
|
|
||||||
// println!("noise: {}", noise);
|
|
||||||
// assert!(noise < 1e-3);
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user