Browse Source

Merge pull request #406 from hermeznetwork/feature/implforgebatchargs

Implement Pipeline.prepareForgeBatchArgs()
feature/sql-semaphore1
arnau 3 years ago
committed by GitHub
parent
commit
3cf615a769
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 382 additions and 53 deletions
  1. +3
    -2
      common/zk.go
  2. +12
    -9
      coordinator/batch.go
  3. +43
    -7
      coordinator/coordinator.go
  4. +76
    -1
      coordinator/coordinator_test.go
  5. +6
    -11
      db/statedb/txprocessors.go
  6. +149
    -1
      db/statedb/zkinputsgen_test.go
  7. +1
    -1
      go.mod
  8. +4
    -0
      go.sum
  9. +18
    -5
      prover/prover.go
  10. +6
    -1
      prover/prover_test.go
  11. +48
    -4
      test/proofserver/proofserver.go
  12. +14
    -9
      txselector/txselector.go
  13. +2
    -2
      txselector/txselector_test.go

+ 3
- 2
common/zk.go

@ -431,8 +431,9 @@ func (z ZKInputs) ToHashGlobalData() ([]byte, error) {
// [MAX_NLEVELS bits] oldLastIdx // [MAX_NLEVELS bits] oldLastIdx
oldLastIdx := make([]byte, bytesMaxLevels) oldLastIdx := make([]byte, bytesMaxLevels)
copy(oldLastIdx, z.OldLastIdx.Bytes())
b = append(b, SwapEndianness(oldLastIdx)...)
oldLastIdxBytes := z.OldLastIdx.Bytes()
copy(oldLastIdx[len(oldLastIdx)-len(oldLastIdxBytes):], oldLastIdxBytes)
b = append(b, oldLastIdx...)
// [MAX_NLEVELS bits] newLastIdx // [MAX_NLEVELS bits] newLastIdx
newLastIdx := make([]byte, bytesMaxLevels) newLastIdx := make([]byte, bytesMaxLevels)

+ 12
- 9
coordinator/batch.go

@ -26,15 +26,18 @@ const (
// BatchInfo contans the Batch information // BatchInfo contans the Batch information
type BatchInfo struct { type BatchInfo struct {
BatchNum common.BatchNum
ServerProof prover.Client
ZKInputs *common.ZKInputs
Proof *prover.Proof
PublicInputs []*big.Int
L1UserTxsExtra []common.L1Tx
L1CoordTxs []common.L1Tx
L2Txs []common.PoolL2Tx
ForgeBatchArgs *eth.RollupForgeBatchArgs
BatchNum common.BatchNum
ServerProof prover.Client
ZKInputs *common.ZKInputs
Proof *prover.Proof
PublicInputs []*big.Int
L1Batch bool
L1UserTxsExtra []common.L1Tx
L1CoordTxs []common.L1Tx
L1CoordinatorTxsAuths [][]byte
L2Txs []common.L2Tx
CoordIdxs []common.Idx
ForgeBatchArgs *eth.RollupForgeBatchArgs
// FeesInfo // FeesInfo
TxStatus TxStatus TxStatus TxStatus
EthTx *types.Transaction EthTx *types.Transaction

+ 43
- 7
coordinator/coordinator.go

@ -3,6 +3,7 @@ package coordinator
import ( import (
"context" "context"
"fmt" "fmt"
"math/big"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -772,7 +773,15 @@ func (p *Pipeline) Stop(ctx context.Context) {
} }
} }
func l2TxsIDs(txs []common.PoolL2Tx) []common.TxID {
func poolL2TxsIDs(txs []common.PoolL2Tx) []common.TxID {
txIDs := make([]common.TxID, len(txs))
for i, tx := range txs {
txIDs[i] = tx.TxID
}
return txIDs
}
func l2TxsIDs(txs []common.L2Tx) []common.TxID {
txIDs := make([]common.TxID, len(txs)) txIDs := make([]common.TxID, len(txs))
for i, tx := range txs { for i, tx := range txs {
txIDs[i] = tx.TxID txIDs[i] = tx.TxID
@ -810,9 +819,11 @@ func (p *Pipeline) forgeBatch(ctx context.Context, batchNum common.BatchNum, sel
var poolL2Txs []common.PoolL2Tx var poolL2Txs []common.PoolL2Tx
// var feesInfo // var feesInfo
var l1UserTxsExtra, l1CoordTxs []common.L1Tx var l1UserTxsExtra, l1CoordTxs []common.L1Tx
var auths [][]byte
var coordIdxs []common.Idx var coordIdxs []common.Idx
// 1. Decide if we forge L2Tx or L1+L2Tx // 1. Decide if we forge L2Tx or L1+L2Tx
if p.shouldL1L2Batch() { if p.shouldL1L2Batch() {
batchInfo.L1Batch = true
p.lastScheduledL1BatchBlockNum = p.stats.Eth.LastBlock.Num p.lastScheduledL1BatchBlockNum = p.stats.Eth.LastBlock.Num
// 2a: L1+L2 txs // 2a: L1+L2 txs
p.lastForgeL1TxsNum++ p.lastForgeL1TxsNum++
@ -821,14 +832,14 @@ func (p *Pipeline) forgeBatch(ctx context.Context, batchNum common.BatchNum, sel
return nil, tracerr.Wrap(err) return nil, tracerr.Wrap(err)
} }
// TODO once feesInfo is added to method return, add the var // TODO once feesInfo is added to method return, add the var
coordIdxs, l1UserTxsExtra, l1CoordTxs, poolL2Txs, err =
coordIdxs, auths, l1UserTxsExtra, l1CoordTxs, poolL2Txs, err =
p.txSelector.GetL1L2TxSelection(selectionConfig, batchNum, l1UserTxs) p.txSelector.GetL1L2TxSelection(selectionConfig, batchNum, l1UserTxs)
if err != nil { if err != nil {
return nil, tracerr.Wrap(err) return nil, tracerr.Wrap(err)
} }
} else { } else {
// 2b: only L2 txs // 2b: only L2 txs
coordIdxs, l1CoordTxs, poolL2Txs, err =
coordIdxs, auths, l1CoordTxs, poolL2Txs, err =
p.txSelector.GetL2TxSelection(selectionConfig, batchNum) p.txSelector.GetL2TxSelection(selectionConfig, batchNum)
if err != nil { if err != nil {
return nil, tracerr.Wrap(err) return nil, tracerr.Wrap(err)
@ -840,9 +851,10 @@ func (p *Pipeline) forgeBatch(ctx context.Context, batchNum common.BatchNum, sel
// TODO feesInfo // TODO feesInfo
batchInfo.L1UserTxsExtra = l1UserTxsExtra batchInfo.L1UserTxsExtra = l1UserTxsExtra
batchInfo.L1CoordTxs = l1CoordTxs batchInfo.L1CoordTxs = l1CoordTxs
batchInfo.L2Txs = poolL2Txs
batchInfo.L1CoordinatorTxsAuths = auths
batchInfo.CoordIdxs = coordIdxs
if err := p.l2DB.StartForging(l2TxsIDs(batchInfo.L2Txs), batchInfo.BatchNum); err != nil {
if err := p.l2DB.StartForging(poolL2TxsIDs(poolL2Txs), batchInfo.BatchNum); err != nil {
return nil, tracerr.Wrap(err) return nil, tracerr.Wrap(err)
} }
@ -864,6 +876,11 @@ func (p *Pipeline) forgeBatch(ctx context.Context, batchNum common.BatchNum, sel
if err != nil { if err != nil {
return nil, tracerr.Wrap(err) return nil, tracerr.Wrap(err)
} }
l2Txs, err := common.PoolL2TxsToL2Txs(poolL2Txs) // NOTE: This is a big uggly, find a better way
if err != nil {
return nil, tracerr.Wrap(err)
}
batchInfo.L2Txs = l2Txs
// 5. Save metadata from BatchBuilder output for BatchNum // 5. Save metadata from BatchBuilder output for BatchNum
batchInfo.ZKInputs = zkInputs batchInfo.ZKInputs = zkInputs
@ -903,6 +920,25 @@ func (p *Pipeline) shouldL1L2Batch() bool {
} }
func (p *Pipeline) prepareForgeBatchArgs(batchInfo *BatchInfo) *eth.RollupForgeBatchArgs { func (p *Pipeline) prepareForgeBatchArgs(batchInfo *BatchInfo) *eth.RollupForgeBatchArgs {
// TODO
return &eth.RollupForgeBatchArgs{}
proof := batchInfo.Proof
zki := batchInfo.ZKInputs
return &eth.RollupForgeBatchArgs{
NewLastIdx: int64(zki.Metadata.NewLastIdxRaw),
NewStRoot: zki.Metadata.NewStateRootRaw.BigInt(),
NewExitRoot: zki.Metadata.NewExitRootRaw.BigInt(),
L1UserTxs: batchInfo.L1UserTxsExtra,
L1CoordinatorTxs: batchInfo.L1CoordTxs,
L1CoordinatorTxsAuths: batchInfo.L1CoordinatorTxsAuths,
L2TxsData: batchInfo.L2Txs,
FeeIdxCoordinator: batchInfo.CoordIdxs,
// Circuit selector
VerifierIdx: 0, // TODO
L1Batch: batchInfo.L1Batch,
ProofA: [2]*big.Int{proof.PiA[0], proof.PiA[1]},
ProofB: [2][2]*big.Int{
{proof.PiB[0][0], proof.PiB[0][1]},
{proof.PiB[1][0], proof.PiB[1][1]},
},
ProofC: [2]*big.Int{proof.PiC[0], proof.PiC[1]},
}
} }

+ 76
- 1
coordinator/coordinator_test.go

@ -6,6 +6,7 @@ import (
"io/ioutil" "io/ioutil"
"math/big" "math/big"
"os" "os"
"sync"
"testing" "testing"
"time" "time"
@ -157,11 +158,21 @@ func newTestCoordinator(t *testing.T, forgerAddr ethCommon.Address, ethClient *t
ConfirmBlocks: 5, ConfirmBlocks: 5,
L1BatchTimeoutPerc: 0.5, L1BatchTimeoutPerc: 0.5,
EthClientAttempts: 5, EthClientAttempts: 5,
SyncRetryInterval: 400 * time.Microsecond,
EthClientAttemptsDelay: 100 * time.Millisecond, EthClientAttemptsDelay: 100 * time.Millisecond,
TxManagerCheckInterval: 300 * time.Millisecond, TxManagerCheckInterval: 300 * time.Millisecond,
DebugBatchPath: debugBatchPath, DebugBatchPath: debugBatchPath,
Purger: PurgerCfg{
PurgeBatchDelay: 10,
PurgeBlockDelay: 10,
InvalidateBatchDelay: 4,
InvalidateBlockDelay: 4,
},
}
serverProofs := []prover.Client{
&prover.MockClient{Delay: 300 * time.Millisecond},
&prover.MockClient{Delay: 400 * time.Millisecond},
} }
serverProofs := []prover.Client{&prover.MockClient{}, &prover.MockClient{}}
scConsts := &synchronizer.SCConsts{ scConsts := &synchronizer.SCConsts{
Rollup: *ethClientSetup.RollupConstants, Rollup: *ethClientSetup.RollupConstants,
@ -628,6 +639,70 @@ PoolTransfer(0) User2-User3: 300 (126)
assert.Equal(t, 0, len(batchInfo.L2Txs)) assert.Equal(t, 0, len(batchInfo.L2Txs))
} }
func TestCoordinatorStress(t *testing.T) {
if os.Getenv("TEST_COORD_STRESS") == "" {
return
}
log.Info("Begin Test Coord Stress")
ethClientSetup := test.NewClientSetupExample()
var timer timer
ethClient := test.NewClient(true, &timer, &bidder, ethClientSetup)
modules := newTestModules(t)
coord := newTestCoordinator(t, forger, ethClient, ethClientSetup, modules)
syn := newTestSynchronizer(t, ethClient, ethClientSetup, modules)
coord.Start()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
// Synchronizer loop
wg.Add(1)
go func() {
for {
blockData, _, err := syn.Sync2(ctx, nil)
if ctx.Err() != nil {
wg.Done()
return
}
require.NoError(t, err)
if blockData != nil {
stats := syn.Stats()
coord.SendMsg(MsgSyncBlock{
Stats: *stats,
Batches: blockData.Rollup.Batches,
Vars: synchronizer.SCVariablesPtr{
Rollup: blockData.Rollup.Vars,
Auction: blockData.Auction.Vars,
WDelayer: blockData.WDelayer.Vars,
},
})
} else {
time.Sleep(100 * time.Millisecond)
}
}
}()
// Blockchain mining loop
wg.Add(1)
go func() {
for {
select {
case <-ctx.Done():
wg.Done()
return
case <-time.After(100 * time.Millisecond):
ethClient.CtlMineBlock()
}
}
}()
time.Sleep(600 * time.Second)
cancel()
wg.Wait()
coord.Stop()
}
// TODO: Test Reorg // TODO: Test Reorg
// TODO: Test Pipeline // TODO: Test Pipeline
// TODO: Test TxMonitor // TODO: Test TxMonitor

+ 6
- 11
db/statedb/txprocessors.go

@ -82,16 +82,6 @@ func (s *StateDB) ProcessTxs(ptc ProcessTxsConfig, coordIdxs []common.Idx, l1use
s.accumulatedFees = make(map[common.Idx]*big.Int) s.accumulatedFees = make(map[common.Idx]*big.Int)
nTx := len(l1usertxs) + len(l1coordinatortxs) + len(l2txs) nTx := len(l1usertxs) + len(l1coordinatortxs) + len(l2txs)
if nTx == 0 {
// TODO return ZKInputs of batch without txs
return &ProcessTxOutput{
ZKInputs: nil,
ExitInfos: nil,
CreatedAccounts: nil,
CoordinatorIdxsMap: nil,
CollectedFees: nil,
}, nil
}
if nTx > int(ptc.MaxTx) { if nTx > int(ptc.MaxTx) {
return nil, tracerr.Wrap(fmt.Errorf("L1UserTx + L1CoordinatorTx + L2Tx (%d) can not be bigger than MaxTx (%d)", nTx, ptc.MaxTx)) return nil, tracerr.Wrap(fmt.Errorf("L1UserTx + L1CoordinatorTx + L2Tx (%d) can not be bigger than MaxTx (%d)", nTx, ptc.MaxTx))
@ -106,6 +96,7 @@ func (s *StateDB) ProcessTxs(ptc ProcessTxsConfig, coordIdxs []common.Idx, l1use
s.zki = common.NewZKInputs(ptc.MaxTx, ptc.MaxL1Tx, ptc.MaxTx, ptc.MaxFeeTx, ptc.NLevels, s.currentBatch.BigInt()) s.zki = common.NewZKInputs(ptc.MaxTx, ptc.MaxL1Tx, ptc.MaxTx, ptc.MaxFeeTx, ptc.NLevels, s.currentBatch.BigInt())
s.zki.OldLastIdx = s.idx.BigInt() s.zki.OldLastIdx = s.idx.BigInt()
s.zki.OldStateRoot = s.mt.Root().BigInt() s.zki.OldStateRoot = s.mt.Root().BigInt()
s.zki.Metadata.NewLastIdxRaw = s.idx
} }
// TBD if ExitTree is only in memory or stored in disk, for the moment // TBD if ExitTree is only in memory or stored in disk, for the moment
@ -272,7 +263,11 @@ func (s *StateDB) ProcessTxs(ptc ProcessTxsConfig, coordIdxs []common.Idx, l1use
} }
if s.zki != nil { if s.zki != nil {
for i := s.i - 1; i < int(ptc.MaxTx); i++ {
last := s.i - 1
if s.i == 0 {
last = 0
}
for i := last; i < int(ptc.MaxTx); i++ {
if i < int(ptc.MaxTx)-1 { if i < int(ptc.MaxTx)-1 {
s.zki.ISOutIdx[i] = s.idx.BigInt() s.zki.ISOutIdx[i] = s.idx.BigInt()
s.zki.ISStateRoot[i] = s.mt.Root().BigInt() s.zki.ISStateRoot[i] = s.mt.Root().BigInt()

+ 149
- 1
db/statedb/zkinputsgen_test.go
File diff suppressed because it is too large
View File


+ 1
- 1
go.mod

@ -12,7 +12,7 @@ require (
github.com/gin-contrib/cors v1.3.1 github.com/gin-contrib/cors v1.3.1
github.com/gin-gonic/gin v1.5.0 github.com/gin-gonic/gin v1.5.0
github.com/go-sql-driver/mysql v1.5.0 // indirect github.com/go-sql-driver/mysql v1.5.0 // indirect
github.com/gobuffalo/packr/v2 v2.8.0
github.com/gobuffalo/packr/v2 v2.8.1
github.com/hermeznetwork/tracerr v0.3.1-0.20201126162137-de9930d0cf29 github.com/hermeznetwork/tracerr v0.3.1-0.20201126162137-de9930d0cf29
github.com/iden3/go-iden3-crypto v0.0.6-0.20201221160344-58e589b6eb4c github.com/iden3/go-iden3-crypto v0.0.6-0.20201221160344-58e589b6eb4c
github.com/iden3/go-merkletree v0.0.0-20201215142017-730707e5659a github.com/iden3/go-merkletree v0.0.0-20201215142017-730707e5659a

+ 4
- 0
go.sum

@ -237,6 +237,8 @@ github.com/gobuffalo/packd v1.0.0/go.mod h1:6VTc4htmJRFB7u1m/4LeMTWjFoYrUiBkU9Fd
github.com/gobuffalo/packr/v2 v2.7.1/go.mod h1:qYEvAazPaVxy7Y7KR0W8qYEE+RymX74kETFqjFoFlOc= github.com/gobuffalo/packr/v2 v2.7.1/go.mod h1:qYEvAazPaVxy7Y7KR0W8qYEE+RymX74kETFqjFoFlOc=
github.com/gobuffalo/packr/v2 v2.8.0 h1:IULGd15bQL59ijXLxEvA5wlMxsmx/ZkQv9T282zNVIY= github.com/gobuffalo/packr/v2 v2.8.0 h1:IULGd15bQL59ijXLxEvA5wlMxsmx/ZkQv9T282zNVIY=
github.com/gobuffalo/packr/v2 v2.8.0/go.mod h1:PDk2k3vGevNE3SwVyVRgQCCXETC9SaONCNSXT1Q8M1g= github.com/gobuffalo/packr/v2 v2.8.0/go.mod h1:PDk2k3vGevNE3SwVyVRgQCCXETC9SaONCNSXT1Q8M1g=
github.com/gobuffalo/packr/v2 v2.8.1 h1:tkQpju6i3EtMXJ9uoF5GT6kB+LMTimDWD8Xvbz6zDVA=
github.com/gobuffalo/packr/v2 v2.8.1/go.mod h1:c/PLlOuTU+p3SybaJATW3H6lX/iK7xEz5OeMf+NnJpg=
github.com/godror/godror v0.13.3/go.mod h1:2ouUT4kdhUBk7TAkHWD4SN0CdI0pgEQbo8FVHhbSKWg= github.com/godror/godror v0.13.3/go.mod h1:2ouUT4kdhUBk7TAkHWD4SN0CdI0pgEQbo8FVHhbSKWg=
github.com/gofrs/flock v0.7.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gofrs/flock v0.7.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s= github.com/gogo/googleapis v1.1.0/go.mod h1:gf4bu3Q80BeJ6H1S1vYPm8/ELATdvryBaNFGgqEef3s=
@ -392,6 +394,8 @@ github.com/karalabe/usb v0.0.0-20190919080040-51dc0efba356 h1:I/yrLt2WilKxlQKCM5
github.com/karalabe/usb v0.0.0-20190919080040-51dc0efba356/go.mod h1:Od972xHfMJowv7NGVDiWVxk2zxnWgjLlJzE+F4F7AGU= github.com/karalabe/usb v0.0.0-20190919080040-51dc0efba356/go.mod h1:Od972xHfMJowv7NGVDiWVxk2zxnWgjLlJzE+F4F7AGU=
github.com/karrick/godirwalk v1.15.3 h1:0a2pXOgtB16CqIqXTiT7+K9L73f74n/aNQUnH6Ortew= github.com/karrick/godirwalk v1.15.3 h1:0a2pXOgtB16CqIqXTiT7+K9L73f74n/aNQUnH6Ortew=
github.com/karrick/godirwalk v1.15.3/go.mod h1:j4mkqPuvaLI8mp1DroR3P6ad7cyYd4c1qeJ3RV7ULlk= github.com/karrick/godirwalk v1.15.3/go.mod h1:j4mkqPuvaLI8mp1DroR3P6ad7cyYd4c1qeJ3RV7ULlk=
github.com/karrick/godirwalk v1.15.8 h1:7+rWAZPn9zuRxaIqqT8Ohs2Q2Ac0msBqwRdxNCr2VVs=
github.com/karrick/godirwalk v1.15.8/go.mod h1:j4mkqPuvaLI8mp1DroR3P6ad7cyYd4c1qeJ3RV7ULlk=
github.com/kevinburke/ssh_config v0.0.0-20190725054713-01f96b0aa0cd/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= github.com/kevinburke/ssh_config v0.0.0-20190725054713-01f96b0aa0cd/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM=
github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q= github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q=
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=

+ 18
- 5
prover/prover.go

@ -16,9 +16,9 @@ import (
// Proof TBD this type will be received from the proof server // Proof TBD this type will be received from the proof server
type Proof struct { type Proof struct {
PiA [2]*big.Int `json:"pi_a"`
PiA [3]*big.Int `json:"pi_a"`
PiB [3][2]*big.Int `json:"pi_b"` PiB [3][2]*big.Int `json:"pi_b"`
PiC [2]*big.Int `json:"pi_c"`
PiC [3]*big.Int `json:"pi_c"`
Protocol string `json:"protocol"` Protocol string `json:"protocol"`
} }
@ -36,9 +36,9 @@ func (b *bigInt) UnmarshalText(text []byte) error {
// ints as strings // ints as strings
func (p *Proof) UnmarshalJSON(data []byte) error { func (p *Proof) UnmarshalJSON(data []byte) error {
proof := struct { proof := struct {
PiA [2]*bigInt `json:"pi_a"`
PiA [3]*bigInt `json:"pi_a"`
PiB [3][2]*bigInt `json:"pi_b"` PiB [3][2]*bigInt `json:"pi_b"`
PiC [2]*bigInt `json:"pi_c"`
PiC [3]*bigInt `json:"pi_c"`
Protocol string `json:"protocol"` Protocol string `json:"protocol"`
}{} }{}
if err := json.Unmarshal(data, &proof); err != nil { if err := json.Unmarshal(data, &proof); err != nil {
@ -46,14 +46,26 @@ func (p *Proof) UnmarshalJSON(data []byte) error {
} }
p.PiA[0] = (*big.Int)(proof.PiA[0]) p.PiA[0] = (*big.Int)(proof.PiA[0])
p.PiA[1] = (*big.Int)(proof.PiA[1]) p.PiA[1] = (*big.Int)(proof.PiA[1])
p.PiA[2] = (*big.Int)(proof.PiA[2])
if p.PiA[2].Int64() != 1 {
return fmt.Errorf("Expected PiA[2] == 1, but got %v", p.PiA[2])
}
p.PiB[0][0] = (*big.Int)(proof.PiB[0][0]) p.PiB[0][0] = (*big.Int)(proof.PiB[0][0])
p.PiB[0][1] = (*big.Int)(proof.PiB[0][1]) p.PiB[0][1] = (*big.Int)(proof.PiB[0][1])
p.PiB[1][0] = (*big.Int)(proof.PiB[1][0]) p.PiB[1][0] = (*big.Int)(proof.PiB[1][0])
p.PiB[1][1] = (*big.Int)(proof.PiB[1][1]) p.PiB[1][1] = (*big.Int)(proof.PiB[1][1])
p.PiB[2][0] = (*big.Int)(proof.PiB[2][0]) p.PiB[2][0] = (*big.Int)(proof.PiB[2][0])
p.PiB[2][1] = (*big.Int)(proof.PiB[2][1]) p.PiB[2][1] = (*big.Int)(proof.PiB[2][1])
if p.PiB[2][0].Int64() != 1 || p.PiB[2][1].Int64() != 0 {
return fmt.Errorf("Expected PiB[2] == [1, 0], but got %v", p.PiB[2])
}
p.PiC[0] = (*big.Int)(proof.PiC[0]) p.PiC[0] = (*big.Int)(proof.PiC[0])
p.PiC[1] = (*big.Int)(proof.PiC[1]) p.PiC[1] = (*big.Int)(proof.PiC[1])
p.PiC[2] = (*big.Int)(proof.PiC[2])
if p.PiC[2].Int64() != 1 {
return fmt.Errorf("Expected PiC[2] == 1, but got %v", p.PiC[2])
}
// TODO: Assert ones and zeroes
p.Protocol = proof.Protocol p.Protocol = proof.Protocol
return nil return nil
} }
@ -276,6 +288,7 @@ func (p *ProofServerClient) WaitReady(ctx context.Context) error {
// MockClient is a mock ServerProof to be used in tests. It doesn't calculate anything // MockClient is a mock ServerProof to be used in tests. It doesn't calculate anything
type MockClient struct { type MockClient struct {
Delay time.Duration
} }
// CalculateProof sends the *common.ZKInputs to the ServerProof to compute the // CalculateProof sends the *common.ZKInputs to the ServerProof to compute the
@ -288,7 +301,7 @@ func (p *MockClient) CalculateProof(ctx context.Context, zkInputs *common.ZKInpu
func (p *MockClient) GetProof(ctx context.Context) (*Proof, []*big.Int, error) { func (p *MockClient) GetProof(ctx context.Context) (*Proof, []*big.Int, error) {
// Simulate a delay // Simulate a delay
select { select {
case <-time.After(500 * time.Millisecond): //nolint:gomnd
case <-time.After(p.Delay): //nolint:gomnd
return &Proof{}, []*big.Int{big.NewInt(1234)}, nil //nolint:gomnd return &Proof{}, []*big.Int{big.NewInt(1234)}, nil //nolint:gomnd
case <-ctx.Done(): case <-ctx.Done():
return nil, nil, tracerr.Wrap(common.ErrDone) return nil, nil, tracerr.Wrap(common.ErrDone)

+ 6
- 1
prover/prover_test.go

@ -12,7 +12,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
const apiURL = "http://localhost:3000/api"
var apiURL = "http://localhost:3000/api"
const pollInterval = 1 * time.Second const pollInterval = 1 * time.Second
var proofServerClient *ProofServerClient var proofServerClient *ProofServerClient
@ -20,6 +21,10 @@ var proofServerClient *ProofServerClient
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
exitVal := 0 exitVal := 0
if os.Getenv("INTEGRATION") != "" { if os.Getenv("INTEGRATION") != "" {
_apiURL := os.Getenv("PROOF_SERVER_URL")
if _apiURL != "" {
apiURL = _apiURL
}
proofServerClient = NewProofServerClient(apiURL, pollInterval) proofServerClient = NewProofServerClient(apiURL, pollInterval)
err := proofServerClient.WaitReady(context.Background()) err := proofServerClient.WaitReady(context.Background())
if err != nil { if err != nil {

+ 48
- 4
test/proofserver/proofserver.go

@ -67,6 +67,50 @@ func (s *Mock) handleCancel(c *gin.Context) {
c.JSON(http.StatusOK, "OK") c.JSON(http.StatusOK, "OK")
} }
/* Status example from the real server proof:
Status:
{
"proof": "{\n \"pi_a\": [\n \"1368015179489954701390400359078579693043519447331113978918064868415326638035\",\n \"9918110051302171585080402603319702774565515993150576347155970296011118125764\",\n \"1\"\n ],\n \"pi_b\": [\n [\n \"10857046999023057135944570762232829481370756359578518086990519993285655852781\",\n \"11559732032986387107991004021392285783925812861821192530917403151452391805634\"\n ],\n [\n \"8495653923123431417604973247489272438418190587263600148770280649306958101930\",\n \"4082367875863433681332203403145435568316851327593401208105741076214120093531\"\n ],\n [\n \"1\",\n \"0\"\n ]\n ],\n \"pi_c\": [\n \"1368015179489954701390400359078579693043519447331113978918064868415326638035\",\n \"9918110051302171585080402603319702774565515993150576347155970296011118125764\",\n \"1\"\n ],\n \"protocol\": \"groth\"\n}\n",
"pubData": "[\n \"8863150934551775031093873719629424744398133643983814385850330952980893030086\"\n]\n",
"status": "success"
}
proof:
{
"pi_a": [
"1368015179489954701390400359078579693043519447331113978918064868415326638035",
"9918110051302171585080402603319702774565515993150576347155970296011118125764",
"1"
],
"pi_b": [
[
"10857046999023057135944570762232829481370756359578518086990519993285655852781",
"11559732032986387107991004021392285783925812861821192530917403151452391805634"
],
[
"8495653923123431417604973247489272438418190587263600148770280649306958101930",
"4082367875863433681332203403145435568316851327593401208105741076214120093531"
],
[
"1",
"0"
]
],
"pi_c": [
"1368015179489954701390400359078579693043519447331113978918064868415326638035",
"9918110051302171585080402603319702774565515993150576347155970296011118125764",
"1"
],
"protocol": "groth"
}
pubData:
[
"8863150934551775031093873719629424744398133643983814385850330952980893030086"
]
*/
func (s *Mock) handleStatus(c *gin.Context) { func (s *Mock) handleStatus(c *gin.Context) {
s.RLock() s.RLock()
c.JSON(http.StatusOK, prover.Status{ c.JSON(http.StatusOK, prover.Status{
@ -133,11 +177,11 @@ func (s *Mock) runProver(ctx context.Context) {
s.counter++ s.counter++
// Mock data // Mock data
s.proof = fmt.Sprintf(`{ s.proof = fmt.Sprintf(`{
"pi_a": ["%v", "%v"],
"pi_b": [["%v", "%v"],["%v", "%v"],["%v", "%v"]],
"pi_c": ["%v", "%v"],
"pi_a": ["%v", "%v", "1"],
"pi_b": [["%v", "%v"],["%v", "%v"],["1", "0"]],
"pi_c": ["%v", "%v", "1"],
"protocol": "groth16" "protocol": "groth16"
}`, i, i+1, i+2, i+3, i+4, i+5, i+6, i+7, i+8, i+9) //nolint:gomnd
}`, i, i+1, i+2, i+3, i+4, i+5, i+6, i+7) //nolint:gomnd
s.pubData = fmt.Sprintf(`[ s.pubData = fmt.Sprintf(`[
"%v" "%v"
]`, i+42) //nolint:gomnd ]`, i+42) //nolint:gomnd

+ 14
- 9
txselector/txselector.go

@ -151,15 +151,15 @@ func (txsel *TxSelector) GetCoordIdxs() (map[common.TokenID]common.Idx, error) {
// GetL2TxSelection returns the L1CoordinatorTxs and a selection of the L2Txs // GetL2TxSelection returns the L1CoordinatorTxs and a selection of the L2Txs
// for the next batch, from the L2DB pool // for the next batch, from the L2DB pool
func (txsel *TxSelector) GetL2TxSelection(selectionConfig *SelectionConfig, func (txsel *TxSelector) GetL2TxSelection(selectionConfig *SelectionConfig,
batchNum common.BatchNum) ([]common.Idx, []common.L1Tx, []common.PoolL2Tx, error) {
coordIdxs, _, l1CoordinatorTxs, l2Txs, err := txsel.GetL1L2TxSelection(selectionConfig, batchNum,
batchNum common.BatchNum) ([]common.Idx, [][]byte, []common.L1Tx, []common.PoolL2Tx, error) {
coordIdxs, auths, _, l1CoordinatorTxs, l2Txs, err := txsel.GetL1L2TxSelection(selectionConfig, batchNum,
[]common.L1Tx{}) []common.L1Tx{})
return coordIdxs, l1CoordinatorTxs, l2Txs, tracerr.Wrap(err)
return coordIdxs, auths, l1CoordinatorTxs, l2Txs, tracerr.Wrap(err)
} }
// GetL1L2TxSelection returns the selection of L1 + L2 txs // GetL1L2TxSelection returns the selection of L1 + L2 txs
func (txsel *TxSelector) GetL1L2TxSelection(selectionConfig *SelectionConfig, func (txsel *TxSelector) GetL1L2TxSelection(selectionConfig *SelectionConfig,
batchNum common.BatchNum, l1Txs []common.L1Tx) ([]common.Idx, []common.L1Tx, []common.L1Tx,
batchNum common.BatchNum, l1Txs []common.L1Tx) ([]common.Idx, [][]byte, []common.L1Tx, []common.L1Tx,
[]common.PoolL2Tx, error) { []common.PoolL2Tx, error) {
// apply l1-user-tx to localAccountDB // apply l1-user-tx to localAccountDB
// create new leaves // create new leaves
@ -169,7 +169,7 @@ func (txsel *TxSelector) GetL1L2TxSelection(selectionConfig *SelectionConfig,
// get existing CoordIdxs // get existing CoordIdxs
coordIdxsMap, err := txsel.GetCoordIdxs() coordIdxsMap, err := txsel.GetCoordIdxs()
if err != nil { if err != nil {
return nil, nil, nil, nil, tracerr.Wrap(err)
return nil, nil, nil, nil, nil, tracerr.Wrap(err)
} }
var coordIdxs []common.Idx var coordIdxs []common.Idx
for tokenID := range coordIdxsMap { for tokenID := range coordIdxsMap {
@ -179,7 +179,7 @@ func (txsel *TxSelector) GetL1L2TxSelection(selectionConfig *SelectionConfig,
// get pending l2-tx from tx-pool // get pending l2-tx from tx-pool
l2TxsRaw, err := txsel.l2db.GetPendingTxs() // (batchID) l2TxsRaw, err := txsel.l2db.GetPendingTxs() // (batchID)
if err != nil { if err != nil {
return nil, nil, nil, nil, tracerr.Wrap(err)
return nil, nil, nil, nil, nil, tracerr.Wrap(err)
} }
var validTxs txs var validTxs txs
@ -235,14 +235,19 @@ func (txsel *TxSelector) GetL1L2TxSelection(selectionConfig *SelectionConfig,
// process the txs in the local AccountsDB // process the txs in the local AccountsDB
_, err = txsel.localAccountsDB.ProcessTxs(ptc, coordIdxs, l1Txs, l1CoordinatorTxs, l2Txs) _, err = txsel.localAccountsDB.ProcessTxs(ptc, coordIdxs, l1Txs, l1CoordinatorTxs, l2Txs)
if err != nil { if err != nil {
return nil, nil, nil, nil, tracerr.Wrap(err)
return nil, nil, nil, nil, nil, tracerr.Wrap(err)
} }
err = txsel.localAccountsDB.MakeCheckpoint() err = txsel.localAccountsDB.MakeCheckpoint()
if err != nil { if err != nil {
return nil, nil, nil, nil, tracerr.Wrap(err)
return nil, nil, nil, nil, nil, tracerr.Wrap(err)
} }
return nil, l1Txs, l1CoordinatorTxs, l2Txs, nil
// TODO
auths := make([][]byte, len(l1CoordinatorTxs))
for i := range auths {
auths[i] = make([]byte, 65)
}
return nil, auths, l1Txs, l1CoordinatorTxs, l2Txs, nil
} }
// processTxsToEthAddrBJJ process the common.PoolL2Tx in the case where // processTxsToEthAddrBJJ process the common.PoolL2Tx in the case where

+ 2
- 2
txselector/txselector_test.go

@ -133,12 +133,12 @@ func TestGetL2TxSelection(t *testing.T) {
// add the 1st batch of transactions to the TxSelector // add the 1st batch of transactions to the TxSelector
addL2Txs(t, txsel, common.L2TxsToPoolL2Txs(blocks[0].Rollup.Batches[0].L2Txs)) addL2Txs(t, txsel, common.L2TxsToPoolL2Txs(blocks[0].Rollup.Batches[0].L2Txs))
_, l1CoordTxs, l2Txs, err := txsel.GetL2TxSelection(selectionConfig, 0)
_, _, l1CoordTxs, l2Txs, err := txsel.GetL2TxSelection(selectionConfig, 0)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 0, len(l2Txs)) assert.Equal(t, 0, len(l2Txs))
assert.Equal(t, 0, len(l1CoordTxs)) assert.Equal(t, 0, len(l1CoordTxs))
_, _, _, _, err = txsel.GetL1L2TxSelection(selectionConfig, 0, blocks[0].Rollup.L1UserTxs)
_, _, _, _, _, err = txsel.GetL1L2TxSelection(selectionConfig, 0, blocks[0].Rollup.L1UserTxs)
assert.NoError(t, err) assert.NoError(t, err)
// TODO once L2DB is updated to return error in case that AddTxTest // TODO once L2DB is updated to return error in case that AddTxTest

Loading…
Cancel
Save