Significant refactor and all tests passing, as well as optimized range check for Goldilocks (#37)

This commit is contained in:
puma314
2023-10-11 18:02:46 -07:00
committed by GitHub
parent 13624e4daf
commit 940c81b212
50 changed files with 1089 additions and 1146 deletions

1
goldilocks/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
gnark.pprof

View File

@@ -11,15 +11,19 @@ package goldilocks
// be very beneficial to use the no reduction methods and keep track of the maximum number of bits
// your computation uses.
// This implementation is based on the following plonky2 implementation of Goldilocks
// Available here: https://github.com/0xPolygonZero/plonky2/blob/main/field/src/goldilocks_field.rs#L70
import (
"fmt"
"math"
"math/big"
"github.com/consensys/gnark-crypto/field/goldilocks"
"github.com/consensys/gnark/constraint/solver"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/bits"
"github.com/consensys/gnark/std/math/emulated"
"github.com/consensys/gnark/std/rangecheck"
)
// The multiplicative group generator of the field.
@@ -45,6 +49,7 @@ func init() {
solver.RegisterHint(MulAddHint)
solver.RegisterHint(ReduceHint)
solver.RegisterHint(InverseHint)
solver.RegisterHint(SplitLimbsHint)
}
// A type alias used to represent Goldilocks field elements.
@@ -75,12 +80,14 @@ func NegOne() Variable {
// The chip used for Goldilocks field operations.
type Chip struct {
api frontend.API
api frontend.API
rangeChecker frontend.Rangechecker
}
// Creates a new Goldilocks chip.
func NewChip(api frontend.API) *Chip {
return &Chip{api: api}
// Creates a new Goldilocks Chip.
func New(api frontend.API) *Chip {
rangeChecker := rangecheck.New(api)
return &Chip{api: api, rangeChecker: rangeChecker}
}
// Adds two field elements such that x + y = z within the Golidlocks field.
@@ -180,8 +187,7 @@ func (p *Chip) Reduce(x Variable) Variable {
}
quotient := result[0]
rangeCheckNbBits := RANGE_CHECK_NB_BITS
p.api.ToBinary(quotient, rangeCheckNbBits)
p.rangeChecker.Check(quotient, RANGE_CHECK_NB_BITS)
remainder := NewVariable(result[1])
p.RangeCheck(remainder)
@@ -203,7 +209,7 @@ func (p *Chip) ReduceWithMaxBits(x Variable, maxNbBits uint64) Variable {
}
quotient := result[0]
p.api.ToBinary(quotient, int(maxNbBits))
p.rangeChecker.Check(quotient, int(maxNbBits))
remainder := NewVariable(result[1])
p.RangeCheck(remainder)
@@ -279,6 +285,29 @@ func (p *Chip) Exp(x Variable, k *big.Int) Variable {
return z
}
// The hint used to split a GoldilocksVariable into 2 32 bit limbs.
func SplitLimbsHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
if len(inputs) != 1 {
panic("SplitLimbsHint expects 1 input operand")
}
// The Goldilocks field element
input := inputs[0]
if input.Cmp(MODULUS) == 0 || input.Cmp(MODULUS) == 1 {
return fmt.Errorf("input is not in the field")
}
two_32 := big.NewInt(int64(math.Pow(2, 32)))
// The most significant bits
results[0] = new(big.Int).Quo(input, two_32)
// The least significant bits
results[1] = new(big.Int).Rem(input, two_32)
return nil
}
// Range checks a field element x to be less than the Golidlocks modulus 2 ^ 64 - 2 ^ 32 + 1.
func (p *Chip) RangeCheck(x Variable) {
// The Goldilocks' modulus is 2^64 - 2^32 + 1, which is:
@@ -288,40 +317,32 @@ func (p *Chip) RangeCheck(x Variable) {
// in big endian binary. This function will first verify that x is at most 64 bits wide. Then it
// checks that if the bits[0:31] (in big-endian) are all 1, then bits[32:64] are all zero.
// First decompose x into 64 bits. The bits will be in little-endian order.
bits := bits.ToBinary(p.api, x.Limb, bits.WithNbDigits(64))
// Those bits should compose back to x.
reconstructedX := frontend.Variable(0)
c := uint64(1)
for i := 0; i < 64; i++ {
reconstructedX = p.api.Add(reconstructedX, p.api.Mul(bits[i], c))
c = c << 1
p.api.AssertIsBoolean(bits[i])
}
p.api.AssertIsEqual(x.Limb, reconstructedX)
mostSigBits32Sum := frontend.Variable(0)
for i := 32; i < 64; i++ {
mostSigBits32Sum = p.api.Add(mostSigBits32Sum, bits[i])
result, err := p.api.Compiler().NewHint(SplitLimbsHint, 2, x.Limb)
if err != nil {
panic(err)
}
leastSigBits32Sum := frontend.Variable(0)
for i := 0; i < 32; i++ {
leastSigBits32Sum = p.api.Add(leastSigBits32Sum, bits[i])
}
// We check that this is a valid decomposition of the Goldilock's element and range-check each limb.
mostSigLimb := result[0]
leastSigLimb := result[1]
p.api.AssertIsEqual(
p.api.Add(
p.api.Mul(mostSigLimb, uint64(math.Pow(2, 32))),
leastSigLimb,
),
x.Limb,
)
p.rangeChecker.Check(mostSigLimb, 32)
p.rangeChecker.Check(leastSigLimb, 32)
// If mostSigBits32Sum < 32, then we know that:
//
// x < (2^63 + ... + 2^32 + 0 * 2^31 + ... + 0 * 2^0)
//
// which equals to 2^64 - 2^32. So in that case, we don't need to do any more checks. If
// mostSigBits32Sum == 32, then we need to check that x == 2^64 - 2^32 (max GL value).
shouldCheck := p.api.IsZero(p.api.Sub(mostSigBits32Sum, 32))
// If the most significant bits are all 1, then we need to check that the least significant bits are all zero
// in order for element to be less than the Goldilock's modulus.
// Otherwise, we don't need to do any checks, since we already know that the element is less than the Goldilocks modulus.
shouldCheck := p.api.IsZero(p.api.Sub(mostSigLimb, uint64(math.Pow(2, 32))-1))
p.api.AssertIsEqual(
p.api.Select(
shouldCheck,
leastSigBits32Sum,
leastSigLimb,
frontend.Variable(0),
),
frontend.Variable(0),

View File

@@ -1,12 +1,16 @@
package goldilocks
import (
"fmt"
"math/big"
"os"
"testing"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/cs/scs"
"github.com/consensys/gnark/profile"
"github.com/consensys/gnark/test"
)
@@ -15,8 +19,8 @@ type TestGoldilocksRangeCheckCircuit struct {
}
func (c *TestGoldilocksRangeCheckCircuit) Define(api frontend.API) error {
chip := NewChip(api)
chip.RangeCheck(NewVariable(c.X))
glApi := New(api)
glApi.RangeCheck(NewVariable(c.X))
return nil
}
func TestGoldilocksRangeCheck(t *testing.T) {
@@ -25,13 +29,13 @@ func TestGoldilocksRangeCheck(t *testing.T) {
var circuit, witness TestGoldilocksRangeCheckCircuit
witness.X = 1
assert.ProverSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), test.NoSerialization())
assert.ProverSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), test.NoSerializationChecks())
witness.X = 0
assert.ProverSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), test.NoSerialization())
assert.ProverSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), test.NoSerializationChecks())
witness.X = MODULUS
assert.ProverFailed(&circuit, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), test.NoSerialization())
assert.ProverFailed(&circuit, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), test.NoSerializationChecks())
one := big.NewInt(1)
maxValidVal := new(big.Int).Sub(MODULUS, one)
@@ -39,14 +43,53 @@ func TestGoldilocksRangeCheck(t *testing.T) {
assert.ProverSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16))
}
type TestGoldilocksRangeCheckBenchmarkCircuit struct {
X []frontend.Variable
}
func (c *TestGoldilocksRangeCheckBenchmarkCircuit) Define(api frontend.API) error {
glApi := New(api)
for _, x := range c.X {
glApi.RangeCheck(NewVariable(x))
glApi.Reduce(NewVariable(x))
}
return nil
}
func BenchmarkGoldilocksRangeCheck(b *testing.B) {
var sizes = []int{5, 10, 15}
for i := 0; i < len(sizes); i++ {
var circuit, witness TestGoldilocksRangeCheckBenchmarkCircuit
circuit.X = make([]frontend.Variable, 2<<sizes[i])
witness.X = make([]frontend.Variable, 2<<sizes[i])
for j := 0; j < len(circuit.X); j++ {
witness.X[j] = 1
}
p := profile.Start()
r1cs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &circuit)
if err != nil {
fmt.Println("error in building circuit", err)
os.Exit(1)
}
p.Stop()
p.Top()
println("r1cs.GetNbCoefficients(): ", r1cs.GetNbCoefficients())
println("r1cs.GetNbConstraints(): ", r1cs.GetNbConstraints())
println("r1cs.GetNbSecretVariables(): ", r1cs.GetNbSecretVariables())
println("r1cs.GetNbPublicVariables(): ", r1cs.GetNbPublicVariables())
println("r1cs.GetNbInternalVariables(): ", r1cs.GetNbInternalVariables())
}
}
type TestGoldilocksMulAddCircuit struct {
X, Y, Z frontend.Variable
ExpectedResult frontend.Variable
}
func (c *TestGoldilocksMulAddCircuit) Define(api frontend.API) error {
chip := NewChip(api)
calculateValue := chip.MulAdd(NewVariable(c.X), NewVariable(c.Y), NewVariable(c.Z))
glApi := New(api)
calculateValue := glApi.MulAdd(NewVariable(c.X), NewVariable(c.Y), NewVariable(c.Z))
api.AssertIsEqual(calculateValue.Limb, c.ExpectedResult)
return nil
}

View File

@@ -15,7 +15,7 @@ type TestQuadraticExtensionMulCircuit struct {
}
func (c *TestQuadraticExtensionMulCircuit) Define(api frontend.API) error {
glApi := NewChip(api)
glApi := New(api)
actualRes := glApi.MulExtension(c.Operand1, c.Operand2)
glApi.AssertIsEqual(actualRes[0], c.ExpectedResult[0])
glApi.AssertIsEqual(actualRes[1], c.ExpectedResult[1])
@@ -58,7 +58,7 @@ type TestQuadraticExtensionDivCircuit struct {
}
func (c *TestQuadraticExtensionDivCircuit) Define(api frontend.API) error {
glAPI := NewChip(api)
glAPI := New(api)
actualRes := glAPI.DivExtension(c.Operand1, c.Operand2)
glAPI.AssertIsEqual(actualRes[0], c.ExpectedResult[0])
glAPI.AssertIsEqual(actualRes[1], c.ExpectedResult[1])