You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

415 lines
11 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. package arbo
  2. import (
  3. "encoding/hex"
  4. "math/big"
  5. "testing"
  6. "time"
  7. qt "github.com/frankban/quicktest"
  8. "github.com/iden3/go-merkletree/db/memory"
  9. )
  10. func TestAddTestVectors(t *testing.T) {
  11. c := qt.New(t)
  12. // Poseidon test vectors generated using https://github.com/iden3/circomlib smt.js
  13. testVectorsPoseidon := []string{
  14. "0000000000000000000000000000000000000000000000000000000000000000",
  15. "13578938674299138072471463694055224830892726234048532520316387704878000008795",
  16. "5412393676474193513566895793055462193090331607895808993925969873307089394741",
  17. "14204494359367183802864593755198662203838502594566452929175967972147978322084",
  18. }
  19. testAdd(c, HashFunctionPoseidon, testVectorsPoseidon)
  20. testVectorsSha256 := []string{
  21. "0000000000000000000000000000000000000000000000000000000000000000",
  22. "46910109172468462938850740851377282682950237270676610513794735904325820156367",
  23. "59481735341404520835410489183267411392292882901306595567679529387376287440550",
  24. "20573794434149960984975763118181266662429997821552560184909083010514790081771",
  25. }
  26. testAdd(c, HashFunctionSha256, testVectorsSha256)
  27. }
  28. func testAdd(c *qt.C, hashFunc HashFunction, testVectors []string) {
  29. tree, err := NewTree(memory.NewMemoryStorage(), 10, hashFunc)
  30. c.Assert(err, qt.IsNil)
  31. defer tree.db.Close()
  32. c.Check(hex.EncodeToString(tree.Root()), qt.Equals, testVectors[0])
  33. err = tree.Add(
  34. BigIntToBytes(big.NewInt(1)),
  35. BigIntToBytes(big.NewInt(2)))
  36. c.Assert(err, qt.IsNil)
  37. rootBI := BytesToBigInt(tree.Root())
  38. c.Check(rootBI.String(), qt.Equals, testVectors[1])
  39. err = tree.Add(
  40. BigIntToBytes(big.NewInt(33)),
  41. BigIntToBytes(big.NewInt(44)))
  42. c.Assert(err, qt.IsNil)
  43. rootBI = BytesToBigInt(tree.Root())
  44. c.Check(rootBI.String(), qt.Equals, testVectors[2])
  45. err = tree.Add(
  46. BigIntToBytes(big.NewInt(1234)),
  47. BigIntToBytes(big.NewInt(9876)))
  48. c.Assert(err, qt.IsNil)
  49. rootBI = BytesToBigInt(tree.Root())
  50. c.Check(rootBI.String(), qt.Equals, testVectors[3])
  51. }
  52. func TestAddBatch(t *testing.T) {
  53. c := qt.New(t)
  54. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  55. c.Assert(err, qt.IsNil)
  56. defer tree.db.Close()
  57. for i := 0; i < 1000; i++ {
  58. k := BigIntToBytes(big.NewInt(int64(i)))
  59. v := BigIntToBytes(big.NewInt(0))
  60. if err := tree.Add(k, v); err != nil {
  61. t.Fatal(err)
  62. }
  63. }
  64. rootBI := BytesToBigInt(tree.Root())
  65. c.Check(rootBI.String(), qt.Equals,
  66. "296519252211642170490407814696803112091039265640052570497930797516015811235")
  67. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  68. c.Assert(err, qt.IsNil)
  69. defer tree2.db.Close()
  70. var keys, values [][]byte
  71. for i := 0; i < 1000; i++ {
  72. k := BigIntToBytes(big.NewInt(int64(i)))
  73. v := BigIntToBytes(big.NewInt(0))
  74. keys = append(keys, k)
  75. values = append(values, v)
  76. }
  77. indexes, err := tree2.AddBatch(keys, values)
  78. c.Assert(err, qt.IsNil)
  79. c.Check(len(indexes), qt.Equals, 0)
  80. rootBI = BytesToBigInt(tree2.Root())
  81. c.Check(rootBI.String(), qt.Equals,
  82. "296519252211642170490407814696803112091039265640052570497930797516015811235")
  83. }
  84. func TestAddDifferentOrder(t *testing.T) {
  85. c := qt.New(t)
  86. tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  87. c.Assert(err, qt.IsNil)
  88. defer tree1.db.Close()
  89. for i := 0; i < 16; i++ {
  90. k := BigIntToBytes(big.NewInt(int64(i)))
  91. v := BigIntToBytes(big.NewInt(0))
  92. if err := tree1.Add(k, v); err != nil {
  93. t.Fatal(err)
  94. }
  95. }
  96. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  97. c.Assert(err, qt.IsNil)
  98. defer tree2.db.Close()
  99. for i := 16 - 1; i >= 0; i-- {
  100. k := BigIntToBytes(big.NewInt(int64(i)))
  101. v := BigIntToBytes(big.NewInt(0))
  102. if err := tree2.Add(k, v); err != nil {
  103. t.Fatal(err)
  104. }
  105. }
  106. c.Check(hex.EncodeToString(tree2.Root()), qt.Equals, hex.EncodeToString(tree1.Root()))
  107. c.Check(hex.EncodeToString(tree1.Root()), qt.Equals,
  108. "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f")
  109. }
  110. func TestAddRepeatedIndex(t *testing.T) {
  111. c := qt.New(t)
  112. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  113. c.Assert(err, qt.IsNil)
  114. defer tree.db.Close()
  115. k := BigIntToBytes(big.NewInt(int64(3)))
  116. v := BigIntToBytes(big.NewInt(int64(12)))
  117. if err := tree.Add(k, v); err != nil {
  118. t.Fatal(err)
  119. }
  120. err = tree.Add(k, v)
  121. c.Assert(err, qt.Not(qt.IsNil))
  122. c.Check(err, qt.ErrorMatches, "max virtual level 100")
  123. }
  124. func TestUpdate(t *testing.T) {
  125. c := qt.New(t)
  126. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  127. c.Assert(err, qt.IsNil)
  128. defer tree.db.Close()
  129. k := BigIntToBytes(big.NewInt(int64(20)))
  130. v := BigIntToBytes(big.NewInt(int64(12)))
  131. if err := tree.Add(k, v); err != nil {
  132. t.Fatal(err)
  133. }
  134. v = BigIntToBytes(big.NewInt(int64(11)))
  135. err = tree.Update(k, v)
  136. c.Assert(err, qt.IsNil)
  137. gettedKey, gettedValue, err := tree.Get(k)
  138. c.Assert(err, qt.IsNil)
  139. c.Check(gettedKey, qt.DeepEquals, k)
  140. c.Check(gettedValue, qt.DeepEquals, v)
  141. // add more leafs to the tree to do another test
  142. for i := 0; i < 16; i++ {
  143. k := BigIntToBytes(big.NewInt(int64(i)))
  144. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  145. if err := tree.Add(k, v); err != nil {
  146. t.Fatal(err)
  147. }
  148. }
  149. k = BigIntToBytes(big.NewInt(int64(3)))
  150. v = BigIntToBytes(big.NewInt(int64(11)))
  151. // check that before the Update, value for 3 is !=11
  152. gettedKey, gettedValue, err = tree.Get(k)
  153. c.Assert(err, qt.IsNil)
  154. c.Check(gettedKey, qt.DeepEquals, k)
  155. c.Check(gettedValue, qt.Not(qt.DeepEquals), v)
  156. c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(6)))
  157. err = tree.Update(k, v)
  158. c.Assert(err, qt.IsNil)
  159. // check that after Update, the value for 3 is ==11
  160. gettedKey, gettedValue, err = tree.Get(k)
  161. c.Assert(err, qt.IsNil)
  162. c.Check(gettedKey, qt.DeepEquals, k)
  163. c.Check(gettedValue, qt.DeepEquals, v)
  164. c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(11)))
  165. }
  166. func TestAux(t *testing.T) { // TODO split in proper tests
  167. c := qt.New(t)
  168. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  169. c.Assert(err, qt.IsNil)
  170. defer tree.db.Close()
  171. k := BigIntToBytes(big.NewInt(int64(1)))
  172. v := BigIntToBytes(big.NewInt(int64(0)))
  173. err = tree.Add(k, v)
  174. c.Assert(err, qt.IsNil)
  175. k = BigIntToBytes(big.NewInt(int64(256)))
  176. err = tree.Add(k, v)
  177. c.Assert(err, qt.IsNil)
  178. k = BigIntToBytes(big.NewInt(int64(257)))
  179. err = tree.Add(k, v)
  180. c.Assert(err, qt.IsNil)
  181. k = BigIntToBytes(big.NewInt(int64(515)))
  182. err = tree.Add(k, v)
  183. c.Assert(err, qt.IsNil)
  184. k = BigIntToBytes(big.NewInt(int64(770)))
  185. err = tree.Add(k, v)
  186. c.Assert(err, qt.IsNil)
  187. k = BigIntToBytes(big.NewInt(int64(388)))
  188. err = tree.Add(k, v)
  189. c.Assert(err, qt.IsNil)
  190. k = BigIntToBytes(big.NewInt(int64(900)))
  191. err = tree.Add(k, v)
  192. c.Assert(err, qt.IsNil)
  193. //
  194. // err = tree.PrintGraphviz(nil)
  195. // c.Assert(err, qt.IsNil)
  196. }
  197. func TestGet(t *testing.T) {
  198. c := qt.New(t)
  199. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  200. c.Assert(err, qt.IsNil)
  201. defer tree.db.Close()
  202. for i := 0; i < 10; i++ {
  203. k := BigIntToBytes(big.NewInt(int64(i)))
  204. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  205. if err := tree.Add(k, v); err != nil {
  206. t.Fatal(err)
  207. }
  208. }
  209. k := BigIntToBytes(big.NewInt(int64(7)))
  210. gettedKey, gettedValue, err := tree.Get(k)
  211. c.Assert(err, qt.IsNil)
  212. c.Check(gettedKey, qt.DeepEquals, k)
  213. c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(int64(7*2))))
  214. }
  215. func TestGenProofAndVerify(t *testing.T) {
  216. c := qt.New(t)
  217. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  218. c.Assert(err, qt.IsNil)
  219. defer tree.db.Close()
  220. for i := 0; i < 10; i++ {
  221. k := BigIntToBytes(big.NewInt(int64(i)))
  222. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  223. if err := tree.Add(k, v); err != nil {
  224. t.Fatal(err)
  225. }
  226. }
  227. k := BigIntToBytes(big.NewInt(int64(7)))
  228. _, siblings, err := tree.GenProof(k)
  229. c.Assert(err, qt.IsNil)
  230. k = BigIntToBytes(big.NewInt(int64(7)))
  231. v := BigIntToBytes(big.NewInt(int64(14)))
  232. verif, err := CheckProof(tree.hashFunction, k, v, tree.Root(), siblings)
  233. c.Assert(err, qt.IsNil)
  234. c.Check(verif, qt.IsTrue)
  235. }
  236. func TestDumpAndImportDump(t *testing.T) {
  237. c := qt.New(t)
  238. tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  239. c.Assert(err, qt.IsNil)
  240. defer tree1.db.Close()
  241. for i := 0; i < 16; i++ {
  242. k := BigIntToBytes(big.NewInt(int64(i)))
  243. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  244. if err := tree1.Add(k, v); err != nil {
  245. t.Fatal(err)
  246. }
  247. }
  248. e, err := tree1.Dump()
  249. c.Assert(err, qt.IsNil)
  250. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  251. c.Assert(err, qt.IsNil)
  252. defer tree2.db.Close()
  253. err = tree2.ImportDump(e)
  254. c.Assert(err, qt.IsNil)
  255. c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
  256. c.Check(hex.EncodeToString(tree2.Root()), qt.Equals,
  257. "0d93aaa3362b2f999f15e15728f123087c2eee716f01c01f56e23aae07f09f08")
  258. }
  259. func TestRWMutex(t *testing.T) {
  260. c := qt.New(t)
  261. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  262. c.Assert(err, qt.IsNil)
  263. defer tree.db.Close()
  264. var keys, values [][]byte
  265. for i := 0; i < 1000; i++ {
  266. k := BigIntToBytes(big.NewInt(int64(i)))
  267. v := BigIntToBytes(big.NewInt(0))
  268. keys = append(keys, k)
  269. values = append(values, v)
  270. }
  271. go func() {
  272. _, err = tree.AddBatch(keys, values)
  273. if err != nil {
  274. panic(err)
  275. }
  276. }()
  277. time.Sleep(500 * time.Millisecond)
  278. k := BigIntToBytes(big.NewInt(int64(99999)))
  279. v := BigIntToBytes(big.NewInt(int64(99999)))
  280. if err := tree.Add(k, v); err != nil {
  281. t.Fatal(err)
  282. }
  283. }
  284. func TestSetGetNLeafs(t *testing.T) {
  285. c := qt.New(t)
  286. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  287. c.Assert(err, qt.IsNil)
  288. // 0
  289. tree.tx, err = tree.db.NewTx()
  290. c.Assert(err, qt.IsNil)
  291. err = tree.setNLeafs(0)
  292. c.Assert(err, qt.IsNil)
  293. err = tree.tx.Commit()
  294. c.Assert(err, qt.IsNil)
  295. n, err := tree.GetNLeafs()
  296. c.Assert(err, qt.IsNil)
  297. c.Assert(n, qt.Equals, 0)
  298. // 1024
  299. tree.tx, err = tree.db.NewTx()
  300. c.Assert(err, qt.IsNil)
  301. err = tree.setNLeafs(1024)
  302. c.Assert(err, qt.IsNil)
  303. err = tree.tx.Commit()
  304. c.Assert(err, qt.IsNil)
  305. n, err = tree.GetNLeafs()
  306. c.Assert(err, qt.IsNil)
  307. c.Assert(n, qt.Equals, 1024)
  308. // 2**64 -1
  309. tree.tx, err = tree.db.NewTx()
  310. c.Assert(err, qt.IsNil)
  311. maxUint := ^uint(0)
  312. maxInt := int(maxUint >> 1)
  313. err = tree.setNLeafs(maxInt)
  314. c.Assert(err, qt.IsNil)
  315. err = tree.tx.Commit()
  316. c.Assert(err, qt.IsNil)
  317. n, err = tree.GetNLeafs()
  318. c.Assert(err, qt.IsNil)
  319. c.Assert(n, qt.Equals, maxInt)
  320. }
  321. func BenchmarkAdd(b *testing.B) {
  322. // prepare inputs
  323. var ks, vs [][]byte
  324. for i := 0; i < 1000; i++ {
  325. k := BigIntToBytes(big.NewInt(int64(i)))
  326. v := BigIntToBytes(big.NewInt(int64(i)))
  327. ks = append(ks, k)
  328. vs = append(vs, v)
  329. }
  330. b.Run("Poseidon", func(b *testing.B) {
  331. benchmarkAdd(b, HashFunctionPoseidon, ks, vs)
  332. })
  333. b.Run("Sha256", func(b *testing.B) {
  334. benchmarkAdd(b, HashFunctionSha256, ks, vs)
  335. })
  336. }
  337. func benchmarkAdd(b *testing.B, hashFunc HashFunction, ks, vs [][]byte) {
  338. c := qt.New(b)
  339. tree, err := NewTree(memory.NewMemoryStorage(), 140, hashFunc)
  340. c.Assert(err, qt.IsNil)
  341. defer tree.db.Close()
  342. for i := 0; i < len(ks); i++ {
  343. if err := tree.Add(ks[i], vs[i]); err != nil {
  344. b.Fatal(err)
  345. }
  346. }
  347. }