Add Hardware Abstraction Layer (#56)

This commit is contained in:
Jean-Philippe Bossuat
2025-08-08 19:22:42 +02:00
committed by GitHub
parent 833520b163
commit 0e0745065e
194 changed files with 17397 additions and 11955 deletions

View File

@@ -1,18 +1,47 @@
use backend::{
FFT64, MatZnxDftOps, MatZnxDftScratch, Module, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, Scratch, VecZnxAlloc,
VecZnxBigAlloc, VecZnxBigOps, VecZnxBigScratch, VecZnxDftAlloc, VecZnxDftOps, VecZnxOps, ZnxView, ZnxViewMut, ZnxZero,
use backend::hal::{
api::{
ScratchAvailable, SvpApply, SvpPPolAllocBytes, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, TakeVecZnxDftSlice,
TakeVecZnxSlice, VecZnxAddInplace, VecZnxAllocBytes, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes,
VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAdd, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftFromVecZnx,
VecZnxDftSubABInplace, VecZnxDftToVecZnxBig, VecZnxDftToVecZnxBigTmpBytes, VecZnxDftZero, VecZnxMulXpMinusOneInplace,
VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxSubABInplace, VmpApplyTmpBytes, ZnxView, ZnxZero,
},
layouts::{Backend, DataMut, DataRef, Module, Scratch, SvpPPol},
};
use itertools::izip;
use crate::{
GLWECiphertext, GLWECiphertextToMut, GLWEOps, Infos, LWECiphertext, ScratchCore,
blind_rotation::{key::BlindRotationKeyCGGI, lut::LookUpTable},
GLWECiphertext, GLWECiphertextToMut, GLWEExternalProductFamily, GLWEOps, Infos, LWECiphertext, TakeGLWECt,
blind_rotation::{key::BlindRotationKeyCGGIExec, lut::LookUpTable},
dist::Distribution,
lwe::ciphertext::LWECiphertextToRef,
};
pub fn cggi_blind_rotate_scratch_space(
module: &Module<FFT64>,
pub trait CCGIBlindRotationFamily<B: Backend> = VecZnxBigAllocBytes
+ VecZnxDftAllocBytes
+ SvpPPolAllocBytes
+ VmpApplyTmpBytes
+ VecZnxBigNormalizeTmpBytes
+ VecZnxDftToVecZnxBigTmpBytes
+ VecZnxDftToVecZnxBig<B>
+ VecZnxDftAdd<B>
+ VecZnxDftAddInplace<B>
+ VecZnxDftFromVecZnx<B>
+ VecZnxDftZero<B>
+ SvpApply<B>
+ VecZnxDftSubABInplace<B>
+ VecZnxBigAddSmallInplace<B>
+ GLWEExternalProductFamily<B>
+ VecZnxRotate
+ VecZnxAddInplace
+ VecZnxSubABInplace
+ VecZnxNormalize<B>
+ VecZnxNormalizeInplace<B>
+ VecZnxCopy
+ VecZnxMulXpMinusOneInplace;
pub fn cggi_blind_rotate_scratch_space<B: Backend>(
module: &Module<B>,
block_size: usize,
extension_factor: usize,
basek: usize,
@@ -20,22 +49,24 @@ pub fn cggi_blind_rotate_scratch_space(
k_brk: usize,
rows: usize,
rank: usize,
) -> usize {
) -> usize
where
Module<B>: CCGIBlindRotationFamily<B> + VecZnxAllocBytes,
{
let brk_size: usize = k_brk.div_ceil(basek);
if block_size > 1 {
let cols: usize = rank + 1;
let acc_dft: usize = module.bytes_of_vec_znx_dft(cols, rows) * extension_factor;
let acc_big: usize = module.bytes_of_vec_znx_big(1, brk_size);
let vmp_res: usize = module.bytes_of_vec_znx_dft(cols, brk_size) * extension_factor;
let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, rows) * extension_factor;
let acc_big: usize = module.vec_znx_big_alloc_bytes(1, brk_size);
let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor;
let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(1, brk_size);
let acc_dft_add: usize = vmp_res;
let xai_plus_y: usize = module.bytes_of_scalar_znx_dft(1);
let xai_plus_y_dft: usize = module.bytes_of_scalar_znx_dft(1);
let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
let acc: usize;
if extension_factor > 1 {
acc = module.bytes_of_vec_znx(cols, k_res.div_ceil(basek)) * extension_factor;
acc = module.vec_znx_alloc_bytes(cols, k_res.div_ceil(basek)) * extension_factor;
} else {
acc = 0;
}
@@ -44,26 +75,30 @@ pub fn cggi_blind_rotate_scratch_space(
+ acc_dft
+ acc_dft_add
+ vmp_res
+ xai_plus_y
+ xai_plus_y_dft
+ (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_idft_tmp_bytes())));
+ vmp_xai
+ (vmp
| (acc_big
+ (module.vec_znx_big_normalize_tmp_bytes(module.n()) | module.vec_znx_dft_to_vec_znx_big_tmp_bytes())));
} else {
2 * GLWECiphertext::bytes_of(module, basek, k_res, rank)
GLWECiphertext::bytes_of(module, basek, k_res, rank)
+ GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank)
}
}
pub fn cggi_blind_rotate<DataRes, DataIn, DataBrk>(
module: &Module<FFT64>,
pub fn cggi_blind_rotate<DataRes, DataIn, DataBrk, B: Backend>(
module: &Module<B>,
res: &mut GLWECiphertext<DataRes>,
lwe: &LWECiphertext<DataIn>,
lut: &LookUpTable,
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
scratch: &mut Scratch,
brk: &BlindRotationKeyCGGIExec<DataBrk, B>,
scratch: &mut Scratch<B>,
) where
DataRes: AsRef<[u8]> + AsMut<[u8]>,
DataIn: AsRef<[u8]>,
DataBrk: AsRef<[u8]>,
DataRes: DataMut,
DataIn: DataRef,
DataBrk: DataRef,
Module<B>: CCGIBlindRotationFamily<B>,
Scratch<B>:
TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnx<B> + ScratchAvailable + TakeVecZnxSlice<B>,
{
match brk.dist {
Distribution::BinaryBlock(_) | Distribution::BinaryFixed(_) | Distribution::BinaryProb(_) | Distribution::ZERO => {
@@ -82,37 +117,36 @@ pub fn cggi_blind_rotate<DataRes, DataIn, DataBrk>(
}
}
pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
module: &Module<FFT64>,
pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk, B: Backend>(
module: &Module<B>,
res: &mut GLWECiphertext<DataRes>,
lwe: &LWECiphertext<DataIn>,
lut: &LookUpTable,
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
scratch: &mut Scratch,
brk: &BlindRotationKeyCGGIExec<DataBrk, B>,
scratch: &mut Scratch<B>,
) where
DataRes: AsRef<[u8]> + AsMut<[u8]>,
DataIn: AsRef<[u8]>,
DataBrk: AsRef<[u8]>,
DataRes: DataMut,
DataIn: DataRef,
DataBrk: DataRef,
Module<B>: CCGIBlindRotationFamily<B>,
Scratch<B>: TakeVecZnxDftSlice<B> + TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnxSlice<B>,
{
let extension_factor: usize = lut.extension_factor();
let basek: usize = res.basek();
let rows: usize = brk.rows();
let cols: usize = res.rank() + 1;
let (mut acc, scratch1) = scratch.tmp_slice_vec_znx(extension_factor, module, cols, res.size());
let (mut acc_dft, scratch2) = scratch1.tmp_slice_vec_znx_dft(extension_factor, module, cols, rows);
let (mut vmp_res, scratch3) = scratch2.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size());
let (mut acc_add_dft, scratch4) = scratch3.tmp_slice_vec_znx_dft(extension_factor, module, cols, brk.size());
let (mut minus_one, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1);
let (mut xai_plus_y_dft, scratch6) = scratch5.tmp_scalar_znx_dft(module, 1);
minus_one.raw_mut()[..module.n() >> 1].fill(-1.0);
let (mut acc, scratch1) = scratch.take_vec_znx_slice(extension_factor, module, cols, res.size());
let (mut acc_dft, scratch2) = scratch1.take_vec_znx_dft_slice(extension_factor, module, cols, rows);
let (mut vmp_res, scratch3) = scratch2.take_vec_znx_dft_slice(extension_factor, module, cols, brk.size());
let (mut acc_add_dft, scratch4) = scratch3.take_vec_znx_dft_slice(extension_factor, module, cols, brk.size());
let (mut vmp_xai, scratch5) = scratch4.take_vec_znx_dft(module, 1, brk.size());
(0..extension_factor).for_each(|i| {
acc[i].zero();
});
let x_pow_a: &Vec<ScalarZnxDft<Vec<u8>, FFT64>>;
let x_pow_a: &Vec<SvpPPol<Vec<u8>, B>>;
if let Some(b) = &brk.x_pow_a {
x_pow_a = b
} else {
@@ -149,9 +183,9 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
.for_each(|(ai, ski)| {
(0..extension_factor).for_each(|i| {
(0..cols).for_each(|j| {
module.vec_znx_dft(1, 0, &mut acc_dft[i], j, &acc[i], j);
module.vec_znx_dft_from_vec_znx(1, 0, &mut acc_dft[i], j, &acc[i], j);
});
acc_add_dft[i].zero();
module.vec_znx_dft_zero(&mut acc_add_dft[i])
});
// TODO: first & last iterations can be optimized
@@ -162,19 +196,19 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
// vmp_res = DFT(acc) * BRK[i]
(0..extension_factor).for_each(|i| {
module.vmp_apply(&mut vmp_res[i], &acc_dft[i], &skii.data, scratch6);
module.vmp_apply(&mut vmp_res[i], &acc_dft[i], &skii.data, scratch5);
});
// Trivial case: no rotation between polynomials, we can directly multiply with (X^{-ai} - 1)
if ai_lo == 0 {
// Sets acc_add_dft[i] = (acc[i] * sk) * (X^{-ai} - 1)
// Sets acc_add_dft[i] = (acc[i] * sk) * X^{-ai} - (acc[i] * sk)
if ai_hi != 0 {
// DFT X^{-ai}
module.vec_znx_dft_add(&mut xai_plus_y_dft, 0, &x_pow_a[ai_hi], 0, &minus_one, 0);
(0..extension_factor).for_each(|j| {
(0..cols).for_each(|i| {
module.svp_apply_inplace(&mut vmp_res[j], i, &xai_plus_y_dft, 0);
module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i);
module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], i);
module.vec_znx_dft_add_inplace(&mut acc_add_dft[j], i, &vmp_xai, 0);
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[j], i, &vmp_res[j], i);
});
});
}
@@ -184,32 +218,13 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
// ring homomorphism R^{N} -> prod R^{N/extension_factor}, so we split the
// computation in two steps: acc_add_dft = (acc * sk) * (-1) + (acc * sk) * X^{-ai}
} else {
// Sets acc_add_dft[i] = acc[i] * sk
// Sets acc_add_dft[0..ai_lo] -= acc[..ai_lo] * sk
if (ai_hi + 1) & (two_n - 1) != 0 {
for i in 0..ai_lo {
(0..cols).for_each(|k| {
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
});
}
}
// Sets acc_add_dft[ai_lo..extension_factor] -= acc[ai_lo..extension_factor] * sk
if ai_hi != 0 {
for i in ai_lo..extension_factor {
(0..cols).for_each(|k: usize| {
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
});
}
}
// Sets acc_add_dft[0..ai_lo] += (acc[extension_factor - ai_lo..extension_factor] * sk) * X^{-ai+1}
if (ai_hi + 1) & (two_n - 1) != 0 {
for (i, j) in (0..ai_lo).zip(extension_factor - ai_lo..extension_factor) {
(0..cols).for_each(|k| {
module.svp_apply_inplace(&mut vmp_res[j], k, &x_pow_a[ai_hi + 1], 0);
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k);
module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_hi + 1], 0, &vmp_res[j], k);
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_xai, 0);
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
});
}
}
@@ -219,8 +234,9 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
// Sets acc_add_dft[ai_lo..extension_factor] += (acc[0..extension_factor - ai_lo] * sk) * X^{-ai}
for (i, j) in (ai_lo..extension_factor).zip(0..extension_factor - ai_lo) {
(0..cols).for_each(|k| {
module.svp_apply_inplace(&mut vmp_res[j], k, &x_pow_a[ai_hi], 0);
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_res[j], k);
module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_hi], 0, &vmp_res[j], k);
module.vec_znx_dft_add_inplace(&mut acc_add_dft[i], k, &vmp_xai, 0);
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft[i], k, &vmp_res[i], k);
});
}
}
@@ -228,11 +244,11 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
});
{
let (mut acc_add_big, scratch7) = scratch6.tmp_vec_znx_big(module, 1, brk.size());
let (mut acc_add_big, scratch7) = scratch5.take_vec_znx_big(module, 1, brk.size());
(0..extension_factor).for_each(|j| {
(0..cols).for_each(|i| {
module.vec_znx_idft(&mut acc_add_big, 0, &acc_add_dft[j], i, scratch7);
module.vec_znx_dft_to_vec_znx_big(&mut acc_add_big, 0, &acc_add_dft[j], i, scratch7);
module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &acc[j], i);
module.vec_znx_big_normalize(basek, &mut acc[j], i, &acc_add_big, 0, scratch7);
});
@@ -245,17 +261,19 @@ pub(crate) fn cggi_blind_rotate_block_binary_extended<DataRes, DataIn, DataBrk>(
});
}
pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn, DataBrk>(
module: &Module<FFT64>,
pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn, DataBrk, B: Backend>(
module: &Module<B>,
res: &mut GLWECiphertext<DataRes>,
lwe: &LWECiphertext<DataIn>,
lut: &LookUpTable,
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
scratch: &mut Scratch,
brk: &BlindRotationKeyCGGIExec<DataBrk, B>,
scratch: &mut Scratch<B>,
) where
DataRes: AsRef<[u8]> + AsMut<[u8]>,
DataIn: AsRef<[u8]>,
DataBrk: AsRef<[u8]>,
DataRes: DataMut,
DataIn: DataRef,
DataBrk: DataRef,
Module<B>: CCGIBlindRotationFamily<B>,
Scratch<B>: TakeVecZnxDft<B> + TakeVecZnxBig<B>,
{
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
let mut out_mut: GLWECiphertext<&mut [u8]> = res.to_mut();
@@ -280,15 +298,12 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn, DataBrk>(
// ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)]
let (mut acc_dft, scratch1) = scratch.tmp_vec_znx_dft(module, cols, rows);
let (mut vmp_res, scratch2) = scratch1.tmp_vec_znx_dft(module, cols, brk.size());
let (mut acc_add_dft, scratch3) = scratch2.tmp_vec_znx_dft(module, cols, brk.size());
let (mut minus_one, scratch4) = scratch3.tmp_scalar_znx_dft(module, 1);
let (mut xai_plus_y_dft, scratch5) = scratch4.tmp_scalar_znx_dft(module, 1);
let (mut acc_dft, scratch1) = scratch.take_vec_znx_dft(module, cols, rows);
let (mut vmp_res, scratch2) = scratch1.take_vec_znx_dft(module, cols, brk.size());
let (mut acc_add_dft, scratch3) = scratch2.take_vec_znx_dft(module, cols, brk.size());
let (mut vmp_xai, scratch4) = scratch3.take_vec_znx_dft(module, 1, brk.size());
minus_one.raw_mut()[..module.n() >> 1].fill(-1.0);
let x_pow_a: &Vec<ScalarZnxDft<Vec<u8>, FFT64>>;
let x_pow_a: &Vec<SvpPPol<Vec<u8>, B>>;
if let Some(b) = &brk.x_pow_a {
x_pow_a = b
} else {
@@ -301,50 +316,50 @@ pub(crate) fn cggi_blind_rotate_block_binary<DataRes, DataIn, DataBrk>(
)
.for_each(|(ai, ski)| {
(0..cols).for_each(|j| {
module.vec_znx_dft(1, 0, &mut acc_dft, j, &out_mut.data, j);
module.vec_znx_dft_from_vec_znx(1, 0, &mut acc_dft, j, &out_mut.data, j);
});
acc_add_dft.zero();
module.vec_znx_dft_zero(&mut acc_add_dft);
izip!(ai.iter(), ski.iter()).for_each(|(aii, skii)| {
let ai_pos: usize = ((aii + two_n as i64) & (two_n - 1) as i64) as usize;
// vmp_res = DFT(acc) * BRK[i]
module.vmp_apply(&mut vmp_res, &acc_dft, &skii.data, scratch5);
// DFT(X^ai -1)
module.vec_znx_dft_add(&mut xai_plus_y_dft, 0, &x_pow_a[ai_pos], 0, &minus_one, 0);
module.vmp_apply(&mut vmp_res, &acc_dft, &skii.data, scratch4);
// DFT(X^ai -1) * (DFT(acc) * BRK[i])
(0..cols).for_each(|i| {
module.svp_apply_inplace(&mut vmp_res, i, &xai_plus_y_dft, 0);
module.vec_znx_dft_add_inplace(&mut acc_add_dft, i, &vmp_res, i);
module.svp_apply(&mut vmp_xai, 0, &x_pow_a[ai_pos], 0, &vmp_res, i);
module.vec_znx_dft_add_inplace(&mut acc_add_dft, i, &vmp_xai, 0);
module.vec_znx_dft_sub_ab_inplace(&mut acc_add_dft, i, &vmp_res, i);
});
});
{
let (mut acc_add_big, scratch6) = scratch5.tmp_vec_znx_big(module, 1, brk.size());
let (mut acc_add_big, scratch5) = scratch4.take_vec_znx_big(module, 1, brk.size());
(0..cols).for_each(|i| {
module.vec_znx_idft(&mut acc_add_big, 0, &acc_add_dft, i, scratch6);
module.vec_znx_dft_to_vec_znx_big(&mut acc_add_big, 0, &acc_add_dft, i, scratch5);
module.vec_znx_big_add_small_inplace(&mut acc_add_big, 0, &out_mut.data, i);
module.vec_znx_big_normalize(basek, &mut out_mut.data, i, &acc_add_big, 0, scratch6);
module.vec_znx_big_normalize(basek, &mut out_mut.data, i, &acc_add_big, 0, scratch5);
});
}
});
}
pub(crate) fn cggi_blind_rotate_binary_standard<DataRes, DataIn, DataBrk>(
module: &Module<FFT64>,
pub(crate) fn cggi_blind_rotate_binary_standard<DataRes, DataIn, DataBrk, B: Backend>(
module: &Module<B>,
res: &mut GLWECiphertext<DataRes>,
lwe: &LWECiphertext<DataIn>,
lut: &LookUpTable,
brk: &BlindRotationKeyCGGI<DataBrk, FFT64>,
scratch: &mut Scratch,
brk: &BlindRotationKeyCGGIExec<DataBrk, B>,
scratch: &mut Scratch<B>,
) where
DataRes: AsRef<[u8]> + AsMut<[u8]>,
DataIn: AsRef<[u8]>,
DataBrk: AsRef<[u8]>,
DataRes: DataMut,
DataIn: DataRef,
DataBrk: DataRef,
Module<B>: CCGIBlindRotationFamily<B>,
Scratch<B>: TakeVecZnxDft<B> + TakeVecZnxBig<B> + TakeVecZnx<B> + ScratchAvailable,
{
#[cfg(debug_assertions)]
{
@@ -401,28 +416,24 @@ pub(crate) fn cggi_blind_rotate_binary_standard<DataRes, DataIn, DataBrk>(
module.vec_znx_rotate(b, &mut out_mut.data, 0, &lut.data[0], 0);
// ACC + [sum DFT(X^ai -1) * (DFT(ACC) x BRKi)]
let (mut acc_tmp, scratch1) = scratch.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank());
let (mut acc_tmp_rot, scratch2) = scratch1.tmp_glwe_ct(module, basek, out_mut.k(), out_mut.rank());
let (mut acc_tmp, scratch1) = scratch.take_glwe_ct(module, basek, out_mut.k(), out_mut.rank());
// TODO: see if faster by skipping normalization in external product and keeping acc in big coeffs
// TODO: first iteration can be optimized to be a gglwe product
izip!(a.iter(), brk.data.iter()).for_each(|(ai, ski)| {
// acc_tmp = sk[i] * acc
acc_tmp.external_product(module, &out_mut, ski, scratch2);
acc_tmp.external_product(module, &out_mut, ski, scratch1);
// acc_tmp = (sk[i] * acc) * X^{ai}
acc_tmp_rot.rotate(module, *ai, &acc_tmp);
// acc_tmp = (sk[i] * acc) * (X^{ai} - 1)
acc_tmp.mul_xp_minus_one_inplace(module, *ai);
// acc = acc + (sk[i] * acc) * X^{ai}
out_mut.add_inplace(module, &acc_tmp_rot);
// acc = acc + (sk[i] * acc) * X^{ai} - (sk[i] * acc) = acc + (sk[i] * acc) * (X^{ai} - 1)
out_mut.sub_inplace_ab(module, &acc_tmp);
// acc = acc + (sk[i] * acc) * (X^{ai} - 1)
out_mut.add_inplace(module, &acc_tmp);
});
// We can normalize only at the end because we add normalized values in [-2^{basek-1}, 2^{basek-1}]
// on top of each others, thus ~ 2^{63-basek} additions are supported before overflow.
out_mut.normalize_inplace(module, scratch2);
out_mut.normalize_inplace(module, scratch1);
}
pub(crate) fn negate_and_mod_switch_2n(n: usize, res: &mut [i64], lwe: &LWECiphertext<&[u8]>) {

View File

@@ -1,39 +1,185 @@
use backend::{
Backend, FFT64, Module, ScalarZnx, ScalarZnxAlloc, ScalarZnxDft, ScalarZnxDftAlloc, ScalarZnxDftOps, ScalarZnxToRef, Scratch,
ZnxView, ZnxViewMut,
use backend::hal::{
api::{
MatZnxAlloc, ScalarZnxAlloc, ScratchAvailable, SvpPPolAlloc, SvpPrepare, TakeVecZnx, TakeVecZnxDft,
VecZnxAddScalarInplace, VecZnxAllocBytes, ZnxView, ZnxViewMut,
},
layouts::{Backend, Data, DataMut, DataRef, Module, ReaderFrom, ScalarZnx, ScalarZnxToRef, Scratch, SvpPPol, WriterTo},
};
use sampling::source::Source;
use crate::{Distribution, FourierGLWESecret, GGSWCiphertext, Infos, LWESecret};
use crate::{
Distribution, GGSWCiphertext, GGSWCiphertextExec, GGSWEncryptSkFamily, GGSWLayoutFamily, GLWESecretExec, Infos, LWESecret,
};
pub struct BlindRotationKeyCGGI<D, B: Backend> {
pub(crate) data: Vec<GGSWCiphertext<D, B>>,
pub struct BlindRotationKeyCGGI<D: Data> {
pub(crate) keys: Vec<GGSWCiphertext<D>>,
pub(crate) dist: Distribution,
pub(crate) x_pow_a: Option<Vec<ScalarZnxDft<Vec<u8>, B>>>,
}
// pub struct BlindRotationKeyFHEW<B: Backend> {
// pub(crate) data: Vec<GGSWCiphertext<Vec<u8>, B>>,
// pub(crate) auto: Vec<GLWEAutomorphismKey<Vec<u8>, B>>,
//}
impl<D: Data> PartialEq for BlindRotationKeyCGGI<D> {
fn eq(&self, other: &Self) -> bool {
if self.keys.len() != other.keys.len() {
return false;
}
for (a, b) in self.keys.iter().zip(other.keys.iter()) {
if a != b {
return false;
}
}
self.dist == other.dist
}
}
impl BlindRotationKeyCGGI<Vec<u8>, FFT64> {
pub fn allocate(module: &Module<FFT64>, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
let mut data: Vec<GGSWCiphertext<Vec<u8>, FFT64>> = Vec::with_capacity(n_lwe);
impl<D: Data> Eq for BlindRotationKeyCGGI<D> {}
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
impl<D: DataMut> ReaderFrom for BlindRotationKeyCGGI<D> {
fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
match Distribution::read_from(reader) {
Ok(dist) => self.dist = dist,
Err(e) => return Err(e),
}
let len: usize = reader.read_u64::<LittleEndian>()? as usize;
if self.keys.len() != len {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("self.keys.len()={} != read len={}", self.keys.len(), len),
));
}
for key in &mut self.keys {
key.read_from(reader)?;
}
Ok(())
}
}
impl<D: DataRef> WriterTo for BlindRotationKeyCGGI<D> {
fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
match self.dist.write_to(writer) {
Ok(()) => {}
Err(e) => return Err(e),
}
writer.write_u64::<LittleEndian>(self.keys.len() as u64)?;
for key in &self.keys {
key.write_to(writer)?;
}
Ok(())
}
}
impl BlindRotationKeyCGGI<Vec<u8>> {
pub fn alloc<B: Backend>(module: &Module<B>, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self
where
Module<B>: MatZnxAlloc,
{
let mut data: Vec<GGSWCiphertext<Vec<u8>>> = Vec::with_capacity(n_lwe);
(0..n_lwe).for_each(|_| data.push(GGSWCiphertext::alloc(module, basek, k, rows, 1, rank)));
Self {
data,
keys: data,
dist: Distribution::NONE,
x_pow_a: None::<Vec<ScalarZnxDft<Vec<u8>, FFT64>>>,
}
}
pub fn generate_from_sk_scratch_space(module: &Module<FFT64>, basek: usize, k: usize, rank: usize) -> usize {
pub fn generate_from_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
where
Module<B>: GGSWEncryptSkFamily<B> + VecZnxAllocBytes,
{
GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank)
}
}
impl<D: AsRef<[u8]>> BlindRotationKeyCGGI<D, FFT64> {
impl<D: DataRef> BlindRotationKeyCGGI<D> {
#[allow(dead_code)]
pub(crate) fn n(&self) -> usize {
self.keys[0].n()
}
#[allow(dead_code)]
pub(crate) fn rows(&self) -> usize {
self.keys[0].rows()
}
#[allow(dead_code)]
pub(crate) fn k(&self) -> usize {
self.keys[0].k()
}
#[allow(dead_code)]
pub(crate) fn size(&self) -> usize {
self.keys[0].size()
}
#[allow(dead_code)]
pub(crate) fn rank(&self) -> usize {
self.keys[0].rank()
}
pub(crate) fn basek(&self) -> usize {
self.keys[0].basek()
}
#[allow(dead_code)]
pub(crate) fn block_size(&self) -> usize {
match self.dist {
Distribution::BinaryBlock(value) => value,
_ => 1,
}
}
}
impl<D: DataMut> BlindRotationKeyCGGI<D> {
pub fn generate_from_sk<DataSkGLWE, DataSkLWE, B: Backend>(
&mut self,
module: &Module<B>,
sk_glwe: &GLWESecretExec<DataSkGLWE, B>,
sk_lwe: &LWESecret<DataSkLWE>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch<B>,
) where
DataSkGLWE: DataRef,
DataSkLWE: DataRef,
Module<B>: GGSWEncryptSkFamily<B> + ScalarZnxAlloc + VecZnxAddScalarInplace,
Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx<B>,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.keys.len(), sk_lwe.n());
assert_eq!(sk_glwe.n(), module.n());
assert_eq!(sk_glwe.rank(), self.keys[0].rank());
match sk_lwe.dist {
Distribution::BinaryBlock(_)
| Distribution::BinaryFixed(_)
| Distribution::BinaryProb(_)
| Distribution::ZERO => {}
_ => panic!(
"invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)"
),
}
}
self.dist = sk_lwe.dist;
let mut pt: ScalarZnx<Vec<u8>> = module.scalar_znx_alloc(1);
let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data.to_ref();
self.keys.iter_mut().enumerate().for_each(|(i, ggsw)| {
pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i];
ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, sigma, scratch);
});
}
}
#[derive(PartialEq, Eq)]
pub struct BlindRotationKeyCGGIExec<D: Data, B: Backend> {
pub(crate) data: Vec<GGSWCiphertextExec<D, B>>,
pub(crate) dist: Distribution,
pub(crate) x_pow_a: Option<Vec<SvpPPol<Vec<u8>, B>>>,
}
impl<D: Data, B: Backend> BlindRotationKeyCGGIExec<D, B> {
#[allow(dead_code)]
pub(crate) fn n(&self) -> usize {
self.data[0].n()
@@ -71,52 +217,66 @@ impl<D: AsRef<[u8]>> BlindRotationKeyCGGI<D, FFT64> {
}
}
impl<D: AsRef<[u8]> + AsMut<[u8]>> BlindRotationKeyCGGI<D, FFT64> {
pub fn generate_from_sk<DataSkGLWE, DataSkLWE>(
&mut self,
module: &Module<FFT64>,
sk_glwe: &FourierGLWESecret<DataSkGLWE, FFT64>,
sk_lwe: &LWESecret<DataSkLWE>,
source_xa: &mut Source,
source_xe: &mut Source,
sigma: f64,
scratch: &mut Scratch,
) where
DataSkGLWE: AsRef<[u8]>,
DataSkLWE: AsRef<[u8]>,
pub trait BlindRotationKeyCGGIExecLayoutFamily<B: Backend> = GGSWLayoutFamily<B> + SvpPPolAlloc<B> + SvpPrepare<B>;
impl<B: Backend> BlindRotationKeyCGGIExec<Vec<u8>, B> {
pub fn alloc(module: &Module<B>, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self
where
Module<B>: BlindRotationKeyCGGIExecLayoutFamily<B>,
{
let mut data: Vec<GGSWCiphertextExec<Vec<u8>, B>> = Vec::with_capacity(n_lwe);
(0..n_lwe).for_each(|_| data.push(GGSWCiphertextExec::alloc(module, basek, k, rows, 1, rank)));
Self {
data,
dist: Distribution::NONE,
x_pow_a: None,
}
}
pub fn from<DataOther>(module: &Module<B>, other: &BlindRotationKeyCGGI<DataOther>, scratch: &mut Scratch<B>) -> Self
where
DataOther: DataRef,
Module<B>: BlindRotationKeyCGGIExecLayoutFamily<B> + ScalarZnxAlloc,
{
let mut brk: BlindRotationKeyCGGIExec<Vec<u8>, B> = Self::alloc(
module,
other.keys.len(),
other.basek(),
other.k(),
other.rows(),
other.rank(),
);
brk.prepare(module, other, scratch);
brk
}
}
impl<D: DataMut, B: Backend> BlindRotationKeyCGGIExec<D, B> {
pub fn prepare<DataOther>(&mut self, module: &Module<B>, other: &BlindRotationKeyCGGI<DataOther>, scratch: &mut Scratch<B>)
where
DataOther: DataRef,
Module<B>: BlindRotationKeyCGGIExecLayoutFamily<B> + ScalarZnxAlloc,
{
#[cfg(debug_assertions)]
{
assert_eq!(self.data.len(), sk_lwe.n());
assert_eq!(sk_glwe.n(), module.n());
assert_eq!(sk_glwe.rank(), self.data[0].rank());
match sk_lwe.dist {
Distribution::BinaryBlock(_)
| Distribution::BinaryFixed(_)
| Distribution::BinaryProb(_)
| Distribution::ZERO => {}
_ => panic!(
"invalid GLWESecret distribution: must be BinaryBlock, BinaryFixed or BinaryProb (or ZERO for debugging)"
),
}
assert_eq!(self.data.len(), other.keys.len());
}
self.dist = sk_lwe.dist;
self.data
.iter_mut()
.zip(other.keys.iter())
.for_each(|(ggsw_exec, other)| {
ggsw_exec.prepare(module, other, scratch);
});
let mut pt: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let sk_ref: ScalarZnx<&[u8]> = sk_lwe.data.to_ref();
self.dist = other.dist;
self.data.iter_mut().enumerate().for_each(|(i, ggsw)| {
pt.at_mut(0, 0)[0] = sk_ref.at(0, 0)[i];
ggsw.encrypt_sk(module, &pt, sk_glwe, source_xa, source_xe, sigma, scratch);
});
match sk_lwe.dist {
match other.dist {
Distribution::BinaryBlock(_) => {
let mut x_pow_a: Vec<ScalarZnxDft<Vec<u8>, FFT64>> = Vec::with_capacity(module.n() << 1);
let mut buf: ScalarZnx<Vec<u8>> = module.new_scalar_znx(1);
let mut x_pow_a: Vec<SvpPPol<Vec<u8>, B>> = Vec::with_capacity(module.n() << 1);
let mut buf: ScalarZnx<Vec<u8>> = module.scalar_znx_alloc(1);
(0..module.n() << 1).for_each(|i| {
let mut res: ScalarZnxDft<Vec<u8>, FFT64> = module.new_scalar_znx_dft(1);
let mut res: SvpPPol<Vec<u8>, B> = module.svp_ppol_alloc(1);
set_xai_plus_y(module, i, 0, &mut res, &mut buf);
x_pow_a.push(res);
});
@@ -127,10 +287,11 @@ impl<D: AsRef<[u8]> + AsMut<[u8]>> BlindRotationKeyCGGI<D, FFT64> {
}
}
pub fn set_xai_plus_y<A, B>(module: &Module<FFT64>, ai: usize, y: i64, res: &mut ScalarZnxDft<A, FFT64>, buf: &mut ScalarZnx<B>)
pub fn set_xai_plus_y<A, C, B: Backend>(module: &Module<B>, ai: usize, y: i64, res: &mut SvpPPol<A, B>, buf: &mut ScalarZnx<C>)
where
A: AsRef<[u8]> + AsMut<[u8]>,
B: AsRef<[u8]> + AsMut<[u8]>,
A: DataMut,
C: DataMut,
Module<B>: SvpPrepare<B>,
{
let n: usize = module.n();

View File

@@ -1,4 +1,11 @@
use backend::{FFT64, Module, VecZnx, VecZnxAlloc, VecZnxOps, ZnxInfos, ZnxViewMut, alloc_aligned};
use backend::hal::{
api::{
ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAlloc, VecZnxCopy, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes,
VecZnxRotateInplace, VecZnxSwithcDegree, ZnxInfos, ZnxViewMut,
},
layouts::{Backend, Module, ScratchOwned, VecZnx},
oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl},
};
pub struct LookUpTable {
pub(crate) data: Vec<VecZnx<Vec<u8>>>,
@@ -7,7 +14,10 @@ pub struct LookUpTable {
}
impl LookUpTable {
pub fn alloc(module: &Module<FFT64>, basek: usize, k: usize, extension_factor: usize) -> Self {
pub fn alloc<B: Backend>(module: &Module<B>, basek: usize, k: usize, extension_factor: usize) -> Self
where
Module<B>: VecZnxAlloc,
{
#[cfg(debug_assertions)]
{
assert!(
@@ -19,7 +29,7 @@ impl LookUpTable {
let size: usize = k.div_ceil(basek);
let mut data: Vec<VecZnx<Vec<u8>>> = Vec::with_capacity(extension_factor);
(0..extension_factor).for_each(|_| {
data.push(module.new_vec_znx(1, size));
data.push(module.vec_znx_alloc(1, size));
});
Self { data, basek, k }
}
@@ -36,7 +46,11 @@ impl LookUpTable {
self.data.len() * self.data[0].n()
}
pub fn set(&mut self, module: &Module<FFT64>, f: &Vec<i64>, k: usize) {
pub fn set<B: Backend>(&mut self, module: &Module<B>, f: &Vec<i64>, k: usize)
where
Module<B>: VecZnxRotateInplace + VecZnxNormalizeInplace<B> + VecZnxNormalizeTmpBytes + VecZnxSwithcDegree + VecZnxCopy,
B: ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
{
assert!(f.len() <= module.n());
let basek: usize = self.basek;
@@ -74,16 +88,22 @@ impl LookUpTable {
// Rotates half the step to the left
let half_step: usize = domain_size.div_round(f_len << 1);
lut_full.rotate(-(half_step as i64));
module.vec_znx_rotate_inplace(-(half_step as i64), &mut lut_full, 0);
let mut tmp_bytes: Vec<u8> = alloc_aligned(lut_full.n() * size_of::<i64>());
lut_full.normalize(self.basek, 0, &mut tmp_bytes);
let n_large: usize = lut_full.n();
module.vec_znx_normalize_inplace(
self.basek,
&mut lut_full,
0,
ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes(n_large)).borrow(),
);
if self.extension_factor() > 1 {
(0..self.extension_factor()).for_each(|i| {
module.switch_degree(&mut self.data[i], 0, &lut_full, 0);
module.vec_znx_switch_degree(&mut self.data[i], 0, &lut_full, 0);
if i < self.extension_factor() {
lut_full.rotate(-1);
module.vec_znx_rotate_inplace(-1, &mut lut_full, 0);
}
});
} else {
@@ -92,7 +112,10 @@ impl LookUpTable {
}
#[allow(dead_code)]
pub(crate) fn rotate(&mut self, k: i64) {
pub(crate) fn rotate<B: Backend>(&mut self, module: &Module<B>, k: i64)
where
Module<B>: VecZnxRotateInplace,
{
let extension_factor: usize = self.extension_factor();
let two_n: usize = 2 * self.data[0].n();
let two_n_ext: usize = two_n * extension_factor;
@@ -103,11 +126,11 @@ impl LookUpTable {
let k_lo: usize = k_pos % extension_factor;
(0..extension_factor - k_lo).for_each(|i| {
self.data[i].rotate(k_hi as i64);
module.vec_znx_rotate_inplace(k_hi as i64, &mut self.data[i], 0);
});
(extension_factor - k_lo..extension_factor).for_each(|i| {
self.data[i].rotate(k_hi as i64 + 1);
module.vec_znx_rotate_inplace(k_hi as i64 + 1, &mut self.data[i], 0);
});
self.data.rotate_right(k_lo as usize);

View File

@@ -2,9 +2,9 @@ pub mod cggi;
pub mod key;
pub mod lut;
pub use cggi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space};
pub use key::BlindRotationKeyCGGI;
pub use cggi::{CCGIBlindRotationFamily, cggi_blind_rotate, cggi_blind_rotate_scratch_space};
pub use key::{BlindRotationKeyCGGI, BlindRotationKeyCGGIExec, BlindRotationKeyCGGIExecLayoutFamily};
pub use lut::LookUpTable;
#[cfg(test)]
pub mod test_fft64;
mod test;

View File

@@ -0,0 +1,179 @@
use backend::{
hal::{
api::{
MatZnxAlloc, ModuleNew, ScalarZnxAlloc, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddNormal,
VecZnxAddScalarInplace, VecZnxAlloc, VecZnxAllocBytes, VecZnxEncodeCoeffsi64, VecZnxFillUniform, VecZnxRotateInplace,
VecZnxSwithcDegree, ZnxView,
},
layouts::{Backend, Module, ScratchOwned},
oep::{
ScratchAvailableImpl, ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl,
TakeVecZnxDftSliceImpl, TakeVecZnxImpl, TakeVecZnxSliceImpl, VecZnxBigAllocBytesImpl, VecZnxDftAllocBytesImpl,
},
},
implementation::cpu_spqlios::FFT64,
};
use sampling::source::Source;
use crate::{
BlindRotationKeyCGGIExecLayoutFamily, CCGIBlindRotationFamily, GLWECiphertext, GLWEDecryptFamily, GLWEPlaintext, GLWESecret,
GLWESecretExec, GLWESecretFamily, Infos, LWECiphertext, LWESecret,
blind_rotation::{
cggi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, negate_and_mod_switch_2n},
key::{BlindRotationKeyCGGI, BlindRotationKeyCGGIExec},
lut::LookUpTable,
},
lwe::{LWEPlaintext, ciphertext::LWECiphertextToRef},
};
#[test]
fn standard() {
let module: Module<FFT64> = Module::<FFT64>::new(512);
blind_rotatio_test(&module, 224, 1, 1);
}
#[test]
fn block_binary() {
let module: Module<FFT64> = Module::<FFT64>::new(512);
blind_rotatio_test(&module, 224, 7, 1);
}
#[test]
fn block_binary_extended() {
let module: Module<FFT64> = Module::<FFT64>::new(512);
blind_rotatio_test(&module, 224, 7, 2);
}
pub(crate) trait CGGITestModuleFamily<B: Backend> = CCGIBlindRotationFamily<B>
+ GLWESecretFamily<B>
+ GLWEDecryptFamily<B>
+ BlindRotationKeyCGGIExecLayoutFamily<B>
+ VecZnxAlloc
+ ScalarZnxAlloc
+ VecZnxFillUniform
+ VecZnxAddNormal
+ VecZnxAllocBytes
+ VecZnxAddScalarInplace
+ VecZnxEncodeCoeffsi64
+ VecZnxRotateInplace
+ VecZnxSwithcDegree
+ MatZnxAlloc;
pub(crate) trait CGGITestScratchFamily<B: Backend> = VecZnxDftAllocBytesImpl<B>
+ VecZnxBigAllocBytesImpl<B>
+ ScratchOwnedAllocImpl<B>
+ ScratchOwnedBorrowImpl<B>
+ TakeVecZnxDftImpl<B>
+ TakeVecZnxBigImpl<B>
+ TakeVecZnxDftSliceImpl<B>
+ ScratchAvailableImpl<B>
+ TakeVecZnxImpl<B>
+ TakeVecZnxSliceImpl<B>;
fn blind_rotatio_test<B: Backend>(module: &Module<B>, n_lwe: usize, block_size: usize, extension_factor: usize)
where
Module<B>: CGGITestModuleFamily<B>,
B: CGGITestScratchFamily<B>,
{
let basek: usize = 19;
let k_lwe: usize = 24;
let k_brk: usize = 3 * basek;
let rows_brk: usize = 2; // Ensures first limb is noise-free.
let k_lut: usize = 1 * basek;
let k_res: usize = 2 * basek;
let rank: usize = 1;
let message_modulus: usize = 1 << 4;
let mut source_xs: Source = Source::new([2u8; 32]);
let mut source_xe: Source = Source::new([2u8; 32]);
let mut source_xa: Source = Source::new([1u8; 32]);
let mut sk_glwe: GLWESecret<Vec<u8>> = GLWESecret::alloc(module, rank);
sk_glwe.fill_ternary_prob(0.5, &mut source_xs);
let sk_glwe_dft: GLWESecretExec<Vec<u8>, B> = GLWESecretExec::from(module, &sk_glwe);
let mut sk_lwe: LWESecret<Vec<u8>> = LWESecret::alloc(n_lwe);
sk_lwe.fill_binary_block(block_size, &mut source_xs);
let mut scratch: ScratchOwned<B> = ScratchOwned::<B>::alloc(BlindRotationKeyCGGI::generate_from_sk_scratch_space(
module, basek, k_brk, rank,
));
let mut scratch_br: ScratchOwned<B> = ScratchOwned::<B>::alloc(cggi_blind_rotate_scratch_space(
module,
block_size,
extension_factor,
basek,
k_res,
k_brk,
rows_brk,
rank,
));
let mut brk: BlindRotationKeyCGGI<Vec<u8>> = BlindRotationKeyCGGI::alloc(module, n_lwe, basek, k_brk, rows_brk, rank);
brk.generate_from_sk(
module,
&sk_glwe_dft,
&sk_lwe,
&mut source_xa,
&mut source_xe,
3.2,
scratch.borrow(),
);
let mut lwe: LWECiphertext<Vec<u8>> = LWECiphertext::alloc(n_lwe, basek, k_lwe);
let mut pt_lwe: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc(basek, k_lwe);
let x: i64 = 2;
let bits: usize = 8;
module.encode_coeff_i64(basek, &mut pt_lwe.data, 0, bits, 0, x, bits);
lwe.encrypt_sk(
module,
&pt_lwe,
&sk_lwe,
&mut source_xa,
&mut source_xe,
3.2,
);
let mut f: Vec<i64> = vec![0i64; message_modulus];
f.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = 2 * (i as i64) + 1);
let mut lut: LookUpTable = LookUpTable::alloc(module, basek, k_lut, extension_factor);
lut.set(module, &f, message_modulus);
let mut res: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(module, basek, k_res, rank);
let brk_exec: BlindRotationKeyCGGIExec<Vec<u8>, B> = BlindRotationKeyCGGIExec::from(module, &brk, scratch_br.borrow());
cggi_blind_rotate(module, &mut res, &lwe, &lut, &brk_exec, scratch_br.borrow());
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(module, basek, k_res);
res.decrypt(module, &mut pt_have, &sk_glwe_dft, scratch.borrow());
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe.to_ref());
let pt_want: i64 = (lwe_2n[0]
+ lwe_2n[1..]
.iter()
.zip(sk_lwe.data.at(0, 0))
.map(|(x, y)| x * y)
.sum::<i64>())
& (2 * lut.domain_size() - 1) as i64;
lut.rotate(module, pt_want);
// First limb should be exactly equal (test are parameterized such that the noise does not reach
// the first limb)
assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0));
}

View File

@@ -1,6 +1,12 @@
use std::vec;
use backend::{FFT64, Module, ZnxView};
use backend::{
hal::{
api::{ModuleNew, ZnxView},
layouts::Module,
},
implementation::cpu_spqlios::FFT64,
};
use crate::blind_rotation::lut::{DivRound, LookUpTable};
@@ -23,7 +29,7 @@ fn standard() {
lut.set(&module, &f, log_scale);
let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64;
lut.rotate(half_step);
lut.rotate(&module, half_step);
let step: usize = lut.domain_size().div_round(message_modulus);
@@ -33,7 +39,7 @@ fn standard() {
f[i / step] % message_modulus as i64,
lut.data[0].raw()[0] / (1 << (log_scale % basek)) as i64
);
lut.rotate(-1);
lut.rotate(&module, -1);
});
});
}
@@ -57,7 +63,7 @@ fn extended() {
lut.set(&module, &f, log_scale);
let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64;
lut.rotate(half_step);
lut.rotate(&module, half_step);
let step: usize = lut.domain_size().div_round(message_modulus);
@@ -67,7 +73,7 @@ fn extended() {
f[i / step] % message_modulus as i64,
lut.data[0].raw()[0] / (1 << (log_scale % basek)) as i64
);
lut.rotate(-1);
lut.rotate(&module, -1);
});
});
}

View File

@@ -1,125 +0,0 @@
use backend::{Encoding, FFT64, Module, ScratchOwned, ZnxView};
use sampling::source::Source;
use crate::{
FourierGLWESecret, GLWECiphertext, GLWEPlaintext, GLWESecret, Infos, LWECiphertext, LWESecret,
blind_rotation::{
cggi::{cggi_blind_rotate, cggi_blind_rotate_scratch_space, negate_and_mod_switch_2n},
key::BlindRotationKeyCGGI,
lut::LookUpTable,
},
lwe::{LWEPlaintext, ciphertext::LWECiphertextToRef},
};
#[test]
fn standard() {
blind_rotatio_test(224, 1, 1);
}
#[test]
fn block_binary() {
blind_rotatio_test(224, 7, 1);
}
#[test]
fn block_binary_extended() {
blind_rotatio_test(224, 7, 2);
}
fn blind_rotatio_test(n_lwe: usize, block_size: usize, extension_factor: usize) {
let module: Module<FFT64> = Module::<FFT64>::new(512);
let basek: usize = 19;
let k_lwe: usize = 24;
let k_brk: usize = 3 * basek;
let rows_brk: usize = 2; // Ensures first limb is noise-free.
let k_lut: usize = 1 * basek;
let k_res: usize = 2 * basek;
let rank: usize = 1;
let message_modulus: usize = 1 << 4;
let mut source_xs: Source = Source::new([2u8; 32]);
let mut source_xe: Source = Source::new([2u8; 32]);
let mut source_xa: Source = Source::new([1u8; 32]);
let mut sk_glwe: GLWESecret<Vec<u8>> = GLWESecret::alloc(&module, rank);
sk_glwe.fill_ternary_prob(0.5, &mut source_xs);
let sk_glwe_dft: FourierGLWESecret<Vec<u8>, FFT64> = FourierGLWESecret::from(&module, &sk_glwe);
let mut sk_lwe: LWESecret<Vec<u8>> = LWESecret::alloc(n_lwe);
sk_lwe.fill_binary_block(block_size, &mut source_xs);
let mut scratch: ScratchOwned = ScratchOwned::new(BlindRotationKeyCGGI::generate_from_sk_scratch_space(
&module, basek, k_brk, rank,
));
let mut scratch_br: ScratchOwned = ScratchOwned::new(cggi_blind_rotate_scratch_space(
&module,
block_size,
extension_factor,
basek,
k_res,
k_brk,
rows_brk,
rank,
));
let mut brk: BlindRotationKeyCGGI<Vec<u8>, FFT64> =
BlindRotationKeyCGGI::allocate(&module, n_lwe, basek, k_brk, rows_brk, rank);
brk.generate_from_sk(
&module,
&sk_glwe_dft,
&sk_lwe,
&mut source_xa,
&mut source_xe,
3.2,
scratch.borrow(),
);
let mut lwe: LWECiphertext<Vec<u8>> = LWECiphertext::alloc(n_lwe, basek, k_lwe);
let mut pt_lwe: LWEPlaintext<Vec<u8>> = LWEPlaintext::alloc(basek, k_lwe);
let x: i64 = 2;
let bits: usize = 8;
pt_lwe.data.encode_coeff_i64(0, basek, bits, 0, x, bits);
lwe.encrypt_sk(&pt_lwe, &sk_lwe, &mut source_xa, &mut source_xe, 3.2);
let mut f: Vec<i64> = vec![0i64; message_modulus];
f.iter_mut()
.enumerate()
.for_each(|(i, x)| *x = 2 * (i as i64) + 1);
let mut lut: LookUpTable = LookUpTable::alloc(&module, basek, k_lut, extension_factor);
lut.set(&module, &f, message_modulus);
let mut res: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(&module, basek, k_res, rank);
cggi_blind_rotate(&module, &mut res, &lwe, &lut, &brk, scratch_br.borrow());
let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(&module, basek, k_res);
res.decrypt(&module, &mut pt_have, &sk_glwe_dft, scratch.borrow());
let mut lwe_2n: Vec<i64> = vec![0i64; lwe.n() + 1]; // TODO: from scratch space
negate_and_mod_switch_2n(2 * lut.domain_size(), &mut lwe_2n, &lwe.to_ref());
let pt_want: i64 = (lwe_2n[0]
+ lwe_2n[1..]
.iter()
.zip(sk_lwe.data.at(0, 0))
.map(|(x, y)| x * y)
.sum::<i64>())
& (2 * lut.domain_size() - 1) as i64;
lut.rotate(pt_want);
// First limb should be exactly equal (test are parameterized such that the noise does not reach
// the first limb)
assert_eq!(pt_have.data.at(0, 0), lut.data[0].at(0, 0));
}