mirror of
https://github.com/arnaucube/go-snark-study.git
synced 2026-02-02 17:26:41 +01:00
flat circuit code to R1CS working
This commit is contained in:
34
README.md
34
README.md
@@ -30,24 +30,36 @@ fqR := fields.NewFq(bn.R)
|
|||||||
// new Polynomial Field
|
// new Polynomial Field
|
||||||
pf := r1csqap.NewPolynomialField(f)
|
pf := r1csqap.NewPolynomialField(f)
|
||||||
|
|
||||||
/*
|
// compile circuit and get the R1CS
|
||||||
suppose that we have the following variables with *big.Int elements:
|
flatCode := `
|
||||||
a = [[0 1 0 0 0 0] [0 0 0 1 0 0] [0 1 0 0 1 0] [5 0 0 0 0 1]]
|
func test(x):
|
||||||
b = [[0 1 0 0 0 0] [0 1 0 0 0 0] [1 0 0 0 0 0] [1 0 0 0 0 0]]
|
aux = x*x
|
||||||
c = [[0 0 0 1 0 0] [0 0 0 0 1 0] [0 0 0 0 0 1] [0 0 1 0 0 0]]
|
y = aux*x
|
||||||
|
z = x + y
|
||||||
|
out = z + 5
|
||||||
|
`
|
||||||
|
// parse the code
|
||||||
|
parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
|
||||||
|
circuit, err := parser.Parse()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
fmt.Println(circuit)
|
||||||
|
// flat code to R1CS
|
||||||
|
fmt.Println("generating R1CS from flat code")
|
||||||
|
a, b, c := circuit.GenerateR1CS()
|
||||||
|
|
||||||
w = [1, 3, 35, 9, 27, 30]
|
/*
|
||||||
|
now we have the R1CS from the circuit:
|
||||||
|
a == [[0 1 0 0 0 0] [0 0 0 1 0 0] [0 1 0 0 1 0] [5 0 0 0 0 1]]
|
||||||
|
b == [[0 1 0 0 0 0] [0 1 0 0 0 0] [1 0 0 0 0 0] [1 0 0 0 0 0]]
|
||||||
|
c == [[0 0 0 1 0 0] [0 0 0 0 1 0] [0 0 0 0 0 1] [0 0 1 0 0 0]]
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
alphas, betas, gammas, zx := pf.R1CSToQAP(a, b, c)
|
alphas, betas, gammas, zx := pf.R1CSToQAP(a, b, c)
|
||||||
|
|
||||||
// wittness = 1, 3, 35, 9, 27, 30
|
// wittness = 1, 3, 35, 9, 27, 30
|
||||||
w := []*big.Int{b1, b3, b35, b9, b27, b30}
|
w := []*big.Int{b1, b3, b35, b9, b27, b30}
|
||||||
circuit := compiler.Circuit{
|
|
||||||
NVars: 6,
|
|
||||||
NPublic: 0,
|
|
||||||
NSignals: len(w),
|
|
||||||
}
|
|
||||||
ax, bx, cx, px := pf.CombinePolynomials(w, alphas, betas, gammas)
|
ax, bx, cx, px := pf.CombinePolynomials(w, alphas, betas, gammas)
|
||||||
|
|
||||||
hx := pf.DivisorPolinomial(px, zx)
|
hx := pf.DivisorPolinomial(px, zx)
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
package circuitcompiler
|
package circuitcompiler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"errors"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/arnaucube/go-snark/r1csqap"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Circuit struct {
|
type Circuit struct {
|
||||||
@@ -19,14 +22,89 @@ type Circuit struct {
|
|||||||
C [][]*big.Int
|
C [][]*big.Int
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
type Constraint struct {
|
||||||
|
// v1 op v2 = out
|
||||||
|
Op string
|
||||||
|
V1 string
|
||||||
|
V2 string
|
||||||
|
Out string
|
||||||
|
Literal string
|
||||||
|
|
||||||
func (c *Circuit) GenerateR1CS() {
|
Inputs []string // in func delcaration case
|
||||||
fmt.Print("function with inputs: ")
|
}
|
||||||
fmt.Println(c.Inputs)
|
|
||||||
fmt.Print("signals: ")
|
func indexInArray(arr []string, e string) int {
|
||||||
fmt.Println(c.Signals)
|
for i, a := range arr {
|
||||||
for _, constraint := range c.Constraints {
|
if a == e {
|
||||||
fmt.Println(constraint.Literal)
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
func isValue(a string) (bool, int) {
|
||||||
|
v, err := strconv.Atoi(a)
|
||||||
|
if err != nil {
|
||||||
|
return false, 0
|
||||||
|
}
|
||||||
|
return true, v
|
||||||
|
}
|
||||||
|
func insertVar(arr []*big.Int, signals []string, v string, used map[string]bool) ([]*big.Int, map[string]bool) {
|
||||||
|
isVal, value := isValue(v)
|
||||||
|
valueBigInt := big.NewInt(int64(value))
|
||||||
|
if isVal {
|
||||||
|
arr[0] = new(big.Int).Add(arr[0], valueBigInt)
|
||||||
|
} else {
|
||||||
|
if !used[v] {
|
||||||
|
panic(errors.New("using variable before it's set"))
|
||||||
|
}
|
||||||
|
arr[indexInArray(signals, v)] = new(big.Int).Add(arr[indexInArray(signals, v)], big.NewInt(int64(1)))
|
||||||
|
}
|
||||||
|
return arr, used
|
||||||
|
}
|
||||||
|
|
||||||
|
func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) {
|
||||||
|
// from flat code to R1CS
|
||||||
|
|
||||||
|
var a [][]*big.Int
|
||||||
|
var b [][]*big.Int
|
||||||
|
var c [][]*big.Int
|
||||||
|
|
||||||
|
used := make(map[string]bool)
|
||||||
|
for _, constraint := range circ.Constraints {
|
||||||
|
|
||||||
|
aConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
|
||||||
|
bConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
|
||||||
|
cConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
|
||||||
|
|
||||||
|
// if existInArray(constraint.Out) {
|
||||||
|
if used[constraint.Out] {
|
||||||
|
panic(errors.New("out variable already used: " + constraint.Out))
|
||||||
|
}
|
||||||
|
used[constraint.Out] = true
|
||||||
|
if constraint.Op == "in" {
|
||||||
|
for i := 0; i < len(constraint.Inputs); i++ {
|
||||||
|
aConstraint[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(aConstraint[indexInArray(circ.Signals, constraint.Out)], big.NewInt(int64(1)))
|
||||||
|
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.Out, used)
|
||||||
|
bConstraint[0] = big.NewInt(int64(1))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
continue
|
||||||
|
|
||||||
|
} else if constraint.Op == "+" {
|
||||||
|
cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
|
||||||
|
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
|
||||||
|
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V2, used)
|
||||||
|
bConstraint[0] = big.NewInt(int64(1))
|
||||||
|
} else if constraint.Op == "*" {
|
||||||
|
cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
|
||||||
|
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
|
||||||
|
bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used)
|
||||||
|
}
|
||||||
|
|
||||||
|
a = append(a, aConstraint)
|
||||||
|
b = append(b, bConstraint)
|
||||||
|
c = append(c, cConstraint)
|
||||||
|
|
||||||
|
}
|
||||||
|
return a, b, c
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package circuitcompiler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -37,6 +38,40 @@ func TestCircuitParser(t *testing.T) {
|
|||||||
|
|
||||||
// flat code to R1CS
|
// flat code to R1CS
|
||||||
fmt.Println("generating R1CS from flat code")
|
fmt.Println("generating R1CS from flat code")
|
||||||
circuit.GenerateR1CS()
|
a, b, c := circuit.GenerateR1CS()
|
||||||
|
fmt.Print("function with inputs: ")
|
||||||
fmt.Println(circuit.Inputs)
|
fmt.Println(circuit.Inputs)
|
||||||
|
|
||||||
|
fmt.Print("signals: ")
|
||||||
|
fmt.Println(circuit.Signals)
|
||||||
|
|
||||||
|
// expected result
|
||||||
|
b0 := big.NewInt(int64(0))
|
||||||
|
b1 := big.NewInt(int64(1))
|
||||||
|
b5 := big.NewInt(int64(5))
|
||||||
|
aExpected := [][]*big.Int{
|
||||||
|
[]*big.Int{b0, b1, b0, b0, b0, b0},
|
||||||
|
[]*big.Int{b0, b0, b0, b1, b0, b0},
|
||||||
|
[]*big.Int{b0, b1, b0, b0, b1, b0},
|
||||||
|
[]*big.Int{b5, b0, b0, b0, b0, b1},
|
||||||
|
}
|
||||||
|
bExpected := [][]*big.Int{
|
||||||
|
[]*big.Int{b0, b1, b0, b0, b0, b0},
|
||||||
|
[]*big.Int{b0, b1, b0, b0, b0, b0},
|
||||||
|
[]*big.Int{b1, b0, b0, b0, b0, b0},
|
||||||
|
[]*big.Int{b1, b0, b0, b0, b0, b0},
|
||||||
|
}
|
||||||
|
cExpected := [][]*big.Int{
|
||||||
|
[]*big.Int{b0, b0, b0, b1, b0, b0},
|
||||||
|
[]*big.Int{b0, b0, b0, b0, b1, b0},
|
||||||
|
[]*big.Int{b0, b0, b0, b0, b0, b1},
|
||||||
|
[]*big.Int{b0, b0, b1, b0, b0, b0},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, aExpected, a)
|
||||||
|
assert.Equal(t, bExpected, b)
|
||||||
|
assert.Equal(t, cExpected, c)
|
||||||
|
fmt.Println(a)
|
||||||
|
fmt.Println(b)
|
||||||
|
fmt.Println(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,17 +16,6 @@ type Parser struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type Constraint struct {
|
|
||||||
// v1 op v2 = out
|
|
||||||
Op Token
|
|
||||||
V1 string
|
|
||||||
V2 string
|
|
||||||
Out string
|
|
||||||
Literal string
|
|
||||||
|
|
||||||
Inputs []string // in func delcaration case
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewParser(r io.Reader) *Parser {
|
func NewParser(r io.Reader) *Parser {
|
||||||
return &Parser{s: NewScanner(r)}
|
return &Parser{s: NewScanner(r)}
|
||||||
}
|
}
|
||||||
@@ -90,7 +79,8 @@ func (p *Parser) ParseLine() (*Constraint, error) {
|
|||||||
c.V1 = lit
|
c.V1 = lit
|
||||||
c.Literal += lit
|
c.Literal += lit
|
||||||
// operator
|
// operator
|
||||||
c.Op, lit = p.scanIgnoreWhitespace()
|
_, lit = p.scanIgnoreWhitespace()
|
||||||
|
c.Op = lit
|
||||||
c.Literal += lit
|
c.Literal += lit
|
||||||
// v2
|
// v2
|
||||||
_, lit = p.scanIgnoreWhitespace()
|
_, lit = p.scanIgnoreWhitespace()
|
||||||
@@ -102,6 +92,15 @@ func (p *Parser) ParseLine() (*Constraint, error) {
|
|||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func existInArray(arr []string, elem string) bool {
|
||||||
|
for _, v := range arr {
|
||||||
|
if v == elem {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func addToArrayIfNotExist(arr []string, elem string) []string {
|
func addToArrayIfNotExist(arr []string, elem string) []string {
|
||||||
for _, v := range arr {
|
for _, v := range arr {
|
||||||
if v == elem {
|
if v == elem {
|
||||||
@@ -111,22 +110,61 @@ func addToArrayIfNotExist(arr []string, elem string) []string {
|
|||||||
arr = append(arr, elem)
|
arr = append(arr, elem)
|
||||||
return arr
|
return arr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Parser) Parse() (*Circuit, error) {
|
func (p *Parser) Parse() (*Circuit, error) {
|
||||||
circuit := &Circuit{}
|
circuit := &Circuit{}
|
||||||
circuit.Signals = append(circuit.Signals, "one")
|
circuit.Signals = append(circuit.Signals, "one")
|
||||||
|
nInputs := 0
|
||||||
for {
|
for {
|
||||||
constraint, err := p.ParseLine()
|
constraint, err := p.ParseLine()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if constraint.Literal == "func" {
|
if constraint.Literal == "func" {
|
||||||
|
// one constraint for each input
|
||||||
|
for _, in := range constraint.Inputs {
|
||||||
|
newConstr := &Constraint{
|
||||||
|
Op: "in",
|
||||||
|
Out: in,
|
||||||
|
}
|
||||||
|
circuit.Constraints = append(circuit.Constraints, *newConstr)
|
||||||
|
nInputs++
|
||||||
|
}
|
||||||
circuit.Inputs = constraint.Inputs
|
circuit.Inputs = constraint.Inputs
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
circuit.Constraints = append(circuit.Constraints, *constraint)
|
circuit.Constraints = append(circuit.Constraints, *constraint)
|
||||||
|
isVal, _ := isValue(constraint.V1)
|
||||||
|
if !isVal {
|
||||||
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V1)
|
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V1)
|
||||||
|
}
|
||||||
|
isVal, _ = isValue(constraint.V2)
|
||||||
|
if !isVal {
|
||||||
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V2)
|
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V2)
|
||||||
|
}
|
||||||
|
if constraint.Out == "out" {
|
||||||
|
// if Out is "out", put it after the inputs
|
||||||
|
if !existInArray(circuit.Signals, constraint.Out) {
|
||||||
|
signalsCopy := copyArray(circuit.Signals)
|
||||||
|
var auxSignals []string
|
||||||
|
auxSignals = append(auxSignals, signalsCopy[0:nInputs+1]...)
|
||||||
|
auxSignals = append(auxSignals, constraint.Out)
|
||||||
|
auxSignals = append(auxSignals, signalsCopy[nInputs+1:]...)
|
||||||
|
circuit.Signals = auxSignals
|
||||||
|
}
|
||||||
|
} else {
|
||||||
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.Out)
|
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.Out)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
circuit.NVars = len(circuit.Signals)
|
||||||
|
circuit.NSignals = len(circuit.Signals)
|
||||||
|
circuit.NPublic = 0
|
||||||
return circuit, nil
|
return circuit, nil
|
||||||
}
|
}
|
||||||
|
func copyArray(in []string) []string { // tmp
|
||||||
|
var out []string
|
||||||
|
for _, e := range in {
|
||||||
|
out = append(out, e)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package snark
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/arnaucube/go-snark/bn128"
|
"github.com/arnaucube/go-snark/bn128"
|
||||||
@@ -12,7 +13,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestZk(t *testing.T) {
|
func TestZkFromHardcodedR1CS(t *testing.T) {
|
||||||
bn, err := bn128.NewBn128()
|
bn, err := bn128.NewBn128()
|
||||||
assert.Nil(t, err)
|
assert.Nil(t, err)
|
||||||
|
|
||||||
@@ -85,3 +86,70 @@ func TestZk(t *testing.T) {
|
|||||||
|
|
||||||
assert.True(t, VerifyProof(bn, circuit, setup, proof))
|
assert.True(t, VerifyProof(bn, circuit, setup, proof))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestZkFromFlatCircuitCode(t *testing.T) {
|
||||||
|
bn, err := bn128.NewBn128()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
// new Finite Field
|
||||||
|
fqR := fields.NewFq(bn.R)
|
||||||
|
|
||||||
|
// new Polynomial Field
|
||||||
|
pf := r1csqap.NewPolynomialField(fqR)
|
||||||
|
|
||||||
|
// compile circuit and get the R1CS
|
||||||
|
flatCode := `
|
||||||
|
func test(x):
|
||||||
|
aux = x*x
|
||||||
|
y = aux*x
|
||||||
|
z = x + y
|
||||||
|
out = z + 5
|
||||||
|
`
|
||||||
|
// parse the code
|
||||||
|
parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
|
||||||
|
circuit, err := parser.Parse()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
fmt.Println(circuit)
|
||||||
|
// flat code to R1CS
|
||||||
|
fmt.Println("generating R1CS from flat code")
|
||||||
|
a, b, c := circuit.GenerateR1CS()
|
||||||
|
|
||||||
|
alphas, betas, gammas, zx := pf.R1CSToQAP(a, b, c)
|
||||||
|
|
||||||
|
// wittness = 1, 3, 35, 9, 27, 30
|
||||||
|
b1 := big.NewInt(int64(1))
|
||||||
|
b3 := big.NewInt(int64(3))
|
||||||
|
b9 := big.NewInt(int64(9))
|
||||||
|
b27 := big.NewInt(int64(27))
|
||||||
|
b30 := big.NewInt(int64(30))
|
||||||
|
b35 := big.NewInt(int64(35))
|
||||||
|
w := []*big.Int{b1, b3, b35, b9, b27, b30}
|
||||||
|
|
||||||
|
ax, bx, cx, px := pf.CombinePolynomials(w, alphas, betas, gammas)
|
||||||
|
|
||||||
|
hx := pf.DivisorPolinomial(px, zx)
|
||||||
|
|
||||||
|
// hx==px/zx so px==hx*zx
|
||||||
|
assert.Equal(t, px, pf.Mul(hx, zx))
|
||||||
|
|
||||||
|
// p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
|
||||||
|
abc := pf.Sub(pf.Mul(ax, bx), cx)
|
||||||
|
assert.Equal(t, abc, px)
|
||||||
|
hz := pf.Mul(hx, zx)
|
||||||
|
assert.Equal(t, abc, hz)
|
||||||
|
|
||||||
|
div, rem := pf.Div(px, zx)
|
||||||
|
assert.Equal(t, hx, div)
|
||||||
|
assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
|
||||||
|
|
||||||
|
// calculate trusted setup
|
||||||
|
setup, err := GenerateTrustedSetup(bn, fqR, pf, len(w), *circuit, alphas, betas, gammas, zx)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
fmt.Println("t", setup.Toxic.T)
|
||||||
|
|
||||||
|
// piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
|
||||||
|
proof, err := GenerateProofs(bn, fqR, *circuit, setup, hx, w)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
|
||||||
|
assert.True(t, VerifyProof(bn, *circuit, setup, proof))
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user