streamrepacker wip

This commit is contained in:
Jean-Philippe Bossuat
2025-01-11 15:07:22 +01:00
parent 74bfb52ada
commit a8bca16047
3 changed files with 171 additions and 15 deletions

View File

@@ -1,8 +1,10 @@
use crate::modulus::barrett::Barrett; use crate::modulus::barrett::Barrett;
use crate::modulus::montgomery::Montgomery;
use crate::modulus::ONCE; use crate::modulus::ONCE;
use crate::poly::Poly; use crate::poly::Poly;
use crate::ring::Ring; use crate::ring::Ring;
use std::cmp::min; use std::cmp::min;
use std::rc::Rc;
impl Ring<u64> { impl Ring<u64> {
@@ -53,7 +55,7 @@ impl Ring<u64> {
} }
}); });
let x_pow2: Vec<Poly<u64>> = self.gen_x_pow_2::<true, false>(log_n); let x_pow2: Vec<Poly<Montgomery<u64>>> = self.gen_x_pow_2::<true, false>(log_n);
let mut tmpa: Poly<u64> = self.new_poly(); let mut tmpa: Poly<u64> = self.new_poly();
let mut tmpb: Poly<u64> = self.new_poly(); let mut tmpb: Poly<u64> = self.new_poly();
@@ -115,3 +117,112 @@ fn max_gap(vec: &[usize]) -> usize {
} }
gap gap
} }
pub struct StreamRepacker{
accumulators: Vec<Accumulator>,
buf0: Poly<u64>,
buf1: Poly<u64>,
buf_auto: Poly<u64>,
x_pow_2: Vec<Poly<Montgomery<u64>>>,
n_inv: Barrett<u64>,
counter: usize,
}
pub struct Accumulator{
buf: [Option<Rc<Poly<u64>>>; 2],
control: bool,
}
impl Accumulator{
pub fn new(r: &Ring<u64>) -> Self{
Self { buf: [Some(Rc::new(r.new_poly())), None], control: false }
}
}
impl StreamRepacker{
pub fn new(r: &Ring<u64>) -> Self{
let mut accumulators: Vec<Accumulator> = Vec::<Accumulator>::new();
(0..r.log_n()).for_each(|_|
accumulators.push(Accumulator::new(r))
);
Self{
accumulators: accumulators,
buf0: r.new_poly(),
buf1: r.new_poly(),
buf_auto: r.new_poly(),
x_pow_2: r.gen_x_pow_2::<true, false>(r.log_n()),
n_inv: r.modulus.barrett.prepare(r.modulus.inv(r.n() as u64)),
counter:0,
}
}
fn merge_ab(&mut self, r: &Ring<u64>, a: &Poly<u64>, b: &Poly<u64>, i: usize) -> &Poly<u64>{
let tmp_a: &mut Poly<u64> = &mut self.buf0;
let tmp_b: &mut Poly<u64> = &mut self.buf1;
r.a_mul_b_montgomery_into_c::<ONCE>(a, &self.x_pow_2[r.log_n()-i-1], tmp_a);
r.a_sub_b_into_c::<1, ONCE>(a, tmp_a, tmp_b);
r.a_add_b_into_b::<ONCE>(a, tmp_a);
if i == 0{
r.a_mul_b_scalar_barrett_into_a::<ONCE>(&self.n_inv, tmp_a);
r.a_mul_b_scalar_barrett_into_a::<ONCE>(&self.n_inv, tmp_b);
}
let log_nth_root = r.log_n()+1;
let nth_root = 1<<log_nth_root;
let gal_el: usize = r.galois_element((1 << i) >> 1, i == 0, log_nth_root);
r.a_apply_automorphism_add_b_into_b::<ONCE, true>(tmp_b, gal_el, nth_root, tmp_a);
tmp_a
}
fn merge_a(&mut self, r: &Ring<u64>, a: &Poly<u64>, i: usize) -> &Poly<u64>{
let tmp_a: &mut Poly<u64> = &mut self.buf0;
let log_nth_root = r.log_n()+1;
let nth_root = 1<<log_nth_root;
let gal_el: usize = r.galois_element((1 << i) >> 1, i == 0, log_nth_root);
if i == 0{
r.a_mul_b_scalar_barrett_into_a::<ONCE>(&self.n_inv, tmp_a);
r.a_apply_automorphism_into_b::<true>(tmp_a, gal_el, nth_root, &mut self.buf_auto)
r.a_add_b_into_c::<ONCE>(&self.buf_auto, a, tmp_a);
}else{
r.a_apply_automorphism_into_b::<true>(a, gal_el, nth_root, tmp_a);
r.a_add_b_into_b::<ONCE>(a, tmp_a);
}
tmp_a
}
fn merge_b(&mut self, r: &Ring<u64>, b: &Poly<u64>, i: usize) -> &Poly<u64>{
let tmp_a: &mut Poly<u64> = &mut self.buf0;
let tmp_b: &mut Poly<u64> = &mut self.buf1;
let log_nth_root = r.log_n()+1;
let nth_root = 1<<log_nth_root;
let gal_el: usize = r.galois_element((1 << i) >> 1, i == 0, log_nth_root);
if i == 0{
r.a_mul_b_scalar_barrett_into_c::<ONCE>(&self.n_inv, b, tmp_b);
r.a_mul_b_montgomery_into_a::<ONCE>(&self.x_pow_2[r.log_n()-i-1], tmp_b);
}else{
r.a_mul_b_montgomery_into_c::<ONCE>(b, &self.x_pow_2[r.log_n()-i-1], tmp_b);
}
r.a_apply_automorphism_into_b::<true>(tmp_b, gal_el, nth_root, &mut self.buf_auto);
r.a_sub_b_into_a::<1, ONCE>(&self.buf_auto, tmp_b);
tmp_b
}
}

View File

@@ -4,29 +4,26 @@ use crate::modulus::barrett::Barrett;
use crate::modulus::ONCE; use crate::modulus::ONCE;
impl Ring<u64>{ impl Ring<u64>{
pub fn trace_inplace<const NTT:bool>(&self, log_start: usize, log_end: usize, a: &mut Poly<u64>){ pub fn trace_inplace<const NTT:bool>(&self, step_start: usize, a: &mut Poly<u64>){
assert!(log_end <= self.log_n(), "invalid argument log_end: log_end={} > self.log_n()={}", log_end, self.log_n()); assert!(step_start <= self.log_n(), "invalid argument step_start: step_start={} > self.log_n()={}", step_start, self.log_n());
assert!(log_end > log_start, "invalid argument log_start: log_start={} > log_end={}", log_start, log_end);
let log_steps = log_end - log_start; let log_steps: usize = self.log_n() - step_start;
let log_nth_root = self.log_n()+1;
let nth_root: usize= 1<<log_nth_root;
if log_steps > 0 { if log_steps > 0 {
let n_inv: Barrett<u64> = self.modulus.barrett.prepare(self.modulus.inv(1<<log_steps)); let n_inv: Barrett<u64> = self.modulus.barrett.prepare(self.modulus.inv(1<<log_steps));
self.a_mul_b_scalar_barrett_into_a::<ONCE>(&n_inv, a); self.a_mul_b_scalar_barrett_into_a::<ONCE>(&n_inv, a);
if !NTT{
self.ntt_inplace::<false>(a);
}
let mut tmp: Poly<u64> = self.new_poly(); let mut tmp: Poly<u64> = self.new_poly();
(log_start..log_end).for_each(|i|{ (step_start..self.log_n()).for_each(|i|{
let gal_el: usize = self.galois_element((1 << i) >> 1, i == 0, log_nth_root);
self.a_apply_automorphism_into_b::<NTT>(a, gal_el, nth_root, &mut tmp);
self.a_add_b_into_b::<ONCE>(&tmp, a);
}); });
} }
} }
} }

View File

@@ -1,6 +1,5 @@
use itertools::izip; use itertools::izip;
use math::poly::Poly; use math::poly::Poly;
use math::ring::impl_u64::ring;
use math::ring::Ring; use math::ring::Ring;
#[test] #[test]
@@ -135,3 +134,52 @@ fn test_packing_sparse_u64<const NTT: bool>(ring: &Ring<u64>) {
}); });
} }
} }
#[test]
fn trace_u64() {
let n: usize = 1 << 5;
let q_base: u64 = 65537u64;
let q_power: usize = 1usize;
let ring: Ring<u64> = Ring::new(n, q_base, q_power);
sub_test("test_trace::<NTT:false>", || {
test_trace_u64::<false>(&ring)
});
sub_test("test_trace::<NTT:true>", || {
test_trace_u64::<true>(&ring)
});
}
fn test_trace_u64<const NTT: bool>(ring: &Ring<u64>) {
let n: usize = ring.n();
let mut poly: Poly<u64> = ring.new_poly();
poly.0.iter_mut().enumerate().for_each(|(i, x)|{
*x = (i+1) as u64
});
if NTT{
ring.ntt_inplace::<false>(&mut poly);
}
let step_start: usize = 2;
ring.trace_inplace::<NTT>(step_start, &mut poly);
if NTT{
ring.intt_inplace::<false>(&mut poly);
}
let gap: usize = 1<<(ring.log_n() - step_start);
poly.0.iter().enumerate().for_each(|(i, x)| {
if i % gap == 0 {
assert_eq!(*x, 1 + i as u64)
} else {
assert_eq!(*x, 0u64)
}
});
}