mirror of
https://github.com/arnaucube/go-snark-study.git
synced 2026-02-02 17:26:41 +01:00
circuit CalculateWitness, added - & / in GenerateR1CS(), added doc
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/arnaucube/go-snark/r1csqap"
|
||||
)
|
||||
|
||||
// Circuit is the data structure of the compiled circuit
|
||||
type Circuit struct {
|
||||
NVars int
|
||||
NPublic int
|
||||
@@ -22,6 +23,8 @@ type Circuit struct {
|
||||
C [][]*big.Int
|
||||
}
|
||||
}
|
||||
|
||||
// Constraint is the data structure of a flat code operation
|
||||
type Constraint struct {
|
||||
// v1 op v2 = out
|
||||
Op string
|
||||
@@ -61,7 +64,21 @@ func insertVar(arr []*big.Int, signals []string, v string, used map[string]bool)
|
||||
}
|
||||
return arr, used
|
||||
}
|
||||
func insertVarNeg(arr []*big.Int, signals []string, v string, used map[string]bool) ([]*big.Int, map[string]bool) {
|
||||
isVal, value := isValue(v)
|
||||
valueBigInt := big.NewInt(int64(value))
|
||||
if isVal {
|
||||
arr[0] = new(big.Int).Add(arr[0], valueBigInt)
|
||||
} else {
|
||||
if !used[v] {
|
||||
panic(errors.New("using variable before it's set"))
|
||||
}
|
||||
arr[indexInArray(signals, v)] = new(big.Int).Add(arr[indexInArray(signals, v)], big.NewInt(int64(-1)))
|
||||
}
|
||||
return arr, used
|
||||
}
|
||||
|
||||
// GenerateR1CS generates the R1CS polynomials from the Circuit
|
||||
func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) {
|
||||
// from flat code to R1CS
|
||||
|
||||
@@ -71,7 +88,6 @@ func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) {
|
||||
|
||||
used := make(map[string]bool)
|
||||
for _, constraint := range circ.Constraints {
|
||||
|
||||
aConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
|
||||
bConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
|
||||
cConstraint := r1csqap.ArrayOfBigZeros(len(circ.Signals))
|
||||
@@ -86,7 +102,6 @@ func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) {
|
||||
aConstraint[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(aConstraint[indexInArray(circ.Signals, constraint.Out)], big.NewInt(int64(1)))
|
||||
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.Out, used)
|
||||
bConstraint[0] = big.NewInt(int64(1))
|
||||
|
||||
}
|
||||
continue
|
||||
|
||||
@@ -95,10 +110,19 @@ func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) {
|
||||
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
|
||||
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V2, used)
|
||||
bConstraint[0] = big.NewInt(int64(1))
|
||||
} else if constraint.Op == "-" {
|
||||
cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
|
||||
aConstraint, used = insertVarNeg(aConstraint, circ.Signals, constraint.V1, used)
|
||||
aConstraint, used = insertVarNeg(aConstraint, circ.Signals, constraint.V2, used)
|
||||
bConstraint[0] = big.NewInt(int64(1))
|
||||
} else if constraint.Op == "*" {
|
||||
cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
|
||||
aConstraint, used = insertVar(aConstraint, circ.Signals, constraint.V1, used)
|
||||
bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used)
|
||||
} else if constraint.Op == "/" {
|
||||
cConstraint, used = insertVar(cConstraint, circ.Signals, constraint.V1, used)
|
||||
cConstraint[indexInArray(circ.Signals, constraint.Out)] = big.NewInt(int64(1))
|
||||
bConstraint, used = insertVar(bConstraint, circ.Signals, constraint.V2, used)
|
||||
}
|
||||
|
||||
a = append(a, aConstraint)
|
||||
@@ -108,3 +132,35 @@ func (circ *Circuit) GenerateR1CS() ([][]*big.Int, [][]*big.Int, [][]*big.Int) {
|
||||
}
|
||||
return a, b, c
|
||||
}
|
||||
|
||||
func grabVar(signals []string, w []*big.Int, vStr string) *big.Int {
|
||||
isVal, v := isValue(vStr)
|
||||
vBig := big.NewInt(int64(v))
|
||||
if isVal {
|
||||
return vBig
|
||||
} else {
|
||||
return w[indexInArray(signals, vStr)]
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateWitness calculates the Witness of a Circuit based on the given inputs
|
||||
func (circ *Circuit) CalculateWitness(inputs []*big.Int) []*big.Int {
|
||||
w := r1csqap.ArrayOfBigZeros(len(circ.Signals))
|
||||
w[0] = big.NewInt(int64(1))
|
||||
for i, input := range inputs {
|
||||
w[i+1] = input
|
||||
}
|
||||
for _, constraint := range circ.Constraints {
|
||||
if constraint.Op == "in" {
|
||||
} else if constraint.Op == "+" {
|
||||
w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Add(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
|
||||
} else if constraint.Op == "-" {
|
||||
w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Sub(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
|
||||
} else if constraint.Op == "*" {
|
||||
w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Mul(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
|
||||
} else if constraint.Op == "/" {
|
||||
w[indexInArray(circ.Signals, constraint.Out)] = new(big.Int).Div(grabVar(circ.Signals, w, constraint.V1), grabVar(circ.Signals, w, constraint.V2))
|
||||
}
|
||||
}
|
||||
return w
|
||||
}
|
||||
|
||||
@@ -74,4 +74,10 @@ func TestCircuitParser(t *testing.T) {
|
||||
fmt.Println(a)
|
||||
fmt.Println(b)
|
||||
fmt.Println(c)
|
||||
|
||||
b3 := big.NewInt(int64(3))
|
||||
inputs := []*big.Int{b3}
|
||||
// Calculate Witness
|
||||
w := circuit.CalculateWitness(inputs)
|
||||
fmt.Println("w", w)
|
||||
}
|
||||
|
||||
@@ -42,10 +42,12 @@ func isDigit(ch rune) bool {
|
||||
return (ch >= '0' && ch <= '9')
|
||||
}
|
||||
|
||||
// Scanner holds the bufio.Reader
|
||||
type Scanner struct {
|
||||
r *bufio.Reader
|
||||
}
|
||||
|
||||
// NewScanner creates a new Scanner with the given io.Reader
|
||||
func NewScanner(r io.Reader) *Scanner {
|
||||
return &Scanner{r: bufio.NewReader(r)}
|
||||
}
|
||||
@@ -62,7 +64,8 @@ func (s *Scanner) unread() {
|
||||
_ = s.r.UnreadRune()
|
||||
}
|
||||
|
||||
func (s *Scanner) Scan() (tok Token, lit string) {
|
||||
// Scan returns the Token and literal string of the current value
|
||||
func (s *Scanner) scan() (tok Token, lit string) {
|
||||
ch := s.read()
|
||||
|
||||
if isWhitespace(ch) {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Parser data structure holds the Scanner and the Parsing functions
|
||||
type Parser struct {
|
||||
s *Scanner
|
||||
buf struct {
|
||||
@@ -16,6 +17,7 @@ type Parser struct {
|
||||
}
|
||||
}
|
||||
|
||||
// NewParser creates a new parser from a io.Reader
|
||||
func NewParser(r io.Reader) *Parser {
|
||||
return &Parser{s: NewScanner(r)}
|
||||
}
|
||||
@@ -26,7 +28,7 @@ func (p *Parser) scan() (tok Token, lit string) {
|
||||
p.buf.n = 0
|
||||
return p.buf.tok, p.buf.lit
|
||||
}
|
||||
tok, lit = p.s.Scan()
|
||||
tok, lit = p.s.scan()
|
||||
|
||||
p.buf.tok, p.buf.lit = tok, lit
|
||||
|
||||
@@ -45,7 +47,8 @@ func (p *Parser) scanIgnoreWhitespace() (tok Token, lit string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (p *Parser) ParseLine() (*Constraint, error) {
|
||||
// parseLine parses the current line
|
||||
func (p *Parser) parseLine() (*Constraint, error) {
|
||||
/*
|
||||
in this version,
|
||||
line will be for example s3 = s1 * s4
|
||||
@@ -111,12 +114,13 @@ func addToArrayIfNotExist(arr []string, elem string) []string {
|
||||
return arr
|
||||
}
|
||||
|
||||
// Parse parses the lines and returns the compiled Circuit
|
||||
func (p *Parser) Parse() (*Circuit, error) {
|
||||
circuit := &Circuit{}
|
||||
circuit.Signals = append(circuit.Signals, "one")
|
||||
nInputs := 0
|
||||
for {
|
||||
constraint, err := p.ParseLine()
|
||||
constraint, err := p.parseLine()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user