You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

357 lines
12 KiB

  1. // this file contains a sum-check protocol initial implementation, not used by the rest of the repo
  2. // but implemented as an exercise and it will probably be used in the future.
  3. use ark_ec::CurveGroup;
  4. use ark_ff::{BigInteger, PrimeField};
  5. use ark_poly::{
  6. multivariate::{SparsePolynomial, SparseTerm, Term},
  7. DenseMVPolynomial, DenseUVPolynomial, Polynomial,
  8. };
  9. use ark_std::log2;
  10. use ark_std::marker::PhantomData;
  11. use ark_crypto_primitives::sponge::{poseidon::PoseidonConfig, Absorb};
  12. use crate::transcript::Transcript;
  13. pub struct SumCheck<
  14. F: PrimeField + Absorb,
  15. C: CurveGroup,
  16. UV: Polynomial<F> + DenseUVPolynomial<F>,
  17. MV: Polynomial<F> + DenseMVPolynomial<F>,
  18. > {
  19. _f: PhantomData<F>,
  20. _c: PhantomData<C>,
  21. _uv: PhantomData<UV>,
  22. _mv: PhantomData<MV>,
  23. }
  24. impl<
  25. F: PrimeField + Absorb,
  26. C: CurveGroup,
  27. UV: Polynomial<F> + DenseUVPolynomial<F>,
  28. MV: Polynomial<F> + DenseMVPolynomial<F>,
  29. > SumCheck<F, C, UV, MV>
  30. where
  31. <C as CurveGroup>::BaseField: Absorb,
  32. {
  33. fn partial_evaluate(g: &MV, point: &[Option<F>]) -> UV {
  34. assert!(point.len() >= g.num_vars(), "Invalid evaluation domain");
  35. // TODO: add check: there can only be 1 'None' value in point
  36. if g.is_zero() {
  37. return UV::from_coefficients_vec(vec![F::zero()]);
  38. }
  39. // note: this can be parallelized with cfg_into_iter
  40. let mut univ_terms: Vec<(F, SparseTerm)> = vec![];
  41. for (coef, term) in g.terms().iter() {
  42. // partial_evaluate each term
  43. let mut new_coef = F::one();
  44. let mut new_term = Vec::new();
  45. for (var, power) in term.iter() {
  46. match point[*var] {
  47. Some(v) => {
  48. if v.is_zero() {
  49. new_coef = F::zero();
  50. new_term = vec![];
  51. break;
  52. } else {
  53. new_coef *= v.pow([(*power) as u64]);
  54. }
  55. }
  56. _ => {
  57. new_term.push((*var, *power));
  58. }
  59. };
  60. }
  61. let new_term = SparseTerm::new(new_term);
  62. let new_coef = new_coef * coef;
  63. univ_terms.push((new_coef, new_term));
  64. }
  65. let mv_poly: SparsePolynomial<F, SparseTerm> =
  66. DenseMVPolynomial::<F>::from_coefficients_vec(g.num_vars(), univ_terms.clone());
  67. let mut univ_coeffs: Vec<F> = vec![F::zero(); mv_poly.degree() + 1];
  68. for (coef, term) in univ_terms {
  69. if term.is_empty() {
  70. univ_coeffs[0] += coef;
  71. continue;
  72. }
  73. for (_, power) in term.iter() {
  74. univ_coeffs[*power] += coef;
  75. }
  76. }
  77. UV::from_coefficients_vec(univ_coeffs)
  78. }
  79. fn point_complete(challenges: Vec<F>, n_elems: usize, iter_num: usize) -> Vec<F> {
  80. let p = Self::point(challenges, false, n_elems, iter_num);
  81. let mut r = vec![F::zero(); n_elems];
  82. for i in 0..n_elems {
  83. r[i] = p[i].unwrap();
  84. }
  85. r
  86. }
  87. fn point(challenges: Vec<F>, none: bool, n_elems: usize, iter_num: usize) -> Vec<Option<F>> {
  88. let mut n_vars = n_elems - challenges.len();
  89. assert!(n_vars >= log2(iter_num + 1) as usize);
  90. if none {
  91. // WIP
  92. if n_vars == 0 {
  93. panic!("err"); // or return directly challenges vector
  94. }
  95. n_vars -= 1;
  96. }
  97. let none_pos = if none {
  98. challenges.len() + 1
  99. } else {
  100. challenges.len()
  101. };
  102. let mut p: Vec<Option<F>> = vec![None; n_elems];
  103. for i in 0..challenges.len() {
  104. p[i] = Some(challenges[i]);
  105. }
  106. for i in 0..n_vars {
  107. let k = F::from(iter_num as u64).into_bigint().to_bytes_le();
  108. let bit = k[i / 8] & (1 << (i % 8));
  109. if bit == 0 {
  110. p[none_pos + i] = Some(F::zero());
  111. } else {
  112. p[none_pos + i] = Some(F::one());
  113. }
  114. }
  115. p
  116. }
  117. pub fn prove(poseidon_config: &PoseidonConfig<F>, g: MV) -> (F, Vec<UV>, F)
  118. where
  119. <MV as Polynomial<F>>::Point: From<Vec<F>>,
  120. {
  121. // init transcript
  122. let mut transcript = Transcript::<F, C>::new(poseidon_config);
  123. let v = g.num_vars();
  124. // compute H
  125. let mut H = F::zero();
  126. for i in 0..(2_u64.pow(v as u32) as usize) {
  127. let p = Self::point_complete(vec![], v, i);
  128. H += g.evaluate(&p.into());
  129. }
  130. transcript.add(&H);
  131. let mut ss: Vec<UV> = Vec::new();
  132. let mut r: Vec<F> = vec![];
  133. for i in 0..v {
  134. let r_i = transcript.get_challenge();
  135. r.push(r_i);
  136. let var_slots = v - 1 - i;
  137. let n_points = 2_u64.pow(var_slots as u32) as usize;
  138. let mut s_i = UV::zero();
  139. for j in 0..n_points {
  140. let point = Self::point(r[..i].to_vec(), true, v, j);
  141. s_i = s_i + Self::partial_evaluate(&g, &point);
  142. }
  143. transcript.add_vec(s_i.coeffs());
  144. ss.push(s_i);
  145. }
  146. let last_g_eval = g.evaluate(&r.into());
  147. (H, ss, last_g_eval)
  148. }
  149. pub fn verify(poseidon_config: &PoseidonConfig<F>, proof: (F, Vec<UV>, F)) -> bool {
  150. // init transcript
  151. let mut transcript = Transcript::<F, C>::new(poseidon_config);
  152. transcript.add(&proof.0);
  153. let (c, ss, last_g_eval) = proof;
  154. let mut r: Vec<F> = vec![];
  155. for (i, s) in ss.iter().enumerate() {
  156. // TODO check degree
  157. if i == 0 {
  158. if c != s.evaluate(&F::zero()) + s.evaluate(&F::one()) {
  159. return false;
  160. }
  161. let r_i = transcript.get_challenge();
  162. r.push(r_i);
  163. transcript.add_vec(s.coeffs());
  164. continue;
  165. }
  166. let r_i = transcript.get_challenge();
  167. r.push(r_i);
  168. if ss[i - 1].evaluate(&r[i - 1]) != s.evaluate(&F::zero()) + s.evaluate(&F::one()) {
  169. return false;
  170. }
  171. transcript.add_vec(s.coeffs());
  172. }
  173. // last round
  174. if ss[ss.len() - 1].evaluate(&r[r.len() - 1]) != last_g_eval {
  175. return false;
  176. }
  177. true
  178. }
  179. }
  180. #[cfg(test)]
  181. mod tests {
  182. use super::*;
  183. use crate::transcript::poseidon_test_config;
  184. use ark_mnt4_298::{Fr, G1Projective}; // scalar field
  185. use ark_poly::{
  186. multivariate::{SparsePolynomial, SparseTerm, Term},
  187. univariate::DensePolynomial,
  188. DenseMVPolynomial, DenseUVPolynomial,
  189. };
  190. use ark_std::{rand::Rng, UniformRand};
  191. #[test]
  192. fn test_new_point() {
  193. let f4 = Fr::from(4_u32);
  194. let f1 = Fr::from(1);
  195. let f0 = Fr::from(0);
  196. type SC = SumCheck<Fr, G1Projective, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  197. let p = SC::point(vec![Fr::from(4_u32)], true, 5, 0);
  198. assert_eq!(vec![Some(f4), None, Some(f0), Some(f0), Some(f0),], p);
  199. let p = SC::point(vec![Fr::from(4_u32)], true, 5, 1);
  200. assert_eq!(vec![Some(f4), None, Some(f1), Some(f0), Some(f0),], p);
  201. let p = SC::point(vec![Fr::from(4_u32)], true, 5, 2);
  202. assert_eq!(vec![Some(f4), None, Some(f0), Some(f1), Some(f0),], p);
  203. let p = SC::point(vec![Fr::from(4_u32)], true, 5, 3);
  204. assert_eq!(vec![Some(f4), None, Some(f1), Some(f1), Some(f0),], p);
  205. let p = SC::point(vec![Fr::from(4_u32)], true, 5, 4);
  206. assert_eq!(vec![Some(f4), None, Some(f0), Some(f0), Some(f1),], p);
  207. // without None
  208. let p = SC::point(vec![], false, 4, 0);
  209. assert_eq!(vec![Some(f0), Some(f0), Some(f0), Some(f0),], p);
  210. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 0);
  211. assert_eq!(vec![Some(f4), Some(f0), Some(f0), Some(f0), Some(f0),], p);
  212. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 1);
  213. assert_eq!(vec![Some(f4), Some(f1), Some(f0), Some(f0), Some(f0),], p);
  214. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 3);
  215. assert_eq!(vec![Some(f4), Some(f1), Some(f1), Some(f0), Some(f0),], p);
  216. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 4);
  217. assert_eq!(vec![Some(f4), Some(f0), Some(f0), Some(f1), Some(f0),], p);
  218. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 10);
  219. assert_eq!(vec![Some(f4), Some(f0), Some(f1), Some(f0), Some(f1),], p);
  220. let p = SC::point(vec![Fr::from(4_u32)], false, 5, 15);
  221. assert_eq!(vec![Some(f4), Some(f1), Some(f1), Some(f1), Some(f1),], p);
  222. // let p = SC::point(vec![Fr::from(4_u32)], false, 4, 16); // TODO expect error
  223. }
  224. #[test]
  225. fn test_partial_evaluate() {
  226. // g(X_0, X_1, X_2) = 2 X_0^3 + X_0 X_2 + X_1 X_2
  227. let terms = vec![
  228. (Fr::from(2u32), SparseTerm::new(vec![(0_usize, 3)])),
  229. (
  230. Fr::from(1u32),
  231. SparseTerm::new(vec![(0_usize, 1), (2_usize, 1)]),
  232. ),
  233. (
  234. Fr::from(1u32),
  235. SparseTerm::new(vec![(1_usize, 1), (2_usize, 1)]),
  236. ),
  237. ];
  238. let p = SparsePolynomial::from_coefficients_slice(3, &terms);
  239. type SC = SumCheck<Fr, G1Projective, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  240. let e0 = SC::partial_evaluate(&p, &[Some(Fr::from(2_u32)), None, Some(Fr::from(0_u32))]);
  241. assert_eq!(e0.coeffs(), vec![Fr::from(16_u32)]);
  242. let e1 = SC::partial_evaluate(&p, &[Some(Fr::from(2_u32)), None, Some(Fr::from(1_u32))]);
  243. assert_eq!(e1.coeffs(), vec![Fr::from(18_u32), Fr::from(1)]);
  244. assert_eq!((e0 + e1).coeffs(), vec![Fr::from(34_u32), Fr::from(1)]);
  245. }
  246. #[test]
  247. fn test_flow_hardcoded_values() {
  248. // g(X_0, X_1, X_2) = 2 X_0^3 + X_0 X_2 + X_1 X_2
  249. let terms = vec![
  250. (Fr::from(2u32), SparseTerm::new(vec![(0_usize, 3)])),
  251. (
  252. Fr::from(1u32),
  253. SparseTerm::new(vec![(0_usize, 1), (2_usize, 1)]),
  254. ),
  255. (
  256. Fr::from(1u32),
  257. SparseTerm::new(vec![(1_usize, 1), (2_usize, 1)]),
  258. ),
  259. ];
  260. let p = SparsePolynomial::from_coefficients_slice(3, &terms);
  261. // println!("p {:?}", p);
  262. let poseidon_config = poseidon_test_config::<Fr>();
  263. type SC = SumCheck<Fr, G1Projective, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  264. let proof = SC::prove(&poseidon_config, p);
  265. assert_eq!(proof.0, Fr::from(12_u32));
  266. // println!("proof {:?}", proof);
  267. let v = SC::verify(&poseidon_config, proof);
  268. assert!(v);
  269. }
  270. fn rand_poly<R: Rng>(l: usize, d: usize, rng: &mut R) -> SparsePolynomial<Fr, SparseTerm> {
  271. // This method is from the arkworks/algebra/poly/multivariate test:
  272. // https://github.com/arkworks-rs/algebra/blob/bc991d44c5e579025b7ed56df3d30267a7b9acac/poly/src/polynomial/multivariate/sparse.rs#L303
  273. let mut random_terms = Vec::new();
  274. let num_terms = rng.gen_range(1..1000);
  275. // For each term, randomly select up to `l` variables with degree
  276. // in [1,d] and random coefficient
  277. random_terms.push((Fr::rand(rng), SparseTerm::new(vec![])));
  278. for _ in 1..num_terms {
  279. let term = (0..l)
  280. .map(|i| {
  281. if rng.gen_bool(0.5) {
  282. Some((i, rng.gen_range(1..(d + 1))))
  283. } else {
  284. None
  285. }
  286. })
  287. .flatten()
  288. .collect();
  289. let coeff = Fr::rand(rng);
  290. random_terms.push((coeff, SparseTerm::new(term)));
  291. }
  292. SparsePolynomial::from_coefficients_slice(l, &random_terms)
  293. }
  294. #[test]
  295. fn test_flow_rng() {
  296. let mut rng = ark_std::test_rng();
  297. // let p = SparsePolynomial::<Fr, SparseTerm>::rand(3, 3, &mut rng);
  298. let p = rand_poly(3, 3, &mut rng);
  299. // println!("p {:?}", p);
  300. let poseidon_config = poseidon_test_config::<Fr>();
  301. type SC = SumCheck<Fr, G1Projective, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  302. let proof = SC::prove(&poseidon_config, p);
  303. // println!("proof.s len {:?}", proof.1.len());
  304. let v = SC::verify(&poseidon_config, proof);
  305. assert!(v);
  306. }
  307. }