mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 13:16:44 +01:00
Add Zn type
This commit is contained in:
@@ -23,7 +23,6 @@ use crate::tfhe::blind_rotation::{
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn cggi_blind_rotate_scratch_space<B: Backend>(
|
||||
module: &Module<B>,
|
||||
n: usize,
|
||||
block_size: usize,
|
||||
extension_factor: usize,
|
||||
basek: usize,
|
||||
@@ -44,14 +43,14 @@ where
|
||||
|
||||
if block_size > 1 {
|
||||
let cols: usize = rank + 1;
|
||||
let acc_dft: usize = module.vec_znx_dft_alloc_bytes(n, cols, rows) * extension_factor;
|
||||
let acc_big: usize = module.vec_znx_big_alloc_bytes(n, 1, brk_size);
|
||||
let vmp_res: usize = module.vec_znx_dft_alloc_bytes(n, cols, brk_size) * extension_factor;
|
||||
let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(n, 1, brk_size);
|
||||
let acc_dft: usize = module.vec_znx_dft_alloc_bytes(cols, rows) * extension_factor;
|
||||
let acc_big: usize = module.vec_znx_big_alloc_bytes(1, brk_size);
|
||||
let vmp_res: usize = module.vec_znx_dft_alloc_bytes(cols, brk_size) * extension_factor;
|
||||
let vmp_xai: usize = module.vec_znx_dft_alloc_bytes(1, brk_size);
|
||||
let acc_dft_add: usize = vmp_res;
|
||||
let vmp: usize = module.vmp_apply_tmp_bytes(n, brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
|
||||
let vmp: usize = module.vmp_apply_tmp_bytes(brk_size, rows, rows, 2, 2, brk_size); // GGSW product: (1 x 2) x (2 x 2)
|
||||
let acc: usize = if extension_factor > 1 {
|
||||
VecZnx::alloc_bytes(n, cols, k_res.div_ceil(basek)) * extension_factor
|
||||
VecZnx::alloc_bytes(module.n(), cols, k_res.div_ceil(basek)) * extension_factor
|
||||
} else {
|
||||
0
|
||||
};
|
||||
@@ -60,10 +59,10 @@ where
|
||||
+ acc_dft_add
|
||||
+ vmp_res
|
||||
+ vmp_xai
|
||||
+ (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes(n) | module.vec_znx_dft_to_vec_znx_big_tmp_bytes(n))))
|
||||
+ (vmp | (acc_big + (module.vec_znx_big_normalize_tmp_bytes() | module.vec_znx_dft_to_vec_znx_big_tmp_bytes())))
|
||||
} else {
|
||||
GLWECiphertext::bytes_of(n, basek, k_res, rank)
|
||||
+ GLWECiphertext::external_product_scratch_space(module, n, basek, k_res, k_res, k_brk, 1, rank)
|
||||
GLWECiphertext::bytes_of(module.n(), basek, k_res, rank)
|
||||
+ GLWECiphertext::external_product_scratch_space(module, basek, k_res, k_res, k_brk, 1, rank)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -38,11 +38,11 @@ impl BlindRotationKeyAlloc for BlindRotationKey<Vec<u8>, CGGI> {
|
||||
}
|
||||
|
||||
impl BlindRotationKey<Vec<u8>, CGGI> {
|
||||
pub fn generate_from_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize, rank: usize) -> usize
|
||||
pub fn generate_from_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
|
||||
where
|
||||
Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
|
||||
{
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(module, n, basek, k, rank)
|
||||
GGSWCiphertext::encrypt_sk_scratch_space(module, basek, k, rank)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,11 +108,11 @@ impl<B: Backend> BlindRotationKeyPreparedAlloc<B> for BlindRotationKeyPrepared<V
|
||||
where
|
||||
Module<B>: VmpPMatAlloc<B> + VmpPrepare<B>,
|
||||
{
|
||||
fn alloc(module: &Module<B>, n_glwe: usize, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
|
||||
fn alloc(module: &Module<B>, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self {
|
||||
let mut data: Vec<GGSWCiphertextPrepared<Vec<u8>, B>> = Vec::with_capacity(n_lwe);
|
||||
(0..n_lwe).for_each(|_| {
|
||||
data.push(GGSWCiphertextPrepared::alloc(
|
||||
module, n_glwe, basek, k, rows, 1, rank,
|
||||
module, basek, k, rows, 1, rank,
|
||||
))
|
||||
});
|
||||
Self {
|
||||
@@ -139,11 +139,11 @@ impl BlindRotationKeyCompressed<Vec<u8>, CGGI> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_from_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize, rank: usize) -> usize
|
||||
pub fn generate_from_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
|
||||
where
|
||||
Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
|
||||
{
|
||||
GGSWCiphertextCompressed::encrypt_sk_scratch_space(module, n, basek, k, rank)
|
||||
GGSWCiphertextCompressed::encrypt_sk_scratch_space(module, basek, k, rank)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ use poulpy_core::{
|
||||
use crate::tfhe::blind_rotation::{BlindRotationAlgo, BlindRotationKey, utils::set_xai_plus_y};
|
||||
|
||||
pub trait BlindRotationKeyPreparedAlloc<B: Backend> {
|
||||
fn alloc(module: &Module<B>, n_glwe: usize, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self;
|
||||
fn alloc(module: &Module<B>, n_lwe: usize, basek: usize, k: usize, rows: usize, rank: usize) -> Self;
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
@@ -74,7 +74,6 @@ where
|
||||
fn prepare_alloc(&self, module: &Module<B>, scratch: &mut Scratch<B>) -> BlindRotationKeyPrepared<Vec<u8>, BRA, B> {
|
||||
let mut brk: BlindRotationKeyPrepared<Vec<u8>, BRA, B> = BlindRotationKeyPrepared::alloc(
|
||||
module,
|
||||
self.n(),
|
||||
self.keys.len(),
|
||||
self.basek(),
|
||||
self.k(),
|
||||
@@ -112,7 +111,7 @@ where
|
||||
let mut x_pow_a: Vec<SvpPPol<Vec<u8>, B>> = Vec::with_capacity(n << 1);
|
||||
let mut buf: ScalarZnx<Vec<u8>> = ScalarZnx::alloc(n, 1);
|
||||
(0..n << 1).for_each(|i| {
|
||||
let mut res: SvpPPol<Vec<u8>, B> = module.svp_ppol_alloc(n, 1);
|
||||
let mut res: SvpPPol<Vec<u8>, B> = module.svp_ppol_alloc(1);
|
||||
set_xai_plus_y(module, i, 0, &mut res, &mut buf);
|
||||
x_pow_a.push(res);
|
||||
});
|
||||
|
||||
@@ -22,7 +22,7 @@ pub struct LookUpTable {
|
||||
}
|
||||
|
||||
impl LookUpTable {
|
||||
pub fn alloc(n: usize, basek: usize, k: usize, extension_factor: usize) -> Self {
|
||||
pub fn alloc<B: Backend>(module: &Module<B>, basek: usize, k: usize, extension_factor: usize) -> Self {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(
|
||||
@@ -34,7 +34,7 @@ impl LookUpTable {
|
||||
let size: usize = k.div_ceil(basek);
|
||||
let mut data: Vec<VecZnx<Vec<u8>>> = Vec::with_capacity(extension_factor);
|
||||
(0..extension_factor).for_each(|_| {
|
||||
data.push(VecZnx::alloc(n, 1, size));
|
||||
data.push(VecZnx::alloc(module.n(), 1, size));
|
||||
});
|
||||
Self {
|
||||
data,
|
||||
@@ -121,16 +121,6 @@ impl LookUpTable {
|
||||
let drift: usize = step >> 1;
|
||||
|
||||
// Rotates half the step to the left
|
||||
module.vec_znx_rotate_inplace(-(drift as i64), &mut lut_full, 0);
|
||||
|
||||
let n_large: usize = lut_full.n();
|
||||
|
||||
module.vec_znx_normalize_inplace(
|
||||
self.basek,
|
||||
&mut lut_full,
|
||||
0,
|
||||
ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes(n_large)).borrow(),
|
||||
);
|
||||
|
||||
if self.extension_factor() > 1 {
|
||||
(0..self.extension_factor()).for_each(|i| {
|
||||
@@ -143,6 +133,14 @@ impl LookUpTable {
|
||||
module.vec_znx_copy(&mut self.data[0], 0, &lut_full, 0);
|
||||
}
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
|
||||
|
||||
self.data.iter_mut().for_each(|a| {
|
||||
module.vec_znx_normalize_inplace(self.basek, a, 0, scratch.borrow());
|
||||
});
|
||||
|
||||
self.rotate(module, -(drift as i64));
|
||||
|
||||
self.drift = drift
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,8 @@ use poulpy_hal::{
|
||||
VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftSubABInplace, VecZnxDftToVecZnxBig, VecZnxDftToVecZnxBigConsume,
|
||||
VecZnxDftToVecZnxBigTmpBytes, VecZnxDftZero, VecZnxFillUniform, VecZnxMulXpMinusOneInplace, VecZnxNormalize,
|
||||
VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxSub, VecZnxSubABInplace,
|
||||
VecZnxSwithcDegree, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, VmpPMatAlloc, VmpPrepare, ZnxView,
|
||||
VecZnxSwithcDegree, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, VmpPMatAlloc, VmpPrepare, ZnAddNormal, ZnFillUniform,
|
||||
ZnNormalizeInplace, ZnxView,
|
||||
},
|
||||
layouts::{Backend, Module, ScratchOwned},
|
||||
oep::{
|
||||
@@ -65,7 +66,10 @@ where
|
||||
+ VmpPMatAlloc<B>
|
||||
+ VmpPrepare<B>
|
||||
+ VmpApply<B>
|
||||
+ VmpApplyAdd<B>,
|
||||
+ VmpApplyAdd<B>
|
||||
+ ZnFillUniform
|
||||
+ ZnAddNormal
|
||||
+ ZnNormalizeInplace<B>,
|
||||
B: Backend
|
||||
+ VecZnxDftAllocBytesImpl<B>
|
||||
+ VecZnxBigAllocBytesImpl<B>
|
||||
@@ -96,7 +100,7 @@ where
|
||||
let mut source_xa: Source = Source::new([1u8; 32]);
|
||||
|
||||
let mut scratch: ScratchOwned<B> = ScratchOwned::<B>::alloc(BlindRotationKey::generate_from_sk_scratch_space(
|
||||
module, n, basek, k_brk, rank,
|
||||
module, basek, k_brk, rank,
|
||||
));
|
||||
|
||||
let mut sk_glwe: GLWESecret<Vec<u8>> = GLWESecret::alloc(n, rank);
|
||||
@@ -108,7 +112,6 @@ where
|
||||
|
||||
let mut scratch_br: ScratchOwned<B> = ScratchOwned::<B>::alloc(cggi_blind_rotate_scratch_space(
|
||||
module,
|
||||
n,
|
||||
block_size,
|
||||
extension_factor,
|
||||
basek,
|
||||
@@ -148,7 +151,7 @@ where
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = f(i as i64));
|
||||
|
||||
let mut lut: LookUpTable = LookUpTable::alloc(n, basek, k_lut, extension_factor);
|
||||
let mut lut: LookUpTable = LookUpTable::alloc(module, basek, k_lut, extension_factor);
|
||||
lut.set(module, &f_vec, log_message_modulus + 1);
|
||||
|
||||
let mut res: GLWECiphertext<Vec<u8>> = GLWECiphertext::alloc(n, basek, k_res, rank);
|
||||
|
||||
@@ -13,7 +13,6 @@ where
|
||||
Module<B>: VecZnxRotateInplace + VecZnxNormalizeInplace<B> + VecZnxNormalizeTmpBytes + VecZnxSwithcDegree + VecZnxCopy,
|
||||
B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 20;
|
||||
let k_lut: usize = 40;
|
||||
let message_modulus: usize = 16;
|
||||
@@ -26,7 +25,7 @@ where
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i as i64) - 8);
|
||||
|
||||
let mut lut: LookUpTable = LookUpTable::alloc(n, basek, k_lut, extension_factor);
|
||||
let mut lut: LookUpTable = LookUpTable::alloc(module, basek, k_lut, extension_factor);
|
||||
lut.set(module, &f, log_scale);
|
||||
|
||||
let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64;
|
||||
@@ -49,7 +48,6 @@ where
|
||||
Module<B>: VecZnxRotateInplace + VecZnxNormalizeInplace<B> + VecZnxNormalizeTmpBytes + VecZnxSwithcDegree + VecZnxCopy,
|
||||
B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
|
||||
{
|
||||
let n: usize = module.n();
|
||||
let basek: usize = 20;
|
||||
let k_lut: usize = 40;
|
||||
let message_modulus: usize = 16;
|
||||
@@ -62,7 +60,7 @@ where
|
||||
.enumerate()
|
||||
.for_each(|(i, x)| *x = (i as i64) - 8);
|
||||
|
||||
let mut lut: LookUpTable = LookUpTable::alloc(n, basek, k_lut, extension_factor);
|
||||
let mut lut: LookUpTable = LookUpTable::alloc(module, basek, k_lut, extension_factor);
|
||||
lut.set(module, &f, log_scale);
|
||||
|
||||
let half_step: i64 = lut.domain_size().div_round(message_modulus << 1) as i64;
|
||||
|
||||
Reference in New Issue
Block a user