You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

325 lines
9.1 KiB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
  1. package arbo
  2. import (
  3. "encoding/hex"
  4. "math/big"
  5. "testing"
  6. qt "github.com/frankban/quicktest"
  7. "github.com/iden3/go-merkletree/db/memory"
  8. )
  9. func TestAddTestVectors(t *testing.T) {
  10. c := qt.New(t)
  11. // Poseidon test vectors generated using https://github.com/iden3/circomlib smt.js
  12. testVectorsPoseidon := []string{
  13. "0000000000000000000000000000000000000000000000000000000000000000",
  14. "13578938674299138072471463694055224830892726234048532520316387704878000008795",
  15. "5412393676474193513566895793055462193090331607895808993925969873307089394741",
  16. "14204494359367183802864593755198662203838502594566452929175967972147978322084",
  17. }
  18. testAdd(c, HashFunctionPoseidon, testVectorsPoseidon)
  19. testVectorsSha256 := []string{
  20. "0000000000000000000000000000000000000000000000000000000000000000",
  21. "46910109172468462938850740851377282682950237270676610513794735904325820156367",
  22. "59481735341404520835410489183267411392292882901306595567679529387376287440550",
  23. "20573794434149960984975763118181266662429997821552560184909083010514790081771",
  24. }
  25. testAdd(c, HashFunctionSha256, testVectorsSha256)
  26. }
  27. func testAdd(c *qt.C, hashFunc HashFunction, testVectors []string) {
  28. tree, err := NewTree(memory.NewMemoryStorage(), 10, hashFunc)
  29. c.Assert(err, qt.IsNil)
  30. defer tree.db.Close()
  31. c.Check(hex.EncodeToString(tree.Root()), qt.Equals, testVectors[0])
  32. err = tree.Add(
  33. BigIntToBytes(big.NewInt(1)),
  34. BigIntToBytes(big.NewInt(2)))
  35. c.Assert(err, qt.IsNil)
  36. rootBI := BytesToBigInt(tree.Root())
  37. c.Check(rootBI.String(), qt.Equals, testVectors[1])
  38. err = tree.Add(
  39. BigIntToBytes(big.NewInt(33)),
  40. BigIntToBytes(big.NewInt(44)))
  41. c.Assert(err, qt.IsNil)
  42. rootBI = BytesToBigInt(tree.Root())
  43. c.Check(rootBI.String(), qt.Equals, testVectors[2])
  44. err = tree.Add(
  45. BigIntToBytes(big.NewInt(1234)),
  46. BigIntToBytes(big.NewInt(9876)))
  47. c.Assert(err, qt.IsNil)
  48. rootBI = BytesToBigInt(tree.Root())
  49. c.Check(rootBI.String(), qt.Equals, testVectors[3])
  50. }
  51. func TestAddBatch(t *testing.T) {
  52. c := qt.New(t)
  53. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  54. c.Assert(err, qt.IsNil)
  55. defer tree.db.Close()
  56. for i := 0; i < 1000; i++ {
  57. k := BigIntToBytes(big.NewInt(int64(i)))
  58. v := BigIntToBytes(big.NewInt(0))
  59. if err := tree.Add(k, v); err != nil {
  60. t.Fatal(err)
  61. }
  62. }
  63. rootBI := BytesToBigInt(tree.Root())
  64. c.Check(rootBI.String(), qt.Equals,
  65. "296519252211642170490407814696803112091039265640052570497930797516015811235")
  66. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  67. c.Assert(err, qt.IsNil)
  68. defer tree2.db.Close()
  69. var keys, values [][]byte
  70. for i := 0; i < 1000; i++ {
  71. k := BigIntToBytes(big.NewInt(int64(i)))
  72. v := BigIntToBytes(big.NewInt(0))
  73. keys = append(keys, k)
  74. values = append(values, v)
  75. }
  76. indexes, err := tree2.AddBatch(keys, values)
  77. c.Assert(err, qt.IsNil)
  78. c.Check(len(indexes), qt.Equals, 0)
  79. rootBI = BytesToBigInt(tree2.Root())
  80. c.Check(rootBI.String(), qt.Equals,
  81. "296519252211642170490407814696803112091039265640052570497930797516015811235")
  82. }
  83. func TestAddDifferentOrder(t *testing.T) {
  84. c := qt.New(t)
  85. tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  86. c.Assert(err, qt.IsNil)
  87. defer tree1.db.Close()
  88. for i := 0; i < 16; i++ {
  89. k := BigIntToBytes(big.NewInt(int64(i)))
  90. v := BigIntToBytes(big.NewInt(0))
  91. if err := tree1.Add(k, v); err != nil {
  92. t.Fatal(err)
  93. }
  94. }
  95. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  96. c.Assert(err, qt.IsNil)
  97. defer tree2.db.Close()
  98. for i := 16 - 1; i >= 0; i-- {
  99. k := BigIntToBytes(big.NewInt(int64(i)))
  100. v := BigIntToBytes(big.NewInt(0))
  101. if err := tree2.Add(k, v); err != nil {
  102. t.Fatal(err)
  103. }
  104. }
  105. c.Check(hex.EncodeToString(tree2.Root()), qt.Equals, hex.EncodeToString(tree1.Root()))
  106. c.Check(hex.EncodeToString(tree1.Root()), qt.Equals,
  107. "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f")
  108. }
  109. func TestAddRepeatedIndex(t *testing.T) {
  110. c := qt.New(t)
  111. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  112. c.Assert(err, qt.IsNil)
  113. defer tree.db.Close()
  114. k := BigIntToBytes(big.NewInt(int64(3)))
  115. v := BigIntToBytes(big.NewInt(int64(12)))
  116. if err := tree.Add(k, v); err != nil {
  117. t.Fatal(err)
  118. }
  119. err = tree.Add(k, v)
  120. c.Assert(err, qt.Not(qt.IsNil))
  121. c.Check(err, qt.ErrorMatches, "max virtual level 100")
  122. }
  123. func TestUpdate(t *testing.T) {
  124. c := qt.New(t)
  125. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  126. c.Assert(err, qt.IsNil)
  127. defer tree.db.Close()
  128. k := BigIntToBytes(big.NewInt(int64(20)))
  129. v := BigIntToBytes(big.NewInt(int64(12)))
  130. if err := tree.Add(k, v); err != nil {
  131. t.Fatal(err)
  132. }
  133. v = BigIntToBytes(big.NewInt(int64(11)))
  134. err = tree.Update(k, v)
  135. c.Assert(err, qt.IsNil)
  136. gettedKey, gettedValue, err := tree.Get(k)
  137. c.Assert(err, qt.IsNil)
  138. c.Check(gettedKey, qt.DeepEquals, k)
  139. c.Check(gettedValue, qt.DeepEquals, v)
  140. // add more leafs to the tree to do another test
  141. for i := 0; i < 16; i++ {
  142. k := BigIntToBytes(big.NewInt(int64(i)))
  143. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  144. if err := tree.Add(k, v); err != nil {
  145. t.Fatal(err)
  146. }
  147. }
  148. k = BigIntToBytes(big.NewInt(int64(3)))
  149. v = BigIntToBytes(big.NewInt(int64(11)))
  150. // check that before the Update, value for 3 is !=11
  151. gettedKey, gettedValue, err = tree.Get(k)
  152. c.Assert(err, qt.IsNil)
  153. c.Check(gettedKey, qt.DeepEquals, k)
  154. c.Check(gettedValue, qt.Not(qt.DeepEquals), v)
  155. c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(6)))
  156. err = tree.Update(k, v)
  157. c.Assert(err, qt.IsNil)
  158. // check that after Update, the value for 3 is ==11
  159. gettedKey, gettedValue, err = tree.Get(k)
  160. c.Assert(err, qt.IsNil)
  161. c.Check(gettedKey, qt.DeepEquals, k)
  162. c.Check(gettedValue, qt.DeepEquals, v)
  163. c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(11)))
  164. }
  165. func TestAux(t *testing.T) {
  166. c := qt.New(t)
  167. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  168. c.Assert(err, qt.IsNil)
  169. defer tree.db.Close()
  170. k := BigIntToBytes(big.NewInt(int64(1)))
  171. v := BigIntToBytes(big.NewInt(int64(0)))
  172. err = tree.Add(k, v)
  173. c.Assert(err, qt.IsNil)
  174. k = BigIntToBytes(big.NewInt(int64(256)))
  175. err = tree.Add(k, v)
  176. c.Assert(err, qt.IsNil)
  177. k = BigIntToBytes(big.NewInt(int64(257)))
  178. err = tree.Add(k, v)
  179. c.Assert(err, qt.IsNil)
  180. k = BigIntToBytes(big.NewInt(int64(515)))
  181. err = tree.Add(k, v)
  182. c.Assert(err, qt.IsNil)
  183. k = BigIntToBytes(big.NewInt(int64(770)))
  184. err = tree.Add(k, v)
  185. c.Assert(err, qt.IsNil)
  186. }
  187. func TestGet(t *testing.T) {
  188. c := qt.New(t)
  189. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  190. c.Assert(err, qt.IsNil)
  191. defer tree.db.Close()
  192. for i := 0; i < 10; i++ {
  193. k := BigIntToBytes(big.NewInt(int64(i)))
  194. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  195. if err := tree.Add(k, v); err != nil {
  196. t.Fatal(err)
  197. }
  198. }
  199. k := BigIntToBytes(big.NewInt(int64(7)))
  200. gettedKey, gettedValue, err := tree.Get(k)
  201. c.Assert(err, qt.IsNil)
  202. c.Check(gettedKey, qt.DeepEquals, k)
  203. c.Check(gettedValue, qt.DeepEquals, BigIntToBytes(big.NewInt(int64(7*2))))
  204. }
  205. func TestGenProofAndVerify(t *testing.T) {
  206. c := qt.New(t)
  207. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  208. c.Assert(err, qt.IsNil)
  209. defer tree.db.Close()
  210. for i := 0; i < 10; i++ {
  211. k := BigIntToBytes(big.NewInt(int64(i)))
  212. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  213. if err := tree.Add(k, v); err != nil {
  214. t.Fatal(err)
  215. }
  216. }
  217. k := BigIntToBytes(big.NewInt(int64(7)))
  218. siblings, err := tree.GenProof(k)
  219. c.Assert(err, qt.IsNil)
  220. k = BigIntToBytes(big.NewInt(int64(7)))
  221. v := BigIntToBytes(big.NewInt(int64(14)))
  222. verif, err := CheckProof(tree.hashFunction, k, v, tree.Root(), siblings)
  223. c.Assert(err, qt.IsNil)
  224. c.Check(verif, qt.IsTrue)
  225. }
  226. func TestDumpAndImportDump(t *testing.T) {
  227. c := qt.New(t)
  228. tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  229. c.Assert(err, qt.IsNil)
  230. defer tree1.db.Close()
  231. for i := 0; i < 16; i++ {
  232. k := BigIntToBytes(big.NewInt(int64(i)))
  233. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  234. if err := tree1.Add(k, v); err != nil {
  235. t.Fatal(err)
  236. }
  237. }
  238. e, err := tree1.Dump()
  239. c.Assert(err, qt.IsNil)
  240. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  241. c.Assert(err, qt.IsNil)
  242. defer tree2.db.Close()
  243. err = tree2.ImportDump(e)
  244. c.Assert(err, qt.IsNil)
  245. c.Check(tree2.Root(), qt.DeepEquals, tree1.Root())
  246. c.Check(hex.EncodeToString(tree2.Root()), qt.Equals,
  247. "0d93aaa3362b2f999f15e15728f123087c2eee716f01c01f56e23aae07f09f08")
  248. }
  249. func BenchmarkAdd(b *testing.B) {
  250. // prepare inputs
  251. var ks, vs [][]byte
  252. for i := 0; i < 1000; i++ {
  253. k := BigIntToBytes(big.NewInt(int64(i)))
  254. v := BigIntToBytes(big.NewInt(int64(i)))
  255. ks = append(ks, k)
  256. vs = append(vs, v)
  257. }
  258. b.Run("Poseidon", func(b *testing.B) {
  259. benchmarkAdd(b, HashFunctionPoseidon, ks, vs)
  260. })
  261. b.Run("Sha256", func(b *testing.B) {
  262. benchmarkAdd(b, HashFunctionSha256, ks, vs)
  263. })
  264. }
  265. func benchmarkAdd(b *testing.B, hashFunc HashFunction, ks, vs [][]byte) {
  266. c := qt.New(b)
  267. tree, err := NewTree(memory.NewMemoryStorage(), 140, hashFunc)
  268. c.Assert(err, qt.IsNil)
  269. defer tree.db.Close()
  270. for i := 0; i < len(ks); i++ {
  271. if err := tree.Add(ks[i], vs[i]); err != nil {
  272. b.Fatal(err)
  273. }
  274. }
  275. }