You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

520 lines
14 KiB

  1. package arbo
  2. import (
  3. "encoding/hex"
  4. "fmt"
  5. "math/big"
  6. "testing"
  7. "time"
  8. qt "github.com/frankban/quicktest"
  9. "github.com/iden3/go-merkletree/db/memory"
  10. )
  11. func TestBatchAux(t *testing.T) { // TODO TMP this test will be delted
  12. c := qt.New(t)
  13. nLeafs := 16
  14. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  15. c.Assert(err, qt.IsNil)
  16. defer tree.db.Close()
  17. start := time.Now()
  18. for i := 0; i < nLeafs; i++ {
  19. k := BigIntToBytes(big.NewInt(int64(i)))
  20. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  21. if err := tree.Add(k, v); err != nil {
  22. t.Fatal(err)
  23. }
  24. }
  25. fmt.Println(time.Since(start))
  26. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  27. c.Assert(err, qt.IsNil)
  28. defer tree2.db.Close()
  29. for i := 0; i < 8; i++ {
  30. k := BigIntToBytes(big.NewInt(int64(i)))
  31. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  32. if err := tree2.Add(k, v); err != nil {
  33. t.Fatal(err)
  34. }
  35. }
  36. // tree.PrintGraphviz(nil)
  37. // tree2.PrintGraphviz(nil)
  38. var keys, values [][]byte
  39. for i := 8; i < nLeafs; i++ {
  40. k := BigIntToBytes(big.NewInt(int64(i)))
  41. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  42. keys = append(keys, k)
  43. values = append(values, v)
  44. }
  45. start = time.Now()
  46. indexes, err := tree2.AddBatchOpt(keys, values)
  47. c.Assert(err, qt.IsNil)
  48. fmt.Println(time.Since(start))
  49. c.Check(len(indexes), qt.Equals, 0)
  50. // check that both trees roots are equal
  51. c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
  52. }
  53. func TestAddBatchCaseA(t *testing.T) {
  54. c := qt.New(t)
  55. nLeafs := 1024
  56. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  57. c.Assert(err, qt.IsNil)
  58. defer tree.db.Close()
  59. start := time.Now()
  60. for i := 0; i < nLeafs; i++ {
  61. k := BigIntToBytes(big.NewInt(int64(i)))
  62. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  63. if err := tree.Add(k, v); err != nil {
  64. t.Fatal(err)
  65. }
  66. }
  67. fmt.Println("time elapsed without CASE A: ", time.Since(start))
  68. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  69. c.Assert(err, qt.IsNil)
  70. defer tree2.db.Close()
  71. var keys, values [][]byte
  72. for i := 0; i < nLeafs; i++ {
  73. k := BigIntToBytes(big.NewInt(int64(i)))
  74. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  75. keys = append(keys, k)
  76. values = append(values, v)
  77. }
  78. start = time.Now()
  79. indexes, err := tree2.AddBatchOpt(keys, values)
  80. c.Assert(err, qt.IsNil)
  81. fmt.Println("time elapsed with CASE A: ", time.Since(start))
  82. c.Check(len(indexes), qt.Equals, 0)
  83. // check that both trees roots are equal
  84. c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
  85. }
  86. func TestAddBatchCaseANotPowerOf2(t *testing.T) {
  87. c := qt.New(t)
  88. nLeafs := 1027
  89. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  90. c.Assert(err, qt.IsNil)
  91. defer tree.db.Close()
  92. for i := 0; i < nLeafs; i++ {
  93. k := BigIntToBytes(big.NewInt(int64(i)))
  94. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  95. if err := tree.Add(k, v); err != nil {
  96. t.Fatal(err)
  97. }
  98. }
  99. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  100. c.Assert(err, qt.IsNil)
  101. defer tree2.db.Close()
  102. var keys, values [][]byte
  103. for i := 0; i < nLeafs; i++ {
  104. k := BigIntToBytes(big.NewInt(int64(i)))
  105. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  106. keys = append(keys, k)
  107. values = append(values, v)
  108. }
  109. indexes, err := tree2.AddBatchOpt(keys, values)
  110. c.Assert(err, qt.IsNil)
  111. c.Check(len(indexes), qt.Equals, 0)
  112. // check that both trees roots are equal
  113. c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
  114. }
  115. func TestAddBatchCaseB(t *testing.T) {
  116. c := qt.New(t)
  117. nLeafs := 1024
  118. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  119. c.Assert(err, qt.IsNil)
  120. defer tree.db.Close()
  121. start := time.Now()
  122. for i := 0; i < nLeafs; i++ {
  123. k := BigIntToBytes(big.NewInt(int64(i)))
  124. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  125. if err := tree.Add(k, v); err != nil {
  126. t.Fatal(err)
  127. }
  128. }
  129. fmt.Println("time elapsed without CASE B: ", time.Since(start))
  130. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  131. c.Assert(err, qt.IsNil)
  132. defer tree2.db.Close()
  133. // add the initial leafs to fill a bit the tree before calling the
  134. // AddBatch method
  135. for i := 0; i < 99; i++ { // TMP TODO use const minLeafsThreshold-1 once ready
  136. k := BigIntToBytes(big.NewInt(int64(i)))
  137. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  138. if err := tree2.Add(k, v); err != nil {
  139. t.Fatal(err)
  140. }
  141. }
  142. var keys, values [][]byte
  143. for i := 99; i < nLeafs; i++ {
  144. k := BigIntToBytes(big.NewInt(int64(i)))
  145. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  146. keys = append(keys, k)
  147. values = append(values, v)
  148. }
  149. start = time.Now()
  150. indexes, err := tree2.AddBatchOpt(keys, values)
  151. c.Assert(err, qt.IsNil)
  152. fmt.Println("time elapsed with CASE B: ", time.Since(start))
  153. c.Check(len(indexes), qt.Equals, 0)
  154. // check that both trees roots are equal
  155. c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
  156. }
  157. func TestGetKeysAtLevel(t *testing.T) {
  158. c := qt.New(t)
  159. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  160. c.Assert(err, qt.IsNil)
  161. defer tree.db.Close()
  162. for i := 0; i < 32; i++ {
  163. k := BigIntToBytes(big.NewInt(int64(i)))
  164. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  165. if err := tree.Add(k, v); err != nil {
  166. t.Fatal(err)
  167. }
  168. }
  169. keys, err := tree.getKeysAtLevel(2)
  170. c.Assert(err, qt.IsNil)
  171. expected := []string{
  172. "a5d5f14fce7348e40751496cf25d107d91b0bd043435b9577d778a01f8aa6111",
  173. "e9e8dd9b28a7f81d1ff34cb5cefc0146dd848b31031a427b79bdadb62e7f6910",
  174. }
  175. for i := 0; i < len(keys); i++ {
  176. c.Assert(hex.EncodeToString(keys[i]), qt.Equals, expected[i])
  177. }
  178. keys, err = tree.getKeysAtLevel(3)
  179. c.Assert(err, qt.IsNil)
  180. expected = []string{
  181. "9f12c13e52bca96ad4882a26558e48ab67ddd63e062b839207e893d961390f01",
  182. "16d246dd6826ec7346c7328f11c4261facf82d4689f33263ff6e207956a77f21",
  183. "4a22cc901c6337daa17a431fa20170684b710e5f551509511492ec24e81a8f2f",
  184. "470d61abcbd154977bffc9a9ec5a8daff0caabcf2a25e8441f604c79daa0f82d",
  185. }
  186. for i := 0; i < len(keys); i++ {
  187. c.Assert(hex.EncodeToString(keys[i]), qt.Equals, expected[i])
  188. }
  189. keys, err = tree.getKeysAtLevel(4)
  190. c.Assert(err, qt.IsNil)
  191. expected = []string{
  192. "7a5d1c81f7b96318012de3417e53d4f13df5b1337718651cd29d0cb0a66edd20",
  193. "3408213e4e844bdf3355eb8781c74e31626812898c2dbe141ed6d2c92256fc1c",
  194. "dfd8a4d0b6954a3e9f3892e655b58d456eeedf9367f27dfdd9bc2dd6a5577312",
  195. "9e99fbec06fb2a6725997c12c4995f62725eb4cce4808523a5a5e80cca64b007",
  196. "0befa1e070231dbf4e8ff841c05878cdec823e0c09594c24910a248b3ff5a628",
  197. "b7131b0a15c772a57005a4dc5d0d6dd4b3414f5d9ee7408ce5e86c5ab3520e04",
  198. "6d1abe0364077846a56bab1deb1a04883eb796b74fe531a7676a9a370f83ab21",
  199. "4270116394bede69cf9cd72069eca018238080380bef5de75be8dcbbe968e105",
  200. }
  201. for i := 0; i < len(keys); i++ {
  202. c.Assert(hex.EncodeToString(keys[i]), qt.Equals, expected[i])
  203. }
  204. }
  205. func TestSplitInBuckets(t *testing.T) {
  206. c := qt.New(t)
  207. nLeafs := 16
  208. kvs := make([]kv, nLeafs)
  209. for i := 0; i < nLeafs; i++ {
  210. k := BigIntToBytes(big.NewInt(int64(i)))
  211. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  212. keyPath := make([]byte, 32)
  213. copy(keyPath[:], k)
  214. kvs[i].pos = i
  215. kvs[i].keyPath = k
  216. kvs[i].k = k
  217. kvs[i].v = v
  218. }
  219. // check keyToBucket results for 4 buckets & 8 keys
  220. c.Assert(keyToBucket(kvs[0].k, 4), qt.Equals, 0)
  221. c.Assert(keyToBucket(kvs[1].k, 4), qt.Equals, 2)
  222. c.Assert(keyToBucket(kvs[2].k, 4), qt.Equals, 1)
  223. c.Assert(keyToBucket(kvs[3].k, 4), qt.Equals, 3)
  224. c.Assert(keyToBucket(kvs[4].k, 4), qt.Equals, 0)
  225. c.Assert(keyToBucket(kvs[5].k, 4), qt.Equals, 2)
  226. c.Assert(keyToBucket(kvs[6].k, 4), qt.Equals, 1)
  227. c.Assert(keyToBucket(kvs[7].k, 4), qt.Equals, 3)
  228. // check keyToBucket results for 8 buckets & 8 keys
  229. c.Assert(keyToBucket(kvs[0].k, 8), qt.Equals, 0)
  230. c.Assert(keyToBucket(kvs[1].k, 8), qt.Equals, 4)
  231. c.Assert(keyToBucket(kvs[2].k, 8), qt.Equals, 2)
  232. c.Assert(keyToBucket(kvs[3].k, 8), qt.Equals, 6)
  233. c.Assert(keyToBucket(kvs[4].k, 8), qt.Equals, 1)
  234. c.Assert(keyToBucket(kvs[5].k, 8), qt.Equals, 5)
  235. c.Assert(keyToBucket(kvs[6].k, 8), qt.Equals, 3)
  236. c.Assert(keyToBucket(kvs[7].k, 8), qt.Equals, 7)
  237. buckets := splitInBuckets(kvs, 4)
  238. expected := [][]string{
  239. {
  240. "00000000", // bucket 0
  241. "08000000",
  242. "04000000",
  243. "0c000000",
  244. },
  245. {
  246. "02000000", // bucket 1
  247. "0a000000",
  248. "06000000",
  249. "0e000000",
  250. },
  251. {
  252. "01000000", // bucket 2
  253. "09000000",
  254. "05000000",
  255. "0d000000",
  256. },
  257. {
  258. "03000000", // bucket 3
  259. "0b000000",
  260. "07000000",
  261. "0f000000",
  262. },
  263. }
  264. for i := 0; i < len(buckets); i++ {
  265. sortKvs(buckets[i])
  266. c.Assert(len(buckets[i]), qt.Equals, len(expected[i]))
  267. for j := 0; j < len(buckets[i]); j++ {
  268. c.Check(hex.EncodeToString(buckets[i][j].k[:4]), qt.Equals, expected[i][j])
  269. }
  270. }
  271. }
  272. func TestAddBatchCaseC(t *testing.T) {
  273. c := qt.New(t)
  274. nLeafs := 1024
  275. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  276. c.Assert(err, qt.IsNil)
  277. defer tree.db.Close()
  278. start := time.Now()
  279. for i := 0; i < nLeafs; i++ {
  280. k := BigIntToBytes(big.NewInt(int64(i)))
  281. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  282. if err := tree.Add(k, v); err != nil {
  283. t.Fatal(err)
  284. }
  285. }
  286. fmt.Println("time elapsed without CASE C: ", time.Since(start))
  287. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  288. c.Assert(err, qt.IsNil)
  289. defer tree2.db.Close()
  290. // add the initial leafs to fill a bit the tree before calling the
  291. // AddBatch method
  292. for i := 0; i < 101; i++ { // TMP TODO use const minLeafsThreshold-1 once ready
  293. k := BigIntToBytes(big.NewInt(int64(i)))
  294. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  295. if err := tree2.Add(k, v); err != nil {
  296. t.Fatal(err)
  297. }
  298. }
  299. var keys, values [][]byte
  300. for i := 101; i < nLeafs; i++ {
  301. k := BigIntToBytes(big.NewInt(int64(i)))
  302. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  303. keys = append(keys, k)
  304. values = append(values, v)
  305. }
  306. start = time.Now()
  307. indexes, err := tree2.AddBatchOpt(keys, values)
  308. c.Assert(err, qt.IsNil)
  309. fmt.Println("time elapsed with CASE C: ", time.Since(start))
  310. c.Check(len(indexes), qt.Equals, 0)
  311. // check that both trees roots are equal
  312. c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
  313. }
  314. func TestAddBatchCaseD(t *testing.T) {
  315. c := qt.New(t)
  316. nLeafs := 4096
  317. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  318. c.Assert(err, qt.IsNil)
  319. defer tree.db.Close()
  320. start := time.Now()
  321. for i := 0; i < nLeafs; i++ {
  322. k := BigIntToBytes(big.NewInt(int64(i)))
  323. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  324. if err := tree.Add(k, v); err != nil {
  325. t.Fatal(err)
  326. }
  327. }
  328. fmt.Println("time elapsed without CASE D: ", time.Since(start))
  329. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  330. c.Assert(err, qt.IsNil)
  331. defer tree2.db.Close()
  332. // add the initial leafs to fill a bit the tree before calling the
  333. // AddBatch method
  334. for i := 0; i < 900; i++ { // TMP TODO use const minLeafsThreshold+1 once ready
  335. k := BigIntToBytes(big.NewInt(int64(i)))
  336. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  337. if err := tree2.Add(k, v); err != nil {
  338. t.Fatal(err)
  339. }
  340. }
  341. var keys, values [][]byte
  342. for i := 900; i < nLeafs; i++ {
  343. k := BigIntToBytes(big.NewInt(int64(i)))
  344. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  345. keys = append(keys, k)
  346. values = append(values, v)
  347. }
  348. start = time.Now()
  349. indexes, err := tree2.AddBatchOpt(keys, values)
  350. c.Assert(err, qt.IsNil)
  351. fmt.Println("time elapsed with CASE D: ", time.Since(start))
  352. c.Check(len(indexes), qt.Equals, 0)
  353. // check that both trees roots are equal
  354. c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
  355. }
  356. func TestAddBatchCaseE(t *testing.T) {
  357. c := qt.New(t)
  358. nLeafs := 4096
  359. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  360. c.Assert(err, qt.IsNil)
  361. defer tree.db.Close()
  362. start := time.Now()
  363. for i := 0; i < nLeafs; i++ {
  364. k := BigIntToBytes(big.NewInt(int64(i)))
  365. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  366. if err := tree.Add(k, v); err != nil {
  367. t.Fatal(err)
  368. }
  369. }
  370. fmt.Println("time elapsed without CASE E: ", time.Since(start))
  371. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  372. c.Assert(err, qt.IsNil)
  373. defer tree2.db.Close()
  374. var keys, values [][]byte
  375. // add the initial leafs to fill a bit the tree before calling the
  376. // AddBatch method
  377. for i := 0; i < 900; i++ { // TMP TODO use const minLeafsThreshold+1 once ready
  378. k := BigIntToBytes(big.NewInt(int64(i)))
  379. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  380. // use only the keys of one bucket, store the not used ones for
  381. // later
  382. if i%4 != 0 {
  383. keys = append(keys, k)
  384. values = append(values, v)
  385. continue
  386. }
  387. if err := tree2.Add(k, v); err != nil {
  388. t.Fatal(err)
  389. }
  390. }
  391. for i := 900; i < nLeafs; i++ {
  392. k := BigIntToBytes(big.NewInt(int64(i)))
  393. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  394. keys = append(keys, k)
  395. values = append(values, v)
  396. }
  397. start = time.Now()
  398. indexes, err := tree2.AddBatchOpt(keys, values)
  399. c.Assert(err, qt.IsNil)
  400. fmt.Println("time elapsed with CASE E: ", time.Since(start))
  401. c.Check(len(indexes), qt.Equals, 0)
  402. // check that both trees roots are equal
  403. c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
  404. }
  405. func TestHighestPowerOfTwo(t *testing.T) {
  406. c := qt.New(t)
  407. c.Assert(highestPowerOfTwo(31), qt.Equals, 16)
  408. c.Assert(highestPowerOfTwo(32), qt.Equals, 32)
  409. c.Assert(highestPowerOfTwo(33), qt.Equals, 32)
  410. c.Assert(highestPowerOfTwo(63), qt.Equals, 32)
  411. c.Assert(highestPowerOfTwo(64), qt.Equals, 64)
  412. }
  413. // func printLeafs(name string, t *Tree) {
  414. // w := bytes.NewBufferString("")
  415. //
  416. // err := t.Iterate(func(k, v []byte) {
  417. // if v[0] != PrefixValueLeaf {
  418. // return
  419. // }
  420. // leafK, _ := readLeafValue(v)
  421. // fmt.Fprintf(w, hex.EncodeToString(leafK[:4])+"\n")
  422. // })
  423. // if err != nil {
  424. // panic(err)
  425. // }
  426. // err = ioutil.WriteFile(name, w.Bytes(), 0644)
  427. // if err != nil {
  428. // panic(err)
  429. // }
  430. //
  431. // }
  432. // func TestComputeCosts(t *testing.T) {
  433. // fmt.Println(computeSimpleAddCost(10))
  434. // fmt.Println(computeBottomUpAddCost(10))
  435. //
  436. // fmt.Println(computeSimpleAddCost(1024))
  437. // fmt.Println(computeBottomUpAddCost(1024))
  438. // }
  439. // TODO test tree with nLeafs > minLeafsThreshold, but that at level L, there is
  440. // less keys than nBuckets (so CASE C could be applied if first few leafs are
  441. // added to balance the tree)
  442. // TODO for Cases tests, add initial keys, do snapshot, and then measure time of
  443. // adding the rest of keys with loop over normal Add, and with AddBatch
  444. // TODO test adding batch with repeated keys in the batch
  445. // TODO test adding batch with multiple invalid keys