You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

346 lines
12 KiB

  1. // Sum-check protocol initial implementation, not used by the rest of the repo but implemented as
  2. // an exercise and it will probably be used in the future.
  3. use ark_ff::{BigInteger, PrimeField};
  4. use ark_poly::{
  5. multivariate::{SparsePolynomial, SparseTerm, Term},
  6. univariate::DensePolynomial,
  7. DenseMVPolynomial, DenseUVPolynomial, EvaluationDomain, GeneralEvaluationDomain, Polynomial,
  8. SparseMultilinearExtension,
  9. };
  10. use ark_std::cfg_into_iter;
  11. use ark_std::log2;
  12. use ark_std::marker::PhantomData;
  13. use ark_std::ops::Mul;
  14. use ark_std::{rand::Rng, UniformRand};
  15. use crate::transcript::Transcript;
  16. pub struct SumCheck<
  17. F: PrimeField,
  18. UV: Polynomial<F> + DenseUVPolynomial<F>,
  19. MV: Polynomial<F> + DenseMVPolynomial<F>,
  20. > {
  21. _f: PhantomData<F>,
  22. _uv: PhantomData<UV>,
  23. _mv: PhantomData<MV>,
  24. }
  25. impl<
  26. F: PrimeField,
  27. UV: Polynomial<F> + DenseUVPolynomial<F>,
  28. MV: Polynomial<F> + DenseMVPolynomial<F>,
  29. > SumCheck<F, UV, MV>
  30. {
  31. fn partial_evaluate(g: &MV, point: &[Option<F>]) -> UV {
  32. assert!(point.len() >= g.num_vars(), "Invalid evaluation domain");
  33. // TODO: add check: there can only be 1 'None' value in point
  34. if g.is_zero() {
  35. return UV::from_coefficients_vec(vec![F::zero()]);
  36. }
  37. // note: this can be parallelized with cfg_into_iter
  38. let mut univ_terms: Vec<(F, SparseTerm)> = vec![];
  39. for (coef, term) in g.terms().iter() {
  40. // partial_evaluate each term
  41. let mut new_coef = F::one();
  42. let mut new_term = Vec::new();
  43. for (var, power) in term.iter() {
  44. match point[*var] {
  45. Some(v) => {
  46. if v.is_zero() {
  47. new_coef = F::zero();
  48. new_term = vec![];
  49. break;
  50. } else {
  51. new_coef = new_coef * v.pow([(*power) as u64]);
  52. }
  53. }
  54. _ => {
  55. new_term.push((*var, *power));
  56. }
  57. };
  58. }
  59. let new_term = SparseTerm::new(new_term);
  60. let new_coef = new_coef * coef;
  61. univ_terms.push((new_coef, new_term));
  62. }
  63. let mv_poly: SparsePolynomial<F, SparseTerm> =
  64. DenseMVPolynomial::<F>::from_coefficients_vec(g.num_vars(), univ_terms.clone());
  65. let mut univ_coeffs: Vec<F> = vec![F::zero(); mv_poly.degree() + 1];
  66. for (coef, term) in univ_terms {
  67. if term.is_empty() {
  68. univ_coeffs[0] += coef;
  69. continue;
  70. }
  71. for (_, power) in term.iter() {
  72. univ_coeffs[*power] += coef;
  73. }
  74. }
  75. UV::from_coefficients_vec(univ_coeffs)
  76. }
  77. fn point_complete(challenges: Vec<F>, n_elems: usize, iter_num: usize) -> Vec<F> {
  78. let p = Self::point(challenges, false, n_elems, iter_num);
  79. let mut r = vec![F::zero(); n_elems];
  80. for i in 0..n_elems {
  81. r[i] = p[i].unwrap();
  82. }
  83. r
  84. }
  85. fn point(challenges: Vec<F>, none: bool, n_elems: usize, iter_num: usize) -> Vec<Option<F>> {
  86. let mut n_vars = n_elems - challenges.len();
  87. assert!(n_vars >= log2(iter_num + 1) as usize);
  88. if none {
  89. // WIP
  90. if n_vars == 0 {
  91. panic!("err"); // or return directly challenges vector
  92. }
  93. n_vars -= 1;
  94. }
  95. let none_pos = if none {
  96. challenges.len() + 1
  97. } else {
  98. challenges.len()
  99. };
  100. let mut p: Vec<Option<F>> = vec![None; n_elems];
  101. for i in 0..challenges.len() {
  102. p[i] = Some(challenges[i]);
  103. }
  104. for i in 0..n_vars {
  105. let k = F::from(iter_num as u64).into_bigint().to_bytes_le();
  106. let bit = k[(i / 8) as usize] & (1 << (i % 8));
  107. if bit == 0 {
  108. p[none_pos + i] = Some(F::zero());
  109. } else {
  110. p[none_pos + i] = Some(F::one());
  111. }
  112. }
  113. p
  114. }
  115. pub fn prove(g: MV) -> (F, Vec<UV>, F)
  116. where
  117. <MV as Polynomial<F>>::Point: From<Vec<F>>,
  118. {
  119. // init transcript
  120. let mut transcript: Transcript<F> = Transcript::<F>::new();
  121. let v = g.num_vars();
  122. // compute H
  123. let mut H = F::zero();
  124. for i in 0..(2_u64.pow(v as u32) as usize) {
  125. let p = Self::point_complete(vec![], v, i);
  126. H = H + g.evaluate(&p.into());
  127. }
  128. transcript.add(b"H", &H);
  129. let mut ss: Vec<UV> = Vec::new();
  130. let mut r: Vec<F> = vec![];
  131. for i in 0..v {
  132. let r_i = transcript.get_challenge(b"r_i");
  133. r.push(r_i);
  134. let var_slots = v - 1 - i;
  135. let n_points = 2_u64.pow(var_slots as u32) as usize;
  136. let mut s_i = UV::zero();
  137. for j in 0..n_points {
  138. let point = Self::point(r[..i].to_vec(), true, v, j);
  139. s_i = s_i + Self::partial_evaluate(&g, &point);
  140. }
  141. transcript.add(b"s_i", &s_i);
  142. ss.push(s_i);
  143. }
  144. let last_g_eval = g.evaluate(&r.into());
  145. (H, ss, last_g_eval)
  146. }
  147. pub fn verify(proof: (F, Vec<UV>, F)) -> bool {
  148. // init transcript
  149. let mut transcript: Transcript<F> = Transcript::<F>::new();
  150. transcript.add(b"H", &proof.0);
  151. let (c, ss, last_g_eval) = proof;
  152. let mut r: Vec<F> = vec![];
  153. for (i, s) in ss.iter().enumerate() {
  154. // TODO check degree
  155. if i == 0 {
  156. if c != s.evaluate(&F::zero()) + s.evaluate(&F::one()) {
  157. return false;
  158. }
  159. let r_i = transcript.get_challenge(b"r_i");
  160. r.push(r_i);
  161. transcript.add(b"s_i", s);
  162. continue;
  163. }
  164. let r_i = transcript.get_challenge(b"r_i");
  165. r.push(r_i);
  166. if ss[i - 1].evaluate(&r[i - 1]) != s.evaluate(&F::zero()) + s.evaluate(&F::one()) {
  167. return false;
  168. }
  169. transcript.add(b"s_i", s);
  170. }
  171. // last round
  172. if ss[ss.len() - 1].evaluate(&r[r.len() - 1]) != last_g_eval {
  173. return false;
  174. }
  175. true
  176. }
  177. }
  178. #[cfg(test)]
  179. mod tests {
  180. use super::*;
  181. use ark_bn254::Fr; // scalar field
  182. #[test]
  183. fn test_new_point() {
  184. let f4 = Fr::from(4_u32);
  185. let f1 = Fr::from(1);
  186. let f0 = Fr::from(0);
  187. type SC = SumCheck<Fr, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  188. let p = SC::point(vec![Fr::from(4_u32)], true, 5, 0);
  189. assert_eq!(vec![Some(f4), None, Some(f0), Some(f0), Some(f0),], p);
  190. let p = SC::point(vec![Fr::from(4_u32)], true, 5, 1);
  191. assert_eq!(vec![Some(f4), None, Some(f1), Some(f0), Some(f0),], p);
  192. let p = SC::point(vec![Fr::from(4_u32)], true, 5, 2);
  193. assert_eq!(vec![Some(f4), None, Some(f0), Some(f1), Some(f0),], p);
  194. let p = SC::point(vec![Fr::from(4_u32)], true, 5, 3);
  195. assert_eq!(vec![Some(f4), None, Some(f1), Some(f1), Some(f0),], p);
  196. let p = SC::point(vec![Fr::from(4_u32)], true, 5, 4);
  197. assert_eq!(vec![Some(f4), None, Some(f0), Some(f0), Some(f1),], p);
  198. // without None
  199. let p = SC::point(vec![], false, 4, 0);
  200. assert_eq!(vec![Some(f0), Some(f0), Some(f0), Some(f0),], p);
  201. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 0);
  202. assert_eq!(vec![Some(f4), Some(f0), Some(f0), Some(f0), Some(f0),], p);
  203. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 1);
  204. assert_eq!(vec![Some(f4), Some(f1), Some(f0), Some(f0), Some(f0),], p);
  205. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 3);
  206. assert_eq!(vec![Some(f4), Some(f1), Some(f1), Some(f0), Some(f0),], p);
  207. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 4);
  208. assert_eq!(vec![Some(f4), Some(f0), Some(f0), Some(f1), Some(f0),], p);
  209. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 10);
  210. assert_eq!(vec![Some(f4), Some(f0), Some(f1), Some(f0), Some(f1),], p);
  211. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 15);
  212. assert_eq!(vec![Some(f4), Some(f1), Some(f1), Some(f1), Some(f1),], p);
  213. // let p = SC::point(vec![Fr::from(4_u32)], false, 4, 16); // TODO expect error
  214. }
  215. #[test]
  216. fn test_partial_evaluate() {
  217. // g(X_0, X_1, X_2) = 2 X_0^3 + X_0 X_2 + X_1 X_2
  218. let terms = vec![
  219. (Fr::from(2u32), SparseTerm::new(vec![(0_usize, 3)])),
  220. (
  221. Fr::from(1u32),
  222. SparseTerm::new(vec![(0_usize, 1), (2_usize, 1)]),
  223. ),
  224. (
  225. Fr::from(1u32),
  226. SparseTerm::new(vec![(1_usize, 1), (2_usize, 1)]),
  227. ),
  228. ];
  229. let p = SparsePolynomial::from_coefficients_slice(3, &terms);
  230. type SC = SumCheck<Fr, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  231. let e0 = SC::partial_evaluate(&p, &[Some(Fr::from(2_u32)), None, Some(Fr::from(0_u32))]);
  232. assert_eq!(e0.coeffs(), vec![Fr::from(16_u32)]);
  233. let e1 = SC::partial_evaluate(&p, &[Some(Fr::from(2_u32)), None, Some(Fr::from(1_u32))]);
  234. assert_eq!(e1.coeffs(), vec![Fr::from(18_u32), Fr::from(1)]);
  235. assert_eq!((e0 + e1).coeffs(), vec![Fr::from(34_u32), Fr::from(1)]);
  236. }
  237. #[test]
  238. fn test_flow_hardcoded_values() {
  239. let mut rng = ark_std::test_rng();
  240. // g(X_0, X_1, X_2) = 2 X_0^3 + X_0 X_2 + X_1 X_2
  241. let terms = vec![
  242. (Fr::from(2u32), SparseTerm::new(vec![(0_usize, 3)])),
  243. (
  244. Fr::from(1u32),
  245. SparseTerm::new(vec![(0_usize, 1), (2_usize, 1)]),
  246. ),
  247. (
  248. Fr::from(1u32),
  249. SparseTerm::new(vec![(1_usize, 1), (2_usize, 1)]),
  250. ),
  251. ];
  252. let p = SparsePolynomial::from_coefficients_slice(3, &terms);
  253. // println!("p {:?}", p);
  254. type SC = SumCheck<Fr, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  255. let proof = SC::prove(p);
  256. assert_eq!(proof.0, Fr::from(12_u32));
  257. // println!("proof {:?}", proof);
  258. let v = SC::verify(proof);
  259. assert!(v);
  260. }
  261. fn rand_poly<R: Rng>(l: usize, d: usize, rng: &mut R) -> SparsePolynomial<Fr, SparseTerm> {
  262. // This method is from the arkworks/algebra/poly/multivariate test:
  263. // https://github.com/arkworks-rs/algebra/blob/bc991d44c5e579025b7ed56df3d30267a7b9acac/poly/src/polynomial/multivariate/sparse.rs#L303
  264. let mut random_terms = Vec::new();
  265. let num_terms = rng.gen_range(1..1000);
  266. // For each term, randomly select up to `l` variables with degree
  267. // in [1,d] and random coefficient
  268. random_terms.push((Fr::rand(rng), SparseTerm::new(vec![])));
  269. for _ in 1..num_terms {
  270. let term = (0..l)
  271. .map(|i| {
  272. if rng.gen_bool(0.5) {
  273. Some((i, rng.gen_range(1..(d + 1))))
  274. } else {
  275. None
  276. }
  277. })
  278. .flatten()
  279. .collect();
  280. let coeff = Fr::rand(rng);
  281. random_terms.push((coeff, SparseTerm::new(term)));
  282. }
  283. SparsePolynomial::from_coefficients_slice(l, &random_terms)
  284. }
  285. #[test]
  286. fn test_flow_rng() {
  287. let mut rng = ark_std::test_rng();
  288. // let p = SparsePolynomial::<Fr, SparseTerm>::rand(3, 3, &mut rng);
  289. let p = rand_poly(3, 3, &mut rng);
  290. // println!("p {:?}", p);
  291. type SC = SumCheck<Fr, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  292. let proof = SC::prove(p);
  293. println!("proof.s len {:?}", proof.1.len());
  294. let v = SC::verify(proof);
  295. assert!(v);
  296. }
  297. }