From 836df871fe4ddca838df3e2e371f30253de27c5f Mon Sep 17 00:00:00 2001 From: Pro7ech Date: Fri, 7 Nov 2025 16:30:47 +0100 Subject: [PATCH] Add normalize in cmux & uint_prepared to uint --- .../bdd_arithmetic/ciphertexts/fhe_uint.rs | 29 ++++++++++++++++++- .../src/tfhe/bdd_arithmetic/eval.rs | 15 ++++++---- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs index 5ae7145..6bd8be1 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/ciphertexts/fhe_uint.rs @@ -13,7 +13,7 @@ use poulpy_hal::{ }; use std::{collections::HashMap, marker::PhantomData}; -use crate::tfhe::bdd_arithmetic::{FromBits, ToBits, UnsignedInteger}; +use crate::tfhe::bdd_arithmetic::{Cmux, FheUintPrepared, FromBits, GetGGSWBit, ToBits, UnsignedInteger}; /// An FHE ciphertext encrypting the bits of an [UnsignedInteger]. pub struct FheUint { @@ -362,6 +362,33 @@ impl GLWEToRef for FheUint { } impl FheUint { + pub fn from_fhe_uint_prepared( + &mut self, + module: &M, + other: &FheUintPrepared, + keys: &H, + scratch: &mut Scratch, + ) where + DR: DataRef, + M: Cmux + ModuleLogN + GLWEPacking + GLWECopy, + Scratch: ScratchTakeCore, + K: GGLWEPreparedToRef + GetGaloisElement + GGLWEInfos, + H: GLWEAutomorphismKeyHelper, + { + let zero: GLWE> = GLWE::alloc_from_infos(self); + let mut one: GLWE> = GLWE::alloc_from_infos(self); + one.data_mut() + .encode_coeff_i64(self.base2k().into(), 0, 2, 0, 1); + + let (mut out_bits, scratch_1) = scratch.take_glwe_slice(T::BITS as usize, self); + + for i in 0..T::BITS as usize { + module.cmux(&mut out_bits[i], &one, &zero, &other.get_bit(i), scratch_1); + } + + self.pack(module, out_bits, keys, scratch_1); + } + pub fn zero_byte(&mut self, module: &M, byte: usize, keys: &H, scratch: &mut Scratch) where H: GLWEAutomorphismKeyHelper, diff --git a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs index 5369903..dbd1862 100644 --- a/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs +++ b/poulpy-schemes/src/tfhe/bdd_arithmetic/eval.rs @@ -2,8 +2,7 @@ use core::panic; use itertools::Itertools; use poulpy_core::{ - GLWEAdd, GLWECopy, GLWEExternalProduct, GLWESub, ScratchTakeCore, - layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef}, + GLWEAdd, GLWECopy, GLWEExternalProduct, GLWENormalize, GLWESub, ScratchTakeCore, layouts::{GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef} }; use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero}; @@ -71,8 +70,8 @@ where { #[cfg(debug_assertions)] { - assert!(inputs.bit_size() >= circuit.input_size()); - assert!(out.len() >= circuit.output_size()); + assert!(inputs.bit_size() >= circuit.input_size(), "inputs.bit_size(): {} < circuit.input_size():{}", inputs.bit_size(), circuit.input_size()); + assert!(out.len() >= circuit.output_size(), "out.len(): {} < circuit.output_size(): {}", out.len(), circuit.output_size()); } for (i, out_i) in out.iter_mut().enumerate().take(circuit.output_size()) { @@ -165,7 +164,7 @@ pub enum Node { pub trait Cmux where - Self: GLWEExternalProduct + GLWESub + GLWEAdd, + Self: GLWEExternalProduct + GLWESub + GLWEAdd + GLWENormalize, { fn cmux_tmp_bytes(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize where @@ -185,8 +184,10 @@ where Scratch: ScratchTakeCore, { self.glwe_sub(res, t, f); + self.glwe_normalize_inplace(res, scratch); self.glwe_external_product_inplace(res, s, scratch); self.glwe_add_inplace(res, f); + self.glwe_normalize_inplace(res, scratch); } fn cmux_inplace(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch) @@ -197,14 +198,16 @@ where Scratch: ScratchTakeCore, { self.glwe_sub_inplace(res, a); + self.glwe_normalize_inplace(res, scratch); self.glwe_external_product_inplace(res, s, scratch); self.glwe_add_inplace(res, a); + self.glwe_normalize_inplace(res, scratch); } } impl Cmux for Module where - Self: GLWEExternalProduct + GLWESub + GLWEAdd, + Self: GLWEExternalProduct + GLWESub + GLWEAdd + GLWENormalize, Scratch: ScratchTakeCore, { }