mirror of
https://github.com/arnaucube/go-snark-study.git
synced 2026-02-02 17:26:41 +01:00
flat circuit code to R1CS working
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
package circuitcompiler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"errors"
|
||||
"math/big"
|
||||
"strconv"
|
||||
|
||||
"github.com/arnaucube/go-snark/r1csqap"
|
||||
)
|
||||
|
||||
type Circuit struct {
|
||||
@@ -19,14 +22,89 @@ type Circuit struct {
|
||||
C [][]*big.Int
|
||||
}
|
||||
}
|
||||
type Constraint struct {
|
||||
// v1 op v2 = out
|
||||
Op string
|
||||
V1 string
|
||||
V2 string
|
||||
Out string
|
||||
Literal string
|
||||
|
||||
func (c *Circuit) GenerateR1CS() {
|
||||
fmt.Print("function with inputs: ")
|
||||
fmt.Println(c.Inputs)
|
||||
fmt.Print("signals: ")
|
||||
fmt.Println(c.Signals)
|
||||
for _, constraint := range c.Constraints {
|
||||
fmt.Println(constraint.Literal)
|
||||
Inputs []string // in func delcaration case
|
||||
}
|
||||
|
||||
func indexInArray(arr []string, e string) int {
|
||||
for i, a := range arr {
|
||||
if a == e {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
func isValue(a string) (bool, int) {
|
||||
v, err := strconv.Atoi(a)
|
||||
if err != nil {
|
||||
return false, 0
|
||||
}
|
||||
return true, v
|
||||
}
|
||||
func insertVar(arr []*big.Int, signals []string, v string, used map[string]bool) ([]*big.Int, map[string]bool) {
|
||||
isVal, value := isValue(v)
|
||||
valueBigInt := big.NewInt(int64(value))
|
||||
if isVal {
|
||||
arr[0] = new(big.Int).Add(arr[0], valueBigInt)
|
||||
} else {
|
||||
if !used[v] {
|
||||
panic(errors.New("using variable before it's set"))
|
||||
}
|
||||
arr[indexInArray(signals, v)] = new(big.Int).Add(arr[indexInArray(signals, v)], big.NewInt(int64(1)))
|
||||
}
|
||||
return arr, used
|
||||
}
|
||||
|
||||
func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) {
|
||||
// from flat code to R1CS
|
||||
|
||||
var a [][]*big.Int
|
||||
var b [][]*big.Int
|
||||
var c [][]*big.Int
|
||||
|
||||
used := make(map[string]bool)
|
||||
for _, constraint := range circ.Constraints {
|
||||
|
||||
aConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
|
||||
bConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
|
||||
cConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
|
||||
|
||||
// if existInArray(constraint.Out) {
|
||||
if used[constraint.Out] {
|
||||
panic(errors.New("out variable already used: " + constraint.Out))
|
||||
}
|
||||
used[constraint.Out] = true
|
||||
if constraint.Op == "in" {
|
||||
for i := 0; i < len(constraint.Inputs); i++ {
|
||||
aConstraint[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(aConstraint[indexInArray(circ.Signals, constraint.Out)], big.NewInt(int64(1)))
|
||||
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.Out, used)
|
||||
bConstraint[0] = big.NewInt(int64(1))
|
||||
|
||||
}
|
||||
continue
|
||||
|
||||
} else if constraint.Op == "+" {
|
||||
cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
|
||||
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
|
||||
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V2, used)
|
||||
bConstraint[0] = big.NewInt(int64(1))
|
||||
} else if constraint.Op == "*" {
|
||||
cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
|
||||
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
|
||||
bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used)
|
||||
}
|
||||
|
||||
a = append(a, aConstraint)
|
||||
b = append(b, bConstraint)
|
||||
c = append(c, cConstraint)
|
||||
|
||||
}
|
||||
return a, b, c
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package circuitcompiler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -37,6 +38,40 @@ func TestCircuitParser(t *testing.T) {
|
||||
|
||||
// flat code to R1CS
|
||||
fmt.Println("generating R1CS from flat code")
|
||||
circuit.GenerateR1CS()
|
||||
a, b, c := circuit.GenerateR1CS()
|
||||
fmt.Print("function with inputs: ")
|
||||
fmt.Println(circuit.Inputs)
|
||||
|
||||
fmt.Print("signals: ")
|
||||
fmt.Println(circuit.Signals)
|
||||
|
||||
// expected result
|
||||
b0 := big.NewInt(int64(0))
|
||||
b1 := big.NewInt(int64(1))
|
||||
b5 := big.NewInt(int64(5))
|
||||
aExpected := [][]*big.Int{
|
||||
[]*big.Int{b0, b1, b0, b0, b0, b0},
|
||||
[]*big.Int{b0, b0, b0, b1, b0, b0},
|
||||
[]*big.Int{b0, b1, b0, b0, b1, b0},
|
||||
[]*big.Int{b5, b0, b0, b0, b0, b1},
|
||||
}
|
||||
bExpected := [][]*big.Int{
|
||||
[]*big.Int{b0, b1, b0, b0, b0, b0},
|
||||
[]*big.Int{b0, b1, b0, b0, b0, b0},
|
||||
[]*big.Int{b1, b0, b0, b0, b0, b0},
|
||||
[]*big.Int{b1, b0, b0, b0, b0, b0},
|
||||
}
|
||||
cExpected := [][]*big.Int{
|
||||
[]*big.Int{b0, b0, b0, b1, b0, b0},
|
||||
[]*big.Int{b0, b0, b0, b0, b1, b0},
|
||||
[]*big.Int{b0, b0, b0, b0, b0, b1},
|
||||
[]*big.Int{b0, b0, b1, b0, b0, b0},
|
||||
}
|
||||
|
||||
assert.Equal(t, aExpected, a)
|
||||
assert.Equal(t, bExpected, b)
|
||||
assert.Equal(t, cExpected, c)
|
||||
fmt.Println(a)
|
||||
fmt.Println(b)
|
||||
fmt.Println(c)
|
||||
}
|
||||
|
||||
@@ -16,17 +16,6 @@ type Parser struct {
|
||||
}
|
||||
}
|
||||
|
||||
type Constraint struct {
|
||||
// v1 op v2 = out
|
||||
Op Token
|
||||
V1 string
|
||||
V2 string
|
||||
Out string
|
||||
Literal string
|
||||
|
||||
Inputs []string // in func delcaration case
|
||||
}
|
||||
|
||||
func NewParser(r io.Reader) *Parser {
|
||||
return &Parser{s: NewScanner(r)}
|
||||
}
|
||||
@@ -90,7 +79,8 @@ func (p *Parser) ParseLine() (*Constraint, error) {
|
||||
c.V1 = lit
|
||||
c.Literal += lit
|
||||
// operator
|
||||
c.Op, lit = p.scanIgnoreWhitespace()
|
||||
_, lit = p.scanIgnoreWhitespace()
|
||||
c.Op = lit
|
||||
c.Literal += lit
|
||||
// v2
|
||||
_, lit = p.scanIgnoreWhitespace()
|
||||
@@ -102,6 +92,15 @@ func (p *Parser) ParseLine() (*Constraint, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func existInArray(arr []string, elem string) bool {
|
||||
for _, v := range arr {
|
||||
if v == elem {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func addToArrayIfNotExist(arr []string, elem string) []string {
|
||||
for _, v := range arr {
|
||||
if v == elem {
|
||||
@@ -111,22 +110,61 @@ func addToArrayIfNotExist(arr []string, elem string) []string {
|
||||
arr = append(arr, elem)
|
||||
return arr
|
||||
}
|
||||
|
||||
func (p *Parser) Parse() (*Circuit, error) {
|
||||
circuit := &Circuit{}
|
||||
circuit.Signals = append(circuit.Signals, "one")
|
||||
nInputs := 0
|
||||
for {
|
||||
constraint, err := p.ParseLine()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if constraint.Literal == "func" {
|
||||
// one constraint for each input
|
||||
for _, in := range constraint.Inputs {
|
||||
newConstr := &Constraint{
|
||||
Op: "in",
|
||||
Out: in,
|
||||
}
|
||||
circuit.Constraints = append(circuit.Constraints, *newConstr)
|
||||
nInputs++
|
||||
}
|
||||
circuit.Inputs = constraint.Inputs
|
||||
continue
|
||||
}
|
||||
circuit.Constraints = append(circuit.Constraints, *constraint)
|
||||
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V1)
|
||||
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V2)
|
||||
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.Out)
|
||||
isVal, _ := isValue(constraint.V1)
|
||||
if !isVal {
|
||||
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V1)
|
||||
}
|
||||
isVal, _ = isValue(constraint.V2)
|
||||
if !isVal {
|
||||
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.V2)
|
||||
}
|
||||
if constraint.Out == "out" {
|
||||
// if Out is "out", put it after the inputs
|
||||
if !existInArray(circuit.Signals, constraint.Out) {
|
||||
signalsCopy := copyArray(circuit.Signals)
|
||||
var auxSignals []string
|
||||
auxSignals = append(auxSignals, signalsCopy[0:nInputs+1]...)
|
||||
auxSignals = append(auxSignals, constraint.Out)
|
||||
auxSignals = append(auxSignals, signalsCopy[nInputs+1:]...)
|
||||
circuit.Signals = auxSignals
|
||||
}
|
||||
} else {
|
||||
circuit.Signals = addToArrayIfNotExist(circuit.Signals, constraint.Out)
|
||||
}
|
||||
}
|
||||
circuit.NVars = len(circuit.Signals)
|
||||
circuit.NSignals = len(circuit.Signals)
|
||||
circuit.NPublic = 0
|
||||
return circuit, nil
|
||||
}
|
||||
func copyArray(in []string) []string { // tmp
|
||||
var out []string
|
||||
for _, e := range in {
|
||||
out = append(out, e)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user