mirror of
https://github.com/arnaucube/poulpy.git
synced 2026-02-10 05:06:44 +01:00
Improve cmux speed
This commit is contained in:
@@ -382,8 +382,8 @@ impl<D: DataMut, T: UnsignedInteger> FheUint<D, T> {
|
||||
|
||||
let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, self);
|
||||
|
||||
for i in 0..T::BITS as usize {
|
||||
module.cmux(&mut out_bits[i], &one, &zero, &other.get_bit(i), scratch_1);
|
||||
for (i, bits) in out_bits.iter_mut().enumerate().take(T::BITS as usize) {
|
||||
module.cmux(bits, &one, &zero, &other.get_bit(i), scratch_1);
|
||||
}
|
||||
|
||||
self.pack(module, out_bits, keys, scratch_1);
|
||||
|
||||
@@ -2,9 +2,13 @@ use core::panic;
|
||||
|
||||
use itertools::Itertools;
|
||||
use poulpy_core::{
|
||||
GLWEAdd, GLWECopy, GLWEExternalProduct, GLWENormalize, GLWESub, ScratchTakeCore, layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef}
|
||||
GLWECopy, GLWEExternalProductInternal, GLWENormalize, GLWESub, ScratchTakeCore,
|
||||
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef},
|
||||
};
|
||||
use poulpy_hal::{
|
||||
api::{ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftBytesOf},
|
||||
layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero},
|
||||
};
|
||||
use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero};
|
||||
|
||||
use crate::tfhe::bdd_arithmetic::{GetGGSWBit, UnsignedInteger};
|
||||
|
||||
@@ -70,8 +74,18 @@ where
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
assert!(inputs.bit_size() >= circuit.input_size(), "inputs.bit_size(): {} < circuit.input_size():{}", inputs.bit_size(), circuit.input_size());
|
||||
assert!(out.len() >= circuit.output_size(), "out.len(): {} < circuit.output_size(): {}", out.len(), circuit.output_size());
|
||||
assert!(
|
||||
inputs.bit_size() >= circuit.input_size(),
|
||||
"inputs.bit_size(): {} < circuit.input_size():{}",
|
||||
inputs.bit_size(),
|
||||
circuit.input_size()
|
||||
);
|
||||
assert!(
|
||||
out.len() >= circuit.output_size(),
|
||||
"out.len(): {} < circuit.output_size(): {}",
|
||||
out.len(),
|
||||
circuit.output_size()
|
||||
);
|
||||
}
|
||||
|
||||
for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) {
|
||||
@@ -164,7 +178,14 @@ pub enum Node {
|
||||
|
||||
pub trait Cmux<BE: Backend>
|
||||
where
|
||||
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd + GLWENormalize<BE>,
|
||||
Self: Sized
|
||||
+ GLWEExternalProductInternal<BE>
|
||||
+ GLWESub
|
||||
+ VecZnxBigAddSmallInplace<BE>
|
||||
+ GLWENormalize<BE>
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
{
|
||||
fn cmux_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
|
||||
where
|
||||
@@ -172,7 +193,11 @@ where
|
||||
A: GLWEInfos,
|
||||
B: GGSWInfos,
|
||||
{
|
||||
self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos)
|
||||
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), b_infos.size());
|
||||
res_dft
|
||||
+ self
|
||||
.glwe_external_product_internal_tmp_bytes(res_infos, a_infos, b_infos)
|
||||
.max(self.vec_znx_big_normalize_tmp_bytes())
|
||||
}
|
||||
|
||||
fn cmux<R, T, F, S>(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch<BE>)
|
||||
@@ -180,34 +205,73 @@ where
|
||||
R: GLWEToMut,
|
||||
T: GLWEToRef,
|
||||
F: GLWEToRef,
|
||||
S: GGSWPreparedToRef<BE>,
|
||||
S: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
self.glwe_sub(res, t, f);
|
||||
self.glwe_normalize_inplace(res, scratch);
|
||||
self.glwe_external_product_inplace(res, s, scratch);
|
||||
self.glwe_add_inplace(res, f);
|
||||
self.glwe_normalize_inplace(res, scratch);
|
||||
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||
let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
|
||||
let f: GLWE<&[u8]> = f.to_ref();
|
||||
|
||||
let res_base2k: usize = res.base2k().into();
|
||||
let ggsw_base2k: usize = s.base2k().into();
|
||||
|
||||
self.glwe_sub(res, t, &f);
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise
|
||||
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1);
|
||||
for j in 0..(res.rank() + 1).into() {
|
||||
self.vec_znx_big_add_small_inplace(&mut res_big, j, f.data(), j);
|
||||
self.vec_znx_big_normalize(
|
||||
res_base2k,
|
||||
res.data_mut(),
|
||||
j,
|
||||
ggsw_base2k,
|
||||
&res_big,
|
||||
j,
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn cmux_inplace<R, A, S>(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch<BE>)
|
||||
where
|
||||
R: GLWEToMut,
|
||||
A: GLWEToRef,
|
||||
S: GGSWPreparedToRef<BE>,
|
||||
S: GGSWPreparedToRef<BE> + GGSWInfos,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
self.glwe_sub_inplace(res, a);
|
||||
self.glwe_normalize_inplace(res, scratch);
|
||||
self.glwe_external_product_inplace(res, s, scratch);
|
||||
self.glwe_add_inplace(res, a);
|
||||
self.glwe_normalize_inplace(res, scratch);
|
||||
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
|
||||
let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
|
||||
let a: GLWE<&[u8]> = a.to_ref();
|
||||
let res_base2k: usize = res.base2k().into();
|
||||
let ggsw_base2k: usize = s.base2k().into();
|
||||
self.glwe_sub_inplace(res, &a);
|
||||
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise
|
||||
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1);
|
||||
for j in 0..(res.rank() + 1).into() {
|
||||
self.vec_znx_big_add_small_inplace(&mut res_big, j, a.data(), j);
|
||||
self.vec_znx_big_normalize(
|
||||
res_base2k,
|
||||
res.data_mut(),
|
||||
j,
|
||||
ggsw_base2k,
|
||||
&res_big,
|
||||
j,
|
||||
scratch_1,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<BE: Backend> Cmux<BE> for Module<BE>
|
||||
where
|
||||
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd + GLWENormalize<BE>,
|
||||
Self: Sized
|
||||
+ GLWEExternalProductInternal<BE>
|
||||
+ GLWESub
|
||||
+ VecZnxBigAddSmallInplace<BE>
|
||||
+ GLWENormalize<BE>
|
||||
+ VecZnxDftBytesOf
|
||||
+ VecZnxBigNormalize<BE>
|
||||
+ VecZnxBigNormalizeTmpBytes,
|
||||
Scratch<BE>: ScratchTakeCore<BE>,
|
||||
{
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user