143 lines
4.9 KiB

  1. //! This module implements useful functions for the permutation check protocol.
  2. use crate::PolyIOPErrors;
  3. use ark_ff::PrimeField;
  4. use ark_poly::DenseMultilinearExtension;
  5. use ark_std::{end_timer, rand::RngCore, start_timer};
  6. /// Returns three MLEs:
  7. /// - prod(0,x)
  8. /// - numerator
  9. /// - denominator
  10. ///
  11. /// where
  12. /// - `prod(0,x) := prod(0, x1, …, xn)` which is the MLE over the
  13. /// evaluations of the following polynomial on the boolean hypercube {0,1}^n:
  14. ///
  15. /// (f(x) + \beta s_id(x) + \gamma)/(g(x) + \beta s_perm(x) + \gamma)
  16. ///
  17. /// where
  18. /// - beta and gamma are challenges
  19. /// - f(x), g(x), s_id(x), s_perm(x) are mle-s
  20. ///
  21. /// - numerator is the MLE for `f(x) + \beta s_id(x) + \gamma`
  22. /// - denominator is the MLE for `g(x) + \beta s_perm(x) + \gamma`
  23. ///
  24. /// The caller needs to check num_vars matches in f/g/s_id/s_perm
  25. /// Cost: linear in N.
  26. #[allow(clippy::type_complexity)]
  27. pub(super) fn compute_prod_0<F: PrimeField>(
  28. beta: &F,
  29. gamma: &F,
  30. fx: &DenseMultilinearExtension<F>,
  31. gx: &DenseMultilinearExtension<F>,
  32. s_perm: &DenseMultilinearExtension<F>,
  33. ) -> Result<
  34. (
  35. DenseMultilinearExtension<F>,
  36. DenseMultilinearExtension<F>,
  37. DenseMultilinearExtension<F>,
  38. ),
  39. PolyIOPErrors,
  40. > {
  41. let start = start_timer!(|| "compute prod(1,x)");
  42. let num_vars = fx.num_vars;
  43. let mut prod_0x_evals = vec![];
  44. let mut numerator_evals = vec![];
  45. let mut denominator_evals = vec![];
  46. let s_id = identity_permutation_mle::<F>(num_vars);
  47. for (&fi, (&gi, (&s_id_i, &s_perm_i))) in
  48. fx.iter().zip(gx.iter().zip(s_id.iter().zip(s_perm.iter())))
  49. {
  50. let numerator = fi + *beta * s_id_i + gamma;
  51. let denominator = gi + *beta * s_perm_i + gamma;
  52. prod_0x_evals.push(numerator / denominator);
  53. numerator_evals.push(numerator);
  54. denominator_evals.push(denominator);
  55. }
  56. let prod_0x = DenseMultilinearExtension::from_evaluations_vec(num_vars, prod_0x_evals);
  57. let numerator = DenseMultilinearExtension::from_evaluations_vec(num_vars, numerator_evals);
  58. let denominator = DenseMultilinearExtension::from_evaluations_vec(num_vars, denominator_evals);
  59. end_timer!(start);
  60. Ok((prod_0x, numerator, denominator))
  61. }
  62. /// An MLE that represent an identity permutation: `f(index) \mapto index`
  63. pub fn identity_permutation_mle<F: PrimeField>(num_vars: usize) -> DenseMultilinearExtension<F> {
  64. let s_id_vec = (0..1u64 << num_vars).map(F::from).collect();
  65. DenseMultilinearExtension::from_evaluations_vec(num_vars, s_id_vec)
  66. }
  67. /// An MLE that represent a random permutation
  68. pub fn random_permutation_mle<F: PrimeField, R: RngCore>(
  69. num_vars: usize,
  70. rng: &mut R,
  71. ) -> DenseMultilinearExtension<F> {
  72. let len = 1u64 << num_vars;
  73. let mut s_id_vec: Vec<F> = (0..len).map(F::from).collect();
  74. let mut s_perm_vec = vec![];
  75. for _ in 0..len {
  76. let index = rng.next_u64() as usize % s_id_vec.len();
  77. s_perm_vec.push(s_id_vec.remove(index));
  78. }
  79. DenseMultilinearExtension::from_evaluations_vec(num_vars, s_perm_vec)
  80. }
  81. #[cfg(test)]
  82. mod test {
  83. use super::*;
  84. use crate::utils::bit_decompose;
  85. use ark_bls12_381::Fr;
  86. use ark_ff::UniformRand;
  87. use ark_poly::MultilinearExtension;
  88. use ark_std::test_rng;
  89. #[test]
  90. fn test_compute_prod_0() -> Result<(), PolyIOPErrors> {
  91. let mut rng = test_rng();
  92. for num_vars in 2..6 {
  93. let f = DenseMultilinearExtension::rand(num_vars, &mut rng);
  94. let g = DenseMultilinearExtension::rand(num_vars, &mut rng);
  95. let s_id = identity_permutation_mle::<Fr>(num_vars);
  96. let s_perm = random_permutation_mle(num_vars, &mut rng);
  97. let beta = Fr::rand(&mut rng);
  98. let gamma = Fr::rand(&mut rng);
  99. let (prod_0, numerator, denominator) = compute_prod_0(&beta, &gamma, &f, &g, &s_perm)?;
  100. for i in 0..1 << num_vars {
  101. let r: Vec<Fr> = bit_decompose(i, num_vars)
  102. .iter()
  103. .map(|&x| Fr::from(x))
  104. .collect();
  105. let prod_0_eval = prod_0.evaluate(&r).unwrap();
  106. let numerator_eval = numerator.evaluate(&r).unwrap();
  107. let denominator_eval = denominator.evaluate(&r).unwrap();
  108. let f_eval = f.evaluate(&r).unwrap();
  109. let g_eval = g.evaluate(&r).unwrap();
  110. let s_id_eval = s_id.evaluate(&r).unwrap();
  111. let s_perm_eval = s_perm.evaluate(&r).unwrap();
  112. let numerator_eval_rec = f_eval + beta * s_id_eval + gamma;
  113. let denominator_eval_rec = g_eval + beta * s_perm_eval + gamma;
  114. let prod_0_eval_rec = numerator_eval_rec / denominator_eval_rec;
  115. assert_eq!(numerator_eval, numerator_eval_rec);
  116. assert_eq!(denominator_eval, denominator_eval_rec);
  117. assert_eq!(prod_0_eval, prod_0_eval_rec);
  118. }
  119. }
  120. Ok(())
  121. }
  122. }