Improve cmux speed

This commit is contained in:
Pro7ech
2025-11-07 17:56:33 +01:00
parent 836df871fe
commit 75842cd80a
4 changed files with 270 additions and 222 deletions

View File

@@ -2,9 +2,13 @@ use core::panic;
use itertools::Itertools;
use poulpy_core::{
GLWEAdd, GLWECopy, GLWEExternalProduct, GLWENormalize, GLWESub, ScratchTakeCore, layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef}
GLWECopy, GLWEExternalProductInternal, GLWENormalize, GLWESub, ScratchTakeCore,
layouts::{GGSWInfos, GGSWPrepared, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef},
};
use poulpy_hal::{
api::{ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftBytesOf},
layouts::{Backend, DataMut, Module, Scratch, VecZnxBig, ZnxZero},
};
use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero};
use crate::tfhe::bdd_arithmetic::{GetGGSWBit, UnsignedInteger};
@@ -70,8 +74,18 @@ where
{
#[cfg(debug_assertions)]
{
assert!(inputs.bit_size() >= circuit.input_size(), "inputs.bit_size(): {} < circuit.input_size():{}", inputs.bit_size(), circuit.input_size());
assert!(out.len() >= circuit.output_size(), "out.len(): {} < circuit.output_size(): {}", out.len(), circuit.output_size());
assert!(
inputs.bit_size() >= circuit.input_size(),
"inputs.bit_size(): {} < circuit.input_size():{}",
inputs.bit_size(),
circuit.input_size()
);
assert!(
out.len() >= circuit.output_size(),
"out.len(): {} < circuit.output_size(): {}",
out.len(),
circuit.output_size()
);
}
for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) {
@@ -164,7 +178,14 @@ pub enum Node {
pub trait Cmux<BE: Backend>
where
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd + GLWENormalize<BE>,
Self: Sized
+ GLWEExternalProductInternal<BE>
+ GLWESub
+ VecZnxBigAddSmallInplace<BE>
+ GLWENormalize<BE>
+ VecZnxDftBytesOf
+ VecZnxBigNormalize<BE>
+ VecZnxBigNormalizeTmpBytes,
{
fn cmux_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where
@@ -172,7 +193,11 @@ where
A: GLWEInfos,
B: GGSWInfos,
{
self.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos)
let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), b_infos.size());
res_dft
+ self
.glwe_external_product_internal_tmp_bytes(res_infos, a_infos, b_infos)
.max(self.vec_znx_big_normalize_tmp_bytes())
}
fn cmux<R, T, F, S>(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch<BE>)
@@ -180,34 +205,73 @@ where
R: GLWEToMut,
T: GLWEToRef,
F: GLWEToRef,
S: GGSWPreparedToRef<BE>,
S: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
self.glwe_sub(res, t, f);
self.glwe_normalize_inplace(res, scratch);
self.glwe_external_product_inplace(res, s, scratch);
self.glwe_add_inplace(res, f);
self.glwe_normalize_inplace(res, scratch);
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
let f: GLWE<&[u8]> = f.to_ref();
let res_base2k: usize = res.base2k().into();
let ggsw_base2k: usize = s.base2k().into();
self.glwe_sub(res, t, &f);
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1);
for j in 0..(res.rank() + 1).into() {
self.vec_znx_big_add_small_inplace(&mut res_big, j, f.data(), j);
self.vec_znx_big_normalize(
res_base2k,
res.data_mut(),
j,
ggsw_base2k,
&res_big,
j,
scratch_1,
);
}
}
fn cmux_inplace<R, A, S>(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
S: GGSWPreparedToRef<BE>,
S: GGSWPreparedToRef<BE> + GGSWInfos,
Scratch<BE>: ScratchTakeCore<BE>,
{
self.glwe_sub_inplace(res, a);
self.glwe_normalize_inplace(res, scratch);
self.glwe_external_product_inplace(res, s, scratch);
self.glwe_add_inplace(res, a);
self.glwe_normalize_inplace(res, scratch);
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let s: &GGSWPrepared<&[u8], BE> = &s.to_ref();
let a: GLWE<&[u8]> = a.to_ref();
let res_base2k: usize = res.base2k().into();
let ggsw_base2k: usize = s.base2k().into();
self.glwe_sub_inplace(res, &a);
let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self, (res.rank() + 1).into(), s.size()); // Todo optimise
let mut res_big: VecZnxBig<&mut [u8], BE> = self.glwe_external_product_internal(res_dft, res, s, scratch_1);
for j in 0..(res.rank() + 1).into() {
self.vec_znx_big_add_small_inplace(&mut res_big, j, a.data(), j);
self.vec_znx_big_normalize(
res_base2k,
res.data_mut(),
j,
ggsw_base2k,
&res_big,
j,
scratch_1,
);
}
}
}
impl<BE: Backend> Cmux<BE> for Module<BE>
where
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd + GLWENormalize<BE>,
Self: Sized
+ GLWEExternalProductInternal<BE>
+ GLWESub
+ VecZnxBigAddSmallInplace<BE>
+ GLWENormalize<BE>
+ VecZnxDftBytesOf
+ VecZnxBigNormalize<BE>
+ VecZnxBigNormalizeTmpBytes,
Scratch<BE>: ScratchTakeCore<BE>,
{
}