From 18ca4801ae888f933eecdb7ec1c272dbaedddaeb Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 24 Feb 2025 17:19:43 +0100 Subject: [PATCH] implemented Encoding for VecZnxBorrow --- base2k/src/encoding.rs | 355 +++++++++++++++++++++++++---------------- rlwe/src/encryptor.rs | 2 +- 2 files changed, 215 insertions(+), 142 deletions(-) diff --git a/base2k/src/encoding.rs b/base2k/src/encoding.rs index 0380c03..51840ad 100644 --- a/base2k/src/encoding.rs +++ b/base2k/src/encoding.rs @@ -1,5 +1,5 @@ use crate::ffi::znx::znx_zero_i64_ref; -use crate::{Infos, VecZnx, VecZnxApi}; +use crate::{VecZnx, VecZnxBorrow, VecZnxCommon}; use itertools::izip; use rug::{Assign, Float}; use std::cmp::min; @@ -62,103 +62,15 @@ pub trait Encoding { impl Encoding for VecZnx { fn encode_vec_i64(&mut self, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { - let cols: usize = (log_k + log_base2k - 1) / log_base2k; - - assert!(cols <= self.cols(), "invalid argument log_k: (log_k + self.log_base2k - 1)/self.log_base2k={} > self.cols()={}", cols, self.cols()); - - let size: usize = min(data.len(), self.n()); - let log_k_rem: usize = log_base2k - (log_k % log_base2k); - - // If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy - // values on the last limb. - // Else we decompose values base2k. - if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - (0..self.cols()).for_each(|i| unsafe { - znx_zero_i64_ref(size as u64, self.at_mut(i).as_mut_ptr()); - }); - self.at_mut(cols - 1)[..size].copy_from_slice(&data[..size]); - } else { - let mask: i64 = (1 << log_base2k) - 1; - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - - (0..steps).for_each(|i| unsafe { - znx_zero_i64_ref(size as u64, self.at_mut(i).as_mut_ptr()); - }); - - (cols - steps..cols) - .rev() - .enumerate() - .for_each(|(i, i_rev)| { - let shift: usize = i * log_base2k; - izip!(self.at_mut(i_rev)[..size].iter_mut(), data[..size].iter()) - .for_each(|(y, x)| *y = (x >> shift) & mask); - }) - } - - // Case where self.prec % self.k != 0. - if log_k_rem != log_base2k { - let cols = self.cols(); - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - (cols - steps..cols).rev().for_each(|i| { - self.at_mut(i)[..size] - .iter_mut() - .for_each(|x| *x <<= log_k_rem); - }) - } + encode_vec_i64(self, log_base2k, log_k, data, log_max) } fn decode_vec_i64(&self, log_base2k: usize, log_k: usize, data: &mut [i64]) { - let cols: usize = (log_k + log_base2k - 1) / log_base2k; - assert!( - data.len() >= self.n, - "invalid data: data.len()={} < self.n()={}", - data.len(), - self.n - ); - data.copy_from_slice(self.at(0)); - let rem: usize = log_base2k - (log_k % log_base2k); - (1..cols).for_each(|i| { - if i == cols - 1 && rem != log_base2k { - let k_rem: usize = log_base2k - rem; - izip!(self.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << k_rem) + (x >> rem); - }); - } else { - izip!(self.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << log_base2k) + x; - }); - } - }) + decode_vec_i64(self, log_base2k, log_k, data) } fn decode_vec_float(&self, log_base2k: usize, data: &mut [Float]) { - let cols: usize = self.cols(); - assert!( - data.len() >= self.n(), - "invalid data: data.len()={} < self.n()={}", - data.len(), - self.n() - ); - - let prec: u32 = (log_base2k * cols) as u32; - - // 2^{log_base2k} - let base = Float::with_val(prec, (1 << log_base2k) as f64); - - // y[i] = sum x[j][i] * 2^{-log_base2k*j} - (0..cols).for_each(|i| { - if i == 0 { - izip!(self.at(cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { - y.assign(*x); - *y /= &base; - }); - } else { - izip!(self.at(cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { - *y += Float::with_val(prec, *x); - *y /= &base; - }); - } - }); + decode_vec_float(self, log_base2k, data) } fn encode_coeff_i64( @@ -169,61 +81,222 @@ impl Encoding for VecZnx { value: i64, log_max: usize, ) { - assert!(i < self.n()); - let cols: usize = (log_k + log_base2k - 1) / log_base2k; - assert!(cols <= self.cols(), "invalid argument log_k: (log_k + self.log_base2k - 1)/self.log_base2k={} > self.cols()={}", cols, self.cols()); - let log_k_rem: usize = log_base2k - (log_k % log_base2k); - let cols = self.cols(); - - // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy - // values on the last limb. - // Else we decompose values base2k. - if log_max + log_k_rem < 63 || log_k_rem == log_base2k { - (0..cols - 1).for_each(|j| self.at_mut(j)[i] = 0); - - self.at_mut(self.cols() - 1)[i] = value; - } else { - let mask: i64 = (1 << log_base2k) - 1; - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - - (0..cols - steps).for_each(|j| self.at_mut(j)[i] = 0); - - (cols - steps..cols) - .rev() - .enumerate() - .for_each(|(j, j_rev)| { - self.at_mut(j_rev)[i] = (value >> (j * log_base2k)) & mask; - }) - } - - // Case where self.prec % self.k != 0. - if log_k_rem != log_base2k { - let cols = self.cols(); - let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); - (cols - steps..cols).rev().for_each(|j| { - self.at_mut(j)[i] <<= log_k_rem; - }) - } + encode_coeff_i64(self, log_base2k, log_k, i, value, log_max) } fn decode_coeff_i64(&self, log_base2k: usize, log_k: usize, i: usize) -> i64 { - let cols: usize = (log_k + log_base2k - 1) / log_base2k; - assert!(i < self.n()); - let mut res: i64 = self.data[i]; - let rem: usize = log_base2k - (log_k % log_base2k); - (1..cols).for_each(|i| { - let x = self.data[i * self.n]; - if i == cols - 1 && rem != log_base2k { - let k_rem: usize = log_base2k - rem; - res = (res << k_rem) + (x >> rem); - } else { - res = (res << log_base2k) + x; - } - }); - res + decode_coeff_i64(self, log_base2k, log_k, i) } } +impl Encoding for VecZnxBorrow { + fn encode_vec_i64(&mut self, log_base2k: usize, log_k: usize, data: &[i64], log_max: usize) { + encode_vec_i64(self, log_base2k, log_k, data, log_max) + } + + fn decode_vec_i64(&self, log_base2k: usize, log_k: usize, data: &mut [i64]) { + decode_vec_i64(self, log_base2k, log_k, data) + } + + fn decode_vec_float(&self, log_base2k: usize, data: &mut [Float]) { + decode_vec_float(self, log_base2k, data) + } + + fn encode_coeff_i64( + &mut self, + log_base2k: usize, + log_k: usize, + i: usize, + value: i64, + log_max: usize, + ) { + encode_coeff_i64(self, log_base2k, log_k, i, value, log_max) + } + + fn decode_coeff_i64(&self, log_base2k: usize, log_k: usize, i: usize) -> i64 { + decode_coeff_i64(self, log_base2k, log_k, i) + } +} + +fn encode_vec_i64( + a: &mut T, + log_base2k: usize, + log_k: usize, + data: &[i64], + log_max: usize, +) { + let cols: usize = (log_k + log_base2k - 1) / log_base2k; + + assert!( + cols <= a.cols(), + "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", + cols, + a.cols() + ); + + let size: usize = min(data.len(), a.n()); + let log_k_rem: usize = log_base2k - (log_k % log_base2k); + + // If 2^{log_base2k} * 2^{k_rem} < 2^{63}-1, then we can simply copy + // values on the last limb. + // Else we decompose values base2k. + if log_max + log_k_rem < 63 || log_k_rem == log_base2k { + (0..a.cols()).for_each(|i| unsafe { + znx_zero_i64_ref(size as u64, a.at_mut(i).as_mut_ptr()); + }); + a.at_mut(cols - 1)[..size].copy_from_slice(&data[..size]); + } else { + let mask: i64 = (1 << log_base2k) - 1; + let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); + + (0..steps).for_each(|i| unsafe { + znx_zero_i64_ref(size as u64, a.at_mut(i).as_mut_ptr()); + }); + + (cols - steps..cols) + .rev() + .enumerate() + .for_each(|(i, i_rev)| { + let shift: usize = i * log_base2k; + izip!(a.at_mut(i_rev)[..size].iter_mut(), data[..size].iter()) + .for_each(|(y, x)| *y = (x >> shift) & mask); + }) + } + + // Case where self.prec % self.k != 0. + if log_k_rem != log_base2k { + let cols = a.cols(); + let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); + (cols - steps..cols).rev().for_each(|i| { + a.at_mut(i)[..size] + .iter_mut() + .for_each(|x| *x <<= log_k_rem); + }) + } +} + +fn decode_vec_i64(a: &T, log_base2k: usize, log_k: usize, data: &mut [i64]) { + let cols: usize = (log_k + log_base2k - 1) / log_base2k; + assert!( + data.len() >= a.n(), + "invalid data: data.len()={} < a.n()={}", + data.len(), + a.n() + ); + data.copy_from_slice(a.at(0)); + let rem: usize = log_base2k - (log_k % log_base2k); + (1..cols).for_each(|i| { + if i == cols - 1 && rem != log_base2k { + let k_rem: usize = log_base2k - rem; + izip!(a.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { + *y = (*y << k_rem) + (x >> rem); + }); + } else { + izip!(a.at(i).iter(), data.iter_mut()).for_each(|(x, y)| { + *y = (*y << log_base2k) + x; + }); + } + }) +} + +fn decode_vec_float(a: &T, log_base2k: usize, data: &mut [Float]) { + let cols: usize = a.cols(); + assert!( + data.len() >= a.n(), + "invalid data: data.len()={} < a.n()={}", + data.len(), + a.n() + ); + + let prec: u32 = (log_base2k * cols) as u32; + + // 2^{log_base2k} + let base = Float::with_val(prec, (1 << log_base2k) as f64); + + // y[i] = sum x[j][i] * 2^{-log_base2k*j} + (0..cols).for_each(|i| { + if i == 0 { + izip!(a.at(cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + y.assign(*x); + *y /= &base; + }); + } else { + izip!(a.at(cols - i - 1).iter(), data.iter_mut()).for_each(|(x, y)| { + *y += Float::with_val(prec, *x); + *y /= &base; + }); + } + }); +} + +fn encode_coeff_i64( + a: &mut T, + log_base2k: usize, + log_k: usize, + i: usize, + value: i64, + log_max: usize, +) { + assert!(i < a.n()); + let cols: usize = (log_k + log_base2k - 1) / log_base2k; + assert!( + cols <= a.cols(), + "invalid argument log_k: (log_k + a.log_base2k - 1)/a.log_base2k={} > a.cols()={}", + cols, + a.cols() + ); + let log_k_rem: usize = log_base2k - (log_k % log_base2k); + let cols = a.cols(); + + // If 2^{log_base2k} * 2^{log_k_rem} < 2^{63}-1, then we can simply copy + // values on the last limb. + // Else we decompose values base2k. + if log_max + log_k_rem < 63 || log_k_rem == log_base2k { + (0..cols - 1).for_each(|j| a.at_mut(j)[i] = 0); + + a.at_mut(a.cols() - 1)[i] = value; + } else { + let mask: i64 = (1 << log_base2k) - 1; + let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); + + (0..cols - steps).for_each(|j| a.at_mut(j)[i] = 0); + + (cols - steps..cols) + .rev() + .enumerate() + .for_each(|(j, j_rev)| { + a.at_mut(j_rev)[i] = (value >> (j * log_base2k)) & mask; + }) + } + + // Case where prec % k != 0. + if log_k_rem != log_base2k { + let cols = a.cols(); + let steps: usize = min(cols, (log_max + log_base2k - 1) / log_base2k); + (cols - steps..cols).rev().for_each(|j| { + a.at_mut(j)[i] <<= log_k_rem; + }) + } +} + +fn decode_coeff_i64(a: &T, log_base2k: usize, log_k: usize, i: usize) -> i64 { + let cols: usize = (log_k + log_base2k - 1) / log_base2k; + assert!(i < a.n()); + let data: &[i64] = a.raw(); + let mut res: i64 = data[i]; + let rem: usize = log_base2k - (log_k % log_base2k); + (1..cols).for_each(|i| { + let x = data[i * a.n()]; + if i == cols - 1 && rem != log_base2k { + let k_rem: usize = log_base2k - rem; + res = (res << k_rem) + (x >> rem); + } else { + res = (res << log_base2k) + x; + } + }); + res +} + #[cfg(test)] mod tests { use crate::{Encoding, VecZnx}; diff --git a/rlwe/src/encryptor.rs b/rlwe/src/encryptor.rs index 7156d52..cb69944 100644 --- a/rlwe/src/encryptor.rs +++ b/rlwe/src/encryptor.rs @@ -6,7 +6,7 @@ use crate::plaintext::Plaintext; use base2k::sampling::Sampling; use base2k::{ Module, Scalar, SvpPPol, SvpPPolOps, VecZnx, VecZnxApi, VecZnxBig, VecZnxBigOps, VecZnxBorrow, - VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, cast_mut, + VecZnxDft, VecZnxDftOps, VecZnxOps, VmpPMat, VmpPMatOps, }; use sampling::source::{Source, new_seed};