diff --git a/fri/fri.go b/fri/fri.go index e71d6e3..2e37f0f 100644 --- a/fri/fri.go +++ b/fri/fri.go @@ -16,7 +16,7 @@ import ( type Chip struct { api frontend.API `gnark:"-"` - gl gl.Chip `gnark:"-"` + gl *gl.Chip `gnark:"-"` poseidonBN254Chip *poseidon.BN254Chip `gnark:"-"` commonData *types.CommonCircuitData `gnark:"-"` friParams *types.FriParams `gnark:"-"` @@ -33,7 +33,7 @@ func NewChip( poseidonBN254Chip: poseidonBN254Chip, commonData: commonData, friParams: friParams, - gl: *gl.New(api), + gl: gl.New(api), } } diff --git a/goldilocks/base.go b/goldilocks/base.go index 1908325..ed63eb6 100644 --- a/goldilocks/base.go +++ b/goldilocks/base.go @@ -19,6 +19,8 @@ import ( "math" "math/big" "os" + "strconv" + "sync" "github.com/consensys/gnark-crypto/field/goldilocks" "github.com/consensys/gnark/constraint/solver" @@ -40,7 +42,13 @@ var POWER_OF_TWO_GENERATOR goldilocks.Element = goldilocks.NewElement(1753635133 var MODULUS *big.Int = emulated.Goldilocks{}.Modulus() // The number of bits to use for range checks on inner products of field elements. -var RANGE_CHECK_NB_BITS int = 140 +// This MUST be a multiple of EXPECTED_OPTIMAL_BASEWIDTH if the commit based range checker is used. +// There is a bug in the pre 0.9.2 gnark range checker where it wouldn't appropriately range check a bitwidth that +// is misaligned from EXPECTED_OPTIMAL_BASEWIDTH: https://github.com/Consensys/gnark/security/advisories/GHSA-rjjm-x32p-m3f7 +var RANGE_CHECK_NB_BITS int = 144 + +// The bit width size that the gnark commit based range checker should use. +var EXPECTED_OPTIMAL_BASEWIDTH int = 16 // Registers the hint functions with the solver. func init() { @@ -76,25 +84,78 @@ func NegOne() Variable { return NewVariable(MODULUS.Uint64() - 1) } +type RangeCheckerType int + +const ( + NATIVE_RANGE_CHECKER RangeCheckerType = iota + COMMIT_RANGE_CHECKER + BIT_DECOMP_RANGE_CHECKER +) + // The chip used for Goldilocks field operations. type Chip struct { - api frontend.API - rangeChecker frontend.Rangechecker + api frontend.API + + rangeChecker frontend.Rangechecker + rangeCheckerType RangeCheckerType + + rangeCheckCollected []checkedVariable // These field are used if rangeCheckerType == commit_range_checker + collectedMutex sync.Mutex } +var ( + poseidonChips = make(map[frontend.API]*Chip) + mutex sync.Mutex +) + // Creates a new Goldilocks Chip. func New(api frontend.API) *Chip { - use_bit_decomp := os.Getenv("USE_BIT_DECOMPOSITION_RANGE_CHECK") + mutex.Lock() + defer mutex.Unlock() - var rangeChecker frontend.Rangechecker + if chip, ok := poseidonChips[api]; ok { + return chip + } - // If USE_BIT_DECOMPOSITION_RANGE_CHECK is not set, then use the std.rangecheck New function - if use_bit_decomp == "" { - rangeChecker = rangecheck.New(api) + c := &Chip{api: api} + + // Instantiate the range checker gadget + // Per Gnark's range checker gadget's New function, there are three possible range checkers: + // 1. The native range checker + // 2. The commit range checker + // 3. The bit decomposition range checker + // + // See https://github.com/Consensys/gnark/blob/3421eaa7d544286abf3de8c46282b8d4da6d5da0/std/rangecheck/rangecheck.go#L3 + + // This function will emulate gnark's range checker selection logic (within the gnarkRangeCheckSelector func). However, + // if the USE_BIT_DECOMPOSITION_RANGE_CHECK env var is set, then it will explicitly use the bit decomposition range checker. + + rangeCheckerType := gnarkRangeCheckerSelector(api) + useBitDecomp := os.Getenv("USE_BIT_DECOMPOSITION_RANGE_CHECK") + if useBitDecomp == "true" { + fmt.Println("The USE_BIT_DECOMPOSITION_RANGE_CHECK env var is set to true. Using the bit decomposition range checker.") + rangeCheckerType = BIT_DECOMP_RANGE_CHECKER + } + + c.rangeCheckerType = rangeCheckerType + + // If we are using the bit decomposition range checker, then create bitDecompChecker object + if c.rangeCheckerType == BIT_DECOMP_RANGE_CHECKER { + c.rangeChecker = bitDecompChecker{api: api} } else { - rangeChecker = bitDecompChecker{api: api} + if c.rangeCheckerType == COMMIT_RANGE_CHECKER { + api.Compiler().Defer(c.checkCollected) + } + + // If we are using the native or commit range checker, then have gnark's range checker gadget's New function create it. + // Also, note that the range checker will need to be created AFTER the c.checkCollected function is deferred. + // The commit range checker gadget will also call a deferred function, which needs to be called after c.checkCollected. + c.rangeChecker = rangecheck.New(api) } - return &Chip{api: api, rangeChecker: rangeChecker} + + poseidonChips[api] = c + + return c } // Adds two goldilocks field elements and returns a value within the goldilocks field. @@ -209,7 +270,7 @@ func (p *Chip) ReduceWithMaxBits(x Variable, maxNbBits uint64) Variable { } quotient := result[0] - p.rangeChecker.Check(quotient, int(maxNbBits)) + p.rangeCheckerCheck(quotient, int(maxNbBits)) remainder := NewVariable(result[1]) p.RangeCheck(remainder) @@ -321,8 +382,8 @@ func (p *Chip) RangeCheck(x Variable) { ), x.Limb, ) - p.rangeChecker.Check(mostSigLimb, 32) - p.rangeChecker.Check(leastSigLimb, 32) + p.rangeCheckerCheck(mostSigLimb, 32) + p.rangeCheckerCheck(leastSigLimb, 32) // If the most significant bits are all 1, then we need to check that the least significant bits are all zero // in order for element to be less than the Goldilock's modulus. @@ -340,13 +401,46 @@ func (p *Chip) RangeCheck(x Variable) { // This function will assert that the field element x is less than 2^maxNbBits. func (p *Chip) RangeCheckWithMaxBits(x Variable, maxNbBits uint64) { - p.rangeChecker.Check(x.Limb, int(maxNbBits)) + p.rangeCheckerCheck(x.Limb, int(maxNbBits)) } func (p *Chip) AssertIsEqual(x, y Variable) { p.api.AssertIsEqual(x.Limb, y.Limb) } +func (p *Chip) rangeCheckerCheck(x frontend.Variable, nbBits int) { + switch p.rangeCheckerType { + case NATIVE_RANGE_CHECKER: + case BIT_DECOMP_RANGE_CHECKER: + p.rangeChecker.Check(x, nbBits) + case COMMIT_RANGE_CHECKER: + p.collectedMutex.Lock() + defer p.collectedMutex.Unlock() + p.rangeCheckCollected = append(p.rangeCheckCollected, checkedVariable{v: x, bits: nbBits}) + } +} + +func (p *Chip) checkCollected(api frontend.API) error { + if p.rangeCheckerType != COMMIT_RANGE_CHECKER { + panic("checkCollected should only be called when using the commit range checker") + } + + nbBits := getOptimalBasewidth(p.api, p.rangeCheckCollected) + if nbBits != EXPECTED_OPTIMAL_BASEWIDTH { + panic("nbBits should be " + strconv.Itoa(EXPECTED_OPTIMAL_BASEWIDTH)) + } + + for _, v := range p.rangeCheckCollected { + if v.bits%nbBits != 0 { + panic("v.bits is not nbBits aligned") + } + + p.rangeChecker.Check(v.v, v.bits) + } + + return nil +} + // Computes the n'th primitive root of unity for the Goldilocks field. func PrimitiveRootOfUnity(nLog uint64) goldilocks.Element { if nLog > TWO_ADICITY { diff --git a/goldilocks/quadratic_extension_algebra.go b/goldilocks/quadratic_extension_algebra.go index d3d5aa2..f9f28c0 100644 --- a/goldilocks/quadratic_extension_algebra.go +++ b/goldilocks/quadratic_extension_algebra.go @@ -47,7 +47,7 @@ func (p *Chip) SubExtensionAlgebra( return diff } -func (p Chip) MulExtensionAlgebra( +func (p *Chip) MulExtensionAlgebra( a QuadraticExtensionAlgebraVariable, b QuadraticExtensionAlgebraVariable, ) QuadraticExtensionAlgebraVariable { diff --git a/goldilocks/range_checker_utils.go b/goldilocks/range_checker_utils.go new file mode 100644 index 0000000..594d83e --- /dev/null +++ b/goldilocks/range_checker_utils.go @@ -0,0 +1,89 @@ +package goldilocks + +import ( + "math" + + "github.com/consensys/gnark/frontend" +) + +// The types, structs, and functions in this file were ported over from the gnark library +// https://github.com/Consensys/gnark/blob/3421eaa7d544286abf3de8c46282b8d4da6d5da0/std/rangecheck/rangecheck_commit.go +type Type int + +const ( + R1CS Type = iota + SCS +) + +type FrontendTyper interface { + FrontendType() Type +} + +type checkedVariable struct { + v frontend.Variable + bits int +} + +func getOptimalBasewidth(api frontend.API, collected []checkedVariable) int { + if ft, ok := api.(FrontendTyper); ok { + switch ft.FrontendType() { + case R1CS: + return optimalWidth(nbR1CSConstraints, collected) + case SCS: + return optimalWidth(nbPLONKConstraints, collected) + } + } + return optimalWidth(nbR1CSConstraints, collected) +} + +func optimalWidth(countFn func(baseLength int, collected []checkedVariable) int, collected []checkedVariable) int { + min := math.MaxInt64 + minVal := 0 + for j := 2; j < 18; j++ { + current := countFn(j, collected) + if current < min { + min = current + minVal = j + } + } + + return minVal +} + +func decompSize(varSize int, limbSize int) int { + return (varSize + limbSize - 1) / limbSize +} + +func nbR1CSConstraints(baseLength int, collected []checkedVariable) int { + nbDecomposed := 0 + for i := range collected { + nbDecomposed += int(decompSize(collected[i].bits, baseLength)) + } + eqs := len(collected) // correctness of decomposition + nbRight := nbDecomposed // inverse per decomposed + nbleft := (1 << baseLength) // div per table + return nbleft + nbRight + eqs + 1 +} + +func nbPLONKConstraints(baseLength int, collected []checkedVariable) int { + nbDecomposed := 0 + for i := range collected { + nbDecomposed += int(decompSize(collected[i].bits, baseLength)) + } + eqs := nbDecomposed // check correctness of every decomposition. this is nbDecomp adds + eq cost per collected + nbRight := 3 * nbDecomposed // denominator sub, inv and large sum per table entry + nbleft := 3 * (1 << baseLength) // denominator sub, div and large sum per table entry + return nbleft + nbRight + eqs + 1 // and the final assert +} + +func gnarkRangeCheckerSelector(api frontend.API) RangeCheckerType { + // Emulate the logic within rangecheck.New + // https://github.com/Consensys/gnark/blob/3421eaa7d544286abf3de8c46282b8d4da6d5da0/std/rangecheck/rangecheck.go#L24 + if _, ok := api.(frontend.Rangechecker); ok { + return NATIVE_RANGE_CHECKER + } else if _, ok := api.(frontend.Committer); ok { + return COMMIT_RANGE_CHECKER + } else { + return BIT_DECOMP_RANGE_CHECKER + } +} diff --git a/plonk/gates/arithmetic_extension_gate.go b/plonk/gates/arithmetic_extension_gate.go index 7391798..3e689e5 100644 --- a/plonk/gates/arithmetic_extension_gate.go +++ b/plonk/gates/arithmetic_extension_gate.go @@ -58,7 +58,7 @@ func (g *ArithmeticExtensionGate) wiresIthOutput(i uint64) Range { func (g *ArithmeticExtensionGate) EvalUnfiltered( api frontend.API, - glApi gl.Chip, + glApi *gl.Chip, vars EvaluationVars, ) []gl.QuadraticExtensionVariable { const0 := vars.localConstants[0] diff --git a/plonk/gates/arithmetic_gate.go b/plonk/gates/arithmetic_gate.go index f248bc4..3abb62d 100644 --- a/plonk/gates/arithmetic_gate.go +++ b/plonk/gates/arithmetic_gate.go @@ -59,7 +59,7 @@ func (g *ArithmeticGate) WireIthOutput(i uint64) uint64 { func (g *ArithmeticGate) EvalUnfiltered( api frontend.API, - glApi gl.Chip, + glApi *gl.Chip, vars EvaluationVars, ) []gl.QuadraticExtensionVariable { const0 := vars.localConstants[0] diff --git a/plonk/gates/base_sum_gate.go b/plonk/gates/base_sum_gate.go index 444d188..9b7d23a 100644 --- a/plonk/gates/base_sum_gate.go +++ b/plonk/gates/base_sum_gate.go @@ -65,7 +65,7 @@ func (g *BaseSumGate) limbs() []uint64 { func (g *BaseSumGate) EvalUnfiltered( api frontend.API, - glApi gl.Chip, + glApi *gl.Chip, vars EvaluationVars, ) []gl.QuadraticExtensionVariable { sum := vars.localWires[BASESUM_GATE_WIRE_SUM] diff --git a/plonk/gates/constant_gate.go b/plonk/gates/constant_gate.go index da7d852..c5280f2 100644 --- a/plonk/gates/constant_gate.go +++ b/plonk/gates/constant_gate.go @@ -56,7 +56,7 @@ func (g *ConstantGate) WireOutput(i uint64) uint64 { func (g *ConstantGate) EvalUnfiltered( api frontend.API, - glApi gl.Chip, + glApi *gl.Chip, vars EvaluationVars, ) []gl.QuadraticExtensionVariable { constraints := []gl.QuadraticExtensionVariable{} diff --git a/plonk/gates/coset_interpolation_gate.go b/plonk/gates/coset_interpolation_gate.go index a9a835d..26d4dd9 100644 --- a/plonk/gates/coset_interpolation_gate.go +++ b/plonk/gates/coset_interpolation_gate.go @@ -150,7 +150,7 @@ func (g *CosetInterpolationGate) wiresShiftedEvaluationPoint() Range { func (g *CosetInterpolationGate) EvalUnfiltered( api frontend.API, - glApi gl.Chip, + glApi *gl.Chip, vars EvaluationVars, ) []gl.QuadraticExtensionVariable { constraints := []gl.QuadraticExtensionVariable{} diff --git a/plonk/gates/evaluate_gates.go b/plonk/gates/evaluate_gates.go index 835eb6e..d95f9a3 100644 --- a/plonk/gates/evaluate_gates.go +++ b/plonk/gates/evaluate_gates.go @@ -67,7 +67,7 @@ func (g *EvaluateGatesChip) evalFiltered( vars.RemovePrefix(numSelectors) - unfiltered := gate.EvalUnfiltered(g.api, *glApi, vars) + unfiltered := gate.EvalUnfiltered(g.api, glApi, vars) for i := range unfiltered { unfiltered[i] = glApi.MulExtension(unfiltered[i], filter) } diff --git a/plonk/gates/exponentiation_gate.go b/plonk/gates/exponentiation_gate.go index eff38b4..f4f1fab 100644 --- a/plonk/gates/exponentiation_gate.go +++ b/plonk/gates/exponentiation_gate.go @@ -79,7 +79,7 @@ func (g *ExponentiationGate) wireIntermediateValue(i uint64) uint64 { func (g *ExponentiationGate) EvalUnfiltered( api frontend.API, - glApi gl.Chip, + glApi *gl.Chip, vars EvaluationVars, ) []gl.QuadraticExtensionVariable { base := vars.localWires[g.wireBase()] diff --git a/plonk/gates/gates.go b/plonk/gates/gates.go index a10aee8..fc76580 100644 --- a/plonk/gates/gates.go +++ b/plonk/gates/gates.go @@ -12,7 +12,7 @@ type Gate interface { Id() string EvalUnfiltered( api frontend.API, - glApi gl.Chip, + glApi *gl.Chip, vars EvaluationVars, ) []gl.QuadraticExtensionVariable } diff --git a/plonk/gates/gates_test.go b/plonk/gates/gates_test.go index 8e793c9..f9b4672 100644 --- a/plonk/gates/gates_test.go +++ b/plonk/gates/gates_test.go @@ -697,7 +697,7 @@ func (circuit *TestGateCircuit) Define(api frontend.API) error { vars := gates.NewEvaluationVars(localConstants[numSelectors:], localWires, publicInputsHash) - constraints := circuit.testGate.EvalUnfiltered(api, *glApi, *vars) + constraints := circuit.testGate.EvalUnfiltered(api, glApi, *vars) if len(constraints) != len(circuit.ExpectedConstraints) { return errors.New("gate constraints length mismatch") diff --git a/plonk/gates/multiplication_extension_gate.go b/plonk/gates/multiplication_extension_gate.go index 7948138..febc0f6 100644 --- a/plonk/gates/multiplication_extension_gate.go +++ b/plonk/gates/multiplication_extension_gate.go @@ -54,7 +54,7 @@ func (g *MultiplicationExtensionGate) wiresIthOutput(i uint64) Range { func (g *MultiplicationExtensionGate) EvalUnfiltered( api frontend.API, - glApi gl.Chip, + glApi *gl.Chip, vars EvaluationVars, ) []gl.QuadraticExtensionVariable { const0 := vars.localConstants[0] diff --git a/plonk/gates/noop_gate.go b/plonk/gates/noop_gate.go index f7c67e0..6df7630 100644 --- a/plonk/gates/noop_gate.go +++ b/plonk/gates/noop_gate.go @@ -27,7 +27,7 @@ func (g *NoopGate) Id() string { func (g *NoopGate) EvalUnfiltered( api frontend.API, - glApi gl.Chip, + glApi *gl.Chip, vars EvaluationVars, ) []gl.QuadraticExtensionVariable { return []gl.QuadraticExtensionVariable{} diff --git a/plonk/gates/poseidon_gate.go b/plonk/gates/poseidon_gate.go index 4576b2f..2d3dbe4 100644 --- a/plonk/gates/poseidon_gate.go +++ b/plonk/gates/poseidon_gate.go @@ -91,7 +91,7 @@ func (g *PoseidonGate) WiresEnd() uint64 { func (g *PoseidonGate) EvalUnfiltered( api frontend.API, - glApi gl.Chip, + glApi *gl.Chip, vars EvaluationVars, ) []gl.QuadraticExtensionVariable { constraints := []gl.QuadraticExtensionVariable{} diff --git a/plonk/gates/poseidon_mds_gate.go b/plonk/gates/poseidon_mds_gate.go index db23e9a..48cf20e 100644 --- a/plonk/gates/poseidon_mds_gate.go +++ b/plonk/gates/poseidon_mds_gate.go @@ -75,7 +75,7 @@ func (g *PoseidonMdsGate) mdsLayerAlgebra( func (g *PoseidonMdsGate) EvalUnfiltered( api frontend.API, - glApi gl.Chip, + glApi *gl.Chip, vars EvaluationVars, ) []gl.QuadraticExtensionVariable { constraints := []gl.QuadraticExtensionVariable{} diff --git a/plonk/gates/public_input_gate.go b/plonk/gates/public_input_gate.go index caf780e..cdb8cdf 100644 --- a/plonk/gates/public_input_gate.go +++ b/plonk/gates/public_input_gate.go @@ -31,7 +31,7 @@ func (g *PublicInputGate) WiresPublicInputsHash() []uint64 { func (g *PublicInputGate) EvalUnfiltered( api frontend.API, - glApi gl.Chip, + glApi *gl.Chip, vars EvaluationVars, ) []gl.QuadraticExtensionVariable { constraints := []gl.QuadraticExtensionVariable{} diff --git a/plonk/gates/random_access_gate.go b/plonk/gates/random_access_gate.go index ca49934..706f47f 100644 --- a/plonk/gates/random_access_gate.go +++ b/plonk/gates/random_access_gate.go @@ -130,7 +130,7 @@ func (g *RandomAccessGate) WireBit(i uint64, copy uint64) uint64 { func (g *RandomAccessGate) EvalUnfiltered( api frontend.API, - glApi gl.Chip, + glApi *gl.Chip, vars EvaluationVars, ) []gl.QuadraticExtensionVariable { two := gl.NewVariable(2).ToQuadraticExtension() diff --git a/plonk/gates/reducing_extension_gate.go b/plonk/gates/reducing_extension_gate.go index fd92f2d..7bb9aee 100644 --- a/plonk/gates/reducing_extension_gate.go +++ b/plonk/gates/reducing_extension_gate.go @@ -76,7 +76,7 @@ func (g *ReducingExtensionGate) wiresAccs(i uint64) Range { func (g *ReducingExtensionGate) EvalUnfiltered( api frontend.API, - glApi gl.Chip, + glApi *gl.Chip, vars EvaluationVars, ) []gl.QuadraticExtensionVariable { alpha := vars.GetLocalExtAlgebra(g.wiresAlpha()) diff --git a/plonk/gates/reducing_gate.go b/plonk/gates/reducing_gate.go index cb212e6..57c29a3 100644 --- a/plonk/gates/reducing_gate.go +++ b/plonk/gates/reducing_gate.go @@ -76,7 +76,7 @@ func (g *ReducingGate) wiresAccs(i uint64) Range { func (g *ReducingGate) EvalUnfiltered( api frontend.API, - glApi gl.Chip, + glApi *gl.Chip, vars EvaluationVars, ) []gl.QuadraticExtensionVariable { alpha := vars.GetLocalExtAlgebra(g.wiresAlpha()) diff --git a/poseidon/bn254.go b/poseidon/bn254.go index 25ad83c..80dab15 100644 --- a/poseidon/bn254.go +++ b/poseidon/bn254.go @@ -22,7 +22,7 @@ const BN254_SPONGE_RATE int = 3 type BN254Chip struct { api frontend.API `gnark:"-"` - gl gl.Chip `gnark:"-"` + gl *gl.Chip `gnark:"-"` } type BN254State = [BN254_SPONGE_WIDTH]frontend.Variable @@ -33,7 +33,7 @@ func NewBN254Chip(api frontend.API) *BN254Chip { panic("Gnark compiler not set to BN254 scalar field") } - return &BN254Chip{api: api, gl: *gl.New(api)} + return &BN254Chip{api: api, gl: gl.New(api)} } func (c *BN254Chip) Poseidon(state BN254State) BN254State { diff --git a/poseidon/goldilocks.go b/poseidon/goldilocks.go index 7fe658c..93e1ebe 100644 --- a/poseidon/goldilocks.go +++ b/poseidon/goldilocks.go @@ -17,11 +17,11 @@ type GoldilocksHashOut = [POSEIDON_GL_HASH_SIZE]gl.Variable type GoldilocksChip struct { api frontend.API `gnark:"-"` - gl gl.Chip `gnark:"-"` + gl *gl.Chip `gnark:"-"` } func NewGoldilocksChip(api frontend.API) *GoldilocksChip { - return &GoldilocksChip{api: api, gl: *gl.New(api)} + return &GoldilocksChip{api: api, gl: gl.New(api)} } // The permutation function. diff --git a/poseidon/public_inputs_hash_test.go b/poseidon/public_inputs_hash_test.go index 2939c6d..51ef4dd 100644 --- a/poseidon/public_inputs_hash_test.go +++ b/poseidon/public_inputs_hash_test.go @@ -26,7 +26,7 @@ func (circuit *TestPublicInputsHashCircuit) Define(api frontend.API) error { input[i] = gl.NewVariable(api.FromBinary(api.ToBinary(circuit.In[i], 64)...)) } - poseidonChip := &GoldilocksChip{api: api, gl: *glAPI} + poseidonChip := &GoldilocksChip{api: api, gl: glAPI} output := poseidonChip.HashNoPad(input[:]) // Check that output is correct