You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

871 lines
27 KiB

  1. package sql
  2. import (
  3. "bytes"
  4. "encoding/hex"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "github.com/iden3/go-iden3-crypto/constants"
  9. cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
  10. "github.com/iden3/go-merkletree"
  11. "github.com/iden3/go-merkletree/db/memory"
  12. "github.com/iden3/go-merkletree/db/test"
  13. "github.com/jmoiron/sqlx"
  14. "github.com/stretchr/testify/assert"
  15. "github.com/stretchr/testify/require"
  16. "math/big"
  17. "os"
  18. "strconv"
  19. "testing"
  20. )
  21. var maxMTId uint64 = 0
  22. var cleared = false
  23. func setupDB() (*sqlx.DB, error) {
  24. var err error
  25. host := os.Getenv("PGHOST")
  26. if host == "" {
  27. host = "localhost"
  28. }
  29. port, _ := strconv.Atoi(os.Getenv("PGPORT"))
  30. if port == 0 {
  31. port = 5432
  32. }
  33. user := os.Getenv("PGUSER")
  34. if user == "" {
  35. user = "user"
  36. }
  37. password := os.Getenv("PGPASSWORD")
  38. if password == "" {
  39. return nil, errors.New("No PGPASSWORD envvar specified")
  40. }
  41. dbname := os.Getenv("PGDATABASE")
  42. if dbname == "" {
  43. dbname = "test"
  44. }
  45. psqlconn := fmt.Sprintf(
  46. "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
  47. host,
  48. port,
  49. user,
  50. password,
  51. dbname,
  52. )
  53. dbx, err := sqlx.Connect("postgres", psqlconn)
  54. if err != nil {
  55. return nil, err
  56. }
  57. // clear MerkleTree table
  58. //if !cleared {
  59. dbx.Exec("TRUNCATE TABLE mt_roots")
  60. dbx.Exec("TRUNCATE TABLE mt_nodes")
  61. cleared = true
  62. //}
  63. return dbx, nil
  64. }
  65. func sqlStorage(t *testing.T) merkletree.Storage {
  66. dbx, err := setupDB()
  67. if err != nil {
  68. t.Fatal(err)
  69. return nil
  70. }
  71. sto, err := NewSqlStorage(dbx, false)
  72. if err != nil {
  73. t.Fatal(err)
  74. return nil
  75. }
  76. sto.mtId = maxMTId
  77. maxMTId++
  78. t.Cleanup(func() {
  79. })
  80. return sto
  81. }
  82. func TestReturnKnownErrIfNotExists(t *testing.T) {
  83. test.TestReturnKnownErrIfNotExists(t, sqlStorage(t))
  84. }
  85. func TestStorageInsertGet(t *testing.T) {
  86. test.TestStorageInsertGet(t, sqlStorage(t))
  87. }
  88. func TestStorageWithPrefix(t *testing.T) {
  89. test.TestStorageWithPrefix(t, sqlStorage(t))
  90. }
  91. func TestSql(t *testing.T) {
  92. //sto := sqlStorage(t)
  93. t.Run("TestReturnKnownErrIfNotExists", func(t *testing.T) {
  94. test.TestReturnKnownErrIfNotExists(t, sqlStorage(t))
  95. })
  96. t.Run("TestStorageInsertGet", func(t *testing.T) {
  97. test.TestStorageInsertGet(t, sqlStorage(t))
  98. })
  99. t.Run("TestStorageWithPrefix", func(t *testing.T) {
  100. test.TestStorageWithPrefix(t, sqlStorage(t))
  101. })
  102. test.TestConcatTx(t, sqlStorage(t))
  103. test.TestList(t, sqlStorage(t))
  104. test.TestIterate(t, sqlStorage(t))
  105. test.TestNewTree(t, sqlStorage(t))
  106. test.TestAddDifferentOrder(t, sqlStorage(t), sqlStorage(t))
  107. test.TestAddRepeatedIndex(t, sqlStorage(t))
  108. test.TestGet(t, sqlStorage(t))
  109. test.TestUpdate(t, sqlStorage(t))
  110. test.TestUpdate2(t, sqlStorage(t))
  111. test.TestGenerateAndVerifyProof128(t, sqlStorage(t))
  112. test.TestTreeLimit(t, sqlStorage(t))
  113. test.TestSiblingsFromProof(t, sqlStorage(t))
  114. test.TestVerifyProofCases(t, sqlStorage(t))
  115. test.TestVerifyProofFalse(t, sqlStorage(t))
  116. test.TestGraphViz(t, sqlStorage(t))
  117. test.TestDelete(t, sqlStorage(t))
  118. test.TestDelete2(t, sqlStorage(t), sqlStorage(t))
  119. test.TestDelete3(t, sqlStorage(t), sqlStorage(t))
  120. test.TestDelete4(t, sqlStorage(t), sqlStorage(t))
  121. test.TestDelete5(t, sqlStorage(t), sqlStorage(t))
  122. test.TestDeleteNonExistingKeys(t, sqlStorage(t))
  123. test.TestDumpLeafsImportLeafs(t, sqlStorage(t), sqlStorage(t))
  124. test.TestAddAndGetCircomProof(t, sqlStorage(t))
  125. test.TestUpdateCircomProcessorProof(t, sqlStorage(t))
  126. test.TestSmtVerifier(t, sqlStorage(t))
  127. test.TestTypesMarshalers(t, sqlStorage(t))
  128. }
  129. var debug = false
  130. func newTestingMerkle(f *testing.T, maxLevels int) *merkletree.MerkleTree {
  131. sto := sqlStorage(f)
  132. mt, err := merkletree.NewMerkleTree(sto, maxLevels)
  133. if err != nil {
  134. f.Fatal(err)
  135. return nil
  136. }
  137. return mt
  138. }
  139. func TestHashParsers(t *testing.T) {
  140. h0 := merkletree.NewHashFromBigInt(big.NewInt(0))
  141. assert.Equal(t, "0", h0.String())
  142. h1 := merkletree.NewHashFromBigInt(big.NewInt(1))
  143. assert.Equal(t, "1", h1.String())
  144. h10 := merkletree.NewHashFromBigInt(big.NewInt(10))
  145. assert.Equal(t, "10", h10.String())
  146. h7l := merkletree.NewHashFromBigInt(big.NewInt(1234567))
  147. assert.Equal(t, "1234567", h7l.String())
  148. h8l := merkletree.NewHashFromBigInt(big.NewInt(12345678))
  149. assert.Equal(t, "12345678...", h8l.String())
  150. b, ok := new(big.Int).SetString("4932297968297298434239270129193057052722409868268166443802652458940273154854", 10) //nolint:lll
  151. assert.True(t, ok)
  152. h := merkletree.NewHashFromBigInt(b)
  153. assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", h.BigInt().String()) //nolint:lll
  154. assert.Equal(t, "49322979...", h.String())
  155. assert.Equal(t, "265baaf161e875c372d08e50f52abddc01d32efc93e90290bb8b3d9ceb94e70a", h.Hex())
  156. b1, err := merkletree.NewBigIntFromHashBytes(b.Bytes())
  157. assert.Nil(t, err)
  158. assert.Equal(t, new(big.Int).SetBytes(b.Bytes()).String(), b1.String())
  159. b2, err := merkletree.NewHashFromBytes(b.Bytes())
  160. assert.Nil(t, err)
  161. assert.Equal(t, b.String(), b2.BigInt().String())
  162. h2, err := merkletree.NewHashFromHex(h.Hex())
  163. assert.Nil(t, err)
  164. assert.Equal(t, h, h2)
  165. _, err = merkletree.NewHashFromHex("0x12")
  166. assert.NotNil(t, err)
  167. // check limits
  168. a := new(big.Int).Sub(constants.Q, big.NewInt(1))
  169. testHashParsers(t, a)
  170. a = big.NewInt(int64(1))
  171. testHashParsers(t, a)
  172. }
  173. func testHashParsers(t *testing.T, a *big.Int) {
  174. require.True(t, cryptoUtils.CheckBigIntInField(a))
  175. h := merkletree.NewHashFromBigInt(a)
  176. assert.Equal(t, a, h.BigInt())
  177. hFromBytes, err := merkletree.NewHashFromBytes(h.Bytes())
  178. assert.Nil(t, err)
  179. assert.Equal(t, h, hFromBytes)
  180. assert.Equal(t, a, hFromBytes.BigInt())
  181. assert.Equal(t, a.String(), hFromBytes.BigInt().String())
  182. hFromHex, err := merkletree.NewHashFromHex(h.Hex())
  183. assert.Nil(t, err)
  184. assert.Equal(t, h, hFromHex)
  185. aBIFromHBytes, err := merkletree.NewBigIntFromHashBytes(h.Bytes())
  186. assert.Nil(t, err)
  187. assert.Equal(t, a, aBIFromHBytes)
  188. assert.Equal(t, new(big.Int).SetBytes(a.Bytes()).String(), aBIFromHBytes.String())
  189. }
  190. func TestNewTree(t *testing.T) {
  191. mt := newTestingMerkle(t, 10)
  192. mt, err := merkletree.NewMerkleTree(memory.NewMemoryStorage(), 10)
  193. assert.Nil(t, err)
  194. assert.Equal(t, "0", mt.Root().String())
  195. // test vectors generated using https://github.com/iden3/circomlib smt.js
  196. err = mt.Add(big.NewInt(1), big.NewInt(2))
  197. assert.Nil(t, err)
  198. assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll
  199. err = mt.Add(big.NewInt(33), big.NewInt(44))
  200. assert.Nil(t, err)
  201. assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
  202. err = mt.Add(big.NewInt(1234), big.NewInt(9876))
  203. assert.Nil(t, err)
  204. assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll
  205. proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
  206. assert.Nil(t, err)
  207. assert.Equal(t, big.NewInt(44), v)
  208. assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44)))
  209. assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45)))
  210. }
  211. func TestAddDifferentOrder(t *testing.T) {
  212. mt1 := newTestingMerkle(t, 140)
  213. for i := 0; i < 16; i++ {
  214. k := big.NewInt(int64(i))
  215. v := big.NewInt(0)
  216. if err := mt1.Add(k, v); err != nil {
  217. t.Fatal(err)
  218. }
  219. }
  220. mt2 := newTestingMerkle(t, 140)
  221. for i := 16 - 1; i >= 0; i-- {
  222. k := big.NewInt(int64(i))
  223. v := big.NewInt(0)
  224. if err := mt2.Add(k, v); err != nil {
  225. t.Fatal(err)
  226. }
  227. }
  228. assert.Equal(t, mt1.Root().Hex(), mt2.Root().Hex())
  229. assert.Equal(t, "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f", mt1.Root().Hex()) //nolint:lll
  230. }
  231. func TestAddRepeatedIndex(t *testing.T) {
  232. mt := newTestingMerkle(t, 140)
  233. k := big.NewInt(int64(3))
  234. v := big.NewInt(int64(12))
  235. if err := mt.Add(k, v); err != nil {
  236. t.Fatal(err)
  237. }
  238. err := mt.Add(k, v)
  239. assert.NotNil(t, err)
  240. assert.Equal(t, merkletree.ErrEntryIndexAlreadyExists, err)
  241. }
  242. func TestGet(t *testing.T) {
  243. mt := newTestingMerkle(t, 140)
  244. for i := 0; i < 16; i++ {
  245. k := big.NewInt(int64(i))
  246. v := big.NewInt(int64(i * 2))
  247. if err := mt.Add(k, v); err != nil {
  248. t.Fatal(err)
  249. }
  250. }
  251. k, v, _, err := mt.Get(big.NewInt(10))
  252. assert.Nil(t, err)
  253. assert.Equal(t, big.NewInt(10), k)
  254. assert.Equal(t, big.NewInt(20), v)
  255. k, v, _, err = mt.Get(big.NewInt(15))
  256. assert.Nil(t, err)
  257. assert.Equal(t, big.NewInt(15), k)
  258. assert.Equal(t, big.NewInt(30), v)
  259. k, v, _, err = mt.Get(big.NewInt(16))
  260. assert.NotNil(t, err)
  261. assert.Equal(t, merkletree.ErrKeyNotFound, err)
  262. assert.Equal(t, "0", k.String())
  263. assert.Equal(t, "0", v.String())
  264. }
  265. func TestUpdate(t *testing.T) {
  266. mt := newTestingMerkle(t, 140)
  267. for i := 0; i < 16; i++ {
  268. k := big.NewInt(int64(i))
  269. v := big.NewInt(int64(i * 2))
  270. if err := mt.Add(k, v); err != nil {
  271. t.Fatal(err)
  272. }
  273. }
  274. _, v, _, err := mt.Get(big.NewInt(10))
  275. assert.Nil(t, err)
  276. assert.Equal(t, big.NewInt(20), v)
  277. _, err = mt.Update(big.NewInt(10), big.NewInt(1024))
  278. assert.Nil(t, err)
  279. _, v, _, err = mt.Get(big.NewInt(10))
  280. assert.Nil(t, err)
  281. assert.Equal(t, big.NewInt(1024), v)
  282. _, err = mt.Update(big.NewInt(1000), big.NewInt(1024))
  283. assert.Equal(t, merkletree.ErrKeyNotFound, err)
  284. }
  285. func TestUpdate2(t *testing.T) {
  286. mt1 := newTestingMerkle(t, 140)
  287. mt2 := newTestingMerkle(t, 140)
  288. err := mt1.Add(big.NewInt(1), big.NewInt(119))
  289. assert.Nil(t, err)
  290. err = mt1.Add(big.NewInt(2), big.NewInt(229))
  291. assert.Nil(t, err)
  292. err = mt1.Add(big.NewInt(9876), big.NewInt(6789))
  293. assert.Nil(t, err)
  294. err = mt2.Add(big.NewInt(1), big.NewInt(11))
  295. assert.Nil(t, err)
  296. err = mt2.Add(big.NewInt(2), big.NewInt(22))
  297. assert.Nil(t, err)
  298. err = mt2.Add(big.NewInt(9876), big.NewInt(10))
  299. assert.Nil(t, err)
  300. _, err = mt1.Update(big.NewInt(1), big.NewInt(11))
  301. assert.Nil(t, err)
  302. _, err = mt1.Update(big.NewInt(2), big.NewInt(22))
  303. assert.Nil(t, err)
  304. _, err = mt2.Update(big.NewInt(9876), big.NewInt(6789))
  305. assert.Nil(t, err)
  306. assert.Equal(t, mt1.Root(), mt2.Root())
  307. }
  308. func TestGenerateAndVerifyProof128(t *testing.T) {
  309. mt := newTestingMerkle(t, 140)
  310. for i := 0; i < 128; i++ {
  311. k := big.NewInt(int64(i))
  312. v := big.NewInt(0)
  313. if err := mt.Add(k, v); err != nil {
  314. t.Fatal(err)
  315. }
  316. }
  317. proof, v, err := mt.GenerateProof(big.NewInt(42), nil)
  318. assert.Nil(t, err)
  319. assert.Equal(t, "0", v.String())
  320. assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(42), big.NewInt(0)))
  321. }
  322. func TestTreeLimit(t *testing.T) {
  323. mt := newTestingMerkle(t, 5)
  324. for i := 0; i < 16; i++ {
  325. err := mt.Add(big.NewInt(int64(i)), big.NewInt(int64(i)))
  326. assert.Nil(t, err)
  327. }
  328. // here the tree is full, should not allow to add more data as reaches the maximum number of levels
  329. err := mt.Add(big.NewInt(int64(16)), big.NewInt(int64(16)))
  330. assert.NotNil(t, err)
  331. assert.Equal(t, merkletree.ErrReachedMaxLevel, err)
  332. }
  333. func TestSiblingsFromProof(t *testing.T) {
  334. mt := newTestingMerkle(t, 140)
  335. for i := 0; i < 64; i++ {
  336. k := big.NewInt(int64(i))
  337. v := big.NewInt(0)
  338. if err := mt.Add(k, v); err != nil {
  339. t.Fatal(err)
  340. }
  341. }
  342. proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
  343. if err != nil {
  344. t.Fatal(err)
  345. }
  346. siblings := merkletree.SiblingsFromProof(proof)
  347. assert.Equal(t, 6, len(siblings))
  348. assert.Equal(t,
  349. "d6e368bda90c5ee3e910222c1fc1c0d9e23f2d350dbc47f4a92de30f1be3c60b",
  350. siblings[0].Hex())
  351. assert.Equal(t,
  352. "9dbd03b1bcd580e0f3e6668d80d55288f04464126feb1624ec8ee30be8df9c16",
  353. siblings[1].Hex())
  354. assert.Equal(t,
  355. "de866af9545dcd1c5bb7811e7f27814918e037eb9fead40919e8f19525896e27",
  356. siblings[2].Hex())
  357. assert.Equal(t,
  358. "5f4182212a84741d1174ba7c42e369f2e3ad8ade7d04eea2d0f98e3ed8b7a317",
  359. siblings[3].Hex())
  360. assert.Equal(t,
  361. "77639098d513f7aef9730fdb1d1200401af5fe9da91b61772f4dd142ac89a122",
  362. siblings[4].Hex())
  363. assert.Equal(t,
  364. "943ee501f4ba2137c79b54af745dfc5f105f539fcc449cd2a356eb5c030e3c07",
  365. siblings[5].Hex())
  366. }
  367. func TestVerifyProofCases(t *testing.T) {
  368. mt := newTestingMerkle(t, 140)
  369. defer mt.DB().Close()
  370. for i := 0; i < 8; i++ {
  371. if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
  372. t.Fatal(err)
  373. }
  374. }
  375. // Existence proof
  376. proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
  377. if err != nil {
  378. t.Fatal(err)
  379. }
  380. assert.Equal(t, proof.Existence, true)
  381. assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(4), big.NewInt(0)))
  382. assert.Equal(t, "0003000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df06", hex.EncodeToString(proof.Bytes())) //nolint:lll
  383. for i := 8; i < 32; i++ {
  384. proof, _, err = mt.GenerateProof(big.NewInt(int64(i)), nil)
  385. assert.Nil(t, err)
  386. if debug {
  387. fmt.Println(i, proof)
  388. }
  389. }
  390. // Non-existence proof, empty aux
  391. proof, _, err = mt.GenerateProof(big.NewInt(12), nil)
  392. if err != nil {
  393. t.Fatal(err)
  394. }
  395. assert.Equal(t, proof.Existence, false)
  396. // assert.True(t, proof.nodeAux == nil)
  397. assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(12), big.NewInt(0)))
  398. assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df0604000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll
  399. // Non-existence proof, diff. node aux
  400. proof, _, err = mt.GenerateProof(big.NewInt(10), nil)
  401. if err != nil {
  402. t.Fatal(err)
  403. }
  404. assert.Equal(t, proof.Existence, false)
  405. assert.True(t, proof.NodeAux != nil)
  406. assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(10), big.NewInt(0)))
  407. assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e4030acfcdd2617df9eb5aef744c5f2e03eb8c92c61f679007dc1f2707fd908ea41a9433745b469c101edca814c498e7f388100d497b24f1d2ac935bced3572f591d02000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll
  408. }
  409. func TestVerifyProofFalse(t *testing.T) {
  410. mt := newTestingMerkle(t, 140)
  411. defer mt.DB().Close()
  412. for i := 0; i < 8; i++ {
  413. if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
  414. t.Fatal(err)
  415. }
  416. }
  417. // Invalid existence proof (node used for verification doesn't
  418. // correspond to node in the proof)
  419. proof, _, err := mt.GenerateProof(big.NewInt(int64(4)), nil)
  420. if err != nil {
  421. t.Fatal(err)
  422. }
  423. assert.Equal(t, proof.Existence, true)
  424. assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(int64(5)), big.NewInt(int64(5))))
  425. // Invalid non-existence proof (Non-existence proof, diff. node aux)
  426. proof, _, err = mt.GenerateProof(big.NewInt(int64(4)), nil)
  427. if err != nil {
  428. t.Fatal(err)
  429. }
  430. assert.Equal(t, proof.Existence, true)
  431. // Now we change the proof from existence to non-existence, and add e's
  432. // data as auxiliary node.
  433. proof.Existence = false
  434. proof.NodeAux = &merkletree.NodeAux{Key: merkletree.NewHashFromBigInt(big.NewInt(int64(4))),
  435. Value: merkletree.NewHashFromBigInt(big.NewInt(4))}
  436. assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(int64(4)), big.NewInt(0)))
  437. }
  438. func TestGraphViz(t *testing.T) {
  439. mt := newTestingMerkle(t, 140)
  440. _ = mt.Add(big.NewInt(1), big.NewInt(0))
  441. _ = mt.Add(big.NewInt(2), big.NewInt(0))
  442. _ = mt.Add(big.NewInt(3), big.NewInt(0))
  443. _ = mt.Add(big.NewInt(4), big.NewInt(0))
  444. _ = mt.Add(big.NewInt(5), big.NewInt(0))
  445. _ = mt.Add(big.NewInt(100), big.NewInt(0))
  446. // mt.PrintGraphViz(nil)
  447. expected := `digraph hierarchy {
  448. node [fontname=Monospace,fontsize=10,shape=box]
  449. "56332309..." -> {"18483622..." "20902180..."}
  450. "18483622..." -> {"75768243..." "16893244..."}
  451. "75768243..." -> {"empty0" "21857056..."}
  452. "empty0" [style=dashed,label=0];
  453. "21857056..." -> {"51072523..." "empty1"}
  454. "empty1" [style=dashed,label=0];
  455. "51072523..." -> {"17311038..." "empty2"}
  456. "empty2" [style=dashed,label=0];
  457. "17311038..." -> {"69499803..." "21008290..."}
  458. "69499803..." [style=filled];
  459. "21008290..." [style=filled];
  460. "16893244..." [style=filled];
  461. "20902180..." -> {"12496585..." "18055627..."}
  462. "12496585..." -> {"19374975..." "15739329..."}
  463. "19374975..." [style=filled];
  464. "15739329..." [style=filled];
  465. "18055627..." [style=filled];
  466. }
  467. `
  468. w := bytes.NewBufferString("")
  469. err := mt.GraphViz(w, nil)
  470. assert.Nil(t, err)
  471. assert.Equal(t, []byte(expected), w.Bytes())
  472. }
  473. func TestDelete(t *testing.T) {
  474. mt := newTestingMerkle(t, 10)
  475. assert.Equal(t, "0", mt.Root().String())
  476. // test vectors generated using https://github.com/iden3/circomlib smt.js
  477. err := mt.Add(big.NewInt(1), big.NewInt(2))
  478. assert.Nil(t, err)
  479. assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll
  480. err = mt.Add(big.NewInt(33), big.NewInt(44))
  481. assert.Nil(t, err)
  482. assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
  483. err = mt.Add(big.NewInt(1234), big.NewInt(9876))
  484. assert.Nil(t, err)
  485. assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll
  486. // mt.PrintGraphViz(nil)
  487. err = mt.Delete(big.NewInt(33))
  488. // mt.PrintGraphViz(nil)
  489. assert.Nil(t, err)
  490. assert.Equal(t, "15550352095346187559699212771793131433118240951738528922418613687814377955591", mt.Root().BigInt().String()) //nolint:lll
  491. err = mt.Delete(big.NewInt(1234))
  492. assert.Nil(t, err)
  493. err = mt.Delete(big.NewInt(1))
  494. assert.Nil(t, err)
  495. assert.Equal(t, "0", mt.Root().String())
  496. }
  497. func TestDelete2(t *testing.T) {
  498. mt := newTestingMerkle(t, 140)
  499. for i := 0; i < 8; i++ {
  500. k := big.NewInt(int64(i))
  501. v := big.NewInt(0)
  502. if err := mt.Add(k, v); err != nil {
  503. t.Fatal(err)
  504. }
  505. }
  506. expectedRoot := mt.Root()
  507. k := big.NewInt(8)
  508. v := big.NewInt(0)
  509. err := mt.Add(k, v)
  510. require.Nil(t, err)
  511. err = mt.Delete(big.NewInt(8))
  512. assert.Nil(t, err)
  513. assert.Equal(t, expectedRoot, mt.Root())
  514. mt2 := newTestingMerkle(t, 140)
  515. for i := 0; i < 8; i++ {
  516. k := big.NewInt(int64(i))
  517. v := big.NewInt(0)
  518. if err := mt2.Add(k, v); err != nil {
  519. t.Fatal(err)
  520. }
  521. }
  522. assert.Equal(t, mt2.Root(), mt.Root())
  523. }
  524. func TestDelete3(t *testing.T) {
  525. mt := newTestingMerkle(t, 140)
  526. err := mt.Add(big.NewInt(1), big.NewInt(1))
  527. assert.Nil(t, err)
  528. err = mt.Add(big.NewInt(2), big.NewInt(2))
  529. assert.Nil(t, err)
  530. assert.Equal(t, "19060075022714027595905950662613111880864833370144986660188929919683258088314", mt.Root().BigInt().String()) //nolint:lll
  531. err = mt.Delete(big.NewInt(1))
  532. assert.Nil(t, err)
  533. assert.Equal(t, "849831128489032619062850458217693666094013083866167024127442191257793527951", mt.Root().BigInt().String()) //nolint:lll
  534. mt2 := newTestingMerkle(t, 140)
  535. err = mt2.Add(big.NewInt(2), big.NewInt(2))
  536. assert.Nil(t, err)
  537. assert.Equal(t, mt2.Root(), mt.Root())
  538. }
  539. func TestDelete4(t *testing.T) {
  540. mt := newTestingMerkle(t, 140)
  541. err := mt.Add(big.NewInt(1), big.NewInt(1))
  542. assert.Nil(t, err)
  543. err = mt.Add(big.NewInt(2), big.NewInt(2))
  544. assert.Nil(t, err)
  545. err = mt.Add(big.NewInt(3), big.NewInt(3))
  546. assert.Nil(t, err)
  547. assert.Equal(t, "14109632483797541575275728657193822866549917334388996328141438956557066918117", mt.Root().BigInt().String()) //nolint:lll
  548. err = mt.Delete(big.NewInt(1))
  549. assert.Nil(t, err)
  550. assert.Equal(t, "159935162486187606489815340465698714590556679404589449576549073038844694972", mt.Root().BigInt().String()) //nolint:lll
  551. mt2 := newTestingMerkle(t, 140)
  552. err = mt2.Add(big.NewInt(2), big.NewInt(2))
  553. assert.Nil(t, err)
  554. err = mt2.Add(big.NewInt(3), big.NewInt(3))
  555. assert.Nil(t, err)
  556. assert.Equal(t, mt2.Root(), mt.Root())
  557. }
  558. func TestDelete5(t *testing.T) {
  559. mt := newTestingMerkle(t, 10)
  560. err := mt.Add(big.NewInt(1), big.NewInt(2))
  561. assert.Nil(t, err)
  562. err = mt.Add(big.NewInt(33), big.NewInt(44))
  563. assert.Nil(t, err)
  564. assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
  565. err = mt.Delete(big.NewInt(1))
  566. assert.Nil(t, err)
  567. assert.Equal(t, "18869260084287237667925661423624848342947598951870765316380602291081195309822", mt.Root().BigInt().String()) //nolint:lll
  568. mt2 := newTestingMerkle(t, 140)
  569. err = mt2.Add(big.NewInt(33), big.NewInt(44))
  570. assert.Nil(t, err)
  571. assert.Equal(t, mt2.Root(), mt.Root())
  572. }
  573. func TestDeleteNonExistingKeys(t *testing.T) {
  574. mt := newTestingMerkle(t, 10)
  575. err := mt.Add(big.NewInt(1), big.NewInt(2))
  576. assert.Nil(t, err)
  577. err = mt.Add(big.NewInt(33), big.NewInt(44))
  578. assert.Nil(t, err)
  579. err = mt.Delete(big.NewInt(33))
  580. assert.Nil(t, err)
  581. err = mt.Delete(big.NewInt(33))
  582. assert.Equal(t, merkletree.ErrKeyNotFound, err)
  583. err = mt.Delete(big.NewInt(1))
  584. assert.Nil(t, err)
  585. assert.Equal(t, "0", mt.Root().String())
  586. err = mt.Delete(big.NewInt(33))
  587. assert.Equal(t, merkletree.ErrKeyNotFound, err)
  588. }
  589. func TestDumpLeafsImportLeafs(t *testing.T) {
  590. mt := newTestingMerkle(t, 140)
  591. q1 := new(big.Int).Sub(constants.Q, big.NewInt(1))
  592. for i := 0; i < 10; i++ {
  593. // use numbers near under Q
  594. k := new(big.Int).Sub(q1, big.NewInt(int64(i)))
  595. v := big.NewInt(0)
  596. err := mt.Add(k, v)
  597. require.Nil(t, err)
  598. // use numbers near above 0
  599. k = big.NewInt(int64(i))
  600. err = mt.Add(k, v)
  601. require.Nil(t, err)
  602. }
  603. d, err := mt.DumpLeafs(nil)
  604. assert.Nil(t, err)
  605. mt2, err := merkletree.NewMerkleTree(memory.NewMemoryStorage(), 140)
  606. require.Nil(t, err)
  607. err = mt2.ImportDumpedLeafs(d)
  608. assert.Nil(t, err)
  609. assert.Equal(t, mt.Root(), mt2.Root())
  610. }
  611. func TestAddAndGetCircomProof(t *testing.T) {
  612. mt := newTestingMerkle(t, 10)
  613. assert.Equal(t, "0", mt.Root().String())
  614. // test vectors generated using https://github.com/iden3/circomlib smt.js
  615. cpp, err := mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2))
  616. assert.Nil(t, err)
  617. assert.Equal(t, "0", cpp.OldRoot.String())
  618. assert.Equal(t, "13578938...", cpp.NewRoot.String())
  619. assert.Equal(t, "0", cpp.OldKey.String())
  620. assert.Equal(t, "0", cpp.OldValue.String())
  621. assert.Equal(t, "1", cpp.NewKey.String())
  622. assert.Equal(t, "2", cpp.NewValue.String())
  623. assert.Equal(t, true, cpp.IsOld0)
  624. assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  625. cpp, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44))
  626. assert.Nil(t, err)
  627. assert.Equal(t, "13578938...", cpp.OldRoot.String())
  628. assert.Equal(t, "54123936...", cpp.NewRoot.String())
  629. assert.Equal(t, "1", cpp.OldKey.String())
  630. assert.Equal(t, "2", cpp.OldValue.String())
  631. assert.Equal(t, "33", cpp.NewKey.String())
  632. assert.Equal(t, "44", cpp.NewValue.String())
  633. assert.Equal(t, false, cpp.IsOld0)
  634. assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  635. cpp, err = mt.AddAndGetCircomProof(big.NewInt(55), big.NewInt(66))
  636. assert.Nil(t, err)
  637. assert.Equal(t, "54123936...", cpp.OldRoot.String())
  638. assert.Equal(t, "50943640...", cpp.NewRoot.String())
  639. assert.Equal(t, "0", cpp.OldKey.String())
  640. assert.Equal(t, "0", cpp.OldValue.String())
  641. assert.Equal(t, "55", cpp.NewKey.String())
  642. assert.Equal(t, "66", cpp.NewValue.String())
  643. assert.Equal(t, true, cpp.IsOld0)
  644. assert.Equal(t, "[0 21312042... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  645. }
  646. func TestUpdateCircomProcessorProof(t *testing.T) {
  647. mt := newTestingMerkle(t, 10)
  648. for i := 0; i < 16; i++ {
  649. k := big.NewInt(int64(i))
  650. v := big.NewInt(int64(i * 2))
  651. if err := mt.Add(k, v); err != nil {
  652. t.Fatal(err)
  653. }
  654. }
  655. _, v, _, err := mt.Get(big.NewInt(10))
  656. assert.Nil(t, err)
  657. assert.Equal(t, big.NewInt(20), v)
  658. // test vectors generated using https://github.com/iden3/circomlib smt.js
  659. cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
  660. assert.Nil(t, err)
  661. assert.Equal(t, "39010880...", cpp.OldRoot.String())
  662. assert.Equal(t, "18587862...", cpp.NewRoot.String())
  663. assert.Equal(t, "10", cpp.OldKey.String())
  664. assert.Equal(t, "20", cpp.OldValue.String())
  665. assert.Equal(t, "10", cpp.NewKey.String())
  666. assert.Equal(t, "1024", cpp.NewValue.String())
  667. assert.Equal(t, false, cpp.IsOld0)
  668. assert.Equal(t,
  669. "[34930557... 20201609... 18790542... 15930030... 0 0 0 0 0 0 0]",
  670. fmt.Sprintf("%v", cpp.Siblings))
  671. }
  672. func TestSmtVerifier(t *testing.T) {
  673. mt := newTestingMerkle(t, 4)
  674. err := mt.Add(big.NewInt(1), big.NewInt(11))
  675. assert.Nil(t, err)
  676. cvp, err := mt.GenerateSCVerifierProof(big.NewInt(1), nil)
  677. assert.Nil(t, err)
  678. jCvp, err := json.Marshal(cvp)
  679. assert.Nil(t, err)
  680. // expect siblings to be '[]', instead of 'null'
  681. expected := `{"root":"6525056641794203554583616941316772618766382307684970171204065038799368146416","siblings":[],"oldKey":"0","oldValue":"0","isOld0":false,"key":"1","value":"11","fnc":0}` //nolint:lll
  682. assert.Equal(t, expected, string(jCvp))
  683. err = mt.Add(big.NewInt(2), big.NewInt(22))
  684. assert.Nil(t, err)
  685. err = mt.Add(big.NewInt(3), big.NewInt(33))
  686. assert.Nil(t, err)
  687. err = mt.Add(big.NewInt(4), big.NewInt(44))
  688. assert.Nil(t, err)
  689. cvp, err = mt.GenerateCircomVerifierProof(big.NewInt(2), nil)
  690. assert.Nil(t, err)
  691. jCvp, err = json.Marshal(cvp)
  692. assert.Nil(t, err)
  693. // Test vectors generated using https://github.com/iden3/circomlib smt.js
  694. // Expect siblings with the extra 0 that the circom circuits need
  695. expected = `{"root":"13558168455220559042747853958949063046226645447188878859760119761585093422436","siblings":["11620130507635441932056895853942898236773847390796721536119314875877874016518","5158240518874928563648144881543092238925265313977134167935552944620041388700","0","0","0"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll
  696. assert.Equal(t, expected, string(jCvp))
  697. cvp, err = mt.GenerateSCVerifierProof(big.NewInt(2), nil)
  698. assert.Nil(t, err)
  699. jCvp, err = json.Marshal(cvp)
  700. assert.Nil(t, err)
  701. // Test vectors generated using https://github.com/iden3/circomlib smt.js
  702. // Without the extra 0 that the circom circuits need, but that are not
  703. // needed at a smart contract verification
  704. expected = `{"root":"13558168455220559042747853958949063046226645447188878859760119761585093422436","siblings":["11620130507635441932056895853942898236773847390796721536119314875877874016518","5158240518874928563648144881543092238925265313977134167935552944620041388700"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll
  705. assert.Equal(t, expected, string(jCvp))
  706. }
  707. func TestTypesMarshalers(t *testing.T) {
  708. // test Hash marshalers
  709. h, err := merkletree.NewHashFromString("42")
  710. assert.Nil(t, err)
  711. s, err := json.Marshal(h)
  712. assert.Nil(t, err)
  713. var h2 *merkletree.Hash
  714. err = json.Unmarshal(s, &h2)
  715. assert.Nil(t, err)
  716. assert.Equal(t, h, h2)
  717. // create CircomProcessorProof
  718. mt := newTestingMerkle(t, 10)
  719. for i := 0; i < 16; i++ {
  720. k := big.NewInt(int64(i))
  721. v := big.NewInt(int64(i * 2))
  722. if err := mt.Add(k, v); err != nil {
  723. t.Fatal(err)
  724. }
  725. }
  726. _, v, _, err := mt.Get(big.NewInt(10))
  727. assert.Nil(t, err)
  728. assert.Equal(t, big.NewInt(20), v)
  729. cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
  730. assert.Nil(t, err)
  731. // test CircomProcessorProof marshalers
  732. b, err := json.Marshal(&cpp)
  733. assert.Nil(t, err)
  734. var cpp2 *merkletree.CircomProcessorProof
  735. err = json.Unmarshal(b, &cpp2)
  736. assert.Nil(t, err)
  737. assert.Equal(t, cpp, cpp2)
  738. }