Fix decode_vec_i64 to handle the case k < basek

This commit is contained in:
Janmajaya Mall
2025-07-04 16:03:46 +05:30
parent e8454cd5f1
commit c4a517e9c3

View File

@@ -157,6 +157,9 @@ fn decode_vec_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, k:
} }
data.copy_from_slice(a.at(col_i, 0)); data.copy_from_slice(a.at(col_i, 0));
let rem: usize = basek - (k % basek); let rem: usize = basek - (k % basek);
if k < basek {
data.iter_mut().for_each(|x| *x >>= rem);
} else {
(1..size).for_each(|i| { (1..size).for_each(|i| {
if i == size - 1 && rem != basek { if i == size - 1 && rem != basek {
let k_rem: usize = basek - rem; let k_rem: usize = basek - rem;
@@ -169,6 +172,7 @@ fn decode_vec_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, k:
}); });
} }
}) })
}
} }
fn decode_vec_float<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, data: &mut [Float]) { fn decode_vec_float<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, data: &mut [Float]) {
@@ -268,7 +272,7 @@ fn decode_coeff_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, k
let mut res: i64 = data[i]; let mut res: i64 = data[i];
let rem: usize = basek - (k % basek); let rem: usize = basek - (k % basek);
let slice_size: usize = a.n() * a.cols(); let slice_size: usize = a.n() * a.cols();
(1..size).for_each(|i| { (0..size).for_each(|i| {
let x: i64 = data[i * slice_size]; let x: i64 = data[i * slice_size];
if i == size - 1 && rem != basek { if i == size - 1 && rem != basek {
let k_rem: usize = basek - rem; let k_rem: usize = basek - rem;
@@ -316,18 +320,25 @@ mod tests {
let module: Module<FFT64> = Module::<FFT64>::new(n); let module: Module<FFT64> = Module::<FFT64>::new(n);
let basek: usize = 17; let basek: usize = 17;
let size: usize = 5; let size: usize = 5;
let k: usize = size * basek - 5; for k in [size * basek - 5] {
let mut a: VecZnx<_> = module.new_vec_znx(2, size); let mut a: VecZnx<_> = module.new_vec_znx(2, size);
let mut source = Source::new([0u8; 32]); let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut(); let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| { (0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n]; let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut().for_each(|x| *x = source.next_i64()); have.iter_mut().for_each(|x| {
a.encode_vec_i64(col_i, basek, k, &have, 64); if k < 64 {
*x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
} else {
*x = source.next_i64();
}
});
a.encode_vec_i64(col_i, basek, k, &have, std::cmp::min(k, 64));
let mut want = vec![i64::default(); n]; let mut want = vec![i64::default(); n];
a.decode_vec_i64(col_i, basek, k, &mut want); a.decode_vec_i64(col_i, basek, k, &mut want);
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
}) })
} }
}
} }