diff --git a/src/merkle/simple_smt/mod.rs b/src/merkle/simple_smt/mod.rs index e6e11d1..6fc59fe 100644 --- a/src/merkle/simple_smt/mod.rs +++ b/src/merkle/simple_smt/mod.rs @@ -87,19 +87,20 @@ impl SimpleSmt { return Err(MerkleError::InvalidNumEntries(max, entries.len())); } - // append leaves to the tree returning an error if a duplicate entry for the same key - // is found - let mut empty_entries = BTreeSet::new(); + // This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so + // entries with the empty value need additional tracking. + let mut key_set_to_zero = BTreeSet::new(); + for (key, value) in entries { let old_value = tree.update_leaf(key, value)?; - if old_value != Self::EMPTY_VALUE || empty_entries.contains(&key) { - return Err(MerkleError::DuplicateValuesForIndex(key)); - } - // if we've processed an empty entry, add the key to the set of empty entry keys, and - // if this key was already in the set, return an error - if value == Self::EMPTY_VALUE && !empty_entries.insert(key) { + + if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) { return Err(MerkleError::DuplicateValuesForIndex(key)); } + + if value == Self::EMPTY_VALUE { + key_set_to_zero.insert(key); + }; } Ok(tree) } diff --git a/src/merkle/simple_smt/tests.rs b/src/merkle/simple_smt/tests.rs index 5070d91..3fc9386 100644 --- a/src/merkle/simple_smt/tests.rs +++ b/src/merkle/simple_smt/tests.rs @@ -229,22 +229,31 @@ fn small_tree_opening_is_consistent() { } #[test] -fn fail_on_duplicates() { - let entries = [(1_u64, int_to_leaf(1)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(3))]; - let smt = SimpleSmt::with_leaves(64, entries); - assert!(smt.is_err()); +fn test_simplesmt_fail_on_duplicates() { + let values = [ + // same key, same value + (int_to_leaf(1), int_to_leaf(1)), + // same key, different values + (int_to_leaf(1), int_to_leaf(2)), + // same key, set to zero + (EMPTY_WORD, int_to_leaf(1)), + // same key, re-set to zero + (int_to_leaf(1), EMPTY_WORD), + // same key, set to zero twice + (EMPTY_WORD, EMPTY_WORD), + ]; - let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(0))]; - let smt = SimpleSmt::with_leaves(64, entries); - assert!(smt.is_err()); + for (first, second) in values.iter() { + // consecutive + let entries = [(1, *first), (1, *second)]; + let smt = SimpleSmt::with_leaves(64, entries); + assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1)); - let entries = [(1_u64, int_to_leaf(0)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(1))]; - let smt = SimpleSmt::with_leaves(64, entries); - assert!(smt.is_err()); - - let entries = [(1_u64, int_to_leaf(1)), (5, int_to_leaf(2)), (1_u64, int_to_leaf(0))]; - let smt = SimpleSmt::with_leaves(64, entries); - assert!(smt.is_err()); + // not consecutive + let entries = [(1, *first), (5, int_to_leaf(5)), (1, *second)]; + let smt = SimpleSmt::with_leaves(64, entries); + assert_eq!(smt.unwrap_err(), MerkleError::DuplicateValuesForIndex(1)); + } } #[test]