You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

816 lines
25 KiB

  1. package sql
  2. import (
  3. "bytes"
  4. "encoding/hex"
  5. "encoding/json"
  6. "fmt"
  7. "github.com/iden3/go-iden3-crypto/constants"
  8. cryptoUtils "github.com/iden3/go-iden3-crypto/utils"
  9. "github.com/iden3/go-merkletree"
  10. "github.com/iden3/go-merkletree/db/memory"
  11. "github.com/jmoiron/sqlx"
  12. "github.com/stretchr/testify/assert"
  13. "github.com/stretchr/testify/require"
  14. "math/big"
  15. "os"
  16. "strconv"
  17. "testing"
  18. )
  19. func sqlStorage(t *testing.T) merkletree.Storage {
  20. host := os.Getenv("PGHOST")
  21. if host == "" {
  22. host = "localhost"
  23. }
  24. port, _ := strconv.Atoi(os.Getenv("PGPORT"))
  25. if port == 0 {
  26. port = 5432
  27. }
  28. user := os.Getenv("PGUSER")
  29. if user == "" {
  30. user = "user"
  31. }
  32. password := os.Getenv("PGPASSWORD")
  33. if password == "" {
  34. panic("No PGPASSWORD envvar specified")
  35. }
  36. dbname := os.Getenv("PGDATABASE")
  37. if dbname == "" {
  38. dbname = "test"
  39. }
  40. psqlconn := fmt.Sprintf(
  41. "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
  42. host,
  43. port,
  44. user,
  45. password,
  46. dbname,
  47. )
  48. dbx, err := sqlx.Connect("postgres", psqlconn)
  49. if err != nil {
  50. t.Fatal(err)
  51. return nil
  52. }
  53. // clear MerkleTree table
  54. dbx.Exec("TRUNCATE TABLE mt_roots")
  55. dbx.Exec("TRUNCATE TABLE mt_nodes")
  56. sto, err := NewSqlStorage(dbx, false)
  57. if err != nil {
  58. t.Fatal(err)
  59. return nil
  60. }
  61. t.Cleanup(func() {
  62. })
  63. return sto
  64. }
  65. func TestSql(t *testing.T) {
  66. //sto := sqlStorage(t)
  67. //t.Run("TestReturnKnownErrIfNotExists", func(t *testing.T) {
  68. // test.TestReturnKnownErrIfNotExists(t, sqlStorage(t))
  69. //})
  70. //t.Run("TestStorageInsertGet", func(t *testing.T) {
  71. // test.TestStorageInsertGet(t, sqlStorage(t))
  72. //})
  73. //test.TestStorageWithPrefix(t, sqlStorage(t))
  74. //test.TestConcatTx(t, sqlStorage(t))
  75. //test.TestList(t, sqlStorage(t))
  76. //test.TestIterate(t, sqlStorage(t))
  77. }
  78. var debug = false
  79. type Fatalable interface {
  80. Fatal(args ...interface{})
  81. }
  82. func newTestingMerkle(f *testing.T, maxLevels int) *merkletree.MerkleTree {
  83. sto := sqlStorage(f)
  84. mt, err := merkletree.NewMerkleTree(sto, maxLevels)
  85. if err != nil {
  86. f.Fatal(err)
  87. return nil
  88. }
  89. return mt
  90. }
  91. func TestHashParsers(t *testing.T) {
  92. h0 := merkletree.NewHashFromBigInt(big.NewInt(0))
  93. assert.Equal(t, "0", h0.String())
  94. h1 := merkletree.NewHashFromBigInt(big.NewInt(1))
  95. assert.Equal(t, "1", h1.String())
  96. h10 := merkletree.NewHashFromBigInt(big.NewInt(10))
  97. assert.Equal(t, "10", h10.String())
  98. h7l := merkletree.NewHashFromBigInt(big.NewInt(1234567))
  99. assert.Equal(t, "1234567", h7l.String())
  100. h8l := merkletree.NewHashFromBigInt(big.NewInt(12345678))
  101. assert.Equal(t, "12345678...", h8l.String())
  102. b, ok := new(big.Int).SetString("4932297968297298434239270129193057052722409868268166443802652458940273154854", 10) //nolint:lll
  103. assert.True(t, ok)
  104. h := merkletree.NewHashFromBigInt(b)
  105. assert.Equal(t, "4932297968297298434239270129193057052722409868268166443802652458940273154854", h.BigInt().String()) //nolint:lll
  106. assert.Equal(t, "49322979...", h.String())
  107. assert.Equal(t, "265baaf161e875c372d08e50f52abddc01d32efc93e90290bb8b3d9ceb94e70a", h.Hex())
  108. b1, err := merkletree.NewBigIntFromHashBytes(b.Bytes())
  109. assert.Nil(t, err)
  110. assert.Equal(t, new(big.Int).SetBytes(b.Bytes()).String(), b1.String())
  111. b2, err := merkletree.NewHashFromBytes(b.Bytes())
  112. assert.Nil(t, err)
  113. assert.Equal(t, b.String(), b2.BigInt().String())
  114. h2, err := merkletree.NewHashFromHex(h.Hex())
  115. assert.Nil(t, err)
  116. assert.Equal(t, h, h2)
  117. _, err = merkletree.NewHashFromHex("0x12")
  118. assert.NotNil(t, err)
  119. // check limits
  120. a := new(big.Int).Sub(constants.Q, big.NewInt(1))
  121. testHashParsers(t, a)
  122. a = big.NewInt(int64(1))
  123. testHashParsers(t, a)
  124. }
  125. func testHashParsers(t *testing.T, a *big.Int) {
  126. require.True(t, cryptoUtils.CheckBigIntInField(a))
  127. h := merkletree.NewHashFromBigInt(a)
  128. assert.Equal(t, a, h.BigInt())
  129. hFromBytes, err := merkletree.NewHashFromBytes(h.Bytes())
  130. assert.Nil(t, err)
  131. assert.Equal(t, h, hFromBytes)
  132. assert.Equal(t, a, hFromBytes.BigInt())
  133. assert.Equal(t, a.String(), hFromBytes.BigInt().String())
  134. hFromHex, err := merkletree.NewHashFromHex(h.Hex())
  135. assert.Nil(t, err)
  136. assert.Equal(t, h, hFromHex)
  137. aBIFromHBytes, err := merkletree.NewBigIntFromHashBytes(h.Bytes())
  138. assert.Nil(t, err)
  139. assert.Equal(t, a, aBIFromHBytes)
  140. assert.Equal(t, new(big.Int).SetBytes(a.Bytes()).String(), aBIFromHBytes.String())
  141. }
  142. func TestNewTree(t *testing.T) {
  143. mt := newTestingMerkle(t, 10)
  144. mt, err := merkletree.NewMerkleTree(memory.NewMemoryStorage(), 10)
  145. assert.Nil(t, err)
  146. assert.Equal(t, "0", mt.Root().String())
  147. // test vectors generated using https://github.com/iden3/circomlib smt.js
  148. err = mt.Add(big.NewInt(1), big.NewInt(2))
  149. assert.Nil(t, err)
  150. assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll
  151. err = mt.Add(big.NewInt(33), big.NewInt(44))
  152. assert.Nil(t, err)
  153. assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
  154. err = mt.Add(big.NewInt(1234), big.NewInt(9876))
  155. assert.Nil(t, err)
  156. assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll
  157. proof, v, err := mt.GenerateProof(big.NewInt(33), nil)
  158. assert.Nil(t, err)
  159. assert.Equal(t, big.NewInt(44), v)
  160. assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(44)))
  161. assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(33), big.NewInt(45)))
  162. }
  163. func TestAddDifferentOrder(t *testing.T) {
  164. mt1 := newTestingMerkle(t, 140)
  165. for i := 0; i < 16; i++ {
  166. k := big.NewInt(int64(i))
  167. v := big.NewInt(0)
  168. if err := mt1.Add(k, v); err != nil {
  169. t.Fatal(err)
  170. }
  171. }
  172. mt2 := newTestingMerkle(t, 140)
  173. for i := 16 - 1; i >= 0; i-- {
  174. k := big.NewInt(int64(i))
  175. v := big.NewInt(0)
  176. if err := mt2.Add(k, v); err != nil {
  177. t.Fatal(err)
  178. }
  179. }
  180. assert.Equal(t, mt1.Root().Hex(), mt2.Root().Hex())
  181. assert.Equal(t, "3b89100bec24da9275c87bc188740389e1d5accfc7d88ba5688d7fa96a00d82f", mt1.Root().Hex()) //nolint:lll
  182. }
  183. func TestAddRepeatedIndex(t *testing.T) {
  184. mt := newTestingMerkle(t, 140)
  185. k := big.NewInt(int64(3))
  186. v := big.NewInt(int64(12))
  187. if err := mt.Add(k, v); err != nil {
  188. t.Fatal(err)
  189. }
  190. err := mt.Add(k, v)
  191. assert.NotNil(t, err)
  192. assert.Equal(t, merkletree.ErrEntryIndexAlreadyExists, err)
  193. }
  194. func TestGet(t *testing.T) {
  195. mt := newTestingMerkle(t, 140)
  196. for i := 0; i < 16; i++ {
  197. k := big.NewInt(int64(i))
  198. v := big.NewInt(int64(i * 2))
  199. if err := mt.Add(k, v); err != nil {
  200. t.Fatal(err)
  201. }
  202. }
  203. k, v, _, err := mt.Get(big.NewInt(10))
  204. assert.Nil(t, err)
  205. assert.Equal(t, big.NewInt(10), k)
  206. assert.Equal(t, big.NewInt(20), v)
  207. k, v, _, err = mt.Get(big.NewInt(15))
  208. assert.Nil(t, err)
  209. assert.Equal(t, big.NewInt(15), k)
  210. assert.Equal(t, big.NewInt(30), v)
  211. k, v, _, err = mt.Get(big.NewInt(16))
  212. assert.NotNil(t, err)
  213. assert.Equal(t, merkletree.ErrKeyNotFound, err)
  214. assert.Equal(t, "0", k.String())
  215. assert.Equal(t, "0", v.String())
  216. }
  217. func TestUpdate(t *testing.T) {
  218. mt := newTestingMerkle(t, 140)
  219. for i := 0; i < 16; i++ {
  220. k := big.NewInt(int64(i))
  221. v := big.NewInt(int64(i * 2))
  222. if err := mt.Add(k, v); err != nil {
  223. t.Fatal(err)
  224. }
  225. }
  226. _, v, _, err := mt.Get(big.NewInt(10))
  227. assert.Nil(t, err)
  228. assert.Equal(t, big.NewInt(20), v)
  229. _, err = mt.Update(big.NewInt(10), big.NewInt(1024))
  230. assert.Nil(t, err)
  231. _, v, _, err = mt.Get(big.NewInt(10))
  232. assert.Nil(t, err)
  233. assert.Equal(t, big.NewInt(1024), v)
  234. _, err = mt.Update(big.NewInt(1000), big.NewInt(1024))
  235. assert.Equal(t, merkletree.ErrKeyNotFound, err)
  236. }
  237. func TestUpdate2(t *testing.T) {
  238. mt1 := newTestingMerkle(t, 140)
  239. mt2 := newTestingMerkle(t, 140)
  240. err := mt1.Add(big.NewInt(1), big.NewInt(119))
  241. assert.Nil(t, err)
  242. err = mt1.Add(big.NewInt(2), big.NewInt(229))
  243. assert.Nil(t, err)
  244. err = mt1.Add(big.NewInt(9876), big.NewInt(6789))
  245. assert.Nil(t, err)
  246. err = mt2.Add(big.NewInt(1), big.NewInt(11))
  247. assert.Nil(t, err)
  248. err = mt2.Add(big.NewInt(2), big.NewInt(22))
  249. assert.Nil(t, err)
  250. err = mt2.Add(big.NewInt(9876), big.NewInt(10))
  251. assert.Nil(t, err)
  252. _, err = mt1.Update(big.NewInt(1), big.NewInt(11))
  253. assert.Nil(t, err)
  254. _, err = mt1.Update(big.NewInt(2), big.NewInt(22))
  255. assert.Nil(t, err)
  256. _, err = mt2.Update(big.NewInt(9876), big.NewInt(6789))
  257. assert.Nil(t, err)
  258. assert.Equal(t, mt1.Root(), mt2.Root())
  259. }
  260. func TestGenerateAndVerifyProof128(t *testing.T) {
  261. mt := newTestingMerkle(t, 140)
  262. for i := 0; i < 128; i++ {
  263. k := big.NewInt(int64(i))
  264. v := big.NewInt(0)
  265. if err := mt.Add(k, v); err != nil {
  266. t.Fatal(err)
  267. }
  268. }
  269. proof, v, err := mt.GenerateProof(big.NewInt(42), nil)
  270. assert.Nil(t, err)
  271. assert.Equal(t, "0", v.String())
  272. assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(42), big.NewInt(0)))
  273. }
  274. func TestTreeLimit(t *testing.T) {
  275. mt := newTestingMerkle(t, 5)
  276. for i := 0; i < 16; i++ {
  277. err := mt.Add(big.NewInt(int64(i)), big.NewInt(int64(i)))
  278. assert.Nil(t, err)
  279. }
  280. // here the tree is full, should not allow to add more data as reaches the maximum number of levels
  281. err := mt.Add(big.NewInt(int64(16)), big.NewInt(int64(16)))
  282. assert.NotNil(t, err)
  283. assert.Equal(t, merkletree.ErrReachedMaxLevel, err)
  284. }
  285. func TestSiblingsFromProof(t *testing.T) {
  286. mt := newTestingMerkle(t, 140)
  287. for i := 0; i < 64; i++ {
  288. k := big.NewInt(int64(i))
  289. v := big.NewInt(0)
  290. if err := mt.Add(k, v); err != nil {
  291. t.Fatal(err)
  292. }
  293. }
  294. proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
  295. if err != nil {
  296. t.Fatal(err)
  297. }
  298. siblings := merkletree.SiblingsFromProof(proof)
  299. assert.Equal(t, 6, len(siblings))
  300. assert.Equal(t,
  301. "d6e368bda90c5ee3e910222c1fc1c0d9e23f2d350dbc47f4a92de30f1be3c60b",
  302. siblings[0].Hex())
  303. assert.Equal(t,
  304. "9dbd03b1bcd580e0f3e6668d80d55288f04464126feb1624ec8ee30be8df9c16",
  305. siblings[1].Hex())
  306. assert.Equal(t,
  307. "de866af9545dcd1c5bb7811e7f27814918e037eb9fead40919e8f19525896e27",
  308. siblings[2].Hex())
  309. assert.Equal(t,
  310. "5f4182212a84741d1174ba7c42e369f2e3ad8ade7d04eea2d0f98e3ed8b7a317",
  311. siblings[3].Hex())
  312. assert.Equal(t,
  313. "77639098d513f7aef9730fdb1d1200401af5fe9da91b61772f4dd142ac89a122",
  314. siblings[4].Hex())
  315. assert.Equal(t,
  316. "943ee501f4ba2137c79b54af745dfc5f105f539fcc449cd2a356eb5c030e3c07",
  317. siblings[5].Hex())
  318. }
  319. func TestVerifyProofCases(t *testing.T) {
  320. mt := newTestingMerkle(t, 140)
  321. defer mt.DB().Close()
  322. for i := 0; i < 8; i++ {
  323. if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
  324. t.Fatal(err)
  325. }
  326. }
  327. // Existence proof
  328. proof, _, err := mt.GenerateProof(big.NewInt(4), nil)
  329. if err != nil {
  330. t.Fatal(err)
  331. }
  332. assert.Equal(t, proof.Existence, true)
  333. assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(4), big.NewInt(0)))
  334. assert.Equal(t, "0003000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df06", hex.EncodeToString(proof.Bytes())) //nolint:lll
  335. for i := 8; i < 32; i++ {
  336. proof, _, err = mt.GenerateProof(big.NewInt(int64(i)), nil)
  337. assert.Nil(t, err)
  338. if debug {
  339. fmt.Println(i, proof)
  340. }
  341. }
  342. // Non-existence proof, empty aux
  343. proof, _, err = mt.GenerateProof(big.NewInt(12), nil)
  344. if err != nil {
  345. t.Fatal(err)
  346. }
  347. assert.Equal(t, proof.Existence, false)
  348. // assert.True(t, proof.nodeAux == nil)
  349. assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(12), big.NewInt(0)))
  350. assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e40344ad686a18ba78b502c0b6f285c5c8393bde2f7a3e2abe586515e4d84533e3037b062539bde2d80749746986cf8f0001fd2cdbf9a89fcbf981a769daef49df0604000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll
  351. // Non-existence proof, diff. node aux
  352. proof, _, err = mt.GenerateProof(big.NewInt(10), nil)
  353. if err != nil {
  354. t.Fatal(err)
  355. }
  356. assert.Equal(t, proof.Existence, false)
  357. assert.True(t, proof.NodeAux != nil)
  358. assert.True(t, merkletree.VerifyProof(mt.Root(), proof, big.NewInt(10), big.NewInt(0)))
  359. assert.Equal(t, "0303000000000000000000000000000000000000000000000000000000000007529cbedbda2bdd25fd6455551e55245fa6dc11a9d0c27dc0cd38fca44c17e4030acfcdd2617df9eb5aef744c5f2e03eb8c92c61f679007dc1f2707fd908ea41a9433745b469c101edca814c498e7f388100d497b24f1d2ac935bced3572f591d02000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", hex.EncodeToString(proof.Bytes())) //nolint:lll
  360. }
  361. func TestVerifyProofFalse(t *testing.T) {
  362. mt := newTestingMerkle(t, 140)
  363. defer mt.DB().Close()
  364. for i := 0; i < 8; i++ {
  365. if err := mt.Add(big.NewInt(int64(i)), big.NewInt(0)); err != nil {
  366. t.Fatal(err)
  367. }
  368. }
  369. // Invalid existence proof (node used for verification doesn't
  370. // correspond to node in the proof)
  371. proof, _, err := mt.GenerateProof(big.NewInt(int64(4)), nil)
  372. if err != nil {
  373. t.Fatal(err)
  374. }
  375. assert.Equal(t, proof.Existence, true)
  376. assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(int64(5)), big.NewInt(int64(5))))
  377. // Invalid non-existence proof (Non-existence proof, diff. node aux)
  378. proof, _, err = mt.GenerateProof(big.NewInt(int64(4)), nil)
  379. if err != nil {
  380. t.Fatal(err)
  381. }
  382. assert.Equal(t, proof.Existence, true)
  383. // Now we change the proof from existence to non-existence, and add e's
  384. // data as auxiliary node.
  385. proof.Existence = false
  386. proof.NodeAux = &merkletree.NodeAux{Key: merkletree.NewHashFromBigInt(big.NewInt(int64(4))),
  387. Value: merkletree.NewHashFromBigInt(big.NewInt(4))}
  388. assert.True(t, !merkletree.VerifyProof(mt.Root(), proof, big.NewInt(int64(4)), big.NewInt(0)))
  389. }
  390. func TestGraphViz(t *testing.T) {
  391. mt := newTestingMerkle(t, 140)
  392. _ = mt.Add(big.NewInt(1), big.NewInt(0))
  393. _ = mt.Add(big.NewInt(2), big.NewInt(0))
  394. _ = mt.Add(big.NewInt(3), big.NewInt(0))
  395. _ = mt.Add(big.NewInt(4), big.NewInt(0))
  396. _ = mt.Add(big.NewInt(5), big.NewInt(0))
  397. _ = mt.Add(big.NewInt(100), big.NewInt(0))
  398. // mt.PrintGraphViz(nil)
  399. expected := `digraph hierarchy {
  400. node [fontname=Monospace,fontsize=10,shape=box]
  401. "56332309..." -> {"18483622..." "20902180..."}
  402. "18483622..." -> {"75768243..." "16893244..."}
  403. "75768243..." -> {"empty0" "21857056..."}
  404. "empty0" [style=dashed,label=0];
  405. "21857056..." -> {"51072523..." "empty1"}
  406. "empty1" [style=dashed,label=0];
  407. "51072523..." -> {"17311038..." "empty2"}
  408. "empty2" [style=dashed,label=0];
  409. "17311038..." -> {"69499803..." "21008290..."}
  410. "69499803..." [style=filled];
  411. "21008290..." [style=filled];
  412. "16893244..." [style=filled];
  413. "20902180..." -> {"12496585..." "18055627..."}
  414. "12496585..." -> {"19374975..." "15739329..."}
  415. "19374975..." [style=filled];
  416. "15739329..." [style=filled];
  417. "18055627..." [style=filled];
  418. }
  419. `
  420. w := bytes.NewBufferString("")
  421. err := mt.GraphViz(w, nil)
  422. assert.Nil(t, err)
  423. assert.Equal(t, []byte(expected), w.Bytes())
  424. }
  425. func TestDelete(t *testing.T) {
  426. mt := newTestingMerkle(t, 10)
  427. assert.Equal(t, "0", mt.Root().String())
  428. // test vectors generated using https://github.com/iden3/circomlib smt.js
  429. err := mt.Add(big.NewInt(1), big.NewInt(2))
  430. assert.Nil(t, err)
  431. assert.Equal(t, "13578938674299138072471463694055224830892726234048532520316387704878000008795", mt.Root().BigInt().String()) //nolint:lll
  432. err = mt.Add(big.NewInt(33), big.NewInt(44))
  433. assert.Nil(t, err)
  434. assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
  435. err = mt.Add(big.NewInt(1234), big.NewInt(9876))
  436. assert.Nil(t, err)
  437. assert.Equal(t, "14204494359367183802864593755198662203838502594566452929175967972147978322084", mt.Root().BigInt().String()) //nolint:lll
  438. // mt.PrintGraphViz(nil)
  439. err = mt.Delete(big.NewInt(33))
  440. // mt.PrintGraphViz(nil)
  441. assert.Nil(t, err)
  442. assert.Equal(t, "15550352095346187559699212771793131433118240951738528922418613687814377955591", mt.Root().BigInt().String()) //nolint:lll
  443. err = mt.Delete(big.NewInt(1234))
  444. assert.Nil(t, err)
  445. err = mt.Delete(big.NewInt(1))
  446. assert.Nil(t, err)
  447. assert.Equal(t, "0", mt.Root().String())
  448. }
  449. func TestDelete2(t *testing.T) {
  450. mt := newTestingMerkle(t, 140)
  451. for i := 0; i < 8; i++ {
  452. k := big.NewInt(int64(i))
  453. v := big.NewInt(0)
  454. if err := mt.Add(k, v); err != nil {
  455. t.Fatal(err)
  456. }
  457. }
  458. expectedRoot := mt.Root()
  459. k := big.NewInt(8)
  460. v := big.NewInt(0)
  461. err := mt.Add(k, v)
  462. require.Nil(t, err)
  463. err = mt.Delete(big.NewInt(8))
  464. assert.Nil(t, err)
  465. assert.Equal(t, expectedRoot, mt.Root())
  466. mt2 := newTestingMerkle(t, 140)
  467. for i := 0; i < 8; i++ {
  468. k := big.NewInt(int64(i))
  469. v := big.NewInt(0)
  470. if err := mt2.Add(k, v); err != nil {
  471. t.Fatal(err)
  472. }
  473. }
  474. assert.Equal(t, mt2.Root(), mt.Root())
  475. }
  476. func TestDelete3(t *testing.T) {
  477. mt := newTestingMerkle(t, 140)
  478. err := mt.Add(big.NewInt(1), big.NewInt(1))
  479. assert.Nil(t, err)
  480. err = mt.Add(big.NewInt(2), big.NewInt(2))
  481. assert.Nil(t, err)
  482. assert.Equal(t, "19060075022714027595905950662613111880864833370144986660188929919683258088314", mt.Root().BigInt().String()) //nolint:lll
  483. err = mt.Delete(big.NewInt(1))
  484. assert.Nil(t, err)
  485. assert.Equal(t, "849831128489032619062850458217693666094013083866167024127442191257793527951", mt.Root().BigInt().String()) //nolint:lll
  486. mt2 := newTestingMerkle(t, 140)
  487. err = mt2.Add(big.NewInt(2), big.NewInt(2))
  488. assert.Nil(t, err)
  489. assert.Equal(t, mt2.Root(), mt.Root())
  490. }
  491. func TestDelete4(t *testing.T) {
  492. mt := newTestingMerkle(t, 140)
  493. err := mt.Add(big.NewInt(1), big.NewInt(1))
  494. assert.Nil(t, err)
  495. err = mt.Add(big.NewInt(2), big.NewInt(2))
  496. assert.Nil(t, err)
  497. err = mt.Add(big.NewInt(3), big.NewInt(3))
  498. assert.Nil(t, err)
  499. assert.Equal(t, "14109632483797541575275728657193822866549917334388996328141438956557066918117", mt.Root().BigInt().String()) //nolint:lll
  500. err = mt.Delete(big.NewInt(1))
  501. assert.Nil(t, err)
  502. assert.Equal(t, "159935162486187606489815340465698714590556679404589449576549073038844694972", mt.Root().BigInt().String()) //nolint:lll
  503. mt2 := newTestingMerkle(t, 140)
  504. err = mt2.Add(big.NewInt(2), big.NewInt(2))
  505. assert.Nil(t, err)
  506. err = mt2.Add(big.NewInt(3), big.NewInt(3))
  507. assert.Nil(t, err)
  508. assert.Equal(t, mt2.Root(), mt.Root())
  509. }
  510. func TestDelete5(t *testing.T) {
  511. mt := newTestingMerkle(t, 10)
  512. err := mt.Add(big.NewInt(1), big.NewInt(2))
  513. assert.Nil(t, err)
  514. err = mt.Add(big.NewInt(33), big.NewInt(44))
  515. assert.Nil(t, err)
  516. assert.Equal(t, "5412393676474193513566895793055462193090331607895808993925969873307089394741", mt.Root().BigInt().String()) //nolint:lll
  517. err = mt.Delete(big.NewInt(1))
  518. assert.Nil(t, err)
  519. assert.Equal(t, "18869260084287237667925661423624848342947598951870765316380602291081195309822", mt.Root().BigInt().String()) //nolint:lll
  520. mt2 := newTestingMerkle(t, 140)
  521. err = mt2.Add(big.NewInt(33), big.NewInt(44))
  522. assert.Nil(t, err)
  523. assert.Equal(t, mt2.Root(), mt.Root())
  524. }
  525. func TestDeleteNonExistingKeys(t *testing.T) {
  526. mt := newTestingMerkle(t, 10)
  527. err := mt.Add(big.NewInt(1), big.NewInt(2))
  528. assert.Nil(t, err)
  529. err = mt.Add(big.NewInt(33), big.NewInt(44))
  530. assert.Nil(t, err)
  531. err = mt.Delete(big.NewInt(33))
  532. assert.Nil(t, err)
  533. err = mt.Delete(big.NewInt(33))
  534. assert.Equal(t, merkletree.ErrKeyNotFound, err)
  535. err = mt.Delete(big.NewInt(1))
  536. assert.Nil(t, err)
  537. assert.Equal(t, "0", mt.Root().String())
  538. err = mt.Delete(big.NewInt(33))
  539. assert.Equal(t, merkletree.ErrKeyNotFound, err)
  540. }
  541. func TestDumpLeafsImportLeafs(t *testing.T) {
  542. mt := newTestingMerkle(t, 140)
  543. q1 := new(big.Int).Sub(constants.Q, big.NewInt(1))
  544. for i := 0; i < 10; i++ {
  545. // use numbers near under Q
  546. k := new(big.Int).Sub(q1, big.NewInt(int64(i)))
  547. v := big.NewInt(0)
  548. err := mt.Add(k, v)
  549. require.Nil(t, err)
  550. // use numbers near above 0
  551. k = big.NewInt(int64(i))
  552. err = mt.Add(k, v)
  553. require.Nil(t, err)
  554. }
  555. d, err := mt.DumpLeafs(nil)
  556. assert.Nil(t, err)
  557. mt2, err := merkletree.NewMerkleTree(memory.NewMemoryStorage(), 140)
  558. require.Nil(t, err)
  559. err = mt2.ImportDumpedLeafs(d)
  560. assert.Nil(t, err)
  561. assert.Equal(t, mt.Root(), mt2.Root())
  562. }
  563. func TestAddAndGetCircomProof(t *testing.T) {
  564. mt := newTestingMerkle(t, 10)
  565. assert.Equal(t, "0", mt.Root().String())
  566. // test vectors generated using https://github.com/iden3/circomlib smt.js
  567. cpp, err := mt.AddAndGetCircomProof(big.NewInt(1), big.NewInt(2))
  568. assert.Nil(t, err)
  569. assert.Equal(t, "0", cpp.OldRoot.String())
  570. assert.Equal(t, "13578938...", cpp.NewRoot.String())
  571. assert.Equal(t, "0", cpp.OldKey.String())
  572. assert.Equal(t, "0", cpp.OldValue.String())
  573. assert.Equal(t, "1", cpp.NewKey.String())
  574. assert.Equal(t, "2", cpp.NewValue.String())
  575. assert.Equal(t, true, cpp.IsOld0)
  576. assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  577. cpp, err = mt.AddAndGetCircomProof(big.NewInt(33), big.NewInt(44))
  578. assert.Nil(t, err)
  579. assert.Equal(t, "13578938...", cpp.OldRoot.String())
  580. assert.Equal(t, "54123936...", cpp.NewRoot.String())
  581. assert.Equal(t, "1", cpp.OldKey.String())
  582. assert.Equal(t, "2", cpp.OldValue.String())
  583. assert.Equal(t, "33", cpp.NewKey.String())
  584. assert.Equal(t, "44", cpp.NewValue.String())
  585. assert.Equal(t, false, cpp.IsOld0)
  586. assert.Equal(t, "[0 0 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  587. cpp, err = mt.AddAndGetCircomProof(big.NewInt(55), big.NewInt(66))
  588. assert.Nil(t, err)
  589. assert.Equal(t, "54123936...", cpp.OldRoot.String())
  590. assert.Equal(t, "50943640...", cpp.NewRoot.String())
  591. assert.Equal(t, "0", cpp.OldKey.String())
  592. assert.Equal(t, "0", cpp.OldValue.String())
  593. assert.Equal(t, "55", cpp.NewKey.String())
  594. assert.Equal(t, "66", cpp.NewValue.String())
  595. assert.Equal(t, true, cpp.IsOld0)
  596. assert.Equal(t, "[0 21312042... 0 0 0 0 0 0 0 0 0]", fmt.Sprintf("%v", cpp.Siblings))
  597. }
  598. func TestUpdateCircomProcessorProof(t *testing.T) {
  599. mt := newTestingMerkle(t, 10)
  600. for i := 0; i < 16; i++ {
  601. k := big.NewInt(int64(i))
  602. v := big.NewInt(int64(i * 2))
  603. if err := mt.Add(k, v); err != nil {
  604. t.Fatal(err)
  605. }
  606. }
  607. _, v, _, err := mt.Get(big.NewInt(10))
  608. assert.Nil(t, err)
  609. assert.Equal(t, big.NewInt(20), v)
  610. // test vectors generated using https://github.com/iden3/circomlib smt.js
  611. cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
  612. assert.Nil(t, err)
  613. assert.Equal(t, "39010880...", cpp.OldRoot.String())
  614. assert.Equal(t, "18587862...", cpp.NewRoot.String())
  615. assert.Equal(t, "10", cpp.OldKey.String())
  616. assert.Equal(t, "20", cpp.OldValue.String())
  617. assert.Equal(t, "10", cpp.NewKey.String())
  618. assert.Equal(t, "1024", cpp.NewValue.String())
  619. assert.Equal(t, false, cpp.IsOld0)
  620. assert.Equal(t,
  621. "[34930557... 20201609... 18790542... 15930030... 0 0 0 0 0 0 0]",
  622. fmt.Sprintf("%v", cpp.Siblings))
  623. }
  624. func TestSmtVerifier(t *testing.T) {
  625. mt := newTestingMerkle(t, 4)
  626. err := mt.Add(big.NewInt(1), big.NewInt(11))
  627. assert.Nil(t, err)
  628. cvp, err := mt.GenerateSCVerifierProof(big.NewInt(1), nil)
  629. assert.Nil(t, err)
  630. jCvp, err := json.Marshal(cvp)
  631. assert.Nil(t, err)
  632. // expect siblings to be '[]', instead of 'null'
  633. expected := `{"root":"6525056641794203554583616941316772618766382307684970171204065038799368146416","siblings":[],"oldKey":"0","oldValue":"0","isOld0":false,"key":"1","value":"11","fnc":0}` //nolint:lll
  634. assert.Equal(t, expected, string(jCvp))
  635. err = mt.Add(big.NewInt(2), big.NewInt(22))
  636. assert.Nil(t, err)
  637. err = mt.Add(big.NewInt(3), big.NewInt(33))
  638. assert.Nil(t, err)
  639. err = mt.Add(big.NewInt(4), big.NewInt(44))
  640. assert.Nil(t, err)
  641. cvp, err = mt.GenerateCircomVerifierProof(big.NewInt(2), nil)
  642. assert.Nil(t, err)
  643. jCvp, err = json.Marshal(cvp)
  644. assert.Nil(t, err)
  645. // Test vectors generated using https://github.com/iden3/circomlib smt.js
  646. // Expect siblings with the extra 0 that the circom circuits need
  647. expected = `{"root":"13558168455220559042747853958949063046226645447188878859760119761585093422436","siblings":["11620130507635441932056895853942898236773847390796721536119314875877874016518","5158240518874928563648144881543092238925265313977134167935552944620041388700","0","0","0"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll
  648. assert.Equal(t, expected, string(jCvp))
  649. cvp, err = mt.GenerateSCVerifierProof(big.NewInt(2), nil)
  650. assert.Nil(t, err)
  651. jCvp, err = json.Marshal(cvp)
  652. assert.Nil(t, err)
  653. // Test vectors generated using https://github.com/iden3/circomlib smt.js
  654. // Without the extra 0 that the circom circuits need, but that are not
  655. // needed at a smart contract verification
  656. expected = `{"root":"13558168455220559042747853958949063046226645447188878859760119761585093422436","siblings":["11620130507635441932056895853942898236773847390796721536119314875877874016518","5158240518874928563648144881543092238925265313977134167935552944620041388700"],"oldKey":"0","oldValue":"0","isOld0":false,"key":"2","value":"22","fnc":0}` //nolint:lll
  657. assert.Equal(t, expected, string(jCvp))
  658. }
  659. func TestTypesMarshalers(t *testing.T) {
  660. // test Hash marshalers
  661. h, err := merkletree.NewHashFromString("42")
  662. assert.Nil(t, err)
  663. s, err := json.Marshal(h)
  664. assert.Nil(t, err)
  665. var h2 *merkletree.Hash
  666. err = json.Unmarshal(s, &h2)
  667. assert.Nil(t, err)
  668. assert.Equal(t, h, h2)
  669. // create CircomProcessorProof
  670. mt := newTestingMerkle(t, 10)
  671. for i := 0; i < 16; i++ {
  672. k := big.NewInt(int64(i))
  673. v := big.NewInt(int64(i * 2))
  674. if err := mt.Add(k, v); err != nil {
  675. t.Fatal(err)
  676. }
  677. }
  678. _, v, _, err := mt.Get(big.NewInt(10))
  679. assert.Nil(t, err)
  680. assert.Equal(t, big.NewInt(20), v)
  681. cpp, err := mt.Update(big.NewInt(10), big.NewInt(1024))
  682. assert.Nil(t, err)
  683. // test CircomProcessorProof marshalers
  684. b, err := json.Marshal(&cpp)
  685. assert.Nil(t, err)
  686. var cpp2 *merkletree.CircomProcessorProof
  687. err = json.Unmarshal(b, &cpp2)
  688. assert.Nil(t, err)
  689. assert.Equal(t, cpp, cpp2)
  690. }