Add normalize in cmux & uint_prepared to uint

This commit is contained in:
Pro7ech
2025-11-07 16:30:47 +01:00
parent f13d61443c
commit 836df871fe
2 changed files with 37 additions and 7 deletions

View File

@@ -2,8 +2,7 @@ use core::panic;
use itertools::Itertools;
use poulpy_core::{
GLWEAdd, GLWECopy, GLWEExternalProduct, GLWESub, ScratchTakeCore,
layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef},
GLWEAdd, GLWECopy, GLWEExternalProduct, GLWENormalize, GLWESub, ScratchTakeCore, layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef}
};
use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero};
@@ -71,8 +70,8 @@ where
{
#[cfg(debug_assertions)]
{
assert!(inputs.bit_size() >= circuit.input_size());
assert!(out.len() >= circuit.output_size());
assert!(inputs.bit_size() >= circuit.input_size(), "inputs.bit_size(): {} < circuit.input_size():{}", inputs.bit_size(), circuit.input_size());
assert!(out.len() >= circuit.output_size(), "out.len(): {} < circuit.output_size(): {}", out.len(), circuit.output_size());
}
for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) {
@@ -165,7 +164,7 @@ pub enum Node {
pub trait Cmux<BE: Backend>
where
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd,
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd + GLWENormalize<BE>,
{
fn cmux_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
where
@@ -185,8 +184,10 @@ where
Scratch<BE>: ScratchTakeCore<BE>,
{
self.glwe_sub(res, t, f);
self.glwe_normalize_inplace(res, scratch);
self.glwe_external_product_inplace(res, s, scratch);
self.glwe_add_inplace(res, f);
self.glwe_normalize_inplace(res, scratch);
}
fn cmux_inplace<R, A, S>(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch<BE>)
@@ -197,14 +198,16 @@ where
Scratch<BE>: ScratchTakeCore<BE>,
{
self.glwe_sub_inplace(res, a);
self.glwe_normalize_inplace(res, scratch);
self.glwe_external_product_inplace(res, s, scratch);
self.glwe_add_inplace(res, a);
self.glwe_normalize_inplace(res, scratch);
}
}
impl<BE: Backend> Cmux<BE> for Module<BE>
where
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd,
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd + GLWENormalize<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
}