mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Update to custom fheuint prepare
This commit is contained in:
@@ -6,10 +6,11 @@ use poulpy_hal::{
|
|||||||
},
|
},
|
||||||
oep::{
|
oep::{
|
||||||
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
|
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
|
||||||
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl,
|
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl, VmpZeroImpl,
|
||||||
},
|
},
|
||||||
reference::fft64::vmp::{
|
reference::fft64::vmp::{
|
||||||
vmp_apply_dft_to_dft, vmp_apply_dft_to_dft_add, vmp_apply_dft_to_dft_tmp_bytes, vmp_prepare, vmp_prepare_tmp_bytes,
|
vmp_apply_dft_to_dft, vmp_apply_dft_to_dft_add, vmp_apply_dft_to_dft_tmp_bytes, vmp_prepare, vmp_prepare_tmp_bytes,
|
||||||
|
vmp_zero,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -141,3 +142,12 @@ unsafe impl VmpApplyDftToDftAddTmpBytesImpl<Self> for FFT64Avx {
|
|||||||
vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in)
|
vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsafe impl VmpZeroImpl<Self> for FFT64Avx {
|
||||||
|
fn vmp_zero_impl<R>(_module: &Module<Self>, res: &mut R)
|
||||||
|
where
|
||||||
|
R: VmpPMatToMut<Self>,
|
||||||
|
{
|
||||||
|
vmp_zero(res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,10 +6,11 @@ use poulpy_hal::{
|
|||||||
},
|
},
|
||||||
oep::{
|
oep::{
|
||||||
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
|
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
|
||||||
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl,
|
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl, VmpZeroImpl,
|
||||||
},
|
},
|
||||||
reference::fft64::vmp::{
|
reference::fft64::vmp::{
|
||||||
vmp_apply_dft_to_dft, vmp_apply_dft_to_dft_add, vmp_apply_dft_to_dft_tmp_bytes, vmp_prepare, vmp_prepare_tmp_bytes,
|
vmp_apply_dft_to_dft, vmp_apply_dft_to_dft_add, vmp_apply_dft_to_dft_tmp_bytes, vmp_prepare, vmp_prepare_tmp_bytes,
|
||||||
|
vmp_zero,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -141,3 +142,12 @@ unsafe impl VmpApplyDftToDftAddTmpBytesImpl<Self> for FFT64Ref {
|
|||||||
vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in)
|
vmp_apply_dft_to_dft_tmp_bytes(a_size, b_rows, b_cols_in)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsafe impl VmpZeroImpl<Self> for FFT64Ref {
|
||||||
|
fn vmp_zero_impl<R>(_module: &Module<Self>, res: &mut R)
|
||||||
|
where
|
||||||
|
R: VmpPMatToMut<Self>,
|
||||||
|
{
|
||||||
|
vmp_zero(res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,8 +6,9 @@ use poulpy_hal::{
|
|||||||
},
|
},
|
||||||
oep::{
|
oep::{
|
||||||
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
|
VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl, VmpApplyDftToDftTmpBytesImpl,
|
||||||
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl,
|
VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl, VmpPrepareTmpBytesImpl, VmpZeroImpl,
|
||||||
},
|
},
|
||||||
|
reference::fft64::vmp::vmp_zero,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::cpu_spqlios::{
|
use crate::cpu_spqlios::{
|
||||||
@@ -269,3 +270,12 @@ unsafe impl VmpApplyDftToDftAddImpl<Self> for FFT64Spqlios {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsafe impl VmpZeroImpl<Self> for FFT64Spqlios {
|
||||||
|
fn vmp_zero_impl<R>(_module: &Module<Self>, res: &mut R)
|
||||||
|
where
|
||||||
|
R: VmpPMatToMut<Self>,
|
||||||
|
{
|
||||||
|
vmp_zero(res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use poulpy_hal::{
|
use poulpy_hal::{
|
||||||
api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes},
|
api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes, VmpZero},
|
||||||
layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToMut, VmpPMatToRef, ZnxInfos},
|
layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToMut, VmpPMatToRef, ZnxInfos},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -51,7 +51,7 @@ impl<D: Data, B: Backend> GGSWInfos for GGSWPrepared<D, B> {
|
|||||||
|
|
||||||
pub trait GGSWPreparedFactory<B: Backend>
|
pub trait GGSWPreparedFactory<B: Backend>
|
||||||
where
|
where
|
||||||
Self: GetDegree + VmpPMatAlloc<B> + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare<B>,
|
Self: GetDegree + VmpPMatAlloc<B> + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare<B> + VmpZero<B>,
|
||||||
{
|
{
|
||||||
fn alloc_ggsw_prepared(
|
fn alloc_ggsw_prepared(
|
||||||
&self,
|
&self,
|
||||||
@@ -163,7 +163,7 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> GGSWPreparedFactory<B> for Module<B> where
|
impl<B: Backend> GGSWPreparedFactory<B> for Module<B> where
|
||||||
Self: GetDegree + VmpPMatAlloc<B> + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare<B>
|
Self: GetDegree + VmpPMatAlloc<B> + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare<B> + VmpZero<B>
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -223,6 +223,13 @@ impl<D: DataMut, B: Backend> GGSWPrepared<D, B> {
|
|||||||
{
|
{
|
||||||
module.ggsw_prepare(self, other, scratch);
|
module.ggsw_prepare(self, other, scratch);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn zero<M>(&mut self, module: &M)
|
||||||
|
where
|
||||||
|
M: GGSWPreparedFactory<B>,
|
||||||
|
{
|
||||||
|
module.vmp_zero(&mut self.data);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait GGSWPreparedToMut<B: Backend> {
|
pub trait GGSWPreparedToMut<B: Backend> {
|
||||||
|
|||||||
@@ -155,11 +155,7 @@ where
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn take_ggsw_slice<A>(
|
fn take_ggsw_slice<A>(&mut self, size: usize, infos: &A) -> (Vec<GGSW<&mut [u8]>>, &mut Self)
|
||||||
&mut self,
|
|
||||||
size: usize,
|
|
||||||
infos: &A,
|
|
||||||
) -> (Vec<GGSW<&mut [u8]>>, &mut Self)
|
|
||||||
where
|
where
|
||||||
A: GGSWInfos,
|
A: GGSWInfos,
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -111,3 +111,9 @@ pub trait VmpApplyDftToDftAdd<B: Backend> {
|
|||||||
A: VecZnxDftToRef<B>,
|
A: VecZnxDftToRef<B>,
|
||||||
C: VmpPMatToRef<B>;
|
C: VmpPMatToRef<B>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait VmpZero<B: Backend> {
|
||||||
|
fn vmp_zero<R>(&self, res: &mut R)
|
||||||
|
where
|
||||||
|
R: VmpPMatToMut<B>;
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
api::{
|
api::{
|
||||||
VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes,
|
VmpApplyDft, VmpApplyDftTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftAddTmpBytes,
|
||||||
VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatBytesOf, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes,
|
VmpApplyDftToDftTmpBytes, VmpPMatAlloc, VmpPMatBytesOf, VmpPMatFromBytes, VmpPrepare, VmpPrepareTmpBytes, VmpZero,
|
||||||
},
|
},
|
||||||
layouts::{
|
layouts::{
|
||||||
Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut,
|
Backend, MatZnxToRef, Module, Scratch, VecZnxDftToMut, VecZnxDftToRef, VecZnxToRef, VmpPMatOwned, VmpPMatToMut,
|
||||||
@@ -10,7 +10,7 @@ use crate::{
|
|||||||
oep::{
|
oep::{
|
||||||
VmpApplyDftImpl, VmpApplyDftTmpBytesImpl, VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl,
|
VmpApplyDftImpl, VmpApplyDftTmpBytesImpl, VmpApplyDftToDftAddImpl, VmpApplyDftToDftAddTmpBytesImpl, VmpApplyDftToDftImpl,
|
||||||
VmpApplyDftToDftTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl,
|
VmpApplyDftToDftTmpBytesImpl, VmpPMatAllocBytesImpl, VmpPMatAllocImpl, VmpPMatFromBytesImpl, VmpPrepareImpl,
|
||||||
VmpPrepareTmpBytesImpl,
|
VmpPrepareTmpBytesImpl, VmpZeroImpl,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -161,3 +161,15 @@ where
|
|||||||
B::vmp_apply_dft_to_dft_add_impl(self, res, a, b, scale, scratch);
|
B::vmp_apply_dft_to_dft_add_impl(self, res, a, b, scale, scratch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<B> VmpZero<B> for Module<B>
|
||||||
|
where
|
||||||
|
B: Backend + VmpZeroImpl<B>,
|
||||||
|
{
|
||||||
|
fn vmp_zero<R>(&self, res: &mut R)
|
||||||
|
where
|
||||||
|
R: VmpPMatToMut<B>,
|
||||||
|
{
|
||||||
|
B::vmp_zero_impl(self, res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -145,3 +145,13 @@ pub unsafe trait VmpApplyDftToDftAddImpl<B: Backend> {
|
|||||||
A: VecZnxDftToRef<B>,
|
A: VecZnxDftToRef<B>,
|
||||||
C: VmpPMatToRef<B>;
|
C: VmpPMatToRef<B>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// # THIS TRAIT IS AN OPEN EXTENSION POINT (unsafe)
|
||||||
|
/// * See TODO.
|
||||||
|
/// * See [crate::api::VmpZero] for corresponding public API.
|
||||||
|
/// # Safety [crate::doc::backend_safety] for safety contract.
|
||||||
|
pub unsafe trait VmpZeroImpl<B: Backend> {
|
||||||
|
fn vmp_zero_impl<R>(module: &Module<B>, res: &mut R)
|
||||||
|
where
|
||||||
|
R: VmpPMatToMut<B>;
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
cast_mut,
|
cast_mut,
|
||||||
layouts::{MatZnx, MatZnxToRef, VecZnx, VecZnxToRef, VmpPMatToMut, ZnxView, ZnxViewMut},
|
layouts::{DataViewMut, MatZnx, MatZnxToRef, VecZnx, VecZnxToRef, VmpPMatToMut, ZnxView, ZnxViewMut},
|
||||||
oep::VecZnxDftAllocBytesImpl,
|
oep::VecZnxDftAllocBytesImpl,
|
||||||
reference::fft64::{
|
reference::fft64::{
|
||||||
reim::{ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimZero},
|
reim::{ReimDFTExecute, ReimFFTTable, ReimFromZnx, ReimZero},
|
||||||
@@ -157,6 +157,13 @@ pub fn vmp_apply_dft_to_dft_tmp_bytes(a_size: usize, prows: usize, pcols_in: usi
|
|||||||
(16 + 8 * row_max * pcols_in) * size_of::<f64>()
|
(16 + 8 * row_max * pcols_in) * size_of::<f64>()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn vmp_zero<R, BE: Backend>(res: &mut R)
|
||||||
|
where
|
||||||
|
R: VmpPMatToMut<BE>,
|
||||||
|
{
|
||||||
|
res.to_mut().data_mut().fill(0);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn vmp_apply_dft_to_dft<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, tmp_bytes: &mut [f64])
|
pub fn vmp_apply_dft_to_dft<R, A, M, BE>(res: &mut R, a: &A, pmat: &M, tmp_bytes: &mut [f64])
|
||||||
where
|
where
|
||||||
BE: Backend<ScalarPrep = f64>
|
BE: Backend<ScalarPrep = f64>
|
||||||
|
|||||||
@@ -217,11 +217,12 @@ pub trait FheUintPrepare<BRA: BlindRotationAlgo, T: UnsignedInteger, BE: Backend
|
|||||||
DB: DataRef,
|
DB: DataRef,
|
||||||
DK: DataRef,
|
DK: DataRef,
|
||||||
K: BDDKeyHelper<DK, BRA, BE>;
|
K: BDDKeyHelper<DK, BRA, BE>;
|
||||||
fn fhe_uint_prepare_partial<DM, DB, DK, K>(
|
fn fhe_uint_prepare_custom<DM, DB, DK, K>(
|
||||||
&self,
|
&self,
|
||||||
res: &mut FheUintPrepared<DM, T, BE>,
|
res: &mut FheUintPrepared<DM, T, BE>,
|
||||||
bits: &FheUint<DB, T>,
|
bits: &FheUint<DB, T>,
|
||||||
count: usize,
|
bit_start: usize,
|
||||||
|
bit_end: usize,
|
||||||
key: &K,
|
key: &K,
|
||||||
scratch: &mut Scratch<BE>,
|
scratch: &mut Scratch<BE>,
|
||||||
) where
|
) where
|
||||||
@@ -261,22 +262,15 @@ where
|
|||||||
DK: DataRef,
|
DK: DataRef,
|
||||||
K: BDDKeyHelper<DK, BRA, BE>,
|
K: BDDKeyHelper<DK, BRA, BE>,
|
||||||
{
|
{
|
||||||
let (cbt, ks) = key.get_cbt_key();
|
self.fhe_uint_prepare_custom(res, bits, 0, T::BITS as usize, key, scratch);
|
||||||
|
|
||||||
let mut lwe: LWE<Vec<u8>> = LWE::alloc_from_infos(bits); //TODO: add TakeLWE
|
|
||||||
let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(res);
|
|
||||||
for (bit, dst) in res.bits.iter_mut().enumerate() {
|
|
||||||
bits.get_bit_lwe(self, bit, &mut lwe, ks, scratch_1);
|
|
||||||
cbt.execute_to_constant(self, &mut tmp_ggsw, &lwe, 1, 1, scratch_1);
|
|
||||||
dst.prepare(self, &tmp_ggsw, scratch_1);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn fhe_uint_prepare_partial<DM, DB, DK, K>(
|
fn fhe_uint_prepare_custom<DM, DB, DK, K>(
|
||||||
&self,
|
&self,
|
||||||
res: &mut FheUintPrepared<DM, T, BE>,
|
res: &mut FheUintPrepared<DM, T, BE>,
|
||||||
bits: &FheUint<DB, T>,
|
bits: &FheUint<DB, T>,
|
||||||
count: usize,
|
bit_start: usize,
|
||||||
|
bit_end: usize,
|
||||||
key: &K,
|
key: &K,
|
||||||
scratch: &mut Scratch<BE>,
|
scratch: &mut Scratch<BE>,
|
||||||
) where
|
) where
|
||||||
@@ -289,11 +283,20 @@ where
|
|||||||
|
|
||||||
let mut lwe: LWE<Vec<u8>> = LWE::alloc_from_infos(bits); //TODO: add TakeLWE
|
let mut lwe: LWE<Vec<u8>> = LWE::alloc_from_infos(bits); //TODO: add TakeLWE
|
||||||
let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(res);
|
let (mut tmp_ggsw, scratch_1) = scratch.take_ggsw(res);
|
||||||
for (bit, dst) in res.bits[0..count].iter_mut().enumerate() { // TODO: set the rest of the bits to a prepared zero GGSW
|
for (bit, dst) in res.bits[bit_start..bit_end].iter_mut().enumerate() {
|
||||||
|
// TODO: set the rest of the bits to a prepared zero GGSW
|
||||||
bits.get_bit_lwe(self, bit, &mut lwe, ks, scratch_1);
|
bits.get_bit_lwe(self, bit, &mut lwe, ks, scratch_1);
|
||||||
cbt.execute_to_constant(self, &mut tmp_ggsw, &lwe, 1, 1, scratch_1);
|
cbt.execute_to_constant(self, &mut tmp_ggsw, &lwe, 1, 1, scratch_1);
|
||||||
dst.prepare(self, &tmp_ggsw, scratch_1);
|
dst.prepare(self, &tmp_ggsw, scratch_1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for i in 0..bit_start {
|
||||||
|
res.bits[i].zero(self);
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in bit_end..T::BITS as usize {
|
||||||
|
res.bits[i].zero(self);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -309,8 +312,15 @@ impl<D: DataMut, T: UnsignedInteger, BE: Backend> FheUintPrepared<D, T, BE> {
|
|||||||
{
|
{
|
||||||
module.fhe_uint_prepare(self, other, key, scratch);
|
module.fhe_uint_prepare(self, other, key, scratch);
|
||||||
}
|
}
|
||||||
pub fn prepare_partial<BRA, M, O, K, DK>(&mut self, module: &M, other: &FheUint<O, T>, count: usize, key: &K, scratch: &mut Scratch<BE>)
|
pub fn prepare_partial<BRA, M, O, K, DK>(
|
||||||
where
|
&mut self,
|
||||||
|
module: &M,
|
||||||
|
other: &FheUint<O, T>,
|
||||||
|
bit_start: usize,
|
||||||
|
bit_end: usize,
|
||||||
|
key: &K,
|
||||||
|
scratch: &mut Scratch<BE>,
|
||||||
|
) where
|
||||||
BRA: BlindRotationAlgo,
|
BRA: BlindRotationAlgo,
|
||||||
O: DataRef,
|
O: DataRef,
|
||||||
DK: DataRef,
|
DK: DataRef,
|
||||||
@@ -318,6 +328,6 @@ impl<D: DataMut, T: UnsignedInteger, BE: Backend> FheUintPrepared<D, T, BE> {
|
|||||||
M: FheUintPrepare<BRA, T, BE>,
|
M: FheUintPrepare<BRA, T, BE>,
|
||||||
Scratch<BE>: ScratchTakeCore<BE>,
|
Scratch<BE>: ScratchTakeCore<BE>,
|
||||||
{
|
{
|
||||||
module.fhe_uint_prepare_partial(self, other, count, key, scratch);
|
module.fhe_uint_prepare_custom(self, other, bit_start, bit_end, key, scratch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user