You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

357 lines
12 KiB

  1. // this file contains a sum-check protocol initial implementation, not used by the rest of the repo
  2. // but implemented as an exercise and it will probably be used in the future.
  3. use ark_ec::{CurveGroup, Group};
  4. use ark_ff::{BigInteger, PrimeField};
  5. use ark_poly::{
  6. multivariate::{SparsePolynomial, SparseTerm, Term},
  7. univariate::DensePolynomial,
  8. DenseMVPolynomial, DenseUVPolynomial, EvaluationDomain, GeneralEvaluationDomain, Polynomial,
  9. SparseMultilinearExtension,
  10. };
  11. use ark_std::cfg_into_iter;
  12. use ark_std::log2;
  13. use ark_std::marker::PhantomData;
  14. use ark_std::ops::Mul;
  15. use ark_std::{rand::Rng, UniformRand};
  16. use ark_crypto_primitives::sponge::{poseidon::PoseidonConfig, Absorb};
  17. use crate::transcript::Transcript;
  18. pub struct SumCheck<
  19. F: PrimeField + Absorb,
  20. C: CurveGroup,
  21. UV: Polynomial<F> + DenseUVPolynomial<F>,
  22. MV: Polynomial<F> + DenseMVPolynomial<F>,
  23. > {
  24. _f: PhantomData<F>,
  25. _c: PhantomData<C>,
  26. _uv: PhantomData<UV>,
  27. _mv: PhantomData<MV>,
  28. }
  29. impl<
  30. F: PrimeField + Absorb,
  31. C: CurveGroup,
  32. UV: Polynomial<F> + DenseUVPolynomial<F>,
  33. MV: Polynomial<F> + DenseMVPolynomial<F>,
  34. > SumCheck<F, C, UV, MV>
  35. where
  36. <C as CurveGroup>::BaseField: Absorb,
  37. {
  38. fn partial_evaluate(g: &MV, point: &[Option<F>]) -> UV {
  39. assert!(point.len() >= g.num_vars(), "Invalid evaluation domain");
  40. // TODO: add check: there can only be 1 'None' value in point
  41. if g.is_zero() {
  42. return UV::from_coefficients_vec(vec![F::zero()]);
  43. }
  44. // note: this can be parallelized with cfg_into_iter
  45. let mut univ_terms: Vec<(F, SparseTerm)> = vec![];
  46. for (coef, term) in g.terms().iter() {
  47. // partial_evaluate each term
  48. let mut new_coef = F::one();
  49. let mut new_term = Vec::new();
  50. for (var, power) in term.iter() {
  51. match point[*var] {
  52. Some(v) => {
  53. if v.is_zero() {
  54. new_coef = F::zero();
  55. new_term = vec![];
  56. break;
  57. } else {
  58. new_coef = new_coef * v.pow([(*power) as u64]);
  59. }
  60. }
  61. _ => {
  62. new_term.push((*var, *power));
  63. }
  64. };
  65. }
  66. let new_term = SparseTerm::new(new_term);
  67. let new_coef = new_coef * coef;
  68. univ_terms.push((new_coef, new_term));
  69. }
  70. let mv_poly: SparsePolynomial<F, SparseTerm> =
  71. DenseMVPolynomial::<F>::from_coefficients_vec(g.num_vars(), univ_terms.clone());
  72. let mut univ_coeffs: Vec<F> = vec![F::zero(); mv_poly.degree() + 1];
  73. for (coef, term) in univ_terms {
  74. if term.is_empty() {
  75. univ_coeffs[0] += coef;
  76. continue;
  77. }
  78. for (_, power) in term.iter() {
  79. univ_coeffs[*power] += coef;
  80. }
  81. }
  82. UV::from_coefficients_vec(univ_coeffs)
  83. }
  84. fn point_complete(challenges: Vec<F>, n_elems: usize, iter_num: usize) -> Vec<F> {
  85. let p = Self::point(challenges, false, n_elems, iter_num);
  86. let mut r = vec![F::zero(); n_elems];
  87. for i in 0..n_elems {
  88. r[i] = p[i].unwrap();
  89. }
  90. r
  91. }
  92. fn point(challenges: Vec<F>, none: bool, n_elems: usize, iter_num: usize) -> Vec<Option<F>> {
  93. let mut n_vars = n_elems - challenges.len();
  94. assert!(n_vars >= log2(iter_num + 1) as usize);
  95. if none {
  96. // WIP
  97. if n_vars == 0 {
  98. panic!("err"); // or return directly challenges vector
  99. }
  100. n_vars -= 1;
  101. }
  102. let none_pos = if none {
  103. challenges.len() + 1
  104. } else {
  105. challenges.len()
  106. };
  107. let mut p: Vec<Option<F>> = vec![None; n_elems];
  108. for i in 0..challenges.len() {
  109. p[i] = Some(challenges[i]);
  110. }
  111. for i in 0..n_vars {
  112. let k = F::from(iter_num as u64).into_bigint().to_bytes_le();
  113. let bit = k[(i / 8) as usize] & (1 << (i % 8));
  114. if bit == 0 {
  115. p[none_pos + i] = Some(F::zero());
  116. } else {
  117. p[none_pos + i] = Some(F::one());
  118. }
  119. }
  120. p
  121. }
  122. pub fn prove(poseidon_config: &PoseidonConfig<F>, g: MV) -> (F, Vec<UV>, F)
  123. where
  124. <MV as Polynomial<F>>::Point: From<Vec<F>>,
  125. {
  126. // init transcript
  127. let mut transcript = Transcript::<F, C>::new(poseidon_config);
  128. let v = g.num_vars();
  129. // compute H
  130. let mut H = F::zero();
  131. for i in 0..(2_u64.pow(v as u32) as usize) {
  132. let p = Self::point_complete(vec![], v, i);
  133. H = H + g.evaluate(&p.into());
  134. }
  135. transcript.add(&H);
  136. let mut ss: Vec<UV> = Vec::new();
  137. let mut r: Vec<F> = vec![];
  138. for i in 0..v {
  139. let r_i = transcript.get_challenge();
  140. r.push(r_i);
  141. let var_slots = v - 1 - i;
  142. let n_points = 2_u64.pow(var_slots as u32) as usize;
  143. let mut s_i = UV::zero();
  144. for j in 0..n_points {
  145. let point = Self::point(r[..i].to_vec(), true, v, j);
  146. s_i = s_i + Self::partial_evaluate(&g, &point);
  147. }
  148. transcript.add_vec(s_i.coeffs());
  149. ss.push(s_i);
  150. }
  151. let last_g_eval = g.evaluate(&r.into());
  152. (H, ss, last_g_eval)
  153. }
  154. pub fn verify(poseidon_config: &PoseidonConfig<F>, proof: (F, Vec<UV>, F)) -> bool {
  155. // init transcript
  156. let mut transcript = Transcript::<F, C>::new(poseidon_config);
  157. transcript.add(&proof.0);
  158. let (c, ss, last_g_eval) = proof;
  159. let mut r: Vec<F> = vec![];
  160. for (i, s) in ss.iter().enumerate() {
  161. // TODO check degree
  162. if i == 0 {
  163. if c != s.evaluate(&F::zero()) + s.evaluate(&F::one()) {
  164. return false;
  165. }
  166. let r_i = transcript.get_challenge();
  167. r.push(r_i);
  168. transcript.add_vec(s.coeffs());
  169. continue;
  170. }
  171. let r_i = transcript.get_challenge();
  172. r.push(r_i);
  173. if ss[i - 1].evaluate(&r[i - 1]) != s.evaluate(&F::zero()) + s.evaluate(&F::one()) {
  174. return false;
  175. }
  176. transcript.add_vec(s.coeffs());
  177. }
  178. // last round
  179. if ss[ss.len() - 1].evaluate(&r[r.len() - 1]) != last_g_eval {
  180. return false;
  181. }
  182. true
  183. }
  184. }
  185. #[cfg(test)]
  186. mod tests {
  187. use super::*;
  188. use crate::transcript::poseidon_test_config;
  189. use ark_mnt4_298::{Fr, G1Projective}; // scalar field
  190. #[test]
  191. fn test_new_point() {
  192. let f4 = Fr::from(4_u32);
  193. let f1 = Fr::from(1);
  194. let f0 = Fr::from(0);
  195. type SC = SumCheck<Fr, G1Projective, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  196. let p = SC::point(vec![Fr::from(4_u32)], true, 5, 0);
  197. assert_eq!(vec![Some(f4), None, Some(f0), Some(f0), Some(f0),], p);
  198. let p = SC::point(vec![Fr::from(4_u32)], true, 5, 1);
  199. assert_eq!(vec![Some(f4), None, Some(f1), Some(f0), Some(f0),], p);
  200. let p = SC::point(vec![Fr::from(4_u32)], true, 5, 2);
  201. assert_eq!(vec![Some(f4), None, Some(f0), Some(f1), Some(f0),], p);
  202. let p = SC::point(vec![Fr::from(4_u32)], true, 5, 3);
  203. assert_eq!(vec![Some(f4), None, Some(f1), Some(f1), Some(f0),], p);
  204. let p = SC::point(vec![Fr::from(4_u32)], true, 5, 4);
  205. assert_eq!(vec![Some(f4), None, Some(f0), Some(f0), Some(f1),], p);
  206. // without None
  207. let p = SC::point(vec![], false, 4, 0);
  208. assert_eq!(vec![Some(f0), Some(f0), Some(f0), Some(f0),], p);
  209. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 0);
  210. assert_eq!(vec![Some(f4), Some(f0), Some(f0), Some(f0), Some(f0),], p);
  211. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 1);
  212. assert_eq!(vec![Some(f4), Some(f1), Some(f0), Some(f0), Some(f0),], p);
  213. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 3);
  214. assert_eq!(vec![Some(f4), Some(f1), Some(f1), Some(f0), Some(f0),], p);
  215. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 4);
  216. assert_eq!(vec![Some(f4), Some(f0), Some(f0), Some(f1), Some(f0),], p);
  217. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 10);
  218. assert_eq!(vec![Some(f4), Some(f0), Some(f1), Some(f0), Some(f1),], p);
  219. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 15);
  220. assert_eq!(vec![Some(f4), Some(f1), Some(f1), Some(f1), Some(f1),], p);
  221. // let p = SC::point(vec![Fr::from(4_u32)], false, 4, 16); // TODO expect error
  222. }
  223. #[test]
  224. fn test_partial_evaluate() {
  225. // g(X_0, X_1, X_2) = 2 X_0^3 + X_0 X_2 + X_1 X_2
  226. let terms = vec![
  227. (Fr::from(2u32), SparseTerm::new(vec![(0_usize, 3)])),
  228. (
  229. Fr::from(1u32),
  230. SparseTerm::new(vec![(0_usize, 1), (2_usize, 1)]),
  231. ),
  232. (
  233. Fr::from(1u32),
  234. SparseTerm::new(vec![(1_usize, 1), (2_usize, 1)]),
  235. ),
  236. ];
  237. let p = SparsePolynomial::from_coefficients_slice(3, &terms);
  238. type SC = SumCheck<Fr, G1Projective, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  239. let e0 = SC::partial_evaluate(&p, &[Some(Fr::from(2_u32)), None, Some(Fr::from(0_u32))]);
  240. assert_eq!(e0.coeffs(), vec![Fr::from(16_u32)]);
  241. let e1 = SC::partial_evaluate(&p, &[Some(Fr::from(2_u32)), None, Some(Fr::from(1_u32))]);
  242. assert_eq!(e1.coeffs(), vec![Fr::from(18_u32), Fr::from(1)]);
  243. assert_eq!((e0 + e1).coeffs(), vec![Fr::from(34_u32), Fr::from(1)]);
  244. }
  245. #[test]
  246. fn test_flow_hardcoded_values() {
  247. let mut rng = ark_std::test_rng();
  248. // g(X_0, X_1, X_2) = 2 X_0^3 + X_0 X_2 + X_1 X_2
  249. let terms = vec![
  250. (Fr::from(2u32), SparseTerm::new(vec![(0_usize, 3)])),
  251. (
  252. Fr::from(1u32),
  253. SparseTerm::new(vec![(0_usize, 1), (2_usize, 1)]),
  254. ),
  255. (
  256. Fr::from(1u32),
  257. SparseTerm::new(vec![(1_usize, 1), (2_usize, 1)]),
  258. ),
  259. ];
  260. let p = SparsePolynomial::from_coefficients_slice(3, &terms);
  261. // println!("p {:?}", p);
  262. let poseidon_config = poseidon_test_config::<Fr>();
  263. type SC = SumCheck<Fr, G1Projective, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  264. let proof = SC::prove(&poseidon_config, p);
  265. assert_eq!(proof.0, Fr::from(12_u32));
  266. // println!("proof {:?}", proof);
  267. let v = SC::verify(&poseidon_config, proof);
  268. assert!(v);
  269. }
  270. fn rand_poly<R: Rng>(l: usize, d: usize, rng: &mut R) -> SparsePolynomial<Fr, SparseTerm> {
  271. // This method is from the arkworks/algebra/poly/multivariate test:
  272. // https://github.com/arkworks-rs/algebra/blob/bc991d44c5e579025b7ed56df3d30267a7b9acac/poly/src/polynomial/multivariate/sparse.rs#L303
  273. let mut random_terms = Vec::new();
  274. let num_terms = rng.gen_range(1..1000);
  275. // For each term, randomly select up to `l` variables with degree
  276. // in [1,d] and random coefficient
  277. random_terms.push((Fr::rand(rng), SparseTerm::new(vec![])));
  278. for _ in 1..num_terms {
  279. let term = (0..l)
  280. .map(|i| {
  281. if rng.gen_bool(0.5) {
  282. Some((i, rng.gen_range(1..(d + 1))))
  283. } else {
  284. None
  285. }
  286. })
  287. .flatten()
  288. .collect();
  289. let coeff = Fr::rand(rng);
  290. random_terms.push((coeff, SparseTerm::new(term)));
  291. }
  292. SparsePolynomial::from_coefficients_slice(l, &random_terms)
  293. }
  294. #[test]
  295. fn test_flow_rng() {
  296. let mut rng = ark_std::test_rng();
  297. // let p = SparsePolynomial::<Fr, SparseTerm>::rand(3, 3, &mut rng);
  298. let p = rand_poly(3, 3, &mut rng);
  299. // println!("p {:?}", p);
  300. let poseidon_config = poseidon_test_config::<Fr>();
  301. type SC = SumCheck<Fr, G1Projective, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  302. let proof = SC::prove(&poseidon_config, p);
  303. // println!("proof.s len {:?}", proof.1.len());
  304. let v = SC::verify(&poseidon_config, proof);
  305. assert!(v);
  306. }
  307. }