You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

517 lines
14 KiB

  1. package arbo
  2. import (
  3. "encoding/hex"
  4. "fmt"
  5. "math/big"
  6. "testing"
  7. "time"
  8. qt "github.com/frankban/quicktest"
  9. "github.com/iden3/go-merkletree/db/memory"
  10. )
  11. var debug = true
  12. func debugTime(descr string, time1, time2 time.Duration) {
  13. if debug {
  14. fmt.Printf("%s was %f times faster than without AddBatch\n",
  15. descr, float64(time1)/float64(time2))
  16. }
  17. }
  18. func testInit(c *qt.C, n int) (*Tree, *Tree) {
  19. tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  20. c.Assert(err, qt.IsNil)
  21. defer tree1.db.Close()
  22. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  23. c.Assert(err, qt.IsNil)
  24. defer tree2.db.Close()
  25. // add the initial leafs to fill a bit the trees before calling the
  26. // AddBatch method
  27. for i := 0; i < n; i++ {
  28. k := BigIntToBytes(big.NewInt(int64(i)))
  29. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  30. if err := tree1.Add(k, v); err != nil {
  31. c.Fatal(err)
  32. }
  33. if err := tree2.Add(k, v); err != nil {
  34. c.Fatal(err)
  35. }
  36. }
  37. return tree1, tree2
  38. }
  39. func TestAddBatchCaseA(t *testing.T) {
  40. c := qt.New(t)
  41. nLeafs := 1024
  42. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  43. c.Assert(err, qt.IsNil)
  44. defer tree.db.Close()
  45. start := time.Now()
  46. for i := 0; i < nLeafs; i++ {
  47. k := BigIntToBytes(big.NewInt(int64(i)))
  48. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  49. if err := tree.Add(k, v); err != nil {
  50. t.Fatal(err)
  51. }
  52. }
  53. time1 := time.Since(start)
  54. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  55. c.Assert(err, qt.IsNil)
  56. defer tree2.db.Close()
  57. var keys, values [][]byte
  58. for i := 0; i < nLeafs; i++ {
  59. k := BigIntToBytes(big.NewInt(int64(i)))
  60. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  61. keys = append(keys, k)
  62. values = append(values, v)
  63. }
  64. start = time.Now()
  65. indexes, err := tree2.AddBatch(keys, values)
  66. c.Assert(err, qt.IsNil)
  67. time2 := time.Since(start)
  68. debugTime("CASE A, AddBatch", time1, time2)
  69. c.Check(len(indexes), qt.Equals, 0)
  70. // check that both trees roots are equal
  71. c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
  72. }
  73. func TestAddBatchCaseANotPowerOf2(t *testing.T) {
  74. c := qt.New(t)
  75. nLeafs := 1027
  76. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  77. c.Assert(err, qt.IsNil)
  78. defer tree.db.Close()
  79. for i := 0; i < nLeafs; i++ {
  80. k := BigIntToBytes(big.NewInt(int64(i)))
  81. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  82. if err := tree.Add(k, v); err != nil {
  83. t.Fatal(err)
  84. }
  85. }
  86. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  87. c.Assert(err, qt.IsNil)
  88. defer tree2.db.Close()
  89. var keys, values [][]byte
  90. for i := 0; i < nLeafs; i++ {
  91. k := BigIntToBytes(big.NewInt(int64(i)))
  92. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  93. keys = append(keys, k)
  94. values = append(values, v)
  95. }
  96. indexes, err := tree2.AddBatch(keys, values)
  97. c.Assert(err, qt.IsNil)
  98. c.Check(len(indexes), qt.Equals, 0)
  99. // check that both trees roots are equal
  100. c.Check(tree2.Root(), qt.DeepEquals, tree.Root())
  101. }
  102. func TestAddBatchCaseB(t *testing.T) {
  103. c := qt.New(t)
  104. nLeafs := 1024
  105. initialNLeafs := 99 // TMP TODO use const minLeafsThreshold-1 once ready
  106. tree1, tree2 := testInit(c, initialNLeafs)
  107. start := time.Now()
  108. for i := initialNLeafs; i < nLeafs; i++ {
  109. k := BigIntToBytes(big.NewInt(int64(i)))
  110. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  111. if err := tree1.Add(k, v); err != nil {
  112. t.Fatal(err)
  113. }
  114. }
  115. time1 := time.Since(start)
  116. // prepare the key-values to be added
  117. var keys, values [][]byte
  118. for i := initialNLeafs; i < nLeafs; i++ {
  119. k := BigIntToBytes(big.NewInt(int64(i)))
  120. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  121. keys = append(keys, k)
  122. values = append(values, v)
  123. }
  124. start = time.Now()
  125. indexes, err := tree2.AddBatch(keys, values)
  126. c.Assert(err, qt.IsNil)
  127. time2 := time.Since(start)
  128. debugTime("CASE B, AddBatch", time1, time2)
  129. c.Check(len(indexes), qt.Equals, 0)
  130. // check that both trees roots are equal
  131. c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
  132. }
  133. func TestAddBatchCaseBRepeatedLeafs(t *testing.T) {
  134. c := qt.New(t)
  135. nLeafs := 1024
  136. initialNLeafs := 99 // TMP TODO use const minLeafsThreshold-1 once ready
  137. tree1, tree2 := testInit(c, initialNLeafs)
  138. for i := initialNLeafs; i < nLeafs; i++ {
  139. k := BigIntToBytes(big.NewInt(int64(i)))
  140. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  141. if err := tree1.Add(k, v); err != nil {
  142. t.Fatal(err)
  143. }
  144. }
  145. // prepare the key-values to be added, including already existing keys
  146. var keys, values [][]byte
  147. for i := 0; i < nLeafs; i++ {
  148. k := BigIntToBytes(big.NewInt(int64(i)))
  149. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  150. keys = append(keys, k)
  151. values = append(values, v)
  152. }
  153. indexes, err := tree2.AddBatch(keys, values)
  154. c.Assert(err, qt.IsNil)
  155. c.Check(len(indexes), qt.Equals, initialNLeafs)
  156. // check that both trees roots are equal
  157. c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
  158. }
  159. func TestCombineInKVSet(t *testing.T) {
  160. c := qt.New(t)
  161. var a, b, expected []kv
  162. for i := 0; i < 10; i++ {
  163. k := BigIntToBytes(big.NewInt(int64(i)))
  164. kv := kv{k: k}
  165. if i < 7 {
  166. a = append(a, kv)
  167. }
  168. if i >= 4 {
  169. b = append(b, kv)
  170. }
  171. expected = append(expected, kv)
  172. }
  173. r, invalids := combineInKVSet(a, b)
  174. for i := 0; i < len(r); i++ {
  175. c.Assert(r[i].k, qt.DeepEquals, expected[i].k)
  176. }
  177. c.Assert(len(invalids), qt.Equals, 7-4)
  178. }
  179. func TestGetKeysAtLevel(t *testing.T) {
  180. c := qt.New(t)
  181. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  182. c.Assert(err, qt.IsNil)
  183. defer tree.db.Close()
  184. for i := 0; i < 32; i++ {
  185. k := BigIntToBytes(big.NewInt(int64(i)))
  186. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  187. if err := tree.Add(k, v); err != nil {
  188. t.Fatal(err)
  189. }
  190. }
  191. keys, err := tree.getKeysAtLevel(2)
  192. c.Assert(err, qt.IsNil)
  193. expected := []string{
  194. "a5d5f14fce7348e40751496cf25d107d91b0bd043435b9577d778a01f8aa6111",
  195. "e9e8dd9b28a7f81d1ff34cb5cefc0146dd848b31031a427b79bdadb62e7f6910",
  196. }
  197. for i := 0; i < len(keys); i++ {
  198. c.Assert(hex.EncodeToString(keys[i]), qt.Equals, expected[i])
  199. }
  200. keys, err = tree.getKeysAtLevel(3)
  201. c.Assert(err, qt.IsNil)
  202. expected = []string{
  203. "9f12c13e52bca96ad4882a26558e48ab67ddd63e062b839207e893d961390f01",
  204. "16d246dd6826ec7346c7328f11c4261facf82d4689f33263ff6e207956a77f21",
  205. "4a22cc901c6337daa17a431fa20170684b710e5f551509511492ec24e81a8f2f",
  206. "470d61abcbd154977bffc9a9ec5a8daff0caabcf2a25e8441f604c79daa0f82d",
  207. }
  208. for i := 0; i < len(keys); i++ {
  209. c.Assert(hex.EncodeToString(keys[i]), qt.Equals, expected[i])
  210. }
  211. keys, err = tree.getKeysAtLevel(4)
  212. c.Assert(err, qt.IsNil)
  213. expected = []string{
  214. "7a5d1c81f7b96318012de3417e53d4f13df5b1337718651cd29d0cb0a66edd20",
  215. "3408213e4e844bdf3355eb8781c74e31626812898c2dbe141ed6d2c92256fc1c",
  216. "dfd8a4d0b6954a3e9f3892e655b58d456eeedf9367f27dfdd9bc2dd6a5577312",
  217. "9e99fbec06fb2a6725997c12c4995f62725eb4cce4808523a5a5e80cca64b007",
  218. "0befa1e070231dbf4e8ff841c05878cdec823e0c09594c24910a248b3ff5a628",
  219. "b7131b0a15c772a57005a4dc5d0d6dd4b3414f5d9ee7408ce5e86c5ab3520e04",
  220. "6d1abe0364077846a56bab1deb1a04883eb796b74fe531a7676a9a370f83ab21",
  221. "4270116394bede69cf9cd72069eca018238080380bef5de75be8dcbbe968e105",
  222. }
  223. for i := 0; i < len(keys); i++ {
  224. c.Assert(hex.EncodeToString(keys[i]), qt.Equals, expected[i])
  225. }
  226. }
  227. func TestSplitInBuckets(t *testing.T) {
  228. c := qt.New(t)
  229. nLeafs := 16
  230. kvs := make([]kv, nLeafs)
  231. for i := 0; i < nLeafs; i++ {
  232. k := BigIntToBytes(big.NewInt(int64(i)))
  233. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  234. keyPath := make([]byte, 32)
  235. copy(keyPath[:], k)
  236. kvs[i].pos = i
  237. kvs[i].keyPath = k
  238. kvs[i].k = k
  239. kvs[i].v = v
  240. }
  241. // check keyToBucket results for 4 buckets & 8 keys
  242. c.Assert(keyToBucket(kvs[0].k, 4), qt.Equals, 0)
  243. c.Assert(keyToBucket(kvs[1].k, 4), qt.Equals, 2)
  244. c.Assert(keyToBucket(kvs[2].k, 4), qt.Equals, 1)
  245. c.Assert(keyToBucket(kvs[3].k, 4), qt.Equals, 3)
  246. c.Assert(keyToBucket(kvs[4].k, 4), qt.Equals, 0)
  247. c.Assert(keyToBucket(kvs[5].k, 4), qt.Equals, 2)
  248. c.Assert(keyToBucket(kvs[6].k, 4), qt.Equals, 1)
  249. c.Assert(keyToBucket(kvs[7].k, 4), qt.Equals, 3)
  250. // check keyToBucket results for 8 buckets & 8 keys
  251. c.Assert(keyToBucket(kvs[0].k, 8), qt.Equals, 0)
  252. c.Assert(keyToBucket(kvs[1].k, 8), qt.Equals, 4)
  253. c.Assert(keyToBucket(kvs[2].k, 8), qt.Equals, 2)
  254. c.Assert(keyToBucket(kvs[3].k, 8), qt.Equals, 6)
  255. c.Assert(keyToBucket(kvs[4].k, 8), qt.Equals, 1)
  256. c.Assert(keyToBucket(kvs[5].k, 8), qt.Equals, 5)
  257. c.Assert(keyToBucket(kvs[6].k, 8), qt.Equals, 3)
  258. c.Assert(keyToBucket(kvs[7].k, 8), qt.Equals, 7)
  259. buckets := splitInBuckets(kvs, 4)
  260. expected := [][]string{
  261. {
  262. "00000000", // bucket 0
  263. "08000000",
  264. "04000000",
  265. "0c000000",
  266. },
  267. {
  268. "02000000", // bucket 1
  269. "0a000000",
  270. "06000000",
  271. "0e000000",
  272. },
  273. {
  274. "01000000", // bucket 2
  275. "09000000",
  276. "05000000",
  277. "0d000000",
  278. },
  279. {
  280. "03000000", // bucket 3
  281. "0b000000",
  282. "07000000",
  283. "0f000000",
  284. },
  285. }
  286. for i := 0; i < len(buckets); i++ {
  287. sortKvs(buckets[i])
  288. c.Assert(len(buckets[i]), qt.Equals, len(expected[i]))
  289. for j := 0; j < len(buckets[i]); j++ {
  290. c.Check(hex.EncodeToString(buckets[i][j].k[:4]), qt.Equals, expected[i][j])
  291. }
  292. }
  293. }
  294. func TestAddBatchCaseC(t *testing.T) {
  295. c := qt.New(t)
  296. nLeafs := 1024
  297. initialNLeafs := 101 // TMP TODO use const minLeafsThreshold+1 once ready
  298. tree1, tree2 := testInit(c, initialNLeafs)
  299. start := time.Now()
  300. for i := initialNLeafs; i < nLeafs; i++ {
  301. k := BigIntToBytes(big.NewInt(int64(i)))
  302. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  303. if err := tree1.Add(k, v); err != nil {
  304. t.Fatal(err)
  305. }
  306. }
  307. time1 := time.Since(start)
  308. // prepare the key-values to be added
  309. var keys, values [][]byte
  310. for i := initialNLeafs; i < nLeafs; i++ {
  311. k := BigIntToBytes(big.NewInt(int64(i)))
  312. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  313. keys = append(keys, k)
  314. values = append(values, v)
  315. }
  316. start = time.Now()
  317. indexes, err := tree2.AddBatch(keys, values)
  318. c.Assert(err, qt.IsNil)
  319. time2 := time.Since(start)
  320. debugTime("CASE C, AddBatch", time1, time2)
  321. c.Check(len(indexes), qt.Equals, 0)
  322. // check that both trees roots are equal
  323. c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
  324. }
  325. func TestAddBatchCaseD(t *testing.T) {
  326. c := qt.New(t)
  327. nLeafs := 4096
  328. initialNLeafs := 900
  329. tree1, tree2 := testInit(c, initialNLeafs)
  330. start := time.Now()
  331. for i := initialNLeafs; i < nLeafs; i++ {
  332. k := BigIntToBytes(big.NewInt(int64(i)))
  333. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  334. if err := tree1.Add(k, v); err != nil {
  335. t.Fatal(err)
  336. }
  337. }
  338. time1 := time.Since(start)
  339. // prepare the key-values to be added
  340. var keys, values [][]byte
  341. for i := initialNLeafs; i < nLeafs; i++ {
  342. k := BigIntToBytes(big.NewInt(int64(i)))
  343. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  344. keys = append(keys, k)
  345. values = append(values, v)
  346. }
  347. start = time.Now()
  348. indexes, err := tree2.AddBatch(keys, values)
  349. c.Assert(err, qt.IsNil)
  350. time2 := time.Since(start)
  351. debugTime("CASE D, AddBatch", time1, time2)
  352. c.Check(len(indexes), qt.Equals, 0)
  353. // check that both trees roots are equal
  354. c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
  355. }
  356. func TestAddBatchCaseE(t *testing.T) {
  357. c := qt.New(t)
  358. nLeafs := 4096
  359. initialNLeafs := 900
  360. tree1, _ := testInit(c, initialNLeafs)
  361. start := time.Now()
  362. for i := initialNLeafs; i < nLeafs; i++ {
  363. k := BigIntToBytes(big.NewInt(int64(i)))
  364. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  365. if err := tree1.Add(k, v); err != nil {
  366. t.Fatal(err)
  367. }
  368. }
  369. time1 := time.Since(start)
  370. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  371. c.Assert(err, qt.IsNil)
  372. defer tree2.db.Close()
  373. var keys, values [][]byte
  374. // add the initial leafs to fill a bit the tree before calling the
  375. // AddBatch method
  376. for i := 0; i < initialNLeafs; i++ {
  377. k := BigIntToBytes(big.NewInt(int64(i)))
  378. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  379. // use only the keys of one bucket, store the not used ones for
  380. // later
  381. if i%4 != 0 {
  382. keys = append(keys, k)
  383. values = append(values, v)
  384. continue
  385. }
  386. if err := tree2.Add(k, v); err != nil {
  387. t.Fatal(err)
  388. }
  389. }
  390. for i := initialNLeafs; i < nLeafs; i++ {
  391. k := BigIntToBytes(big.NewInt(int64(i)))
  392. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  393. keys = append(keys, k)
  394. values = append(values, v)
  395. }
  396. start = time.Now()
  397. indexes, err := tree2.AddBatch(keys, values)
  398. c.Assert(err, qt.IsNil)
  399. time2 := time.Since(start)
  400. debugTime("CASE E, AddBatch", time1, time2)
  401. c.Check(len(indexes), qt.Equals, 0)
  402. // check that both trees roots are equal
  403. c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
  404. }
  405. func TestHighestPowerOfTwo(t *testing.T) {
  406. c := qt.New(t)
  407. c.Assert(highestPowerOfTwo(31), qt.Equals, 16)
  408. c.Assert(highestPowerOfTwo(32), qt.Equals, 32)
  409. c.Assert(highestPowerOfTwo(33), qt.Equals, 32)
  410. c.Assert(highestPowerOfTwo(63), qt.Equals, 32)
  411. c.Assert(highestPowerOfTwo(64), qt.Equals, 64)
  412. }
  413. // func printLeafs(name string, t *Tree) {
  414. // w := bytes.NewBufferString("")
  415. //
  416. // err := t.Iterate(func(k, v []byte) {
  417. // if v[0] != PrefixValueLeaf {
  418. // return
  419. // }
  420. // leafK, _ := readLeafValue(v)
  421. // fmt.Fprintf(w, hex.EncodeToString(leafK[:4])+"\n")
  422. // })
  423. // if err != nil {
  424. // panic(err)
  425. // }
  426. // err = ioutil.WriteFile(name, w.Bytes(), 0644)
  427. // if err != nil {
  428. // panic(err)
  429. // }
  430. //
  431. // }
  432. // func TestComputeCosts(t *testing.T) {
  433. // fmt.Println(computeSimpleAddCost(10))
  434. // fmt.Println(computeBottomUpAddCost(10))
  435. //
  436. // fmt.Println(computeSimpleAddCost(1024))
  437. // fmt.Println(computeBottomUpAddCost(1024))
  438. // }
  439. // TODO test tree with nLeafs > minLeafsThreshold, but that at level L, there is
  440. // less keys than nBuckets (so CASE C could be applied if first few leafs are
  441. // added to balance the tree)
  442. // TODO test adding batch with repeated keys in the batch
  443. // TODO test adding batch with multiple invalid keys