You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

788 lines
25 KiB

4 years ago
  1. package merkletree
  2. import (
  3. "bytes"
  4. "encoding/hex"
  5. "encoding/json"
  6. "fmt"
  7. "math/big"
  8. "testing"
  9. "github.com/iden3/go-iden3-crypto/constants"
  10. cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
  11. "github.com/iden3/go-merkletree/db/memory"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/stretchr/testify/require"
  14. )
  15. var debug = false
  16. type Fatalable interface {
  17. Fatal(args ...interface{})
  18. }
  19. func newTestingMerkle(f Fatalable, numLevels int) *MerkleTree {
  20. mt, err := NewMerkleTree(memory.NewMemoryStorage(), numLevels)
  21. if err != nil {
  22. f.Fatal(err)
  23. return nil
  24. }
  25. return mt
  26. }
  27. func TestHashParsers(t *testing.T) {
  28. h0 := NewHashFromBigInt(big.NewInt(0))
  29. assert.Equal(t, "0", h0.String())
  30. h1 := NewHashFromBigInt(big.NewInt(1))
  31. assert.Equal(t, "1", h1.String())
  32. h10 := NewHashFromBigInt(big.NewInt(10))
  33. assert.Equal(t, "10", h10.String())
  34. h7l := NewHashFromBigInt(big.NewInt(1234567))
  35. assert.Equal(t, "1234567", h7l.String())
  36. h8l := NewHashFromBigInt(big.NewInt(12345678))
  37. assert.Equal(t, "12345678...", h8l.String())
  38. b, ok := new(big.Int).SetString("4932297968297298434239270129193057052722409868268166443802652458940273154854", 10) //nolint:lll
  39. assert.True(t, ok)
  40. h := NewHashFromBigInt(b)
  41. assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", h.BigInt().String()) //nolint:lll
  42. assert.Equal(t, "49322979...", h.String())
  43. assert.Equal(t, "265baaf161e875c372d08e50f52abddc01d32efc93e90290bb8b3d9ceb94e70a", h.Hex())
  44. b1, err := NewBigIntFromHashBytes(b.Bytes())
  45. assert.Nil(t, err)
  46. assert.Equal(t, new(big.Int).SetBytes(b.Bytes()).String(), b1.String())
  47. b2, err := NewHashFromBytes(b.Bytes())
  48. assert.Nil(t, err)
  49. assert.Equal(t, b.String(), b2.BigInt().String())
  50. h2, err := NewHashFromHex(h.Hex())
  51. assert.Nil(t, err)
  52. assert.Equal(t, h, h2)
  53. _, err = NewHashFromHex("0x12")
  54. assert.NotNil(t, err)
  55. // check limits
  56. a := new(big.Int).Sub(constants.Q, big.NewInt(1))
  57. testHashParsers(t, a)
  58. a = big.NewInt(int64(1))
  59. testHashParsers(t, a)
  60. }
  61. func testHashParsers(t *testing.T, a *big.Int) {
  62. require.True(t, cryptoUtils.CheckBigIntInField(a))
  63. h := NewHashFromBigInt(a)
  64. assert.Equal(t, a, h.BigInt())
  65. hFromBytes, err := NewHashFromBytes(h.Bytes())
  66. assert.Nil(t, err)
  67. assert.Equal(t, h, hFromBytes)
  68. assert.Equal(t, a, hFromBytes.BigInt())
  69. assert.Equal(t, a.String(), hFromBytes.BigInt().String())
  70. hFromHex, err := NewHashFromHex(h.Hex())
  71. assert.Nil(t, err)
  72. assert.Equal(t, h, hFromHex)
  73. aBIFromHBytes, err := NewBigIntFromHashBytes(h.Bytes())
  74. assert.Nil(t, err)
  75. assert.Equal(t, a, aBIFromHBytes)
  76. assert.Equal(t, new(big.Int).SetBytes(a.Bytes()).String(), aBIFromHBytes.String())
  77. }
  78. func TestNewTree(t *testing.T) {
  79. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  80. assert.Nil(t, err)
  81. assert.Equal(t, "0", mt.Root().String())
  82. // test vectors generated using https://github.com/iden3/circomlib smt.js
  83. err = mt.Add(big.NewInt(1), big.NewInt(2))
  84. assert.Nil(t, err)
  85. assert.Equal(t, "6449712043256457369579901840927028403950625973089336675272087704159094984964", mt.Root().BigInt().String()) //nolint:lll
  86. err = mt.Add(big.NewInt(33), big.NewInt(44))
  87. assert.Nil(t, err)
  88. assert.Equal(t, "11404118908468506234838877883514126008995570353394659302846433035311596046064", mt.Root().BigInt().String()) //nolint:lll
  89. err = mt.Add(big.NewInt(1234), big.NewInt(9876))
  90. assert.Nil(t, err)
  91. assert.Equal(t, "12841932325181810040554102151615400973767747666110051836366805309524360490677", mt.Root().BigInt().String()) //nolint:lll
  92. dbRoot, err := mt.dbGetRoot()
  93. require.Nil(t, err)
  94. assert.Equal(t, mt.Root(), dbRoot)
  95. proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
  96. assert.Nil(t, err)
  97. assert.Equal(t, big.NewInt(44), v)
  98. assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44)))
  99. assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45)))
  100. }
  101. func TestAddDifferentOrder(t *testing.T) {
  102. mt1 := newTestingMerkle(t, 140)
  103. defer mt1.db.Close()
  104. for i := 0; i < 16; i++ {
  105. k := big.NewInt(int64(i))
  106. v := big.NewInt(0)
  107. if err := mt1.Add(k, v); err != nil {
  108. t.Fatal(err)
  109. }
  110. }
  111. mt2 := newTestingMerkle(t, 140)
  112. defer mt2.db.Close()
  113. for i := 16 - 1; i >= 0; i-- {
  114. k := big.NewInt(int64(i))
  115. v := big.NewInt(0)
  116. if err := mt2.Add(k, v); err != nil {
  117. t.Fatal(err)
  118. }
  119. }
  120. assert.Equal(t, mt1.Root().Hex(), mt2.Root().Hex())
  121. assert.Equal(t, "268e25964aa9d6ba42d66ae9eb44b5528540acb19a3644d1367d8c6f7cb23006", mt1.Root().Hex()) //nolint:lll
  122. }
  123. func TestAddRepeatedIndex(t *testing.T) {
  124. mt := newTestingMerkle(t, 140)
  125. defer mt.db.Close()
  126. k := big.NewInt(int64(3))
  127. v := big.NewInt(int64(12))
  128. if err := mt.Add(k, v); err != nil {
  129. t.Fatal(err)
  130. }
  131. err := mt.Add(k, v)
  132. assert.NotNil(t, err)
  133. assert.Equal(t, err, ErrEntryIndexAlreadyExists)
  134. }
  135. func TestGet(t *testing.T) {
  136. mt := newTestingMerkle(t, 140)
  137. defer mt.db.Close()
  138. for i := 0; i < 16; i++ {
  139. k := big.NewInt(int64(i))
  140. v := big.NewInt(int64(i * 2))
  141. if err := mt.Add(k, v); err != nil {
  142. t.Fatal(err)
  143. }
  144. }
  145. k, v, _, err := mt.Get(big.NewInt(10))
  146. assert.Nil(t, err)
  147. assert.Equal(t, big.NewInt(10), k)
  148. assert.Equal(t, big.NewInt(20), v)
  149. k, v, _, err = mt.Get(big.NewInt(15))
  150. assert.Nil(t, err)
  151. assert.Equal(t, big.NewInt(15), k)
  152. assert.Equal(t, big.NewInt(30), v)
  153. k, v, _, err = mt.Get(big.NewInt(16))
  154. assert.NotNil(t, err)
  155. assert.Equal(t, ErrKeyNotFound, err)
  156. assert.Equal(t, "0", k.String())
  157. assert.Equal(t, "0", v.String())
  158. }
  159. func TestUpdate(t *testing.T) {
  160. mt := newTestingMerkle(t, 140)
  161. defer mt.db.Close()
  162. for i := 0; i < 16; i++ {
  163. k := big.NewInt(int64(i))
  164. v := big.NewInt(int64(i * 2))
  165. if err := mt.Add(k, v); err != nil {
  166. t.Fatal(err)
  167. }
  168. }
  169. _, v, _, err := mt.Get(big.NewInt(10))
  170. assert.Nil(t, err)
  171. assert.Equal(t, big.NewInt(20), v)
  172. _, err = mt.Update(big.NewInt(10), big.NewInt(1024))
  173. assert.Nil(t, err)
  174. _, v, _, err = mt.Get(big.NewInt(10))
  175. assert.Nil(t, err)
  176. assert.Equal(t, big.NewInt(1024), v)
  177. _, err = mt.Update(big.NewInt(1000), big.NewInt(1024))
  178. assert.Equal(t, ErrKeyNotFound, err)
  179. dbRoot, err := mt.dbGetRoot()
  180. require.Nil(t, err)
  181. assert.Equal(t, mt.Root(), dbRoot)
  182. }
  183. func TestUpdate2(t *testing.T) {
  184. mt1 := newTestingMerkle(t, 140)
  185. defer mt1.db.Close()
  186. mt2 := newTestingMerkle(t, 140)
  187. defer mt2.db.Close()
  188. err := mt1.Add(big.NewInt(1), big.NewInt(119))
  189. assert.Nil(t, err)
  190. err = mt1.Add(big.NewInt(2), big.NewInt(229))
  191. assert.Nil(t, err)
  192. err = mt1.Add(big.NewInt(9876), big.NewInt(6789))
  193. assert.Nil(t, err)
  194. err = mt2.Add(big.NewInt(1), big.NewInt(11))
  195. assert.Nil(t, err)
  196. err = mt2.Add(big.NewInt(2), big.NewInt(22))
  197. assert.Nil(t, err)
  198. err = mt2.Add(big.NewInt(9876), big.NewInt(10))
  199. assert.Nil(t, err)
  200. _, err = mt1.Update(big.NewInt(1), big.NewInt(11))
  201. assert.Nil(t, err)
  202. _, err = mt1.Update(big.NewInt(2), big.NewInt(22))
  203. assert.Nil(t, err)
  204. _, err = mt2.Update(big.NewInt(9876), big.NewInt(6789))
  205. assert.Nil(t, err)
  206. assert.Equal(t, mt1.Root(), mt2.Root())
  207. }
  208. func TestGenerateAndVerifyProof128(t *testing.T) {
  209. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
  210. require.Nil(t, err)
  211. defer mt.db.Close()
  212. for i := 0; i < 128; i++ {
  213. k := big.NewInt(int64(i))
  214. v := big.NewInt(0)
  215. if err := mt.Add(k, v); err != nil {
  216. t.Fatal(err)
  217. }
  218. }
  219. proof, v, err := mt.GenerateProof(big.NewInt(42), nil)
  220. assert.Nil(t, err)
  221. assert.Equal(t, "0", v.String())
  222. assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(42), big.NewInt(0)))
  223. }
  224. func TestTreeLimit(t *testing.T) {
  225. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 5)
  226. require.Nil(t, err)
  227. defer mt.db.Close()
  228. for i := 0; i < 16; i++ {
  229. err = mt.Add(big.NewInt(int64(i)), big.NewInt(int64(i)))
  230. assert.Nil(t, err)
  231. }
  232. // here the tree is full, should not allow to add more data as reaches the maximum number of levels
  233. err = mt.Add(big.NewInt(int64(16)), big.NewInt(int64(16)))
  234. assert.NotNil(t, err)
  235. assert.Equal(t, ErrReachedMaxLevel, err)
  236. }
  237. func TestSiblingsFromProof(t *testing.T) {
  238. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
  239. require.Nil(t, err)
  240. defer mt.db.Close()
  241. for i := 0; i < 64; i++ {
  242. k := big.NewInt(int64(i))
  243. v := big.NewInt(0)
  244. if err := mt.Add(k, v); err != nil {
  245. t.Fatal(err)
  246. }
  247. }
  248. proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
  249. if err != nil {
  250. t.Fatal(err)
  251. }
  252. siblings := SiblingsFromProof(proof)
  253. assert.Equal(t, 6, len(siblings))
  254. assert.Equal(t,
  255. "5b478bdd58595ead03ebf494a74014cbb576ba0d9456aa0916885b9eefae592f",
  256. siblings[0].Hex())
  257. assert.Equal(t,
  258. "c1e8ab120a4e475ea1bf00633228bfb9d248f7ddec2aa6367f98d0defb9fb22e",
  259. siblings[1].Hex())
  260. assert.Equal(t,
  261. "f4dafd8ac2b9165adc3f6d125af67d5a4d8a7a263dcc90a373d0338929e16e0c",
  262. siblings[2].Hex())
  263. assert.Equal(t,
  264. "a94aa346bd85f96aba2e85b67920e44fe6ed767b0e13bea602784e0b8b897515",
  265. siblings[3].Hex())
  266. assert.Equal(t,
  267. "54791d7514030ded79301dbf221f5bf186facbc5800912411852fdc101b7151d",
  268. siblings[4].Hex())
  269. assert.Equal(t,
  270. "435d28bc0511f8feb93b5f1649a049b460947702ce0baaefcf596175370fe01e",
  271. siblings[5].Hex())
  272. }
  273. func TestVerifyProofCases(t *testing.T) {
  274. mt := newTestingMerkle(t, 140)
  275. defer mt.DB().Close()
  276. for i := 0; i < 8; i++ {
  277. if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
  278. t.Fatal(err)
  279. }
  280. }
  281. // Existence proof
  282. proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
  283. if err != nil {
  284. t.Fatal(err)
  285. }
  286. assert.Equal(t, proof.Existence, true)
  287. assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(4), big.NewInt(0)))
  288. assert.Equal(t, "0003000000000000000000000000000000000000000000000000000000000007a6d6b46fefe213a6b579844a1bb7ab5c2db4a13f8662d9c5e729c36728f42730211ddfcc8d30ebd157d1d6912769b8e4abdca41e5dc2b57b026a361c091a8c14c748530e61bf8ea80c987657c3d24b134ece1ef8e2d4bd3f74437bf4392a6b1e", hex.EncodeToString(proof.Bytes())) //nolint:lll
  289. for i := 8; i < 32; i++ {
  290. proof, _, err = mt.GenerateProof(big.NewInt(int64(i)), nil)
  291. assert.Nil(t, err)
  292. if debug {
  293. fmt.Println(i, proof)
  294. }
  295. }
  296. // Non-existence proof, empty aux
  297. proof, _, err = mt.GenerateProof(big.NewInt(12), nil)
  298. if err != nil {
  299. t.Fatal(err)
  300. }
  301. assert.Equal(t, proof.Existence, false)
  302. // assert.True(t, proof.nodeAux == nil)
  303. assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(12), big.NewInt(0)))
  304. assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007a6d6b46fefe213a6b579844a1bb7ab5c2db4a13f8662d9c5e729c36728f42730211ddfcc8d30ebd157d1d6912769b8e4abdca41e5dc2b57b026a361c091a8c14c748530e61bf8ea80c987657c3d24b134ece1ef8e2d4bd3f74437bf4392a6b1e04000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll
  305. // Non-existence proof, diff. node aux
  306. proof, _, err = mt.GenerateProof(big.NewInt(10), nil)
  307. if err != nil {
  308. t.Fatal(err)
  309. }
  310. assert.Equal(t, proof.Existence, false)
  311. assert.True(t, proof.NodeAux != nil)
  312. assert.True(t, VerifyProof(mt.Root(), proof, big.NewInt(10), big.NewInt(0)))
  313. assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007a6d6b46fefe213a6b579844a1bb7ab5c2db4a13f8662d9c5e729c36728f42730e667e2ca15909c4a23beff18e3cc74348fbd3c1a4c765a5bbbca126c9607a42b77e008a73926f1280f8531b139dc1cacf8d83fcec31d405f5c51b7cbddfe152902000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll
  314. }
  315. func TestVerifyProofFalse(t *testing.T) {
  316. mt := newTestingMerkle(t, 140)
  317. defer mt.DB().Close()
  318. for i := 0; i < 8; i++ {
  319. if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
  320. t.Fatal(err)
  321. }
  322. }
  323. // Invalid existence proof (node used for verification doesn't
  324. // correspond to node in the proof)
  325. proof, _, err := mt.GenerateProof(big.NewInt(int64(4)), nil)
  326. if err != nil {
  327. t.Fatal(err)
  328. }
  329. assert.Equal(t, proof.Existence, true)
  330. assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(int64(5)), big.NewInt(int64(5))))
  331. // Invalid non-existence proof (Non-existence proof, diff. node aux)
  332. proof, _, err = mt.GenerateProof(big.NewInt(int64(4)), nil)
  333. if err != nil {
  334. t.Fatal(err)
  335. }
  336. assert.Equal(t, proof.Existence, true)
  337. // Now we change the proof from existence to non-existence, and add e's
  338. // data as auxiliary node.
  339. proof.Existence = false
  340. proof.NodeAux = &NodeAux{Key: NewHashFromBigInt(big.NewInt(int64(4))),
  341. Value: NewHashFromBigInt(big.NewInt(4))}
  342. assert.True(t, !VerifyProof(mt.Root(), proof, big.NewInt(int64(4)), big.NewInt(0)))
  343. }
  344. func TestGraphViz(t *testing.T) {
  345. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  346. assert.Nil(t, err)
  347. _ = mt.Add(big.NewInt(1), big.NewInt(0))
  348. _ = mt.Add(big.NewInt(2), big.NewInt(0))
  349. _ = mt.Add(big.NewInt(3), big.NewInt(0))
  350. _ = mt.Add(big.NewInt(4), big.NewInt(0))
  351. _ = mt.Add(big.NewInt(5), big.NewInt(0))
  352. _ = mt.Add(big.NewInt(100), big.NewInt(0))
  353. // mt.PrintGraphViz(nil)
  354. expected := `digraph hierarchy {
  355. node [fontname=Monospace,fontsize=10,shape=box]
  356. "16053348..." -> {"19137630..." "14119616..."}
  357. "19137630..." -> {"19543983..." "19746229..."}
  358. "19543983..." -> {"empty0" "65773153..."}
  359. "empty0" [style=dashed,label=0];
  360. "65773153..." -> {"73498412..." "empty1"}
  361. "empty1" [style=dashed,label=0];
  362. "73498412..." -> {"53169236..." "empty2"}
  363. "empty2" [style=dashed,label=0];
  364. "53169236..." -> {"73522717..." "34811870..."}
  365. "73522717..." [style=filled];
  366. "34811870..." [style=filled];
  367. "19746229..." [style=filled];
  368. "14119616..." -> {"19419204..." "15569531..."}
  369. "19419204..." -> {"78154875..." "34589916..."}
  370. "78154875..." [style=filled];
  371. "34589916..." [style=filled];
  372. "15569531..." [style=filled];
  373. }
  374. `
  375. w := bytes.NewBufferString("")
  376. err = mt.GraphViz(w, nil)
  377. assert.Nil(t, err)
  378. assert.Equal(t, []byte(expected), w.Bytes())
  379. }
  380. func TestDelete(t *testing.T) {
  381. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  382. assert.Nil(t, err)
  383. assert.Equal(t, "0", mt.Root().String())
  384. // test vectors generated using https://github.com/iden3/circomlib smt.js
  385. err = mt.Add(big.NewInt(1), big.NewInt(2))
  386. assert.Nil(t, err)
  387. assert.Equal(t, "6449712043256457369579901840927028403950625973089336675272087704159094984964", mt.Root().BigInt().String()) //nolint:lll
  388. err = mt.Add(big.NewInt(33), big.NewInt(44))
  389. assert.Nil(t, err)
  390. assert.Equal(t, "11404118908468506234838877883514126008995570353394659302846433035311596046064", mt.Root().BigInt().String()) //nolint:lll
  391. err = mt.Add(big.NewInt(1234), big.NewInt(9876))
  392. assert.Nil(t, err)
  393. assert.Equal(t, "12841932325181810040554102151615400973767747666110051836366805309524360490677", mt.Root().BigInt().String()) //nolint:lll
  394. // mt.PrintGraphViz(nil)
  395. err = mt.Delete(big.NewInt(33))
  396. // mt.PrintGraphViz(nil)
  397. assert.Nil(t, err)
  398. assert.Equal(t, "16195585003843604118922861401064871511855368913846540536604351220077317790615", mt.Root().BigInt().String()) //nolint:lll
  399. err = mt.Delete(big.NewInt(1234))
  400. assert.Nil(t, err)
  401. err = mt.Delete(big.NewInt(1))
  402. assert.Nil(t, err)
  403. assert.Equal(t, "0", mt.Root().String())
  404. dbRoot, err := mt.dbGetRoot()
  405. require.Nil(t, err)
  406. assert.Equal(t, mt.Root(), dbRoot)
  407. }
  408. func TestDelete2(t *testing.T) {
  409. mt := newTestingMerkle(t, 140)
  410. defer mt.db.Close()
  411. for i := 0; i < 8; i++ {
  412. k := big.NewInt(int64(i))
  413. v := big.NewInt(0)
  414. if err := mt.Add(k, v); err != nil {
  415. t.Fatal(err)
  416. }
  417. }
  418. expectedRoot := mt.Root()
  419. k := big.NewInt(8)
  420. v := big.NewInt(0)
  421. err := mt.Add(k, v)
  422. require.Nil(t, err)
  423. err = mt.Delete(big.NewInt(8))
  424. assert.Nil(t, err)
  425. assert.Equal(t, expectedRoot, mt.Root())
  426. mt2 := newTestingMerkle(t, 140)
  427. defer mt2.db.Close()
  428. for i := 0; i < 8; i++ {
  429. k := big.NewInt(int64(i))
  430. v := big.NewInt(0)
  431. if err := mt2.Add(k, v); err != nil {
  432. t.Fatal(err)
  433. }
  434. }
  435. assert.Equal(t, mt2.Root(), mt.Root())
  436. }
  437. func TestDelete3(t *testing.T) {
  438. mt := newTestingMerkle(t, 140)
  439. defer mt.db.Close()
  440. err := mt.Add(big.NewInt(1), big.NewInt(1))
  441. assert.Nil(t, err)
  442. err = mt.Add(big.NewInt(2), big.NewInt(2))
  443. assert.Nil(t, err)
  444. assert.Equal(t, "6701939280963330813043570145125351311131831356446202146710280245621673558344", mt.Root().BigInt().String()) //nolint:lll
  445. err = mt.Delete(big.NewInt(1))
  446. assert.Nil(t, err)
  447. assert.Equal(t, "10304354743004778619823249005484018655542356856535590307973732141291410579841", mt.Root().BigInt().String()) //nolint:lll
  448. mt2 := newTestingMerkle(t, 140)
  449. defer mt2.db.Close()
  450. err = mt2.Add(big.NewInt(2), big.NewInt(2))
  451. assert.Nil(t, err)
  452. assert.Equal(t, mt2.Root(), mt.Root())
  453. }
  454. func TestDelete4(t *testing.T) {
  455. mt := newTestingMerkle(t, 140)
  456. defer mt.db.Close()
  457. err := mt.Add(big.NewInt(1), big.NewInt(1))
  458. assert.Nil(t, err)
  459. err = mt.Add(big.NewInt(2), big.NewInt(2))
  460. assert.Nil(t, err)
  461. err = mt.Add(big.NewInt(3), big.NewInt(3))
  462. assert.Nil(t, err)
  463. assert.Equal(t, "6989694633650442615746486460134957295274675622748484439660143938730686550248", mt.Root().BigInt().String()) //nolint:lll
  464. err = mt.Delete(big.NewInt(1))
  465. assert.Nil(t, err)
  466. assert.Equal(t, "1192610901536912535888866440319084773171371421781091005185759505381507049136", mt.Root().BigInt().String()) //nolint:lll
  467. mt2 := newTestingMerkle(t, 140)
  468. defer mt2.db.Close()
  469. err = mt2.Add(big.NewInt(2), big.NewInt(2))
  470. assert.Nil(t, err)
  471. err = mt2.Add(big.NewInt(3), big.NewInt(3))
  472. assert.Nil(t, err)
  473. assert.Equal(t, mt2.Root(), mt.Root())
  474. }
  475. func TestDelete5(t *testing.T) {
  476. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  477. assert.Nil(t, err)
  478. err = mt.Add(big.NewInt(1), big.NewInt(2))
  479. assert.Nil(t, err)
  480. err = mt.Add(big.NewInt(33), big.NewInt(44))
  481. assert.Nil(t, err)
  482. assert.Equal(t, "11404118908468506234838877883514126008995570353394659302846433035311596046064", mt.Root().BigInt().String()) //nolint:lll
  483. err = mt.Delete(big.NewInt(1))
  484. assert.Nil(t, err)
  485. assert.Equal(t, "12802904154263054831102426711825443668153853847661287611768065280921698471037", mt.Root().BigInt().String()) //nolint:lll
  486. mt2 := newTestingMerkle(t, 140)
  487. defer mt2.db.Close()
  488. err = mt2.Add(big.NewInt(33), big.NewInt(44))
  489. assert.Nil(t, err)
  490. assert.Equal(t, mt2.Root(), mt.Root())
  491. }
  492. func TestDeleteNonExistingKeys(t *testing.T) {
  493. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  494. assert.Nil(t, err)
  495. err = mt.Add(big.NewInt(1), big.NewInt(2))
  496. assert.Nil(t, err)
  497. err = mt.Add(big.NewInt(33), big.NewInt(44))
  498. assert.Nil(t, err)
  499. err = mt.Delete(big.NewInt(33))
  500. assert.Nil(t, err)
  501. err = mt.Delete(big.NewInt(33))
  502. assert.Equal(t, ErrKeyNotFound, err)
  503. err = mt.Delete(big.NewInt(1))
  504. assert.Nil(t, err)
  505. assert.Equal(t, "0", mt.Root().String())
  506. err = mt.Delete(big.NewInt(33))
  507. assert.Equal(t, ErrKeyNotFound, err)
  508. }
  509. func TestDumpLeafsImportLeafs(t *testing.T) {
  510. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
  511. require.Nil(t, err)
  512. defer mt.db.Close()
  513. q1 := new(big.Int).Sub(constants.Q, big.NewInt(1))
  514. for i := 0; i < 10; i++ {
  515. // use numbers near under Q
  516. k := new(big.Int).Sub(q1, big.NewInt(int64(i)))
  517. v := big.NewInt(0)
  518. err = mt.Add(k, v)
  519. require.Nil(t, err)
  520. // use numbers near above 0
  521. k = big.NewInt(int64(i))
  522. err = mt.Add(k, v)
  523. require.Nil(t, err)
  524. }
  525. d, err := mt.DumpLeafs(nil)
  526. assert.Nil(t, err)
  527. mt2, err := NewMerkleTree(memory.NewMemoryStorage(), 140)
  528. require.Nil(t, err)
  529. defer mt2.db.Close()
  530. err = mt2.ImportDumpedLeafs(d)
  531. assert.Nil(t, err)
  532. assert.Equal(t, mt.Root(), mt2.Root())
  533. }
  534. func TestAddAndGetCircomProof(t *testing.T) {
  535. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 10)
  536. assert.Nil(t, err)
  537. assert.Equal(t, "0", mt.Root().String())
  538. // test vectors generated using https://github.com/iden3/circomlib smt.js
  539. cpp, err := mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2))
  540. assert.Nil(t, err)
  541. assert.Equal(t, "0", cpp.OldRoot.String())
  542. assert.Equal(t, "64497120...", cpp.NewRoot.String())
  543. assert.Equal(t, "0", cpp.OldKey.String())
  544. assert.Equal(t, "0", cpp.OldValue.String())
  545. assert.Equal(t, "1", cpp.NewKey.String())
  546. assert.Equal(t, "2", cpp.NewValue.String())
  547. assert.Equal(t, true, cpp.IsOld0)
  548. assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  549. assert.Equal(t, mt.maxLevels+1, len(cpp.Siblings))
  550. cpp, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44))
  551. assert.Nil(t, err)
  552. assert.Equal(t, "64497120...", cpp.OldRoot.String())
  553. assert.Equal(t, "11404118...", cpp.NewRoot.String())
  554. assert.Equal(t, "1", cpp.OldKey.String())
  555. assert.Equal(t, "2", cpp.OldValue.String())
  556. assert.Equal(t, "33", cpp.NewKey.String())
  557. assert.Equal(t, "44", cpp.NewValue.String())
  558. assert.Equal(t, false, cpp.IsOld0)
  559. assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  560. assert.Equal(t, mt.maxLevels+1, len(cpp.Siblings))
  561. cpp, err = mt.AddAndGetCircomProof(big.NewInt(55), big.NewInt(66))
  562. assert.Nil(t, err)
  563. assert.Equal(t, "11404118...", cpp.OldRoot.String())
  564. assert.Equal(t, "18284203...", cpp.NewRoot.String())
  565. assert.Equal(t, "0", cpp.OldKey.String())
  566. assert.Equal(t, "0", cpp.OldValue.String())
  567. assert.Equal(t, "55", cpp.NewKey.String())
  568. assert.Equal(t, "66", cpp.NewValue.String())
  569. assert.Equal(t, true, cpp.IsOld0)
  570. assert.Equal(t, "[0 42948778... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  571. assert.Equal(t, mt.maxLevels+1, len(cpp.Siblings))
  572. }
  573. func TestUpdateCircomProcessorProof(t *testing.T) {
  574. mt := newTestingMerkle(t, 10)
  575. defer mt.db.Close()
  576. for i := 0; i < 16; i++ {
  577. k := big.NewInt(int64(i))
  578. v := big.NewInt(int64(i * 2))
  579. if err := mt.Add(k, v); err != nil {
  580. t.Fatal(err)
  581. }
  582. }
  583. _, v, _, err := mt.Get(big.NewInt(10))
  584. assert.Nil(t, err)
  585. assert.Equal(t, big.NewInt(20), v)
  586. // test vectors generated using https://github.com/iden3/circomlib smt.js
  587. cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
  588. assert.Nil(t, err)
  589. assert.Equal(t, "14895645...", cpp.OldRoot.String())
  590. assert.Equal(t, "75223641...", cpp.NewRoot.String())
  591. assert.Equal(t, "10", cpp.OldKey.String())
  592. assert.Equal(t, "20", cpp.OldValue.String())
  593. assert.Equal(t, "10", cpp.NewKey.String())
  594. assert.Equal(t, "1024", cpp.NewValue.String())
  595. assert.Equal(t, false, cpp.IsOld0)
  596. assert.Equal(t,
  597. "[19625419... 46910949... 18399594... 20473908... 0 0 0 0 0 0 0]",
  598. fmt.Sprintf("%v", cpp.Siblings))
  599. }
  600. func TestSmtVerifier(t *testing.T) {
  601. mt, err := NewMerkleTree(memory.NewMemoryStorage(), 4)
  602. assert.Nil(t, err)
  603. err = mt.Add(big.NewInt(1), big.NewInt(11))
  604. assert.Nil(t, err)
  605. cvp, err := mt.GenerateSCVerifierProof(big.NewInt(1), nil)
  606. assert.Nil(t, err)
  607. jCvp, err := json.Marshal(cvp)
  608. assert.Nil(t, err)
  609. // expect siblings to be '[]', instead of 'null'
  610. expected := `{"root":"14137057030252181222327992235694793580963111268072013054745223667806564674729","siblings":[],"oldKey":"0","oldValue":"0","isOld0":false,"key":"1","value":"11","fnc":0}` //nolint:lll
  611. assert.Equal(t, expected, string(jCvp))
  612. err = mt.Add(big.NewInt(2), big.NewInt(22))
  613. assert.Nil(t, err)
  614. err = mt.Add(big.NewInt(3), big.NewInt(33))
  615. assert.Nil(t, err)
  616. err = mt.Add(big.NewInt(4), big.NewInt(44))
  617. assert.Nil(t, err)
  618. cvp, err = mt.GenerateCircomVerifierProof(big.NewInt(2), nil)
  619. assert.Nil(t, err)
  620. jCvp, err = json.Marshal(cvp)
  621. assert.Nil(t, err)
  622. // Test vectors generated using https://github.com/iden3/circomlib smt.js
  623. // Expect siblings with the extra 0 that the circom circuits need
  624. expected = `{"root":"10171140035965439966839815283432442651152991056297946102647688349369299124493","siblings":["12422661758472400223401299094238820777063458096110016599986781158438915645129","4330149052063565277182642012557086942088176847773467265587998154672740895682","0","0","0"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll
  625. assert.Equal(t, expected, string(jCvp))
  626. cvp, err = mt.GenerateSCVerifierProof(big.NewInt(2), nil)
  627. assert.Nil(t, err)
  628. jCvp, err = json.Marshal(cvp)
  629. assert.Nil(t, err)
  630. // Test vectors generated using https://github.com/iden3/circomlib smt.js
  631. // Without the extra 0 that the circom circuits need, but that are not
  632. // needed at a smart contract verification
  633. expected = `{"root":"10171140035965439966839815283432442651152991056297946102647688349369299124493","siblings":["12422661758472400223401299094238820777063458096110016599986781158438915645129","4330149052063565277182642012557086942088176847773467265587998154672740895682"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll
  634. assert.Equal(t, expected, string(jCvp))
  635. }
  636. func TestTypesMarshalers(t *testing.T) {
  637. // test Hash marshalers
  638. h, err := NewHashFromString("42")
  639. assert.Nil(t, err)
  640. s, err := json.Marshal(h)
  641. assert.Nil(t, err)
  642. var h2 *Hash
  643. err = json.Unmarshal(s, &h2)
  644. assert.Nil(t, err)
  645. assert.Equal(t, h, h2)
  646. // create CircomProcessorProof
  647. mt := newTestingMerkle(t, 10)
  648. defer mt.db.Close()
  649. for i := 0; i < 16; i++ {
  650. k := big.NewInt(int64(i))
  651. v := big.NewInt(int64(i * 2))
  652. if err := mt.Add(k, v); err != nil {
  653. t.Fatal(err)
  654. }
  655. }
  656. _, v, _, err := mt.Get(big.NewInt(10))
  657. assert.Nil(t, err)
  658. assert.Equal(t, big.NewInt(20), v)
  659. cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
  660. assert.Nil(t, err)
  661. // test CircomProcessorProof marshalers
  662. b, err := json.Marshal(&cpp)
  663. assert.Nil(t, err)
  664. var cpp2 *CircomProcessorProof
  665. err = json.Unmarshal(b, &cpp2)
  666. assert.Nil(t, err)
  667. assert.Equal(t, cpp, cpp2)
  668. }