You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

200 lines
5.5 KiB

  1. package arbo
  2. import (
  3. "encoding/hex"
  4. "fmt"
  5. "math/big"
  6. "testing"
  7. "github.com/iden3/go-merkletree/db/memory"
  8. "github.com/stretchr/testify/assert"
  9. "github.com/stretchr/testify/require"
  10. )
  11. func TestAddTestVectors(t *testing.T) {
  12. // Poseidon test vectors generated using https://github.com/iden3/circomlib smt.js
  13. testVectorsPoseidon := []string{
  14. "0000000000000000000000000000000000000000000000000000000000000000",
  15. "13578938674299138072471463694055224830892726234048532520316387704878000008795",
  16. "5412393676474193513566895793055462193090331607895808993925969873307089394741",
  17. "14204494359367183802864593755198662203838502594566452929175967972147978322084",
  18. }
  19. testAdd(t, HashFunctionPoseidon, testVectorsPoseidon)
  20. testVectorsSha256 := []string{
  21. "0000000000000000000000000000000000000000000000000000000000000000",
  22. "46910109172468462938850740851377282682950237270676610513794735904325820156367",
  23. "59481735341404520835410489183267411392292882901306595567679529387376287440550",
  24. "20573794434149960984975763118181266662429997821552560184909083010514790081771",
  25. }
  26. testAdd(t, HashFunctionSha256, testVectorsSha256)
  27. }
  28. func testAdd(t *testing.T, hashFunc HashFunction, testVectors []string) {
  29. tree, err := NewTree(memory.NewMemoryStorage(), 10, hashFunc)
  30. assert.Nil(t, err)
  31. defer tree.db.Close()
  32. assert.Equal(t, testVectors[0], hex.EncodeToString(tree.Root()))
  33. err = tree.Add(
  34. BigIntToBytes(big.NewInt(1)),
  35. BigIntToBytes(big.NewInt(2)))
  36. assert.Nil(t, err)
  37. rootBI := BytesToBigInt(tree.Root())
  38. assert.Equal(t, testVectors[1], rootBI.String())
  39. err = tree.Add(
  40. BigIntToBytes(big.NewInt(33)),
  41. BigIntToBytes(big.NewInt(44)))
  42. assert.Nil(t, err)
  43. rootBI = BytesToBigInt(tree.Root())
  44. assert.Equal(t, testVectors[2], rootBI.String())
  45. err = tree.Add(
  46. BigIntToBytes(big.NewInt(1234)),
  47. BigIntToBytes(big.NewInt(9876)))
  48. assert.Nil(t, err)
  49. rootBI = BytesToBigInt(tree.Root())
  50. assert.Equal(t, testVectors[3], rootBI.String())
  51. }
  52. func TestAdd1000(t *testing.T) {
  53. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  54. require.Nil(t, err)
  55. defer tree.db.Close()
  56. for i := 0; i < 1000; i++ {
  57. k := BigIntToBytes(big.NewInt(int64(i)))
  58. v := BigIntToBytes(big.NewInt(0))
  59. if err := tree.Add(k, v); err != nil {
  60. t.Fatal(err)
  61. }
  62. }
  63. rootBI := BytesToBigInt(tree.Root())
  64. assert.Equal(t,
  65. "296519252211642170490407814696803112091039265640052570497930797516015811235",
  66. rootBI.String())
  67. }
  68. func TestAddDifferentOrder(t *testing.T) {
  69. tree1, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  70. require.Nil(t, err)
  71. defer tree1.db.Close()
  72. for i := 0; i < 16; i++ {
  73. k := SwapEndianness(big.NewInt(int64(i)).Bytes())
  74. v := SwapEndianness(big.NewInt(0).Bytes())
  75. if err := tree1.Add(k, v); err != nil {
  76. t.Fatal(err)
  77. }
  78. }
  79. tree2, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  80. require.Nil(t, err)
  81. defer tree2.db.Close()
  82. for i := 16 - 1; i >= 0; i-- {
  83. k := big.NewInt(int64(i)).Bytes()
  84. v := big.NewInt(0).Bytes()
  85. if err := tree2.Add(k, v); err != nil {
  86. t.Fatal(err)
  87. }
  88. }
  89. assert.Equal(t, hex.EncodeToString(tree1.Root()), hex.EncodeToString(tree2.Root()))
  90. assert.Equal(t,
  91. "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f",
  92. hex.EncodeToString(tree1.Root()))
  93. }
  94. func TestAddRepeatedIndex(t *testing.T) {
  95. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  96. require.Nil(t, err)
  97. defer tree.db.Close()
  98. k := big.NewInt(int64(3)).Bytes()
  99. v := big.NewInt(int64(12)).Bytes()
  100. if err := tree.Add(k, v); err != nil {
  101. t.Fatal(err)
  102. }
  103. err = tree.Add(k, v)
  104. assert.NotNil(t, err)
  105. assert.Equal(t, fmt.Errorf("max virtual level 100"), err)
  106. }
  107. func TestAux(t *testing.T) {
  108. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  109. require.Nil(t, err)
  110. defer tree.db.Close()
  111. k := BigIntToBytes(big.NewInt(int64(1)))
  112. v := BigIntToBytes(big.NewInt(int64(0)))
  113. err = tree.Add(k, v)
  114. assert.Nil(t, err)
  115. k = BigIntToBytes(big.NewInt(int64(256)))
  116. err = tree.Add(k, v)
  117. assert.Nil(t, err)
  118. k = BigIntToBytes(big.NewInt(int64(257)))
  119. err = tree.Add(k, v)
  120. assert.Nil(t, err)
  121. k = BigIntToBytes(big.NewInt(int64(515)))
  122. err = tree.Add(k, v)
  123. assert.Nil(t, err)
  124. k = BigIntToBytes(big.NewInt(int64(770)))
  125. err = tree.Add(k, v)
  126. assert.Nil(t, err)
  127. }
  128. func TestGenProofAndVerify(t *testing.T) {
  129. tree, err := NewTree(memory.NewMemoryStorage(), 100, HashFunctionPoseidon)
  130. require.Nil(t, err)
  131. defer tree.db.Close()
  132. for i := 0; i < 10; i++ {
  133. k := BigIntToBytes(big.NewInt(int64(i)))
  134. v := BigIntToBytes(big.NewInt(int64(i * 2)))
  135. if err := tree.Add(k, v); err != nil {
  136. t.Fatal(err)
  137. }
  138. }
  139. k := BigIntToBytes(big.NewInt(int64(7)))
  140. siblings, err := tree.GenProof(k)
  141. assert.Nil(t, err)
  142. k = BigIntToBytes(big.NewInt(int64(7)))
  143. v := BigIntToBytes(big.NewInt(int64(14)))
  144. verif, err := CheckProof(tree.hashFunction, k, v, tree.Root(), siblings)
  145. require.Nil(t, err)
  146. assert.True(t, verif)
  147. }
  148. func BenchmarkAdd(b *testing.B) {
  149. // prepare inputs
  150. var ks, vs [][]byte
  151. for i := 0; i < 1000; i++ {
  152. k := BigIntToBytes(big.NewInt(int64(i)))
  153. v := BigIntToBytes(big.NewInt(int64(i)))
  154. ks = append(ks, k)
  155. vs = append(vs, v)
  156. }
  157. b.Run("Poseidon", func(b *testing.B) {
  158. benchmarkAdd(b, HashFunctionPoseidon, ks, vs)
  159. })
  160. b.Run("Sha256", func(b *testing.B) {
  161. benchmarkAdd(b, HashFunctionSha256, ks, vs)
  162. })
  163. }
  164. func benchmarkAdd(b *testing.B, hashFunc HashFunction, ks, vs [][]byte) {
  165. tree, err := NewTree(memory.NewMemoryStorage(), 140, hashFunc)
  166. require.Nil(b, err)
  167. defer tree.db.Close()
  168. for i := 0; i < len(ks); i++ {
  169. if err := tree.Add(ks[i], vs[i]); err != nil {
  170. b.Fatal(err)
  171. }
  172. }
  173. }