From 1cc38b304236b1f2e1981bb67f5626018d600b30 Mon Sep 17 00:00:00 2001 From: Jean-Philippe Bossuat Date: Mon, 20 Jan 2025 11:44:27 +0100 Subject: [PATCH] [ring]: added ring degree switching --- math/src/dft/ntt.rs | 18 +++++- math/src/ring/impl_u64/mod.rs | 1 + math/src/ring/impl_u64/ring_switch.rs | 43 +++++++++++++ math/tests/automorphism.rs | 6 +- math/tests/ring_switch.rs | 89 +++++++++++++++++++++++++++ 5 files changed, 153 insertions(+), 4 deletions(-) create mode 100644 math/src/ring/impl_u64/ring_switch.rs create mode 100644 math/tests/ring_switch.rs diff --git a/math/src/dft/ntt.rs b/math/src/dft/ntt.rs index a2fdfc2..ff54e7d 100644 --- a/math/src/dft/ntt.rs +++ b/math/src/dft/ntt.rs @@ -100,9 +100,17 @@ impl Table { let n: usize = a.len(); assert!( n & n - 1 == 0, - "invalid x.len()= {} must be a power of two", + "invalid a.len()={} must be a power of two", n ); + + assert!( + n <= self.psi_forward_rev.len(), + "invalid a.len()={} > psi_forward_rev.len()={}", + n, + self.psi_forward_rev.len() + ); + let log_n: u32 = usize::BITS - ((n as usize) - 1).leading_zeros(); let start: u32 = SKIPSTART as u32; @@ -204,6 +212,14 @@ impl Table { "invalid x.len()= {} must be a power of two", n ); + + assert!( + n <= self.psi_backward_rev.len(), + "invalid a.len()={} > psi_backward_rev.len()={}", + n, + self.psi_backward_rev.len() + ); + let log_n = usize::BITS - ((n as usize) - 1).leading_zeros(); let start: u32 = SKIPEND as u32; diff --git a/math/src/ring/impl_u64/mod.rs b/math/src/ring/impl_u64/mod.rs index e57f0cb..abacff1 100644 --- a/math/src/ring/impl_u64/mod.rs +++ b/math/src/ring/impl_u64/mod.rs @@ -2,5 +2,6 @@ pub mod automorphism; pub mod rescaling_rns; pub mod ring; pub mod ring_rns; +pub mod ring_switch; pub mod sampling; pub mod utils; diff --git a/math/src/ring/impl_u64/ring_switch.rs b/math/src/ring/impl_u64/ring_switch.rs new file mode 100644 index 0000000..beaa415 --- /dev/null +++ b/math/src/ring/impl_u64/ring_switch.rs @@ -0,0 +1,43 @@ +use crate::poly::Poly; +use crate::ring::Ring; + +impl Ring { + pub fn switch_degree( + &self, + a: &Poly, + buf: &mut Poly, + b: &mut Poly, + ) { + let (n_in, n_out) = (a.n(), b.n()); + + if n_in > n_out { + let (gap_in, gap_out) = (1, n_in / n_out); + if NTT { + self.intt::(&a, buf); + b.0.iter_mut() + .step_by(gap_in) + .zip(buf.0.iter().step_by(gap_out)) + .for_each(|(x_out, x_in)| *x_out = *x_in); + self.ntt_inplace::(b); + } else { + b.0.iter_mut() + .step_by(gap_in) + .zip(a.0.iter().step_by(gap_out)) + .for_each(|(x_out, x_in)| *x_out = *x_in); + } + } else { + let gap: usize = n_out / n_in; + + if NTT { + a.0.iter() + .enumerate() + .for_each(|(i, &c)| (0..gap).for_each(|j| b.0[i * gap + j] = c)); + } else { + b.0.iter_mut() + .step_by(gap) + .zip(a.0.iter()) + .for_each(|(x_out, x_in)| *x_out = *x_in); + } + } + } +} diff --git a/math/tests/automorphism.rs b/math/tests/automorphism.rs index ec7f7f2..95812d0 100644 --- a/math/tests/automorphism.rs +++ b/math/tests/automorphism.rs @@ -19,10 +19,10 @@ fn automorphism_u64() { }); sub_test("test_automorphism_from_perm_u64::", || { - test_automorphism_from_perm_u64::(&ring, nth_root) + test_automorphism_from_perm_u64::(&ring) }); sub_test("test_automorphism_from_perm_u64::", || { - test_automorphism_from_perm_u64::(&ring, nth_root) + test_automorphism_from_perm_u64::(&ring) }); } @@ -62,7 +62,7 @@ fn test_automorphism_native_u64(ring: &Ring, nth_root: usi izip!(p0.0, p1.0).for_each(|(a, b)| assert_eq!(a, b)); } -fn test_automorphism_from_perm_u64(ring: &Ring, nth_root: usize) { +fn test_automorphism_from_perm_u64(ring: &Ring) { let n: usize = ring.n(); let q: u64 = ring.modulus.q; diff --git a/math/tests/ring_switch.rs b/math/tests/ring_switch.rs new file mode 100644 index 0000000..6be79ec --- /dev/null +++ b/math/tests/ring_switch.rs @@ -0,0 +1,89 @@ +use itertools::izip; +use math::automorphism::AutoPerm; +use math::poly::Poly; +use math::ring::Ring; + +#[test] +fn ring_switch_u64() { + let n: usize = 1 << 4; + let q_base: u64 = 65537u64; + let q_power: usize = 1usize; + let ring_small: Ring = Ring::new(n, q_base, q_power); + let ring_large = Ring::new(2 * n, q_base, q_power); + + sub_test("test_ring_switch_small_to_large_u64::", || { + test_ring_switch_small_to_large_u64::(&ring_small, &ring_large) + }); + sub_test("test_ring_switch_small_to_large_u64::", || { + test_ring_switch_small_to_large_u64::(&ring_small, &ring_large) + }); + sub_test("test_ring_switch_large_to_small_u64::", || { + test_ring_switch_large_to_small_u64::(&ring_small, &ring_large) + }); + sub_test("test_ring_switch_large_to_small_u64::", || { + test_ring_switch_large_to_small_u64::(&ring_small, &ring_large) + }); +} + +fn sub_test(name: &str, f: F) { + println!("Running {}", name); + f(); +} + +fn test_ring_switch_small_to_large_u64( + ring_small: &Ring, + ring_large: &Ring, +) { + let mut a: Poly = ring_small.new_poly(); + let mut buf: Poly = ring_small.new_poly(); + let mut b: Poly = ring_large.new_poly(); + + a.0.iter_mut().enumerate().for_each(|(i, x)| *x = i as u64); + + if NTT { + ring_small.ntt_inplace::(&mut a); + } + + ring_large.switch_degree::(&a, &mut buf, &mut b); + + if NTT { + ring_small.intt_inplace::(&mut a); + ring_large.intt_inplace::(&mut b); + } + + let gap: usize = ring_large.n() / ring_small.n(); + + b.0.iter() + .step_by(gap) + .zip(a.0.iter()) + .for_each(|(x_out, x_in)| assert_eq!(x_out, x_in)); +} + +fn test_ring_switch_large_to_small_u64( + ring_small: &Ring, + ring_large: &Ring, +) { + let mut a: Poly = ring_large.new_poly(); + let mut buf: Poly = ring_large.new_poly(); + let mut b: Poly = ring_small.new_poly(); + + a.0.iter_mut().enumerate().for_each(|(i, x)| *x = i as u64); + + if NTT { + ring_large.ntt_inplace::(&mut a); + } + + ring_large.switch_degree::(&a, &mut buf, &mut b); + + if NTT { + ring_large.intt_inplace::(&mut a); + ring_small.intt_inplace::(&mut b); + } + + let gap: usize = ring_large.n() / ring_small.n(); + + a.0.iter() + .step_by(gap) + .zip(b.0.iter()) + .for_each(|(x_out, x_in)| assert_eq!(x_out, x_in)); +}