refactoring

This commit is contained in:
Jean-Philippe Bossuat
2025-01-27 16:23:23 +01:00
parent c30f598776
commit 1ac719ce7e
21 changed files with 113 additions and 88 deletions

3
.gitmodules vendored
View File

@@ -1,3 +1,6 @@
[submodule "spqlios/spqlios-arithmetic"]
path = spqlios/spqlios-arithmetic
url = https://github.com/tfhe/spqlios-arithmetic
[submodule "base2k/spqlios-arithmetic"] [submodule "base2k/spqlios-arithmetic"]
path = base2k/spqlios-arithmetic path = base2k/spqlios-arithmetic
url = https://github.com/tfhe/spqlios-arithmetic url = https://github.com/tfhe/spqlios-arithmetic

63
Cargo.lock generated
View File

@@ -49,6 +49,20 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "base2k"
version = "0.1.0"
dependencies = [
"bindgen",
"criterion",
"itertools 0.14.0",
"rand",
"rand_core",
"rand_distr",
"sampling",
"utils",
]
[[package]] [[package]]
name = "bindgen" name = "bindgen"
version = "0.71.1" version = "0.71.1"
@@ -362,24 +376,6 @@ version = "0.4.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24"
[[package]]
name = "math"
version = "0.1.0"
dependencies = [
"criterion",
"itertools 0.14.0",
"num",
"num-bigint",
"num-integer",
"num-traits",
"primality-test",
"prime_factorization",
"rand_distr",
"sampling",
"sprs",
"utils",
]
[[package]] [[package]]
name = "matrixmultiply" name = "matrixmultiply"
version = "0.3.9" version = "0.3.9"
@@ -725,6 +721,24 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "rns"
version = "0.1.0"
dependencies = [
"criterion",
"itertools 0.14.0",
"num",
"num-bigint",
"num-integer",
"num-traits",
"primality-test",
"prime_factorization",
"rand_distr",
"sampling",
"sprs",
"utils",
]
[[package]] [[package]]
name = "rustc-hash" name = "rustc-hash"
version = "2.1.0" version = "2.1.0"
@@ -798,19 +812,6 @@ version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67"
[[package]]
name = "spqlios"
version = "0.1.0"
dependencies = [
"bindgen",
"criterion",
"itertools 0.14.0",
"rand",
"rand_core",
"rand_distr",
"sampling",
]
[[package]] [[package]]
name = "sprs" name = "sprs"
version = "0.11.2" version = "0.11.2"

View File

@@ -1,2 +1,11 @@
[workspace] [workspace]
members = ["base2k", "rns", "sampling", "utils"] members = ["base2k", "rns", "sampling", "utils"]
[workspace.dependencies]
rand = "0.8.4"
rand_chacha = "0.3.1"
rand_core = "0.6.4"
rand_distr = "0.4.3"
itertools = "0.14.0"
criterion = "0.5.1"

View File

@@ -1,18 +1,19 @@
[package] [package]
name = "spqlios" name = "base2k"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
rand = "0.8.5" criterion = {workspace = true}
rand_core = "0.6.4" itertools = {workspace = true}
itertools = "0.14.0" rand = {workspace = true}
criterion = "0.5.1" rand_distr = {workspace = true}
rand_distr = "0.4.3" rand_core = {workspace = true}
sampling = { path = "../sampling" } sampling = { path = "../sampling" }
utils = { path = "../utils" }
[build-dependencies] [build-dependencies]
bindgen = "0.71.1" bindgen ="0.71.1"
[[bench]] [[bench]]
name = "fft" name = "fft"

View File

@@ -1,9 +1,9 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use base2k::bindings::{
use spqlios::bindings::{
new_reim_fft_precomp, new_reim_ifft_precomp, reim_fft, reim_fft_precomp, new_reim_fft_precomp, new_reim_ifft_precomp, reim_fft, reim_fft_precomp,
reim_fft_precomp_get_buffer, reim_from_znx64_simple, reim_ifft, reim_ifft_precomp, reim_fft_precomp_get_buffer, reim_from_znx64_simple, reim_ifft, reim_ifft_precomp,
reim_ifft_precomp_get_buffer, reim_ifft_precomp_get_buffer,
}; };
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use std::ffi::c_void; use std::ffi::c_void;
fn fft(c: &mut Criterion) { fn fft(c: &mut Criterion) {

View File

@@ -7,9 +7,9 @@ use std::time::SystemTime;
fn main() { fn main() {
// Path to the C header file // Path to the C header file
let header_paths = [ let header_paths: [&str; 2] = [
"lib/spqlios/coeffs/coeffs_arithmetic.h", "spqlios-arithmetic/spqlios/coeffs/coeffs_arithmetic.h",
"lib/spqlios/arithmetic/vec_znx_arithmetic.h", "spqlios-arithmetic/spqlios/arithmetic/vec_znx_arithmetic.h",
]; ];
let out_path: PathBuf = PathBuf::from(env::var("OUT_DIR").unwrap()); let out_path: PathBuf = PathBuf::from(env::var("OUT_DIR").unwrap());
@@ -46,7 +46,10 @@ fn main() {
println!( println!(
"cargo:rustc-link-search=native={}", "cargo:rustc-link-search=native={}",
absolute("./lib/build/spqlios").unwrap().to_str().unwrap() absolute("./spqlios-arithmetic/build/spqlios")
.unwrap()
.to_str()
.unwrap()
); );
println!("cargo:rustc-link-lib=static=spqlios"); //"cargo:rustc-link-lib=dylib=spqlios" println!("cargo:rustc-link-lib=static=spqlios"); //"cargo:rustc-link-lib=dylib=spqlios"
} }

View File

@@ -1,8 +1,11 @@
use base2k::bindings::{
new_reim_fft_precomp, new_reim_ifft_precomp, reim_fft, reim_fft_precomp_get_buffer,
reim_fftvec_mul_simple, reim_from_znx64_simple, reim_ifft, reim_ifft_precomp_get_buffer,
reim_to_znx64_simple,
};
use std::ffi::c_void; use std::ffi::c_void;
use std::time::Instant; use std::time::Instant;
use spqlios::bindings::*;
fn main() { fn main() {
let log_bound: usize = 19; let log_bound: usize = 19;

View File

@@ -1,8 +1,8 @@
use base2k::module::{Module, FFT64};
use base2k::scalar::Scalar;
use base2k::vector::Vector;
use itertools::izip; use itertools::izip;
use sampling::source::Source; use sampling::source::Source;
use spqlios::module::{Module, FFT64};
use spqlios::scalar::Scalar;
use spqlios::vector::Vector;
fn main() { fn main() {
let n: usize = 16; let n: usize = 16;
@@ -23,7 +23,7 @@ fn main() {
s.fill_ternary_prob(0.5, &mut source); s.fill_ternary_prob(0.5, &mut source);
// Buffer to store s in the DFT domain // Buffer to store s in the DFT domain
let mut s_ppol: spqlios::module::SVPPOL = module.svp_new_ppol(); let mut s_ppol: base2k::module::SVPPOL = module.svp_new_ppol();
// s_ppol <- DFT(s) // s_ppol <- DFT(s)
module.svp_prepare(&mut s_ppol, &s); module.svp_prepare(&mut s_ppol, &s);
@@ -33,13 +33,13 @@ fn main() {
a.fill_uniform(&mut source); a.fill_uniform(&mut source);
// Scratch space for DFT values // Scratch space for DFT values
let mut buf_dft: spqlios::module::VECZNXDFT = module.new_vec_znx_dft(a.limbs()); let mut buf_dft: base2k::module::VECZNXDFT = module.new_vec_znx_dft(a.limbs());
// Applies buf_dft <- s * a // Applies buf_dft <- s * a
module.svp_apply_dft(&mut buf_dft, &s_ppol, &a); module.svp_apply_dft(&mut buf_dft, &s_ppol, &a);
// Alias scratch space // Alias scratch space
let mut buf_big: spqlios::module::VECZNXBIG = buf_dft.as_vec_znx_big(); let mut buf_big: base2k::module::VECZNXBIG = buf_dft.as_vec_znx_big();
// buf_big <- IDFT(buf_dft) (not normalized) // buf_big <- IDFT(buf_dft) (not normalized)
module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, a.limbs()); module.vec_znx_idft_tmp_a(&mut buf_big, &mut buf_dft, a.limbs());

View File

@@ -14,14 +14,20 @@ pub mod bindings {
} }
pub mod vec_znx_arithmetic; pub mod vec_znx_arithmetic;
#[allow(unused_imports)]
pub use vec_znx_arithmetic::*; pub use vec_znx_arithmetic::*;
pub mod vec_znx_big_arithmetic; pub mod vec_znx_big_arithmetic;
#[allow(unused_imports)]
pub use vec_znx_big_arithmetic::*; pub use vec_znx_big_arithmetic::*;
pub mod vec_znx_dft; pub mod vec_znx_dft;
#[allow(unused_imports)]
pub use vec_znx_dft::*; pub use vec_znx_dft::*;
pub mod scalar_vector_product; pub mod scalar_vector_product;
#[allow(unused_imports)]
pub use scalar_vector_product::*; pub use scalar_vector_product::*;
#[allow(dead_code)]
fn cast_mut_u64_to_mut_u8_slice(data: &mut [u64]) -> &mut [u8] { fn cast_mut_u64_to_mut_u8_slice(data: &mut [u64]) -> &mut [u8] {
let ptr: *mut u8 = data.as_mut_ptr() as *mut u8; let ptr: *mut u8 = data.as_mut_ptr() as *mut u8;
let len: usize = data.len() * std::mem::size_of::<u64>(); let len: usize = data.len() * std::mem::size_of::<u64>();

View File

@@ -1,6 +1,6 @@
use rand::distributions::{Distribution, WeightedIndex};
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use rand_core::RngCore; use rand_core::RngCore;
use rand_distr::{Distribution, WeightedIndex};
use sampling::source::Source; use sampling::source::Source;
pub struct Scalar(pub Vec<i64>); pub struct Scalar(pub Vec<i64>);

View File

@@ -1,5 +1,5 @@
[package] [package]
name = "math" name = "rns"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
@@ -10,11 +10,10 @@ num-bigint = "0.4.6"
num-traits = "0.2.19" num-traits = "0.2.19"
num-integer ="0.1.46" num-integer ="0.1.46"
prime_factorization = "1.0.5" prime_factorization = "1.0.5"
itertools = "0.14.0"
criterion = "0.5.1"
rand_distr = "0.4.3"
sprs = "0.11.2" sprs = "0.11.2"
criterion = {workspace = true}
itertools = {workspace = true}
rand_distr = {workspace = true}
sampling = { path = "../sampling" } sampling = { path = "../sampling" }
utils = { path = "../utils" } utils = { path = "../utils" }

View File

@@ -1,7 +1,7 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use math::modulus::WordOps; use rns::modulus::WordOps;
use math::poly::Poly; use rns::poly::Poly;
use math::ring::Ring; use rns::ring::Ring;
fn ntt(c: &mut Criterion) { fn ntt(c: &mut Criterion) {
fn runner<'a, const INPLACE: bool, const LAZY: bool>( fn runner<'a, const INPLACE: bool, const LAZY: bool>(

View File

@@ -1,8 +1,8 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use math::modulus::montgomery::Montgomery; use rns::modulus::montgomery::Montgomery;
use math::modulus::{WordOps, ONCE}; use rns::modulus::{WordOps, ONCE};
use math::poly::Poly; use rns::poly::Poly;
use math::ring::Ring; use rns::ring::Ring;
fn a_add_b_into_b(c: &mut Criterion) { fn a_add_b_into_b(c: &mut Criterion) {
fn runner(ring: Ring<u64>) -> Box<dyn FnMut()> { fn runner(ring: Ring<u64>) -> Box<dyn FnMut()> {

View File

@@ -1,11 +1,11 @@
use criterion::{criterion_group, criterion_main, Criterion}; use criterion::{criterion_group, criterion_main, Criterion};
use math::poly::PolyRNS; use rns::poly::PolyRNS;
use math::ring::RingRNS; use rns::ring::RingRNS;
fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) { fn div_floor_by_last_modulus_ntt_true(c: &mut Criterion) {
fn runner(r: RingRNS<u64>) -> Box<dyn FnMut()> { fn runner(r: RingRNS<u64>) -> Box<dyn FnMut()> {
let a: PolyRNS<u64> = r.new_polyrns(); let a: PolyRNS<u64> = r.new_polyrns();
let mut b: [math::poly::Poly<u64>; 2] = [r.new_poly(), r.new_poly()]; let mut b: [rns::poly::Poly<u64>; 2] = [r.new_poly(), r.new_poly()];
let mut c: PolyRNS<u64> = r.new_polyrns(); let mut c: PolyRNS<u64> = r.new_polyrns();
Box::new(move || r.div_by_last_modulus::<false, true>(&a, &mut b, &mut c)) Box::new(move || r.div_by_last_modulus::<false, true>(&a, &mut b, &mut c))

View File

@@ -1,6 +1,6 @@
use criterion::{criterion_group, criterion_main, Criterion}; use criterion::{criterion_group, criterion_main, Criterion};
use math::poly::PolyRNS; use rns::poly::PolyRNS;
use math::ring::RingRNS; use rns::ring::RingRNS;
use sampling::source::Source; use sampling::source::Source;
fn fill_uniform(c: &mut Criterion) { fn fill_uniform(c: &mut Criterion) {

View File

@@ -1,6 +1,6 @@
use math::dft::ntt::Table; use rns::dft::ntt::Table;
use math::modulus::prime::Prime; use rns::modulus::prime::Prime;
use math::ring::Ring; use rns::ring::Ring;
fn main() { fn main() {
// Example usage of `Prime<u64>` // Example usage of `Prime<u64>`
@@ -37,8 +37,8 @@ fn main() {
let r: Ring<u64> = Ring::<u64>::new(n as usize, q_base, q_power); let r: Ring<u64> = Ring::<u64>::new(n as usize, q_base, q_power);
let mut p0: math::poly::Poly<u64> = r.new_poly(); let mut p0: rns::poly::Poly<u64> = r.new_poly();
let mut p1: math::poly::Poly<u64> = r.new_poly(); let mut p1: rns::poly::Poly<u64> = r.new_poly();
for i in 0..p0.n() { for i in 0..p0.n() {
p0.0[i] = i as u64 p0.0[i] = i as u64

View File

@@ -1,7 +1,7 @@
use itertools::izip; use itertools::izip;
use math::automorphism::AutoPerm; use rns::automorphism::AutoPerm;
use math::poly::Poly; use rns::poly::Poly;
use math::ring::Ring; use rns::ring::Ring;
#[test] #[test]
fn automorphism_u64() { fn automorphism_u64() {

View File

@@ -1,7 +1,7 @@
use itertools::izip; use itertools::izip;
use math::modulus::{WordOps, ONCE}; use rns::modulus::{WordOps, ONCE};
use math::poly::Poly; use rns::poly::Poly;
use math::ring::Ring; use rns::ring::Ring;
use sampling::source::Source; use sampling::source::Source;
#[test] #[test]

View File

@@ -1,8 +1,8 @@
use itertools::izip; use itertools::izip;
use math::num_bigint::Div;
use math::poly::{Poly, PolyRNS};
use math::ring::RingRNS;
use num_bigint::BigInt; use num_bigint::BigInt;
use rns::num_bigint::Div;
use rns::poly::{Poly, PolyRNS};
use rns::ring::RingRNS;
use sampling::source::Source; use sampling::source::Source;
#[test] #[test]

View File

@@ -1,5 +1,5 @@
use math::poly::Poly; use rns::poly::Poly;
use math::ring::Ring; use rns::ring::Ring;
#[test] #[test]
fn ring_switch_u64() { fn ring_switch_u64() {

View File

@@ -4,5 +4,5 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
rand_chacha = "0.3.1" rand_chacha = { workspace = true }
rand_core = "0.6.4" rand_core = { workspace = true }