add zero_byte for fhe_uint & fix test for glwe blind selection

This commit is contained in:
Pro7ech
2025-11-05 09:46:00 +01:00
parent 6cf571c0b0
commit 92cfef5b60
3 changed files with 33 additions and 14 deletions

View File

@@ -234,20 +234,14 @@ impl<D: DataMut, T: UnsignedInteger> FheUint<D, T> {
let trace_start = (T::LOG_BITS - T::LOG_BYTES) as usize; let trace_start = (T::LOG_BITS - T::LOG_BYTES) as usize;
let rot: i64 = (T::bit_index(dst << 3) << log_gap) as i64; let rot: i64 = (T::bit_index(dst << 3) << log_gap) as i64;
// Move a to self and align byte module.glwe_copy(self, a);
module.glwe_rotate(-rot, &mut self.bits, a);
// Stores this byte (everything else zeroed) into tmp_trace self.zero_byte(module, dst, keys, scratch);
let (mut tmp_trace, scratch_1) = scratch.take_glwe(a);
module.glwe_trace(&mut tmp_trace, trace_start, self, keys, scratch_1);
// Subtracts to self to zero it
module.glwe_sub_inplace(&mut self.bits, &tmp_trace);
// Isolate the byte to transfer from a // Isolate the byte to transfer from a
let (mut tmp_fhe_uint_byte, scratch_1) = scratch.take_fhe_uint(b); let (mut tmp_fhe_uint_byte, scratch_1) = scratch.take_fhe_uint(b);
// Move a[byte_a] into a[0] // Move a[byte_a] into a[dst]
module.glwe_rotate( module.glwe_rotate(
-((T::bit_index(src << 3) << log_gap) as i64), -((T::bit_index(src << 3) << log_gap) as i64),
&mut tmp_fhe_uint_byte, &mut tmp_fhe_uint_byte,
@@ -257,11 +251,11 @@ impl<D: DataMut, T: UnsignedInteger> FheUint<D, T> {
// Zeroes all other bytes // Zeroes all other bytes
module.glwe_trace_inplace(&mut tmp_fhe_uint_byte, trace_start, keys, scratch_1); module.glwe_trace_inplace(&mut tmp_fhe_uint_byte, trace_start, keys, scratch_1);
// Moves back self[0] to self[byte_tg]
module.glwe_rotate_inplace(rot, &mut tmp_fhe_uint_byte, scratch_1);
// Add self[0] += a[0] // Add self[0] += a[0]
module.glwe_add_inplace(&mut self.bits, &tmp_fhe_uint_byte); module.glwe_add_inplace(&mut self.bits, &tmp_fhe_uint_byte);
// Moves back self[0] to self[byte_tg]
module.glwe_rotate_inplace(rot, &mut self.bits, scratch);
} }
} }
@@ -313,6 +307,31 @@ impl<D: DataRef, T: UnsignedInteger> GLWEToRef for FheUint<D, T> {
} }
impl<D: DataMut, T: UnsignedInteger> FheUint<D, T> { impl<D: DataMut, T: UnsignedInteger> FheUint<D, T> {
pub fn zero_byte<M, K, H, BE: Backend>(&mut self, module: &M, byte: usize, keys: &H, scratch: &mut Scratch<BE>)
where
H: GLWEAutomorphismKeyHelper<K, BE>,
K: GGLWEPreparedToRef<BE> + GGLWEInfos + GetGaloisElement,
M: ModuleLogN + GLWERotate<BE> + GLWETrace<BE> + GLWESub + GLWEAdd + GLWECopy,
Scratch<BE>: ScratchTakeBDD<T, BE>,
{
let log_gap: usize = module.log_n() - T::LOG_BITS as usize;
let trace_start = (T::LOG_BITS - T::LOG_BYTES) as usize;
let rot: i64 = (T::bit_index(byte << 3) << log_gap) as i64;
// Move a to self and align byte
module.glwe_rotate_inplace(-rot, &mut self.bits, scratch);
// Stores this byte (everything else zeroed) into tmp_trace
let (mut tmp_trace, scratch_1) = scratch.take_glwe(self);
module.glwe_trace(&mut tmp_trace, trace_start, self, keys, scratch_1);
// Subtracts to self to zero it
module.glwe_sub_inplace(&mut self.bits, &tmp_trace);
// Move a to self and align byte
module.glwe_rotate_inplace(rot, &mut self.bits, scratch);
}
pub fn sext<M, H, K, BE>(&mut self, module: &M, byte: usize, keys: &H, scratch: &mut Scratch<BE>) pub fn sext<M, H, K, BE>(&mut self, module: &M, byte: usize, keys: &H, scratch: &mut Scratch<BE>)
where where
M:, M:,

View File

@@ -58,7 +58,7 @@ where
} }
pub fn sext(x: u32, byte: usize) -> u32 { pub fn sext(x: u32, byte: usize) -> u32 {
x | ((x >> (byte << 3)) & 1) * (0xFFFF_FFFF & (0xFFFF_FFFF << (byte << 3))) x | (((x >> (byte << 3)) & 1) * (0xFFFF_FFFF << (byte << 3)))
} }
pub fn test_fhe_uint_splice_u8<BRA: BlindRotationAlgo, BE: Backend>(test_context: &TestContext<BRA, BE>) pub fn test_fhe_uint_splice_u8<BRA: BlindRotationAlgo, BE: Backend>(test_context: &TestContext<BRA, BE>)

View File

@@ -129,7 +129,7 @@ where
res.decrypt(module, &mut pt, sk_glwe_prep, scratch.borrow()); res.decrypt(module, &mut pt, sk_glwe_prep, scratch.borrow());
let idx = ((k >> bit_start) & mask) as usize; let idx = ((k >> bit_start) & mask) as usize;
if idx.is_multiple_of(3) { if !idx.is_multiple_of(3) {
assert_eq!(0, pt.decode_coeff_i64(TorusPrecision(base2k.as_u32()), 0)); assert_eq!(0, pt.decode_coeff_i64(TorusPrecision(base2k.as_u32()), 0));
} else { } else {
assert_eq!( assert_eq!(