You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

302 lines
9.7 KiB

  1. // merkletree.rs implements a simple binary insert-only merkletree in which the leafs positions is
  2. // determined by the leaf value binary representation. Inspired by
  3. // https://docs.iden3.io/publications/pdfs/Merkle-Tree.pdf (which can be found implemented in
  4. // https://github.com/vocdoni/arbo).
  5. use ark_ff::{BigInteger, PrimeField};
  6. use ark_std::log2;
  7. use arkworks_native_gadgets::poseidon;
  8. use arkworks_native_gadgets::poseidon::FieldHasher;
  9. use arkworks_utils::{
  10. bytes_matrix_to_f, bytes_vec_to_f, poseidon_params::setup_poseidon_params, Curve,
  11. };
  12. pub struct Params<F: PrimeField> {
  13. pub poseidon_hash: poseidon::Poseidon<F>,
  14. }
  15. #[derive(Clone, Debug)]
  16. pub struct Node<F: PrimeField> {
  17. hash: F,
  18. left: Option<Box<Node<F>>>,
  19. right: Option<Box<Node<F>>>,
  20. value: Option<F>,
  21. }
  22. impl<F: PrimeField> Node<F> {
  23. pub fn new_leaf(params: &Params<F>, v: F) -> Self {
  24. let h = params.poseidon_hash.hash(&[v]).unwrap();
  25. Self {
  26. hash: h,
  27. left: None,
  28. right: None,
  29. value: Some(v),
  30. }
  31. }
  32. pub fn new_node(params: &Params<F>, l: Self, r: Self) -> Self {
  33. let left = Box::new(l);
  34. let right = Box::new(r);
  35. let hash = params.poseidon_hash.hash(&[left.hash, right.hash]).unwrap();
  36. Self {
  37. hash,
  38. left: Some(left),
  39. right: Some(right),
  40. value: None,
  41. }
  42. }
  43. }
  44. pub struct MerkleTree<F: PrimeField> {
  45. pub root: Node<F>,
  46. nlevels: u32,
  47. }
  48. impl<F: PrimeField> MerkleTree<F> {
  49. pub fn setup(poseidon_hash: &poseidon::Poseidon<F>) -> Params<F> {
  50. Params {
  51. poseidon_hash: poseidon_hash.clone(),
  52. }
  53. }
  54. pub fn new(params: &Params<F>, values: Vec<F>) -> Self {
  55. // for the moment assume that values length is a power of 2.
  56. if (values.len() != 0) && (values.len() & (values.len() - 1) != 0) {
  57. panic!("values.len() should be a power of 2");
  58. }
  59. // prepare the leafs
  60. let mut leaf_nodes: Vec<Node<F>> = Vec::new();
  61. for i in 0..values.len() {
  62. let node = Node::<F>::new_leaf(&params, values[i]);
  63. leaf_nodes.push(node);
  64. }
  65. // go up from the leafs to the root
  66. let top_nodes = Self::up_from_nodes(&params, leaf_nodes);
  67. Self {
  68. root: top_nodes[0].clone(),
  69. nlevels: log2(values.len()),
  70. }
  71. }
  72. fn up_from_nodes(params: &Params<F>, nodes: Vec<Node<F>>) -> Vec<Node<F>> {
  73. if nodes.len() == 0 {
  74. return [Node::<F> {
  75. hash: F::from(0_u32),
  76. left: None,
  77. right: None,
  78. value: None,
  79. }]
  80. .to_vec();
  81. }
  82. if nodes.len() == 1 {
  83. return nodes;
  84. }
  85. let mut next_level_nodes: Vec<Node<F>> = Vec::new();
  86. for i in (0..nodes.len()).step_by(2) {
  87. let node = Node::<F>::new_node(&params, nodes[i].clone(), nodes[i + 1].clone());
  88. next_level_nodes.push(node);
  89. }
  90. return Self::up_from_nodes(params, next_level_nodes);
  91. }
  92. fn get_path(num_levels: u32, value: F) -> Vec<bool> {
  93. let value_bytes = value.into_repr().to_bytes_le();
  94. let mut path = Vec::new();
  95. for i in 0..num_levels {
  96. path.push(value_bytes[(i / 8) as usize] & (1 << (i % 8)) != 0);
  97. }
  98. path
  99. }
  100. pub fn gen_proof(&self, index: usize) -> Vec<F> {
  101. // start from root, and go down to the index, while getting the siblings at each level
  102. let path = Self::get_path(self.nlevels, F::from(index as u32));
  103. // reverse path as we're going from up to down
  104. let path_inv = path.iter().copied().rev().collect();
  105. let mut siblings: Vec<F> = Vec::new();
  106. siblings = Self::go_down(path_inv, self.root.clone(), siblings);
  107. return siblings;
  108. }
  109. fn go_down(path: Vec<bool>, node: Node<F>, mut siblings: Vec<F>) -> Vec<F> {
  110. if !node.value.is_none() {
  111. return siblings;
  112. }
  113. if !path[0] {
  114. siblings.push(node.right.unwrap().hash);
  115. return Self::go_down(path[1..].to_vec(), *node.left.unwrap(), siblings);
  116. } else {
  117. siblings.push(node.left.unwrap().hash);
  118. return Self::go_down(path[1..].to_vec(), *node.right.unwrap(), siblings);
  119. }
  120. }
  121. pub fn verify(params: &Params<F>, root: F, index: usize, value: F, siblings: Vec<F>) -> bool {
  122. let mut h = params.poseidon_hash.hash(&[value]).unwrap();
  123. let path = Self::get_path(siblings.len() as u32, F::from(index as u32));
  124. for i in 0..siblings.len() {
  125. if !path[i] {
  126. h = params
  127. .poseidon_hash
  128. .hash(&[h, siblings[siblings.len() - 1 - i]])
  129. .unwrap();
  130. } else {
  131. h = params
  132. .poseidon_hash
  133. .hash(&[siblings[siblings.len() - 1 - i], h])
  134. .unwrap();
  135. }
  136. }
  137. if h == root {
  138. return true;
  139. }
  140. false
  141. }
  142. }
  143. pub struct MerkleTreePoseidon<F: PrimeField>(MerkleTree<F>);
  144. impl<F: PrimeField> MerkleTreePoseidon<F> {
  145. pub fn commit(values: &[F]) -> (F, Self) {
  146. let poseidon_params = poseidon_setup_params::<F>(Curve::Bn254, 5, 4);
  147. let poseidon_hash = poseidon::Poseidon::new(poseidon_params);
  148. let params = MerkleTree::setup(&poseidon_hash);
  149. let mt = MerkleTree::new(&params, values.to_vec());
  150. (mt.root.hash, MerkleTreePoseidon(mt))
  151. }
  152. pub fn prove(&self, index: usize) -> Vec<F> {
  153. self.0.gen_proof(index)
  154. }
  155. pub fn verify(root: F, index: usize, value: F, siblings: Vec<F>) -> bool {
  156. let poseidon_params = poseidon_setup_params::<F>(Curve::Bn254, 5, 4);
  157. let poseidon_hash = poseidon::Poseidon::new(poseidon_params);
  158. let params = MerkleTree::setup(&poseidon_hash);
  159. MerkleTree::verify(&params, root, index, value, siblings)
  160. }
  161. }
  162. pub fn poseidon_setup_params<F: PrimeField>(
  163. curve: Curve,
  164. exp: i8,
  165. width: u8,
  166. ) -> poseidon::PoseidonParameters<F> {
  167. let pos_data = setup_poseidon_params(curve, exp, width).unwrap();
  168. let mds_f = bytes_matrix_to_f(&pos_data.mds);
  169. let rounds_f = bytes_vec_to_f(&pos_data.rounds);
  170. poseidon::PoseidonParameters {
  171. mds_matrix: mds_f,
  172. round_keys: rounds_f,
  173. full_rounds: pos_data.full_rounds,
  174. partial_rounds: pos_data.partial_rounds,
  175. sbox: poseidon::sbox::PoseidonSbox(pos_data.exp),
  176. width: pos_data.width,
  177. }
  178. }
  179. #[cfg(test)]
  180. mod tests {
  181. use super::*;
  182. use ark_std::UniformRand;
  183. pub type Fr = ark_bn254::Fr; // scalar field
  184. #[test]
  185. fn test_path() {
  186. assert_eq!(
  187. MerkleTree::get_path(8, Fr::from(0_u32)),
  188. [false, false, false, false, false, false, false, false]
  189. );
  190. assert_eq!(
  191. MerkleTree::get_path(8, Fr::from(1_u32)),
  192. [true, false, false, false, false, false, false, false]
  193. );
  194. assert_eq!(
  195. MerkleTree::get_path(8, Fr::from(2_u32)),
  196. [false, true, false, false, false, false, false, false]
  197. );
  198. assert_eq!(
  199. MerkleTree::get_path(8, Fr::from(3_u32)),
  200. [true, true, false, false, false, false, false, false]
  201. );
  202. assert_eq!(
  203. MerkleTree::get_path(8, Fr::from(254_u32)),
  204. [false, true, true, true, true, true, true, true]
  205. );
  206. assert_eq!(
  207. MerkleTree::get_path(8, Fr::from(255_u32)),
  208. [true, true, true, true, true, true, true, true]
  209. );
  210. }
  211. #[test]
  212. fn test_new_empty_tree() {
  213. let poseidon_params = poseidon_setup_params::<Fr>(Curve::Bn254, 5, 4);
  214. let poseidon_hash = poseidon::Poseidon::new(poseidon_params);
  215. let params = MerkleTree::setup(&poseidon_hash);
  216. let mt = MerkleTree::new(&params, Vec::new());
  217. assert_eq!(mt.root.hash, Fr::from(0_u32));
  218. }
  219. #[test]
  220. fn test_proof() {
  221. let poseidon_params = poseidon_setup_params::<Fr>(Curve::Bn254, 5, 4);
  222. let poseidon_hash = poseidon::Poseidon::new(poseidon_params);
  223. let params = MerkleTree::setup(&poseidon_hash);
  224. let values = [
  225. Fr::from(0_u32),
  226. Fr::from(1_u32),
  227. Fr::from(2_u32),
  228. Fr::from(3_u32),
  229. Fr::from(200_u32),
  230. Fr::from(201_u32),
  231. Fr::from(202_u32),
  232. Fr::from(203_u32),
  233. ];
  234. let mt = MerkleTree::new(&params, values.to_vec());
  235. assert_eq!(
  236. mt.root.hash.to_string(),
  237. "Fp256 \"(10E90845D7A113686E4F2F30D73B315BA891A5DB9A58782F1260C35F99660794)\""
  238. );
  239. let index = 3;
  240. let siblings = mt.gen_proof(index);
  241. assert!(MerkleTree::verify(
  242. &params,
  243. mt.root.hash,
  244. index,
  245. values[index],
  246. siblings
  247. ));
  248. }
  249. #[test]
  250. fn test_proofs() {
  251. let poseidon_params = poseidon_setup_params::<Fr>(Curve::Bn254, 5, 4);
  252. let poseidon_hash = poseidon::Poseidon::new(poseidon_params);
  253. let params = MerkleTree::setup(&poseidon_hash);
  254. let mut rng = ark_std::test_rng();
  255. let n_values = 256;
  256. let mut values: Vec<Fr> = Vec::new();
  257. for _i in 0..n_values {
  258. let v = Fr::rand(&mut rng);
  259. values.push(v);
  260. }
  261. let mt = MerkleTree::new(&params, values.to_vec());
  262. assert_eq!(mt.nlevels, 8);
  263. for i in 0..n_values {
  264. let siblings = mt.gen_proof(i);
  265. assert!(MerkleTree::verify(
  266. &params,
  267. mt.root.hash,
  268. i,
  269. values[i],
  270. siblings
  271. ));
  272. }
  273. }
  274. }