You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

334 lines
8.0 KiB

  1. package arbo
  2. import (
  3. "encoding/hex"
  4. "math"
  5. "math/big"
  6. "testing"
  7. qt "github.com/frankban/quicktest"
  8. "go.vocdoni.io/dvote/db"
  9. "go.vocdoni.io/dvote/db/badgerdb"
  10. )
  11. // testVirtualTree adds the given key-values and tests the vt root against the
  12. // Tree
  13. func testVirtualTree(c *qt.C, maxLevels int, keys, values [][]byte) {
  14. c.Assert(len(keys), qt.Equals, len(values))
  15. // normal tree, to have an expected root value
  16. database, err := badgerdb.New(db.Options{Path: c.TempDir()})
  17. c.Assert(err, qt.IsNil)
  18. tree, err := NewTree(database, maxLevels, HashFunctionSha256)
  19. c.Assert(err, qt.IsNil)
  20. for i := 0; i < len(keys); i++ {
  21. err := tree.Add(keys[i], values[i])
  22. c.Assert(err, qt.IsNil)
  23. }
  24. // virtual tree
  25. vTree := newVT(maxLevels, HashFunctionSha256)
  26. c.Assert(vTree.root, qt.IsNil)
  27. for i := 0; i < len(keys); i++ {
  28. err := vTree.add(0, keys[i], values[i])
  29. c.Assert(err, qt.IsNil)
  30. }
  31. // compute hashes, and check Root
  32. _, err = vTree.computeHashes()
  33. c.Assert(err, qt.IsNil)
  34. root, err := tree.Root()
  35. c.Assert(err, qt.IsNil)
  36. c.Assert(vTree.root.h, qt.DeepEquals, root)
  37. }
  38. func TestVirtualTreeTestVectors(t *testing.T) {
  39. c := qt.New(t)
  40. maxLevels := 32
  41. keyLen := int(math.Ceil(float64(maxLevels) / float64(8))) //nolint:gomnd
  42. keys := [][]byte{
  43. BigIntToBytes(keyLen, big.NewInt(1)),
  44. BigIntToBytes(keyLen, big.NewInt(33)),
  45. BigIntToBytes(keyLen, big.NewInt(1234)),
  46. BigIntToBytes(keyLen, big.NewInt(123456789)),
  47. }
  48. values := [][]byte{
  49. BigIntToBytes(keyLen, big.NewInt(2)),
  50. BigIntToBytes(keyLen, big.NewInt(44)),
  51. BigIntToBytes(keyLen, big.NewInt(9876)),
  52. BigIntToBytes(keyLen, big.NewInt(987654321)),
  53. }
  54. // check the root for different batches of leafs
  55. testVirtualTree(c, maxLevels, keys[:1], values[:1])
  56. testVirtualTree(c, maxLevels, keys[:2], values[:2])
  57. testVirtualTree(c, maxLevels, keys[:3], values[:3])
  58. testVirtualTree(c, maxLevels, keys[:4], values[:4])
  59. // test with hardcoded values
  60. testvectorKeys := []string{
  61. "1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642",
  62. "2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf",
  63. "9cb87ec67e875c61390edcd1ab517f443591047709a4d4e45b0f9ed980857b8e",
  64. "9b4e9e92e974a589f426ceeb4cb291dc24893513fecf8e8460992dcf52621d4d",
  65. "1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5",
  66. "d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7",
  67. "3cd55dbfb8f975f20a0925dfbdabe79fa2d51dd0268afbb8ba6b01de9dfcdd3c",
  68. "5d0a9d6d9f197c091bf054fac9cb60e11ec723d6610ed8578e617b4d46cb43d5",
  69. }
  70. keys = [][]byte{}
  71. values = [][]byte{}
  72. for i := 0; i < len(testvectorKeys); i++ {
  73. key, err := hex.DecodeString(testvectorKeys[i])
  74. c.Assert(err, qt.IsNil)
  75. keys = append(keys, key)
  76. values = append(values, []byte{0})
  77. }
  78. // check the root for different batches of leafs
  79. testVirtualTree(c, 256, keys[:1], values[:1])
  80. testVirtualTree(c, 256, keys, values)
  81. }
  82. func TestVirtualTreeRandomKeys(t *testing.T) {
  83. c := qt.New(t)
  84. // test with random values
  85. nLeafs := 1024
  86. keys := make([][]byte, nLeafs)
  87. values := make([][]byte, nLeafs)
  88. for i := 0; i < nLeafs; i++ {
  89. keys[i] = randomBytes(32)
  90. values[i] = randomBytes(32)
  91. }
  92. testVirtualTree(c, 256, keys, values)
  93. }
  94. func TestVirtualTreeAddBatch(t *testing.T) {
  95. c := qt.New(t)
  96. nLeafs := 2000
  97. maxLevels := 256
  98. keys := make([][]byte, nLeafs)
  99. values := make([][]byte, nLeafs)
  100. for i := 0; i < nLeafs; i++ {
  101. keys[i] = randomBytes(32)
  102. values[i] = randomBytes(32)
  103. }
  104. // normal tree, to have an expected root value
  105. database, err := badgerdb.New(db.Options{Path: c.TempDir()})
  106. c.Assert(err, qt.IsNil)
  107. tree, err := NewTree(database, maxLevels, HashFunctionBlake2b)
  108. c.Assert(err, qt.IsNil)
  109. for i := 0; i < len(keys); i++ {
  110. err := tree.Add(keys[i], values[i])
  111. c.Assert(err, qt.IsNil)
  112. }
  113. // virtual tree
  114. vTree := newVT(maxLevels, HashFunctionBlake2b)
  115. c.Assert(vTree.root, qt.IsNil)
  116. invalids, err := vTree.addBatch(keys, values)
  117. c.Assert(err, qt.IsNil)
  118. c.Assert(len(invalids), qt.Equals, 0)
  119. // compute hashes, and check Root
  120. _, err = vTree.computeHashes()
  121. c.Assert(err, qt.IsNil)
  122. root, err := tree.Root()
  123. c.Assert(err, qt.IsNil)
  124. c.Assert(vTree.root.h, qt.DeepEquals, root)
  125. }
  126. func TestVirtualTreeAddBatchFullyUsed(t *testing.T) {
  127. c := qt.New(t)
  128. vTree1 := newVT(7, HashFunctionPoseidon) // used for add one by one
  129. vTree2 := newVT(7, HashFunctionPoseidon) // used for addBatch
  130. var keys, values [][]byte
  131. for i := 0; i < 128; i++ {
  132. k := BigIntToBytes(1, big.NewInt(int64(i)))
  133. v := k
  134. keys = append(keys, k)
  135. values = append(values, v)
  136. // add one by one expecting no error
  137. err := vTree1.add(0, k, v)
  138. c.Assert(err, qt.IsNil)
  139. }
  140. invalids, err := vTree2.addBatch(keys, values)
  141. c.Assert(err, qt.IsNil)
  142. c.Assert(0, qt.Equals, len(invalids))
  143. }
  144. func TestGetNodesAtLevel(t *testing.T) {
  145. c := qt.New(t)
  146. tree0 := vt{
  147. params: &params{
  148. maxLevels: 100,
  149. hashFunction: HashFunctionBlake2b,
  150. emptyHash: make([]byte, HashFunctionBlake2b.Len()),
  151. },
  152. root: nil,
  153. }
  154. tree1 := vt{
  155. params: &params{
  156. maxLevels: 100,
  157. hashFunction: HashFunctionBlake2b,
  158. emptyHash: make([]byte, HashFunctionBlake2b.Len()),
  159. },
  160. root: &node{
  161. l: &node{
  162. l: &node{
  163. k: []byte{0, 0, 0, 0},
  164. v: []byte{0, 0, 0, 0},
  165. },
  166. r: &node{
  167. k: []byte{0, 0, 0, 1},
  168. v: []byte{0, 0, 0, 1},
  169. },
  170. },
  171. r: &node{
  172. l: &node{
  173. k: []byte{0, 0, 0, 2},
  174. v: []byte{0, 0, 0, 2},
  175. },
  176. r: &node{
  177. k: []byte{0, 0, 0, 3},
  178. v: []byte{0, 0, 0, 3},
  179. },
  180. },
  181. },
  182. }
  183. // tree1.printGraphviz()
  184. tree2 := vt{
  185. params: &params{
  186. maxLevels: 100,
  187. hashFunction: HashFunctionBlake2b,
  188. emptyHash: make([]byte, HashFunctionBlake2b.Len()),
  189. },
  190. root: &node{
  191. l: nil,
  192. r: &node{
  193. l: &node{
  194. l: &node{
  195. l: &node{
  196. k: []byte{0, 0, 0, 0},
  197. v: []byte{0, 0, 0, 0},
  198. },
  199. r: &node{
  200. k: []byte{0, 0, 0, 1},
  201. v: []byte{0, 0, 0, 1},
  202. },
  203. },
  204. r: &node{
  205. k: []byte{0, 0, 0, 2},
  206. v: []byte{0, 0, 0, 2},
  207. },
  208. },
  209. r: &node{
  210. k: []byte{0, 0, 0, 3},
  211. v: []byte{0, 0, 0, 3},
  212. },
  213. },
  214. },
  215. }
  216. // tree2.printGraphviz()
  217. tree3 := vt{
  218. params: &params{
  219. maxLevels: 100,
  220. hashFunction: HashFunctionBlake2b,
  221. emptyHash: make([]byte, HashFunctionBlake2b.Len()),
  222. },
  223. root: &node{
  224. l: nil,
  225. r: &node{
  226. l: &node{
  227. l: &node{
  228. l: &node{
  229. k: []byte{0, 0, 0, 0},
  230. v: []byte{0, 0, 0, 0},
  231. },
  232. r: &node{
  233. k: []byte{0, 0, 0, 1},
  234. v: []byte{0, 0, 0, 1},
  235. },
  236. },
  237. r: &node{
  238. k: []byte{0, 0, 0, 2},
  239. v: []byte{0, 0, 0, 2},
  240. },
  241. },
  242. r: nil,
  243. },
  244. },
  245. }
  246. // tree3.printGraphviz()
  247. nodes0, err := tree0.getNodesAtLevel(2)
  248. c.Assert(err, qt.IsNil)
  249. c.Assert(len(nodes0), qt.DeepEquals, 4)
  250. c.Assert("0000", qt.DeepEquals, getNotNils(nodes0))
  251. nodes1, err := tree1.getNodesAtLevel(2)
  252. c.Assert(err, qt.IsNil)
  253. c.Assert(len(nodes1), qt.DeepEquals, 4)
  254. c.Assert("1111", qt.DeepEquals, getNotNils(nodes1))
  255. nodes1, err = tree1.getNodesAtLevel(3)
  256. c.Assert(err, qt.IsNil)
  257. c.Assert(len(nodes1), qt.DeepEquals, 8)
  258. c.Assert("00000000", qt.DeepEquals, getNotNils(nodes1))
  259. nodes2, err := tree2.getNodesAtLevel(2)
  260. c.Assert(err, qt.IsNil)
  261. c.Assert(len(nodes2), qt.DeepEquals, 4)
  262. c.Assert("0011", qt.DeepEquals, getNotNils(nodes2))
  263. nodes2, err = tree2.getNodesAtLevel(3)
  264. c.Assert(err, qt.IsNil)
  265. c.Assert(len(nodes2), qt.DeepEquals, 8)
  266. c.Assert("00001100", qt.DeepEquals, getNotNils(nodes2))
  267. nodes3, err := tree3.getNodesAtLevel(2)
  268. c.Assert(err, qt.IsNil)
  269. c.Assert(len(nodes3), qt.DeepEquals, 4)
  270. c.Assert("0010", qt.DeepEquals, getNotNils(nodes3))
  271. nodes3, err = tree3.getNodesAtLevel(3)
  272. c.Assert(err, qt.IsNil)
  273. c.Assert(len(nodes3), qt.DeepEquals, 8)
  274. c.Assert("00001100", qt.DeepEquals, getNotNils(nodes3))
  275. nodes3, err = tree3.getNodesAtLevel(4)
  276. c.Assert(err, qt.IsNil)
  277. c.Assert(len(nodes3), qt.DeepEquals, 16)
  278. c.Assert("0000000011000000", qt.DeepEquals, getNotNils(nodes3))
  279. }
  280. func getNotNils(nodes []*node) string {
  281. s := ""
  282. for i := 0; i < len(nodes); i++ {
  283. if nodes[i] == nil {
  284. s += "0"
  285. } else {
  286. s += "1"
  287. }
  288. }
  289. return s
  290. }