You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

405 lines
15 KiB

  1. use std::collections::HashMap;
  2. #[macro_use]
  3. extern crate arrayref;
  4. extern crate tiny_keccak;
  5. extern crate rustc_hex;
  6. mod utils;
  7. mod node;
  8. const TYPENODEEMPTY: u8 = 0;
  9. const TYPENODENORMAL: u8 = 1;
  10. const TYPENODEFINAL: u8 = 2;
  11. const TYPENODEVALUE: u8 = 3;
  12. const EMPTYNODEVALUE: [u8;32] = [0;32];
  13. pub struct TestValue {
  14. bytes: Vec<u8>,
  15. index_length: u32,
  16. }
  17. pub trait Value {
  18. fn bytes(&self) -> &Vec<u8>;
  19. fn index_length(&self) -> u32;
  20. }
  21. impl Value for TestValue {
  22. fn bytes(&self) -> &Vec<u8> {
  23. &self.bytes
  24. }
  25. fn index_length(&self) -> u32 {
  26. self.index_length
  27. }
  28. }
  29. #[allow(dead_code)]
  30. pub struct Db {
  31. storage: HashMap<[u8;32], Vec<u8>>,
  32. }
  33. impl Db {
  34. pub fn insert(&mut self, k: [u8; 32], t: u8, il: u32, b: &mut Vec<u8>) {
  35. let mut v: Vec<u8>;
  36. v = [t].to_vec();
  37. let il_bytes = il.to_le_bytes();
  38. v.append(&mut il_bytes.to_vec()); // il_bytes are [u8;4] (4 bytes)
  39. v.append(b);
  40. self.storage.insert(k, v);
  41. }
  42. pub fn get(&self, k: &[u8;32]) -> (u8, u32, Vec<u8>) {
  43. if k.to_vec() == EMPTYNODEVALUE.to_vec() {
  44. return (0, 0, EMPTYNODEVALUE.to_vec());
  45. }
  46. match self.storage.get(k) {
  47. Some(x) => {
  48. let t = x[0];
  49. let il_bytes: [u8; 4] = [x[1], x[2], x[3], x[4]];
  50. let il = u32::from_le_bytes(il_bytes);
  51. let b = &x[5..];
  52. return (t, il, b.to_vec());
  53. },
  54. None => return (TYPENODEEMPTY, 0, EMPTYNODEVALUE.to_vec()),
  55. }
  56. }
  57. }
  58. pub fn new_db()-> Db {
  59. let db = Db {
  60. storage: HashMap::new(),
  61. };
  62. db
  63. }
  64. pub struct MerkleTree {
  65. #[allow(dead_code)]
  66. root: [u8; 32],
  67. #[allow(dead_code)]
  68. num_levels: u32,
  69. #[allow(dead_code)]
  70. sto: Db,
  71. }
  72. pub fn new(num_levels: u32) -> MerkleTree {
  73. let mt = MerkleTree {
  74. root: EMPTYNODEVALUE,
  75. num_levels: num_levels,
  76. sto: new_db(),
  77. };
  78. mt
  79. }
  80. impl MerkleTree {
  81. pub fn add(&mut self, v: &TestValue) {
  82. #![allow(unused_variables)]
  83. #[allow(dead_code)]
  84. // println!("adding value: {:?}", v.bytes());
  85. // add the leaf that we are adding
  86. self.sto.insert(utils::hash_vec(v.bytes().to_vec()), TYPENODEVALUE, v.index_length(), &mut v.bytes().to_vec());
  87. let index = v.index_length() as usize;
  88. let hi = utils::hash_vec(v.bytes()[..index].to_vec());
  89. let ht = utils::hash_vec(v.bytes().to_vec());
  90. let path = utils::get_path(self.num_levels, hi);
  91. let mut siblings: Vec<[u8;32]> = Vec::new();
  92. let mut node_hash = self.root;
  93. for i in (0..self.num_levels-1).rev() {
  94. // get node
  95. // let (t, il, node_bytes) = self.sto.get(&utils::hash_vec(node_hash.to_vec()));
  96. let (t, il, node_bytes) = self.sto.get(&node_hash);
  97. if t == TYPENODEFINAL {
  98. let hi_child = utils::hash_vec(node_bytes.to_vec().split_at(il as usize).0.to_vec());
  99. let path_child = utils::get_path(self.num_levels, hi_child);
  100. let pos_diff = utils::compare_paths(path_child.clone(), path.clone());
  101. if pos_diff == 999 { // TODO use a match here, and instead of 999 return something better
  102. println!("compare paths err");
  103. return;
  104. }
  105. let final_node_1_hash = utils::calc_hash_from_leaf_and_level(pos_diff, path_child.clone(), utils::hash_vec(node_bytes.to_vec()));
  106. self.sto.insert(final_node_1_hash, TYPENODEFINAL, il, &mut node_bytes.to_vec());
  107. let final_node_2_hash = utils::calc_hash_from_leaf_and_level(pos_diff, path.clone(), utils::hash_vec(v.bytes().to_vec()));
  108. self.sto.insert(final_node_2_hash, TYPENODEFINAL, v.index_length(), &mut v.bytes().to_vec());
  109. // parent node
  110. let parent_node: node::TreeNode;
  111. if path[pos_diff as usize] {
  112. parent_node = node::TreeNode {
  113. child_l: final_node_1_hash,
  114. child_r: final_node_2_hash,
  115. };
  116. } else {
  117. parent_node = node::TreeNode {
  118. child_l: final_node_2_hash,
  119. child_r: final_node_1_hash,
  120. };
  121. }
  122. let empties = utils::get_empties_between_i_and_pos(i, pos_diff+1);
  123. for j in 0..empties.len() {
  124. siblings.push(empties[j]);
  125. }
  126. let path_from_pos_diff = utils::cut_path(path.clone(), (pos_diff +1) as usize);
  127. self.root = self.replace_leaf(path_from_pos_diff, siblings.clone(), parent_node.ht(), TYPENODENORMAL, 0, &mut parent_node.bytes().to_vec());
  128. return;
  129. }
  130. let node = node::parse_node_bytes(node_bytes);
  131. let sibling: [u8;32];
  132. if path.clone().into_iter().nth(i as usize).unwrap() {
  133. node_hash = node.child_l;
  134. sibling = node.child_r;
  135. } else {
  136. sibling = node.child_l;
  137. node_hash = node.child_r;
  138. }
  139. siblings.push(*array_ref!(sibling, 0, 32));
  140. if node_hash == EMPTYNODEVALUE {
  141. if i==self.num_levels-2 && siblings[siblings.len()-1]==EMPTYNODEVALUE {
  142. let final_node_hash = utils::calc_hash_from_leaf_and_level(i+1, path.clone(), utils::hash_vec(v.bytes().to_vec()));
  143. self.sto.insert(final_node_hash, TYPENODEFINAL, v.index_length(), &mut v.bytes().to_vec());
  144. self.root = final_node_hash;
  145. return;
  146. }
  147. let final_node_hash = utils::calc_hash_from_leaf_and_level(i, path.clone(), utils::hash_vec(v.bytes().to_vec()));
  148. let path_from_i = utils::cut_path(path.clone(), i as usize);
  149. self.root = self.replace_leaf(path_from_i, siblings.clone(), final_node_hash, TYPENODEFINAL, v.index_length(), &mut v.bytes().to_vec());
  150. return;
  151. }
  152. }
  153. self.root = self.replace_leaf(path, siblings, utils::hash_vec(v.bytes().to_vec()), TYPENODEVALUE, v.index_length(), &mut v.bytes().to_vec());
  154. }
  155. #[allow(dead_code)]
  156. pub fn replace_leaf(&mut self, path: Vec<bool>, siblings: Vec<[u8;32]>, leaf_hash: [u8;32], node_type: u8, index_length: u32, leaf_value: &mut Vec<u8>) -> [u8;32] {
  157. self.sto.insert(leaf_hash, node_type, index_length, leaf_value);
  158. let mut curr_node = leaf_hash;
  159. for i in 0..siblings.len() {
  160. if path.clone().into_iter().nth(i as usize).unwrap() {
  161. let node = node::TreeNode {
  162. child_l: curr_node,
  163. child_r: siblings[siblings.len()-1-i],
  164. };
  165. self.sto.insert(node.ht(), TYPENODENORMAL, 0, &mut node.bytes());
  166. curr_node = node.ht();
  167. } else {
  168. let node = node::TreeNode {
  169. child_l: siblings[siblings.len()-1-i],
  170. child_r: curr_node,
  171. };
  172. self.sto.insert(node.ht(), TYPENODENORMAL, 0, &mut node.bytes());
  173. curr_node = node.ht();
  174. }
  175. }
  176. curr_node
  177. }
  178. #[allow(dead_code)]
  179. pub fn get_value_in_pos(&self, hi: [u8;32]) -> Vec<u8> {
  180. let path = utils::get_path(self.num_levels, hi);
  181. let mut node_hash = self.root;
  182. for i in (0..self.num_levels-1).rev() {
  183. let (t, il, node_bytes) = self.sto.get(&node_hash);
  184. if t == TYPENODEFINAL {
  185. let hi_node = utils::hash_vec(node_bytes.to_vec().split_at(il as usize).0.to_vec());
  186. let path_node = utils::get_path(self.num_levels, hi_node);
  187. let pos_diff = utils::compare_paths(path_node.clone(), path.clone());
  188. // if pos_diff > self.num_levels {
  189. if pos_diff != 999 {
  190. return EMPTYNODEVALUE.to_vec();
  191. }
  192. return node_bytes;
  193. }
  194. let node = node::parse_node_bytes(node_bytes);
  195. if !path.clone().into_iter().nth(i as usize).unwrap() {
  196. node_hash = node.child_l;
  197. } else {
  198. node_hash = node.child_r;
  199. }
  200. }
  201. let (_t, _il, node_bytes) = self.sto.get(&node_hash);
  202. node_bytes
  203. }
  204. #[allow(dead_code)]
  205. pub fn generate_proof(&self, hi: [u8;32]) -> Vec<u8> {
  206. let mut mp: Vec<u8> = Vec::new();
  207. let mut empties: [u8;32] = [0;32];
  208. let path = utils::get_path(self.num_levels, hi);
  209. let mut siblings: Vec<[u8;32]> = Vec::new();
  210. let mut node_hash = self.root;
  211. for i in 0..self.num_levels {
  212. let (t, il, node_bytes) = self.sto.get(&node_hash);
  213. if t == TYPENODEFINAL {
  214. let real_value_in_pos = self.get_value_in_pos(hi);
  215. if real_value_in_pos == EMPTYNODEVALUE {
  216. let leaf_hi = utils::hash_vec(node_bytes.to_vec().split_at(il as usize).0.to_vec());
  217. let path_child = utils::get_path(self.num_levels, leaf_hi);
  218. let pos_diff = utils::compare_paths(path_child.clone(), path.clone());
  219. if pos_diff == self.num_levels { // TODO use a match here, and instead of 999 return something better
  220. return mp;
  221. }
  222. if pos_diff != self.num_levels-1-i {
  223. let sibling = utils::calc_hash_from_leaf_and_level(pos_diff, path_child.clone(), utils::hash_vec(node_bytes.to_vec()));
  224. let mut new_siblings: Vec<[u8;32]> = Vec::new();
  225. new_siblings.push(sibling);
  226. new_siblings.append(&mut siblings);
  227. siblings = new_siblings;
  228. // set empties bit
  229. let bit_pos = self.num_levels-2-pos_diff;
  230. empties[(empties.len() as isize + (bit_pos as isize/8-1) as isize) as usize] |= 1 << bit_pos%8;
  231. }
  232. }
  233. break
  234. }
  235. let node = node::parse_node_bytes(node_bytes);
  236. let sibling: [u8;32];
  237. if !path.clone().into_iter().nth(self.num_levels as usize -i as usize-2 as usize).unwrap() {
  238. node_hash = node.child_l;
  239. sibling = node.child_r;
  240. } else {
  241. sibling = node.child_l;
  242. node_hash = node.child_r;
  243. }
  244. if sibling != EMPTYNODEVALUE {
  245. // set empties bit
  246. empties[(empties.len() as isize + (i as isize/8-1) as isize) as usize] |= 1 << i%8;
  247. let mut new_siblings: Vec<[u8;32]> = Vec::new();
  248. new_siblings.push(sibling);
  249. new_siblings.append(&mut siblings);
  250. siblings = new_siblings;
  251. }
  252. }
  253. mp.append(&mut empties[..].to_vec());
  254. for s in siblings {
  255. mp.append(&mut s.to_vec());
  256. }
  257. mp
  258. }
  259. }
  260. #[cfg(test)]
  261. mod tests {
  262. use super::*;
  263. use rustc_hex::ToHex;
  264. #[test]
  265. fn test_hash_vec() {
  266. let a: Vec<u8> = From::from("test");
  267. let h = utils::hash_vec(a);
  268. assert_eq!("9c22ff5f21f0b81b113e63f7db6da94fedef11b2119b4088b89664fb9a3cb658", h.to_hex());
  269. }
  270. #[test]
  271. fn test_new_mt() {
  272. let mt: MerkleTree = new(140);
  273. assert_eq!(140, mt.num_levels);
  274. assert_eq!("0000000000000000000000000000000000000000000000000000000000000000", mt.root.to_hex());
  275. let (_t, _il, b) = mt.sto.get(&[0;32]);
  276. assert_eq!(mt.root.to_vec(), b);
  277. }
  278. #[test]
  279. fn test_tree_node() {
  280. let n = node::TreeNode {
  281. child_l: [1;32],
  282. child_r: [2;32],
  283. };
  284. assert_eq!("01010101010101010101010101010101010101010101010101010101010101010202020202020202020202020202020202020202020202020202020202020202",
  285. n.bytes().to_hex());
  286. assert_eq!("346d8c96a2454213fcc0daff3c96ad0398148181b9fa6488f7ae2c0af5b20aa0", n.ht().to_hex());
  287. }
  288. #[test]
  289. fn test_add() {
  290. let mut mt: MerkleTree = new(140);
  291. assert_eq!("0000000000000000000000000000000000000000000000000000000000000000", mt.root.to_hex());
  292. let val = TestValue {
  293. bytes: vec![1,2,3,4,5],
  294. index_length: 3,
  295. };
  296. mt.add(&val);
  297. let (_t, _il, b) = mt.sto.get(&utils::hash_vec(val.bytes().to_vec()));
  298. assert_eq!(*val.bytes(), b);
  299. assert_eq!("a0e72cc948119fcb71b413cf5ada12b2b825d5133299b20a6d9325ffc3e2fbf1", mt.root.to_hex());
  300. }
  301. #[test]
  302. fn test_add_2() {
  303. let mut mt: MerkleTree = new(140);
  304. let val = TestValue {
  305. bytes: "this is a test leaf".as_bytes().to_vec(),
  306. index_length: 15,
  307. };
  308. assert_eq!("0000000000000000000000000000000000000000000000000000000000000000", mt.root.to_hex());
  309. mt.add(&val);
  310. let (_t, _il, b) = mt.sto.get(&utils::hash_vec(val.bytes().to_vec()));
  311. assert_eq!(*val.bytes(), b);
  312. assert_eq!("b4fdf8a653198f0e179ccb3af7e4fc09d76247f479d6cfc95cd92d6fda589f27", mt.root.to_hex());
  313. let val2 = TestValue {
  314. bytes: "this is a second test leaf".as_bytes().to_vec(),
  315. index_length: 15,
  316. };
  317. mt.add(&val2);
  318. let (_t, _il, b) = mt.sto.get(&utils::hash_vec(val2.bytes().to_vec()));
  319. assert_eq!(*val2.bytes(), b);
  320. assert_eq!("8ac95e9c8a6fbd40bb21de7895ee35f9c8f30ca029dbb0972c02344f49462e82", mt.root.to_hex());
  321. }
  322. #[test]
  323. fn test_generate_proof() {
  324. let mut mt: MerkleTree = new(140);
  325. let val = TestValue {
  326. bytes: "this is a test leaf".as_bytes().to_vec(),
  327. index_length: 15,
  328. };
  329. assert_eq!("0000000000000000000000000000000000000000000000000000000000000000", mt.root.to_hex());
  330. mt.add(&val);
  331. let (_t, _il, b) = mt.sto.get(&utils::hash_vec(val.bytes().to_vec()));
  332. assert_eq!(*val.bytes(), b);
  333. assert_eq!("b4fdf8a653198f0e179ccb3af7e4fc09d76247f479d6cfc95cd92d6fda589f27", mt.root.to_hex());
  334. let val2 = TestValue {
  335. bytes: "this is a second test leaf".as_bytes().to_vec(),
  336. index_length: 15,
  337. };
  338. mt.add(&val2);
  339. let (_t, _il, b) = mt.sto.get(&utils::hash_vec(val2.bytes().to_vec()));
  340. assert_eq!(*val2.bytes(), b);
  341. assert_eq!("8ac95e9c8a6fbd40bb21de7895ee35f9c8f30ca029dbb0972c02344f49462e82", mt.root.to_hex());
  342. let hi = utils::hash_vec(val2.bytes().to_vec().split_at(val2.index_length as usize).0.to_vec());
  343. let mp = mt.generate_proof(hi);
  344. assert_eq!("0000000000000000000000000000000000000000000000000000000000000001fd8e1a60cdb23c0c7b2cf8462c99fafd905054dccb0ed75e7c8a7d6806749b6b", mp.to_hex())
  345. }
  346. #[test]
  347. fn test_generate_proof_empty_leaf() {
  348. let mut mt: MerkleTree = new(140);
  349. let val = TestValue {
  350. bytes: "this is a test leaf".as_bytes().to_vec(),
  351. index_length: 15,
  352. };
  353. mt.add(&val);
  354. let val2 = TestValue {
  355. bytes: "this is a second test leaf".as_bytes().to_vec(),
  356. index_length: 15,
  357. };
  358. mt.add(&val2);
  359. assert_eq!("8ac95e9c8a6fbd40bb21de7895ee35f9c8f30ca029dbb0972c02344f49462e82", mt.root.to_hex());
  360. // proof of empty leaf
  361. let val3 = TestValue {
  362. bytes: "this is a third test leaf".as_bytes().to_vec(),
  363. index_length: 15,
  364. };
  365. let hi = utils::hash_vec(val3.bytes().to_vec().split_at(val3.index_length as usize).0.to_vec());
  366. let mp = mt.generate_proof(hi);
  367. assert_eq!("000000000000000000000000000000000000000000000000000000000000000389741fa23da77c259781ad8f4331a5a7d793eef1db7e5200ddfc8e5f5ca7ce2bfd8e1a60cdb23c0c7b2cf8462c99fafd905054dccb0ed75e7c8a7d6806749b6b", mp.to_hex())
  368. }
  369. }