You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

678 lines
21 KiB

4 years ago
  1. package merkletree
  2. import (
  3. "bytes"
  4. "encoding/hex"
  5. "fmt"
  6. "math/big"
  7. "testing"
  8. "github.com/iden3/go-iden3-crypto/constants"
  9. cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
  10. "github.com/iden3/go-merkletree/db/memory"
  11. "github.com/stretchr/testify/assert"
  12. "github.com/stretchr/testify/require"
  13. )
  14. var debug = false
  15. type Fatalable interface {
  16. Fatal(args ...interface{})
  17. }
  18. func newTestingMerkle(f Fatalable, numLevels int) *MerkleTree {
  19. mt, err := NewMerkleTree(memory.NewMemoryStorage(), numLevels)
  20. if err != nil {
  21. f.Fatal(err)
  22. return nil
  23. }
  24. return mt
  25. }
  26. func TestHashParsers(t *testing.T) {
  27. h0 := NewHashFromBigInt(big.NewInt(0))
  28. assert.Equal(t, "0", h0.String())
  29. h1 := NewHashFromBigInt(big.NewInt(1))
  30. assert.Equal(t, "1", h1.String())
  31. h10 := NewHashFromBigInt(big.NewInt(10))
  32. assert.Equal(t, "10", h10.String())
  33. h7l := NewHashFromBigInt(big.NewInt(1234567))
  34. assert.Equal(t, "1234567", h7l.String())
  35. h8l := NewHashFromBigInt(big.NewInt(12345678))
  36. assert.Equal(t, "12345678...", h8l.String())
  37. b, ok := new(big.Int).SetString("4932297968297298434239270129193057052722409868268166443802652458940273154854", 10)
  38. assert.True(t, ok)
  39. h := NewHashFromBigInt(b)
  40. assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", h.BigInt().String())
  41. assert.Equal(t, "49322979...", h.String())
  42. assert.Equal(t, "265baaf161e875c372d08e50f52abddc01d32efc93e90290bb8b3d9ceb94e70a", h.Hex())
  43. b1, err := NewBigIntFromHashBytes(b.Bytes())
  44. assert.Nil(t, err)
  45. assert.Equal(t, new(big.Int).SetBytes(b.Bytes()).String(), b1.String())
  46. b2, err := NewHashFromBytes(b.Bytes())
  47. assert.Nil(t, err)
  48. assert.Equal(t, b.String(), b2.BigInt().String())
  49. h2, err := NewHashFromHex(h.Hex())
  50. assert.Nil(t, err)
  51. assert.Equal(t, h, h2)
  52. _, err = NewHashFromHex("0x12")
  53. assert.NotNil(t, err)
  54. // check limits
  55. a := new(big.Int).Sub(constants.Q, big.NewInt(1))
  56. testHashParsers(t, a)
  57. a = big.NewInt(int64(1))
  58. testHashParsers(t, a)
  59. }
  60. func testHashParsers(t *testing.T, a *big.Int) {
  61. require.True(t, cryptoUtils.CheckBigIntInField(a))
  62. h := NewHashFromBigInt(a)
  63. assert.Equal(t, a, h.BigInt())
  64. hFromBytes, err := NewHashFromBytes(h.Bytes())
  65. assert.Nil(t, err)
  66. assert.Equal(t, h, hFromBytes)
  67. assert.Equal(t, a, hFromBytes.BigInt())
  68. assert.Equal(t, a.String(), hFromBytes.BigInt().String())
  69. hFromHex, err := NewHashFromHex(h.Hex())
  70. assert.Nil(t, err)
  71. assert.Equal(t, h, hFromHex)
  72. aBIFromHBytes, err := NewBigIntFromHashBytes(h.Bytes())
  73. assert.Nil(t, err)
  74. assert.Equal(t, a, aBIFromHBytes)
  75. assert.Equal(t, new(big.Int).SetBytes(a.Bytes()).String(), aBIFromHBytes.String())
  76. }
  77. func TestNewTree(t *testing.T) {
  78. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  79. assert.Nil(t, err)
  80. assert.Equal(t, "0", mt.Root().String())
  81. // test vectors generated using https://github.com/iden3/circomlib smt.js
  82. err = mt.Add(big.NewInt(1), big.NewInt(2))
  83. assert.Nil(t, err)
  84. assert.Equal(t, "6449712043256457369579901840927028403950625973089336675272087704159094984964", mt.Root().BigInt().String())
  85. err = mt.Add(big.NewInt(33), big.NewInt(44))
  86. assert.Nil(t, err)
  87. assert.Equal(t, "11404118908468506234838877883514126008995570353394659302846433035311596046064", mt.Root().BigInt().String())
  88. err = mt.Add(big.NewInt(1234), big.NewInt(9876))
  89. assert.Nil(t, err)
  90. assert.Equal(t, "12841932325181810040554102151615400973767747666110051836366805309524360490677", mt.Root().BigInt().String())
  91. proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
  92. assert.Nil(t, err)
  93. assert.Equal(t, big.NewInt(44), v)
  94. assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44)))
  95. assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45)))
  96. }
  97. func TestAddDifferentOrder(t *testing.T) {
  98. mt1 := newTestingMerkle(t, 140)
  99. defer mt1.db.Close()
  100. for i := 0; i < 16; i++ {
  101. k := big.NewInt(int64(i))
  102. v := big.NewInt(0)
  103. if err := mt1.Add(k, v); err != nil {
  104. t.Fatal(err)
  105. }
  106. }
  107. mt2 := newTestingMerkle(t, 140)
  108. defer mt2.db.Close()
  109. for i := 16 - 1; i >= 0; i-- {
  110. k := big.NewInt(int64(i))
  111. v := big.NewInt(0)
  112. if err := mt2.Add(k, v); err != nil {
  113. t.Fatal(err)
  114. }
  115. }
  116. assert.Equal(t, mt1.Root().Hex(), mt2.Root().Hex())
  117. assert.Equal(t, "268e25964aa9d6ba42d66ae9eb44b5528540acb19a3644d1367d8c6f7cb23006", mt1.Root().Hex())
  118. }
  119. func TestAddRepeatedIndex(t *testing.T) {
  120. mt := newTestingMerkle(t, 140)
  121. defer mt.db.Close()
  122. k := big.NewInt(int64(3))
  123. v := big.NewInt(int64(12))
  124. if err := mt.Add(k, v); err != nil {
  125. t.Fatal(err)
  126. }
  127. err := mt.Add(k, v)
  128. assert.NotNil(t, err)
  129. assert.Equal(t, err, ErrEntryIndexAlreadyExists)
  130. }
  131. func TestGet(t *testing.T) {
  132. mt := newTestingMerkle(t, 140)
  133. defer mt.db.Close()
  134. for i := 0; i < 16; i++ {
  135. k := big.NewInt(int64(i))
  136. v := big.NewInt(int64(i * 2))
  137. if err := mt.Add(k, v); err != nil {
  138. t.Fatal(err)
  139. }
  140. }
  141. k, v, _, err := mt.Get(big.NewInt(10))
  142. assert.Nil(t, err)
  143. assert.Equal(t, big.NewInt(10), k)
  144. assert.Equal(t, big.NewInt(20), v)
  145. k, v, _, err = mt.Get(big.NewInt(15))
  146. assert.Nil(t, err)
  147. assert.Equal(t, big.NewInt(15), k)
  148. assert.Equal(t, big.NewInt(30), v)
  149. k, v, _, err = mt.Get(big.NewInt(16))
  150. assert.NotNil(t, err)
  151. assert.Equal(t, ErrKeyNotFound, err)
  152. assert.Equal(t, "0", k.String())
  153. assert.Equal(t, "0", v.String())
  154. }
  155. func TestUpdate(t *testing.T) {
  156. mt := newTestingMerkle(t, 140)
  157. defer mt.db.Close()
  158. for i := 0; i < 16; i++ {
  159. k := big.NewInt(int64(i))
  160. v := big.NewInt(int64(i * 2))
  161. if err := mt.Add(k, v); err != nil {
  162. t.Fatal(err)
  163. }
  164. }
  165. _, v, _, err := mt.Get(big.NewInt(10))
  166. assert.Nil(t, err)
  167. assert.Equal(t, big.NewInt(20), v)
  168. _, err = mt.Update(big.NewInt(10), big.NewInt(1024))
  169. assert.Nil(t, err)
  170. _, v, _, err = mt.Get(big.NewInt(10))
  171. assert.Nil(t, err)
  172. assert.Equal(t, big.NewInt(1024), v)
  173. _, err = mt.Update(big.NewInt(1000), big.NewInt(1024))
  174. assert.Equal(t, ErrKeyNotFound, err)
  175. }
  176. func TestUpdate2(t *testing.T) {
  177. mt1 := newTestingMerkle(t, 140)
  178. defer mt1.db.Close()
  179. mt2 := newTestingMerkle(t, 140)
  180. defer mt2.db.Close()
  181. err := mt1.Add(big.NewInt(1), big.NewInt(119))
  182. assert.Nil(t, err)
  183. err = mt1.Add(big.NewInt(2), big.NewInt(229))
  184. assert.Nil(t, err)
  185. err = mt1.Add(big.NewInt(9876), big.NewInt(6789))
  186. assert.Nil(t, err)
  187. err = mt2.Add(big.NewInt(1), big.NewInt(11))
  188. assert.Nil(t, err)
  189. err = mt2.Add(big.NewInt(2), big.NewInt(22))
  190. assert.Nil(t, err)
  191. err = mt2.Add(big.NewInt(9876), big.NewInt(10))
  192. assert.Nil(t, err)
  193. _, err = mt1.Update(big.NewInt(1), big.NewInt(11))
  194. assert.Nil(t, err)
  195. _, err = mt1.Update(big.NewInt(2), big.NewInt(22))
  196. assert.Nil(t, err)
  197. _, err = mt2.Update(big.NewInt(9876), big.NewInt(6789))
  198. assert.Nil(t, err)
  199. assert.Equal(t, mt1.Root(), mt2.Root())
  200. }
  201. func TestGenerateAndVerifyProof128(t *testing.T) {
  202. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
  203. require.Nil(t, err)
  204. defer mt.db.Close()
  205. for i := 0; i < 128; i++ {
  206. k := big.NewInt(int64(i))
  207. v := big.NewInt(0)
  208. if err := mt.Add(k, v); err != nil {
  209. t.Fatal(err)
  210. }
  211. }
  212. proof, v, err := mt.GenerateProof(big.NewInt(42), nil)
  213. assert.Nil(t, err)
  214. assert.Equal(t, "0", v.String())
  215. assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(42), big.NewInt(0)))
  216. }
  217. func TestTreeLimit(t *testing.T) {
  218. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 5)
  219. require.Nil(t, err)
  220. defer mt.db.Close()
  221. for i := 0; i < 16; i++ {
  222. err = mt.Add(big.NewInt(int64(i)), big.NewInt(int64(i)))
  223. assert.Nil(t, err)
  224. }
  225. // here the tree is full, should not allow to add more data as reaches the maximum number of levels
  226. err = mt.Add(big.NewInt(int64(16)), big.NewInt(int64(16)))
  227. assert.NotNil(t, err)
  228. assert.Equal(t, ErrReachedMaxLevel, err)
  229. }
  230. func TestSiblingsFromProof(t *testing.T) {
  231. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
  232. require.Nil(t, err)
  233. defer mt.db.Close()
  234. for i := 0; i < 64; i++ {
  235. k := big.NewInt(int64(i))
  236. v := big.NewInt(0)
  237. if err := mt.Add(k, v); err != nil {
  238. t.Fatal(err)
  239. }
  240. }
  241. proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
  242. if err != nil {
  243. t.Fatal(err)
  244. }
  245. siblings := SiblingsFromProof(proof)
  246. assert.Equal(t, 6, len(siblings))
  247. assert.Equal(t, "5b478bdd58595ead03ebf494a74014cbb576ba0d9456aa0916885b9eefae592f", siblings[0].Hex())
  248. assert.Equal(t, "c1e8ab120a4e475ea1bf00633228bfb9d248f7ddec2aa6367f98d0defb9fb22e", siblings[1].Hex())
  249. assert.Equal(t, "f4dafd8ac2b9165adc3f6d125af67d5a4d8a7a263dcc90a373d0338929e16e0c", siblings[2].Hex())
  250. assert.Equal(t, "a94aa346bd85f96aba2e85b67920e44fe6ed767b0e13bea602784e0b8b897515", siblings[3].Hex())
  251. assert.Equal(t, "54791d7514030ded79301dbf221f5bf186facbc5800912411852fdc101b7151d", siblings[4].Hex())
  252. assert.Equal(t, "435d28bc0511f8feb93b5f1649a049b460947702ce0baaefcf596175370fe01e", siblings[5].Hex())
  253. }
  254. func TestVerifyProofCases(t *testing.T) {
  255. mt := newTestingMerkle(t, 140)
  256. defer mt.DB().Close()
  257. for i := 0; i < 8; i++ {
  258. if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
  259. t.Fatal(err)
  260. }
  261. }
  262. // Existence proof
  263. proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
  264. if err != nil {
  265. t.Fatal(err)
  266. }
  267. assert.Equal(t, proof.Existence, true)
  268. assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(4), big.NewInt(0)))
  269. assert.Equal(t, "0003000000000000000000000000000000000000000000000000000000000007a6d6b46fefe213a6b579844a1bb7ab5c2db4a13f8662d9c5e729c36728f42730211ddfcc8d30ebd157d1d6912769b8e4abdca41e5dc2b57b026a361c091a8c14c748530e61bf8ea80c987657c3d24b134ece1ef8e2d4bd3f74437bf4392a6b1e", hex.EncodeToString(proof.Bytes()))
  270. for i := 8; i < 32; i++ {
  271. proof, _, err = mt.GenerateProof(big.NewInt(int64(i)), nil)
  272. assert.Nil(t, err)
  273. if debug {
  274. fmt.Println(i, proof)
  275. }
  276. }
  277. // Non-existence proof, empty aux
  278. proof, _, err = mt.GenerateProof(big.NewInt(12), nil)
  279. if err != nil {
  280. t.Fatal(err)
  281. }
  282. assert.Equal(t, proof.Existence, false)
  283. // assert.True(t, proof.nodeAux == nil)
  284. assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(12), big.NewInt(0)))
  285. assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007a6d6b46fefe213a6b579844a1bb7ab5c2db4a13f8662d9c5e729c36728f42730211ddfcc8d30ebd157d1d6912769b8e4abdca41e5dc2b57b026a361c091a8c14c748530e61bf8ea80c987657c3d24b134ece1ef8e2d4bd3f74437bf4392a6b1e04000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes()))
  286. // Non-existence proof, diff. node aux
  287. proof, _, err = mt.GenerateProof(big.NewInt(10), nil)
  288. if err != nil {
  289. t.Fatal(err)
  290. }
  291. assert.Equal(t, proof.Existence, false)
  292. assert.True(t, proof.NodeAux != nil)
  293. assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(10), big.NewInt(0)))
  294. assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007a6d6b46fefe213a6b579844a1bb7ab5c2db4a13f8662d9c5e729c36728f42730e667e2ca15909c4a23beff18e3cc74348fbd3c1a4c765a5bbbca126c9607a42b77e008a73926f1280f8531b139dc1cacf8d83fcec31d405f5c51b7cbddfe152902000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes()))
  295. }
  296. func TestVerifyProofFalse(t *testing.T) {
  297. mt := newTestingMerkle(t, 140)
  298. defer mt.DB().Close()
  299. for i := 0; i < 8; i++ {
  300. if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
  301. t.Fatal(err)
  302. }
  303. }
  304. // Invalid existence proof (node used for verification doesn't
  305. // correspond to node in the proof)
  306. proof, _, err := mt.GenerateProof(big.NewInt(int64(4)), nil)
  307. if err != nil {
  308. t.Fatal(err)
  309. }
  310. assert.Equal(t, proof.Existence, true)
  311. assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(int64(5)), big.NewInt(int64(5))))
  312. // Invalid non-existence proof (Non-existence proof, diff. node aux)
  313. proof, _, err = mt.GenerateProof(big.NewInt(int64(4)), nil)
  314. if err != nil {
  315. t.Fatal(err)
  316. }
  317. assert.Equal(t, proof.Existence, true)
  318. // Now we change the proof from existence to non-existence, and add e's
  319. // data as auxiliary node.
  320. proof.Existence = false
  321. proof.NodeAux = &NodeAux{Key: NewHashFromBigInt(big.NewInt(int64(4))), Value: NewHashFromBigInt(big.NewInt(4))}
  322. assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(int64(4)), big.NewInt(0)))
  323. }
  324. func TestGraphViz(t *testing.T) {
  325. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  326. assert.Nil(t, err)
  327. _ = mt.Add(big.NewInt(1), big.NewInt(0))
  328. _ = mt.Add(big.NewInt(2), big.NewInt(0))
  329. _ = mt.Add(big.NewInt(3), big.NewInt(0))
  330. _ = mt.Add(big.NewInt(4), big.NewInt(0))
  331. _ = mt.Add(big.NewInt(5), big.NewInt(0))
  332. _ = mt.Add(big.NewInt(100), big.NewInt(0))
  333. // mt.PrintGraphViz(nil)
  334. expected := `digraph hierarchy {
  335. node [fontname=Monospace,fontsize=10,shape=box]
  336. "16053348..." -> {"19137630..." "14119616..."}
  337. "19137630..." -> {"19543983..." "19746229..."}
  338. "19543983..." -> {"empty0" "65773153..."}
  339. "empty0" [style=dashed,label=0];
  340. "65773153..." -> {"73498412..." "empty1"}
  341. "empty1" [style=dashed,label=0];
  342. "73498412..." -> {"53169236..." "empty2"}
  343. "empty2" [style=dashed,label=0];
  344. "53169236..." -> {"73522717..." "34811870..."}
  345. "73522717..." [style=filled];
  346. "34811870..." [style=filled];
  347. "19746229..." [style=filled];
  348. "14119616..." -> {"19419204..." "15569531..."}
  349. "19419204..." -> {"78154875..." "34589916..."}
  350. "78154875..." [style=filled];
  351. "34589916..." [style=filled];
  352. "15569531..." [style=filled];
  353. }
  354. `
  355. w := bytes.NewBufferString("")
  356. err = mt.GraphViz(w, nil)
  357. assert.Nil(t, err)
  358. assert.Equal(t, []byte(expected), w.Bytes())
  359. }
  360. func TestDelete(t *testing.T) {
  361. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  362. assert.Nil(t, err)
  363. assert.Equal(t, "0", mt.Root().String())
  364. // test vectors generated using https://github.com/iden3/circomlib smt.js
  365. err = mt.Add(big.NewInt(1), big.NewInt(2))
  366. assert.Nil(t, err)
  367. assert.Equal(t, "6449712043256457369579901840927028403950625973089336675272087704159094984964", mt.Root().BigInt().String())
  368. err = mt.Add(big.NewInt(33), big.NewInt(44))
  369. assert.Nil(t, err)
  370. assert.Equal(t, "11404118908468506234838877883514126008995570353394659302846433035311596046064", mt.Root().BigInt().String())
  371. err = mt.Add(big.NewInt(1234), big.NewInt(9876))
  372. assert.Nil(t, err)
  373. assert.Equal(t, "12841932325181810040554102151615400973767747666110051836366805309524360490677", mt.Root().BigInt().String())
  374. // mt.PrintGraphViz(nil)
  375. err = mt.Delete(big.NewInt(33))
  376. // mt.PrintGraphViz(nil)
  377. assert.Nil(t, err)
  378. assert.Equal(t, "16195585003843604118922861401064871511855368913846540536604351220077317790615", mt.Root().BigInt().String())
  379. err = mt.Delete(big.NewInt(1234))
  380. assert.Nil(t, err)
  381. err = mt.Delete(big.NewInt(1))
  382. assert.Nil(t, err)
  383. assert.Equal(t, "0", mt.Root().String())
  384. }
  385. func TestDelete2(t *testing.T) {
  386. mt := newTestingMerkle(t, 140)
  387. defer mt.db.Close()
  388. for i := 0; i < 8; i++ {
  389. k := big.NewInt(int64(i))
  390. v := big.NewInt(0)
  391. if err := mt.Add(k, v); err != nil {
  392. t.Fatal(err)
  393. }
  394. }
  395. expectedRoot := mt.Root()
  396. k := big.NewInt(8)
  397. v := big.NewInt(0)
  398. err := mt.Add(k, v)
  399. require.Nil(t, err)
  400. err = mt.Delete(big.NewInt(8))
  401. assert.Nil(t, err)
  402. assert.Equal(t, expectedRoot, mt.Root())
  403. mt2 := newTestingMerkle(t, 140)
  404. defer mt2.db.Close()
  405. for i := 0; i < 8; i++ {
  406. k := big.NewInt(int64(i))
  407. v := big.NewInt(0)
  408. if err := mt2.Add(k, v); err != nil {
  409. t.Fatal(err)
  410. }
  411. }
  412. assert.Equal(t, mt2.Root(), mt.Root())
  413. }
  414. func TestDelete3(t *testing.T) {
  415. mt := newTestingMerkle(t, 140)
  416. defer mt.db.Close()
  417. err := mt.Add(big.NewInt(1), big.NewInt(1))
  418. assert.Nil(t, err)
  419. err = mt.Add(big.NewInt(2), big.NewInt(2))
  420. assert.Nil(t, err)
  421. assert.Equal(t, "6701939280963330813043570145125351311131831356446202146710280245621673558344", mt.Root().BigInt().String())
  422. err = mt.Delete(big.NewInt(1))
  423. assert.Nil(t, err)
  424. assert.Equal(t, "10304354743004778619823249005484018655542356856535590307973732141291410579841", mt.Root().BigInt().String())
  425. mt2 := newTestingMerkle(t, 140)
  426. defer mt2.db.Close()
  427. err = mt2.Add(big.NewInt(2), big.NewInt(2))
  428. assert.Nil(t, err)
  429. assert.Equal(t, mt2.Root(), mt.Root())
  430. }
  431. func TestDelete4(t *testing.T) {
  432. mt := newTestingMerkle(t, 140)
  433. defer mt.db.Close()
  434. err := mt.Add(big.NewInt(1), big.NewInt(1))
  435. assert.Nil(t, err)
  436. err = mt.Add(big.NewInt(2), big.NewInt(2))
  437. assert.Nil(t, err)
  438. err = mt.Add(big.NewInt(3), big.NewInt(3))
  439. assert.Nil(t, err)
  440. assert.Equal(t, "6989694633650442615746486460134957295274675622748484439660143938730686550248", mt.Root().BigInt().String())
  441. err = mt.Delete(big.NewInt(1))
  442. assert.Nil(t, err)
  443. assert.Equal(t, "1192610901536912535888866440319084773171371421781091005185759505381507049136", mt.Root().BigInt().String())
  444. mt2 := newTestingMerkle(t, 140)
  445. defer mt2.db.Close()
  446. err = mt2.Add(big.NewInt(2), big.NewInt(2))
  447. assert.Nil(t, err)
  448. err = mt2.Add(big.NewInt(3), big.NewInt(3))
  449. assert.Nil(t, err)
  450. assert.Equal(t, mt2.Root(), mt.Root())
  451. }
  452. func TestDelete5(t *testing.T) {
  453. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  454. assert.Nil(t, err)
  455. err = mt.Add(big.NewInt(1), big.NewInt(2))
  456. assert.Nil(t, err)
  457. err = mt.Add(big.NewInt(33), big.NewInt(44))
  458. assert.Nil(t, err)
  459. assert.Equal(t, "11404118908468506234838877883514126008995570353394659302846433035311596046064", mt.Root().BigInt().String())
  460. err = mt.Delete(big.NewInt(1))
  461. assert.Nil(t, err)
  462. assert.Equal(t, "12802904154263054831102426711825443668153853847661287611768065280921698471037", mt.Root().BigInt().String())
  463. mt2 := newTestingMerkle(t, 140)
  464. defer mt2.db.Close()
  465. err = mt2.Add(big.NewInt(33), big.NewInt(44))
  466. assert.Nil(t, err)
  467. assert.Equal(t, mt2.Root(), mt.Root())
  468. }
  469. func TestDeleteNonExistingKeys(t *testing.T) {
  470. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  471. assert.Nil(t, err)
  472. err = mt.Add(big.NewInt(1), big.NewInt(2))
  473. assert.Nil(t, err)
  474. err = mt.Add(big.NewInt(33), big.NewInt(44))
  475. assert.Nil(t, err)
  476. err = mt.Delete(big.NewInt(33))
  477. assert.Nil(t, err)
  478. err = mt.Delete(big.NewInt(33))
  479. assert.Equal(t, ErrKeyNotFound, err)
  480. err = mt.Delete(big.NewInt(1))
  481. assert.Nil(t, err)
  482. assert.Equal(t, "0", mt.Root().String())
  483. err = mt.Delete(big.NewInt(33))
  484. assert.Equal(t, ErrKeyNotFound, err)
  485. }
  486. func TestDumpLeafsImportLeafs(t *testing.T) {
  487. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
  488. require.Nil(t, err)
  489. defer mt.db.Close()
  490. q1 := new(big.Int).Sub(constants.Q, big.NewInt(1))
  491. for i := 0; i < 10; i++ {
  492. // use numbers near under Q
  493. k := new(big.Int).Sub(q1, big.NewInt(int64(i)))
  494. v := big.NewInt(0)
  495. err = mt.Add(k, v)
  496. require.Nil(t, err)
  497. // use numbers near above 0
  498. k = big.NewInt(int64(i))
  499. err = mt.Add(k, v)
  500. require.Nil(t, err)
  501. }
  502. d, err := mt.DumpLeafs(nil)
  503. assert.Nil(t, err)
  504. mt2, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
  505. require.Nil(t, err)
  506. defer mt2.db.Close()
  507. err = mt2.ImportDumpedLeafs(d)
  508. assert.Nil(t, err)
  509. assert.Equal(t, mt.Root(), mt2.Root())
  510. }
  511. func TestAddAndGetCircomProof(t *testing.T) {
  512. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  513. assert.Nil(t, err)
  514. assert.Equal(t, "0", mt.Root().String())
  515. // test vectors generated using https://github.com/iden3/circomlib smt.js
  516. cpp, err := mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2))
  517. assert.Nil(t, err)
  518. assert.Equal(t, "0", cpp.OldRoot.String())
  519. assert.Equal(t, "64497120...", cpp.NewRoot.String())
  520. assert.Equal(t, "0", cpp.OldKey.String())
  521. assert.Equal(t, "0", cpp.OldValue.String())
  522. assert.Equal(t, "1", cpp.NewKey.String())
  523. assert.Equal(t, "2", cpp.NewValue.String())
  524. assert.Equal(t, true, cpp.IsOld0)
  525. assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  526. cpp, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44))
  527. assert.Nil(t, err)
  528. assert.Equal(t, "64497120...", cpp.OldRoot.String())
  529. assert.Equal(t, "11404118...", cpp.NewRoot.String())
  530. assert.Equal(t, "1", cpp.OldKey.String())
  531. assert.Equal(t, "2", cpp.OldValue.String())
  532. assert.Equal(t, "33", cpp.NewKey.String())
  533. assert.Equal(t, "44", cpp.NewValue.String())
  534. assert.Equal(t, false, cpp.IsOld0)
  535. assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  536. cpp, err = mt.AddAndGetCircomProof(big.NewInt(55), big.NewInt(66))
  537. assert.Nil(t, err)
  538. assert.Equal(t, "11404118...", cpp.OldRoot.String())
  539. assert.Equal(t, "18284203...", cpp.NewRoot.String())
  540. assert.Equal(t, "0", cpp.OldKey.String())
  541. assert.Equal(t, "0", cpp.OldValue.String())
  542. assert.Equal(t, "55", cpp.NewKey.String())
  543. assert.Equal(t, "66", cpp.NewValue.String())
  544. assert.Equal(t, true, cpp.IsOld0)
  545. assert.Equal(t, "[0 42948778... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  546. // fmt.Println(cpp)
  547. }
  548. func TestUpdateCircomProcessorProof(t *testing.T) {
  549. mt := newTestingMerkle(t, 10)
  550. defer mt.db.Close()
  551. for i := 0; i < 16; i++ {
  552. k := big.NewInt(int64(i))
  553. v := big.NewInt(int64(i * 2))
  554. if err := mt.Add(k, v); err != nil {
  555. t.Fatal(err)
  556. }
  557. }
  558. _, v, _, err := mt.Get(big.NewInt(10))
  559. assert.Nil(t, err)
  560. assert.Equal(t, big.NewInt(20), v)
  561. // test vectors generated using https://github.com/iden3/circomlib smt.js
  562. cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
  563. assert.Nil(t, err)
  564. assert.Equal(t, "14895645...", cpp.OldRoot.String())
  565. assert.Equal(t, "75223641...", cpp.NewRoot.String())
  566. assert.Equal(t, "10", cpp.OldKey.String())
  567. assert.Equal(t, "20", cpp.OldValue.String())
  568. assert.Equal(t, "10", cpp.NewKey.String())
  569. assert.Equal(t, "1024", cpp.NewValue.String())
  570. assert.Equal(t, false, cpp.IsOld0)
  571. assert.Equal(t, "[19625419... 46910949... 18399594... 20473908... 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  572. }