mirror of
https://github.com/arnaucube/fhe-study.git
synced 2026-01-24 04:33:52 +01:00
add arith::{complex, matrix} primitives
This commit is contained in:
@@ -7,3 +7,11 @@ edition = "2024"
|
||||
anyhow = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
rand_distr = { workspace = true }
|
||||
|
||||
# TMP: the next 4 imports are TMP, to solve systems of linear equations. Used
|
||||
# for the CKKS encoding step, probably remvoed once in ckks the encoding is done
|
||||
# as in 2018-1043 or 2018-1073.
|
||||
num = "0.4.3"
|
||||
num-complex = "0.4.6"
|
||||
ndarray = "0.16.1"
|
||||
ndarray-linalg = { version = "0.17.0", features = ["intel-mkl"] }
|
||||
|
||||
329
arith/src/complex.rs
Normal file
329
arith/src/complex.rs
Normal file
@@ -0,0 +1,329 @@
|
||||
//! Complex
|
||||
|
||||
use rand::Rng;
|
||||
use std::ops::{Add, Div, Mul, Neg, Sub};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
|
||||
pub struct C<T> {
|
||||
pub re: T,
|
||||
pub im: T,
|
||||
}
|
||||
|
||||
impl From<f64> for C<f64> {
|
||||
fn from(v: f64) -> Self {
|
||||
Self { re: v, im: 0_f64 }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> C<T>
|
||||
where
|
||||
T: Default + From<i32>,
|
||||
{
|
||||
pub fn new(re: T, im: T) -> Self {
|
||||
Self { re, im }
|
||||
}
|
||||
|
||||
pub fn rand(mut rng: impl Rng, max: u64) -> Self {
|
||||
Self::new(
|
||||
T::from(rng.gen_range(0..max) as i32),
|
||||
T::from(rng.gen_range(0..max) as i32),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn zero() -> C<T> {
|
||||
Self {
|
||||
re: T::from(0),
|
||||
im: T::from(0),
|
||||
}
|
||||
}
|
||||
pub fn one() -> C<T> {
|
||||
Self {
|
||||
re: T::from(1),
|
||||
im: T::from(0),
|
||||
}
|
||||
}
|
||||
pub fn i() -> C<T> {
|
||||
Self {
|
||||
re: T::from(0),
|
||||
im: T::from(1),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl C<f64> {
|
||||
// cos & sin from Taylor series approximation, details at
|
||||
// https://en.wikipedia.org/wiki/Sine_and_cosine#Series_and_polynomials
|
||||
fn cos(x: f64) -> f64 {
|
||||
let mut r = 1.0;
|
||||
let mut term = 1.0;
|
||||
let mut n = 1;
|
||||
for _ in 0..10 {
|
||||
term *= -(x * x) / ((2 * n - 1) * (2 * n)) as f64;
|
||||
r += term;
|
||||
n += 1;
|
||||
}
|
||||
|
||||
r
|
||||
}
|
||||
fn sin(x: f64) -> f64 {
|
||||
let mut r = x;
|
||||
let mut term = x;
|
||||
let mut n = 1;
|
||||
|
||||
for _ in 0..10 {
|
||||
term *= -(x * x) / ((2 * n) * (2 * n + 1)) as f64;
|
||||
r += term;
|
||||
n += 1;
|
||||
}
|
||||
|
||||
r
|
||||
}
|
||||
|
||||
// e^(self))
|
||||
pub fn exp(self) -> Self {
|
||||
Self {
|
||||
re: Self::cos(self.im), // TODO WIP review
|
||||
im: Self::sin(self.im),
|
||||
}
|
||||
}
|
||||
pub fn pow(self, k: u32) -> Self {
|
||||
let mut k = k;
|
||||
if k == 0 {
|
||||
return Self::one();
|
||||
}
|
||||
let mut base = self.clone();
|
||||
while k & 1 == 0 {
|
||||
base = base.clone() * base;
|
||||
k >>= 1;
|
||||
}
|
||||
|
||||
if k == 1 {
|
||||
return base;
|
||||
}
|
||||
|
||||
let mut acc = base.clone();
|
||||
while k > 1 {
|
||||
k >>= 1;
|
||||
base = base.clone() * base;
|
||||
if k & 1 == 1 {
|
||||
acc = acc * base.clone();
|
||||
}
|
||||
}
|
||||
acc
|
||||
}
|
||||
pub fn modulus<const Q: u64>(self) -> Self {
|
||||
let q: f64 = Q as f64;
|
||||
let re = (self.re % q + q) % q;
|
||||
let im = (self.im % q + q) % q;
|
||||
Self { re, im }
|
||||
}
|
||||
pub fn modulus_centered<const Q: u64>(self) -> Self {
|
||||
let re = modulus_centered_f64::<Q>(self.re);
|
||||
let im = modulus_centered_f64::<Q>(self.im);
|
||||
Self { re, im }
|
||||
}
|
||||
}
|
||||
|
||||
fn modulus_centered_f64<const Q: u64>(v: f64) -> f64 {
|
||||
let q = Q as f64;
|
||||
let mut res = v % q;
|
||||
if res > q / 2.0 {
|
||||
res = res - q;
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
impl<T: Default> Default for C<T> {
|
||||
fn default() -> Self {
|
||||
C {
|
||||
re: T::default(),
|
||||
im: T::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Add for C<T>
|
||||
where
|
||||
T: Add<Output = T> + Copy,
|
||||
{
|
||||
type Output = Self;
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
C {
|
||||
re: self.re + rhs.re,
|
||||
im: self.im + rhs.im,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Sub for C<T>
|
||||
where
|
||||
T: Sub<Output = T> + Copy,
|
||||
{
|
||||
type Output = Self;
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
C {
|
||||
re: self.re - rhs.re,
|
||||
im: self.im - rhs.im,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Mul for C<T>
|
||||
where
|
||||
T: Mul<Output = T> + Sub<Output = T> + Add<Output = T> + Copy,
|
||||
{
|
||||
type Output = Self;
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
C {
|
||||
re: self.re * rhs.re - self.im * rhs.im,
|
||||
im: self.re * rhs.im + self.im * rhs.re,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Neg for C<T>
|
||||
where
|
||||
T: Neg<Output = T> + Copy,
|
||||
{
|
||||
type Output = Self;
|
||||
fn neg(self) -> Self::Output {
|
||||
C {
|
||||
re: -self.re,
|
||||
im: -self.im,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Div for C<T>
|
||||
where
|
||||
T: Copy
|
||||
+ Add<Output = T>
|
||||
+ Sub<Output = T>
|
||||
+ Mul<Output = T>
|
||||
+ Div<Output = T>
|
||||
+ Neg<Output = T>,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn div(self, rhs: Self) -> Self {
|
||||
// (a+ib)/(c+id) = (ac + bd)/(c^2 + d^2) + i* (bc -ad)/(c^2 + d^2)
|
||||
let den = rhs.re * rhs.re + rhs.im * rhs.im;
|
||||
C {
|
||||
re: (self.re * rhs.re + self.im * rhs.im) / den,
|
||||
im: (self.im * rhs.re - self.re * rhs.im) / den,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<T> C<T>
|
||||
where
|
||||
T: Neg<Output = T> + Copy,
|
||||
{
|
||||
pub fn conj(&self) -> Self {
|
||||
C {
|
||||
re: self.re,
|
||||
im: -self.im,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Add<Output = T> + Default + Copy> std::iter::Sum for C<T> {
|
||||
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||
iter.fold(C::default(), |acc, x| acc + x)
|
||||
}
|
||||
}
|
||||
|
||||
impl C<f64> {
|
||||
pub fn abs(&self) -> f64 {
|
||||
(self.re * self.re + self.im * self.im).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
// poly mul with complex coefficients
|
||||
pub fn naive_poly_mul<const N: usize>(poly1: &Vec<C<f64>>, poly2: &Vec<C<f64>>) -> Vec<C<f64>> {
|
||||
let mut result: Vec<C<f64>> = vec![C::<f64>::zero(); (N * 2) - 1];
|
||||
for i in 0..N {
|
||||
for j in 0..N {
|
||||
result[i + j] = result[i + j] + poly1[i] * poly2[j];
|
||||
}
|
||||
}
|
||||
|
||||
// apply mod (X^N + 1))
|
||||
// R::<N>::from_vec(result.iter().map(|c| *c as i64).collect())
|
||||
// modulus_i128::<N>(&mut result);
|
||||
// dbg!(&result);
|
||||
// dbg!(R::<N>(array::from_fn(|i| result[i] as i64)).coeffs());
|
||||
|
||||
// R(array::from_fn(|i| result[i] as i64))
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_i() {
|
||||
assert_eq!(C::i(), C::new(0.0, 1.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add() {
|
||||
let a = C::new(2.0, 3.0);
|
||||
let b = C::new(1.0, 4.0);
|
||||
assert_eq!(a + b, C::new(3.0, 7.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub() {
|
||||
let a = C::new(5.0, 7.0);
|
||||
let b = C::new(2.0, 3.0);
|
||||
assert_eq!(a - b, C::new(3.0, 4.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mult() {
|
||||
let a = C::new(1.0, 2.0);
|
||||
let b = C::new(3.0, 4.0);
|
||||
assert_eq!(a * b, C::new(-5.0, 10.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_div() {
|
||||
let a: C<f64> = C::new(1.0, 2.0);
|
||||
let b: C<f64> = C::new(3.0, 4.0);
|
||||
let r = a / b;
|
||||
let expected = C::new(0.44, 0.08);
|
||||
let epsilon = 1e-2;
|
||||
assert!((r.re - expected.re).abs() < epsilon);
|
||||
assert!((r.im - expected.im).abs() < epsilon);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conj() {
|
||||
let a = C::new(3.0, -4.0);
|
||||
assert_eq!(a.conj(), C::new(3.0, 4.0));
|
||||
assert_eq!(a.conj().conj(), a);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neg() {
|
||||
let a = C::new(1.0, -2.0);
|
||||
assert_eq!(-a, C::new(-1.0, 2.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_abs() {
|
||||
let a = C::new(3.0, 4.0);
|
||||
assert!((a.abs() - 5.0).abs() < 1e-10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exp() {
|
||||
// let a = C::new(3.0, 4.0);
|
||||
let pi = C::<f64>::from(std::f64::consts::PI);
|
||||
let n = 4;
|
||||
let a = ((C::<f64>::from(2f64) * pi * C::<f64>::i()) / C::<f64>::new(n as f64, 0f64)).exp();
|
||||
dbg!(&a);
|
||||
assert_eq!(a.exp(), a.exp());
|
||||
}
|
||||
}
|
||||
@@ -4,12 +4,16 @@
|
||||
#![allow(clippy::upper_case_acronyms)]
|
||||
#![allow(dead_code)] // TMP
|
||||
|
||||
mod naive_ntt; // TODO rm
|
||||
pub mod complex;
|
||||
pub mod matrix;
|
||||
mod naive_ntt; // note: for dev only
|
||||
pub mod ntt;
|
||||
pub mod ring;
|
||||
pub mod ringq;
|
||||
pub mod zq;
|
||||
|
||||
pub use complex::C;
|
||||
pub use matrix::Matrix;
|
||||
pub use ntt::NTT;
|
||||
pub use ring::R;
|
||||
pub use ringq::Rq;
|
||||
|
||||
187
arith/src/matrix.rs
Normal file
187
arith/src/matrix.rs
Normal file
@@ -0,0 +1,187 @@
|
||||
use anyhow::{anyhow, Result};
|
||||
use std::ops::{Add, Mul};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Matrix<T>(pub Vec<Vec<T>>);
|
||||
|
||||
impl<T> Matrix<T>
|
||||
where
|
||||
T: Copy + Add<Output = T> + Mul<Output = T> + Default + std::fmt::Debug,
|
||||
{
|
||||
// TODO maybe rm this method, move it to tests only
|
||||
pub fn new(rows: usize, cols: usize, value: T) -> Self {
|
||||
Matrix(vec![vec![value; cols]; rows])
|
||||
}
|
||||
|
||||
pub fn add(&self, other: &Matrix<T>) -> Result<Matrix<T>> {
|
||||
if self.0.len() != other.0.len() || self.0[0].len() != other.0[0].len() {
|
||||
return Err(anyhow!("dimensions don't match"));
|
||||
}
|
||||
|
||||
let r = self
|
||||
.0
|
||||
.iter()
|
||||
.zip(&other.0)
|
||||
.map(|(row1, row2)| {
|
||||
row1.iter()
|
||||
.zip(row2)
|
||||
.map(|(a, b)| *a + *b)
|
||||
.collect::<Vec<T>>()
|
||||
})
|
||||
.collect::<Vec<Vec<T>>>();
|
||||
|
||||
Ok(Matrix(r))
|
||||
}
|
||||
|
||||
pub fn mul(&self, other: &Matrix<T>) -> Result<Matrix<T>> {
|
||||
let rows_a = self.0.len();
|
||||
let cols_a = self.0[0].len();
|
||||
let rows_b = other.0.len();
|
||||
let cols_b = other.0[0].len();
|
||||
|
||||
if cols_a != rows_b {
|
||||
return Err(anyhow!("self.n_cols != other.n_rows"));
|
||||
}
|
||||
|
||||
let mut r = vec![vec![T::default(); cols_b]; rows_a];
|
||||
|
||||
for i in 0..rows_a {
|
||||
for j in 0..cols_b {
|
||||
for k in 0..cols_a {
|
||||
r[i][j] = r[i][j] + self.0[i][k] * other.0[k][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Matrix(r))
|
||||
}
|
||||
pub fn mul_vec(&self, v: &Vec<T>) -> Result<Vec<T>> {
|
||||
let rows = self.0.len();
|
||||
let cols = self.0[0].len();
|
||||
|
||||
if cols != v.len() {
|
||||
return Err(anyhow!(
|
||||
"Number of columns in matrix does not match the length of the vector"
|
||||
));
|
||||
}
|
||||
|
||||
let mut r = vec![T::default(); rows];
|
||||
|
||||
for i in 0..rows {
|
||||
for j in 0..cols {
|
||||
r[i] = r[i] + self.0[i][j] * v[j];
|
||||
}
|
||||
}
|
||||
Ok(r)
|
||||
}
|
||||
|
||||
pub fn transpose(&self) -> Matrix<T> {
|
||||
let rows = self.0.len();
|
||||
let cols = self.0[0].len();
|
||||
let mut r = vec![vec![T::default(); rows]; cols];
|
||||
|
||||
for i in 0..rows {
|
||||
for j in 0..cols {
|
||||
r[j][i] = self.0[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
Matrix(r)
|
||||
}
|
||||
|
||||
pub fn scalar_mul(&self, scalar: T) -> Matrix<T> {
|
||||
let r = self
|
||||
.0
|
||||
.iter()
|
||||
.map(|row| row.iter().map(|&val| val * scalar).collect::<Vec<T>>())
|
||||
.collect::<Vec<Vec<T>>>();
|
||||
|
||||
Matrix(r)
|
||||
}
|
||||
}
|
||||
|
||||
// WIP. Currently uses ndarray, ndarray_linalg, num_complex to solve a system of
|
||||
// linear equations A*x=b for x.
|
||||
use crate::C;
|
||||
impl Matrix<C<f64>> {
|
||||
pub fn solve(&self, b: &Vec<C<f64>>) -> Result<Vec<C<f64>>> {
|
||||
use ndarray::{Array1, Array2};
|
||||
use ndarray_linalg::Solve;
|
||||
use num_complex::Complex64;
|
||||
|
||||
let m: Array2<Complex64> = Array2::from_shape_vec(
|
||||
(self.0.len(), self.0[0].len()),
|
||||
self.0
|
||||
.clone()
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(|e| Complex64::new(e.re, e.im))
|
||||
.collect(),
|
||||
)
|
||||
.unwrap();
|
||||
let v: Array1<Complex64> = Array1::from_shape_vec(
|
||||
b.len(),
|
||||
b.iter().map(|e| Complex64::new(e.re, e.im)).collect(),
|
||||
)
|
||||
.unwrap();
|
||||
let r = m.solve(&v)?;
|
||||
let r: Vec<C<f64>> = r.into_iter().map(|e| C::<f64>::new(e.re, e.im)).collect();
|
||||
Ok(r)
|
||||
}
|
||||
}
|
||||
impl Matrix<f64> {
|
||||
// tmp (rm)
|
||||
pub fn solve(&self, b: Vec<f64>) -> Result<Vec<f64>> {
|
||||
use ndarray::{Array1, Array2};
|
||||
use ndarray_linalg::Solve;
|
||||
|
||||
let m: Array2<f64> = Array2::from_shape_vec(
|
||||
(self.0.len(), self.0[0].len()),
|
||||
self.0.clone().into_iter().flatten().collect(),
|
||||
)
|
||||
.unwrap();
|
||||
let v: Array1<f64> = Array1::from_shape_vec(b.len(), b).unwrap();
|
||||
let r = m.solve(&v)?;
|
||||
let r: Vec<f64> = r.into_iter().map(|e| e).collect();
|
||||
Ok(r)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_add() -> Result<()> {
|
||||
let a = Matrix::new(2, 3, 1);
|
||||
let b = Matrix::new(2, 3, 2);
|
||||
let expected = Matrix::new(2, 3, 3);
|
||||
assert_eq!(a.add(&b).unwrap(), expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul() -> Result<()> {
|
||||
let a = Matrix::new(2, 3, 1);
|
||||
let b = Matrix::new(3, 2, 1);
|
||||
let expected = Matrix::new(2, 2, 3); // 2x3 * 3x2 = 2x2 matrix (with all values 3)
|
||||
assert_eq!(a.mul(&b).unwrap(), expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transpose() -> Result<()> {
|
||||
let a = Matrix::new(2, 3, 1);
|
||||
let expected = Matrix::new(3, 2, 1);
|
||||
assert_eq!(a.transpose(), expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_mul() -> Result<()> {
|
||||
let a = Matrix::new(2, 3, 1);
|
||||
let expected = Matrix::new(2, 3, 3);
|
||||
assert_eq!(a.scalar_mul(3), expected);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -53,6 +53,29 @@ impl<const N: usize> R<N> {
|
||||
.collect();
|
||||
crate::Rq::<Q, N>::from_vec_f64(r)
|
||||
}
|
||||
|
||||
pub fn infinity_norm(&self) -> u64 {
|
||||
self.coeffs()
|
||||
.iter()
|
||||
// .map(|x| if x.0 > (Q / 2) { Q - x.0 } else { x.0 })
|
||||
.map(|x| x.abs() as u64)
|
||||
.fold(0, |a, b| a.max(b))
|
||||
}
|
||||
pub fn mod_centered_q<const Q: u64>(&self) -> R<N> {
|
||||
let q = Q as i64;
|
||||
let r = self
|
||||
.0
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let mut res = v % q;
|
||||
if res > q / 2 {
|
||||
res = res - q;
|
||||
}
|
||||
res
|
||||
})
|
||||
.collect::<Vec<i64>>();
|
||||
R::<N>::from_vec(r)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mul_div_round<const Q: u64, const N: usize>(
|
||||
|
||||
@@ -52,6 +52,9 @@ impl<const Q: u64, const N: usize> Rq<Q, N> {
|
||||
pub fn coeffs(&self) -> [Zq<Q>; N] {
|
||||
self.coeffs
|
||||
}
|
||||
pub fn compute_evals(&mut self) {
|
||||
self.evals = Some(NTT::<Q, N>::ntt(self.coeffs));
|
||||
}
|
||||
pub fn to_r(self) -> crate::R<N> {
|
||||
crate::R::<N>::from(self)
|
||||
}
|
||||
@@ -131,6 +134,15 @@ impl<const Q: u64, const N: usize> Rq<Q, N> {
|
||||
pub fn remodule<const P: u64>(&self) -> Rq<P, N> {
|
||||
Rq::<P, N>::from_vec_u64(self.coeffs().iter().map(|m_i| m_i.0).collect())
|
||||
}
|
||||
|
||||
/// perform the mod switch operation from Q to Q', where Q2=Q'
|
||||
fn mod_switch<const Q2: u64>(&self) -> Rq<Q2, N> {
|
||||
Rq::<Q2, N> {
|
||||
coeffs: array::from_fn(|i| self.coeffs[i].mod_switch::<Q2>()),
|
||||
evals: None,
|
||||
}
|
||||
}
|
||||
|
||||
// applies mod(T) to all coefficients of self
|
||||
pub fn coeffs_mod<const T: u64>(&self) -> Self {
|
||||
Rq::<Q, N>::from_vec_u64(
|
||||
@@ -236,6 +248,9 @@ impl<const Q: u64, const N: usize> Rq<Q, N> {
|
||||
.map(|x| if x.0 > (Q / 2) { Q - x.0 } else { x.0 })
|
||||
.fold(0, |a, b| a.max(b))
|
||||
}
|
||||
pub fn mod_centered_q(&self) -> crate::ring::R<N> {
|
||||
self.to_r().mod_centered_q::<Q>()
|
||||
}
|
||||
}
|
||||
pub fn matrix_vec_product<const Q: u64>(m: &Vec<Vec<Zq<Q>>>, v: &Vec<Zq<Q>>) -> Result<Vec<Zq<Q>>> {
|
||||
// assert_eq!(m.len(), m[0].len()); // TODO change to returning err
|
||||
@@ -399,6 +414,7 @@ impl<const Q: u64, const N: usize> ops::Neg for Rq<Q, N> {
|
||||
}
|
||||
}
|
||||
|
||||
// note: this assumes that Q is prime
|
||||
fn mul_mut<const Q: u64, const N: usize>(lhs: &mut Rq<Q, N>, rhs: &mut Rq<Q, N>) -> Rq<Q, N> {
|
||||
// reuse evaluations if already computed
|
||||
if !lhs.evals.is_some() {
|
||||
@@ -414,6 +430,8 @@ fn mul_mut<const Q: u64, const N: usize>(lhs: &mut Rq<Q, N>, rhs: &mut Rq<Q, N>)
|
||||
let c = NTT::<Q, { N }>::intt(c_ntt);
|
||||
Rq::new(c, Some(c_ntt))
|
||||
}
|
||||
// note: this assumes that Q is prime
|
||||
// TODO impl karatsuba for non-prime Q
|
||||
fn mul<const Q: u64, const N: usize>(lhs: &Rq<Q, N>, rhs: &Rq<Q, N>) -> Rq<Q, N> {
|
||||
// reuse evaluations if already computed
|
||||
let lhs_evals = if lhs.evals.is_some() {
|
||||
|
||||
@@ -117,6 +117,11 @@ impl<const Q: u64> Zq<Q> {
|
||||
(g, y - (b / a) * x, x)
|
||||
}
|
||||
}
|
||||
|
||||
/// perform the mod switch operation from Q to Q', where Q2=Q'
|
||||
pub fn mod_switch<const Q2: u64>(&self) -> Zq<Q2> {
|
||||
Zq::<Q2>::from_u64(((self.0 as f64 * Q2 as f64) / Q as f64).round() as u64)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const Q: u64> Zq<Q> {
|
||||
|
||||
Reference in New Issue
Block a user