From ca839756d00c34ccb5425727b89436b8e5f93017 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Tue, 23 Apr 2024 17:31:55 +0530 Subject: [PATCH] add basics --- .gitignore | 1 + Cargo.lock | 246 ++++++++++++++++++++++++ Cargo.toml | 12 ++ LICENSE | 21 +++ rustfmt.toml | 1 + src/backend.rs | 163 ++++++++++++++++ src/decomposer.rs | 164 ++++++++++++++++ src/lib.rs | 101 ++++++++++ src/lwe.rs | 287 ++++++++++++++++++++++++++++ src/main.rs | 3 + src/ntt.rs | 408 ++++++++++++++++++++++++++++++++++++++++ src/num.rs | 3 + src/random.rs | 180 ++++++++++++++++++ src/rgsw.rs | 466 ++++++++++++++++++++++++++++++++++++++++++++++ src/utils.rs | 191 +++++++++++++++++++ 15 files changed, 2247 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 LICENSE create mode 100644 rustfmt.toml create mode 100644 src/backend.rs create mode 100644 src/decomposer.rs create mode 100644 src/lib.rs create mode 100644 src/lwe.rs create mode 100644 src/main.rs create mode 100644 src/ntt.rs create mode 100644 src/num.rs create mode 100644 src/random.rs create mode 100644 src/rgsw.rs create mode 100644 src/utils.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/target diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..cb828f8 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,246 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" + +[[package]] +name = "bin-rs" +version = "0.1.0" +dependencies = [ + "itertools", + "num-bigint-dig", + "num-traits", + "rand", + "rand_chacha", + "rand_distr", +] + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "either" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" + +[[package]] +name = "getrandom" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +dependencies = [ + "spin", +] + +[[package]] +name = "libc" +version = "0.2.153" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "num-bigint-dig" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +dependencies = [ + "byteorder", + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand", + "serde", + "smallvec", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d869c01cc0c455284163fd0092f1f93835385ccab5a98a0dcc497b2f8bf055a9" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro2" +version = "1.0.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "serde" +version = "1.0.198" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.198" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + +[[package]] +name = "syn" +version = "2.0.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..74fc978 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "bin-rs" +version = "0.1.0" +edition = "2021" + +[dependencies] +itertools = "0.12.0" +num-traits = "0.2.18" +rand = "0.8.5" +rand_chacha = "0.3.1" +rand_distr = "0.4.3" +num-bigint-dig = { version = "0.8.4", features = ["prime"] } diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..bc00c26 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Gauss labs + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..606e292 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +wrap_comments = true \ No newline at end of file diff --git a/src/backend.rs b/src/backend.rs new file mode 100644 index 0000000..2b3f92c --- /dev/null +++ b/src/backend.rs @@ -0,0 +1,163 @@ +use itertools::izip; + +pub trait VectorOps { + type Element; + + fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element); + fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]); + + fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]); + fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]); + fn elwise_neg_mut(&self, a: &mut [Self::Element]); + /// inplace mutates `a`: a = a + b*c + fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]); + + fn modulus(&self) -> Self::Element; +} + +pub trait ArithmeticOps { + type Element; + + fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; + fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; + fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element; + + fn modulus(&self) -> Self::Element; +} + +pub struct ModularOpsU64 { + q: u64, + logq: usize, + barrett_mu: u128, + barrett_alpha: usize, +} + +impl ModularOpsU64 { + pub fn new(q: u64) -> ModularOpsU64 { + let logq = 64 - q.leading_zeros(); + + // barrett calculation + let mu = (1u128 << (logq * 2 + 3)) / (q as u128); + let alpha = logq + 3; + + ModularOpsU64 { + q, + logq: logq as usize, + barrett_alpha: alpha as usize, + barrett_mu: mu, + } + } + + fn add_mod_fast(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < self.q); + debug_assert!(b < self.q); + + let mut o = a + b; + if o >= self.q { + o -= self.q; + } + o + } + + fn sub_mod_fast(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < self.q); + debug_assert!(b < self.q); + + if a > b { + a - b + } else { + (self.q + a) - b + } + } + + /// returns (a * b) % q + /// + /// - both a and b must be in range [0, 2q) + /// - output is in range [0 , q) + fn mul_mod_fast(&self, a: u64, b: u64) -> u64 { + debug_assert!(a < 2 * self.q); + debug_assert!(b < 2 * self.q); + + let ab = a as u128 * b as u128; + + // ab / (2^{n + \beta}) + // note: \beta is assumed to -2 + let tmp = ab >> (self.logq - 2); + + // k = ((ab / (2^{n + \beta})) * \mu) / 2^{\alpha - (-2)} + let k = (tmp * self.barrett_mu) >> (self.barrett_alpha + 2); + + // ab - k*p + let tmp = k * (self.q as u128); + + let mut out = (ab - tmp) as u64; + + if out >= self.q { + out -= self.q; + } + + return out; + } +} + +impl ArithmeticOps for ModularOpsU64 { + type Element = u64; + + fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + self.add_mod_fast(*a, *b) + } + + fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + self.mul_mod_fast(*a, *b) + } + + fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { + self.sub_mod_fast(*a, *b) + } + + fn modulus(&self) -> Self::Element { + self.q + } +} + +impl VectorOps for ModularOpsU64 { + type Element = u64; + + fn elwise_add_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { + *ai = self.add_mod_fast(*ai, *bi); + }); + } + + fn elwise_mul_mut(&self, a: &mut [Self::Element], b: &[Self::Element]) { + izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| { + *ai = self.mul_mod_fast(*ai, *bi); + }); + } + + fn elwise_neg_mut(&self, a: &mut [Self::Element]) { + a.iter_mut().for_each(|ai| *ai = self.q - *ai); + } + + fn elwise_scalar_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &Self::Element) { + izip!(out.iter_mut(), a.iter()).for_each(|(oi, ai)| { + *oi = self.mul_mod_fast(*ai, *b); + }); + } + + fn elwise_mul(&self, out: &mut [Self::Element], a: &[Self::Element], b: &[Self::Element]) { + izip!(out.iter_mut(), a.iter(), b.iter()).for_each(|(oi, ai, bi)| { + *oi = self.mul_mod_fast(*ai, *bi); + }); + } + + fn elwise_fma_mut(&self, a: &mut [Self::Element], b: &[Self::Element], c: &[Self::Element]) { + izip!(a.iter_mut(), b.iter(), c.iter()).for_each(|(ai, bi, ci)| { + *ai = self.add_mod_fast(*ai, self.mul_mod_fast(*bi, *ci)); + }); + } + + fn modulus(&self) -> Self::Element { + self.q + } +} diff --git a/src/decomposer.rs b/src/decomposer.rs new file mode 100644 index 0000000..db9238e --- /dev/null +++ b/src/decomposer.rs @@ -0,0 +1,164 @@ +use itertools::Itertools; +use num_traits::{AsPrimitive, One, PrimInt, ToPrimitive, WrappingSub, Zero}; +use std::{fmt::Debug, marker::PhantomData, ops::Rem}; + +use crate::backend::{ArithmeticOps, ModularOpsU64}; + +pub fn gadget_vector(logq: usize, logb: usize, d: usize) -> Vec { + let d_ideal = (logq as f64 / logb as f64).ceil().to_usize().unwrap(); + let ignored_limbs = d_ideal - d; + (ignored_limbs..ignored_limbs + d) + .into_iter() + .map(|i| T::one() << (logb * i)) + .collect_vec() +} + +pub trait Decomposer { + type Element; + //FIXME(Jay): there's no reason why it returns a vec instead of an iterator + fn decompose(&self, v: &Self::Element) -> Vec; + fn d(&self) -> usize; +} + +pub struct DefaultDecomposer { + q: T, + logq: usize, + logb: usize, + d: usize, + ignore_bits: usize, + ignore_limbs: usize, +} + +pub trait NumInfo { + const BITS: u32; +} + +impl NumInfo for u64 { + const BITS: u32 = u64::BITS; +} +impl NumInfo for u32 { + const BITS: u32 = u32::BITS; +} +impl NumInfo for u128 { + const BITS: u32 = u128::BITS; +} + +impl DefaultDecomposer { + pub fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer { + // if q is power of 2, then BITS - leading zeros outputs logq + 1. + let logq = if q & (q - T::one()) == T::zero() { + (T::BITS - q.leading_zeros() - 1) as usize + } else { + (T::BITS - q.leading_zeros()) as usize + }; + + let d_ideal = (logq as f64 / logb as f64).ceil().to_usize().unwrap(); + let ignore_limbs = (d_ideal - d); + let ignore_bits = (d_ideal - d) * logb; + + DefaultDecomposer { + q, + logq, + logb, + d, + ignore_bits, + ignore_limbs, + } + } + + fn recompose(&self, limbs: &[T], modq_op: &Op) -> T + where + Op: ArithmeticOps, + { + let mut value = T::zero(); + for i in self.ignore_limbs..self.ignore_limbs + self.d { + value = modq_op.add( + &value, + &(modq_op.mul(&limbs[i], &(T::one() << (self.logb * i)))), + ) + } + value + } +} + +impl Decomposer for DefaultDecomposer { + type Element = T; + fn decompose(&self, value: &T) -> Vec { + let value = round_value(*value, self.ignore_bits); + + let q = self.q; + let logb = self.logb; + // let b = T::one() << logb; // base + let b_by2 = T::one() << (logb - 1); + // let neg_b_by2_modq = q - b_by2; + let full_mask = (T::one() << logb) - T::one(); + // let half_mask = b_by2 - T::one(); + let mut carry = T::zero(); + let mut out = Vec::::with_capacity(self.d); + for i in 0..self.d { + let mut limb = ((value >> (logb * i)) & full_mask) + carry; + + carry = limb & b_by2; + limb = (q + limb) - (carry << 1); + if limb > q { + limb = limb - q; + } + out.push(limb); + + carry = carry >> (logb - 1); + } + + return out; + } + + fn d(&self) -> usize { + self.d + } +} + +fn round_value(value: T, ignore_bits: usize) -> T { + if ignore_bits == 0 { + return value; + } + + let ignored_msb = (value & ((T::one() << ignore_bits) - T::one())) >> (ignore_bits - 1); + (value >> ignore_bits) + ignored_msb +} + +#[cfg(test)] +mod tests { + use rand::{thread_rng, Rng}; + + use crate::{backend::ModularOpsU64, decomposer::round_value, utils::generate_prime}; + + use super::{Decomposer, DefaultDecomposer}; + + #[test] + fn decomposition_works() { + let logq = 50; + let logb = 5; + let d = 10; + + // q is prime of bits logq and i is true, other q = 1< { + type MatElement; + type R: Row; + + fn dimension(&self) -> (usize, usize); + + fn get_row(&self, row_idx: usize) -> impl Iterator { + self.as_ref()[row_idx].as_ref().iter().map(move |r| r) + } + + fn get_row_slice(&self, row_idx: usize) -> &[Self::MatElement] { + self.as_ref()[row_idx].as_ref() + } + + fn iter_rows(&self) -> impl Iterator { + self.as_ref().iter().map(move |r| r) + } + + fn get(&self, row_idx: usize, column_idx: usize) -> &Self::MatElement { + &self.as_ref()[row_idx].as_ref()[column_idx] + } +} + +pub trait MatrixMut: Matrix + AsMut<[::R]> +where + ::R: RowMut, +{ + fn get_row_mut(&mut self, row_index: usize) -> &mut [Self::MatElement] { + self.as_mut()[row_index].as_mut() + } + + fn iter_rows_mut(&mut self) -> impl Iterator { + self.as_mut().iter_mut().map(move |r| r) + } + + fn set(&mut self, row_idx: usize, column_idx: usize, val: ::MatElement) { + self.as_mut()[row_idx].as_mut()[column_idx] = val; + } + + fn split_at_row( + &mut self, + idx: usize, + ) -> (&mut [::R], &mut [::R]) { + self.as_mut().split_at_mut(idx) + } +} + +pub trait MatrixEntity: Matrix // where +// ::MatElement: Zero, +{ + fn zeros(row: usize, col: usize) -> Self; +} + +pub trait Row: AsRef<[Self::Element]> { + type Element; +} + +pub trait RowMut: Row + AsMut<[::Element]> {} + +trait Secret { + type Element; + fn values(&self) -> &[Self::Element]; +} + +impl Matrix for Vec> { + type MatElement = T; + type R = Vec; + + fn dimension(&self) -> (usize, usize) { + (self.len(), self[0].len()) + } +} + +impl MatrixMut for Vec> {} + +impl MatrixEntity for Vec> { + fn zeros(row: usize, col: usize) -> Self { + vec![vec![T::zero(); col]; row] + } +} + +impl Row for Vec { + type Element = T; +} + +impl RowMut for Vec {} diff --git a/src/lwe.rs b/src/lwe.rs new file mode 100644 index 0000000..bb3cf35 --- /dev/null +++ b/src/lwe.rs @@ -0,0 +1,287 @@ +use std::fmt::Debug; + +use itertools::{izip, Itertools}; +use num_traits::{abs, Zero}; + +use crate::{ + backend::{ArithmeticOps, VectorOps}, + decomposer::Decomposer, + lwe, + num::UnsignedInteger, + random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist, DEFAULT_RNG}, + utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom, WithLocal}, + Matrix, MatrixEntity, MatrixMut, Row, RowMut, Secret, +}; + +trait LweKeySwitchParameters { + fn n_in(&self) -> usize; + fn n_out(&self) -> usize; + fn d_ks(&self) -> usize; +} + +trait LweCiphertext {} + +struct LweSecret { + values: Vec, +} + +impl Secret for LweSecret { + type Element = i32; + fn values(&self) -> &[Self::Element] { + &self.values + } +} + +impl LweSecret { + fn random(hw: usize, n: usize) -> LweSecret { + DefaultSecureRng::with_local_mut(|rng| { + let mut out = vec![0i32; n]; + fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng); + + LweSecret { values: out } + }) + } +} + +fn lwe_key_switch< + M: Matrix, + Mmut: MatrixMut + MatrixEntity, + Op: VectorOps + ArithmeticOps, + D: Decomposer, +>( + lwe_out: &mut Mmut, + lwe_in: &M, + lwe_ksk: &M, + operator: &Op, + decomposer: &D, +) where + ::R: RowMut, +{ + assert!(lwe_ksk.dimension().0 == ((lwe_in.dimension().1 - 1) * decomposer.d())); + assert!(lwe_out.dimension() == (1, lwe_ksk.dimension().1)); + + let mut scratch_space = Mmut::zeros(1, lwe_out.dimension().1); + + let lwe_in_a_decomposed = lwe_in + .get_row(0) + .skip(1) + .flat_map(|ai| decomposer.decompose(ai)); + izip!(lwe_in_a_decomposed, lwe_ksk.iter_rows()).for_each(|(ai_j, beta_ij_lwe)| { + operator.elwise_scalar_mul(scratch_space.get_row_mut(0), beta_ij_lwe.as_ref(), &ai_j); + operator.elwise_add_mut(lwe_out.get_row_mut(0), scratch_space.get_row_slice(0)) + }); + + let out_b = operator.add(lwe_out.get(0, 0), lwe_in.get(0, 0)); + lwe_out.set(0, 0, out_b); +} + +fn lwe_ksk_keygen< + Mmut: MatrixMut, + S: Secret, + Op: VectorOps + ArithmeticOps, + R: RandomGaussianDist + + RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>, +>( + lwe_sk_in: &S, + lwe_sk_out: &S, + ksk_out: &mut Mmut, + gadget: &[Mmut::MatElement], + operator: &Op, + rng: &mut R, +) where + ::R: RowMut, + Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, + Mmut::MatElement: Zero + Debug, +{ + assert!( + ksk_out.dimension() + == ( + lwe_sk_in.values().len() * gadget.len(), + lwe_sk_out.values().len() + 1, + ) + ); + + let d = gadget.len(); + + let modulus = VectorOps::modulus(operator); + let mut neg_sk_in_m = Mmut::try_convert_from(lwe_sk_in.values(), &modulus); + operator.elwise_neg_mut(neg_sk_in_m.get_row_mut(0)); + let sk_out_m = Mmut::try_convert_from(lwe_sk_out.values(), &modulus); + + izip!( + neg_sk_in_m.get_row(0), + ksk_out.iter_rows_mut().chunks(d).into_iter() + ) + .for_each(|(neg_sk_in_si, d_ks_lwes)| { + izip!(gadget.iter(), d_ks_lwes.into_iter()).for_each(|(f, lwe)| { + // sample `a` + RandomUniformDist::random_fill(rng, &modulus, &mut lwe.as_mut()[1..]); + + // a * z + let mut az = Mmut::MatElement::zero(); + izip!(lwe.as_ref()[1..].iter(), sk_out_m.get_row(0)).for_each(|(ai, si)| { + let ai_si = operator.mul(ai, si); + az = operator.add(&az, &ai_si); + }); + + // a*z + (-s_i)*\beta^j + e + let mut b = operator.add(&az, &operator.mul(f, neg_sk_in_si)); + let mut e = Mmut::MatElement::zero(); + RandomGaussianDist::random_fill(rng, &modulus, &mut e); + b = operator.add(&b, &e); + + lwe.as_mut()[0] = b; + + // dbg!(&lwe.as_mut(), &f); + }) + }); +} + +/// Encrypts encoded message m as LWE ciphertext +fn encrypt_lwe< + Mmut: MatrixMut + MatrixEntity, + R: RandomGaussianDist + + RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>, + S: Secret, + Op: ArithmeticOps, +>( + lwe_out: &mut Mmut, + m: Mmut::MatElement, + s: &S, + operator: &Op, + rng: &mut R, +) where + Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, + Mmut::MatElement: Zero, + ::R: RowMut, +{ + let s = Mmut::try_convert_from(s.values(), &operator.modulus()); + assert!(s.dimension().0 == (lwe_out.dimension().0)); + assert!(s.dimension().1 == (lwe_out.dimension().1 - 1)); + + // a*s + RandomUniformDist::random_fill(rng, &operator.modulus(), &mut lwe_out.get_row_mut(0)[1..]); + let mut sa = Mmut::MatElement::zero(); + izip!(lwe_out.get_row(0).skip(1), s.get_row(0)).for_each(|(ai, si)| { + let tmp = operator.mul(ai, si); + sa = operator.add(&tmp, &sa); + }); + + // b = a*s + e + m + let mut e = Mmut::MatElement::zero(); + RandomGaussianDist::random_fill(rng, &operator.modulus(), &mut e); + let b = operator.add(&operator.add(&sa, &e), &m); + lwe_out.set(0, 0, b); +} + +fn decrypt_lwe, S: Secret>( + lwe_ct: &M, + s: &S, + operator: &Op, +) -> M::MatElement +where + M: TryConvertFrom<[S::Element], Parameters = M::MatElement>, + M::MatElement: Zero, +{ + let s = M::try_convert_from(s.values(), &operator.modulus()); + + let mut sa = M::MatElement::zero(); + izip!(lwe_ct.get_row(0).skip(1), s.get_row(0)).for_each(|(ai, si)| { + let tmp = operator.mul(ai, si); + sa = operator.add(&tmp, &sa); + }); + + let b = &lwe_ct.get_row_slice(0)[0]; + operator.sub(b, &sa) +} + +#[cfg(test)] +mod tests { + + use crate::{ + backend::ModularOpsU64, + decomposer::{gadget_vector, DefaultDecomposer}, + lwe::lwe_key_switch, + random::DefaultSecureRng, + }; + + use super::{decrypt_lwe, encrypt_lwe, lwe_ksk_keygen, LweSecret}; + + #[test] + fn encrypt_decrypt_works() { + let logq = 20; + let q = 1u64 << logq; + let lwe_n = 1024; + let logp = 3; + + let modq_op = ModularOpsU64::new(q); + let lwe_sk = LweSecret::random(lwe_n >> 1, lwe_n); + + let mut rng = DefaultSecureRng::new(); + + // encrypt + for m in 0..1u64 << logp { + let encoded_m = m << (logq - logp); + let mut lwe_ct = vec![vec![0u64; lwe_n + 1]]; + encrypt_lwe(&mut lwe_ct, encoded_m, &lwe_sk, &modq_op, &mut rng); + let encoded_m_back = decrypt_lwe(&lwe_ct, &lwe_sk, &modq_op); + let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round() + as u64) + % (1u64 << logp); + assert_eq!(m, m_back, "Expected {m} but got {m_back}"); + } + } + + #[test] + fn key_switch_works() { + let logq = 16; + let logp = 3; + let q = 1u64 << logq; + let lwe_in_n = 1024; + let lwe_out_n = 470; + let d_ks = 3; + let logb = 4; + + let lwe_sk_in = LweSecret::random(lwe_in_n >> 1, lwe_in_n); + let lwe_sk_out = LweSecret::random(lwe_out_n >> 1, lwe_out_n); + + let mut rng = DefaultSecureRng::new(); + let modq_op = ModularOpsU64::new(q); + + // genrate ksk + for _ in 0..10 { + let mut ksk = vec![vec![0u64; lwe_out_n + 1]; d_ks * lwe_in_n]; + let gadget = gadget_vector(logq, logb, d_ks); + lwe_ksk_keygen( + &lwe_sk_in, + &lwe_sk_out, + &mut ksk, + &gadget, + &modq_op, + &mut rng, + ); + // println!("{:?}", ksk); + + for m in 0..(1 << logp) { + // encrypt using lwe_sk_in + let encoded_m = m << (logq - logp); + let mut lwe_in_ct = vec![vec![0u64; lwe_in_n + 1]]; + encrypt_lwe(&mut lwe_in_ct, encoded_m, &lwe_sk_in, &modq_op, &mut rng); + + // key switch from lwe_sk_in to lwe_sk_out + let decomposer = DefaultDecomposer::new(1u64 << logq, logb, d_ks); + let mut lwe_out_ct = vec![vec![0u64; lwe_out_n + 1]]; + lwe_key_switch(&mut lwe_out_ct, &lwe_in_ct, &ksk, &modq_op, &decomposer); + + // decrypt lwe_out_ct using lwe_sk_out + let encoded_m_back = decrypt_lwe(&lwe_out_ct, &lwe_sk_out, &modq_op); + let m_back = ((((encoded_m_back as f64) * ((1 << logp) as f64)) / q as f64).round() + as u64) + % (1u64 << logp); + assert_eq!(m, m_back, "Expected {m} but got {m_back}"); + // dbg!(m, m_back); + // dbg!(encoded_m, encoded_m_back); + } + } + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..e7a11a9 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello, world!"); +} diff --git a/src/ntt.rs b/src/ntt.rs new file mode 100644 index 0000000..320a28c --- /dev/null +++ b/src/ntt.rs @@ -0,0 +1,408 @@ +use itertools::Itertools; +use rand::{thread_rng, Rng, RngCore}; + +use crate::{ + backend::{ArithmeticOps, ModularOpsU64}, + utils::{mod_exponent, mod_inverse, shoup_representation_fq}, +}; + +pub trait Ntt { + type Element; + fn forward_lazy(&self, v: &mut [Self::Element]); + fn forward(&self, v: &mut [Self::Element]); + fn backward_lazy(&self, v: &mut [Self::Element]); + fn backward(&self, v: &mut [Self::Element]); +} + +/// Forward butterfly routine for Number theoretic transform. Given inputs `x < +/// 4q` and `y < 4q` mutates x and y in place to equal x' and y' where +/// x' = x + wy +/// y' = x - wy +/// and both x' and y' are \in [0, 4q) +/// +/// Implements Algorithm 4 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf) +pub unsafe fn forward_butterly( + x: *mut u64, + y: *mut u64, + w: &u64, + w_shoup: &u64, + q: &u64, + q_twice: &u64, +) { + debug_assert!(*x < *q * 4, "{} >= (4q){}", *x, 4 * q); + debug_assert!(*y < *q * 4, "{} >= (4q){}", *y, 4 * q); + + if *x >= *q_twice { + *x = *x - q_twice; + } + + // TODO (Jay): Hot path expected. How expensive is it? + let k = ((*w_shoup as u128 * *y as u128) >> 64) as u64; + let t = w.wrapping_mul(*y).wrapping_sub(k.wrapping_mul(*q)); + + *y = *x + q_twice - t; + *x = *x + t; +} + +/// Inverse butterfly routine of Inverse Number theoretic transform. Given +/// inputs `x < 2q` and `y < 2q` mutates x and y to equal x' and y' where +/// x'= x + y +/// y' = w(x - y) +/// and both x' and y' are \in [0, 2q) +/// +/// Implements Algorithm 3 of [FASTER ARITHMETIC FOR NUMBER-THEORETIC TRANSFORMS](https://arxiv.org/pdf/1205.2926.pdf) +pub unsafe fn inverse_butterfly( + x: *mut u64, + y: *mut u64, + w_inv: &u64, + w_inv_shoup: &u64, + q: &u64, + q_twice: &u64, +) { + debug_assert!(*x < *q_twice, "{} >= (2q){q_twice}", *x); + debug_assert!(*y < *q_twice, "{} >= (2q){q_twice}", *y); + + let mut x_dash = *x + *y; + if x_dash >= *q_twice { + x_dash -= q_twice + } + + let t = *x + q_twice - *y; + let k = ((*w_inv_shoup as u128 * t as u128) >> 64) as u64; // TODO (Jay): Hot path + *y = w_inv.wrapping_mul(t).wrapping_sub(k.wrapping_mul(*q)); + + *x = x_dash; +} + +/// Number theoretic transform of vector `a` where each element can be in range +/// [0, 2q). Outputs NTT(a) where each element is in range [0,2q) +/// +/// Implements Cooley-tukey based forward NTT as given in Algorithm 1 of https://eprint.iacr.org/2016/504.pdf. +pub fn ntt_lazy(a: &mut [u64], psi: &[u64], psi_shoup: &[u64], q: u64, q_twice: u64) { + debug_assert!(a.len() == psi.len()); + + let n = a.len(); + let mut t = n; + + let mut m = 1; + while m < n { + t >>= 1; + + for i in 0..m { + let j_1 = 2 * i * t; + let j_2 = j_1 + t; + + unsafe { + let w = psi.get_unchecked(m + i); + let w_shoup = psi_shoup.get_unchecked(m + i); + for j in j_1..j_2 { + let x = a.get_unchecked_mut(j) as *mut u64; + let y = a.get_unchecked_mut(j + t) as *mut u64; + forward_butterly(x, y, w, w_shoup, &q, &q_twice); + } + } + } + + m <<= 1; + } + + a.iter_mut().for_each(|a0| { + if *a0 >= q_twice { + *a0 -= q_twice + } + }); +} + +/// Inverse number theoretic transform of input vector `a` with each element can +/// be in range [0, 2q). Outputs vector INTT(a) with each element in range [0, +/// 2q) +/// +/// Implements backward number theorectic transform using GS algorithm as given in Algorithm 2 of https://eprint.iacr.org/2016/504.pdf +pub fn ntt_inv_lazy( + a: &mut [u64], + psi_inv: &[u64], + psi_inv_shoup: &[u64], + n_inv: u64, + q: u64, + q_twice: u64, +) { + debug_assert!(a.len() == psi_inv.len()); + + let mut m = a.len(); + let mut t = 1; + while m > 1 { + let mut j_1: usize = 0; + let h = m >> 1; + for i in 0..h { + let j_2 = j_1 + t; + unsafe { + let w_inv = psi_inv.get_unchecked(h + i); + let w_inv_shoup = psi_inv_shoup.get_unchecked(h + i); + + for j in j_1..j_2 { + let x = a.get_unchecked_mut(j) as *mut u64; + let y = a.get_unchecked_mut(j + t) as *mut u64; + inverse_butterfly(x, y, w_inv, w_inv_shoup, &q, &q_twice); + } + } + j_1 = j_1 + 2 * t; + } + t *= 2; + m >>= 1; + } + + a.iter_mut() + .for_each(|a0| *a0 = ((*a0 as u128 * n_inv as u128) % q as u128) as u64); +} + +/// Find n^{th} root of unity in field F_q, if one exists +/// +/// Note: n^{th} root of unity exists if and only if $q = 1 \mod{n}$ +pub(crate) fn find_primitive_root(q: u64, n: u64, rng: &mut R) -> Option { + assert!(n.is_power_of_two(), "{n} is not power of two"); + + // n^th root of unity only exists if n|(q-1) + assert!(q % n == 1, "{n}^th root of unity in F_{q} does not exists"); + + let t = (q - 1) / n; + + for _ in 0..100 { + let mut omega = rng.gen::() % q; + + // \omega = \omega^t. \omega is now n^th root of unity + omega = mod_exponent(omega, t, q); + + // We restrict n to be power of 2. Thus checking whether \omega is primitive + // n^th root of unity is as simple as checking: \omega^{n/2} != 1 + if mod_exponent(omega, n >> 1, q) == 1 { + continue; + } else { + return Some(omega); + } + } + + None +} + +pub struct NttBackendU64 { + q: u64, + q_twice: u64, + n: u64, + n_inv: u64, + psi_powers_bo: Box<[u64]>, + psi_inv_powers_bo: Box<[u64]>, + psi_powers_bo_shoup: Box<[u64]>, + psi_inv_powers_bo_shoup: Box<[u64]>, +} + +impl NttBackendU64 { + pub fn new(q: u64, n: usize) -> Self { + // \psi = 2n^{th} primitive root of unity in F_q + let mut rng = thread_rng(); + let psi = find_primitive_root(q, (n * 2) as u64, &mut rng) + .expect("Unable to find 2n^th root of unity"); + + let psi_inv = mod_inverse(psi, q); + + // assert!( + // ((psi_inv as u128 * psi as u128) % q as u128) == 1, + // "psi:{psi}, psi_inv:{psi_inv}" + // ); + + let modulus = ModularOpsU64::new(q); + + let mut psi_powers = Vec::with_capacity(n as usize); + let mut psi_inv_powers = Vec::with_capacity(n as usize); + let mut running_psi = 1; + let mut running_psi_inv = 1; + for _ in 0..n { + psi_powers.push(running_psi); + psi_inv_powers.push(running_psi_inv); + + running_psi = modulus.mul(&running_psi, &psi); + running_psi_inv = modulus.mul(&running_psi_inv, &psi_inv); + } + + // powers stored in bit reversed order + let mut psi_powers_bo = vec![0u64; n as usize]; + let mut psi_inv_powers_bo = vec![0u64; n as usize]; + let shift_by = n.leading_zeros() + 1; + for i in 0..n as usize { + // i in bit reversed order + let bo_index = i.reverse_bits() >> shift_by; + + psi_powers_bo[bo_index] = psi_powers[i]; + psi_inv_powers_bo[bo_index] = psi_inv_powers[i]; + } + + // shoup representation + let psi_powers_bo_shoup = psi_powers_bo + .iter() + .map(|v| shoup_representation_fq(*v, q)) + .collect_vec(); + let psi_inv_powers_bo_shoup = psi_inv_powers_bo + .iter() + .map(|v| shoup_representation_fq(*v, q)) + .collect_vec(); + + // n^{-1} \mod{q} + let n_inv = mod_inverse(n as u64, q); + + NttBackendU64 { + q, + q_twice: 2 * q, + n: n as u64, + n_inv, + psi_powers_bo: psi_powers_bo.into_boxed_slice(), + psi_inv_powers_bo: psi_inv_powers_bo.into_boxed_slice(), + psi_powers_bo_shoup: psi_powers_bo_shoup.into_boxed_slice(), + psi_inv_powers_bo_shoup: psi_inv_powers_bo_shoup.into_boxed_slice(), + } + } +} + +impl NttBackendU64 { + fn reduce_from_lazy(&self, a: &mut [u64]) { + let q = self.q; + a.iter_mut().for_each(|a0| { + if *a0 >= q { + *a0 = *a0 - q; + } + }); + } +} + +impl Ntt for NttBackendU64 { + type Element = u64; + + fn forward_lazy(&self, v: &mut [Self::Element]) { + ntt_lazy( + v, + &self.psi_powers_bo, + &self.psi_powers_bo_shoup, + self.q, + self.q_twice, + ) + } + + fn forward(&self, v: &mut [Self::Element]) { + ntt_lazy( + v, + &self.psi_powers_bo, + &self.psi_powers_bo_shoup, + self.q, + self.q_twice, + ); + self.reduce_from_lazy(v); + } + + fn backward_lazy(&self, v: &mut [Self::Element]) { + ntt_inv_lazy( + v, + &self.psi_inv_powers_bo, + &self.psi_inv_powers_bo_shoup, + self.n_inv, + self.q, + self.q_twice, + ) + } + + fn backward(&self, v: &mut [Self::Element]) { + ntt_inv_lazy( + v, + &self.psi_inv_powers_bo, + &self.psi_inv_powers_bo_shoup, + self.n_inv, + self.q, + self.q_twice, + ); + self.reduce_from_lazy(v); + } +} + +mod tests { + use itertools::Itertools; + use rand::{thread_rng, Rng}; + use rand_distr::Uniform; + + use super::NttBackendU64; + use crate::{ + backend::{ModularOpsU64, VectorOps}, + ntt::Ntt, + utils::{generate_prime, negacyclic_mul}, + }; + + const Q_60_BITS: u64 = 1152921504606748673; + const N: usize = 1 << 4; + + const K: usize = 128; + + fn random_vec_in_fq(size: usize, q: u64) -> Vec { + thread_rng() + .sample_iter(Uniform::new(0, q)) + .take(size) + .collect_vec() + } + + #[test] + fn native_ntt_backend_works() { + // TODO(Jay): Improve tests. Add tests for different primes and ring size. + let ntt_backend = NttBackendU64::new(Q_60_BITS, N); + for _ in 0..K { + let mut a = random_vec_in_fq(N, Q_60_BITS); + let a_clone = a.clone(); + + ntt_backend.forward(&mut a); + assert_ne!(a, a_clone); + ntt_backend.backward(&mut a); + assert_eq!(a, a_clone); + + ntt_backend.forward_lazy(&mut a); + assert_ne!(a, a_clone); + ntt_backend.backward(&mut a); + assert_eq!(a, a_clone); + + ntt_backend.forward(&mut a); + ntt_backend.backward_lazy(&mut a); + // reduce + a.iter_mut().for_each(|a0| { + if *a0 > Q_60_BITS { + *a0 -= *a0 - Q_60_BITS; + } + }); + assert_eq!(a, a_clone); + } + } + + #[test] + fn native_ntt_negacylic_mul() { + let primes = [40, 50, 60] + .iter() + .map(|bits| generate_prime(*bits, (2 * N) as u64, 1u64 << bits).unwrap()) + .collect_vec(); + + for p in primes.into_iter() { + let ntt_backend = NttBackendU64::new(p, N); + let modulus_backend = ModularOpsU64::new(p); + for _ in 0..K { + let a = random_vec_in_fq(N, p); + let b = random_vec_in_fq(N, p); + + let mut a_clone = a.clone(); + let mut b_clone = b.clone(); + ntt_backend.forward_lazy(&mut a_clone); + ntt_backend.forward_lazy(&mut b_clone); + modulus_backend.elwise_mul_mut(&mut a_clone, &b_clone); + ntt_backend.backward(&mut a_clone); + + let mul = |a: &u64, b: &u64| { + let tmp = *a as u128 * *b as u128; + (tmp % p as u128) as u64 + }; + let expected_out = negacyclic_mul(&a, &b, mul, p); + + assert_eq!(a_clone, expected_out); + } + } + } +} diff --git a/src/num.rs b/src/num.rs new file mode 100644 index 0000000..14522ea --- /dev/null +++ b/src/num.rs @@ -0,0 +1,3 @@ +use num_traits::{Num, PrimInt, WrappingShl, WrappingShr, Zero}; + +pub trait UnsignedInteger: Zero + Num {} diff --git a/src/random.rs b/src/random.rs new file mode 100644 index 0000000..397585e --- /dev/null +++ b/src/random.rs @@ -0,0 +1,180 @@ +use std::cell::RefCell; + +use itertools::izip; +use rand::{distributions::Uniform, thread_rng, CryptoRng, Rng, RngCore, SeedableRng}; +use rand_chacha::ChaCha8Rng; +use rand_distr::Distribution; + +use crate::utils::WithLocal; + +thread_local! { + pub(crate) static DEFAULT_RNG: RefCell = RefCell::new(DefaultSecureRng::new()); +} + +pub trait RandomGaussianDist +where + M: ?Sized, +{ + type Parameters: ?Sized; + fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut M); +} + +pub trait RandomUniformDist +where + M: ?Sized, +{ + type Parameters: ?Sized; + fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut M); +} + +pub(crate) struct DefaultSecureRng { + rng: ChaCha8Rng, +} + +impl DefaultSecureRng { + pub fn new_seeded(seed: ::Seed) -> DefaultSecureRng { + let rng = ChaCha8Rng::from_seed(seed); + DefaultSecureRng { rng } + } + + pub fn new() -> DefaultSecureRng { + let rng = ChaCha8Rng::from_entropy(); + DefaultSecureRng { rng } + } +} + +impl RandomUniformDist for DefaultSecureRng { + type Parameters = usize; + fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut usize) { + *container = self.rng.gen_range(0..*parameters); + } +} + +impl RandomUniformDist<[u8]> for DefaultSecureRng { + type Parameters = u8; + fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u8]) { + self.rng.fill_bytes(container); + } +} + +impl RandomUniformDist<[u32]> for DefaultSecureRng { + type Parameters = u32; + fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u32]) { + izip!( + (&mut self.rng).sample_iter(Uniform::new(0, parameters)), + container.iter_mut() + ) + .for_each(|(from, to)| { + *to = from; + }); + } +} + +impl RandomUniformDist<[u64]> for DefaultSecureRng { + type Parameters = u64; + fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u64]) { + izip!( + (&mut self.rng).sample_iter(Uniform::new(0, parameters)), + container.iter_mut() + ) + .for_each(|(from, to)| { + *to = from; + }); + } +} + +impl RandomGaussianDist for DefaultSecureRng { + type Parameters = u64; + fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut u64) { + let o = rand_distr::Normal::new(0.0, 3.2f64) + .unwrap() + .sample(&mut self.rng) + .round(); + + // let o = 0.0f64; + + let is_neg = o.is_sign_negative() && o != 0.0; + if is_neg { + *container = parameters - (o.abs() as u64); + } else { + *container = o as u64; + } + } +} + +impl RandomGaussianDist for DefaultSecureRng { + type Parameters = u32; + fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut u32) { + let o = rand_distr::Normal::new(0.0, 3.2f32) + .unwrap() + .sample(&mut self.rng) + .round(); + + // let o = 0.0f32; + let is_neg = o.is_sign_negative() && o != 0.0; + + if is_neg { + *container = parameters - (o.abs() as u32); + } else { + *container = o as u32; + } + } +} + +impl RandomGaussianDist<[u64]> for DefaultSecureRng { + type Parameters = u64; + fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u64]) { + izip!( + rand_distr::Normal::new(0.0, 3.2f64) + .unwrap() + .sample_iter(&mut self.rng), + container.iter_mut() + ) + .for_each(|(oi, v)| { + let oi = oi.round(); + let is_neg = oi.is_sign_negative() && oi != 0.0; + if is_neg { + *v = parameters - (oi.abs() as u64); + } else { + *v = oi as u64; + } + }); + } +} + +impl RandomGaussianDist<[u32]> for DefaultSecureRng { + type Parameters = u32; + fn random_fill(&mut self, parameters: &Self::Parameters, container: &mut [u32]) { + izip!( + rand_distr::Normal::new(0.0, 3.2f32) + .unwrap() + .sample_iter(&mut self.rng), + container.iter_mut() + ) + .for_each(|(oi, v)| { + let oi = oi.round(); + let is_neg = oi.is_sign_negative() && oi != 0.0; + if is_neg { + *v = parameters - (oi.abs() as u32); + } else { + *v = oi as u32; + } + }); + } +} + +impl WithLocal for DefaultSecureRng { + fn with_local(func: F) -> R + where + F: Fn(&Self) -> R, + { + DEFAULT_RNG.with_borrow(|r| func(r)) + } + + fn with_local_mut(func: F) -> R + where + F: Fn(&mut Self) -> R, + { + DEFAULT_RNG.with_borrow_mut(|r| func(r)) + } +} diff --git a/src/rgsw.rs b/src/rgsw.rs new file mode 100644 index 0000000..72a30cf --- /dev/null +++ b/src/rgsw.rs @@ -0,0 +1,466 @@ +use itertools::izip; + +use crate::{ + backend::VectorOps, + decomposer::{self, Decomposer}, + ntt::Ntt, + random::{DefaultSecureRng, RandomGaussianDist, RandomUniformDist}, + utils::{fill_random_ternary_secret_with_hamming_weight, TryConvertFrom, WithLocal}, + Matrix, MatrixEntity, MatrixMut, RowMut, Secret, +}; + +struct RlweSecret { + values: Vec, +} + +impl Secret for RlweSecret { + type Element = i32; + fn values(&self) -> &[Self::Element] { + &self.values + } +} + +impl RlweSecret { + fn random(hw: usize, n: usize) -> RlweSecret { + DefaultSecureRng::with_local_mut(|rng| { + let mut out = vec![0i32; n]; + fill_random_ternary_secret_with_hamming_weight(&mut out, hw, rng); + + RlweSecret { values: out } + }) + } +} + +/// Encrypts message m as a RGSW ciphertext. +/// +/// - m_eval: is `m` is evaluation domain +/// - out_rgsw: RGSW(m) is stored as single matrix of dimension (d_rgsw * 4, +/// ring_size). The matrix has the following structure [RLWE'_A(-sm) || +/// RLWE'_B(-sm) || RLWE'_A(m) || RLWE'_B(m)]^T +fn encrypt_rgsw< + Mmut: MatrixMut + MatrixEntity, + M: Matrix + Clone, + S: Secret, + R: RandomGaussianDist<[Mmut::MatElement], Parameters = Mmut::MatElement> + + RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement>, + ModOp: VectorOps, + NttOp: Ntt, +>( + out_rgsw: &mut Mmut, + m_eval: &M, + gadget_vector: &[Mmut::MatElement], + s: &S, + mod_op: &ModOp, + ntt_op: &NttOp, + rng: &mut R, +) where + ::R: RowMut, + Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, +{ + let d = gadget_vector.len(); + let q = mod_op.modulus(); + let ring_size = s.values().len(); + assert!(out_rgsw.dimension() == (d * 4, ring_size)); + assert!(m_eval.dimension() == (1, ring_size)); + + // RLWE(-sm), RLWE(-sm) + let (rlwe_dash_nsm, rlwe_dash_m) = out_rgsw.split_at_row(d * 2); + + let mut s_eval = Mmut::try_convert_from(s.values(), &q); + ntt_op.forward(s_eval.get_row_mut(0).as_mut()); + + let mut scratch_space = Mmut::zeros(1, ring_size); + + // RLWE'(-sm) + let (a_rlwe_dash_nsm, b_rlwe_dash_nsm) = rlwe_dash_nsm.split_at_mut(d); + izip!( + a_rlwe_dash_nsm.iter_mut(), + b_rlwe_dash_nsm.iter_mut(), + gadget_vector.iter() + ) + .for_each(|(ai, bi, beta_i)| { + // Sample a_i and transform to evaluation domain + RandomUniformDist::random_fill(rng, &q, ai.as_mut()); + ntt_op.forward(ai.as_mut()); + + // a_i * s + mod_op.elwise_mul( + scratch_space.get_row_mut(0), + ai.as_ref(), + s_eval.get_row_slice(0), + ); + // b_i = e_i + a_i * s + RandomGaussianDist::random_fill(rng, &q, bi.as_mut()); + ntt_op.forward(bi.as_mut()); + mod_op.elwise_add_mut(bi.as_mut(), scratch_space.get_row_slice(0)); + + // a_i + \beta_i * m + mod_op.elwise_scalar_mul( + scratch_space.get_row_mut(0), + m_eval.get_row_slice(0), + beta_i, + ); + mod_op.elwise_add_mut(ai.as_mut(), scratch_space.get_row_slice(0)); + }); + + // RLWE(m) + let (a_rlwe_dash_m, b_rlwe_dash_m) = rlwe_dash_m.split_at_mut(d); + izip!( + a_rlwe_dash_m.iter_mut(), + b_rlwe_dash_m.iter_mut(), + gadget_vector.iter() + ) + .for_each(|(ai, bi, beta_i)| { + // Sample e_i and transform to evaluation domain + RandomGaussianDist::random_fill(rng, &q, bi.as_mut()); + ntt_op.forward(bi.as_mut()); + + // beta_i * m + mod_op.elwise_scalar_mul( + scratch_space.get_row_mut(0), + m_eval.get_row_slice(0), + beta_i, + ); + // e_i + beta_i * m + mod_op.elwise_add_mut(bi.as_mut(), scratch_space.get_row_slice(0)); + + // Sample a_i and transform to evaluation domain + RandomUniformDist::random_fill(rng, &q, ai.as_mut()); + ntt_op.forward(ai.as_mut()); + + // ai * s + mod_op.elwise_mul( + scratch_space.get_row_mut(0), + ai.as_ref(), + s_eval.get_row_slice(0), + ); + + // b_i = a_i*s + e_i + beta_i*m + mod_op.elwise_add_mut(bi.as_mut(), scratch_space.get_row_slice(0)); + }); +} + +/// Returns RLWE(mm') = RLWE(m) x RGSW(m') +/// +/// - rgsw_in: RGSW(m') in evaluation domain +/// - rlwe_in_decomposed: decomposed RLWE(m) in evaluation domain +/// - rlwe_out: returned RLWE(mm') in evaluation domain +fn rlwe_in_decomposed_evaluation_domain_mul_rgsw_rlwe_out_evaluation_domain< + Mmut: MatrixMut + MatrixEntity, + M: Matrix + Clone, + ModOp: VectorOps, +>( + rgsw_in: &M, + rlwe_in_decomposed_eval: &Mmut, + rlwe_out_eval: &mut Mmut, + mod_op: &ModOp, +) where + ::R: RowMut, +{ + let ring_size = rgsw_in.dimension().1; + let d_rgsw = rgsw_in.dimension().0 / 4; + assert!(rlwe_in_decomposed_eval.dimension() == (2 * d_rgsw, ring_size)); + assert!(rlwe_out_eval.dimension() == (2, ring_size)); + + let (a_rlwe_out, b_rlwe_out) = rlwe_out_eval.split_at_row(1); + + // a * RLWE'(-sm) + let a_rlwe_dash_nsm = rgsw_in.iter_rows().take(d_rgsw); + let b_rlwe_dash_nsm = rgsw_in.iter_rows().skip(d_rgsw).take(d_rgsw); + izip!( + rlwe_in_decomposed_eval.iter_rows().take(d_rgsw), + a_rlwe_dash_nsm + ) + .for_each(|(a, b)| { + mod_op.elwise_fma_mut(a_rlwe_out[0].as_mut(), a.as_ref(), b.as_ref()); + }); + izip!( + rlwe_in_decomposed_eval.iter_rows().take(d_rgsw), + b_rlwe_dash_nsm + ) + .for_each(|(a, b)| { + mod_op.elwise_fma_mut(b_rlwe_out[0].as_mut(), a.as_ref(), b.as_ref()); + }); + + // b * RLWE'(m) + let a_rlwe_dash_m = rgsw_in.iter_rows().skip(d_rgsw * 2).take(d_rgsw); + let b_rlwe_dash_m = rgsw_in.iter_rows().skip(d_rgsw * 3); + izip!( + rlwe_in_decomposed_eval.iter_rows().skip(d_rgsw), + a_rlwe_dash_m + ) + .for_each(|(a, b)| { + mod_op.elwise_fma_mut(a_rlwe_out[0].as_mut(), a.as_ref(), b.as_ref()); + }); + izip!( + rlwe_in_decomposed_eval.iter_rows().skip(d_rgsw), + b_rlwe_dash_m + ) + .for_each(|(a, b)| { + mod_op.elwise_fma_mut(b_rlwe_out[0].as_mut(), a.as_ref(), b.as_ref()); + }); +} + +fn decompose_rlwe< + M: Matrix + Clone, + Mmut: MatrixMut + MatrixEntity, + D: Decomposer, +>( + rlwe_in: &M, + decomposer: &D, + rlwe_in_decomposed: &mut Mmut, +) where + M::MatElement: Copy, + ::R: RowMut, +{ + let d_rgsw = decomposer.d(); + let ring_size = rlwe_in.dimension().1; + assert!(rlwe_in_decomposed.dimension() == (2 * d_rgsw, ring_size)); + + // Decompose rlwe_in + for ri in 0..ring_size { + // ai + let ai_decomposed = decomposer.decompose(rlwe_in.get(0, ri)); + for j in 0..d_rgsw { + rlwe_in_decomposed.set(j, ri, ai_decomposed[j]); + } + + // bi + let bi_decomposed = decomposer.decompose(rlwe_in.get(1, ri)); + for j in 0..d_rgsw { + rlwe_in_decomposed.set(j + d_rgsw, ri, bi_decomposed[j]); + } + } +} + +/// Returns RLWE(m0m1) = RLWE(m0) x RGSW(m1) +/// +/// - rlwe_in: is RLWE(m0) with polynomials in coefficient domain +/// - rgsw_in: is RGSW(m1) with polynomials in evaluation domain +/// - rlwe_out: is output RLWE(m0m1) with polynomials in coefficient domain +/// - rlwe_in_decomposed: is a matrix of dimension (d_rgsw * 2, ring_size) used +/// as scratch space to store decomposed RLWE(m0) +fn rlwe_by_rgsw< + M: Matrix + Clone, + Mmut: MatrixMut + MatrixEntity, + D: Decomposer, + ModOp: VectorOps, + NttOp: Ntt, +>( + rlwe_in: &M, + rgsw_in: &M, + rlwe_out: &mut Mmut, + rlwe_in_decomposed: &mut Mmut, + decomposer: &D, + ntt_op: &NttOp, + mod_op: &ModOp, +) where + M::MatElement: Copy, + ::R: RowMut, +{ + decompose_rlwe(rlwe_in, decomposer, rlwe_in_decomposed); + + // transform rlwe_in decomposed to evaluation domain + rlwe_in_decomposed + .iter_rows_mut() + .for_each(|r| ntt_op.forward(r.as_mut())); + + // decomposed RLWE x RGSW + rlwe_in_decomposed_evaluation_domain_mul_rgsw_rlwe_out_evaluation_domain( + rgsw_in, + rlwe_in_decomposed, + rlwe_out, + mod_op, + ); + + // transform rlwe_out to coefficient domain + rlwe_out + .iter_rows_mut() + .for_each(|r| ntt_op.backward(r.as_mut())); +} + +/// Encrypt polynomial m(X) as RLWE ciphertext. +/// +/// - rlwe_out: returned RLWE ciphertext RLWE(m) in coefficient domain. RLWE +/// ciphertext is a matirx with first row consiting polynomial `a` and the +/// second rows consting polynomial `b` +fn encrypt_rlwe< + Mmut: Matrix + MatrixMut + Clone, + ModOp: VectorOps, + NttOp: Ntt, + S: Secret, + R: RandomUniformDist<[Mmut::MatElement], Parameters = Mmut::MatElement> + + RandomGaussianDist<[Mmut::MatElement], Parameters = Mmut::MatElement>, +>( + m: &Mmut, + rlwe_out: &mut Mmut, + s: &S, + mod_op: &ModOp, + ntt_op: &NttOp, + rng: &mut R, +) where + ::R: RowMut, + Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, +{ + let ring_size = s.values().len(); + assert!(rlwe_out.dimension() == (2, ring_size)); + assert!(m.dimension() == (1, ring_size)); + + let q = mod_op.modulus(); + + // sample a + RandomUniformDist::random_fill(rng, &q, rlwe_out.get_row_mut(0)); + + // s * a + let mut sa = Mmut::try_convert_from(s.values(), &q); + ntt_op.forward(sa.get_row_mut(0)); + ntt_op.forward(rlwe_out.get_row_mut(0)); + mod_op.elwise_mul_mut(sa.get_row_mut(0), rlwe_out.get_row_slice(0)); + ntt_op.backward(rlwe_out.get_row_mut(0)); + ntt_op.backward(sa.get_row_mut(0)); + + // sample e + RandomGaussianDist::random_fill(rng, &q, rlwe_out.get_row_mut(1)); + mod_op.elwise_add_mut(rlwe_out.get_row_mut(1), m.get_row_slice(0)); + mod_op.elwise_add_mut(rlwe_out.get_row_mut(1), sa.get_row_slice(0)); +} + +/// Decrypts degree 1 RLWE ciphertext RLWE(m) and returns m +/// +/// - rlwe_ct: input degree 1 ciphertext RLWE(m). +fn decrypt_rlwe< + Mmut: MatrixMut + Clone, + M: Matrix, + ModOp: VectorOps, + NttOp: Ntt, + S: Secret, +>( + rlwe_ct: &M, + s: &S, + m_out: &mut Mmut, + ntt_op: &NttOp, + mod_op: &ModOp, +) where + ::R: RowMut, + Mmut: TryConvertFrom<[S::Element], Parameters = Mmut::MatElement>, + Mmut::MatElement: Copy, +{ + let ring_size = s.values().len(); + assert!(rlwe_ct.dimension() == (2, ring_size)); + assert!(m_out.dimension() == (1, ring_size)); + + // transform a to evluation form + m_out + .get_row_mut(0) + .copy_from_slice(rlwe_ct.get_row_slice(0)); + ntt_op.forward(m_out.get_row_mut(0)); + + // -s*a + let mut s = Mmut::try_convert_from(&s.values(), &mod_op.modulus()); + ntt_op.forward(s.get_row_mut(0)); + mod_op.elwise_mul_mut(m_out.get_row_mut(0), s.get_row_slice(0)); + mod_op.elwise_neg_mut(m_out.get_row_mut(0)); + ntt_op.backward(m_out.get_row_mut(0)); + + // m+e = b - s*a + mod_op.elwise_add_mut(m_out.get_row_mut(0), rlwe_ct.get_row_slice(1)); +} + +#[cfg(test)] +mod tests { + use std::vec; + + use itertools::Itertools; + use rand::{thread_rng, Rng}; + + use crate::{ + backend::ModularOpsU64, + decomposer::{gadget_vector, DefaultDecomposer}, + ntt::{self, Ntt, NttBackendU64}, + random::{DefaultSecureRng, RandomUniformDist}, + utils::{generate_prime, negacyclic_mul}, + }; + + use super::{decrypt_rlwe, encrypt_rgsw, encrypt_rlwe, rlwe_by_rgsw, RlweSecret}; + + #[test] + fn rlwe_by_rgsw_works() { + let logq = 50; + let logp = 3; + let ring_size = 1 << 10; + let q = generate_prime(logq, ring_size, 1u64 << logq).unwrap(); + let p = 1u64 << logp; + let d_rgsw = 10; + let logb = 5; + + let mut rng = DefaultSecureRng::new(); + + let s = RlweSecret::random((ring_size >> 1) as usize, ring_size as usize); + + let mut m0 = vec![0u64; ring_size as usize]; + RandomUniformDist::<[u64]>::random_fill(&mut rng, &(1u64 << logp), m0.as_mut_slice()); + let mut m1 = vec![0u64; ring_size as usize]; + m1[thread_rng().gen_range(0..ring_size) as usize] = 1; + + let ntt_op = NttBackendU64::new(q, ring_size as usize); + let mod_op = ModularOpsU64::new(q); + + // Encrypt m1 as RGSW(m1) + let mut rgsw_ct = vec![vec![0u64; ring_size as usize]; d_rgsw * 4]; + let gadget_vector = gadget_vector(logq, logb, d_rgsw); + let mut m1_eval = m1.clone(); + ntt_op.forward(&mut m1_eval); + encrypt_rgsw( + &mut rgsw_ct, + &vec![m1_eval], + &gadget_vector, + &s, + &mod_op, + &ntt_op, + &mut rng, + ); + // println!("RGSW(m1): {:?}", &rgsw_ct); + + // Encrypt m0 as RLWE(m0) + let mut rlwe_in_ct = vec![vec![0u64; ring_size as usize]; 2]; + let encoded_m = m0 + .iter() + .map(|v| (((*v as f64) * q as f64) / (p as f64)).round() as u64) + .collect_vec(); + encrypt_rlwe( + &vec![encoded_m.clone()], + &mut rlwe_in_ct, + &s, + &mod_op, + &ntt_op, + &mut rng, + ); + + // RLWE(m0m1) = RLWE(m0) x RGSW(m1) + let mut rlwe_out_ct = vec![vec![0u64; ring_size as usize]; 2]; + let mut scratch_space = vec![vec![0u64; ring_size as usize]; d_rgsw * 2]; + let decomposer = DefaultDecomposer::new(q, logb, d_rgsw); + rlwe_by_rgsw( + &rlwe_in_ct, + &rgsw_ct, + &mut rlwe_out_ct, + &mut scratch_space, + &decomposer, + &ntt_op, + &mod_op, + ); + + // Decrypt RLWE(m0m1) + let mut encoded_m0m1_back = vec![vec![0u64; ring_size as usize]]; + decrypt_rlwe(&rlwe_out_ct, &s, &mut encoded_m0m1_back, &ntt_op, &mod_op); + let m0m1_back = encoded_m0m1_back[0] + .iter() + .map(|v| (((*v as f64 * p as f64) / (q as f64)).round() as u64) % p) + .collect_vec(); + + let mul_mod = |v0: &u64, v1: &u64| (v0 * v1) % p; + let m0m1 = negacyclic_mul(&m0, &m1, mul_mod, p); + assert_eq!(m0m1, m0m1_back, "Expected {:?} got {:?}", m0m1, m0m1_back); + // dbg!(&m0m1_back, m0m1, q); + } +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..40d007a --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,191 @@ +use std::usize; + +use itertools::Itertools; +use num_traits::{PrimInt, Signed}; + +use crate::RandomUniformDist; +pub trait WithLocal { + fn with_local(func: F) -> R + where + F: Fn(&Self) -> R; + + fn with_local_mut(func: F) -> R + where + F: Fn(&mut Self) -> R; +} + +pub fn fill_random_ternary_secret_with_hamming_weight< + T: Signed, + R: RandomUniformDist<[u8], Parameters = u8> + RandomUniformDist, +>( + out: &mut [T], + hamming_weight: usize, + rng: &mut R, +) { + let mut bytes = vec![0u8; hamming_weight.div_ceil(8)]; + RandomUniformDist::<[u8]>::random_fill(rng, &0, &mut bytes); + + let size = out.len(); + let mut secret_indices = (0..size).into_iter().map(|i| i).collect_vec(); + let mut bit_index = 0; + let mut byte_index = 0; + for _ in 0..hamming_weight { + let mut s_index = 0usize; + RandomUniformDist::::random_fill(rng, &secret_indices.len(), &mut s_index); + + let curr_bit = (bytes[byte_index] >> bit_index) & 1; + if curr_bit == 1 { + out[secret_indices[s_index]] = T::one(); + } else { + out[secret_indices[s_index]] = -T::one(); + } + + secret_indices[s_index] = *secret_indices.last().unwrap(); + secret_indices.truncate(secret_indices.len()); + + if bit_index == 7 { + bit_index = 0; + byte_index += 1; + } else { + bit_index += 1; + } + } +} + +// TODO (Jay): this is only a workaround. Add a propoer way to perform primality +// tests. +fn is_probably_prime(candidate: u64) -> bool { + num_bigint_dig::prime::probably_prime(&num_bigint_dig::BigUint::from(candidate), 0) +} + +/// Finds prime that satisfy +/// - $prime \lt upper_bound$ +/// - $\log{prime} = num_bits$ +/// - `prime % modulo == 1` +pub fn generate_prime(num_bits: usize, modulo: u64, upper_bound: u64) -> Option { + let leading_zeros = (64 - num_bits) as u32; + + let mut tentative_prime = upper_bound - 1; + while tentative_prime % modulo != 1 && tentative_prime.leading_zeros() == leading_zeros { + tentative_prime -= 1; + } + + while !is_probably_prime(tentative_prime) + && tentative_prime.leading_zeros() == leading_zeros + && tentative_prime >= modulo + { + tentative_prime -= modulo; + } + + if is_probably_prime(tentative_prime) && tentative_prime.leading_zeros() == leading_zeros { + Some(tentative_prime) + } else { + None + } +} + +/// Returns a^b mod q +pub fn mod_exponent(a: u64, mut b: u64, q: u64) -> u64 { + let mod_mul = |v1: &u64, v2: &u64| { + let tmp = *v1 as u128 * *v2 as u128; + (tmp % q as u128) as u64 + }; + + let mut acc = a; + let mut out = 1; + while b != 0 { + let flag = b & 1; + + if flag == 1 { + out = mod_mul(&acc, &out); + } + acc = mod_mul(&acc, &acc); + + b >>= 1; + } + + out +} + +pub fn mod_inverse(a: u64, q: u64) -> u64 { + mod_exponent(a, q - 2, q) +} + +pub fn shoup_representation_fq(v: u64, q: u64) -> u64 { + ((v as u128 * (1u128 << 64)) / q as u128) as u64 +} + +pub fn negacyclic_mul T>( + a: &[T], + b: &[T], + mul: F, + modulus: T, +) -> Vec { + let mut r = vec![T::zero(); a.len()]; + for i in 0..a.len() { + for j in 0..i + 1 { + // println!("i: {j} {}", i - j); + r[i] = (r[i] + mul(&a[j], &b[i - j])) % modulus; + } + + for j in i + 1..a.len() { + // println!("i: {j} {}", a.len() - j + i); + r[i] = (r[i] + modulus - mul(&a[j], &b[a.len() - j + i])) % modulus; + } + // println!("") + } + + return r; +} + +pub trait TryConvertFrom { + type Parameters: ?Sized; + + fn try_convert_from(value: &T, parameters: &Self::Parameters) -> Self; +} + +impl TryConvertFrom<[i32]> for Vec> { + type Parameters = u32; + fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self { + let row0 = value + .iter() + .map(|v| { + let is_neg = v.is_negative(); + let v_u32 = v.abs() as u32; + + assert!(v_u32 < *parameters); + + if is_neg { + parameters - v_u32 + } else { + v_u32 + } + }) + .collect_vec(); + + vec![row0] + } +} + +impl TryConvertFrom<[i32]> for Vec> { + type Parameters = u64; + fn try_convert_from(value: &[i32], parameters: &Self::Parameters) -> Self { + let row0 = value + .iter() + .map(|v| { + let is_neg = v.is_negative(); + let v_u64 = v.abs() as u64; + + assert!(v_u64 < *parameters); + + if is_neg { + parameters - v_u64 + } else { + v_u64 + } + }) + .collect_vec(); + + vec![row0] + } +}