You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

333 lines
8.0 KiB

  1. package arbo
  2. import (
  3. "encoding/hex"
  4. "math"
  5. "math/big"
  6. "testing"
  7. qt "github.com/frankban/quicktest"
  8. "go.vocdoni.io/dvote/db/badgerdb"
  9. )
  10. // testVirtualTree adds the given key-values and tests the vt root against the
  11. // Tree
  12. func testVirtualTree(c *qt.C, maxLevels int, keys, values [][]byte) {
  13. c.Assert(len(keys), qt.Equals, len(values))
  14. // normal tree, to have an expected root value
  15. database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
  16. c.Assert(err, qt.IsNil)
  17. tree, err := NewTree(database, maxLevels, HashFunctionSha256)
  18. c.Assert(err, qt.IsNil)
  19. for i := 0; i < len(keys); i++ {
  20. err := tree.Add(keys[i], values[i])
  21. c.Assert(err, qt.IsNil)
  22. }
  23. // virtual tree
  24. vTree := newVT(maxLevels, HashFunctionSha256)
  25. c.Assert(vTree.root, qt.IsNil)
  26. for i := 0; i < len(keys); i++ {
  27. err := vTree.add(0, keys[i], values[i])
  28. c.Assert(err, qt.IsNil)
  29. }
  30. // compute hashes, and check Root
  31. _, err = vTree.computeHashes()
  32. c.Assert(err, qt.IsNil)
  33. root, err := tree.Root()
  34. c.Assert(err, qt.IsNil)
  35. c.Assert(vTree.root.h, qt.DeepEquals, root)
  36. }
  37. func TestVirtualTreeTestVectors(t *testing.T) {
  38. c := qt.New(t)
  39. maxLevels := 32
  40. keyLen := int(math.Ceil(float64(maxLevels) / float64(8))) //nolint:gomnd
  41. keys := [][]byte{
  42. BigIntToBytes(keyLen, big.NewInt(1)),
  43. BigIntToBytes(keyLen, big.NewInt(33)),
  44. BigIntToBytes(keyLen, big.NewInt(1234)),
  45. BigIntToBytes(keyLen, big.NewInt(123456789)),
  46. }
  47. values := [][]byte{
  48. BigIntToBytes(keyLen, big.NewInt(2)),
  49. BigIntToBytes(keyLen, big.NewInt(44)),
  50. BigIntToBytes(keyLen, big.NewInt(9876)),
  51. BigIntToBytes(keyLen, big.NewInt(987654321)),
  52. }
  53. // check the root for different batches of leafs
  54. testVirtualTree(c, maxLevels, keys[:1], values[:1])
  55. testVirtualTree(c, maxLevels, keys[:2], values[:2])
  56. testVirtualTree(c, maxLevels, keys[:3], values[:3])
  57. testVirtualTree(c, maxLevels, keys[:4], values[:4])
  58. // test with hardcoded values
  59. testvectorKeys := []string{
  60. "1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642",
  61. "2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf",
  62. "9cb87ec67e875c61390edcd1ab517f443591047709a4d4e45b0f9ed980857b8e",
  63. "9b4e9e92e974a589f426ceeb4cb291dc24893513fecf8e8460992dcf52621d4d",
  64. "1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5",
  65. "d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7",
  66. "3cd55dbfb8f975f20a0925dfbdabe79fa2d51dd0268afbb8ba6b01de9dfcdd3c",
  67. "5d0a9d6d9f197c091bf054fac9cb60e11ec723d6610ed8578e617b4d46cb43d5",
  68. }
  69. keys = [][]byte{}
  70. values = [][]byte{}
  71. for i := 0; i < len(testvectorKeys); i++ {
  72. key, err := hex.DecodeString(testvectorKeys[i])
  73. c.Assert(err, qt.IsNil)
  74. keys = append(keys, key)
  75. values = append(values, []byte{0})
  76. }
  77. // check the root for different batches of leafs
  78. testVirtualTree(c, 256, keys[:1], values[:1])
  79. testVirtualTree(c, 256, keys, values)
  80. }
  81. func TestVirtualTreeRandomKeys(t *testing.T) {
  82. c := qt.New(t)
  83. // test with random values
  84. nLeafs := 1024
  85. keys := make([][]byte, nLeafs)
  86. values := make([][]byte, nLeafs)
  87. for i := 0; i < nLeafs; i++ {
  88. keys[i] = randomBytes(32)
  89. values[i] = randomBytes(32)
  90. }
  91. testVirtualTree(c, 256, keys, values)
  92. }
  93. func TestVirtualTreeAddBatch(t *testing.T) {
  94. c := qt.New(t)
  95. nLeafs := 2000
  96. maxLevels := 256
  97. keys := make([][]byte, nLeafs)
  98. values := make([][]byte, nLeafs)
  99. for i := 0; i < nLeafs; i++ {
  100. keys[i] = randomBytes(32)
  101. values[i] = randomBytes(32)
  102. }
  103. // normal tree, to have an expected root value
  104. database, err := badgerdb.New(badgerdb.Options{Path: c.TempDir()})
  105. c.Assert(err, qt.IsNil)
  106. tree, err := NewTree(database, maxLevels, HashFunctionBlake2b)
  107. c.Assert(err, qt.IsNil)
  108. for i := 0; i < len(keys); i++ {
  109. err := tree.Add(keys[i], values[i])
  110. c.Assert(err, qt.IsNil)
  111. }
  112. // virtual tree
  113. vTree := newVT(maxLevels, HashFunctionBlake2b)
  114. c.Assert(vTree.root, qt.IsNil)
  115. invalids, err := vTree.addBatch(keys, values)
  116. c.Assert(err, qt.IsNil)
  117. c.Assert(len(invalids), qt.Equals, 0)
  118. // compute hashes, and check Root
  119. _, err = vTree.computeHashes()
  120. c.Assert(err, qt.IsNil)
  121. root, err := tree.Root()
  122. c.Assert(err, qt.IsNil)
  123. c.Assert(vTree.root.h, qt.DeepEquals, root)
  124. }
  125. func TestVirtualTreeAddBatchFullyUsed(t *testing.T) {
  126. c := qt.New(t)
  127. vTree1 := newVT(7, HashFunctionPoseidon) // used for add one by one
  128. vTree2 := newVT(7, HashFunctionPoseidon) // used for addBatch
  129. var keys, values [][]byte
  130. for i := 0; i < 128; i++ {
  131. k := BigIntToBytes(1, big.NewInt(int64(i)))
  132. v := k
  133. keys = append(keys, k)
  134. values = append(values, v)
  135. // add one by one expecting no error
  136. err := vTree1.add(0, k, v)
  137. c.Assert(err, qt.IsNil)
  138. }
  139. invalids, err := vTree2.addBatch(keys, values)
  140. c.Assert(err, qt.IsNil)
  141. c.Assert(0, qt.Equals, len(invalids))
  142. }
  143. func TestGetNodesAtLevel(t *testing.T) {
  144. c := qt.New(t)
  145. tree0 := vt{
  146. params: &params{
  147. maxLevels: 100,
  148. hashFunction: HashFunctionBlake2b,
  149. emptyHash: make([]byte, HashFunctionBlake2b.Len()),
  150. },
  151. root: nil,
  152. }
  153. tree1 := vt{
  154. params: &params{
  155. maxLevels: 100,
  156. hashFunction: HashFunctionBlake2b,
  157. emptyHash: make([]byte, HashFunctionBlake2b.Len()),
  158. },
  159. root: &node{
  160. l: &node{
  161. l: &node{
  162. k: []byte{0, 0, 0, 0},
  163. v: []byte{0, 0, 0, 0},
  164. },
  165. r: &node{
  166. k: []byte{0, 0, 0, 1},
  167. v: []byte{0, 0, 0, 1},
  168. },
  169. },
  170. r: &node{
  171. l: &node{
  172. k: []byte{0, 0, 0, 2},
  173. v: []byte{0, 0, 0, 2},
  174. },
  175. r: &node{
  176. k: []byte{0, 0, 0, 3},
  177. v: []byte{0, 0, 0, 3},
  178. },
  179. },
  180. },
  181. }
  182. // tree1.printGraphviz()
  183. tree2 := vt{
  184. params: &params{
  185. maxLevels: 100,
  186. hashFunction: HashFunctionBlake2b,
  187. emptyHash: make([]byte, HashFunctionBlake2b.Len()),
  188. },
  189. root: &node{
  190. l: nil,
  191. r: &node{
  192. l: &node{
  193. l: &node{
  194. l: &node{
  195. k: []byte{0, 0, 0, 0},
  196. v: []byte{0, 0, 0, 0},
  197. },
  198. r: &node{
  199. k: []byte{0, 0, 0, 1},
  200. v: []byte{0, 0, 0, 1},
  201. },
  202. },
  203. r: &node{
  204. k: []byte{0, 0, 0, 2},
  205. v: []byte{0, 0, 0, 2},
  206. },
  207. },
  208. r: &node{
  209. k: []byte{0, 0, 0, 3},
  210. v: []byte{0, 0, 0, 3},
  211. },
  212. },
  213. },
  214. }
  215. // tree2.printGraphviz()
  216. tree3 := vt{
  217. params: &params{
  218. maxLevels: 100,
  219. hashFunction: HashFunctionBlake2b,
  220. emptyHash: make([]byte, HashFunctionBlake2b.Len()),
  221. },
  222. root: &node{
  223. l: nil,
  224. r: &node{
  225. l: &node{
  226. l: &node{
  227. l: &node{
  228. k: []byte{0, 0, 0, 0},
  229. v: []byte{0, 0, 0, 0},
  230. },
  231. r: &node{
  232. k: []byte{0, 0, 0, 1},
  233. v: []byte{0, 0, 0, 1},
  234. },
  235. },
  236. r: &node{
  237. k: []byte{0, 0, 0, 2},
  238. v: []byte{0, 0, 0, 2},
  239. },
  240. },
  241. r: nil,
  242. },
  243. },
  244. }
  245. // tree3.printGraphviz()
  246. nodes0, err := tree0.getNodesAtLevel(2)
  247. c.Assert(err, qt.IsNil)
  248. c.Assert(len(nodes0), qt.DeepEquals, 4)
  249. c.Assert("0000", qt.DeepEquals, getNotNils(nodes0))
  250. nodes1, err := tree1.getNodesAtLevel(2)
  251. c.Assert(err, qt.IsNil)
  252. c.Assert(len(nodes1), qt.DeepEquals, 4)
  253. c.Assert("1111", qt.DeepEquals, getNotNils(nodes1))
  254. nodes1, err = tree1.getNodesAtLevel(3)
  255. c.Assert(err, qt.IsNil)
  256. c.Assert(len(nodes1), qt.DeepEquals, 8)
  257. c.Assert("00000000", qt.DeepEquals, getNotNils(nodes1))
  258. nodes2, err := tree2.getNodesAtLevel(2)
  259. c.Assert(err, qt.IsNil)
  260. c.Assert(len(nodes2), qt.DeepEquals, 4)
  261. c.Assert("0011", qt.DeepEquals, getNotNils(nodes2))
  262. nodes2, err = tree2.getNodesAtLevel(3)
  263. c.Assert(err, qt.IsNil)
  264. c.Assert(len(nodes2), qt.DeepEquals, 8)
  265. c.Assert("00001100", qt.DeepEquals, getNotNils(nodes2))
  266. nodes3, err := tree3.getNodesAtLevel(2)
  267. c.Assert(err, qt.IsNil)
  268. c.Assert(len(nodes3), qt.DeepEquals, 4)
  269. c.Assert("0010", qt.DeepEquals, getNotNils(nodes3))
  270. nodes3, err = tree3.getNodesAtLevel(3)
  271. c.Assert(err, qt.IsNil)
  272. c.Assert(len(nodes3), qt.DeepEquals, 8)
  273. c.Assert("00001100", qt.DeepEquals, getNotNils(nodes3))
  274. nodes3, err = tree3.getNodesAtLevel(4)
  275. c.Assert(err, qt.IsNil)
  276. c.Assert(len(nodes3), qt.DeepEquals, 16)
  277. c.Assert("0000000011000000", qt.DeepEquals, getNotNils(nodes3))
  278. }
  279. func getNotNils(nodes []*node) string {
  280. s := ""
  281. for i := 0; i < len(nodes); i++ {
  282. if nodes[i] == nil {
  283. s += "0"
  284. } else {
  285. s += "1"
  286. }
  287. }
  288. return s
  289. }