You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

730 lines
22 KiB

3 years ago
  1. package merkletree
  2. import (
  3. "bytes"
  4. "encoding/hex"
  5. "encoding/json"
  6. "fmt"
  7. "math/big"
  8. "testing"
  9. "github.com/iden3/go-iden3-crypto/constants"
  10. cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
  11. "github.com/iden3/go-merkletree/db/memory"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/stretchr/testify/require"
  14. )
  15. var debug = false
  16. type Fatalable interface {
  17. Fatal(args ...interface{})
  18. }
  19. func newTestingMerkle(f Fatalable, numLevels int) *MerkleTree {
  20. mt, err := NewMerkleTree(memory.NewMemoryStorage(), numLevels)
  21. if err != nil {
  22. f.Fatal(err)
  23. return nil
  24. }
  25. return mt
  26. }
  27. func TestHashParsers(t *testing.T) {
  28. h0 := NewHashFromBigInt(big.NewInt(0))
  29. assert.Equal(t, "0", h0.String())
  30. h1 := NewHashFromBigInt(big.NewInt(1))
  31. assert.Equal(t, "1", h1.String())
  32. h10 := NewHashFromBigInt(big.NewInt(10))
  33. assert.Equal(t, "10", h10.String())
  34. h7l := NewHashFromBigInt(big.NewInt(1234567))
  35. assert.Equal(t, "1234567", h7l.String())
  36. h8l := NewHashFromBigInt(big.NewInt(12345678))
  37. assert.Equal(t, "12345678...", h8l.String())
  38. b, ok := new(big.Int).SetString("4932297968297298434239270129193057052722409868268166443802652458940273154854", 10)
  39. assert.True(t, ok)
  40. h := NewHashFromBigInt(b)
  41. assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", h.BigInt().String())
  42. assert.Equal(t, "49322979...", h.String())
  43. assert.Equal(t, "265baaf161e875c372d08e50f52abddc01d32efc93e90290bb8b3d9ceb94e70a", h.Hex())
  44. b1, err := NewBigIntFromHashBytes(b.Bytes())
  45. assert.Nil(t, err)
  46. assert.Equal(t, new(big.Int).SetBytes(b.Bytes()).String(), b1.String())
  47. b2, err := NewHashFromBytes(b.Bytes())
  48. assert.Nil(t, err)
  49. assert.Equal(t, b.String(), b2.BigInt().String())
  50. h2, err := NewHashFromHex(h.Hex())
  51. assert.Nil(t, err)
  52. assert.Equal(t, h, h2)
  53. _, err = NewHashFromHex("0x12")
  54. assert.NotNil(t, err)
  55. // check limits
  56. a := new(big.Int).Sub(constants.Q, big.NewInt(1))
  57. testHashParsers(t, a)
  58. a = big.NewInt(int64(1))
  59. testHashParsers(t, a)
  60. }
  61. func testHashParsers(t *testing.T, a *big.Int) {
  62. require.True(t, cryptoUtils.CheckBigIntInField(a))
  63. h := NewHashFromBigInt(a)
  64. assert.Equal(t, a, h.BigInt())
  65. hFromBytes, err := NewHashFromBytes(h.Bytes())
  66. assert.Nil(t, err)
  67. assert.Equal(t, h, hFromBytes)
  68. assert.Equal(t, a, hFromBytes.BigInt())
  69. assert.Equal(t, a.String(), hFromBytes.BigInt().String())
  70. hFromHex, err := NewHashFromHex(h.Hex())
  71. assert.Nil(t, err)
  72. assert.Equal(t, h, hFromHex)
  73. aBIFromHBytes, err := NewBigIntFromHashBytes(h.Bytes())
  74. assert.Nil(t, err)
  75. assert.Equal(t, a, aBIFromHBytes)
  76. assert.Equal(t, new(big.Int).SetBytes(a.Bytes()).String(), aBIFromHBytes.String())
  77. }
  78. func TestNewTree(t *testing.T) {
  79. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  80. assert.Nil(t, err)
  81. assert.Equal(t, "0", mt.Root().String())
  82. // test vectors generated using https://github.com/iden3/circomlib smt.js
  83. err = mt.Add(big.NewInt(1), big.NewInt(2))
  84. assert.Nil(t, err)
  85. assert.Equal(t, "6449712043256457369579901840927028403950625973089336675272087704159094984964", mt.Root().BigInt().String())
  86. err = mt.Add(big.NewInt(33), big.NewInt(44))
  87. assert.Nil(t, err)
  88. assert.Equal(t, "11404118908468506234838877883514126008995570353394659302846433035311596046064", mt.Root().BigInt().String())
  89. err = mt.Add(big.NewInt(1234), big.NewInt(9876))
  90. assert.Nil(t, err)
  91. assert.Equal(t, "12841932325181810040554102151615400973767747666110051836366805309524360490677", mt.Root().BigInt().String())
  92. dbRoot, err := mt.dbGetRoot()
  93. require.Nil(t, err)
  94. assert.Equal(t, mt.Root(), dbRoot)
  95. proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
  96. assert.Nil(t, err)
  97. assert.Equal(t, big.NewInt(44), v)
  98. assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44)))
  99. assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45)))
  100. }
  101. func TestAddDifferentOrder(t *testing.T) {
  102. mt1 := newTestingMerkle(t, 140)
  103. defer mt1.db.Close()
  104. for i := 0; i < 16; i++ {
  105. k := big.NewInt(int64(i))
  106. v := big.NewInt(0)
  107. if err := mt1.Add(k, v); err != nil {
  108. t.Fatal(err)
  109. }
  110. }
  111. mt2 := newTestingMerkle(t, 140)
  112. defer mt2.db.Close()
  113. for i := 16 - 1; i >= 0; i-- {
  114. k := big.NewInt(int64(i))
  115. v := big.NewInt(0)
  116. if err := mt2.Add(k, v); err != nil {
  117. t.Fatal(err)
  118. }
  119. }
  120. assert.Equal(t, mt1.Root().Hex(), mt2.Root().Hex())
  121. assert.Equal(t, "268e25964aa9d6ba42d66ae9eb44b5528540acb19a3644d1367d8c6f7cb23006", mt1.Root().Hex())
  122. }
  123. func TestAddRepeatedIndex(t *testing.T) {
  124. mt := newTestingMerkle(t, 140)
  125. defer mt.db.Close()
  126. k := big.NewInt(int64(3))
  127. v := big.NewInt(int64(12))
  128. if err := mt.Add(k, v); err != nil {
  129. t.Fatal(err)
  130. }
  131. err := mt.Add(k, v)
  132. assert.NotNil(t, err)
  133. assert.Equal(t, err, ErrEntryIndexAlreadyExists)
  134. }
  135. func TestGet(t *testing.T) {
  136. mt := newTestingMerkle(t, 140)
  137. defer mt.db.Close()
  138. for i := 0; i < 16; i++ {
  139. k := big.NewInt(int64(i))
  140. v := big.NewInt(int64(i * 2))
  141. if err := mt.Add(k, v); err != nil {
  142. t.Fatal(err)
  143. }
  144. }
  145. k, v, _, err := mt.Get(big.NewInt(10))
  146. assert.Nil(t, err)
  147. assert.Equal(t, big.NewInt(10), k)
  148. assert.Equal(t, big.NewInt(20), v)
  149. k, v, _, err = mt.Get(big.NewInt(15))
  150. assert.Nil(t, err)
  151. assert.Equal(t, big.NewInt(15), k)
  152. assert.Equal(t, big.NewInt(30), v)
  153. k, v, _, err = mt.Get(big.NewInt(16))
  154. assert.NotNil(t, err)
  155. assert.Equal(t, ErrKeyNotFound, err)
  156. assert.Equal(t, "0", k.String())
  157. assert.Equal(t, "0", v.String())
  158. }
  159. func TestUpdate(t *testing.T) {
  160. mt := newTestingMerkle(t, 140)
  161. defer mt.db.Close()
  162. for i := 0; i < 16; i++ {
  163. k := big.NewInt(int64(i))
  164. v := big.NewInt(int64(i * 2))
  165. if err := mt.Add(k, v); err != nil {
  166. t.Fatal(err)
  167. }
  168. }
  169. _, v, _, err := mt.Get(big.NewInt(10))
  170. assert.Nil(t, err)
  171. assert.Equal(t, big.NewInt(20), v)
  172. _, err = mt.Update(big.NewInt(10), big.NewInt(1024))
  173. assert.Nil(t, err)
  174. _, v, _, err = mt.Get(big.NewInt(10))
  175. assert.Nil(t, err)
  176. assert.Equal(t, big.NewInt(1024), v)
  177. _, err = mt.Update(big.NewInt(1000), big.NewInt(1024))
  178. assert.Equal(t, ErrKeyNotFound, err)
  179. dbRoot, err := mt.dbGetRoot()
  180. require.Nil(t, err)
  181. assert.Equal(t, mt.Root(), dbRoot)
  182. }
  183. func TestUpdate2(t *testing.T) {
  184. mt1 := newTestingMerkle(t, 140)
  185. defer mt1.db.Close()
  186. mt2 := newTestingMerkle(t, 140)
  187. defer mt2.db.Close()
  188. err := mt1.Add(big.NewInt(1), big.NewInt(119))
  189. assert.Nil(t, err)
  190. err = mt1.Add(big.NewInt(2), big.NewInt(229))
  191. assert.Nil(t, err)
  192. err = mt1.Add(big.NewInt(9876), big.NewInt(6789))
  193. assert.Nil(t, err)
  194. err = mt2.Add(big.NewInt(1), big.NewInt(11))
  195. assert.Nil(t, err)
  196. err = mt2.Add(big.NewInt(2), big.NewInt(22))
  197. assert.Nil(t, err)
  198. err = mt2.Add(big.NewInt(9876), big.NewInt(10))
  199. assert.Nil(t, err)
  200. _, err = mt1.Update(big.NewInt(1), big.NewInt(11))
  201. assert.Nil(t, err)
  202. _, err = mt1.Update(big.NewInt(2), big.NewInt(22))
  203. assert.Nil(t, err)
  204. _, err = mt2.Update(big.NewInt(9876), big.NewInt(6789))
  205. assert.Nil(t, err)
  206. assert.Equal(t, mt1.Root(), mt2.Root())
  207. }
  208. func TestGenerateAndVerifyProof128(t *testing.T) {
  209. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
  210. require.Nil(t, err)
  211. defer mt.db.Close()
  212. for i := 0; i < 128; i++ {
  213. k := big.NewInt(int64(i))
  214. v := big.NewInt(0)
  215. if err := mt.Add(k, v); err != nil {
  216. t.Fatal(err)
  217. }
  218. }
  219. proof, v, err := mt.GenerateProof(big.NewInt(42), nil)
  220. assert.Nil(t, err)
  221. assert.Equal(t, "0", v.String())
  222. assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(42), big.NewInt(0)))
  223. }
  224. func TestTreeLimit(t *testing.T) {
  225. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 5)
  226. require.Nil(t, err)
  227. defer mt.db.Close()
  228. for i := 0; i < 16; i++ {
  229. err = mt.Add(big.NewInt(int64(i)), big.NewInt(int64(i)))
  230. assert.Nil(t, err)
  231. }
  232. // here the tree is full, should not allow to add more data as reaches the maximum number of levels
  233. err = mt.Add(big.NewInt(int64(16)), big.NewInt(int64(16)))
  234. assert.NotNil(t, err)
  235. assert.Equal(t, ErrReachedMaxLevel, err)
  236. }
  237. func TestSiblingsFromProof(t *testing.T) {
  238. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
  239. require.Nil(t, err)
  240. defer mt.db.Close()
  241. for i := 0; i < 64; i++ {
  242. k := big.NewInt(int64(i))
  243. v := big.NewInt(0)
  244. if err := mt.Add(k, v); err != nil {
  245. t.Fatal(err)
  246. }
  247. }
  248. proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
  249. if err != nil {
  250. t.Fatal(err)
  251. }
  252. siblings := SiblingsFromProof(proof)
  253. assert.Equal(t, 6, len(siblings))
  254. assert.Equal(t, "5b478bdd58595ead03ebf494a74014cbb576ba0d9456aa0916885b9eefae592f", siblings[0].Hex())
  255. assert.Equal(t, "c1e8ab120a4e475ea1bf00633228bfb9d248f7ddec2aa6367f98d0defb9fb22e", siblings[1].Hex())
  256. assert.Equal(t, "f4dafd8ac2b9165adc3f6d125af67d5a4d8a7a263dcc90a373d0338929e16e0c", siblings[2].Hex())
  257. assert.Equal(t, "a94aa346bd85f96aba2e85b67920e44fe6ed767b0e13bea602784e0b8b897515", siblings[3].Hex())
  258. assert.Equal(t, "54791d7514030ded79301dbf221f5bf186facbc5800912411852fdc101b7151d", siblings[4].Hex())
  259. assert.Equal(t, "435d28bc0511f8feb93b5f1649a049b460947702ce0baaefcf596175370fe01e", siblings[5].Hex())
  260. }
  261. func TestVerifyProofCases(t *testing.T) {
  262. mt := newTestingMerkle(t, 140)
  263. defer mt.DB().Close()
  264. for i := 0; i < 8; i++ {
  265. if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
  266. t.Fatal(err)
  267. }
  268. }
  269. // Existence proof
  270. proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
  271. if err != nil {
  272. t.Fatal(err)
  273. }
  274. assert.Equal(t, proof.Existence, true)
  275. assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(4), big.NewInt(0)))
  276. assert.Equal(t, "0003000000000000000000000000000000000000000000000000000000000007a6d6b46fefe213a6b579844a1bb7ab5c2db4a13f8662d9c5e729c36728f42730211ddfcc8d30ebd157d1d6912769b8e4abdca41e5dc2b57b026a361c091a8c14c748530e61bf8ea80c987657c3d24b134ece1ef8e2d4bd3f74437bf4392a6b1e", hex.EncodeToString(proof.Bytes()))
  277. for i := 8; i < 32; i++ {
  278. proof, _, err = mt.GenerateProof(big.NewInt(int64(i)), nil)
  279. assert.Nil(t, err)
  280. if debug {
  281. fmt.Println(i, proof)
  282. }
  283. }
  284. // Non-existence proof, empty aux
  285. proof, _, err = mt.GenerateProof(big.NewInt(12), nil)
  286. if err != nil {
  287. t.Fatal(err)
  288. }
  289. assert.Equal(t, proof.Existence, false)
  290. // assert.True(t, proof.nodeAux == nil)
  291. assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(12), big.NewInt(0)))
  292. assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007a6d6b46fefe213a6b579844a1bb7ab5c2db4a13f8662d9c5e729c36728f42730211ddfcc8d30ebd157d1d6912769b8e4abdca41e5dc2b57b026a361c091a8c14c748530e61bf8ea80c987657c3d24b134ece1ef8e2d4bd3f74437bf4392a6b1e04000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes()))
  293. // Non-existence proof, diff. node aux
  294. proof, _, err = mt.GenerateProof(big.NewInt(10), nil)
  295. if err != nil {
  296. t.Fatal(err)
  297. }
  298. assert.Equal(t, proof.Existence, false)
  299. assert.True(t, proof.NodeAux != nil)
  300. assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(10), big.NewInt(0)))
  301. assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007a6d6b46fefe213a6b579844a1bb7ab5c2db4a13f8662d9c5e729c36728f42730e667e2ca15909c4a23beff18e3cc74348fbd3c1a4c765a5bbbca126c9607a42b77e008a73926f1280f8531b139dc1cacf8d83fcec31d405f5c51b7cbddfe152902000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes()))
  302. }
  303. func TestVerifyProofFalse(t *testing.T) {
  304. mt := newTestingMerkle(t, 140)
  305. defer mt.DB().Close()
  306. for i := 0; i < 8; i++ {
  307. if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
  308. t.Fatal(err)
  309. }
  310. }
  311. // Invalid existence proof (node used for verification doesn't
  312. // correspond to node in the proof)
  313. proof, _, err := mt.GenerateProof(big.NewInt(int64(4)), nil)
  314. if err != nil {
  315. t.Fatal(err)
  316. }
  317. assert.Equal(t, proof.Existence, true)
  318. assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(int64(5)), big.NewInt(int64(5))))
  319. // Invalid non-existence proof (Non-existence proof, diff. node aux)
  320. proof, _, err = mt.GenerateProof(big.NewInt(int64(4)), nil)
  321. if err != nil {
  322. t.Fatal(err)
  323. }
  324. assert.Equal(t, proof.Existence, true)
  325. // Now we change the proof from existence to non-existence, and add e's
  326. // data as auxiliary node.
  327. proof.Existence = false
  328. proof.NodeAux = &NodeAux{Key: NewHashFromBigInt(big.NewInt(int64(4))), Value: NewHashFromBigInt(big.NewInt(4))}
  329. assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(int64(4)), big.NewInt(0)))
  330. }
  331. func TestGraphViz(t *testing.T) {
  332. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  333. assert.Nil(t, err)
  334. _ = mt.Add(big.NewInt(1), big.NewInt(0))
  335. _ = mt.Add(big.NewInt(2), big.NewInt(0))
  336. _ = mt.Add(big.NewInt(3), big.NewInt(0))
  337. _ = mt.Add(big.NewInt(4), big.NewInt(0))
  338. _ = mt.Add(big.NewInt(5), big.NewInt(0))
  339. _ = mt.Add(big.NewInt(100), big.NewInt(0))
  340. // mt.PrintGraphViz(nil)
  341. expected := `digraph hierarchy {
  342. node [fontname=Monospace,fontsize=10,shape=box]
  343. "16053348..." -> {"19137630..." "14119616..."}
  344. "19137630..." -> {"19543983..." "19746229..."}
  345. "19543983..." -> {"empty0" "65773153..."}
  346. "empty0" [style=dashed,label=0];
  347. "65773153..." -> {"73498412..." "empty1"}
  348. "empty1" [style=dashed,label=0];
  349. "73498412..." -> {"53169236..." "empty2"}
  350. "empty2" [style=dashed,label=0];
  351. "53169236..." -> {"73522717..." "34811870..."}
  352. "73522717..." [style=filled];
  353. "34811870..." [style=filled];
  354. "19746229..." [style=filled];
  355. "14119616..." -> {"19419204..." "15569531..."}
  356. "19419204..." -> {"78154875..." "34589916..."}
  357. "78154875..." [style=filled];
  358. "34589916..." [style=filled];
  359. "15569531..." [style=filled];
  360. }
  361. `
  362. w := bytes.NewBufferString("")
  363. err = mt.GraphViz(w, nil)
  364. assert.Nil(t, err)
  365. assert.Equal(t, []byte(expected), w.Bytes())
  366. }
  367. func TestDelete(t *testing.T) {
  368. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  369. assert.Nil(t, err)
  370. assert.Equal(t, "0", mt.Root().String())
  371. // test vectors generated using https://github.com/iden3/circomlib smt.js
  372. err = mt.Add(big.NewInt(1), big.NewInt(2))
  373. assert.Nil(t, err)
  374. assert.Equal(t, "6449712043256457369579901840927028403950625973089336675272087704159094984964", mt.Root().BigInt().String())
  375. err = mt.Add(big.NewInt(33), big.NewInt(44))
  376. assert.Nil(t, err)
  377. assert.Equal(t, "11404118908468506234838877883514126008995570353394659302846433035311596046064", mt.Root().BigInt().String())
  378. err = mt.Add(big.NewInt(1234), big.NewInt(9876))
  379. assert.Nil(t, err)
  380. assert.Equal(t, "12841932325181810040554102151615400973767747666110051836366805309524360490677", mt.Root().BigInt().String())
  381. // mt.PrintGraphViz(nil)
  382. err = mt.Delete(big.NewInt(33))
  383. // mt.PrintGraphViz(nil)
  384. assert.Nil(t, err)
  385. assert.Equal(t, "16195585003843604118922861401064871511855368913846540536604351220077317790615", mt.Root().BigInt().String())
  386. err = mt.Delete(big.NewInt(1234))
  387. assert.Nil(t, err)
  388. err = mt.Delete(big.NewInt(1))
  389. assert.Nil(t, err)
  390. assert.Equal(t, "0", mt.Root().String())
  391. dbRoot, err := mt.dbGetRoot()
  392. require.Nil(t, err)
  393. assert.Equal(t, mt.Root(), dbRoot)
  394. }
  395. func TestDelete2(t *testing.T) {
  396. mt := newTestingMerkle(t, 140)
  397. defer mt.db.Close()
  398. for i := 0; i < 8; i++ {
  399. k := big.NewInt(int64(i))
  400. v := big.NewInt(0)
  401. if err := mt.Add(k, v); err != nil {
  402. t.Fatal(err)
  403. }
  404. }
  405. expectedRoot := mt.Root()
  406. k := big.NewInt(8)
  407. v := big.NewInt(0)
  408. err := mt.Add(k, v)
  409. require.Nil(t, err)
  410. err = mt.Delete(big.NewInt(8))
  411. assert.Nil(t, err)
  412. assert.Equal(t, expectedRoot, mt.Root())
  413. mt2 := newTestingMerkle(t, 140)
  414. defer mt2.db.Close()
  415. for i := 0; i < 8; i++ {
  416. k := big.NewInt(int64(i))
  417. v := big.NewInt(0)
  418. if err := mt2.Add(k, v); err != nil {
  419. t.Fatal(err)
  420. }
  421. }
  422. assert.Equal(t, mt2.Root(), mt.Root())
  423. }
  424. func TestDelete3(t *testing.T) {
  425. mt := newTestingMerkle(t, 140)
  426. defer mt.db.Close()
  427. err := mt.Add(big.NewInt(1), big.NewInt(1))
  428. assert.Nil(t, err)
  429. err = mt.Add(big.NewInt(2), big.NewInt(2))
  430. assert.Nil(t, err)
  431. assert.Equal(t, "6701939280963330813043570145125351311131831356446202146710280245621673558344", mt.Root().BigInt().String())
  432. err = mt.Delete(big.NewInt(1))
  433. assert.Nil(t, err)
  434. assert.Equal(t, "10304354743004778619823249005484018655542356856535590307973732141291410579841", mt.Root().BigInt().String())
  435. mt2 := newTestingMerkle(t, 140)
  436. defer mt2.db.Close()
  437. err = mt2.Add(big.NewInt(2), big.NewInt(2))
  438. assert.Nil(t, err)
  439. assert.Equal(t, mt2.Root(), mt.Root())
  440. }
  441. func TestDelete4(t *testing.T) {
  442. mt := newTestingMerkle(t, 140)
  443. defer mt.db.Close()
  444. err := mt.Add(big.NewInt(1), big.NewInt(1))
  445. assert.Nil(t, err)
  446. err = mt.Add(big.NewInt(2), big.NewInt(2))
  447. assert.Nil(t, err)
  448. err = mt.Add(big.NewInt(3), big.NewInt(3))
  449. assert.Nil(t, err)
  450. assert.Equal(t, "6989694633650442615746486460134957295274675622748484439660143938730686550248", mt.Root().BigInt().String())
  451. err = mt.Delete(big.NewInt(1))
  452. assert.Nil(t, err)
  453. assert.Equal(t, "1192610901536912535888866440319084773171371421781091005185759505381507049136", mt.Root().BigInt().String())
  454. mt2 := newTestingMerkle(t, 140)
  455. defer mt2.db.Close()
  456. err = mt2.Add(big.NewInt(2), big.NewInt(2))
  457. assert.Nil(t, err)
  458. err = mt2.Add(big.NewInt(3), big.NewInt(3))
  459. assert.Nil(t, err)
  460. assert.Equal(t, mt2.Root(), mt.Root())
  461. }
  462. func TestDelete5(t *testing.T) {
  463. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  464. assert.Nil(t, err)
  465. err = mt.Add(big.NewInt(1), big.NewInt(2))
  466. assert.Nil(t, err)
  467. err = mt.Add(big.NewInt(33), big.NewInt(44))
  468. assert.Nil(t, err)
  469. assert.Equal(t, "11404118908468506234838877883514126008995570353394659302846433035311596046064", mt.Root().BigInt().String())
  470. err = mt.Delete(big.NewInt(1))
  471. assert.Nil(t, err)
  472. assert.Equal(t, "12802904154263054831102426711825443668153853847661287611768065280921698471037", mt.Root().BigInt().String())
  473. mt2 := newTestingMerkle(t, 140)
  474. defer mt2.db.Close()
  475. err = mt2.Add(big.NewInt(33), big.NewInt(44))
  476. assert.Nil(t, err)
  477. assert.Equal(t, mt2.Root(), mt.Root())
  478. }
  479. func TestDeleteNonExistingKeys(t *testing.T) {
  480. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  481. assert.Nil(t, err)
  482. err = mt.Add(big.NewInt(1), big.NewInt(2))
  483. assert.Nil(t, err)
  484. err = mt.Add(big.NewInt(33), big.NewInt(44))
  485. assert.Nil(t, err)
  486. err = mt.Delete(big.NewInt(33))
  487. assert.Nil(t, err)
  488. err = mt.Delete(big.NewInt(33))
  489. assert.Equal(t, ErrKeyNotFound, err)
  490. err = mt.Delete(big.NewInt(1))
  491. assert.Nil(t, err)
  492. assert.Equal(t, "0", mt.Root().String())
  493. err = mt.Delete(big.NewInt(33))
  494. assert.Equal(t, ErrKeyNotFound, err)
  495. }
  496. func TestDumpLeafsImportLeafs(t *testing.T) {
  497. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
  498. require.Nil(t, err)
  499. defer mt.db.Close()
  500. q1 := new(big.Int).Sub(constants.Q, big.NewInt(1))
  501. for i := 0; i < 10; i++ {
  502. // use numbers near under Q
  503. k := new(big.Int).Sub(q1, big.NewInt(int64(i)))
  504. v := big.NewInt(0)
  505. err = mt.Add(k, v)
  506. require.Nil(t, err)
  507. // use numbers near above 0
  508. k = big.NewInt(int64(i))
  509. err = mt.Add(k, v)
  510. require.Nil(t, err)
  511. }
  512. d, err := mt.DumpLeafs(nil)
  513. assert.Nil(t, err)
  514. mt2, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
  515. require.Nil(t, err)
  516. defer mt2.db.Close()
  517. err = mt2.ImportDumpedLeafs(d)
  518. assert.Nil(t, err)
  519. assert.Equal(t, mt.Root(), mt2.Root())
  520. }
  521. func TestAddAndGetCircomProof(t *testing.T) {
  522. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  523. assert.Nil(t, err)
  524. assert.Equal(t, "0", mt.Root().String())
  525. // test vectors generated using https://github.com/iden3/circomlib smt.js
  526. cpp, err := mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2))
  527. assert.Nil(t, err)
  528. assert.Equal(t, "0", cpp.OldRoot.String())
  529. assert.Equal(t, "64497120...", cpp.NewRoot.String())
  530. assert.Equal(t, "0", cpp.OldKey.String())
  531. assert.Equal(t, "0", cpp.OldValue.String())
  532. assert.Equal(t, "1", cpp.NewKey.String())
  533. assert.Equal(t, "2", cpp.NewValue.String())
  534. assert.Equal(t, true, cpp.IsOld0)
  535. assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  536. assert.Equal(t, mt.maxLevels+1, len(cpp.Siblings))
  537. cpp, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44))
  538. assert.Nil(t, err)
  539. assert.Equal(t, "64497120...", cpp.OldRoot.String())
  540. assert.Equal(t, "11404118...", cpp.NewRoot.String())
  541. assert.Equal(t, "1", cpp.OldKey.String())
  542. assert.Equal(t, "2", cpp.OldValue.String())
  543. assert.Equal(t, "33", cpp.NewKey.String())
  544. assert.Equal(t, "44", cpp.NewValue.String())
  545. assert.Equal(t, false, cpp.IsOld0)
  546. assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  547. assert.Equal(t, mt.maxLevels+1, len(cpp.Siblings))
  548. cpp, err = mt.AddAndGetCircomProof(big.NewInt(55), big.NewInt(66))
  549. assert.Nil(t, err)
  550. assert.Equal(t, "11404118...", cpp.OldRoot.String())
  551. assert.Equal(t, "18284203...", cpp.NewRoot.String())
  552. assert.Equal(t, "0", cpp.OldKey.String())
  553. assert.Equal(t, "0", cpp.OldValue.String())
  554. assert.Equal(t, "55", cpp.NewKey.String())
  555. assert.Equal(t, "66", cpp.NewValue.String())
  556. assert.Equal(t, true, cpp.IsOld0)
  557. assert.Equal(t, "[0 42948778... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  558. assert.Equal(t, mt.maxLevels+1, len(cpp.Siblings))
  559. }
  560. func TestUpdateCircomProcessorProof(t *testing.T) {
  561. mt := newTestingMerkle(t, 10)
  562. defer mt.db.Close()
  563. for i := 0; i < 16; i++ {
  564. k := big.NewInt(int64(i))
  565. v := big.NewInt(int64(i * 2))
  566. if err := mt.Add(k, v); err != nil {
  567. t.Fatal(err)
  568. }
  569. }
  570. _, v, _, err := mt.Get(big.NewInt(10))
  571. assert.Nil(t, err)
  572. assert.Equal(t, big.NewInt(20), v)
  573. // test vectors generated using https://github.com/iden3/circomlib smt.js
  574. cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
  575. assert.Nil(t, err)
  576. assert.Equal(t, "14895645...", cpp.OldRoot.String())
  577. assert.Equal(t, "75223641...", cpp.NewRoot.String())
  578. assert.Equal(t, "10", cpp.OldKey.String())
  579. assert.Equal(t, "20", cpp.OldValue.String())
  580. assert.Equal(t, "10", cpp.NewKey.String())
  581. assert.Equal(t, "1024", cpp.NewValue.String())
  582. assert.Equal(t, false, cpp.IsOld0)
  583. assert.Equal(t, "[19625419... 46910949... 18399594... 20473908... 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  584. }
  585. func TestTypesMarshalers(t *testing.T) {
  586. // test Hash marshalers
  587. h, err := NewHashFromString("42")
  588. assert.Nil(t, err)
  589. s, err := json.Marshal(h)
  590. assert.Nil(t, err)
  591. var h2 *Hash
  592. err = json.Unmarshal(s, &h2)
  593. assert.Nil(t, err)
  594. assert.Equal(t, h, h2)
  595. // create CircomProcessorProof
  596. mt := newTestingMerkle(t, 10)
  597. defer mt.db.Close()
  598. for i := 0; i < 16; i++ {
  599. k := big.NewInt(int64(i))
  600. v := big.NewInt(int64(i * 2))
  601. if err := mt.Add(k, v); err != nil {
  602. t.Fatal(err)
  603. }
  604. }
  605. _, v, _, err := mt.Get(big.NewInt(10))
  606. assert.Nil(t, err)
  607. assert.Equal(t, big.NewInt(20), v)
  608. cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
  609. assert.Nil(t, err)
  610. // test CircomProcessorProof marshalers
  611. b, err := json.Marshal(&cpp)
  612. assert.Nil(t, err)
  613. var cpp2 *CircomProcessorProof
  614. err = json.Unmarshal(b, &cpp2)
  615. assert.Nil(t, err)
  616. assert.Equal(t, cpp, cpp2)
  617. }