Add bdd rotation

This commit is contained in:
Jean-Philippe Bossuat
2025-10-24 18:13:43 +02:00
parent 96d8f4cfc4
commit d989867c91
13 changed files with 177 additions and 32 deletions

View File

@@ -1,6 +1,6 @@
[package]
name = "poulpy-schemes"
version = "0.2.0"
version = "0.3.0"
edition = "2024"
license = "Apache-2.0"
readme = "README.md"

View File

@@ -0,0 +1,42 @@
use poulpy_core::{
GLWECopy, GLWERotate, ScratchTakeCore,
layouts::{GLWE, GLWEToMut},
};
use poulpy_hal::layouts::{Backend, Scratch};
use crate::tfhe::bdd_arithmetic::{Cmux, GetGGSWBit, UnsignedInteger};
pub trait BDDRotation<T: UnsignedInteger, BE: Backend>
where
Self: GLWECopy + GLWERotate<BE> + Cmux<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
/// Homomorphic multiplication of res by X^{k[bit_start..bit_start + bit_size] * bit_step}.
fn bdd_rotate<R, K, D>(
&self,
res: &mut R,
k: K,
bit_start: usize,
bit_size: usize,
bit_step: usize,
scratch: &mut Scratch<BE>,
) where
R: GLWEToMut,
K: GetGGSWBit<T, BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
let (mut tmp_res, scratch_1) = scratch.take_glwe(res);
self.glwe_copy(&mut tmp_res, res);
for i in 1..bit_size {
// res' = res * X^2^(i * bit_step)
self.glwe_rotate(1 << (i + bit_step), &mut tmp_res, res);
// res = (res - res') * GGSW(b[i]) + res'
self.cmux_inplace(res, &tmp_res, &k.get_bit(i + bit_start), scratch_1);
}
}
}

View File

@@ -39,6 +39,17 @@ impl<D: DataRef, T: UnsignedInteger> GLWEInfos for FheUintBlocks<D, T> {
}
}
impl<D: Data, T: UnsignedInteger> FheUintBlocks<D, T> {
pub fn new(blocks: Vec<GLWE<D>>) -> Self {
assert_eq!(blocks.len(), T::WORD_SIZE);
Self {
blocks,
_base: 1,
_phantom: PhantomData,
}
}
}
impl<T: UnsignedInteger> FheUintBlocks<Vec<u8>, T> {
pub fn alloc_from_infos<A, BE: Backend>(module: &Module<BE>, infos: &A) -> Self
where

View File

@@ -3,6 +3,7 @@ use std::marker::PhantomData;
use poulpy_core::layouts::{
Base2K, Dnum, Dsize, GGSWInfos, GGSWPreparedFactory, GLWEInfos, LWEInfos, Rank, TorusPrecision, prepared::GGSWPrepared,
};
use poulpy_core::layouts::{GGSWPreparedToMut, GGSWPreparedToRef};
use poulpy_core::{GGSWEncryptSk, ScratchTakeCore, layouts::GLWESecretPreparedToRef};
use poulpy_hal::layouts::{Backend, Data, DataRef, Module};
@@ -28,6 +29,28 @@ impl<T: UnsignedInteger, BE: Backend> FheUintBlocksPreparedFactory<T, BE> for Mo
{
}
pub trait GetGGSWBit<T: UnsignedInteger, BE: Backend> {
fn get_bit(&self, bit: usize) -> GGSWPrepared<&[u8], BE>;
}
impl<D: DataRef, T: UnsignedInteger, BE: Backend> GetGGSWBit<T, BE> for FheUintBlocksPrepared<D, T, BE> {
fn get_bit(&self, bit: usize) -> GGSWPrepared<&[u8], BE> {
assert!(bit <= self.blocks.len());
self.blocks[bit].to_ref()
}
}
pub trait GetGGSWBitMut<T: UnsignedInteger, BE: Backend> {
fn get_bit(&mut self, bit: usize) -> GGSWPrepared<&mut [u8], BE>;
}
impl<D: DataMut, T: UnsignedInteger, BE: Backend> GetGGSWBitMut<T, BE> for FheUintBlocksPrepared<D, T, BE> {
fn get_bit(&mut self, bit: usize) -> GGSWPrepared<&mut [u8], BE> {
assert!(bit <= self.blocks.len());
self.blocks[bit].to_mut()
}
}
pub trait FheUintBlocksPreparedFactory<T: UnsignedInteger, BE: Backend>
where
Self: Sized + GGSWPreparedFactory<BE>,

View File

@@ -3,12 +3,9 @@ use core::panic;
use itertools::Itertools;
use poulpy_core::{
GLWEAdd, GLWECopy, GLWEExternalProduct, GLWESub, ScratchTakeCore,
layouts::{
GLWE, LWEInfos,
prepared::{GGSWPrepared, GGSWPreparedToRef},
},
layouts::{GLWE, GLWEToMut, GLWEToRef, LWEInfos, prepared::GGSWPreparedToRef},
};
use poulpy_hal::layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero};
use poulpy_hal::layouts::{Backend, DataMut, Module, Scratch, ZnxZero};
use crate::tfhe::bdd_arithmetic::UnsignedInteger;
@@ -146,30 +143,38 @@ pub enum Node {
None,
}
pub trait Cmux<BE: Backend> {
fn cmux<O, T, F, S>(&self, out: &mut GLWE<O>, t: &GLWE<T>, f: &GLWE<F>, s: &GGSWPrepared<S, BE>, scratch: &mut Scratch<BE>)
pub trait Cmux<BE: Backend>
where
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn cmux<R, T, F, S>(&self, res: &mut R, t: &T, f: &F, s: &S, scratch: &mut Scratch<BE>)
where
O: DataMut,
T: DataRef,
F: DataRef,
S: DataRef;
R: GLWEToMut,
T: GLWEToRef,
F: GLWEToRef,
S: GGSWPreparedToRef<BE>,
{
self.glwe_sub(res, t, f);
self.glwe_external_product_inplace(res, s, scratch);
self.glwe_add_inplace(res, f);
}
fn cmux_inplace<R, A, S>(&self, res: &mut R, a: &A, s: &S, scratch: &mut Scratch<BE>)
where
R: GLWEToMut,
A: GLWEToRef,
S: GGSWPreparedToRef<BE>,
{
self.glwe_sub_inplace(res, a);
self.glwe_external_product_inplace(res, s, scratch);
self.glwe_add_inplace(res, a);
}
}
impl<BE: Backend> Cmux<BE> for Module<BE>
where
Module<BE>: GLWEExternalProduct<BE> + GLWESub + GLWEAdd,
Self: GLWEExternalProduct<BE> + GLWESub + GLWEAdd,
Scratch<BE>: ScratchTakeCore<BE>,
{
fn cmux<O, T, F, S>(&self, out: &mut GLWE<O>, t: &GLWE<T>, f: &GLWE<F>, s: &GGSWPrepared<S, BE>, scratch: &mut Scratch<BE>)
where
O: DataMut,
T: DataRef,
F: DataRef,
S: DataRef,
{
// let mut out: GLWECiphertext<&mut [u8]> = out.to_mut();
self.glwe_sub(out, t, f);
self.glwe_external_product_inplace(out, s, scratch);
self.glwe_add_inplace(out, f);
}
}

View File

@@ -1,10 +1,12 @@
mod bdd_2w_to_1w;
mod bdd_rotation;
mod ciphertexts;
mod circuits;
mod eval;
mod key;
pub use bdd_2w_to_1w::*;
pub use bdd_rotation::*;
pub use ciphertexts::*;
pub(crate) use circuits::*;
pub(crate) use eval::*;