Add inplace ggsw-based blind-rotation

This commit is contained in:
Pro7ech
2025-10-28 16:03:26 +01:00
parent 8c1cc354e3
commit 37c76b6420
2 changed files with 70 additions and 28 deletions

View File

@@ -25,16 +25,49 @@ where
R: GLWEInfos, R: GLWEInfos,
K: GGSWInfos, K: GGSWInfos,
{ {
self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos) self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos)
}
#[allow(clippy::too_many_arguments)]
/// res <- res * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}.
fn ggsw_blind_rotation_inplace<R, K>(
&self,
res: &mut R,
fhe_uint: &K,
sign: bool,
bit_rsh: usize,
bit_mask: usize,
bit_lsh: usize,
scratch: &mut Scratch<BE>,
) where
R: GGSWToMut,
K: GetGGSWBit<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
for col in 0..(res.rank() + 1).into() {
for row in 0..res.dnum().into() {
self.glwe_blind_rotation_inplace(
&mut res.at_mut(row, col),
fhe_uint,
sign,
bit_rsh,
bit_mask,
bit_lsh,
scratch,
);
}
}
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
/// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. /// res <- a * X^{((k>>bit_rsh) % 2^bit_mask) << bit_lsh}.
fn ggsw_to_ggsw_blind_rotation<R, A, K>( fn ggsw_blind_rotation<R, A, K>(
&self, &self,
res: &mut R, res: &mut R,
a: &A, a: &A,
k: &K, fhe_uint: &K,
sign: bool, sign: bool,
bit_rsh: usize, bit_rsh: usize,
bit_mask: usize, bit_mask: usize,
@@ -54,10 +87,10 @@ where
for col in 0..(res.rank() + 1).into() { for col in 0..(res.rank() + 1).into() {
for row in 0..res.dnum().into() { for row in 0..res.dnum().into() {
self.glwe_to_glwe_blind_rotation( self.glwe_blind_rotation(
&mut res.at_mut(row, col), &mut res.at_mut(row, col),
&a.at(row, col), &a.at(row, col),
k, fhe_uint,
sign, sign,
bit_rsh, bit_rsh,
bit_mask, bit_mask,
@@ -73,7 +106,7 @@ where
R: GLWEInfos, R: GLWEInfos,
K: GGSWInfos, K: GGSWInfos,
{ {
self.glwe_to_glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos) self.glwe_blind_rotation_tmp_bytes(res_infos, k_infos) + GLWE::bytes_of_from_infos(res_infos)
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
@@ -81,7 +114,7 @@ where
&self, &self,
res: &mut R, res: &mut R,
test_vector: &S, test_vector: &S,
k: &K, fhe_uint: &K,
sign: bool, sign: bool,
bit_rsh: usize, bit_rsh: usize,
bit_mask: usize, bit_mask: usize,
@@ -113,10 +146,10 @@ where
); );
self.vec_znx_normalize_inplace(base2k, tmp_glwe.data_mut(), col, scratch_1); self.vec_znx_normalize_inplace(base2k, tmp_glwe.data_mut(), col, scratch_1);
self.glwe_to_glwe_blind_rotation( self.glwe_blind_rotation(
&mut res.at_mut(row, col), &mut res.at_mut(row, col),
&tmp_glwe, &tmp_glwe,
k, fhe_uint,
sign, sign,
bit_rsh, bit_rsh,
bit_mask, bit_mask,
@@ -139,7 +172,7 @@ pub trait GLWEBlindRotation<T: UnsignedInteger, BE: Backend>
where where
Self: GLWECopy + GLWERotate<BE> + Cmux<BE>, Self: GLWECopy + GLWERotate<BE> + Cmux<BE>,
{ {
fn glwe_to_glwe_blind_rotation_tmp_bytes<R, K>(&self, res_infos: &R, k_infos: &K) -> usize fn glwe_blind_rotation_tmp_bytes<R, K>(&self, res_infos: &R, k_infos: &K) -> usize
where where
R: GLWEInfos, R: GLWEInfos,
K: GGSWInfos, K: GGSWInfos,
@@ -148,12 +181,10 @@ where
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
/// res <- a * X^{sign * ((k>>bit_rsh) % 2^bit_mask) << bit_lsh}. fn glwe_blind_rotation_inplace<R, K>(
fn glwe_to_glwe_blind_rotation<R, A, K>(
&self, &self,
res: &mut R, res: &mut R,
a: &A, fhe_uint: &K,
k: &K,
sign: bool, sign: bool,
bit_rsh: usize, bit_rsh: usize,
bit_mask: usize, bit_mask: usize,
@@ -161,31 +192,20 @@ where
scratch: &mut Scratch<BE>, scratch: &mut Scratch<BE>,
) where ) where
R: GLWEToMut, R: GLWEToMut,
A: GLWEToRef,
K: GetGGSWBit<BE>, K: GetGGSWBit<BE>,
Scratch<BE>: ScratchTakeCore<BE>, Scratch<BE>: ScratchTakeCore<BE>,
{ {
assert!(bit_rsh + bit_mask <= T::WORD_SIZE); assert!(bit_rsh + bit_mask <= T::WORD_SIZE);
let mut res: GLWE<&mut [u8]> = res.to_mut(); let mut res: GLWE<&mut [u8]> = res.to_mut();
let a: &GLWE<&[u8]> = &a.to_ref();
let (mut tmp_res, scratch_1) = scratch.take_glwe(&res); let (mut tmp_res, scratch_1) = scratch.take_glwe(&res);
// a <- a ; b <- a * X^{-2^{i + bit_lsh}}
match sign {
true => self.glwe_rotate(1 << bit_lsh, &mut res, a),
false => self.glwe_rotate(-1 << bit_lsh, &mut res, a),
}
// b <- (b - a) * GGSW(b[i]) + a
self.cmux_inplace(&mut res, a, &k.get_bit(bit_rsh), scratch_1);
// a_is_res = true => (a, b) = (&mut res, &mut tmp_res) // a_is_res = true => (a, b) = (&mut res, &mut tmp_res)
// a_is_res = false => (a, b) = (&mut tmp_res, &mut res) // a_is_res = false => (a, b) = (&mut tmp_res, &mut res)
let mut a_is_res: bool = true; let mut a_is_res: bool = true;
for i in 1..bit_mask { for i in 0..bit_mask {
let (a, b) = if a_is_res { let (a, b) = if a_is_res {
(&mut res, &mut tmp_res) (&mut res, &mut tmp_res)
} else { } else {
@@ -199,7 +219,7 @@ where
} }
// b <- (b - a) * GGSW(b[i]) + a // b <- (b - a) * GGSW(b[i]) + a
self.cmux_inplace(b, a, &k.get_bit(i + bit_rsh), scratch_1); self.cmux_inplace(b, a, &fhe_uint.get_bit(i + bit_rsh), scratch_1);
// ping-pong roles for next iter // ping-pong roles for next iter
a_is_res = !a_is_res; a_is_res = !a_is_res;
@@ -210,4 +230,26 @@ where
self.glwe_copy(&mut res, &tmp_res); self.glwe_copy(&mut res, &tmp_res);
} }
} }
#[allow(clippy::too_many_arguments)]
/// res <- a * X^{sign * ((k>>bit_rsh) % 2^bit_mask) << bit_lsh}.
fn glwe_blind_rotation<R, A, K>(
&self,
res: &mut R,
a: &A,
fhe_uint: &K,
sign: bool,
bit_rsh: usize,
bit_mask: usize,
bit_lsh: usize,
scratch: &mut Scratch<BE>,
) where
R: GLWEToMut,
A: GLWEToRef,
K: GetGGSWBit<BE>,
Scratch<BE>: ScratchTakeCore<BE>,
{
self.glwe_copy(res, a);
self.glwe_blind_rotation_inplace(res, fhe_uint, sign, bit_rsh, bit_mask, bit_lsh, scratch);
}
} }

View File

@@ -101,7 +101,7 @@ where
// How many bits to take // How many bits to take
let bit_size: usize = (32 - bit_start).min(digit); let bit_size: usize = (32 - bit_start).min(digit);
module.glwe_to_glwe_blind_rotation( module.glwe_blind_rotation(
&mut res, &mut res,
&test_glwe, &test_glwe,
&k_enc_prep, &k_enc_prep,