From c4a517e9c3ef4eb4214d6894e035b685f6f665b9 Mon Sep 17 00:00:00 2001 From: Janmajaya Mall Date: Fri, 4 Jul 2025 16:03:46 +0530 Subject: [PATCH] Fix `decode_vec_i64` to handle the case `k < basek` --- backend/src/encoding.rs | 63 ++++++++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/backend/src/encoding.rs b/backend/src/encoding.rs index 73b86a3..55bba09 100644 --- a/backend/src/encoding.rs +++ b/backend/src/encoding.rs @@ -157,18 +157,22 @@ fn decode_vec_i64>(a: &VecZnx, col_i: usize, basek: usize, k: } data.copy_from_slice(a.at(col_i, 0)); let rem: usize = basek - (k % basek); - (1..size).for_each(|i| { - if i == size - 1 && rem != basek { - let k_rem: usize = basek - rem; - izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << k_rem) + (x >> rem); - }); - } else { - izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { - *y = (*y << basek) + x; - }); - } - }) + if k < basek { + data.iter_mut().for_each(|x| *x >>= rem); + } else { + (1..size).for_each(|i| { + if i == size - 1 && rem != basek { + let k_rem: usize = basek - rem; + izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { + *y = (*y << k_rem) + (x >> rem); + }); + } else { + izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { + *y = (*y << basek) + x; + }); + } + }) + } } fn decode_vec_float>(a: &VecZnx, col_i: usize, basek: usize, data: &mut [Float]) { @@ -268,7 +272,7 @@ fn decode_coeff_i64>(a: &VecZnx, col_i: usize, basek: usize, k let mut res: i64 = data[i]; let rem: usize = basek - (k % basek); let slice_size: usize = a.n() * a.cols(); - (1..size).for_each(|i| { + (0..size).for_each(|i| { let x: i64 = data[i * slice_size]; if i == size - 1 && rem != basek { let k_rem: usize = basek - rem; @@ -316,18 +320,25 @@ mod tests { let module: Module = Module::::new(n); let basek: usize = 17; let size: usize = 5; - let k: usize = size * basek - 5; - let mut a: VecZnx<_> = module.new_vec_znx(2, size); - let mut source = Source::new([0u8; 32]); - let raw: &mut [i64] = a.raw_mut(); - raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); - (0..a.cols()).for_each(|col_i| { - let mut have: Vec = vec![i64::default(); n]; - have.iter_mut().for_each(|x| *x = source.next_i64()); - a.encode_vec_i64(col_i, basek, k, &have, 64); - let mut want = vec![i64::default(); n]; - a.decode_vec_i64(col_i, basek, k, &mut want); - izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); - }) + for k in [size * basek - 5] { + let mut a: VecZnx<_> = module.new_vec_znx(2, size); + let mut source = Source::new([0u8; 32]); + let raw: &mut [i64] = a.raw_mut(); + raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); + (0..a.cols()).for_each(|col_i| { + let mut have: Vec = vec![i64::default(); n]; + have.iter_mut().for_each(|x| { + if k < 64 { + *x = source.next_u64n(1 << k, (1 << k) - 1) as i64; + } else { + *x = source.next_i64(); + } + }); + a.encode_vec_i64(col_i, basek, k, &have, std::cmp::min(k, 64)); + let mut want = vec![i64::default(); n]; + a.decode_vec_i64(col_i, basek, k, &mut want); + izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); + }) + } } }