You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

456 lines
17 KiB

  1. use std::collections::HashMap;
  2. #[macro_use]
  3. extern crate arrayref;
  4. extern crate tiny_keccak;
  5. extern crate rustc_hex;
  6. mod utils;
  7. mod node;
  8. const TYPENODEEMPTY: u8 = 0;
  9. const TYPENODENORMAL: u8 = 1;
  10. const TYPENODEFINAL: u8 = 2;
  11. const TYPENODEVALUE: u8 = 3;
  12. const EMPTYNODEVALUE: [u8;32] = [0;32];
  13. pub struct TestValue {
  14. bytes: Vec<u8>,
  15. index_length: u32,
  16. }
  17. pub trait Value {
  18. fn bytes(&self) -> &Vec<u8>;
  19. fn index_length(&self) -> u32;
  20. }
  21. impl Value for TestValue {
  22. fn bytes(&self) -> &Vec<u8> {
  23. &self.bytes
  24. }
  25. fn index_length(&self) -> u32 {
  26. self.index_length
  27. }
  28. }
  29. #[allow(dead_code)]
  30. pub struct Db {
  31. storage: HashMap<[u8;32], Vec<u8>>,
  32. }
  33. impl Db {
  34. pub fn insert(&mut self, k: [u8; 32], t: u8, il: u32, b: &mut Vec<u8>) {
  35. let mut v: Vec<u8>;
  36. v = [t].to_vec();
  37. let il_bytes = il.to_le_bytes();
  38. v.append(&mut il_bytes.to_vec()); // il_bytes are [u8;4] (4 bytes)
  39. v.append(b);
  40. self.storage.insert(k, v);
  41. }
  42. pub fn get(&self, k: &[u8;32]) -> (u8, u32, Vec<u8>) {
  43. if k.to_vec() == EMPTYNODEVALUE.to_vec() {
  44. return (0, 0, EMPTYNODEVALUE.to_vec());
  45. }
  46. match self.storage.get(k) {
  47. Some(x) => {
  48. let t = x[0];
  49. let il_bytes: [u8; 4] = [x[1], x[2], x[3], x[4]];
  50. let il = u32::from_le_bytes(il_bytes);
  51. let b = &x[5..];
  52. (t, il, b.to_vec())
  53. },
  54. None => (TYPENODEEMPTY, 0, EMPTYNODEVALUE.to_vec()),
  55. }
  56. }
  57. }
  58. pub fn new_db()-> Db {
  59. Db {
  60. storage: HashMap::new(),
  61. }
  62. }
  63. pub struct MerkleTree {
  64. #[allow(dead_code)]
  65. root: [u8; 32],
  66. #[allow(dead_code)]
  67. num_levels: u32,
  68. #[allow(dead_code)]
  69. sto: Db,
  70. }
  71. pub fn new(num_levels: u32) -> MerkleTree {
  72. MerkleTree {
  73. root: EMPTYNODEVALUE,
  74. num_levels,
  75. sto: new_db(),
  76. }
  77. }
  78. impl MerkleTree {
  79. pub fn add(&mut self, v: &TestValue) {
  80. #![allow(unused_variables)]
  81. #[allow(dead_code)]
  82. // println!("adding value: {:?}", v.bytes());
  83. // add the leaf that we are adding
  84. self.sto.insert(utils::hash_vec(v.bytes().to_vec()), TYPENODEVALUE, v.index_length(), &mut v.bytes().to_vec());
  85. let index = v.index_length() as usize;
  86. let hi = utils::hash_vec(v.bytes()[..index].to_vec());
  87. let ht = utils::hash_vec(v.bytes().to_vec());
  88. let path = utils::get_path(self.num_levels, hi);
  89. let mut siblings: Vec<[u8;32]> = Vec::new();
  90. let mut node_hash = self.root;
  91. for i in (0..self.num_levels-1).rev() {
  92. // get node
  93. // let (t, il, node_bytes) = self.sto.get(&utils::hash_vec(node_hash.to_vec()));
  94. let (t, il, node_bytes) = self.sto.get(&node_hash);
  95. if t == TYPENODEFINAL {
  96. let hi_child = utils::hash_vec(node_bytes.to_vec().split_at(il as usize).0.to_vec());
  97. let path_child = utils::get_path(self.num_levels, hi_child);
  98. let pos_diff = utils::compare_paths(&path_child, &path);
  99. if pos_diff == 999 { // TODO use a match here, and instead of 999 return something better
  100. println!("compare paths err");
  101. return;
  102. }
  103. let final_node_1_hash = utils::calc_hash_from_leaf_and_level(pos_diff, &path_child, utils::hash_vec(node_bytes.to_vec()));
  104. self.sto.insert(final_node_1_hash, TYPENODEFINAL, il, &mut node_bytes.to_vec());
  105. let final_node_2_hash = utils::calc_hash_from_leaf_and_level(pos_diff, &path, utils::hash_vec(v.bytes().to_vec()));
  106. self.sto.insert(final_node_2_hash, TYPENODEFINAL, v.index_length(), &mut v.bytes().to_vec());
  107. // parent node
  108. let parent_node: node::TreeNode;
  109. if path[pos_diff as usize] {
  110. parent_node = node::TreeNode {
  111. child_l: final_node_1_hash,
  112. child_r: final_node_2_hash,
  113. }} else {
  114. parent_node = node::TreeNode {
  115. child_l: final_node_2_hash,
  116. child_r: final_node_1_hash,
  117. }}
  118. let empties = utils::get_empties_between_i_and_pos(i, pos_diff+1);
  119. for empty in &empties {
  120. siblings.push(*empty);
  121. }
  122. let path_from_pos_diff = utils::cut_path(&path, (pos_diff +1) as usize);
  123. self.root = self.replace_leaf(path_from_pos_diff, siblings.clone(), parent_node.ht(), TYPENODENORMAL, 0, &mut parent_node.bytes().to_vec());
  124. return;
  125. }
  126. let node = node::parse_node_bytes(node_bytes);
  127. let sibling: [u8;32];
  128. if path[i as usize] {
  129. node_hash = node.child_l;
  130. sibling = node.child_r;
  131. } else {
  132. sibling = node.child_l;
  133. node_hash = node.child_r;
  134. }
  135. siblings.push(*array_ref!(sibling, 0, 32));
  136. if node_hash == EMPTYNODEVALUE {
  137. if i==self.num_levels-2 && siblings[siblings.len()-1]==EMPTYNODEVALUE {
  138. let final_node_hash = utils::calc_hash_from_leaf_and_level(i+1, &path, utils::hash_vec(v.bytes().to_vec()));
  139. self.sto.insert(final_node_hash, TYPENODEFINAL, v.index_length(), &mut v.bytes().to_vec());
  140. self.root = final_node_hash;
  141. return;
  142. }
  143. let final_node_hash = utils::calc_hash_from_leaf_and_level(i, &path, utils::hash_vec(v.bytes().to_vec()));
  144. let path_from_i = utils::cut_path(&path, i as usize);
  145. self.root = self.replace_leaf(path_from_i, siblings.clone(), final_node_hash, TYPENODEFINAL, v.index_length(), &mut v.bytes().to_vec());
  146. return;
  147. }
  148. }
  149. self.root = self.replace_leaf(path, siblings, utils::hash_vec(v.bytes().to_vec()), TYPENODEVALUE, v.index_length(), &mut v.bytes().to_vec());
  150. }
  151. #[allow(dead_code)]
  152. pub fn replace_leaf(&mut self, path: Vec<bool>, siblings: Vec<[u8;32]>, leaf_hash: [u8;32], node_type: u8, index_length: u32, leaf_value: &mut Vec<u8>) -> [u8;32] {
  153. self.sto.insert(leaf_hash, node_type, index_length, leaf_value);
  154. let mut curr_node = leaf_hash;
  155. for i in 0..siblings.len() {
  156. if path[i as usize] {
  157. let node = node::TreeNode {
  158. child_l: curr_node,
  159. child_r: siblings[siblings.len()-1-i],
  160. };
  161. self.sto.insert(node.ht(), TYPENODENORMAL, 0, &mut node.bytes());
  162. curr_node = node.ht();
  163. } else {
  164. let node = node::TreeNode {
  165. child_l: siblings[siblings.len()-1-i],
  166. child_r: curr_node,
  167. };
  168. self.sto.insert(node.ht(), TYPENODENORMAL, 0, &mut node.bytes());
  169. curr_node = node.ht();
  170. }
  171. }
  172. curr_node
  173. }
  174. #[allow(dead_code)]
  175. pub fn get_value_in_pos(&self, hi: [u8;32]) -> Vec<u8> {
  176. let path = utils::get_path(self.num_levels, hi);
  177. let mut node_hash = self.root;
  178. for i in (0..self.num_levels-1).rev() {
  179. let (t, il, node_bytes) = self.sto.get(&node_hash);
  180. if t == TYPENODEFINAL {
  181. let hi_node = utils::hash_vec(node_bytes.to_vec().split_at(il as usize).0.to_vec());
  182. let path_node = utils::get_path(self.num_levels, hi_node);
  183. let pos_diff = utils::compare_paths(&path_node, &path);
  184. // if pos_diff > self.num_levels {
  185. if pos_diff != 999 {
  186. return EMPTYNODEVALUE.to_vec();
  187. }
  188. return node_bytes;
  189. }
  190. let node = node::parse_node_bytes(node_bytes);
  191. if !path[i as usize] {
  192. node_hash = node.child_l;
  193. } else {
  194. node_hash = node.child_r;
  195. }
  196. }
  197. let (_t, _il, node_bytes) = self.sto.get(&node_hash);
  198. node_bytes
  199. }
  200. #[allow(dead_code)]
  201. pub fn generate_proof(&self, hi: [u8;32]) -> Vec<u8> {
  202. let mut mp: Vec<u8> = Vec::new();
  203. let mut empties: [u8;32] = [0;32];
  204. let path = utils::get_path(self.num_levels, hi);
  205. let mut siblings: Vec<[u8;32]> = Vec::new();
  206. let mut node_hash = self.root;
  207. for i in 0..self.num_levels {
  208. let (t, il, node_bytes) = self.sto.get(&node_hash);
  209. if t == TYPENODEFINAL {
  210. let real_value_in_pos = self.get_value_in_pos(hi);
  211. if real_value_in_pos == EMPTYNODEVALUE {
  212. let leaf_hi = utils::hash_vec(node_bytes.to_vec().split_at(il as usize).0.to_vec());
  213. let path_child = utils::get_path(self.num_levels, leaf_hi);
  214. let pos_diff = utils::compare_paths(&path_child, &path);
  215. if pos_diff == self.num_levels { // TODO use a match here, and instead of 999 return something better
  216. return mp;
  217. }
  218. if pos_diff != self.num_levels-1-i {
  219. let sibling = utils::calc_hash_from_leaf_and_level(pos_diff, &path_child, utils::hash_vec(node_bytes.to_vec()));
  220. let mut new_siblings: Vec<[u8;32]> = Vec::new();
  221. new_siblings.push(sibling);
  222. new_siblings.append(&mut siblings);
  223. siblings = new_siblings;
  224. // set empties bit
  225. let bit_pos = self.num_levels-2-pos_diff;
  226. empties[(empties.len() as isize + (bit_pos as isize/8-1) as isize) as usize] |= 1 << (bit_pos%8);
  227. }
  228. }
  229. break
  230. }
  231. let node = node::parse_node_bytes(node_bytes);
  232. let sibling: [u8;32];
  233. if !path[self.num_levels as usize -i as usize -2] {
  234. node_hash = node.child_l;
  235. sibling = node.child_r;
  236. } else {
  237. sibling = node.child_l;
  238. node_hash = node.child_r;
  239. }
  240. if sibling != EMPTYNODEVALUE {
  241. // set empties bit
  242. empties[(empties.len() as isize + (i as isize/8-1) as isize) as usize] |= 1 << (i%8);
  243. let mut new_siblings: Vec<[u8;32]> = Vec::new();
  244. new_siblings.push(sibling);
  245. new_siblings.append(&mut siblings);
  246. siblings = new_siblings;
  247. }
  248. }
  249. mp.append(&mut empties[..].to_vec());
  250. for s in siblings {
  251. mp.append(&mut s.to_vec());
  252. }
  253. mp
  254. }
  255. }
  256. #[allow(dead_code)]
  257. pub fn verify_proof(root: [u8;32], mp: Vec<u8>, hi: [u8;32], ht: [u8;32], num_levels: u32) -> bool {
  258. let empties: Vec<u8>;
  259. empties = mp.split_at(32).0.to_vec();
  260. let mut siblings: Vec<[u8;32]> = Vec::new();
  261. for i in (empties.len()..mp.len()).step_by(EMPTYNODEVALUE.len()) {
  262. let mut sibling: [u8;32] = [0;32];
  263. sibling.copy_from_slice(&mp[i..i+EMPTYNODEVALUE.len()]);
  264. siblings.push(sibling);
  265. }
  266. let path = utils::get_path(num_levels, hi);
  267. let mut node_hash = ht;
  268. let mut sibling_used_pos = 0;
  269. for i in (0..num_levels-1).rev() {
  270. let sibling: [u8;32];
  271. if (empties[empties.len()-i as usize/8-1] & (1 << (i%8))) > 0 {
  272. sibling = siblings[sibling_used_pos];
  273. sibling_used_pos += 1;
  274. } else {
  275. sibling = EMPTYNODEVALUE;
  276. }
  277. let n: node::TreeNode;
  278. if path[num_levels as usize - i as usize-2] {
  279. n = node::TreeNode {
  280. child_l: sibling,
  281. child_r: node_hash,
  282. }} else {
  283. n = node::TreeNode {
  284. child_l: node_hash,
  285. child_r: sibling,
  286. }}
  287. if node_hash == EMPTYNODEVALUE && sibling == EMPTYNODEVALUE {
  288. node_hash = EMPTYNODEVALUE;
  289. } else {
  290. node_hash = n.ht();
  291. }
  292. }
  293. if node_hash==root {
  294. return true;
  295. }
  296. false
  297. }
  298. #[cfg(test)]
  299. mod tests {
  300. use super::*;
  301. use rustc_hex::ToHex;
  302. #[test]
  303. fn test_hash_vec() {
  304. let a: Vec<u8> = From::from("test");
  305. let h = utils::hash_vec(a);
  306. assert_eq!("9c22ff5f21f0b81b113e63f7db6da94fedef11b2119b4088b89664fb9a3cb658", h.to_hex());
  307. }
  308. #[test]
  309. fn test_new_mt() {
  310. let mt: MerkleTree = new(140);
  311. assert_eq!(140, mt.num_levels);
  312. assert_eq!("0000000000000000000000000000000000000000000000000000000000000000", mt.root.to_hex());
  313. let (_t, _il, b) = mt.sto.get(&[0;32]);
  314. assert_eq!(mt.root.to_vec(), b);
  315. }
  316. #[test]
  317. fn test_tree_node() {
  318. let n = node::TreeNode {
  319. child_l: [1;32],
  320. child_r: [2;32],
  321. };
  322. assert_eq!("01010101010101010101010101010101010101010101010101010101010101010202020202020202020202020202020202020202020202020202020202020202",
  323. n.bytes().to_hex());
  324. assert_eq!("346d8c96a2454213fcc0daff3c96ad0398148181b9fa6488f7ae2c0af5b20aa0", n.ht().to_hex());
  325. }
  326. #[test]
  327. fn test_add() {
  328. let mut mt: MerkleTree = new(140);
  329. assert_eq!("0000000000000000000000000000000000000000000000000000000000000000", mt.root.to_hex());
  330. let val = TestValue {
  331. bytes: vec![1,2,3,4,5],
  332. index_length: 3,
  333. };
  334. mt.add(&val);
  335. let (_t, _il, b) = mt.sto.get(&utils::hash_vec(val.bytes().to_vec()));
  336. assert_eq!(*val.bytes(), b);
  337. assert_eq!("a0e72cc948119fcb71b413cf5ada12b2b825d5133299b20a6d9325ffc3e2fbf1", mt.root.to_hex());
  338. }
  339. #[test]
  340. fn test_add_2() {
  341. let mut mt: MerkleTree = new(140);
  342. let val = TestValue {
  343. bytes: "this is a test leaf".as_bytes().to_vec(),
  344. index_length: 15,
  345. };
  346. assert_eq!("0000000000000000000000000000000000000000000000000000000000000000", mt.root.to_hex());
  347. mt.add(&val);
  348. let (_t, _il, b) = mt.sto.get(&utils::hash_vec(val.bytes().to_vec()));
  349. assert_eq!(*val.bytes(), b);
  350. assert_eq!("b4fdf8a653198f0e179ccb3af7e4fc09d76247f479d6cfc95cd92d6fda589f27", mt.root.to_hex());
  351. let val2 = TestValue {
  352. bytes: "this is a second test leaf".as_bytes().to_vec(),
  353. index_length: 15,
  354. };
  355. mt.add(&val2);
  356. let (_t, _il, b) = mt.sto.get(&utils::hash_vec(val2.bytes().to_vec()));
  357. assert_eq!(*val2.bytes(), b);
  358. assert_eq!("8ac95e9c8a6fbd40bb21de7895ee35f9c8f30ca029dbb0972c02344f49462e82", mt.root.to_hex());
  359. }
  360. #[test]
  361. fn test_generate_proof_and_verify_proof() {
  362. let mut mt: MerkleTree = new(140);
  363. let val = TestValue {
  364. bytes: "this is a test leaf".as_bytes().to_vec(),
  365. index_length: 15,
  366. };
  367. assert_eq!("0000000000000000000000000000000000000000000000000000000000000000", mt.root.to_hex());
  368. mt.add(&val);
  369. let (_t, _il, b) = mt.sto.get(&utils::hash_vec(val.bytes().to_vec()));
  370. assert_eq!(*val.bytes(), b);
  371. assert_eq!("b4fdf8a653198f0e179ccb3af7e4fc09d76247f479d6cfc95cd92d6fda589f27", mt.root.to_hex());
  372. let val2 = TestValue {
  373. bytes: "this is a second test leaf".as_bytes().to_vec(),
  374. index_length: 15,
  375. };
  376. mt.add(&val2);
  377. let (_t, _il, b) = mt.sto.get(&utils::hash_vec(val2.bytes().to_vec()));
  378. assert_eq!(*val2.bytes(), b);
  379. assert_eq!("8ac95e9c8a6fbd40bb21de7895ee35f9c8f30ca029dbb0972c02344f49462e82", mt.root.to_hex());
  380. let hi = utils::hash_vec(val2.bytes().to_vec().split_at(val2.index_length as usize).0.to_vec());
  381. let mp = mt.generate_proof(hi);
  382. assert_eq!("0000000000000000000000000000000000000000000000000000000000000001fd8e1a60cdb23c0c7b2cf8462c99fafd905054dccb0ed75e7c8a7d6806749b6b", mp.to_hex());
  383. // verify
  384. let ht = utils::hash_vec(val2.bytes().to_vec());
  385. let v = verify_proof(mt.root, mp, hi, ht, mt.num_levels);
  386. assert_eq!(true, v);
  387. }
  388. #[test]
  389. fn test_generate_proof_empty_leaf_and_verify_proof() {
  390. let mut mt: MerkleTree = new(140);
  391. let val = TestValue {
  392. bytes: "this is a test leaf".as_bytes().to_vec(),
  393. index_length: 15,
  394. };
  395. mt.add(&val);
  396. let val2 = TestValue {
  397. bytes: "this is a second test leaf".as_bytes().to_vec(),
  398. index_length: 15,
  399. };
  400. mt.add(&val2);
  401. assert_eq!("8ac95e9c8a6fbd40bb21de7895ee35f9c8f30ca029dbb0972c02344f49462e82", mt.root.to_hex());
  402. // proof of empty leaf
  403. let val3 = TestValue {
  404. bytes: "this is a third test leaf".as_bytes().to_vec(),
  405. index_length: 15,
  406. };
  407. let hi = utils::hash_vec(val3.bytes().to_vec().split_at(val3.index_length as usize).0.to_vec());
  408. let mp = mt.generate_proof(hi);
  409. assert_eq!("000000000000000000000000000000000000000000000000000000000000000389741fa23da77c259781ad8f4331a5a7d793eef1db7e5200ddfc8e5f5ca7ce2bfd8e1a60cdb23c0c7b2cf8462c99fafd905054dccb0ed75e7c8a7d6806749b6b", mp.to_hex());
  410. // verify that is a proof of an empty leaf (EMPTYNODEVALUE)
  411. let v = verify_proof(mt.root, mp, hi, EMPTYNODEVALUE, mt.num_levels);
  412. assert_eq!(true, v);
  413. }
  414. }