diff --git a/fri/fri.go b/fri/fri.go index 80d300f..b43eb2c 100644 --- a/fri/fri.go +++ b/fri/fri.go @@ -238,9 +238,11 @@ func (f *Chip) friCombineInitial( numerator := f.gl.SubExtensionNoReduce(reducedEvals, reducedOpenings) denominator := f.gl.SubExtension(subgroupX_QE, point) sum = f.gl.MulExtension(f.gl.ExpExtension(friAlpha, uint64(len(evals))), sum) + inv, hasInv := f.gl.InverseExtension(denominator) + f.api.AssertIsEqual(hasInv, frontend.Variable(1)) sum = f.gl.MulAddExtension( numerator, - f.gl.InverseExtension(denominator), + inv, sum, ) } @@ -272,17 +274,23 @@ func (f *Chip) interpolate( } sum := gl.ZeroExtension() + + lookupFromPoints := frontend.Variable(1) for i := 0; i < len(xPoints); i++ { + quotient, hasQuotient := f.gl.DivExtension( + barycentricWeights[i], + f.gl.SubExtension( + x, + xPoints[i], + ), + ) + + lookupFromPoints = f.api.Mul(hasQuotient, lookupFromPoints) + sum = f.gl.AddExtension( f.gl.MulExtension( - f.gl.DivExtension( - barycentricWeights[i], - f.gl.SubExtension( - x, - xPoints[i], - ), - ), yPoints[i], + quotient, ), sum, ) @@ -290,17 +298,17 @@ func (f *Chip) interpolate( interpolation := f.gl.MulExtension(lX, sum) - returnField := interpolation + lookupVal := gl.ZeroExtension() // Now check if x is already within the xPoints for i := 0; i < len(xPoints); i++ { - returnField = f.gl.Lookup( + lookupVal = f.gl.Lookup( f.gl.IsZero(f.gl.SubExtension(x, xPoints[i])), - returnField, + lookupVal, yPoints[i], ) } - return returnField + return f.gl.Lookup(lookupFromPoints, lookupVal, interpolation) } func (f *Chip) computeEvaluation( @@ -367,7 +375,9 @@ func (f *Chip) computeEvaluation( } // Take the inverse of the barycentric weights // OPTIMIZE: Can provide a witness to this value - barycentricWeights[i] = f.gl.InverseExtension(barycentricWeights[i]) + inv, hasInv := f.gl.InverseExtension(barycentricWeights[i]) + f.api.AssertIsEqual(hasInv, frontend.Variable(1)) + barycentricWeights[i] = inv } return f.interpolate(beta, xPoints, yPoints, barycentricWeights) diff --git a/goldilocks/base.go b/goldilocks/base.go index 9a53ab6..2bcd8cd 100644 --- a/goldilocks/base.go +++ b/goldilocks/base.go @@ -237,18 +237,21 @@ func ReduceHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error { } // Computes the inverse of a field element x such that x * x^-1 = 1. -func (p *Chip) Inverse(x Variable) Variable { - result, err := p.api.Compiler().NewHint(InverseHint, 1, x.Limb) +func (p *Chip) Inverse(x Variable) (Variable, frontend.Variable) { + result, err := p.api.Compiler().NewHint(InverseHint, 2, x.Limb) if err != nil { panic(err) } inverse := NewVariable(result[0]) + hasInv := frontend.Variable(result[1]) p.RangeCheck(inverse) product := p.Mul(inverse, x) - p.api.AssertIsEqual(product.Limb, frontend.Variable(1)) - return inverse + productToCheck := p.api.Select(hasInv, product.Limb, frontend.Variable(1)) + p.api.AssertIsEqual(productToCheck, frontend.Variable(1)) + + return inverse, hasInv } // The hint used to compute Inverse. @@ -264,11 +267,19 @@ func InverseHint(_ *big.Int, inputs []*big.Int, results []*big.Int) error { inputGl := goldilocks.NewElement(input.Uint64()) resultGl := goldilocks.NewElement(0) + + // Will set resultGL if inputGL == 0 resultGl.Inverse(&inputGl) result := big.NewInt(0) results[0] = resultGl.BigInt(result) + hasInvInt64 := int64(0) + if !inputGl.IsZero() { + hasInvInt64 = 1 + } + results[1] = big.NewInt(hasInvInt64) + return nil } diff --git a/goldilocks/quadratic_extension.go b/goldilocks/quadratic_extension.go index d384300..38e77d3 100644 --- a/goldilocks/quadratic_extension.go +++ b/goldilocks/quadratic_extension.go @@ -126,7 +126,7 @@ func (p *Chip) InnerProductExtension( } // Computes the inverse of a quadratic extension variable in the Goldilocks field. -func (p *Chip) InverseExtension(a QuadraticExtensionVariable) QuadraticExtensionVariable { +func (p *Chip) InverseExtension(a QuadraticExtensionVariable) (QuadraticExtensionVariable, frontend.Variable) { a0IsZero := p.api.IsZero(a[0].Limb) a1IsZero := p.api.IsZero(a[1].Limb) p.api.AssertIsEqual(p.api.Mul(a0IsZero, a1IsZero), frontend.Variable(0)) @@ -135,12 +135,15 @@ func (p *Chip) InverseExtension(a QuadraticExtensionVariable) QuadraticExtension p.Mul(a[1], NewVariable(DTH_ROOT)), } aPowR := p.MulExtension(aPowRMinus1, a) - return p.ScalarMulExtension(aPowRMinus1, p.Inverse(aPowR[0])) + + aPowRInv, hasInv := p.Inverse(aPowR[0]) + return p.ScalarMulExtension(aPowRMinus1, aPowRInv), hasInv } // Divides two quadratic extension variables in the Goldilocks field. -func (p *Chip) DivExtension(a, b QuadraticExtensionVariable) QuadraticExtensionVariable { - return p.MulExtension(a, p.InverseExtension(b)) +func (p *Chip) DivExtension(a, b QuadraticExtensionVariable) (QuadraticExtensionVariable, frontend.Variable) { + bInv, hasInv := p.InverseExtension(b) + return p.MulExtension(a, bInv), hasInv } // Exponentiates a quadratic extension variable to some exponent in the Golidlocks field. diff --git a/goldilocks/quadratic_extension_test.go b/goldilocks/quadratic_extension_test.go index 9521daf..54e94ee 100644 --- a/goldilocks/quadratic_extension_test.go +++ b/goldilocks/quadratic_extension_test.go @@ -59,7 +59,7 @@ type TestQuadraticExtensionDivCircuit struct { func (c *TestQuadraticExtensionDivCircuit) Define(api frontend.API) error { glAPI := New(api) - actualRes := glAPI.DivExtension(c.Operand1, c.Operand2) + actualRes, _ := glAPI.DivExtension(c.Operand1, c.Operand2) glAPI.AssertIsEqual(actualRes[0], c.ExpectedResult[0]) glAPI.AssertIsEqual(actualRes[1], c.ExpectedResult[1]) return nil diff --git a/plonk/plonk.go b/plonk/plonk.go index 00cfc86..ac2348c 100644 --- a/plonk/plonk.go +++ b/plonk/plonk.go @@ -71,10 +71,15 @@ func (p *PlonkChip) evalL0(x gl.QuadraticExtensionVariable, xPowN gl.QuadraticEx glApi.ScalarMulExtension(x, p.DEGREE), p.DEGREE_QE, ) - return glApi.DivExtension( + + quotient, hasQuotient := glApi.DivExtension( evalZeroPoly, denominator, ) + + p.api.AssertIsEqual(hasQuotient, frontend.Variable(1)) + + return quotient } func (p *PlonkChip) checkPartialProducts(