diff --git a/tensor_pcs/src/utils.rs b/tensor_pcs/src/utils.rs index 871c18d..ec33f8d 100644 --- a/tensor_pcs/src/utils.rs +++ b/tensor_pcs/src/utils.rs @@ -1,8 +1,6 @@ -use tiny_keccak::{Hasher, Keccak}; - -use crate::FieldExt; - use crate::transcript::Transcript; +use crate::FieldExt; +use tiny_keccak::{Hasher, Keccak}; pub fn rlc_rows(x: Vec>, r: &[F]) -> Vec { debug_assert_eq!(x.len(), r.len()); @@ -68,7 +66,7 @@ pub fn sample_indices( let mut indices = Vec::with_capacity(num_indices); let mut counter: u32 = 0; - // TODO: Don't sample at n and n + N + let n = max_index / 2; while indices.len() < num_indices { let mut random_bytes = [0u8; 64]; @@ -76,10 +74,8 @@ pub fn sample_indices( transcript.challenge_bytes(&mut random_bytes); let index = sample_index(random_bytes, max_index); - if !indices.contains(&index) - // || !indices.contains(&(index + (max_index / 2))) - // || !indices.contains(&(index - (max_index / 2))) - { + let pair_index = if index > n { index - n } else { index + n }; + if !indices.contains(&index) && !indices.contains(&pair_index) { indices.push(index); } counter += 1; @@ -87,3 +83,27 @@ pub fn sample_indices( indices } + +#[cfg(test)] +mod tests { + use super::*; + type F = halo2curves::secp256k1::Fp; + + #[test] + fn test_sample_indices() { + let mut transcript = Transcript::::new(b"test_sample_index"); + let num_indices = 10; + let max_index = 100; + let indices = sample_indices(num_indices, max_index, &mut transcript); + + assert_eq!(indices.len(), 10); + let n = max_index / 2; + for index in &indices { + if *index > n { + assert!(!indices.contains(&(index - n))); + } else { + assert!(!indices.contains(&(index + n))); + } + } + } +}