Fix decode_vec_i64 to handle the case k < basek

This commit is contained in:
Janmajaya Mall
2025-07-04 16:03:46 +05:30
parent e8454cd5f1
commit c4a517e9c3

View File

@@ -157,18 +157,22 @@ fn decode_vec_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, k:
} }
data.copy_from_slice(a.at(col_i, 0)); data.copy_from_slice(a.at(col_i, 0));
let rem: usize = basek - (k % basek); let rem: usize = basek - (k % basek);
(1..size).for_each(|i| { if k < basek {
if i == size - 1 && rem != basek { data.iter_mut().for_each(|x| *x >>= rem);
let k_rem: usize = basek - rem; } else {
izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { (1..size).for_each(|i| {
*y = (*y << k_rem) + (x >> rem); if i == size - 1 && rem != basek {
}); let k_rem: usize = basek - rem;
} else { izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| { *y = (*y << k_rem) + (x >> rem);
*y = (*y << basek) + x; });
}); } else {
} izip!(a.at(col_i, i).iter(), data.iter_mut()).for_each(|(x, y)| {
}) *y = (*y << basek) + x;
});
}
})
}
} }
fn decode_vec_float<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, data: &mut [Float]) { fn decode_vec_float<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, data: &mut [Float]) {
@@ -268,7 +272,7 @@ fn decode_coeff_i64<D: AsRef<[u8]>>(a: &VecZnx<D>, col_i: usize, basek: usize, k
let mut res: i64 = data[i]; let mut res: i64 = data[i];
let rem: usize = basek - (k % basek); let rem: usize = basek - (k % basek);
let slice_size: usize = a.n() * a.cols(); let slice_size: usize = a.n() * a.cols();
(1..size).for_each(|i| { (0..size).for_each(|i| {
let x: i64 = data[i * slice_size]; let x: i64 = data[i * slice_size];
if i == size - 1 && rem != basek { if i == size - 1 && rem != basek {
let k_rem: usize = basek - rem; let k_rem: usize = basek - rem;
@@ -316,18 +320,25 @@ mod tests {
let module: Module<FFT64> = Module::<FFT64>::new(n); let module: Module<FFT64> = Module::<FFT64>::new(n);
let basek: usize = 17; let basek: usize = 17;
let size: usize = 5; let size: usize = 5;
let k: usize = size * basek - 5; for k in [size * basek - 5] {
let mut a: VecZnx<_> = module.new_vec_znx(2, size); let mut a: VecZnx<_> = module.new_vec_znx(2, size);
let mut source = Source::new([0u8; 32]); let mut source = Source::new([0u8; 32]);
let raw: &mut [i64] = a.raw_mut(); let raw: &mut [i64] = a.raw_mut();
raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64); raw.iter_mut().enumerate().for_each(|(i, x)| *x = i as i64);
(0..a.cols()).for_each(|col_i| { (0..a.cols()).for_each(|col_i| {
let mut have: Vec<i64> = vec![i64::default(); n]; let mut have: Vec<i64> = vec![i64::default(); n];
have.iter_mut().for_each(|x| *x = source.next_i64()); have.iter_mut().for_each(|x| {
a.encode_vec_i64(col_i, basek, k, &have, 64); if k < 64 {
let mut want = vec![i64::default(); n]; *x = source.next_u64n(1 << k, (1 << k) - 1) as i64;
a.decode_vec_i64(col_i, basek, k, &mut want); } else {
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b)); *x = source.next_i64();
}) }
});
a.encode_vec_i64(col_i, basek, k, &have, std::cmp::min(k, 64));
let mut want = vec![i64::default(); n];
a.decode_vec_i64(col_i, basek, k, &mut want);
izip!(want, have).for_each(|(a, b)| assert_eq!(a, b, "{} != {}", a, b));
})
}
} }
} }