You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

240 lines
6.6 KiB

3 years ago
  1. package arbo
  2. import (
  3. "encoding/hex"
  4. "fmt"
  5. "math/big"
  6. "testing"
  7. "github.com/iden3/go-merkletree/db/memory"
  8. "github.com/stretchr/testify/assert"
  9. "github.com/stretchr/testify/require"
  10. )
  11. func TestAddTestVectors(t *testing.T) {
  12. // Poseidon test vectors generated using https://github.com/iden3/circomlib smt.js
  13. testVectorsPoseidon := []string{
  14. "0000000000000000000000000000000000000000000000000000000000000000",
  15. "13578938674299138072471463694055224830892726234048532520316387704878000008795",
  16. "5412393676474193513566895793055462193090331607895808993925969873307089394741",
  17. "14204494359367183802864593755198662203838502594566452929175967972147978322084",
  18. }
  19. testAdd(t, HashFunctionPoseidon, testVectorsPoseidon)
  20. testVectorsSha256 := []string{
  21. "0000000000000000000000000000000000000000000000000000000000000000",
  22. "46910109172468462938850740851377282682950237270676610513794735904325820156367",
  23. "59481735341404520835410489183267411392292882901306595567679529387376287440550",
  24. "20573794434149960984975763118181266662429997821552560184909083010514790081771",
  25. }
  26. testAdd(t, HashFunctionSha256, testVectorsSha256)
  27. }
  28. func testAdd(t *testing.T, hashFunc HashFunction, testVectors []string) {
  29. tree, err := NewTree(memory.NewMemoryStorage(), 10, hashFunc)
  30. assert.Nil(t, err)
  31. defer tree.db.Close()
  32. assert.Equal(t, testVectors[0], hex.EncodeToString(tree.Root()))
  33. err = tree.Add(
  34. BigIntToBytes(big.NewInt(1)),
  35. BigIntToBytes(big.NewInt(2)))
  36. assert.Nil(t, err)
  37. rootBI := BytesToBigInt(tree.Root())
  38. assert.Equal(t, testVectors[1], rootBI.String())
  39. err = tree.Add(
  40. BigIntToBytes(big.NewInt(33)),
  41. BigIntToBytes(big.NewInt(44)))
  42. assert.Nil(t, err)
  43. rootBI = BytesToBigInt(tree.Root())
  44. assert.Equal(t, testVectors[2], rootBI.String())
  45. err = tree.Add(
  46. BigIntToBytes(big.NewInt(1234)),
  47. BigIntToBytes(big.NewInt(9876)))
  48. assert.Nil(t, err)
  49. rootBI = BytesToBigInt(tree.Root())
  50. assert.Equal(t, testVectors[3], rootBI.String())
  51. }
  52. func TestAddBatch(t *testing.T) {
  53. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  54. require.Nil(t, err)
  55. defer tree.db.Close()
  56. for i := 0; i < 1000; i++ {
  57. k := BigIntToBytes(big.NewInt(int64(i)))
  58. v := BigIntToBytes(big.NewInt(0))
  59. if err := tree.Add(k, v); err != nil {
  60. t.Fatal(err)
  61. }
  62. }
  63. rootBI := BytesToBigInt(tree.Root())
  64. assert.Equal(t,
  65. "296519252211642170490407814696803112091039265640052570497930797516015811235",
  66. rootBI.String())
  67. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  68. require.Nil(t, err)
  69. defer tree2.db.Close()
  70. var keys, values [][]byte
  71. for i := 0; i < 1000; i++ {
  72. k := BigIntToBytes(big.NewInt(int64(i)))
  73. v := BigIntToBytes(big.NewInt(0))
  74. keys = append(keys, k)
  75. values = append(values, v)
  76. }
  77. indexes, err := tree2.AddBatch(keys, values)
  78. assert.Nil(t, err)
  79. assert.Equal(t, 0, len(indexes))
  80. rootBI = BytesToBigInt(tree2.Root())
  81. assert.Equal(t,
  82. "296519252211642170490407814696803112091039265640052570497930797516015811235",
  83. rootBI.String())
  84. }
  85. func TestAddDifferentOrder(t *testing.T) {
  86. tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  87. require.Nil(t, err)
  88. defer tree1.db.Close()
  89. for i := 0; i < 16; i++ {
  90. k := SwapEndianness(big.NewInt(int64(i)).Bytes())
  91. v := SwapEndianness(big.NewInt(0).Bytes())
  92. if err := tree1.Add(k, v); err != nil {
  93. t.Fatal(err)
  94. }
  95. }
  96. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  97. require.Nil(t, err)
  98. defer tree2.db.Close()
  99. for i := 16 - 1; i >= 0; i-- {
  100. k := big.NewInt(int64(i)).Bytes()
  101. v := big.NewInt(0).Bytes()
  102. if err := tree2.Add(k, v); err != nil {
  103. t.Fatal(err)
  104. }
  105. }
  106. assert.Equal(t, hex.EncodeToString(tree1.Root()), hex.EncodeToString(tree2.Root()))
  107. assert.Equal(t,
  108. "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f",
  109. hex.EncodeToString(tree1.Root()))
  110. }
  111. func TestAddRepeatedIndex(t *testing.T) {
  112. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  113. require.Nil(t, err)
  114. defer tree.db.Close()
  115. k := big.NewInt(int64(3)).Bytes()
  116. v := big.NewInt(int64(12)).Bytes()
  117. if err := tree.Add(k, v); err != nil {
  118. t.Fatal(err)
  119. }
  120. err = tree.Add(k, v)
  121. assert.NotNil(t, err)
  122. assert.Equal(t, fmt.Errorf("max virtual level 100"), err)
  123. }
  124. func TestAux(t *testing.T) {
  125. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  126. require.Nil(t, err)
  127. defer tree.db.Close()
  128. k := BigIntToBytes(big.NewInt(int64(1)))
  129. v := BigIntToBytes(big.NewInt(int64(0)))
  130. err = tree.Add(k, v)
  131. assert.Nil(t, err)
  132. k = BigIntToBytes(big.NewInt(int64(256)))
  133. err = tree.Add(k, v)
  134. assert.Nil(t, err)
  135. k = BigIntToBytes(big.NewInt(int64(257)))
  136. err = tree.Add(k, v)
  137. assert.Nil(t, err)
  138. k = BigIntToBytes(big.NewInt(int64(515)))
  139. err = tree.Add(k, v)
  140. assert.Nil(t, err)
  141. k = BigIntToBytes(big.NewInt(int64(770)))
  142. err = tree.Add(k, v)
  143. assert.Nil(t, err)
  144. }
  145. func TestGet(t *testing.T) {
  146. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  147. require.Nil(t, err)
  148. defer tree.db.Close()
  149. for i := 0; i < 10; i++ {
  150. k := BigIntToBytes(big.NewInt(int64(i)))
  151. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  152. if err := tree.Add(k, v); err != nil {
  153. t.Fatal(err)
  154. }
  155. }
  156. k := BigIntToBytes(big.NewInt(int64(7)))
  157. gettedKey, gettedValue, err := tree.Get(k)
  158. assert.Nil(t, err)
  159. assert.Equal(t, k, gettedKey)
  160. assert.Equal(t, BigIntToBytes(big.NewInt(int64(7*2))), gettedValue)
  161. }
  162. func TestGenProofAndVerify(t *testing.T) {
  163. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  164. require.Nil(t, err)
  165. defer tree.db.Close()
  166. for i := 0; i < 10; i++ {
  167. k := BigIntToBytes(big.NewInt(int64(i)))
  168. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  169. if err := tree.Add(k, v); err != nil {
  170. t.Fatal(err)
  171. }
  172. }
  173. k := BigIntToBytes(big.NewInt(int64(7)))
  174. siblings, err := tree.GenProof(k)
  175. assert.Nil(t, err)
  176. k = BigIntToBytes(big.NewInt(int64(7)))
  177. v := BigIntToBytes(big.NewInt(int64(14)))
  178. verif, err := CheckProof(tree.hashFunction, k, v, tree.Root(), siblings)
  179. require.Nil(t, err)
  180. assert.True(t, verif)
  181. }
  182. func BenchmarkAdd(b *testing.B) {
  183. // prepare inputs
  184. var ks, vs [][]byte
  185. for i := 0; i < 1000; i++ {
  186. k := BigIntToBytes(big.NewInt(int64(i)))
  187. v := BigIntToBytes(big.NewInt(int64(i)))
  188. ks = append(ks, k)
  189. vs = append(vs, v)
  190. }
  191. b.Run("Poseidon", func(b *testing.B) {
  192. benchmarkAdd(b, HashFunctionPoseidon, ks, vs)
  193. })
  194. b.Run("Sha256", func(b *testing.B) {
  195. benchmarkAdd(b, HashFunctionSha256, ks, vs)
  196. })
  197. }
  198. func benchmarkAdd(b *testing.B, hashFunc HashFunction, ks, vs [][]byte) {
  199. tree, err := NewTree(memory.NewMemoryStorage(), 140, hashFunc)
  200. require.Nil(b, err)
  201. defer tree.db.Close()
  202. for i := 0; i < len(ks); i++ {
  203. if err := tree.Add(ks[i], vs[i]); err != nil {
  204. b.Fatal(err)
  205. }
  206. }
  207. }