Browse Source

fix: Support range checking non aligned bitwidth values (#47)

* initial commit

* most of the code done

* made global poseidon chip

* changed decompSize and added some panics

* made all gl chip as pointers

* working code

* revert go.mod and go.sum

* cleanup and comments

* cleaned up range checker selection

* renamed gnarkRangeCheckSelector to gnarkRangeCheckerSelector

* addressed PR comment

* addressed overflow issue identified by Veridise

* added some comments

* fixed some comment typos

* restore change made from commit hash 85d20ce and 9617141
main
Kevin Jue 1 year ago
committed by GitHub
parent
commit
c01f530fe1
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 222 additions and 39 deletions
  1. +2
    -2
      fri/fri.go
  2. +108
    -14
      goldilocks/base.go
  3. +1
    -1
      goldilocks/quadratic_extension_algebra.go
  4. +89
    -0
      goldilocks/range_checker_utils.go
  5. +1
    -1
      plonk/gates/arithmetic_extension_gate.go
  6. +1
    -1
      plonk/gates/arithmetic_gate.go
  7. +1
    -1
      plonk/gates/base_sum_gate.go
  8. +1
    -1
      plonk/gates/constant_gate.go
  9. +1
    -1
      plonk/gates/coset_interpolation_gate.go
  10. +1
    -1
      plonk/gates/evaluate_gates.go
  11. +1
    -1
      plonk/gates/exponentiation_gate.go
  12. +1
    -1
      plonk/gates/gates.go
  13. +1
    -1
      plonk/gates/gates_test.go
  14. +1
    -1
      plonk/gates/multiplication_extension_gate.go
  15. +1
    -1
      plonk/gates/noop_gate.go
  16. +1
    -1
      plonk/gates/poseidon_gate.go
  17. +1
    -1
      plonk/gates/poseidon_mds_gate.go
  18. +1
    -1
      plonk/gates/public_input_gate.go
  19. +1
    -1
      plonk/gates/random_access_gate.go
  20. +1
    -1
      plonk/gates/reducing_extension_gate.go
  21. +1
    -1
      plonk/gates/reducing_gate.go
  22. +2
    -2
      poseidon/bn254.go
  23. +2
    -2
      poseidon/goldilocks.go
  24. +1
    -1
      poseidon/public_inputs_hash_test.go

+ 2
- 2
fri/fri.go

@ -16,7 +16,7 @@ import (
type Chip struct {
api frontend.API `gnark:"-"`
gl gl.Chip `gnark:"-"`
gl *gl.Chip `gnark:"-"`
poseidonBN254Chip *poseidon.BN254Chip `gnark:"-"`
commonData *types.CommonCircuitData `gnark:"-"`
friParams *types.FriParams `gnark:"-"`
@ -33,7 +33,7 @@ func NewChip(
poseidonBN254Chip: poseidonBN254Chip,
commonData: commonData,
friParams: friParams,
gl: *gl.New(api),
gl: gl.New(api),
}
}

+ 108
- 14
goldilocks/base.go

@ -19,6 +19,8 @@ import (
"math"
"math/big"
"os"
"strconv"
"sync"
"github.com/consensys/gnark-crypto/field/goldilocks"
"github.com/consensys/gnark/constraint/solver"
@ -40,7 +42,13 @@ var POWER_OF_TWO_GENERATOR goldilocks.Element = goldilocks.NewElement(1753635133
var MODULUS *big.Int = emulated.Goldilocks{}.Modulus()
// The number of bits to use for range checks on inner products of field elements.
var RANGE_CHECK_NB_BITS int = 140
// This MUST be a multiple of EXPECTED_OPTIMAL_BASEWIDTH if the commit based range checker is used.
// There is a bug in the pre 0.9.2 gnark range checker where it wouldn't appropriately range check a bitwidth that
// is misaligned from EXPECTED_OPTIMAL_BASEWIDTH: https://github.com/Consensys/gnark/security/advisories/GHSA-rjjm-x32p-m3f7
var RANGE_CHECK_NB_BITS int = 144
// The bit width size that the gnark commit based range checker should use.
var EXPECTED_OPTIMAL_BASEWIDTH int = 16
// Registers the hint functions with the solver.
func init() {
@ -76,25 +84,78 @@ func NegOne() Variable {
return NewVariable(MODULUS.Uint64() - 1)
}
type RangeCheckerType int
const (
NATIVE_RANGE_CHECKER RangeCheckerType = iota
COMMIT_RANGE_CHECKER
BIT_DECOMP_RANGE_CHECKER
)
// The chip used for Goldilocks field operations.
type Chip struct {
api frontend.API
rangeChecker frontend.Rangechecker
api frontend.API
rangeChecker frontend.Rangechecker
rangeCheckerType RangeCheckerType
rangeCheckCollected []checkedVariable // These field are used if rangeCheckerType == commit_range_checker
collectedMutex sync.Mutex
}
var (
poseidonChips = make(map[frontend.API]*Chip)
mutex sync.Mutex
)
// Creates a new Goldilocks Chip.
func New(api frontend.API) *Chip {
use_bit_decomp := os.Getenv("USE_BIT_DECOMPOSITION_RANGE_CHECK")
mutex.Lock()
defer mutex.Unlock()
var rangeChecker frontend.Rangechecker
if chip, ok := poseidonChips[api]; ok {
return chip
}
// If USE_BIT_DECOMPOSITION_RANGE_CHECK is not set, then use the std.rangecheck New function
if use_bit_decomp == "" {
rangeChecker = rangecheck.New(api)
c := &Chip{api: api}
// Instantiate the range checker gadget
// Per Gnark's range checker gadget's New function, there are three possible range checkers:
// 1. The native range checker
// 2. The commit range checker
// 3. The bit decomposition range checker
//
// See https://github.com/Consensys/gnark/blob/3421eaa7d544286abf3de8c46282b8d4da6d5da0/std/rangecheck/rangecheck.go#L3
// This function will emulate gnark's range checker selection logic (within the gnarkRangeCheckSelector func). However,
// if the USE_BIT_DECOMPOSITION_RANGE_CHECK env var is set, then it will explicitly use the bit decomposition range checker.
rangeCheckerType := gnarkRangeCheckerSelector(api)
useBitDecomp := os.Getenv("USE_BIT_DECOMPOSITION_RANGE_CHECK")
if useBitDecomp == "true" {
fmt.Println("The USE_BIT_DECOMPOSITION_RANGE_CHECK env var is set to true. Using the bit decomposition range checker.")
rangeCheckerType = BIT_DECOMP_RANGE_CHECKER
}
c.rangeCheckerType = rangeCheckerType
// If we are using the bit decomposition range checker, then create bitDecompChecker object
if c.rangeCheckerType == BIT_DECOMP_RANGE_CHECKER {
c.rangeChecker = bitDecompChecker{api: api}
} else {
rangeChecker = bitDecompChecker{api: api}
if c.rangeCheckerType == COMMIT_RANGE_CHECKER {
api.Compiler().Defer(c.checkCollected)
}
// If we are using the native or commit range checker, then have gnark's range checker gadget's New function create it.
// Also, note that the range checker will need to be created AFTER the c.checkCollected function is deferred.
// The commit range checker gadget will also call a deferred function, which needs to be called after c.checkCollected.
c.rangeChecker = rangecheck.New(api)
}
return &Chip{api: api, rangeChecker: rangeChecker}
poseidonChips[api] = c
return c
}
// Adds two goldilocks field elements and returns a value within the goldilocks field.
@ -209,7 +270,7 @@ func (p *Chip) ReduceWithMaxBits(x Variable, maxNbBits uint64) Variable {
}
quotient := result[0]
p.rangeChecker.Check(quotient, int(maxNbBits))
p.rangeCheckerCheck(quotient, int(maxNbBits))
remainder := NewVariable(result[1])
p.RangeCheck(remainder)
@ -321,8 +382,8 @@ func (p *Chip) RangeCheck(x Variable) {
),
x.Limb,
)
p.rangeChecker.Check(mostSigLimb, 32)
p.rangeChecker.Check(leastSigLimb, 32)
p.rangeCheckerCheck(mostSigLimb, 32)
p.rangeCheckerCheck(leastSigLimb, 32)
// If the most significant bits are all 1, then we need to check that the least significant bits are all zero
// in order for element to be less than the Goldilock's modulus.
@ -340,13 +401,46 @@ func (p *Chip) RangeCheck(x Variable) {
// This function will assert that the field element x is less than 2^maxNbBits.
func (p *Chip) RangeCheckWithMaxBits(x Variable, maxNbBits uint64) {
p.rangeChecker.Check(x.Limb, int(maxNbBits))
p.rangeCheckerCheck(x.Limb, int(maxNbBits))
}
func (p *Chip) AssertIsEqual(x, y Variable) {
p.api.AssertIsEqual(x.Limb, y.Limb)
}
func (p *Chip) rangeCheckerCheck(x frontend.Variable, nbBits int) {
switch p.rangeCheckerType {
case NATIVE_RANGE_CHECKER:
case BIT_DECOMP_RANGE_CHECKER:
p.rangeChecker.Check(x, nbBits)
case COMMIT_RANGE_CHECKER:
p.collectedMutex.Lock()
defer p.collectedMutex.Unlock()
p.rangeCheckCollected = append(p.rangeCheckCollected, checkedVariable{v: x, bits: nbBits})
}
}
func (p *Chip) checkCollected(api frontend.API) error {
if p.rangeCheckerType != COMMIT_RANGE_CHECKER {
panic("checkCollected should only be called when using the commit range checker")
}
nbBits := getOptimalBasewidth(p.api, p.rangeCheckCollected)
if nbBits != EXPECTED_OPTIMAL_BASEWIDTH {
panic("nbBits should be " + strconv.Itoa(EXPECTED_OPTIMAL_BASEWIDTH))
}
for _, v := range p.rangeCheckCollected {
if v.bits%nbBits != 0 {
panic("v.bits is not nbBits aligned")
}
p.rangeChecker.Check(v.v, v.bits)
}
return nil
}
// Computes the n'th primitive root of unity for the Goldilocks field.
func PrimitiveRootOfUnity(nLog uint64) goldilocks.Element {
if nLog > TWO_ADICITY {

+ 1
- 1
goldilocks/quadratic_extension_algebra.go

@ -47,7 +47,7 @@ func (p *Chip) SubExtensionAlgebra(
return diff
}
func (p Chip) MulExtensionAlgebra(
func (p *Chip) MulExtensionAlgebra(
a QuadraticExtensionAlgebraVariable,
b QuadraticExtensionAlgebraVariable,
) QuadraticExtensionAlgebraVariable {

+ 89
- 0
goldilocks/range_checker_utils.go

@ -0,0 +1,89 @@
package goldilocks
import (
"math"
"github.com/consensys/gnark/frontend"
)
// The types, structs, and functions in this file were ported over from the gnark library
// https://github.com/Consensys/gnark/blob/3421eaa7d544286abf3de8c46282b8d4da6d5da0/std/rangecheck/rangecheck_commit.go
type Type int
const (
R1CS Type = iota
SCS
)
type FrontendTyper interface {
FrontendType() Type
}
type checkedVariable struct {
v frontend.Variable
bits int
}
func getOptimalBasewidth(api frontend.API, collected []checkedVariable) int {
if ft, ok := api.(FrontendTyper); ok {
switch ft.FrontendType() {
case R1CS:
return optimalWidth(nbR1CSConstraints, collected)
case SCS:
return optimalWidth(nbPLONKConstraints, collected)
}
}
return optimalWidth(nbR1CSConstraints, collected)
}
func optimalWidth(countFn func(baseLength int, collected []checkedVariable) int, collected []checkedVariable) int {
min := math.MaxInt64
minVal := 0
for j := 2; j < 18; j++ {
current := countFn(j, collected)
if current < min {
min = current
minVal = j
}
}
return minVal
}
func decompSize(varSize int, limbSize int) int {
return (varSize + limbSize - 1) / limbSize
}
func nbR1CSConstraints(baseLength int, collected []checkedVariable) int {
nbDecomposed := 0
for i := range collected {
nbDecomposed += int(decompSize(collected[i].bits, baseLength))
}
eqs := len(collected) // correctness of decomposition
nbRight := nbDecomposed // inverse per decomposed
nbleft := (1 << baseLength) // div per table
return nbleft + nbRight + eqs + 1
}
func nbPLONKConstraints(baseLength int, collected []checkedVariable) int {
nbDecomposed := 0
for i := range collected {
nbDecomposed += int(decompSize(collected[i].bits, baseLength))
}
eqs := nbDecomposed // check correctness of every decomposition. this is nbDecomp adds + eq cost per collected
nbRight := 3 * nbDecomposed // denominator sub, inv and large sum per table entry
nbleft := 3 * (1 << baseLength) // denominator sub, div and large sum per table entry
return nbleft + nbRight + eqs + 1 // and the final assert
}
func gnarkRangeCheckerSelector(api frontend.API) RangeCheckerType {
// Emulate the logic within rangecheck.New
// https://github.com/Consensys/gnark/blob/3421eaa7d544286abf3de8c46282b8d4da6d5da0/std/rangecheck/rangecheck.go#L24
if _, ok := api.(frontend.Rangechecker); ok {
return NATIVE_RANGE_CHECKER
} else if _, ok := api.(frontend.Committer); ok {
return COMMIT_RANGE_CHECKER
} else {
return BIT_DECOMP_RANGE_CHECKER
}
}

+ 1
- 1
plonk/gates/arithmetic_extension_gate.go

@ -58,7 +58,7 @@ func (g *ArithmeticExtensionGate) wiresIthOutput(i uint64) Range {
func (g *ArithmeticExtensionGate) EvalUnfiltered(
api frontend.API,
glApi gl.Chip,
glApi *gl.Chip,
vars EvaluationVars,
) []gl.QuadraticExtensionVariable {
const0 := vars.localConstants[0]

+ 1
- 1
plonk/gates/arithmetic_gate.go

@ -59,7 +59,7 @@ func (g *ArithmeticGate) WireIthOutput(i uint64) uint64 {
func (g *ArithmeticGate) EvalUnfiltered(
api frontend.API,
glApi gl.Chip,
glApi *gl.Chip,
vars EvaluationVars,
) []gl.QuadraticExtensionVariable {
const0 := vars.localConstants[0]

+ 1
- 1
plonk/gates/base_sum_gate.go

@ -65,7 +65,7 @@ func (g *BaseSumGate) limbs() []uint64 {
func (g *BaseSumGate) EvalUnfiltered(
api frontend.API,
glApi gl.Chip,
glApi *gl.Chip,
vars EvaluationVars,
) []gl.QuadraticExtensionVariable {
sum := vars.localWires[BASESUM_GATE_WIRE_SUM]

+ 1
- 1
plonk/gates/constant_gate.go

@ -56,7 +56,7 @@ func (g *ConstantGate) WireOutput(i uint64) uint64 {
func (g *ConstantGate) EvalUnfiltered(
api frontend.API,
glApi gl.Chip,
glApi *gl.Chip,
vars EvaluationVars,
) []gl.QuadraticExtensionVariable {
constraints := []gl.QuadraticExtensionVariable{}

+ 1
- 1
plonk/gates/coset_interpolation_gate.go

@ -150,7 +150,7 @@ func (g *CosetInterpolationGate) wiresShiftedEvaluationPoint() Range {
func (g *CosetInterpolationGate) EvalUnfiltered(
api frontend.API,
glApi gl.Chip,
glApi *gl.Chip,
vars EvaluationVars,
) []gl.QuadraticExtensionVariable {
constraints := []gl.QuadraticExtensionVariable{}

+ 1
- 1
plonk/gates/evaluate_gates.go

@ -67,7 +67,7 @@ func (g *EvaluateGatesChip) evalFiltered(
vars.RemovePrefix(numSelectors)
unfiltered := gate.EvalUnfiltered(g.api, *glApi, vars)
unfiltered := gate.EvalUnfiltered(g.api, glApi, vars)
for i := range unfiltered {
unfiltered[i] = glApi.MulExtension(unfiltered[i], filter)
}

+ 1
- 1
plonk/gates/exponentiation_gate.go

@ -79,7 +79,7 @@ func (g *ExponentiationGate) wireIntermediateValue(i uint64) uint64 {
func (g *ExponentiationGate) EvalUnfiltered(
api frontend.API,
glApi gl.Chip,
glApi *gl.Chip,
vars EvaluationVars,
) []gl.QuadraticExtensionVariable {
base := vars.localWires[g.wireBase()]

+ 1
- 1
plonk/gates/gates.go

@ -12,7 +12,7 @@ type Gate interface {
Id() string
EvalUnfiltered(
api frontend.API,
glApi gl.Chip,
glApi *gl.Chip,
vars EvaluationVars,
) []gl.QuadraticExtensionVariable
}

+ 1
- 1
plonk/gates/gates_test.go

@ -697,7 +697,7 @@ func (circuit *TestGateCircuit) Define(api frontend.API) error {
vars := gates.NewEvaluationVars(localConstants[numSelectors:], localWires, publicInputsHash)
constraints := circuit.testGate.EvalUnfiltered(api, *glApi, *vars)
constraints := circuit.testGate.EvalUnfiltered(api, glApi, *vars)
if len(constraints) != len(circuit.ExpectedConstraints) {
return errors.New("gate constraints length mismatch")

+ 1
- 1
plonk/gates/multiplication_extension_gate.go

@ -54,7 +54,7 @@ func (g *MultiplicationExtensionGate) wiresIthOutput(i uint64) Range {
func (g *MultiplicationExtensionGate) EvalUnfiltered(
api frontend.API,
glApi gl.Chip,
glApi *gl.Chip,
vars EvaluationVars,
) []gl.QuadraticExtensionVariable {
const0 := vars.localConstants[0]

+ 1
- 1
plonk/gates/noop_gate.go

@ -27,7 +27,7 @@ func (g *NoopGate) Id() string {
func (g *NoopGate) EvalUnfiltered(
api frontend.API,
glApi gl.Chip,
glApi *gl.Chip,
vars EvaluationVars,
) []gl.QuadraticExtensionVariable {
return []gl.QuadraticExtensionVariable{}

+ 1
- 1
plonk/gates/poseidon_gate.go

@ -91,7 +91,7 @@ func (g *PoseidonGate) WiresEnd() uint64 {
func (g *PoseidonGate) EvalUnfiltered(
api frontend.API,
glApi gl.Chip,
glApi *gl.Chip,
vars EvaluationVars,
) []gl.QuadraticExtensionVariable {
constraints := []gl.QuadraticExtensionVariable{}

+ 1
- 1
plonk/gates/poseidon_mds_gate.go

@ -75,7 +75,7 @@ func (g *PoseidonMdsGate) mdsLayerAlgebra(
func (g *PoseidonMdsGate) EvalUnfiltered(
api frontend.API,
glApi gl.Chip,
glApi *gl.Chip,
vars EvaluationVars,
) []gl.QuadraticExtensionVariable {
constraints := []gl.QuadraticExtensionVariable{}

+ 1
- 1
plonk/gates/public_input_gate.go

@ -31,7 +31,7 @@ func (g *PublicInputGate) WiresPublicInputsHash() []uint64 {
func (g *PublicInputGate) EvalUnfiltered(
api frontend.API,
glApi gl.Chip,
glApi *gl.Chip,
vars EvaluationVars,
) []gl.QuadraticExtensionVariable {
constraints := []gl.QuadraticExtensionVariable{}

+ 1
- 1
plonk/gates/random_access_gate.go

@ -130,7 +130,7 @@ func (g *RandomAccessGate) WireBit(i uint64, copy uint64) uint64 {
func (g *RandomAccessGate) EvalUnfiltered(
api frontend.API,
glApi gl.Chip,
glApi *gl.Chip,
vars EvaluationVars,
) []gl.QuadraticExtensionVariable {
two := gl.NewVariable(2).ToQuadraticExtension()

+ 1
- 1
plonk/gates/reducing_extension_gate.go

@ -76,7 +76,7 @@ func (g *ReducingExtensionGate) wiresAccs(i uint64) Range {
func (g *ReducingExtensionGate) EvalUnfiltered(
api frontend.API,
glApi gl.Chip,
glApi *gl.Chip,
vars EvaluationVars,
) []gl.QuadraticExtensionVariable {
alpha := vars.GetLocalExtAlgebra(g.wiresAlpha())

+ 1
- 1
plonk/gates/reducing_gate.go

@ -76,7 +76,7 @@ func (g *ReducingGate) wiresAccs(i uint64) Range {
func (g *ReducingGate) EvalUnfiltered(
api frontend.API,
glApi gl.Chip,
glApi *gl.Chip,
vars EvaluationVars,
) []gl.QuadraticExtensionVariable {
alpha := vars.GetLocalExtAlgebra(g.wiresAlpha())

+ 2
- 2
poseidon/bn254.go

@ -22,7 +22,7 @@ const BN254_SPONGE_RATE int = 3
type BN254Chip struct {
api frontend.API `gnark:"-"`
gl gl.Chip `gnark:"-"`
gl *gl.Chip `gnark:"-"`
}
type BN254State = [BN254_SPONGE_WIDTH]frontend.Variable
@ -33,7 +33,7 @@ func NewBN254Chip(api frontend.API) *BN254Chip {
panic("Gnark compiler not set to BN254 scalar field")
}
return &BN254Chip{api: api, gl: *gl.New(api)}
return &BN254Chip{api: api, gl: gl.New(api)}
}
func (c *BN254Chip) Poseidon(state BN254State) BN254State {

+ 2
- 2
poseidon/goldilocks.go

@ -17,11 +17,11 @@ type GoldilocksHashOut = [POSEIDON_GL_HASH_SIZE]gl.Variable
type GoldilocksChip struct {
api frontend.API `gnark:"-"`
gl gl.Chip `gnark:"-"`
gl *gl.Chip `gnark:"-"`
}
func NewGoldilocksChip(api frontend.API) *GoldilocksChip {
return &GoldilocksChip{api: api, gl: *gl.New(api)}
return &GoldilocksChip{api: api, gl: gl.New(api)}
}
// The permutation function.

+ 1
- 1
poseidon/public_inputs_hash_test.go

@ -26,7 +26,7 @@ func (circuit *TestPublicInputsHashCircuit) Define(api frontend.API) error {
input[i] = gl.NewVariable(api.FromBinary(api.ToBinary(circuit.In[i], 64)...))
}
poseidonChip := &GoldilocksChip{api: api, gl: *glAPI}
poseidonChip := &GoldilocksChip{api: api, gl: glAPI}
output := poseidonChip.HashNoPad(input[:])
// Check that output is correct

Loading…
Cancel
Save