optimized goldilocks (#22)

* cleaned up qe api

* modified goldilocks poseidon to use optimized goldilocks operations

* better comment

* added goldilocks test cases

* some cleanup and comments

* changed poseidon constaints to frontend.Variable

* fixed double cast

* fixed bug in challenger
This commit is contained in:
Kevin Jue
2023-06-08 14:22:42 -07:00
committed by GitHub
parent ecfc4a7b2b
commit 15b7dcbcdb
11 changed files with 1462 additions and 1228 deletions

View File

@@ -1,9 +1,14 @@
package field
import (
"fmt"
"math/big"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark-crypto/field/goldilocks"
"github.com/consensys/gnark/backend/hint"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/bits"
"github.com/consensys/gnark/std/math/emulated"
)
@@ -38,6 +43,7 @@ var NEG_ONE_F = NewFieldConst(EmulatedField{}.Modulus().Uint64() - 1)
var GOLDILOCKS_MULTIPLICATIVE_GROUP_GENERATOR = goldilocks.NewElement(7)
var GOLDILOCKS_TWO_ADICITY = uint64(32)
var GOLDILOCKS_POWER_OF_TWO_GENERATOR = goldilocks.NewElement(1753635133440165772)
var GOLDILOCKS_MODULUS = EmulatedField{}.Modulus()
func GoldilocksPrimitiveRootOfUnity(nLog uint64) goldilocks.Element {
if nLog > GOLDILOCKS_TWO_ADICITY {
@@ -81,3 +87,107 @@ func IsZero(api frontend.API, fieldAPI *emulated.Field[emulated.Goldilocks], x F
return isZero
}
func init() {
// register hints
hint.Register(GoldilocksMulAddHint)
}
func GoldilocksRangeCheck(api frontend.API, x frontend.Variable) {
// Goldilocks' modulus is 2^64 - 2^32 + 1,
// which is "1111111111111111111111111111111100000000000000000000000000000001' in big endian binary
// This function will first verify that x is at most 64 bits wide.
// Then it checks that if the bits[0:31] (in big-endian) are all 1, then bits[32:64] are all zero
// First decompose x into 64 bits. The bits will be in little-endian order.
bits, err := api.Compiler().NewHint(bits.NBits, 64, x)
if err != nil {
panic(err)
}
// Those bits should compose back to x
reconstructedX := frontend.Variable(0)
c := uint64(1)
for i := 0; i < 64; i++ {
reconstructedX = api.Add(reconstructedX, api.Mul(bits[i], c))
c = c << 1
api.AssertIsBoolean(bits[i])
}
api.AssertIsEqual(x, reconstructedX)
mostSigBits32Sum := frontend.Variable(0)
for i := 32; i < 64; i++ {
mostSigBits32Sum = api.Add(mostSigBits32Sum, bits[i])
}
leastSigBits32Sum := frontend.Variable(0)
for i := 0; i < 32; i++ {
leastSigBits32Sum = api.Add(leastSigBits32Sum, bits[i])
}
// If mostSigBits32Sum < 32, then we know that x < (2^63 + ... + 2^32 + 0 * 2^31 + ... + 0 * 2^0), which equals to 2^64 - 2^32
// So in that case, we don't need to do any more checks.
// If mostSigBits32Sum == 32, then we need to check that x == 2^64 - 2^32 (max GL value)
shouldCheck := api.IsZero(api.Sub(mostSigBits32Sum, 32))
api.AssertIsEqual(
api.Select(
shouldCheck,
leastSigBits32Sum,
frontend.Variable(0),
),
frontend.Variable(0),
)
}
// Calculates operands[0] * operands[1] + operands[2]
// This function assumes that all operands are within goldilocks, and will panic otherwise
// It will ensure that the result is within goldilocks
func GoldilocksMulAdd(api frontend.API, operand1, operand2, operand3 frontend.Variable) frontend.Variable {
result, err := api.Compiler().NewHint(GoldilocksMulAddHint, 2, operand1, operand2, operand3)
if err != nil {
panic(err)
}
quotient := result[0]
remainder := result[1]
// Verify the calculated value
lhs := api.Mul(operand1, operand2)
lhs = api.Add(lhs, operand3)
rhs := api.Add(api.Mul(quotient, GOLDILOCKS_MODULUS), remainder)
api.AssertIsEqual(lhs, rhs)
GoldilocksRangeCheck(api, quotient)
GoldilocksRangeCheck(api, remainder)
return remainder
}
func GoldilocksMulAddHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error {
if len(inputs) != 3 {
return fmt.Errorf("GoldilocksMulAddHint expects 3 input operands")
}
for _, operand := range inputs {
if operand.Cmp(GOLDILOCKS_MODULUS) >= 0 {
return fmt.Errorf("%s is not in the field", operand.String())
}
}
product := new(big.Int).Mul(inputs[0], inputs[1])
sum := new(big.Int).Add(product, inputs[2])
quotient := new(big.Int).Div(sum, GOLDILOCKS_MODULUS)
remainder := new(big.Int).Rem(sum, GOLDILOCKS_MODULUS)
results[0] = quotient
results[1] = remainder
return nil
}
func GoldilocksReduce(api frontend.API, x frontend.Variable) frontend.Variable {
// Use gnark's emulated field library.
fieldAPI := NewFieldAPI(api)
element := fieldAPI.NewElement(x)
return fieldAPI.Reduce(element).Limbs[0]
}

72
field/goldilocks_test.go Normal file
View File

@@ -0,0 +1,72 @@
package field
import (
"math/big"
"testing"
"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/test"
)
type TestGoldilocksRangeCheckCircuit struct {
X frontend.Variable
}
func (c *TestGoldilocksRangeCheckCircuit) Define(api frontend.API) error {
GoldilocksRangeCheck(api, c.X)
return nil
}
func TestGoldilocksRangeCheck(t *testing.T) {
assert := test.NewAssert(t)
var circuit, witness TestGoldilocksRangeCheckCircuit
witness.X = 1
assert.ProverSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), test.NoSerialization())
witness.X = 0
assert.ProverSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), test.NoSerialization())
witness.X = EmulatedField{}.Modulus()
assert.ProverFailed(&circuit, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), test.NoSerialization())
one := big.NewInt(1)
maxValidVal := new(big.Int).Sub(EmulatedField{}.Modulus(), one)
witness.X = maxValidVal
assert.ProverSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16))
}
type TestGoldilocksMulAddCircuit struct {
X, Y, Z frontend.Variable
ExpectedResult frontend.Variable
}
func (c *TestGoldilocksMulAddCircuit) Define(api frontend.API) error {
calculateValue := GoldilocksMulAdd(api, c.X, c.Y, c.Z)
api.AssertIsEqual(calculateValue, c.ExpectedResult)
return nil
}
func TestGoldilocksMulAdd(t *testing.T) {
assert := test.NewAssert(t)
var circuit, witness TestGoldilocksMulAddCircuit
witness.X = 1
witness.Y = 2
witness.Z = 3
witness.ExpectedResult = 5
assert.ProverSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), test.NoFuzzing())
bigOperand := new(big.Int).SetUint64(9223372036854775808)
expectedValue, _ := new(big.Int).SetString("18446744068340842500", 10)
witness.X = bigOperand
witness.Y = bigOperand
witness.Z = 3
witness.ExpectedResult = expectedValue
assert.ProverSucceeded(&circuit, &witness, test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), test.NoFuzzing())
}

View File

@@ -97,6 +97,10 @@ func (c *QuadraticExtensionAPI) ScalarMulExtension(a QuadraticExtension, scalar
return QuadraticExtension{c.fieldAPI.Mul(a[0], scalar), c.fieldAPI.Mul(a[1], scalar)}
}
func (c *QuadraticExtensionAPI) VarToQE(a frontend.Variable) QuadraticExtension {
return c.FieldToQE(c.fieldAPI.NewElement(a))
}
func (c *QuadraticExtensionAPI) FieldToQE(a F) QuadraticExtension {
return QuadraticExtension{a, ZERO_F}
}