Browse Source

flat circuit code to R1CS working

pull/5/head
arnaucube 5 years ago
parent
commit
0806af6b80
5 changed files with 267 additions and 36 deletions
  1. +23
    -11
      README.md
  2. +86
    -8
      circuitcompiler/circuit.go
  3. +36
    -1
      circuitcompiler/circuit_test.go
  4. +53
    -15
      circuitcompiler/parser.go
  5. +69
    -1
      snark_test.go

+ 23
- 11
README.md

@ -30,24 +30,36 @@ fqR := fields.NewFq(bn.R)
// new Polynomial Field // new Polynomial Field
pf := r1csqap.NewPolynomialField(f) pf := r1csqap.NewPolynomialField(f)
/*
suppose that we have the following variables with *big.Int elements:
a = [[0 1 0 0 0 0] [0 0 0 1 0 0] [0 1 0 0 1 0] [5 0 0 0 0 1]]
b = [[0 1 0 0 0 0] [0 1 0 0 0 0] [1 0 0 0 0 0] [1 0 0 0 0 0]]
c = [[0 0 0 1 0 0] [0 0 0 0 1 0] [0 0 0 0 0 1] [0 0 1 0 0 0]]
// compile circuit and get the R1CS
flatCode := `
func test(x):
aux = x*x
y = aux*x
z = x + y
out = z + 5
`
// parse the code
parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
circuit, err := parser.Parse()
assert.Nil(t, err)
fmt.Println(circuit)
// flat code to R1CS
fmt.Println("generating R1CS from flat code")
a, b, c := circuit.GenerateR1CS()
w = [1, 3, 35, 9, 27, 30]
/*
now we have the R1CS from the circuit:
a == [[0 1 0 0 0 0] [0 0 0 1 0 0] [0 1 0 0 1 0] [5 0 0 0 0 1]]
b == [[0 1 0 0 0 0] [0 1 0 0 0 0] [1 0 0 0 0 0] [1 0 0 0 0 0]]
c == [[0 0 0 1 0 0] [0 0 0 0 1 0] [0 0 0 0 0 1] [0 0 1 0 0 0]]
*/ */
alphas, betas, gammas, zx := pf.R1CSToQAP(a, b, c) alphas, betas, gammas, zx := pf.R1CSToQAP(a, b, c)
// wittness = 1, 3, 35, 9, 27, 30 // wittness = 1, 3, 35, 9, 27, 30
w := []*big.Int{b1, b3, b35, b9, b27, b30} w := []*big.Int{b1, b3, b35, b9, b27, b30}
circuit := compiler.Circuit{
NVars: 6,
NPublic: 0,
NSignals: len(w),
}
ax, bx, cx, px := pf.CombinePolynomials(w, alphas, betas, gammas) ax, bx, cx, px := pf.CombinePolynomials(w, alphas, betas, gammas)
hx := pf.DivisorPolinomial(px, zx) hx := pf.DivisorPolinomial(px, zx)

+ 86
- 8
circuitcompiler/circuit.go

@ -1,8 +1,11 @@
package circuitcompiler package circuitcompiler
import ( import (
"fmt"
"errors"
"math/big" "math/big"
"strconv"
"github.com/arnaucube/go-snark/r1csqap"
) )
type Circuit struct { type Circuit struct {
@ -19,14 +22,89 @@ type Circuit struct {
C [][]*big.Int C [][]*big.Int
} }
} }
type Constraint struct {
// v1 op v2 = out
Op string
V1 string
V2 string
Out string
Literal string
Inputs []string // in func delcaration case
}
func indexInArray(arr []string, e string) int {
for i, a := range arr {
if a == e {
return i
}
}
return -1
}
func isValue(a string) (bool, int) {
v, err := strconv.Atoi(a)
if err != nil {
return false, 0
}
return true, v
}
func insertVar(arr []*big.Int, signals []string, v string, used map[string]bool) ([]*big.Int, map[string]bool) {
isVal, value := isValue(v)
valueBigInt := big.NewInt(int64(value))
if isVal {
arr[0] = new(big.Int).Add(arr[0], valueBigInt)
} else {
if !used[v] {
panic(errors.New("using variable before it's set"))
}
arr[indexInArray(signals, v)] = new(big.Int).Add(arr[indexInArray(signals, v)], big.NewInt(int64(1)))
}
return arr, used
}
func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) {
// from flat code to R1CS
var a [][]*big.Int
var b [][]*big.Int
var c [][]*big.Int
used := make(map[string]bool)
for _, constraint := range circ.Constraints {
aConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
bConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
cConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
// if existInArray(constraint.Out) {
if used[constraint.Out] {
panic(errors.New("out variable already used: " + constraint.Out))
}
used[constraint.Out] = true
if constraint.Op == "in" {
for i := 0; i < len(constraint.Inputs); i++ {
aConstraint[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(aConstraint[indexInArray(circ.Signals, constraint.Out)], big.NewInt(int64(1)))
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.Out, used)
bConstraint[0] = big.NewInt(int64(1))
}
continue
} else if constraint.Op == "+" {
cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V2, used)
bConstraint[0] = big.NewInt(int64(1))
} else if constraint.Op == "*" {
cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used)
}
func (c *Circuit) GenerateR1CS() {
fmt.Print("function with inputs: ")
fmt.Println(c.Inputs)
fmt.Print("signals: ")
fmt.Println(c.Signals)
for _, constraint := range c.Constraints {
fmt.Println(constraint.Literal)
a = append(a, aConstraint)
b = append(b, bConstraint)
c = append(c, cConstraint)
} }
return a, b, c
} }

+ 36
- 1
circuitcompiler/circuit_test.go

@ -2,6 +2,7 @@ package circuitcompiler
import ( import (
"fmt" "fmt"
"math/big"
"strings" "strings"
"testing" "testing"
@ -37,6 +38,40 @@ func TestCircuitParser(t *testing.T) {
// flat code to R1CS // flat code to R1CS
fmt.Println("generating R1CS from flat code") fmt.Println("generating R1CS from flat code")
circuit.GenerateR1CS()
a, b, c := circuit.GenerateR1CS()
fmt.Print("function with inputs: ")
fmt.Println(circuit.Inputs) fmt.Println(circuit.Inputs)
fmt.Print("signals: ")
fmt.Println(circuit.Signals)
// expected result
b0 := big.NewInt(int64(0))
b1 := big.NewInt(int64(1))
b5 := big.NewInt(int64(5))
aExpected := [][]*big.Int{
[]*big.Int{b0, b1, b0, b0, b0, b0},
[]*big.Int{b0, b0, b0, b1, b0, b0},
[]*big.Int{b0, b1, b0, b0, b1, b0},
[]*big.Int{b5, b0, b0, b0, b0, b1},
}
bExpected := [][]*big.Int{
[]*big.Int{b0, b1, b0, b0, b0, b0},
[]*big.Int{b0, b1, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0},
[]*big.Int{b1, b0, b0, b0, b0, b0},
}
cExpected := [][]*big.Int{
[]*big.Int{b0, b0, b0, b1, b0, b0},
[]*big.Int{b0, b0, b0, b0, b1, b0},
[]*big.Int{b0, b0, b0, b0, b0, b1},
[]*big.Int{b0, b0, b1, b0, b0, b0},
}
assert.Equal(t, aExpected, a)
assert.Equal(t, bExpected, b)
assert.Equal(t, cExpected, c)
fmt.Println(a)
fmt.Println(b)
fmt.Println(c)
} }

+ 53
- 15
circuitcompiler/parser.go

@ -16,17 +16,6 @@ type Parser struct {
} }
} }
type Constraint struct {
// v1 op v2 = out
Op Token
V1 string
V2 string
Out string
Literal string
Inputs []string // in func delcaration case
}
func NewParser(r io.Reader) *Parser { func NewParser(r io.Reader) *Parser {
return &Parser{s: NewScanner(r)} return &Parser{s: NewScanner(r)}
} }
@ -90,7 +79,8 @@ func (p *Parser) ParseLine() (*Constraint, error) {
c.V1 = lit c.V1 = lit
c.Literal += lit c.Literal += lit
// operator // operator
c.Op, lit = p.scanIgnoreWhitespace()
_, lit = p.scanIgnoreWhitespace()
c.Op = lit
c.Literal += lit c.Literal += lit
// v2 // v2
_, lit = p.scanIgnoreWhitespace() _, lit = p.scanIgnoreWhitespace()
@ -102,6 +92,15 @@ func (p *Parser) ParseLine() (*Constraint, error) {
return c, nil return c, nil
} }
func existInArray(arr []string, elem string) bool {
for _, v := range arr {
if v == elem {
return true
}
}
return false
}
func addToArrayIfNotExist(arr []string, elem string) []string { func addToArrayIfNotExist(arr []string, elem string) []string {
for _, v := range arr { for _, v := range arr {
if v == elem { if v == elem {
@ -111,22 +110,61 @@ func addToArrayIfNotExist(arr []string, elem string) []string {
arr = append(arr, elem) arr = append(arr, elem)
return arr return arr
} }
func (p *Parser) Parse() (*Circuit, error) { func (p *Parser) Parse() (*Circuit, error) {
circuit := &Circuit{} circuit := &Circuit{}
circuit.Signals = append(circuit.Signals, "one") circuit.Signals = append(circuit.Signals, "one")
nInputs := 0
for { for {
constraint, err := p.ParseLine() constraint, err := p.ParseLine()
if err != nil { if err != nil {
break break
} }
if constraint.Literal == "func" { if constraint.Literal == "func" {
// one constraint for each input
for _, in := range constraint.Inputs {
newConstr := &Constraint{
Op: "in",
Out: in,
}
circuit.Constraints = append(circuit.Constraints, *newConstr)
nInputs++
}
circuit.Inputs = constraint.Inputs circuit.Inputs = constraint.Inputs
continue continue
} }
circuit.Constraints = append(circuit.Constraints, *constraint) circuit.Constraints = append(circuit.Constraints, *constraint)
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V1)
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V2)
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.Out)
isVal, _ := isValue(constraint.V1)
if !isVal {
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V1)
}
isVal, _ = isValue(constraint.V2)
if !isVal {
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V2)
}
if constraint.Out == "out" {
// if Out is "out", put it after the inputs
if !existInArray(circuit.Signals, constraint.Out) {
signalsCopy := copyArray(circuit.Signals)
var auxSignals []string
auxSignals = append(auxSignals, signalsCopy[0:nInputs+1]...)
auxSignals = append(auxSignals, constraint.Out)
auxSignals = append(auxSignals, signalsCopy[nInputs+1:]...)
circuit.Signals = auxSignals
}
} else {
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.Out)
}
} }
circuit.NVars = len(circuit.Signals)
circuit.NSignals = len(circuit.Signals)
circuit.NPublic = 0
return circuit, nil return circuit, nil
} }
func copyArray(in []string) []string { // tmp
var out []string
for _, e := range in {
out = append(out, e)
}
return out
}

+ 69
- 1
snark_test.go

@ -3,6 +3,7 @@ package snark
import ( import (
"fmt" "fmt"
"math/big" "math/big"
"strings"
"testing" "testing"
"github.com/arnaucube/go-snark/bn128" "github.com/arnaucube/go-snark/bn128"
@ -12,7 +13,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestZk(t *testing.T) {
func TestZkFromHardcodedR1CS(t *testing.T) {
bn, err := bn128.NewBn128() bn, err := bn128.NewBn128()
assert.Nil(t, err) assert.Nil(t, err)
@ -85,3 +86,70 @@ func TestZk(t *testing.T) {
assert.True(t, VerifyProof(bn, circuit, setup, proof)) assert.True(t, VerifyProof(bn, circuit, setup, proof))
} }
func TestZkFromFlatCircuitCode(t *testing.T) {
bn, err := bn128.NewBn128()
assert.Nil(t, err)
// new Finite Field
fqR := fields.NewFq(bn.R)
// new Polynomial Field
pf := r1csqap.NewPolynomialField(fqR)
// compile circuit and get the R1CS
flatCode := `
func test(x):
aux = x*x
y = aux*x
z = x + y
out = z + 5
`
// parse the code
parser := circuitcompiler.NewParser(strings.NewReader(flatCode))
circuit, err := parser.Parse()
assert.Nil(t, err)
fmt.Println(circuit)
// flat code to R1CS
fmt.Println("generating R1CS from flat code")
a, b, c := circuit.GenerateR1CS()
alphas, betas, gammas, zx := pf.R1CSToQAP(a, b, c)
// wittness = 1, 3, 35, 9, 27, 30
b1 := big.NewInt(int64(1))
b3 := big.NewInt(int64(3))
b9 := big.NewInt(int64(9))
b27 := big.NewInt(int64(27))
b30 := big.NewInt(int64(30))
b35 := big.NewInt(int64(35))
w := []*big.Int{b1, b3, b35, b9, b27, b30}
ax, bx, cx, px := pf.CombinePolynomials(w, alphas, betas, gammas)
hx := pf.DivisorPolinomial(px, zx)
// hx==px/zx so px==hx*zx
assert.Equal(t, px, pf.Mul(hx, zx))
// p(x) = a(x) * b(x) - c(x) == h(x) * z(x)
abc := pf.Sub(pf.Mul(ax, bx), cx)
assert.Equal(t, abc, px)
hz := pf.Mul(hx, zx)
assert.Equal(t, abc, hz)
div, rem := pf.Div(px, zx)
assert.Equal(t, hx, div)
assert.Equal(t, rem, r1csqap.ArrayOfBigZeros(4))
// calculate trusted setup
setup, err := GenerateTrustedSetup(bn, fqR, pf, len(w), *circuit, alphas, betas, gammas, zx)
assert.Nil(t, err)
fmt.Println("t", setup.Toxic.T)
// piA = g1 * A(t), piB = g2 * B(t), piC = g1 * C(t), piH = g1 * H(t)
proof, err := GenerateProofs(bn, fqR, *circuit, setup, hx, w)
assert.Nil(t, err)
assert.True(t, VerifyProof(bn, *circuit, setup, proof))
}

Loading…
Cancel
Save