mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add Hardware Abstraction Layer (#56)
This commit is contained in:
committed by
GitHub
parent
833520b163
commit
0e0745065e
@@ -1,18 +1,47 @@
|
||||
use backend::{
|
||||
FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxAlloc,
|
||||
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxView, ZnxViewMut, ZnxZero,
|
||||
use backend::hal::{
|
||||
api::{
|
||||
ScratchAvailable, SvpApply, SvpPPolAllocBytes, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice,
|
||||
TakeVecZnxSlice, VecZnxAddInplace, VecZnxAllocBytes, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes,
|
||||
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftFromVecZnx,
|
||||
VecZnxDftSubABInplace, VecZnxDftToVecZnxBig, VecZnxDftToVecZnxBigTmpBytes, VecZnxDftZero, VecZnxMulXpMinusOneInplace,
|
||||
VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxSubABInplace, VmpApplyTmpBytes, ZnxView, ZnxZero,
|
||||
},
|
||||
layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol},
|
||||
};
|
||||
use itertools::izip;
|
||||
|
||||
use crate::{
|
||||
GLWECiphertext, GLWECiphertextToMut, GLWEOps, Infos, LWECiphertext, ScratchCore,
|
||||
blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable},
|
||||
GLWECiphertext, GLWECiphertextToMut, GLWEExternalProductFamily, GLWEOps, Infos, LWECiphertext, TakeGLWECt,
|
||||
blind_rotation::{key::BlindRotationKeyCGGIExec, lut::LookUpTable},
|
||||
dist::Distribution,
|
||||
lwe::ciphertext::LWECiphertextToRef,
|
||||
};
|
||||
|
||||
pub fn cggi_blind_rotate_scratch_space(
|
||||
module: &Module<FFT64>,
|
||||
pub trait CCGIBlindRotationFamily<B: Backend> = VecZnxBigAllocBytes
|
||||
+ VecZnxDftAllocBytes
|
||||
+ SvpPPolAllocBytes
|
||||
+ VmpApplyTmpBytes
|
||||
+ VecZnxBigNormalizeTmpBytes
|
||||
+ VecZnxDftToVecZnxBigTmpBytes
|
||||
+ VecZnxDftToVecZnxBig<B>
|
||||
+ VecZnxDftAdd<B>
|
||||
+ VecZnxDftAddInplace<B>
|
||||
+ VecZnxDftFromVecZnx<B>
|
||||
+ VecZnxDftZero<B>
|
||||
+ SvpApply<B>
|
||||
+ VecZnxDftSubABInplace<B>
|
||||
+ VecZnxBigAddSmallInplace<B>
|
||||
+ GLWEExternalProductFamily<B>
|
||||
+ VecZnxRotate
|
||||
+ VecZnxAddInplace
|
||||
+ VecZnxSubABInplace
|
||||
+ VecZnxNormalize<B>
|
||||
+ VecZnxNormalizeInplace<B>
|
||||
+ VecZnxCopy
|
||||
+ VecZnxMulXpMinusOneInplace;
|
||||
|
||||
pub fn cggi_blind_rotate_scratch_space<B: Backend>(
|
||||
module: &Module<B>,
|
||||
block_size: usize,
|
||||
extension_factor: usize,
|
||||
basek: usize,
|
||||
@@ -20,22 +49,24 @@ pub fn cggi_blind_rotate_scratch_space(
|
||||
k_brk: usize,
|
||||
rows: usize,
|
||||
rank: usize,
|
||||
) -> usize {
|
||||
) -> usize
|
||||
where
|
||||
Module<B>: CCGIBlindRotationFamily<B> + VecZnxAllocBytes,
|
||||
{
|
||||
let brk_size: usize = k_brk.div_ceil(basek);
|
||||
|
||||
if block_size > 1 {
|
||||
let cols: usize = rank + 1;
|
||||
let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, rows) * extension_factor;
|
||||
let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size);
|
||||
let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor;
|
||||
let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, rows) * extension_factor;
|
||||
let acc_big: usize = module.vec_znx_big_alloc_bytes(1, brk_size);
|
||||
let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor;
|
||||
let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(1, brk_size);
|
||||
let acc_dft_add: usize = vmp_res;
|
||||
let xai_plus_y: usize = module.bytes_of_scalar_znx_dft(1);
|
||||
let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1);
|
||||
let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
|
||||
|
||||
let acc: usize;
|
||||
if extension_factor > 1 {
|
||||
acc = module.bytes_of_vec_znx(cols, k_res.div_ceil(basek)) * extension_factor;
|
||||
acc = module.vec_znx_alloc_bytes(cols, k_res.div_ceil(basek)) * extension_factor;
|
||||
} else {
|
||||
acc = 0;
|
||||
}
|
||||
@@ -44,26 +75,30 @@ pub fn cggi_blind_rotate_scratch_space(
|
||||
+ acc_dft
|
||||
+ acc_dft_add
|
||||
+ vmp_res
|
||||
+ xai_plus_y
|
||||
+ xai_plus_y_dft
|
||||
+ (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())));
|
||||
+ vmp_xai
|
||||
+ (vmp
|
||||
| (acc_big
|
||||
+ (module.vec_znx_big_normalize_tmp_bytes(module.n()) | module.vec_znx_dft_to_vec_znx_big_tmp_bytes())));
|
||||
} else {
|
||||
2 * GLWECiphertext::bytes_of(module, basek, k_res, rank)
|
||||
GLWECiphertext::bytes_of(module, basek, k_res, rank)
|
||||
+ GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cggi_blind_rotate<DataRes, DataIn, DataBrk>(
|
||||
module: &Module<FFT64>,
|
||||
pub fn cggi_blind_rotate<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
module: &Module<B>,
|
||||
res: &mut GLWECiphertext<DataRes>,
|
||||
lwe: &LWECiphertext<DataIn>,
|
||||
lut: &LookUpTable,
|
||||
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
brk: &BlindRotationKeyCGGIExec<DataBrk, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
||||
DataIn: AsRef<[u8]>,
|
||||
DataBrk: AsRef<[u8]>,
|
||||
DataRes: DataMut,
|
||||
DataIn: DataRef,
|
||||
DataBrk: DataRef,
|
||||
Module<B>: CCGIBlindRotationFamily<B>,
|
||||
Scratch<B>:
|
||||
TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnx<B> + ScratchAvailable + TakeVecZnxSlice<B>,
|
||||
{
|
||||
match brk.dist {
|
||||
Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => {
|
||||
@@ -82,37 +117,36 @@ pub fn cggi_blind_rotate<DataRes, DataIn, DataBrk>(
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
|
||||
module: &Module<FFT64>,
|
||||
pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
module: &Module<B>,
|
||||
res: &mut GLWECiphertext<DataRes>,
|
||||
lwe: &LWECiphertext<DataIn>,
|
||||
lut: &LookUpTable,
|
||||
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
brk: &BlindRotationKeyCGGIExec<DataBrk, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
||||
DataIn: AsRef<[u8]>,
|
||||
DataBrk: AsRef<[u8]>,
|
||||
DataRes: DataMut,
|
||||
DataIn: DataRef,
|
||||
DataBrk: DataRef,
|
||||
Module<B>: CCGIBlindRotationFamily<B>,
|
||||
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice<B>,
|
||||
{
|
||||
let extension_factor: usize = lut.extension_factor();
|
||||
let basek: usize = res.basek();
|
||||
let rows: usize = brk.rows();
|
||||
let cols: usize = res.rank() + 1;
|
||||
|
||||
let (mut acc, scratch1) = scratch.tmp_slice_vec_znx(extension_factor, module, cols, res.size());
|
||||
let (mut acc_dft, scratch2) = scratch1.tmp_slice_vec_znx_dft(extension_factor, module, cols, rows);
|
||||
let (mut vmp_res, scratch3) = scratch2.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size());
|
||||
let (mut acc_add_dft, scratch4) = scratch3.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size());
|
||||
let (mut minus_one, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1);
|
||||
let (mut xai_plus_y_dft, scratch6) = scratch5.tmp_scalar_znx_dft(module, 1);
|
||||
|
||||
minus_one.raw_mut()[..module.n() >> 1].fill(-1.0);
|
||||
let (mut acc, scratch1) = scratch.take_vec_znx_slice(extension_factor, module, cols, res.size());
|
||||
let (mut acc_dft, scratch2) = scratch1.take_vec_znx_dft_slice(extension_factor, module, cols, rows);
|
||||
let (mut vmp_res, scratch3) = scratch2.take_vec_znx_dft_slice(extension_factor, module, cols, brk.size());
|
||||
let (mut acc_add_dft, scratch4) = scratch3.take_vec_znx_dft_slice(extension_factor, module, cols, brk.size());
|
||||
let (mut vmp_xai, scratch5) = scratch4.take_vec_znx_dft(module, 1, brk.size());
|
||||
|
||||
(0..extension_factor).for_each(|i| {
|
||||
acc[i].zero();
|
||||
});
|
||||
|
||||
let x_pow_a: &Vec<ScalarZnxDft<Vec<u8>, FFT64>>;
|
||||
let x_pow_a: &Vec<SvpPPol<Vec<u8>, B>>;
|
||||
if let Some(b) = &brk.x_pow_a {
|
||||
x_pow_a = b
|
||||
} else {
|
||||
@@ -149,9 +183,9 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
|
||||
.for_each(|(ai, ski)| {
|
||||
(0..extension_factor).for_each(|i| {
|
||||
(0..cols).for_each(|j| {
|
||||
module.vec_znx_dft(1, 0, &mut acc_dft[i], j, &acc[i], j);
|
||||
module.vec_znx_dft_from_vec_znx(1, 0, &mut acc_dft[i], j, &acc[i], j);
|
||||
});
|
||||
acc_add_dft[i].zero();
|
||||
module.vec_znx_dft_zero(&mut acc_add_dft[i])
|
||||
});
|
||||
|
||||
// TODO: first & last iterations can be optimized
|
||||
@@ -162,19 +196,19 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
|
||||
|
||||
// vmp_res = DFT(acc) * BRK[i]
|
||||
(0..extension_factor).for_each(|i| {
|
||||
module.vmp_apply(&mut vmp_res[i], &acc_dft[i], &skii.data, scratch6);
|
||||
module.vmp_apply(&mut vmp_res[i], &acc_dft[i], &skii.data, scratch5);
|
||||
});
|
||||
|
||||
// Trivial case: no rotation between polynomials, we can directly multiply with (X^{-ai} - 1)
|
||||
if ai_lo == 0 {
|
||||
// Sets acc_add_dft[i] = (acc[i] * sk) * (X^{-ai} - 1)
|
||||
// Sets acc_add_dft[i] = (acc[i] * sk) * X^{-ai} - (acc[i] * sk)
|
||||
if ai_hi != 0 {
|
||||
// DFT X^{-ai}
|
||||
module.vec_znx_dft_add(&mut xai_plus_y_dft, 0, &x_pow_a[ai_hi], 0, &minus_one, 0);
|
||||
(0..extension_factor).for_each(|j| {
|
||||
(0..cols).for_each(|i| {
|
||||
module.svp_apply_inplace(&mut vmp_res[j], i, &xai_plus_y_dft, 0);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i);
|
||||
module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], i);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_xai, 0);
|
||||
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -184,32 +218,13 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
|
||||
// ring homomorphism R^{N} -> prod R^{N/extension_factor}, so we split the
|
||||
// computation in two steps: acc_add_dft = (acc * sk) * (-1) + (acc * sk) * X^{-ai}
|
||||
} else {
|
||||
// Sets acc_add_dft[i] = acc[i] * sk
|
||||
|
||||
// Sets acc_add_dft[0..ai_lo] -= acc[..ai_lo] * sk
|
||||
if (ai_hi + 1) & (two_n - 1) != 0 {
|
||||
for i in 0..ai_lo {
|
||||
(0..cols).for_each(|k| {
|
||||
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Sets acc_add_dft[ai_lo..extension_factor] -= acc[ai_lo..extension_factor] * sk
|
||||
if ai_hi != 0 {
|
||||
for i in ai_lo..extension_factor {
|
||||
(0..cols).for_each(|k: usize| {
|
||||
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Sets acc_add_dft[0..ai_lo] += (acc[extension_factor - ai_lo..extension_factor] * sk) * X^{-ai+1}
|
||||
if (ai_hi + 1) & (two_n - 1) != 0 {
|
||||
for (i, j) in (0..ai_lo).zip(extension_factor - ai_lo..extension_factor) {
|
||||
(0..cols).for_each(|k| {
|
||||
module.svp_apply_inplace(&mut vmp_res[j], k, &x_pow_a[ai_hi + 1], 0);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k);
|
||||
module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_hi + 1], 0, &vmp_res[j], k);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_xai, 0);
|
||||
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -219,8 +234,9 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
|
||||
// Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai}
|
||||
for (i, j) in (ai_lo..extension_factor).zip(0..extension_factor - ai_lo) {
|
||||
(0..cols).for_each(|k| {
|
||||
module.svp_apply_inplace(&mut vmp_res[j], k, &x_pow_a[ai_hi], 0);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k);
|
||||
module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], k);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_xai, 0);
|
||||
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -228,11 +244,11 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
|
||||
});
|
||||
|
||||
{
|
||||
let (mut acc_add_big, scratch7) = scratch6.tmp_vec_znx_big(module, 1, brk.size());
|
||||
let (mut acc_add_big, scratch7) = scratch5.take_vec_znx_big(module, 1, brk.size());
|
||||
|
||||
(0..extension_factor).for_each(|j| {
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_idft(&mut acc_add_big, 0, &acc_add_dft[j], i, scratch7);
|
||||
module.vec_znx_dft_to_vec_znx_big(&mut acc_add_big, 0, &acc_add_dft[j], i, scratch7);
|
||||
module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &acc[j], i);
|
||||
module.vec_znx_big_normalize(basek, &mut acc[j], i, &acc_add_big, 0, scratch7);
|
||||
});
|
||||
@@ -245,17 +261,19 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn, DataBrk>(
|
||||
module: &Module<FFT64>,
|
||||
pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
module: &Module<B>,
|
||||
res: &mut GLWECiphertext<DataRes>,
|
||||
lwe: &LWECiphertext<DataIn>,
|
||||
lut: &LookUpTable,
|
||||
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
brk: &BlindRotationKeyCGGIExec<DataBrk, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
||||
DataIn: AsRef<[u8]>,
|
||||
DataBrk: AsRef<[u8]>,
|
||||
DataRes: DataMut,
|
||||
DataIn: DataRef,
|
||||
DataBrk: DataRef,
|
||||
Module<B>: CCGIBlindRotationFamily<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + TakeVecZnxBig<B>,
|
||||
{
|
||||
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
||||
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
|
||||
@@ -280,15 +298,12 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn, DataBrk>(
|
||||
|
||||
// ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)]
|
||||
|
||||
let (mut acc_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rows);
|
||||
let (mut vmp_res, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, brk.size());
|
||||
let (mut acc_add_dft, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, brk.size());
|
||||
let (mut minus_one, scratch4) = scratch3.tmp_scalar_znx_dft(module, 1);
|
||||
let (mut xai_plus_y_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1);
|
||||
let (mut acc_dft, scratch1) = scratch.take_vec_znx_dft(module, cols, rows);
|
||||
let (mut vmp_res, scratch2) = scratch1.take_vec_znx_dft(module, cols, brk.size());
|
||||
let (mut acc_add_dft, scratch3) = scratch2.take_vec_znx_dft(module, cols, brk.size());
|
||||
let (mut vmp_xai, scratch4) = scratch3.take_vec_znx_dft(module, 1, brk.size());
|
||||
|
||||
minus_one.raw_mut()[..module.n() >> 1].fill(-1.0);
|
||||
|
||||
let x_pow_a: &Vec<ScalarZnxDft<Vec<u8>, FFT64>>;
|
||||
let x_pow_a: &Vec<SvpPPol<Vec<u8>, B>>;
|
||||
if let Some(b) = &brk.x_pow_a {
|
||||
x_pow_a = b
|
||||
} else {
|
||||
@@ -301,50 +316,50 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn, DataBrk>(
|
||||
)
|
||||
.for_each(|(ai, ski)| {
|
||||
(0..cols).for_each(|j| {
|
||||
module.vec_znx_dft(1, 0, &mut acc_dft, j, &out_mut.data, j);
|
||||
module.vec_znx_dft_from_vec_znx(1, 0, &mut acc_dft, j, &out_mut.data, j);
|
||||
});
|
||||
|
||||
acc_add_dft.zero();
|
||||
module.vec_znx_dft_zero(&mut acc_add_dft);
|
||||
|
||||
izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| {
|
||||
let ai_pos: usize = ((aii + two_n as i64) & (two_n - 1) as i64) as usize;
|
||||
|
||||
// vmp_res = DFT(acc) * BRK[i]
|
||||
module.vmp_apply(&mut vmp_res, &acc_dft, &skii.data, scratch5);
|
||||
|
||||
// DFT(X^ai -1)
|
||||
module.vec_znx_dft_add(&mut xai_plus_y_dft, 0, &x_pow_a[ai_pos], 0, &minus_one, 0);
|
||||
module.vmp_apply(&mut vmp_res, &acc_dft, &skii.data, scratch4);
|
||||
|
||||
// DFT(X^ai -1) * (DFT(acc) * BRK[i])
|
||||
(0..cols).for_each(|i| {
|
||||
module.svp_apply_inplace(&mut vmp_res, i, &xai_plus_y_dft, 0);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft, i, &vmp_res, i);
|
||||
module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_pos], 0, &vmp_res, i);
|
||||
module.vec_znx_dft_add_inplace(&mut acc_add_dft, i, &vmp_xai, 0);
|
||||
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft, i, &vmp_res, i);
|
||||
});
|
||||
});
|
||||
|
||||
{
|
||||
let (mut acc_add_big, scratch6) = scratch5.tmp_vec_znx_big(module, 1, brk.size());
|
||||
let (mut acc_add_big, scratch5) = scratch4.take_vec_znx_big(module, 1, brk.size());
|
||||
|
||||
(0..cols).for_each(|i| {
|
||||
module.vec_znx_idft(&mut acc_add_big, 0, &acc_add_dft, i, scratch6);
|
||||
module.vec_znx_dft_to_vec_znx_big(&mut acc_add_big, 0, &acc_add_dft, i, scratch5);
|
||||
module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &out_mut.data, i);
|
||||
module.vec_znx_big_normalize(basek, &mut out_mut.data, i, &acc_add_big, 0, scratch6);
|
||||
module.vec_znx_big_normalize(basek, &mut out_mut.data, i, &acc_add_big, 0, scratch5);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn cggi_blind_rotate_binary_standard<DataRes, DataIn, DataBrk>(
|
||||
module: &Module<FFT64>,
|
||||
pub(crate) fn cggi_blind_rotate_binary_standard<DataRes, DataIn, DataBrk, B: Backend>(
|
||||
module: &Module<B>,
|
||||
res: &mut GLWECiphertext<DataRes>,
|
||||
lwe: &LWECiphertext<DataIn>,
|
||||
lut: &LookUpTable,
|
||||
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
|
||||
scratch: &mut Scratch,
|
||||
brk: &BlindRotationKeyCGGIExec<DataBrk, B>,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
DataRes: AsRef<[u8]> + AsMut<[u8]>,
|
||||
DataIn: AsRef<[u8]>,
|
||||
DataBrk: AsRef<[u8]>,
|
||||
DataRes: DataMut,
|
||||
DataIn: DataRef,
|
||||
DataBrk: DataRef,
|
||||
Module<B>: CCGIBlindRotationFamily<B>,
|
||||
Scratch<B>: TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnx<B> + ScratchAvailable,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
@@ -401,28 +416,24 @@ pub(crate) fn cggi_blind_rotate_binary_standard<DataRes, DataIn, DataBrk>(
|
||||
module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0);
|
||||
|
||||
// ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)]
|
||||
let (mut acc_tmp, scratch1) = scratch.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank());
|
||||
let (mut acc_tmp_rot, scratch2) = scratch1.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank());
|
||||
let (mut acc_tmp, scratch1) = scratch.take_glwe_ct(module, basek, out_mut.k(), out_mut.rank());
|
||||
|
||||
// TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs
|
||||
// TODO: first iteration can be optimized to be a gglwe product
|
||||
izip!(a.iter(), brk.data.iter()).for_each(|(ai, ski)| {
|
||||
// acc_tmp = sk[i] * acc
|
||||
acc_tmp.external_product(module, &out_mut, ski, scratch2);
|
||||
acc_tmp.external_product(module, &out_mut, ski, scratch1);
|
||||
|
||||
// acc_tmp = (sk[i] * acc) * X^{ai}
|
||||
acc_tmp_rot.rotate(module, *ai, &acc_tmp);
|
||||
// acc_tmp = (sk[i] * acc) * (X^{ai} - 1)
|
||||
acc_tmp.mul_xp_minus_one_inplace(module, *ai);
|
||||
|
||||
// acc = acc + (sk[i] * acc) * X^{ai}
|
||||
out_mut.add_inplace(module, &acc_tmp_rot);
|
||||
|
||||
// acc = acc + (sk[i] * acc) * X^{ai} - (sk[i] * acc) = acc + (sk[i] * acc) * (X^{ai} - 1)
|
||||
out_mut.sub_inplace_ab(module, &acc_tmp);
|
||||
// acc = acc + (sk[i] * acc) * (X^{ai} - 1)
|
||||
out_mut.add_inplace(module, &acc_tmp);
|
||||
});
|
||||
|
||||
// We can normalize only at the end because we add normalized values in [-2^{basek-1}, 2^{basek-1}]
|
||||
// on top of each others, thus ~ 2^{63-basek} additions are supported before overflow.
|
||||
out_mut.normalize_inplace(module, scratch2);
|
||||
out_mut.normalize_inplace(module, scratch1);
|
||||
}
|
||||
|
||||
pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) {
|
||||
|
||||
@@ -1,39 +1,185 @@
|
||||
use backend::{
|
||||
Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxToRef, Scratch,
|
||||
ZnxView, ZnxViewMut,
|
||||
use backend::hal::{
|
||||
api::{
|
||||
MatZnxAlloc, ScalarZnxAlloc, ScratchAvailable, SvpPPolAlloc, SvpPrepare, TakeVecZnx, TakeVecZnxDft,
|
||||
VecZnxAddScalarInplace, VecZnxAllocBytes, ZnxView, ZnxViewMut,
|
||||
},
|
||||
layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, ScalarZnx, ScalarZnxToRef, Scratch, SvpPPol, WriterTo},
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{Distribution, FourierGLWESecret, GGSWCiphertext, Infos, LWESecret};
|
||||
use crate::{
|
||||
Distribution, GGSWCiphertext, GGSWCiphertextExec, GGSWEncryptSkFamily, GGSWLayoutFamily, GLWESecretExec, Infos, LWESecret,
|
||||
};
|
||||
|
||||
pub struct BlindRotationKeyCGGI<D, B: Backend> {
|
||||
pub(crate) data: Vec<GGSWCiphertext<D, B>>,
|
||||
pub struct BlindRotationKeyCGGI<D: Data> {
|
||||
pub(crate) keys: Vec<GGSWCiphertext<D>>,
|
||||
pub(crate) dist: Distribution,
|
||||
pub(crate) x_pow_a: Option<Vec<ScalarZnxDft<Vec<u8>, B>>>,
|
||||
}
|
||||
|
||||
// pub struct BlindRotationKeyFHEW<B: Backend> {
|
||||
// pub(crate) data: Vec<GGSWCiphertext<Vec<u8>, B>>,
|
||||
// pub(crate) auto: Vec<GLWEAutomorphismKey<Vec<u8>, B>>,
|
||||
//}
|
||||
impl<D: Data> PartialEq for BlindRotationKeyCGGI<D> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
if self.keys.len() != other.keys.len() {
|
||||
return false;
|
||||
}
|
||||
for (a, b) in self.keys.iter().zip(other.keys.iter()) {
|
||||
if a != b {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
self.dist == other.dist
|
||||
}
|
||||
}
|
||||
|
||||
impl BlindRotationKeyCGGI<Vec<u8>, FFT64> {
|
||||
pub fn allocate(module: &Module<FFT64>, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
|
||||
let mut data: Vec<GGSWCiphertext<Vec<u8>, FFT64>> = Vec::with_capacity(n_lwe);
|
||||
impl<D: Data> Eq for BlindRotationKeyCGGI<D> {}
|
||||
|
||||
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
|
||||
|
||||
impl<D: DataMut> ReaderFrom for BlindRotationKeyCGGI<D> {
|
||||
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
|
||||
match Distribution::read_from(reader) {
|
||||
Ok(dist) => self.dist = dist,
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
|
||||
if self.keys.len() != len {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
format!("self.keys.len()={} != read len={}", self.keys.len(), len),
|
||||
));
|
||||
}
|
||||
for key in &mut self.keys {
|
||||
key.read_from(reader)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataRef> WriterTo for BlindRotationKeyCGGI<D> {
|
||||
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
|
||||
match self.dist.write_to(writer) {
|
||||
Ok(()) => {}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
writer.write_u64::<LittleEndian>(self.keys.len() as u64)?;
|
||||
for key in &self.keys {
|
||||
key.write_to(writer)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl BlindRotationKeyCGGI<Vec<u8>> {
|
||||
pub fn alloc<B: Backend>(module: &Module<B>, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self
|
||||
where
|
||||
Module<B>: MatZnxAlloc,
|
||||
{
|
||||
let mut data: Vec<GGSWCiphertext<Vec<u8>>> = Vec::with_capacity(n_lwe);
|
||||
(0..n_lwe).for_each(|_| data.push(GGSWCiphertext::alloc(module, basek, k, rows, 1, rank)));
|
||||
Self {
|
||||
data,
|
||||
keys: data,
|
||||
dist: Distribution::NONE,
|
||||
x_pow_a: None::<Vec<ScalarZnxDft<Vec<u8>, FFT64>>>,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_from_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
|
||||
pub fn generate_from_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
|
||||
where
|
||||
Module<B>: GGSWEncryptSkFamily<B> + VecZnxAllocBytes,
|
||||
{
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank)
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]>> BlindRotationKeyCGGI<D, FFT64> {
|
||||
impl<D: DataRef> BlindRotationKeyCGGI<D> {
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn n(&self) -> usize {
|
||||
self.keys[0].n()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn rows(&self) -> usize {
|
||||
self.keys[0].rows()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn k(&self) -> usize {
|
||||
self.keys[0].k()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn size(&self) -> usize {
|
||||
self.keys[0].size()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn rank(&self) -> usize {
|
||||
self.keys[0].rank()
|
||||
}
|
||||
|
||||
pub(crate) fn basek(&self) -> usize {
|
||||
self.keys[0].basek()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn block_size(&self) -> usize {
|
||||
match self.dist {
|
||||
Distribution::BinaryBlock(value) => value,
|
||||
_ => 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut> BlindRotationKeyCGGI<D> {
|
||||
pub fn generate_from_sk<DataSkGLWE, DataSkLWE, B: Backend>(
|
||||
&mut self,
|
||||
module: &Module<B>,
|
||||
sk_glwe: &GLWESecretExec<DataSkGLWE, B>,
|
||||
sk_lwe: &LWESecret<DataSkLWE>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
scratch: &mut Scratch<B>,
|
||||
) where
|
||||
DataSkGLWE: DataRef,
|
||||
DataSkLWE: DataRef,
|
||||
Module<B>: GGSWEncryptSkFamily<B> + ScalarZnxAlloc + VecZnxAddScalarInplace,
|
||||
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx<B>,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.keys.len(), sk_lwe.n());
|
||||
assert_eq!(sk_glwe.n(), module.n());
|
||||
assert_eq!(sk_glwe.rank(), self.keys[0].rank());
|
||||
match sk_lwe.dist {
|
||||
Distribution::BinaryBlock(_)
|
||||
| Distribution::BinaryFixed(_)
|
||||
| Distribution::BinaryProb(_)
|
||||
| Distribution::ZERO => {}
|
||||
_ => panic!(
|
||||
"invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
self.dist = sk_lwe.dist;
|
||||
|
||||
let mut pt: ScalarZnx<Vec<u8>> = module.scalar_znx_alloc(1);
|
||||
let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data.to_ref();
|
||||
|
||||
self.keys.iter_mut().enumerate().for_each(|(i, ggsw)| {
|
||||
pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i];
|
||||
ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, sigma, scratch);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct BlindRotationKeyCGGIExec<D: Data, B: Backend> {
|
||||
pub(crate) data: Vec<GGSWCiphertextExec<D, B>>,
|
||||
pub(crate) dist: Distribution,
|
||||
pub(crate) x_pow_a: Option<Vec<SvpPPol<Vec<u8>, B>>>,
|
||||
}
|
||||
|
||||
impl<D: Data, B: Backend> BlindRotationKeyCGGIExec<D, B> {
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn n(&self) -> usize {
|
||||
self.data[0].n()
|
||||
@@ -71,52 +217,66 @@ impl<D: AsRef<[u8]>> BlindRotationKeyCGGI<D, FFT64> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: AsRef<[u8]> + AsMut<[u8]>> BlindRotationKeyCGGI<D, FFT64> {
|
||||
pub fn generate_from_sk<DataSkGLWE, DataSkLWE>(
|
||||
&mut self,
|
||||
module: &Module<FFT64>,
|
||||
sk_glwe: &FourierGLWESecret<DataSkGLWE, FFT64>,
|
||||
sk_lwe: &LWESecret<DataSkLWE>,
|
||||
source_xa: &mut Source,
|
||||
source_xe: &mut Source,
|
||||
sigma: f64,
|
||||
scratch: &mut Scratch,
|
||||
) where
|
||||
DataSkGLWE: AsRef<[u8]>,
|
||||
DataSkLWE: AsRef<[u8]>,
|
||||
pub trait BlindRotationKeyCGGIExecLayoutFamily<B: Backend> = GGSWLayoutFamily<B> + SvpPPolAlloc<B> + SvpPrepare<B>;
|
||||
|
||||
impl<B: Backend> BlindRotationKeyCGGIExec<Vec<u8>, B> {
|
||||
pub fn alloc(module: &Module<B>, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self
|
||||
where
|
||||
Module<B>: BlindRotationKeyCGGIExecLayoutFamily<B>,
|
||||
{
|
||||
let mut data: Vec<GGSWCiphertextExec<Vec<u8>, B>> = Vec::with_capacity(n_lwe);
|
||||
(0..n_lwe).for_each(|_| data.push(GGSWCiphertextExec::alloc(module, basek, k, rows, 1, rank)));
|
||||
Self {
|
||||
data,
|
||||
dist: Distribution::NONE,
|
||||
x_pow_a: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from<DataOther>(module: &Module<B>, other: &BlindRotationKeyCGGI<DataOther>, scratch: &mut Scratch<B>) -> Self
|
||||
where
|
||||
DataOther: DataRef,
|
||||
Module<B>: BlindRotationKeyCGGIExecLayoutFamily<B> + ScalarZnxAlloc,
|
||||
{
|
||||
let mut brk: BlindRotationKeyCGGIExec<Vec<u8>, B> = Self::alloc(
|
||||
module,
|
||||
other.keys.len(),
|
||||
other.basek(),
|
||||
other.k(),
|
||||
other.rows(),
|
||||
other.rank(),
|
||||
);
|
||||
brk.prepare(module, other, scratch);
|
||||
brk
|
||||
}
|
||||
}
|
||||
|
||||
impl<D: DataMut, B: Backend> BlindRotationKeyCGGIExec<D, B> {
|
||||
pub fn prepare<DataOther>(&mut self, module: &Module<B>, other: &BlindRotationKeyCGGI<DataOther>, scratch: &mut Scratch<B>)
|
||||
where
|
||||
DataOther: DataRef,
|
||||
Module<B>: BlindRotationKeyCGGIExecLayoutFamily<B> + ScalarZnxAlloc,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert_eq!(self.data.len(), sk_lwe.n());
|
||||
assert_eq!(sk_glwe.n(), module.n());
|
||||
assert_eq!(sk_glwe.rank(), self.data[0].rank());
|
||||
match sk_lwe.dist {
|
||||
Distribution::BinaryBlock(_)
|
||||
| Distribution::BinaryFixed(_)
|
||||
| Distribution::BinaryProb(_)
|
||||
| Distribution::ZERO => {}
|
||||
_ => panic!(
|
||||
"invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)"
|
||||
),
|
||||
}
|
||||
assert_eq!(self.data.len(), other.keys.len());
|
||||
}
|
||||
|
||||
self.dist = sk_lwe.dist;
|
||||
self.data
|
||||
.iter_mut()
|
||||
.zip(other.keys.iter())
|
||||
.for_each(|(ggsw_exec, other)| {
|
||||
ggsw_exec.prepare(module, other, scratch);
|
||||
});
|
||||
|
||||
let mut pt: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data.to_ref();
|
||||
self.dist = other.dist;
|
||||
|
||||
self.data.iter_mut().enumerate().for_each(|(i, ggsw)| {
|
||||
pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i];
|
||||
ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, sigma, scratch);
|
||||
});
|
||||
|
||||
match sk_lwe.dist {
|
||||
match other.dist {
|
||||
Distribution::BinaryBlock(_) => {
|
||||
let mut x_pow_a: Vec<ScalarZnxDft<Vec<u8>, FFT64>> = Vec::with_capacity(module.n() << 1);
|
||||
let mut buf: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
|
||||
let mut x_pow_a: Vec<SvpPPol<Vec<u8>, B>> = Vec::with_capacity(module.n() << 1);
|
||||
let mut buf: ScalarZnx<Vec<u8>> = module.scalar_znx_alloc(1);
|
||||
(0..module.n() << 1).for_each(|i| {
|
||||
let mut res: ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(1);
|
||||
let mut res: SvpPPol<Vec<u8>, B> = module.svp_ppol_alloc(1);
|
||||
set_xai_plus_y(module, i, 0, &mut res, &mut buf);
|
||||
x_pow_a.push(res);
|
||||
});
|
||||
@@ -127,10 +287,11 @@ impl<D: AsRef<[u8]> + AsMut<[u8]>> BlindRotationKeyCGGI<D, FFT64> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_xai_plus_y<A, B>(module: &Module<FFT64>, ai: usize, y: i64, res: &mut ScalarZnxDft<A, FFT64>, buf: &mut ScalarZnx<B>)
|
||||
pub fn set_xai_plus_y<A, C, B: Backend>(module: &Module<B>, ai: usize, y: i64, res: &mut SvpPPol<A, B>, buf: &mut ScalarZnx<C>)
|
||||
where
|
||||
A: AsRef<[u8]> + AsMut<[u8]>,
|
||||
B: AsRef<[u8]> + AsMut<[u8]>,
|
||||
A: DataMut,
|
||||
C: DataMut,
|
||||
Module<B>: SvpPrepare<B>,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
use backend::{FFT64, Module, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned};
|
||||
use backend::hal::{
|
||||
api::{
|
||||
ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAlloc, VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes,
|
||||
VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxViewMut,
|
||||
},
|
||||
layouts::{Backend, Module, ScratchOwned, VecZnx},
|
||||
oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl},
|
||||
};
|
||||
|
||||
pub struct LookUpTable {
|
||||
pub(crate) data: Vec<VecZnx<Vec<u8>>>,
|
||||
@@ -7,7 +14,10 @@ pub struct LookUpTable {
|
||||
}
|
||||
|
||||
impl LookUpTable {
|
||||
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, extension_factor: usize) -> Self {
|
||||
pub fn alloc<B: Backend>(module: &Module<B>, basek: usize, k: usize, extension_factor: usize) -> Self
|
||||
where
|
||||
Module<B>: VecZnxAlloc,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
@@ -19,7 +29,7 @@ impl LookUpTable {
|
||||
let size: usize = k.div_ceil(basek);
|
||||
let mut data: Vec<VecZnx<Vec<u8>>> = Vec::with_capacity(extension_factor);
|
||||
(0..extension_factor).for_each(|_| {
|
||||
data.push(module.new_vec_znx(1, size));
|
||||
data.push(module.vec_znx_alloc(1, size));
|
||||
});
|
||||
Self { data, basek, k }
|
||||
}
|
||||
@@ -36,7 +46,11 @@ impl LookUpTable {
|
||||
self.data.len() * self.data[0].n()
|
||||
}
|
||||
|
||||
pub fn set(&mut self, module: &Module<FFT64>, f: &Vec<i64>, k: usize) {
|
||||
pub fn set<B: Backend>(&mut self, module: &Module<B>, f: &Vec<i64>, k: usize)
|
||||
where
|
||||
Module<B>: VecZnxRotateInplace + VecZnxNormalizeInplace<B> + VecZnxNormalizeTmpBytes + VecZnxSwithcDegree + VecZnxCopy,
|
||||
B: ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
|
||||
{
|
||||
assert!(f.len() <= module.n());
|
||||
|
||||
let basek: usize = self.basek;
|
||||
@@ -74,16 +88,22 @@ impl LookUpTable {
|
||||
// Rotates half the step to the left
|
||||
let half_step: usize = domain_size.div_round(f_len << 1);
|
||||
|
||||
lut_full.rotate(-(half_step as i64));
|
||||
module.vec_znx_rotate_inplace(-(half_step as i64), &mut lut_full, 0);
|
||||
|
||||
let mut tmp_bytes: Vec<u8> = alloc_aligned(lut_full.n() * size_of::<i64>());
|
||||
lut_full.normalize(self.basek, 0, &mut tmp_bytes);
|
||||
let n_large: usize = lut_full.n();
|
||||
|
||||
module.vec_znx_normalize_inplace(
|
||||
self.basek,
|
||||
&mut lut_full,
|
||||
0,
|
||||
ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes(n_large)).borrow(),
|
||||
);
|
||||
|
||||
if self.extension_factor() > 1 {
|
||||
(0..self.extension_factor()).for_each(|i| {
|
||||
module.switch_degree(&mut self.data[i], 0, &lut_full, 0);
|
||||
module.vec_znx_switch_degree(&mut self.data[i], 0, &lut_full, 0);
|
||||
if i < self.extension_factor() {
|
||||
lut_full.rotate(-1);
|
||||
module.vec_znx_rotate_inplace(-1, &mut lut_full, 0);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
@@ -92,7 +112,10 @@ impl LookUpTable {
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) fn rotate(&mut self, k: i64) {
|
||||
pub(crate) fn rotate<B: Backend>(&mut self, module: &Module<B>, k: i64)
|
||||
where
|
||||
Module<B>: VecZnxRotateInplace,
|
||||
{
|
||||
let extension_factor: usize = self.extension_factor();
|
||||
let two_n: usize = 2 * self.data[0].n();
|
||||
let two_n_ext: usize = two_n * extension_factor;
|
||||
@@ -103,11 +126,11 @@ impl LookUpTable {
|
||||
let k_lo: usize = k_pos % extension_factor;
|
||||
|
||||
(0..extension_factor - k_lo).for_each(|i| {
|
||||
self.data[i].rotate(k_hi as i64);
|
||||
module.vec_znx_rotate_inplace(k_hi as i64, &mut self.data[i], 0);
|
||||
});
|
||||
|
||||
(extension_factor - k_lo..extension_factor).for_each(|i| {
|
||||
self.data[i].rotate(k_hi as i64 + 1);
|
||||
module.vec_znx_rotate_inplace(k_hi as i64 + 1, &mut self.data[i], 0);
|
||||
});
|
||||
|
||||
self.data.rotate_right(k_lo as usize);
|
||||
|
||||
@@ -2,9 +2,9 @@ pub mod cggi;
|
||||
pub mod key;
|
||||
pub mod lut;
|
||||
|
||||
pub use cggi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space};
|
||||
pub use key::BlindRotationKeyCGGI;
|
||||
pub use cggi::{CCGIBlindRotationFamily, cggi_blind_rotate, cggi_blind_rotate_scratch_space};
|
||||
pub use key::{BlindRotationKeyCGGI, BlindRotationKeyCGGIExec, BlindRotationKeyCGGIExecLayoutFamily};
|
||||
pub use lut::LookUpTable;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test_fft64;
|
||||
mod test;
|
||||
|
||||
179
core/src/blind_rotation/test/cggi.rs
Normal file
179
core/src/blind_rotation/test/cggi.rs
Normal file
@@ -0,0 +1,179 @@
|
||||
use backend::{
|
||||
hal::{
|
||||
api::{
|
||||
MatZnxAlloc, ModuleNew, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddNormal,
|
||||
VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes, VecZnxEncodeCoeffsi64, VecZnxFillUniform, VecZnxRotateInplace,
|
||||
VecZnxSwithcDegree, ZnxView,
|
||||
},
|
||||
layouts::{Backend, Module, ScratchOwned},
|
||||
oep::{
|
||||
ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
|
||||
TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl,
|
||||
},
|
||||
},
|
||||
implementation::cpu_spqlios::FFT64,
|
||||
};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
BlindRotationKeyCGGIExecLayoutFamily, CCGIBlindRotationFamily, GLWECiphertext, GLWEDecryptFamily, GLWEPlaintext, GLWESecret,
|
||||
GLWESecretExec, GLWESecretFamily, Infos, LWECiphertext, LWESecret,
|
||||
blind_rotation::{
|
||||
cggi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, negate_and_mod_switch_2n},
|
||||
key::{BlindRotationKeyCGGI, BlindRotationKeyCGGIExec},
|
||||
lut::LookUpTable,
|
||||
},
|
||||
lwe::{LWEPlaintext, ciphertext::LWECiphertextToRef},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn standard() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(512);
|
||||
blind_rotatio_test(&module, 224, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn block_binary() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(512);
|
||||
blind_rotatio_test(&module, 224, 7, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn block_binary_extended() {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(512);
|
||||
blind_rotatio_test(&module, 224, 7, 2);
|
||||
}
|
||||
|
||||
pub(crate) trait CGGITestModuleFamily<B: Backend> = CCGIBlindRotationFamily<B>
|
||||
+ GLWESecretFamily<B>
|
||||
+ GLWEDecryptFamily<B>
|
||||
+ BlindRotationKeyCGGIExecLayoutFamily<B>
|
||||
+ VecZnxAlloc
|
||||
+ ScalarZnxAlloc
|
||||
+ VecZnxFillUniform
|
||||
+ VecZnxAddNormal
|
||||
+ VecZnxAllocBytes
|
||||
+ VecZnxAddScalarInplace
|
||||
+ VecZnxEncodeCoeffsi64
|
||||
+ VecZnxRotateInplace
|
||||
+ VecZnxSwithcDegree
|
||||
+ MatZnxAlloc;
|
||||
pub(crate) trait CGGITestScratchFamily<B: Backend> = VecZnxDftAllocBytesImpl<B>
|
||||
+ VecZnxBigAllocBytesImpl<B>
|
||||
+ ScratchOwnedAllocImpl<B>
|
||||
+ ScratchOwnedBorrowImpl<B>
|
||||
+ TakeVecZnxDftImpl<B>
|
||||
+ TakeVecZnxBigImpl<B>
|
||||
+ TakeVecZnxDftSliceImpl<B>
|
||||
+ ScratchAvailableImpl<B>
|
||||
+ TakeVecZnxImpl<B>
|
||||
+ TakeVecZnxSliceImpl<B>;
|
||||
|
||||
fn blind_rotatio_test<B: Backend>(module: &Module<B>, n_lwe: usize, block_size: usize, extension_factor: usize)
|
||||
where
|
||||
Module<B>: CGGITestModuleFamily<B>,
|
||||
B: CGGITestScratchFamily<B>,
|
||||
{
|
||||
let basek: usize = 19;
|
||||
|
||||
let k_lwe: usize = 24;
|
||||
let k_brk: usize = 3 * basek;
|
||||
let rows_brk: usize = 2; // Ensures first limb is noise-free.
|
||||
let k_lut: usize = 1 * basek;
|
||||
let k_res: usize = 2 * basek;
|
||||
let rank: usize = 1;
|
||||
|
||||
let message_modulus: usize = 1 << 4;
|
||||
|
||||
let mut source_xs: Source = Source::new([2u8; 32]);
|
||||
let mut source_xe: Source = Source::new([2u8; 32]);
|
||||
let mut source_xa: Source = Source::new([1u8; 32]);
|
||||
|
||||
let mut sk_glwe: GLWESecret<Vec<u8>> = GLWESecret::alloc(module, rank);
|
||||
sk_glwe.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let sk_glwe_dft: GLWESecretExec<Vec<u8>, B> = GLWESecretExec::from(module, &sk_glwe);
|
||||
|
||||
let mut sk_lwe: LWESecret<Vec<u8>> = LWESecret::alloc(n_lwe);
|
||||
sk_lwe.fill_binary_block(block_size, &mut source_xs);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::<B>::alloc(BlindRotationKeyCGGI::generate_from_sk_scratch_space(
|
||||
module, basek, k_brk, rank,
|
||||
));
|
||||
|
||||
let mut scratch_br: ScratchOwned<B> = ScratchOwned::<B>::alloc(cggi_blind_rotate_scratch_space(
|
||||
module,
|
||||
block_size,
|
||||
extension_factor,
|
||||
basek,
|
||||
k_res,
|
||||
k_brk,
|
||||
rows_brk,
|
||||
rank,
|
||||
));
|
||||
|
||||
let mut brk: BlindRotationKeyCGGI<Vec<u8>> = BlindRotationKeyCGGI::alloc(module, n_lwe, basek, k_brk, rows_brk, rank);
|
||||
|
||||
brk.generate_from_sk(
|
||||
module,
|
||||
&sk_glwe_dft,
|
||||
&sk_lwe,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
3.2,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let mut lwe: LWECiphertext<Vec<u8>> = LWECiphertext::alloc(n_lwe, basek, k_lwe);
|
||||
|
||||
let mut pt_lwe: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc(basek, k_lwe);
|
||||
|
||||
let x: i64 = 2;
|
||||
let bits: usize = 8;
|
||||
|
||||
module.encode_coeff_i64(basek, &mut pt_lwe.data, 0, bits, 0, x, bits);
|
||||
|
||||
lwe.encrypt_sk(
|
||||
module,
|
||||
&pt_lwe,
|
||||
&sk_lwe,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
3.2,
|
||||
);
|
||||
|
||||
let mut f: Vec<i64> = vec![0i64; message_modulus];
|
||||
f.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = 2 * (i as i64) + 1);
|
||||
|
||||
let mut lut: LookUpTable = LookUpTable::alloc(module, basek, k_lut, extension_factor);
|
||||
lut.set(module, &f, message_modulus);
|
||||
|
||||
let mut res: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(module, basek, k_res, rank);
|
||||
|
||||
let brk_exec: BlindRotationKeyCGGIExec<Vec<u8>, B> = BlindRotationKeyCGGIExec::from(module, &brk, scratch_br.borrow());
|
||||
|
||||
cggi_blind_rotate(module, &mut res, &lwe, &lut, &brk_exec, scratch_br.borrow());
|
||||
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(module, basek, k_res);
|
||||
|
||||
res.decrypt(module, &mut pt_have, &sk_glwe_dft, scratch.borrow());
|
||||
|
||||
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
||||
|
||||
negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe.to_ref());
|
||||
|
||||
let pt_want: i64 = (lwe_2n[0]
|
||||
+ lwe_2n[1..]
|
||||
.iter()
|
||||
.zip(sk_lwe.data.at(0, 0))
|
||||
.map(|(x, y)| x * y)
|
||||
.sum::<i64>())
|
||||
& (2 * lut.domain_size() - 1) as i64;
|
||||
|
||||
lut.rotate(module, pt_want);
|
||||
|
||||
// First limb should be exactly equal (test are parameterized such that the noise does not reach
|
||||
// the first limb)
|
||||
assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0));
|
||||
}
|
||||
@@ -1,6 +1,12 @@
|
||||
use std::vec;
|
||||
|
||||
use backend::{FFT64, Module, ZnxView};
|
||||
use backend::{
|
||||
hal::{
|
||||
api::{ModuleNew, ZnxView},
|
||||
layouts::Module,
|
||||
},
|
||||
implementation::cpu_spqlios::FFT64,
|
||||
};
|
||||
|
||||
use crate::blind_rotation::lut::{DivRound, LookUpTable};
|
||||
|
||||
@@ -23,7 +29,7 @@ fn standard() {
|
||||
lut.set(&module, &f, log_scale);
|
||||
|
||||
let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64;
|
||||
lut.rotate(half_step);
|
||||
lut.rotate(&module, half_step);
|
||||
|
||||
let step: usize = lut.domain_size().div_round(message_modulus);
|
||||
|
||||
@@ -33,7 +39,7 @@ fn standard() {
|
||||
f[i / step] % message_modulus as i64,
|
||||
lut.data[0].raw()[0] / (1 << (log_scale % basek)) as i64
|
||||
);
|
||||
lut.rotate(-1);
|
||||
lut.rotate(&module, -1);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -57,7 +63,7 @@ fn extended() {
|
||||
lut.set(&module, &f, log_scale);
|
||||
|
||||
let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64;
|
||||
lut.rotate(half_step);
|
||||
lut.rotate(&module, half_step);
|
||||
|
||||
let step: usize = lut.domain_size().div_round(message_modulus);
|
||||
|
||||
@@ -67,7 +73,7 @@ fn extended() {
|
||||
f[i / step] % message_modulus as i64,
|
||||
lut.data[0].raw()[0] / (1 << (log_scale % basek)) as i64
|
||||
);
|
||||
lut.rotate(-1);
|
||||
lut.rotate(&module, -1);
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -1,125 +0,0 @@
|
||||
use backend::{Encoding, FFT64, Module, ScratchOwned, ZnxView};
|
||||
use sampling::source::Source;
|
||||
|
||||
use crate::{
|
||||
FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret,
|
||||
blind_rotation::{
|
||||
cggi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, negate_and_mod_switch_2n},
|
||||
key::BlindRotationKeyCGGI,
|
||||
lut::LookUpTable,
|
||||
},
|
||||
lwe::{LWEPlaintext, ciphertext::LWECiphertextToRef},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn standard() {
|
||||
blind_rotatio_test(224, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn block_binary() {
|
||||
blind_rotatio_test(224, 7, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn block_binary_extended() {
|
||||
blind_rotatio_test(224, 7, 2);
|
||||
}
|
||||
|
||||
fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) {
|
||||
let module: Module<FFT64> = Module::<FFT64>::new(512);
|
||||
let basek: usize = 19;
|
||||
|
||||
let k_lwe: usize = 24;
|
||||
let k_brk: usize = 3 * basek;
|
||||
let rows_brk: usize = 2; // Ensures first limb is noise-free.
|
||||
let k_lut: usize = 1 * basek;
|
||||
let k_res: usize = 2 * basek;
|
||||
let rank: usize = 1;
|
||||
|
||||
let message_modulus: usize = 1 << 4;
|
||||
|
||||
let mut source_xs: Source = Source::new([2u8; 32]);
|
||||
let mut source_xe: Source = Source::new([2u8; 32]);
|
||||
let mut source_xa: Source = Source::new([1u8; 32]);
|
||||
|
||||
let mut sk_glwe: GLWESecret<Vec<u8>> = GLWESecret::alloc(&module, rank);
|
||||
sk_glwe.fill_ternary_prob(0.5, &mut source_xs);
|
||||
let sk_glwe_dft: FourierGLWESecret<Vec<u8>, FFT64> = FourierGLWESecret::from(&module, &sk_glwe);
|
||||
|
||||
let mut sk_lwe: LWESecret<Vec<u8>> = LWESecret::alloc(n_lwe);
|
||||
sk_lwe.fill_binary_block(block_size, &mut source_xs);
|
||||
|
||||
let mut scratch: ScratchOwned = ScratchOwned::new(BlindRotationKeyCGGI::generate_from_sk_scratch_space(
|
||||
&module, basek, k_brk, rank,
|
||||
));
|
||||
|
||||
let mut scratch_br: ScratchOwned = ScratchOwned::new(cggi_blind_rotate_scratch_space(
|
||||
&module,
|
||||
block_size,
|
||||
extension_factor,
|
||||
basek,
|
||||
k_res,
|
||||
k_brk,
|
||||
rows_brk,
|
||||
rank,
|
||||
));
|
||||
|
||||
let mut brk: BlindRotationKeyCGGI<Vec<u8>, FFT64> =
|
||||
BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank);
|
||||
|
||||
brk.generate_from_sk(
|
||||
&module,
|
||||
&sk_glwe_dft,
|
||||
&sk_lwe,
|
||||
&mut source_xa,
|
||||
&mut source_xe,
|
||||
3.2,
|
||||
scratch.borrow(),
|
||||
);
|
||||
|
||||
let mut lwe: LWECiphertext<Vec<u8>> = LWECiphertext::alloc(n_lwe, basek, k_lwe);
|
||||
|
||||
let mut pt_lwe: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc(basek, k_lwe);
|
||||
|
||||
let x: i64 = 2;
|
||||
let bits: usize = 8;
|
||||
|
||||
pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits);
|
||||
|
||||
lwe.encrypt_sk(&pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, 3.2);
|
||||
|
||||
let mut f: Vec<i64> = vec![0i64; message_modulus];
|
||||
f.iter_mut()
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = 2 * (i as i64) + 1);
|
||||
|
||||
let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor);
|
||||
lut.set(&module, &f, message_modulus);
|
||||
|
||||
let mut res: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_res, rank);
|
||||
|
||||
cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow());
|
||||
|
||||
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_res);
|
||||
|
||||
res.decrypt(&module, &mut pt_have, &sk_glwe_dft, scratch.borrow());
|
||||
|
||||
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
|
||||
|
||||
negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe.to_ref());
|
||||
|
||||
let pt_want: i64 = (lwe_2n[0]
|
||||
+ lwe_2n[1..]
|
||||
.iter()
|
||||
.zip(sk_lwe.data.at(0, 0))
|
||||
.map(|(x, y)| x * y)
|
||||
.sum::<i64>())
|
||||
& (2 * lut.domain_size() - 1) as i64;
|
||||
|
||||
lut.rotate(pt_want);
|
||||
|
||||
// First limb should be exactly equal (test are parameterized such that the noise does not reach
|
||||
// the first limb)
|
||||
assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0));
|
||||
}
|
||||
Reference in New Issue
Block a user