You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

408 lines
11 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. package arbo
  2. import (
  3. "encoding/hex"
  4. "math/big"
  5. "testing"
  6. "time"
  7. qt "github.com/frankban/quicktest"
  8. "github.com/iden3/go-merkletree/db/memory"
  9. )
  10. func TestAddTestVectors(t *testing.T) {
  11. c := qt.New(t)
  12. // Poseidon test vectors generated using https://github.com/iden3/circomlib smt.js
  13. testVectorsPoseidon := []string{
  14. "0000000000000000000000000000000000000000000000000000000000000000",
  15. "13578938674299138072471463694055224830892726234048532520316387704878000008795",
  16. "5412393676474193513566895793055462193090331607895808993925969873307089394741",
  17. "14204494359367183802864593755198662203838502594566452929175967972147978322084",
  18. }
  19. testAdd(c, HashFunctionPoseidon, testVectorsPoseidon)
  20. testVectorsSha256 := []string{
  21. "0000000000000000000000000000000000000000000000000000000000000000",
  22. "46910109172468462938850740851377282682950237270676610513794735904325820156367",
  23. "59481735341404520835410489183267411392292882901306595567679529387376287440550",
  24. "20573794434149960984975763118181266662429997821552560184909083010514790081771",
  25. }
  26. testAdd(c, HashFunctionSha256, testVectorsSha256)
  27. }
  28. func testAdd(c *qt.C, hashFunc HashFunction, testVectors []string) {
  29. tree, err := NewTree(memory.NewMemoryStorage(), 10, hashFunc)
  30. c.Assert(err, qt.IsNil)
  31. defer tree.db.Close()
  32. c.Check(hex.EncodeToString(tree.Root()), qt.Equals, testVectors[0])
  33. err = tree.Add(
  34. BigIntToBytes(big.NewInt(1)),
  35. BigIntToBytes(big.NewInt(2)))
  36. c.Assert(err, qt.IsNil)
  37. rootBI := BytesToBigInt(tree.Root())
  38. c.Check(rootBI.String(), qt.Equals, testVectors[1])
  39. err = tree.Add(
  40. BigIntToBytes(big.NewInt(33)),
  41. BigIntToBytes(big.NewInt(44)))
  42. c.Assert(err, qt.IsNil)
  43. rootBI = BytesToBigInt(tree.Root())
  44. c.Check(rootBI.String(), qt.Equals, testVectors[2])
  45. err = tree.Add(
  46. BigIntToBytes(big.NewInt(1234)),
  47. BigIntToBytes(big.NewInt(9876)))
  48. c.Assert(err, qt.IsNil)
  49. rootBI = BytesToBigInt(tree.Root())
  50. c.Check(rootBI.String(), qt.Equals, testVectors[3])
  51. }
  52. func TestAddBatch(t *testing.T) {
  53. c := qt.New(t)
  54. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  55. c.Assert(err, qt.IsNil)
  56. defer tree.db.Close()
  57. for i := 0; i < 1000; i++ {
  58. k := BigIntToBytes(big.NewInt(int64(i)))
  59. v := BigIntToBytes(big.NewInt(0))
  60. if err := tree.Add(k, v); err != nil {
  61. t.Fatal(err)
  62. }
  63. }
  64. rootBI := BytesToBigInt(tree.Root())
  65. c.Check(rootBI.String(), qt.Equals,
  66. "296519252211642170490407814696803112091039265640052570497930797516015811235")
  67. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  68. c.Assert(err, qt.IsNil)
  69. defer tree2.db.Close()
  70. var keys, values [][]byte
  71. for i := 0; i < 1000; i++ {
  72. k := BigIntToBytes(big.NewInt(int64(i)))
  73. v := BigIntToBytes(big.NewInt(0))
  74. keys = append(keys, k)
  75. values = append(values, v)
  76. }
  77. indexes, err := tree2.AddBatch(keys, values)
  78. c.Assert(err, qt.IsNil)
  79. c.Check(len(indexes), qt.Equals, 0)
  80. rootBI = BytesToBigInt(tree2.Root())
  81. c.Check(rootBI.String(), qt.Equals,
  82. "296519252211642170490407814696803112091039265640052570497930797516015811235")
  83. }
  84. func TestAddDifferentOrder(t *testing.T) {
  85. c := qt.New(t)
  86. tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  87. c.Assert(err, qt.IsNil)
  88. defer tree1.db.Close()
  89. for i := 0; i < 16; i++ {
  90. k := BigIntToBytes(big.NewInt(int64(i)))
  91. v := BigIntToBytes(big.NewInt(0))
  92. if err := tree1.Add(k, v); err != nil {
  93. t.Fatal(err)
  94. }
  95. }
  96. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  97. c.Assert(err, qt.IsNil)
  98. defer tree2.db.Close()
  99. for i := 16 - 1; i >= 0; i-- {
  100. k := BigIntToBytes(big.NewInt(int64(i)))
  101. v := BigIntToBytes(big.NewInt(0))
  102. if err := tree2.Add(k, v); err != nil {
  103. t.Fatal(err)
  104. }
  105. }
  106. c.Check(hex.EncodeToString(tree2.Root()), qt.Equals, hex.EncodeToString(tree1.Root()))
  107. c.Check(hex.EncodeToString(tree1.Root()), qt.Equals,
  108. "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f")
  109. }
  110. func TestAddRepeatedIndex(t *testing.T) {
  111. c := qt.New(t)
  112. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  113. c.Assert(err, qt.IsNil)
  114. defer tree.db.Close()
  115. k := BigIntToBytes(big.NewInt(int64(3)))
  116. v := BigIntToBytes(big.NewInt(int64(12)))
  117. if err := tree.Add(k, v); err != nil {
  118. t.Fatal(err)
  119. }
  120. err = tree.Add(k, v)
  121. c.Assert(err, qt.Not(qt.IsNil))
  122. c.Check(err, qt.ErrorMatches, "max virtual level 100")
  123. }
  124. func TestUpdate(t *testing.T) {
  125. c := qt.New(t)
  126. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  127. c.Assert(err, qt.IsNil)
  128. defer tree.db.Close()
  129. k := BigIntToBytes(big.NewInt(int64(20)))
  130. v := BigIntToBytes(big.NewInt(int64(12)))
  131. if err := tree.Add(k, v); err != nil {
  132. t.Fatal(err)
  133. }
  134. v = BigIntToBytes(big.NewInt(int64(11)))
  135. err = tree.Update(k, v)
  136. c.Assert(err, qt.IsNil)
  137. gettedKey, gettedValue, err := tree.Get(k)
  138. c.Assert(err, qt.IsNil)
  139. c.Check(gettedKey, qt.DeepEquals, k)
  140. c.Check(gettedValue, qt.DeepEquals, v)
  141. // add more leafs to the tree to do another test
  142. for i := 0; i < 16; i++ {
  143. k := BigIntToBytes(big.NewInt(int64(i)))
  144. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  145. if err := tree.Add(k, v); err != nil {
  146. t.Fatal(err)
  147. }
  148. }
  149. k = BigIntToBytes(big.NewInt(int64(3)))
  150. v = BigIntToBytes(big.NewInt(int64(11)))
  151. // check that before the Update, value for 3 is !=11
  152. gettedKey, gettedValue, err = tree.Get(k)
  153. c.Assert(err, qt.IsNil)
  154. c.Check(gettedKey, qt.DeepEquals, k)
  155. c.Check(gettedValue, qt.Not(qt.DeepEquals), v)
  156. c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(6)))
  157. err = tree.Update(k, v)
  158. c.Assert(err, qt.IsNil)
  159. // check that after Update, the value for 3 is ==11
  160. gettedKey, gettedValue, err = tree.Get(k)
  161. c.Assert(err, qt.IsNil)
  162. c.Check(gettedKey, qt.DeepEquals, k)
  163. c.Check(gettedValue, qt.DeepEquals, v)
  164. c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(11)))
  165. }
  166. func TestAux(t *testing.T) { // TODO split in proper tests
  167. c := qt.New(t)
  168. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  169. c.Assert(err, qt.IsNil)
  170. defer tree.db.Close()
  171. k := BigIntToBytes(big.NewInt(int64(1)))
  172. v := BigIntToBytes(big.NewInt(int64(0)))
  173. err = tree.Add(k, v)
  174. c.Assert(err, qt.IsNil)
  175. k = BigIntToBytes(big.NewInt(int64(256)))
  176. err = tree.Add(k, v)
  177. c.Assert(err, qt.IsNil)
  178. k = BigIntToBytes(big.NewInt(int64(257)))
  179. err = tree.Add(k, v)
  180. c.Assert(err, qt.IsNil)
  181. k = BigIntToBytes(big.NewInt(int64(515)))
  182. err = tree.Add(k, v)
  183. c.Assert(err, qt.IsNil)
  184. k = BigIntToBytes(big.NewInt(int64(770)))
  185. err = tree.Add(k, v)
  186. c.Assert(err, qt.IsNil)
  187. //
  188. // err = tree.PrintGraphviz(nil)
  189. // c.Assert(err, qt.IsNil)
  190. }
  191. func TestGet(t *testing.T) {
  192. c := qt.New(t)
  193. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  194. c.Assert(err, qt.IsNil)
  195. defer tree.db.Close()
  196. for i := 0; i < 10; i++ {
  197. k := BigIntToBytes(big.NewInt(int64(i)))
  198. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  199. if err := tree.Add(k, v); err != nil {
  200. t.Fatal(err)
  201. }
  202. }
  203. k := BigIntToBytes(big.NewInt(int64(7)))
  204. gettedKey, gettedValue, err := tree.Get(k)
  205. c.Assert(err, qt.IsNil)
  206. c.Check(gettedKey, qt.DeepEquals, k)
  207. c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(int64(7*2))))
  208. }
  209. func TestGenProofAndVerify(t *testing.T) {
  210. c := qt.New(t)
  211. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  212. c.Assert(err, qt.IsNil)
  213. defer tree.db.Close()
  214. for i := 0; i < 10; i++ {
  215. k := BigIntToBytes(big.NewInt(int64(i)))
  216. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  217. if err := tree.Add(k, v); err != nil {
  218. t.Fatal(err)
  219. }
  220. }
  221. k := BigIntToBytes(big.NewInt(int64(7)))
  222. siblings, err := tree.GenProof(k)
  223. c.Assert(err, qt.IsNil)
  224. k = BigIntToBytes(big.NewInt(int64(7)))
  225. v := BigIntToBytes(big.NewInt(int64(14)))
  226. verif, err := CheckProof(tree.hashFunction, k, v, tree.Root(), siblings)
  227. c.Assert(err, qt.IsNil)
  228. c.Check(verif, qt.IsTrue)
  229. }
  230. func TestDumpAndImportDump(t *testing.T) {
  231. c := qt.New(t)
  232. tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  233. c.Assert(err, qt.IsNil)
  234. defer tree1.db.Close()
  235. for i := 0; i < 16; i++ {
  236. k := BigIntToBytes(big.NewInt(int64(i)))
  237. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  238. if err := tree1.Add(k, v); err != nil {
  239. t.Fatal(err)
  240. }
  241. }
  242. e, err := tree1.Dump()
  243. c.Assert(err, qt.IsNil)
  244. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  245. c.Assert(err, qt.IsNil)
  246. defer tree2.db.Close()
  247. err = tree2.ImportDump(e)
  248. c.Assert(err, qt.IsNil)
  249. c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
  250. c.Check(hex.EncodeToString(tree2.Root()), qt.Equals,
  251. "0d93aaa3362b2f999f15e15728f123087c2eee716f01c01f56e23aae07f09f08")
  252. }
  253. func TestRWMutex(t *testing.T) {
  254. c := qt.New(t)
  255. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  256. c.Assert(err, qt.IsNil)
  257. defer tree.db.Close()
  258. var keys, values [][]byte
  259. for i := 0; i < 1000; i++ {
  260. k := BigIntToBytes(big.NewInt(int64(i)))
  261. v := BigIntToBytes(big.NewInt(0))
  262. keys = append(keys, k)
  263. values = append(values, v)
  264. }
  265. go func() {
  266. _, err = tree.AddBatch(keys, values)
  267. if err != nil {
  268. panic(err)
  269. }
  270. }()
  271. time.Sleep(500 * time.Millisecond)
  272. k := BigIntToBytes(big.NewInt(int64(99999)))
  273. v := BigIntToBytes(big.NewInt(int64(99999)))
  274. if err := tree.Add(k, v); err != nil {
  275. t.Fatal(err)
  276. }
  277. }
  278. func TestSetGetNLeafs(t *testing.T) {
  279. c := qt.New(t)
  280. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  281. c.Assert(err, qt.IsNil)
  282. // 0
  283. tree.tx, err = tree.db.NewTx()
  284. c.Assert(err, qt.IsNil)
  285. err = tree.setNLeafs(0)
  286. c.Assert(err, qt.IsNil)
  287. err = tree.tx.Commit()
  288. c.Assert(err, qt.IsNil)
  289. n, err := tree.GetNLeafs()
  290. c.Assert(err, qt.IsNil)
  291. c.Assert(n, qt.Equals, 0)
  292. // 1024
  293. tree.tx, err = tree.db.NewTx()
  294. c.Assert(err, qt.IsNil)
  295. err = tree.setNLeafs(1024)
  296. c.Assert(err, qt.IsNil)
  297. err = tree.tx.Commit()
  298. c.Assert(err, qt.IsNil)
  299. n, err = tree.GetNLeafs()
  300. c.Assert(err, qt.IsNil)
  301. c.Assert(n, qt.Equals, 1024)
  302. // 2**64 -1
  303. tree.tx, err = tree.db.NewTx()
  304. c.Assert(err, qt.IsNil)
  305. maxUint := ^uint(0)
  306. maxInt := int(maxUint >> 1)
  307. err = tree.setNLeafs(maxInt)
  308. c.Assert(err, qt.IsNil)
  309. err = tree.tx.Commit()
  310. c.Assert(err, qt.IsNil)
  311. n, err = tree.GetNLeafs()
  312. c.Assert(err, qt.IsNil)
  313. c.Assert(n, qt.Equals, maxInt)
  314. }
  315. func BenchmarkAdd(b *testing.B) {
  316. // prepare inputs
  317. var ks, vs [][]byte
  318. for i := 0; i < 1000; i++ {
  319. k := BigIntToBytes(big.NewInt(int64(i)))
  320. v := BigIntToBytes(big.NewInt(int64(i)))
  321. ks = append(ks, k)
  322. vs = append(vs, v)
  323. }
  324. b.Run("Poseidon", func(b *testing.B) {
  325. benchmarkAdd(b, HashFunctionPoseidon, ks, vs)
  326. })
  327. b.Run("Sha256", func(b *testing.B) {
  328. benchmarkAdd(b, HashFunctionSha256, ks, vs)
  329. })
  330. }
  331. func benchmarkAdd(b *testing.B, hashFunc HashFunction, ks, vs [][]byte) {
  332. c := qt.New(b)
  333. tree, err := NewTree(memory.NewMemoryStorage(), 140, hashFunc)
  334. c.Assert(err, qt.IsNil)
  335. defer tree.db.Close()
  336. for i := 0; i < len(ks); i++ {
  337. if err := tree.Add(ks[i], vs[i]); err != nil {
  338. b.Fatal(err)
  339. }
  340. }
  341. }