diff --git a/src/lib.rs b/src/lib.rs index c1bcb4f..ed49809 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,6 @@ const TYPENODEEMPTY: u8 = 0; const TYPENODENORMAL: u8 = 1; const TYPENODEFINAL: u8 = 2; const TYPENODEVALUE: u8 = 3; -// const TYPENODEROOT: u8 = 4; const EMPTYNODEVALUE: [u8;32] = [0;32]; pub struct TestValue { @@ -113,7 +112,7 @@ impl MerkleTree { // let (t, il, node_bytes) = self.sto.get(&utils::hash_vec(node_hash.to_vec())); let (t, il, node_bytes) = self.sto.get(&node_hash); if t == TYPENODEFINAL { - let hi_child = utils::hash_vec(v.bytes().to_vec().split_off(il as usize)); + let hi_child = utils::hash_vec(node_bytes.to_vec().split_at(il as usize).0.to_vec()); let path_child = utils::get_path(self.num_levels, hi_child); let pos_diff = utils::compare_paths(path_child.clone(), path.clone()); if pos_diff == 999 { // TODO use a match here, and instead of 999 return something better @@ -160,7 +159,6 @@ impl MerkleTree { siblings.push(*array_ref!(sibling, 0, 32)); if node_hash == EMPTYNODEVALUE { if i==self.num_levels-2 && siblings[siblings.len()-1]==EMPTYNODEVALUE { - let final_node_hash = utils::calc_hash_from_leaf_and_level(i+1, path.clone(), utils::hash_vec(v.bytes().to_vec())); self.sto.insert(final_node_hash, TYPENODEFINAL, v.index_length(), &mut v.bytes().to_vec()); self.root = final_node_hash; @@ -173,33 +171,119 @@ impl MerkleTree { } } self.root = self.replace_leaf(path, siblings, utils::hash_vec(v.bytes().to_vec()), TYPENODEVALUE, v.index_length(), &mut v.bytes().to_vec()); + } + + #[allow(dead_code)] + pub fn replace_leaf(&mut self, path: Vec, siblings: Vec<[u8;32]>, leaf_hash: [u8;32], node_type: u8, index_length: u32, leaf_value: &mut Vec) -> [u8;32] { + self.sto.insert(leaf_hash, node_type, index_length, leaf_value); + let mut curr_node = leaf_hash; + + for i in 0..siblings.len() { + if path.clone().into_iter().nth(i as usize).unwrap() { + let node = node::TreeNode { + child_l: curr_node, + child_r: siblings[siblings.len()-1-i], + }; + self.sto.insert(node.ht(), TYPENODENORMAL, 0, &mut node.bytes()); + curr_node = node.ht(); + } else { + let node = node::TreeNode { + child_l: siblings[siblings.len()-1-i], + child_r: curr_node, + }; + self.sto.insert(node.ht(), TYPENODENORMAL, 0, &mut node.bytes()); + curr_node = node.ht(); + } } + curr_node + } - #[allow(dead_code)] - pub fn replace_leaf(&mut self, path: Vec, siblings: Vec<[u8;32]>, leaf_hash: [u8;32], node_type: u8, index_length: u32, leaf_value: &mut Vec) -> [u8;32] { - self.sto.insert(leaf_hash, node_type, index_length, leaf_value); - let mut curr_node = leaf_hash; - - for i in 0..siblings.len() { - if path.clone().into_iter().nth(i as usize).unwrap() { - let node = node::TreeNode { - child_l: curr_node, - child_r: siblings[siblings.len()-1-i], - }; - self.sto.insert(node.ht(), TYPENODENORMAL, 0, &mut node.bytes()); - curr_node = node.ht(); - } else { - let node = node::TreeNode { - child_l: siblings[siblings.len()-1-i], - child_r: curr_node, - }; - self.sto.insert(node.ht(), TYPENODENORMAL, 0, &mut node.bytes()); - curr_node = node.ht(); + #[allow(dead_code)] + pub fn get_value_in_pos(&self, hi: [u8;32]) -> Vec { + let path = utils::get_path(self.num_levels, hi); + let mut node_hash = self.root; + for i in (0..self.num_levels-1).rev() { + let (t, il, node_bytes) = self.sto.get(&node_hash); + if t == TYPENODEFINAL { + let hi_node = utils::hash_vec(node_bytes.to_vec().split_at(il as usize).0.to_vec()); + let path_node = utils::get_path(self.num_levels, hi_node); + let pos_diff = utils::compare_paths(path_node.clone(), path.clone()); + // if pos_diff > self.num_levels { + if pos_diff != 999 { + return EMPTYNODEVALUE.to_vec(); } + return node_bytes; + } + let node = node::parse_node_bytes(node_bytes); + if !path.clone().into_iter().nth(i as usize).unwrap() { + node_hash = node.child_l; + } else { + node_hash = node.child_r; + } + } + let (_t, _il, node_bytes) = self.sto.get(&node_hash); + node_bytes + } + + #[allow(dead_code)] + pub fn generate_proof(&self, hi: [u8;32]) -> Vec { + let mut mp: Vec = Vec::new(); + + let mut empties: [u8;32] = [0;32]; + let path = utils::get_path(self.num_levels, hi); + + let mut siblings: Vec<[u8;32]> = Vec::new(); + let mut node_hash = self.root; + + for i in 0..self.num_levels { + let (t, il, node_bytes) = self.sto.get(&node_hash); + if t == TYPENODEFINAL { + let real_value_in_pos = self.get_value_in_pos(hi); + if real_value_in_pos == EMPTYNODEVALUE { + let leaf_hi = utils::hash_vec(node_bytes.to_vec().split_at(il as usize).0.to_vec()); + let path_child = utils::get_path(self.num_levels, leaf_hi); + let pos_diff = utils::compare_paths(path_child.clone(), path.clone()); + if pos_diff == self.num_levels { // TODO use a match here, and instead of 999 return something better + return mp; + } + if pos_diff != self.num_levels-1-i { + let sibling = utils::calc_hash_from_leaf_and_level(pos_diff, path_child.clone(), utils::hash_vec(node_bytes.to_vec())); + let mut new_siblings: Vec<[u8;32]> = Vec::new(); + new_siblings.push(sibling); + new_siblings.append(&mut siblings); + siblings = new_siblings; + // set empties bit + let bit_pos = self.num_levels-2-pos_diff; + empties[(empties.len() as isize + (bit_pos as isize/8-1) as isize) as usize] |= 1 << bit_pos%8; + } + } + break + } + let node = node::parse_node_bytes(node_bytes); + let sibling: [u8;32]; + if !path.clone().into_iter().nth(self.num_levels as usize -i as usize-2 as usize).unwrap() { + node_hash = node.child_l; + sibling = node.child_r; + } else { + sibling = node.child_l; + node_hash = node.child_r; + } + if sibling != EMPTYNODEVALUE { + // set empties bit + empties[(empties.len() as isize + (i as isize/8-1) as isize) as usize] |= 1 << i%8; + let mut new_siblings: Vec<[u8;32]> = Vec::new(); + new_siblings.push(sibling); + new_siblings.append(&mut siblings); + siblings = new_siblings; } - curr_node } + mp.append(&mut empties[..].to_vec()); + for s in siblings { + mp.append(&mut s.to_vec()); + } + mp } +} #[cfg(test)] @@ -257,8 +341,65 @@ impl MerkleTree { }; assert_eq!("0000000000000000000000000000000000000000000000000000000000000000", mt.root.to_hex()); mt.add(&val); - let (_t, _il, b) = mt.sto.get(&utils::hash_vec(val.bytes().to_vec())); + let (_t, _il, b) = mt.sto.get(&utils::hash_vec(val.bytes().to_vec())); assert_eq!(*val.bytes(), b); assert_eq!("b4fdf8a653198f0e179ccb3af7e4fc09d76247f479d6cfc95cd92d6fda589f27", mt.root.to_hex()); + let val2 = TestValue { + bytes: "this is a second test leaf".as_bytes().to_vec(), + index_length: 15, + }; + mt.add(&val2); + let (_t, _il, b) = mt.sto.get(&utils::hash_vec(val2.bytes().to_vec())); + assert_eq!(*val2.bytes(), b); + assert_eq!("8ac95e9c8a6fbd40bb21de7895ee35f9c8f30ca029dbb0972c02344f49462e82", mt.root.to_hex()); + } + #[test] + fn test_generate_proof() { + let mut mt: MerkleTree = new(140); + let val = TestValue { + bytes: "this is a test leaf".as_bytes().to_vec(), + index_length: 15, + }; + assert_eq!("0000000000000000000000000000000000000000000000000000000000000000", mt.root.to_hex()); + mt.add(&val); + let (_t, _il, b) = mt.sto.get(&utils::hash_vec(val.bytes().to_vec())); + assert_eq!(*val.bytes(), b); + assert_eq!("b4fdf8a653198f0e179ccb3af7e4fc09d76247f479d6cfc95cd92d6fda589f27", mt.root.to_hex()); + let val2 = TestValue { + bytes: "this is a second test leaf".as_bytes().to_vec(), + index_length: 15, + }; + mt.add(&val2); + let (_t, _il, b) = mt.sto.get(&utils::hash_vec(val2.bytes().to_vec())); + assert_eq!(*val2.bytes(), b); + assert_eq!("8ac95e9c8a6fbd40bb21de7895ee35f9c8f30ca029dbb0972c02344f49462e82", mt.root.to_hex()); + + let hi = utils::hash_vec(val2.bytes().to_vec().split_at(val2.index_length as usize).0.to_vec()); + let mp = mt.generate_proof(hi); + assert_eq!("0000000000000000000000000000000000000000000000000000000000000001fd8e1a60cdb23c0c7b2cf8462c99fafd905054dccb0ed75e7c8a7d6806749b6b", mp.to_hex()) + } + #[test] + fn test_generate_proof_empty_leaf() { + let mut mt: MerkleTree = new(140); + let val = TestValue { + bytes: "this is a test leaf".as_bytes().to_vec(), + index_length: 15, + }; + mt.add(&val); + let val2 = TestValue { + bytes: "this is a second test leaf".as_bytes().to_vec(), + index_length: 15, + }; + mt.add(&val2); + assert_eq!("8ac95e9c8a6fbd40bb21de7895ee35f9c8f30ca029dbb0972c02344f49462e82", mt.root.to_hex()); + + // proof of empty leaf + let val3 = TestValue { + bytes: "this is a third test leaf".as_bytes().to_vec(), + index_length: 15, + }; + let hi = utils::hash_vec(val3.bytes().to_vec().split_at(val3.index_length as usize).0.to_vec()); + let mp = mt.generate_proof(hi); + assert_eq!("000000000000000000000000000000000000000000000000000000000000000389741fa23da77c259781ad8f4331a5a7d793eef1db7e5200ddfc8e5f5ca7ce2bfd8e1a60cdb23c0c7b2cf8462c99fafd905054dccb0ed75e7c8a7d6806749b6b", mp.to_hex()) } } diff --git a/src/utils.rs b/src/utils.rs index c00290f..c94e866 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -51,13 +51,12 @@ pub fn cut_path(path: Vec, i: usize) -> Vec { } pub fn compare_paths(a: Vec, b: Vec) -> u32 { - for i in (0..a.len()-1).rev() { + for i in (0..a.len()).rev() { if a[i] != b[i] { return i as u32; } } return 999; - } pub fn get_empties_between_i_and_pos(i: u32, pos: u32) -> Vec<[u8;32]> {