You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

372 lines
13 KiB

  1. // this file contains a sum-check protocol initial implementation, not used by the rest of the repo
  2. // but implemented as an exercise and it will probably be used in the future.
  3. use ark_ec::CurveGroup;
  4. use ark_ff::{BigInteger, PrimeField};
  5. use ark_poly::{
  6. multivariate::{SparsePolynomial, SparseTerm, Term},
  7. DenseMVPolynomial, DenseUVPolynomial, Polynomial,
  8. };
  9. use ark_std::log2;
  10. use ark_std::marker::PhantomData;
  11. use ark_crypto_primitives::sponge::{poseidon::PoseidonConfig, Absorb};
  12. use crate::transcript::Transcript;
  13. pub struct Point<F: PrimeField> {
  14. _f: PhantomData<F>,
  15. }
  16. impl<F: PrimeField> Point<F> {
  17. pub fn point_normal(n_elems: usize, iter_num: usize) -> Vec<F> {
  18. let p = Self::point(vec![], false, n_elems, iter_num);
  19. let mut r = vec![F::zero(); n_elems];
  20. for i in 0..n_elems {
  21. r[i] = p[i].unwrap();
  22. }
  23. r
  24. }
  25. pub fn point_complete(challenges: Vec<F>, n_elems: usize, iter_num: usize) -> Vec<F> {
  26. let p = Self::point(challenges, false, n_elems, iter_num);
  27. let mut r = vec![F::zero(); n_elems];
  28. for i in 0..n_elems {
  29. r[i] = p[i].unwrap();
  30. }
  31. r
  32. }
  33. fn point(challenges: Vec<F>, none: bool, n_elems: usize, iter_num: usize) -> Vec<Option<F>> {
  34. let mut n_vars = n_elems - challenges.len();
  35. assert!(n_vars >= log2(iter_num + 1) as usize);
  36. if none {
  37. // WIP
  38. if n_vars == 0 {
  39. panic!("err"); // or return directly challenges vector
  40. }
  41. n_vars -= 1;
  42. }
  43. let none_pos = if none {
  44. challenges.len() + 1
  45. } else {
  46. challenges.len()
  47. };
  48. let mut p: Vec<Option<F>> = vec![None; n_elems];
  49. for i in 0..challenges.len() {
  50. p[i] = Some(challenges[i]);
  51. }
  52. for i in 0..n_vars {
  53. let k = F::from(iter_num as u64).into_bigint().to_bytes_le();
  54. let bit = k[i / 8] & (1 << (i % 8));
  55. if bit == 0 {
  56. p[none_pos + i] = Some(F::zero());
  57. } else {
  58. p[none_pos + i] = Some(F::one());
  59. }
  60. }
  61. p
  62. }
  63. }
  64. pub struct SumCheck<
  65. F: PrimeField + Absorb,
  66. C: CurveGroup,
  67. UV: DenseUVPolynomial<F>,
  68. MV: DenseMVPolynomial<F>,
  69. > {
  70. _f: PhantomData<F>,
  71. _c: PhantomData<C>,
  72. _uv: PhantomData<UV>,
  73. _mv: PhantomData<MV>,
  74. }
  75. impl<
  76. F: PrimeField + Absorb,
  77. C: CurveGroup,
  78. UV: Polynomial<F> + DenseUVPolynomial<F>,
  79. MV: Polynomial<F> + DenseMVPolynomial<F>,
  80. > SumCheck<F, C, UV, MV>
  81. where
  82. <C as CurveGroup>::BaseField: Absorb,
  83. {
  84. fn partial_evaluate(g: &MV, point: &[Option<F>]) -> UV {
  85. assert!(point.len() >= g.num_vars(), "Invalid evaluation domain");
  86. // TODO: add check: there can only be 1 'None' value in point
  87. if g.is_zero() {
  88. return UV::from_coefficients_vec(vec![F::zero()]);
  89. }
  90. // note: this can be parallelized with cfg_into_iter
  91. let mut univ_terms: Vec<(F, SparseTerm)> = vec![];
  92. for (coef, term) in g.terms().iter() {
  93. // partial_evaluate each term
  94. let mut new_coef = F::one();
  95. let mut new_term = Vec::new();
  96. for (var, power) in term.iter() {
  97. match point[*var] {
  98. Some(v) => {
  99. if v.is_zero() {
  100. new_coef = F::zero();
  101. new_term = vec![];
  102. break;
  103. } else {
  104. new_coef *= v.pow([(*power) as u64]);
  105. }
  106. }
  107. _ => {
  108. new_term.push((*var, *power));
  109. }
  110. };
  111. }
  112. let new_term = SparseTerm::new(new_term);
  113. let new_coef = new_coef * coef;
  114. univ_terms.push((new_coef, new_term));
  115. }
  116. let mv_poly: SparsePolynomial<F, SparseTerm> =
  117. DenseMVPolynomial::<F>::from_coefficients_vec(g.num_vars(), univ_terms.clone());
  118. let mut univ_coeffs: Vec<F> = vec![F::zero(); mv_poly.degree() + 1];
  119. for (coef, term) in univ_terms {
  120. if term.is_empty() {
  121. univ_coeffs[0] += coef;
  122. continue;
  123. }
  124. for (_, power) in term.iter() {
  125. univ_coeffs[*power] += coef;
  126. }
  127. }
  128. UV::from_coefficients_vec(univ_coeffs)
  129. }
  130. pub fn prove(poseidon_config: &PoseidonConfig<F>, g: MV) -> (F, Vec<UV>, F)
  131. where
  132. <MV as Polynomial<F>>::Point: From<Vec<F>>,
  133. {
  134. // init transcript
  135. let mut transcript = Transcript::<F, C>::new(poseidon_config);
  136. let v = g.num_vars();
  137. // compute T
  138. let mut T = F::zero();
  139. for i in 0..(2_u64.pow(v as u32) as usize) {
  140. let p = Point::<F>::point_complete(vec![], v, i);
  141. T += g.evaluate(&p.into());
  142. }
  143. transcript.add(&T);
  144. let mut ss: Vec<UV> = Vec::new();
  145. let mut r: Vec<F> = vec![];
  146. for i in 0..v {
  147. let r_i = transcript.get_challenge();
  148. r.push(r_i);
  149. let var_slots = v - 1 - i;
  150. let n_points = 2_u64.pow(var_slots as u32) as usize;
  151. let mut s_i = UV::zero();
  152. for j in 0..n_points {
  153. let point = Point::<F>::point(r[..i].to_vec(), true, v, j);
  154. s_i = s_i + Self::partial_evaluate(&g, &point);
  155. }
  156. transcript.add_vec(s_i.coeffs());
  157. ss.push(s_i);
  158. }
  159. let last_g_eval = g.evaluate(&r.into());
  160. // ss: intermediate univariate polynomials
  161. (T, ss, last_g_eval)
  162. }
  163. pub fn verify(poseidon_config: &PoseidonConfig<F>, proof: (F, Vec<UV>, F)) -> bool {
  164. // init transcript
  165. let mut transcript = Transcript::<F, C>::new(poseidon_config);
  166. transcript.add(&proof.0);
  167. let (c, ss, last_g_eval) = proof;
  168. let mut r: Vec<F> = vec![];
  169. for (i, s) in ss.iter().enumerate() {
  170. // TODO check degree
  171. if i == 0 {
  172. if c != s.evaluate(&F::zero()) + s.evaluate(&F::one()) {
  173. return false;
  174. }
  175. let r_i = transcript.get_challenge();
  176. r.push(r_i);
  177. transcript.add_vec(s.coeffs());
  178. continue;
  179. }
  180. let r_i = transcript.get_challenge();
  181. r.push(r_i);
  182. if ss[i - 1].evaluate(&r[i - 1]) != s.evaluate(&F::zero()) + s.evaluate(&F::one()) {
  183. return false;
  184. }
  185. transcript.add_vec(s.coeffs());
  186. }
  187. // last round
  188. if ss[ss.len() - 1].evaluate(&r[r.len() - 1]) != last_g_eval {
  189. return false;
  190. }
  191. true
  192. }
  193. }
  194. #[cfg(test)]
  195. mod tests {
  196. use super::*;
  197. use crate::transcript::poseidon_test_config;
  198. use ark_mnt4_298::{Fr, G1Projective}; // scalar field
  199. use ark_poly::{
  200. multivariate::{SparsePolynomial, SparseTerm, Term},
  201. univariate::DensePolynomial,
  202. DenseMVPolynomial, DenseUVPolynomial,
  203. };
  204. use ark_std::{rand::Rng, UniformRand};
  205. #[test]
  206. fn test_new_point() {
  207. let f4 = Fr::from(4_u32);
  208. let f1 = Fr::from(1);
  209. let f0 = Fr::from(0);
  210. type SC = SumCheck<Fr, G1Projective, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  211. type P = Point<Fr>;
  212. let p = P::point(vec![Fr::from(4_u32)], true, 5, 0);
  213. assert_eq!(vec![Some(f4), None, Some(f0), Some(f0), Some(f0),], p);
  214. let p = P::point(vec![Fr::from(4_u32)], true, 5, 1);
  215. assert_eq!(vec![Some(f4), None, Some(f1), Some(f0), Some(f0),], p);
  216. let p = P::point(vec![Fr::from(4_u32)], true, 5, 2);
  217. assert_eq!(vec![Some(f4), None, Some(f0), Some(f1), Some(f0),], p);
  218. let p = P::point(vec![Fr::from(4_u32)], true, 5, 3);
  219. assert_eq!(vec![Some(f4), None, Some(f1), Some(f1), Some(f0),], p);
  220. let p = P::point(vec![Fr::from(4_u32)], true, 5, 4);
  221. assert_eq!(vec![Some(f4), None, Some(f0), Some(f0), Some(f1),], p);
  222. // without None
  223. let p = P::point(vec![], false, 4, 0);
  224. assert_eq!(vec![Some(f0), Some(f0), Some(f0), Some(f0),], p);
  225. let p = P::point(vec![Fr::from(4_u32)], false, 5, 0);
  226. assert_eq!(vec![Some(f4), Some(f0), Some(f0), Some(f0), Some(f0),], p);
  227. let p = P::point(vec![Fr::from(4_u32)], false, 5, 1);
  228. assert_eq!(vec![Some(f4), Some(f1), Some(f0), Some(f0), Some(f0),], p);
  229. let p = P::point(vec![Fr::from(4_u32)], false, 5, 3);
  230. assert_eq!(vec![Some(f4), Some(f1), Some(f1), Some(f0), Some(f0),], p);
  231. let p = P::point(vec![Fr::from(4_u32)], false, 5, 4);
  232. assert_eq!(vec![Some(f4), Some(f0), Some(f0), Some(f1), Some(f0),], p);
  233. let p = P::point(vec![Fr::from(4_u32)], false, 5, 10);
  234. assert_eq!(vec![Some(f4), Some(f0), Some(f1), Some(f0), Some(f1),], p);
  235. let p = P::point(vec![Fr::from(4_u32)], false, 5, 15);
  236. assert_eq!(vec![Some(f4), Some(f1), Some(f1), Some(f1), Some(f1),], p);
  237. // let p = P::point(vec![Fr::from(4_u32)], false, 4, 16); // TODO expect error
  238. }
  239. #[test]
  240. fn test_partial_evaluate() {
  241. // g(X_0, X_1, X_2) = 2 X_0^3 + X_0 X_2 + X_1 X_2
  242. let terms = vec![
  243. (Fr::from(2u32), SparseTerm::new(vec![(0_usize, 3)])),
  244. (
  245. Fr::from(1u32),
  246. SparseTerm::new(vec![(0_usize, 1), (2_usize, 1)]),
  247. ),
  248. (
  249. Fr::from(1u32),
  250. SparseTerm::new(vec![(1_usize, 1), (2_usize, 1)]),
  251. ),
  252. ];
  253. let p = SparsePolynomial::from_coefficients_slice(3, &terms);
  254. type SC = SumCheck<Fr, G1Projective, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  255. let e0 = SC::partial_evaluate(&p, &[Some(Fr::from(2_u32)), None, Some(Fr::from(0_u32))]);
  256. assert_eq!(e0.coeffs(), vec![Fr::from(16_u32)]);
  257. let e1 = SC::partial_evaluate(&p, &[Some(Fr::from(2_u32)), None, Some(Fr::from(1_u32))]);
  258. assert_eq!(e1.coeffs(), vec![Fr::from(18_u32), Fr::from(1)]);
  259. assert_eq!((e0 + e1).coeffs(), vec![Fr::from(34_u32), Fr::from(1)]);
  260. }
  261. #[test]
  262. fn test_flow_hardcoded_values() {
  263. // g(X_0, X_1, X_2) = 2 X_0^3 + X_0 X_2 + X_1 X_2
  264. let terms = vec![
  265. (Fr::from(2u32), SparseTerm::new(vec![(0_usize, 3)])),
  266. (
  267. Fr::from(1u32),
  268. SparseTerm::new(vec![(0_usize, 1), (2_usize, 1)]),
  269. ),
  270. (
  271. Fr::from(1u32),
  272. SparseTerm::new(vec![(1_usize, 1), (2_usize, 1)]),
  273. ),
  274. ];
  275. let p = SparsePolynomial::from_coefficients_slice(3, &terms);
  276. // println!("p {:?}", p);
  277. let poseidon_config = poseidon_test_config::<Fr>();
  278. type SC = SumCheck<Fr, G1Projective, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  279. let proof = SC::prove(&poseidon_config, p);
  280. assert_eq!(proof.0, Fr::from(12_u32));
  281. // println!("proof {:?}", proof);
  282. let v = SC::verify(&poseidon_config, proof);
  283. assert!(v);
  284. }
  285. fn rand_poly<R: Rng>(l: usize, d: usize, rng: &mut R) -> SparsePolynomial<Fr, SparseTerm> {
  286. // This method is from the arkworks/algebra/poly/multivariate test:
  287. // https://github.com/arkworks-rs/algebra/blob/bc991d44c5e579025b7ed56df3d30267a7b9acac/poly/src/polynomial/multivariate/sparse.rs#L303
  288. let mut random_terms = Vec::new();
  289. let num_terms = rng.gen_range(1..1000);
  290. // For each term, randomly select up to `l` variables with degree
  291. // in [1,d] and random coefficient
  292. random_terms.push((Fr::rand(rng), SparseTerm::new(vec![])));
  293. for _ in 1..num_terms {
  294. let term = (0..l)
  295. .map(|i| {
  296. if rng.gen_bool(0.5) {
  297. Some((i, rng.gen_range(1..(d + 1))))
  298. } else {
  299. None
  300. }
  301. })
  302. .flatten()
  303. .collect();
  304. let coeff = Fr::rand(rng);
  305. random_terms.push((coeff, SparseTerm::new(term)));
  306. }
  307. SparsePolynomial::from_coefficients_slice(l, &random_terms)
  308. }
  309. #[test]
  310. fn test_flow_rng() {
  311. let mut rng = ark_std::test_rng();
  312. // let p = SparsePolynomial::<Fr, SparseTerm>::rand(3, 3, &mut rng);
  313. let p = rand_poly(3, 3, &mut rng);
  314. // println!("p {:?}", p);
  315. let poseidon_config = poseidon_test_config::<Fr>();
  316. type SC = SumCheck<Fr, G1Projective, DensePolynomial<Fr>, SparsePolynomial<Fr, SparseTerm>>;
  317. let proof = SC::prove(&poseidon_config, p);
  318. // println!("proof.s len {:?}", proof.1.len());
  319. let v = SC::verify(&poseidon_config, proof);
  320. assert!(v);
  321. }
  322. }