You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

336 lines
8.1 KiB

  1. package arbo
  2. import (
  3. "encoding/hex"
  4. "math"
  5. "math/big"
  6. "testing"
  7. qt "github.com/frankban/quicktest"
  8. "go.vocdoni.io/dvote/db"
  9. "go.vocdoni.io/dvote/db/badgerdb"
  10. )
  11. // testVirtualTree adds the given key-values and tests the vt root against the
  12. // Tree
  13. func testVirtualTree(c *qt.C, maxLevels int, keys, values [][]byte) {
  14. c.Assert(len(keys), qt.Equals, len(values))
  15. // normal tree, to have an expected root value
  16. database, err := badgerdb.New(db.Options{Path: c.TempDir()})
  17. c.Assert(err, qt.IsNil)
  18. tree, err := NewTree(Config{Database: database, MaxLevels: maxLevels,
  19. HashFunction: HashFunctionSha256})
  20. c.Assert(err, qt.IsNil)
  21. for i := 0; i < len(keys); i++ {
  22. err := tree.Add(keys[i], values[i])
  23. c.Assert(err, qt.IsNil)
  24. }
  25. // virtual tree
  26. vTree := newVT(maxLevels, HashFunctionSha256)
  27. c.Assert(vTree.root, qt.IsNil)
  28. for i := 0; i < len(keys); i++ {
  29. err := vTree.add(0, keys[i], values[i])
  30. c.Assert(err, qt.IsNil)
  31. }
  32. // compute hashes, and check Root
  33. _, err = vTree.computeHashes()
  34. c.Assert(err, qt.IsNil)
  35. root, err := tree.Root()
  36. c.Assert(err, qt.IsNil)
  37. c.Assert(vTree.root.h, qt.DeepEquals, root)
  38. }
  39. func TestVirtualTreeTestVectors(t *testing.T) {
  40. c := qt.New(t)
  41. maxLevels := 32
  42. keyLen := int(math.Ceil(float64(maxLevels) / float64(8))) //nolint:gomnd
  43. keys := [][]byte{
  44. BigIntToBytes(keyLen, big.NewInt(1)),
  45. BigIntToBytes(keyLen, big.NewInt(33)),
  46. BigIntToBytes(keyLen, big.NewInt(1234)),
  47. BigIntToBytes(keyLen, big.NewInt(123456789)),
  48. }
  49. values := [][]byte{
  50. BigIntToBytes(keyLen, big.NewInt(2)),
  51. BigIntToBytes(keyLen, big.NewInt(44)),
  52. BigIntToBytes(keyLen, big.NewInt(9876)),
  53. BigIntToBytes(keyLen, big.NewInt(987654321)),
  54. }
  55. // check the root for different batches of leafs
  56. testVirtualTree(c, maxLevels, keys[:1], values[:1])
  57. testVirtualTree(c, maxLevels, keys[:2], values[:2])
  58. testVirtualTree(c, maxLevels, keys[:3], values[:3])
  59. testVirtualTree(c, maxLevels, keys[:4], values[:4])
  60. // test with hardcoded values
  61. testvectorKeys := []string{
  62. "1c7c2265e368314ca58ed2e1f33a326f1220e234a566d55c3605439dbe411642",
  63. "2c9f0a578afff5bfa4e0992a43066460faaab9e8e500db0b16647c701cdb16bf",
  64. "9cb87ec67e875c61390edcd1ab517f443591047709a4d4e45b0f9ed980857b8e",
  65. "9b4e9e92e974a589f426ceeb4cb291dc24893513fecf8e8460992dcf52621d4d",
  66. "1c45cb31f2fa39ec7b9ebf0fad40e0b8296016b5ce8844ae06ff77226379d9a5",
  67. "d8af98bbbb585129798ae54d5eabbc9d0561d583faf1663b3a3724d15bda4ec7",
  68. "3cd55dbfb8f975f20a0925dfbdabe79fa2d51dd0268afbb8ba6b01de9dfcdd3c",
  69. "5d0a9d6d9f197c091bf054fac9cb60e11ec723d6610ed8578e617b4d46cb43d5",
  70. }
  71. keys = [][]byte{}
  72. values = [][]byte{}
  73. for i := 0; i < len(testvectorKeys); i++ {
  74. key, err := hex.DecodeString(testvectorKeys[i])
  75. c.Assert(err, qt.IsNil)
  76. keys = append(keys, key)
  77. values = append(values, []byte{0})
  78. }
  79. // check the root for different batches of leafs
  80. testVirtualTree(c, 256, keys[:1], values[:1])
  81. testVirtualTree(c, 256, keys, values)
  82. }
  83. func TestVirtualTreeRandomKeys(t *testing.T) {
  84. c := qt.New(t)
  85. // test with random values
  86. nLeafs := 1024
  87. keys := make([][]byte, nLeafs)
  88. values := make([][]byte, nLeafs)
  89. for i := 0; i < nLeafs; i++ {
  90. keys[i] = randomBytes(32)
  91. values[i] = randomBytes(32)
  92. }
  93. testVirtualTree(c, 256, keys, values)
  94. }
  95. func TestVirtualTreeAddBatch(t *testing.T) {
  96. c := qt.New(t)
  97. nLeafs := 2000
  98. maxLevels := 256
  99. keys := make([][]byte, nLeafs)
  100. values := make([][]byte, nLeafs)
  101. for i := 0; i < nLeafs; i++ {
  102. keys[i] = randomBytes(32)
  103. values[i] = randomBytes(32)
  104. }
  105. // normal tree, to have an expected root value
  106. database, err := badgerdb.New(db.Options{Path: c.TempDir()})
  107. c.Assert(err, qt.IsNil)
  108. tree, err := NewTree(Config{Database: database, MaxLevels: maxLevels,
  109. HashFunction: HashFunctionBlake2b})
  110. c.Assert(err, qt.IsNil)
  111. for i := 0; i < len(keys); i++ {
  112. err := tree.Add(keys[i], values[i])
  113. c.Assert(err, qt.IsNil)
  114. }
  115. // virtual tree
  116. vTree := newVT(maxLevels, HashFunctionBlake2b)
  117. c.Assert(vTree.root, qt.IsNil)
  118. invalids, err := vTree.addBatch(keys, values)
  119. c.Assert(err, qt.IsNil)
  120. c.Assert(len(invalids), qt.Equals, 0)
  121. // compute hashes, and check Root
  122. _, err = vTree.computeHashes()
  123. c.Assert(err, qt.IsNil)
  124. root, err := tree.Root()
  125. c.Assert(err, qt.IsNil)
  126. c.Assert(vTree.root.h, qt.DeepEquals, root)
  127. }
  128. func TestVirtualTreeAddBatchFullyUsed(t *testing.T) {
  129. c := qt.New(t)
  130. vTree1 := newVT(7, HashFunctionPoseidon) // used for add one by one
  131. vTree2 := newVT(7, HashFunctionPoseidon) // used for addBatch
  132. var keys, values [][]byte
  133. for i := 0; i < 128; i++ {
  134. k := BigIntToBytes(1, big.NewInt(int64(i)))
  135. v := k
  136. keys = append(keys, k)
  137. values = append(values, v)
  138. // add one by one expecting no error
  139. err := vTree1.add(0, k, v)
  140. c.Assert(err, qt.IsNil)
  141. }
  142. invalids, err := vTree2.addBatch(keys, values)
  143. c.Assert(err, qt.IsNil)
  144. c.Assert(0, qt.Equals, len(invalids))
  145. }
  146. func TestGetNodesAtLevel(t *testing.T) {
  147. c := qt.New(t)
  148. tree0 := vt{
  149. params: &params{
  150. maxLevels: 100,
  151. hashFunction: HashFunctionBlake2b,
  152. emptyHash: make([]byte, HashFunctionBlake2b.Len()),
  153. },
  154. root: nil,
  155. }
  156. tree1 := vt{
  157. params: &params{
  158. maxLevels: 100,
  159. hashFunction: HashFunctionBlake2b,
  160. emptyHash: make([]byte, HashFunctionBlake2b.Len()),
  161. },
  162. root: &node{
  163. l: &node{
  164. l: &node{
  165. k: []byte{0, 0, 0, 0},
  166. v: []byte{0, 0, 0, 0},
  167. },
  168. r: &node{
  169. k: []byte{0, 0, 0, 1},
  170. v: []byte{0, 0, 0, 1},
  171. },
  172. },
  173. r: &node{
  174. l: &node{
  175. k: []byte{0, 0, 0, 2},
  176. v: []byte{0, 0, 0, 2},
  177. },
  178. r: &node{
  179. k: []byte{0, 0, 0, 3},
  180. v: []byte{0, 0, 0, 3},
  181. },
  182. },
  183. },
  184. }
  185. // tree1.printGraphviz()
  186. tree2 := vt{
  187. params: &params{
  188. maxLevels: 100,
  189. hashFunction: HashFunctionBlake2b,
  190. emptyHash: make([]byte, HashFunctionBlake2b.Len()),
  191. },
  192. root: &node{
  193. l: nil,
  194. r: &node{
  195. l: &node{
  196. l: &node{
  197. l: &node{
  198. k: []byte{0, 0, 0, 0},
  199. v: []byte{0, 0, 0, 0},
  200. },
  201. r: &node{
  202. k: []byte{0, 0, 0, 1},
  203. v: []byte{0, 0, 0, 1},
  204. },
  205. },
  206. r: &node{
  207. k: []byte{0, 0, 0, 2},
  208. v: []byte{0, 0, 0, 2},
  209. },
  210. },
  211. r: &node{
  212. k: []byte{0, 0, 0, 3},
  213. v: []byte{0, 0, 0, 3},
  214. },
  215. },
  216. },
  217. }
  218. // tree2.printGraphviz()
  219. tree3 := vt{
  220. params: &params{
  221. maxLevels: 100,
  222. hashFunction: HashFunctionBlake2b,
  223. emptyHash: make([]byte, HashFunctionBlake2b.Len()),
  224. },
  225. root: &node{
  226. l: nil,
  227. r: &node{
  228. l: &node{
  229. l: &node{
  230. l: &node{
  231. k: []byte{0, 0, 0, 0},
  232. v: []byte{0, 0, 0, 0},
  233. },
  234. r: &node{
  235. k: []byte{0, 0, 0, 1},
  236. v: []byte{0, 0, 0, 1},
  237. },
  238. },
  239. r: &node{
  240. k: []byte{0, 0, 0, 2},
  241. v: []byte{0, 0, 0, 2},
  242. },
  243. },
  244. r: nil,
  245. },
  246. },
  247. }
  248. // tree3.printGraphviz()
  249. nodes0, err := tree0.getNodesAtLevel(2)
  250. c.Assert(err, qt.IsNil)
  251. c.Assert(len(nodes0), qt.DeepEquals, 4)
  252. c.Assert("0000", qt.DeepEquals, getNotNils(nodes0))
  253. nodes1, err := tree1.getNodesAtLevel(2)
  254. c.Assert(err, qt.IsNil)
  255. c.Assert(len(nodes1), qt.DeepEquals, 4)
  256. c.Assert("1111", qt.DeepEquals, getNotNils(nodes1))
  257. nodes1, err = tree1.getNodesAtLevel(3)
  258. c.Assert(err, qt.IsNil)
  259. c.Assert(len(nodes1), qt.DeepEquals, 8)
  260. c.Assert("00000000", qt.DeepEquals, getNotNils(nodes1))
  261. nodes2, err := tree2.getNodesAtLevel(2)
  262. c.Assert(err, qt.IsNil)
  263. c.Assert(len(nodes2), qt.DeepEquals, 4)
  264. c.Assert("0011", qt.DeepEquals, getNotNils(nodes2))
  265. nodes2, err = tree2.getNodesAtLevel(3)
  266. c.Assert(err, qt.IsNil)
  267. c.Assert(len(nodes2), qt.DeepEquals, 8)
  268. c.Assert("00001100", qt.DeepEquals, getNotNils(nodes2))
  269. nodes3, err := tree3.getNodesAtLevel(2)
  270. c.Assert(err, qt.IsNil)
  271. c.Assert(len(nodes3), qt.DeepEquals, 4)
  272. c.Assert("0010", qt.DeepEquals, getNotNils(nodes3))
  273. nodes3, err = tree3.getNodesAtLevel(3)
  274. c.Assert(err, qt.IsNil)
  275. c.Assert(len(nodes3), qt.DeepEquals, 8)
  276. c.Assert("00001100", qt.DeepEquals, getNotNils(nodes3))
  277. nodes3, err = tree3.getNodesAtLevel(4)
  278. c.Assert(err, qt.IsNil)
  279. c.Assert(len(nodes3), qt.DeepEquals, 16)
  280. c.Assert("0000000011000000", qt.DeepEquals, getNotNils(nodes3))
  281. }
  282. func getNotNils(nodes []*node) string {
  283. s := ""
  284. for i := 0; i < len(nodes); i++ {
  285. if nodes[i] == nil {
  286. s += "0"
  287. } else {
  288. s += "1"
  289. }
  290. }
  291. return s
  292. }