Browse Source

New Multiplication Gate reduction algorithm!

Extracting coefficients from each output, s.t. each gate has a higher chance of being reused. See new Readme
pull/8/head
mottla 4 years ago
parent
commit
49fead2197
6 changed files with 711 additions and 192 deletions
  1. +11
    -5
      README.md
  2. +39
    -184
      circuitcompiler/Programm.go
  3. +154
    -1
      circuitcompiler/Programm_test.go
  4. +2
    -2
      circuitcompiler/circuit.go
  5. +300
    -0
      circuitcompiler/factorHandling.go
  6. +205
    -0
      circuitcompiler/factorHandling_test.go

+ 11
- 5
README.md

@ -4,6 +4,7 @@ Fork UNDER CONSTRUCTION! Will ask for merge soon
Current implementation status:
- [x] optimized gate reduction!! Reusing gates as often as possible! See the awesome results below :)
- [x] extended circuit code compiler
- [x] move witness calculation outside the setup phase
- [x] fixed hard bugs
@ -33,12 +34,17 @@ def mul(a,b):
```
R1CS Output:
```go
[[0 0 210 0 0 0 0 0 0 0 0 0] [0 0 0 1 0 0 0 0 0 0 0 0] [0 0 0 1 0 0 0 0 0 0 0 0] [0 210 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 1 0 0 0 0 0] [0 0 0 0 0 0 1 0 0 0 0 0] [0 1 0 0 0 0 0 0 0 0 0 0] [0 1 0 0 0 0 0 0 0 0 0 0] [1 0 0 0 0 0 0 0 0 0 0 0]]
[[0 0 210 0 0 0 0 0 0 0 0 0] [0 0 5 0 0 0 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0 0 0] [0 210 0 0 0 0 0 0 0 0 0 0] [0 5 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 0 0 0 0] [0 1 0 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 1 0 0 1 0 1 0]]
[[0 0 0 1 0 0 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0 0 0] [0 0 0 0 0 1 0 0 0 0 0 0] [0 0 0 0 0 0 1 0 0 0 0 0] [0 0 0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 0 0 0 0 1]]
[[0 0 1 0 0 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0] [0 0 0 0 1 0 0 0 0 0] [0 1 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 1 0 0] [1 0 0 0 0 0 0 0 0 0]]
[[0 0 1 0 0 0 0 0 0 0] [0 0 1 0 0 0 0 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 1 0 0 0 0 0 0 0 0] [0 1 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0 1 0] [0 0 0 0 0 0 9724050000 0 1 9724050000]]
[[0 0 0 0 1 0 0 0 0 0] [0 0 0 0 0 1 0 0 0 0] [0 0 0 0 0 0 1 0 0 0] [0 0 0 0 0 0 0 1 0 0] [0 0 0 0 0 0 0 0 1 0] [0 0 0 0 0 0 0 0 0 1] [0 0 0 1 0 0 0 0 0 0]]
input
[7 11]
witness
[1 7 11 5336100 293485500 1566067976550000 2160900 75631500 163432108350000 49 343 1729500084900343]
[1 7 11 1729500084900343 121 1331 161051 49 343 16807]
another input
[365235 11876525]
witness
[1 365235 11876525 2297704271284150716235246193843898764109352875 141051846075625 1675205776213312203125 236290867291438012851239954111328125 133396605225 48721109109352875 6499230557984496821593771875]
```
Note that we only need 9 multiplication Gates instead of 16
Note that we only need 7 multiplication Gates instead of 16. The 4th witness value is the programs output. Use python script to check correctness!

+ 39
- 184
circuitcompiler/Programm.go

@ -21,6 +21,12 @@ type R1CS struct {
B [][]*big.Int
C [][]*big.Int
}
type MultiplicationGateSignature struct {
identifier string
commonExtracted [2]int //if the mgate had a extractable factor, it will be stored here
}
type Program struct {
functions map[string]*Circuit
globalInputs []string
@ -34,12 +40,12 @@ type Program struct {
//this datastructure is nice but maybe ill replace it later with something less confusing
//it serves the elementary purpose of not computing a variable a second time.
//it boosts parse time
computedInContext map[string]map[string]string
computedInContext map[string]map[string]MultiplicationGateSignature
//to reduce the number of multiplication gates, we store each factor signature, and the variable name,
//so each time a variable is computed, that happens to have the very same factors, we reuse the former
//it boost setup and proof time
computedFactors map[string]string
computedFactors map[string]MultiplicationGateSignature
}
//returns the cardinality of all main inputs + 1 for the "one" signal
@ -129,10 +135,10 @@ func (c *Circuit) buildTree(g *gate) {
func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
orderedmGates = []gate{}
p.computedInContext = make(map[string]map[string]string)
p.computedFactors = make(map[string]string)
rootHash := []byte{}
p.computedInContext[string(rootHash)] = make(map[string]string)
p.computedInContext = make(map[string]map[string]MultiplicationGateSignature)
p.computedFactors = make(map[string]MultiplicationGateSignature)
rootHash := make([]byte, 10)
p.computedInContext[string(rootHash)] = make(map[string]MultiplicationGateSignature)
p.r1CSRecursiveBuild(p.getMainCircuit(), p.getMainCircuit().root, rootHash, &orderedmGates, false, false)
return orderedmGates
}
@ -140,7 +146,7 @@ func (p *Program) ReduceCombinedTree() (orderedmGates []gate) {
//recursively walks through the parse tree to create a list of all
//multiplication gates needed for the QAP construction
//Takes into account, that multiplication with constants and addition (= substraction) can be reduced, and does so
func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTraceBuildup []byte, orderedmGates *[]gate, negate bool, invert bool) (facs []factor, hashTraceResult []byte, variableEnd bool) {
func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTraceBuildup []byte, orderedmGates *[]gate, negate bool, invert bool) (facs factors, hashTraceResult []byte, variableEnd bool) {
if node.OperationType() == CONST {
b1, v1 := isValue(node.value.Out)
@ -152,7 +158,7 @@ func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTr
mul = [2]int{1, v1}
}
return []factor{{typ: CONST, negate: negate, multiplicative: mul}}, make([]byte, 10), false
return factors{{typ: CONST, negate: negate, multiplicative: mul}}, hashTraceBuildup, false
}
if node.OperationType() == FUNC {
@ -161,21 +167,19 @@ func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTr
node = nextContext.root
hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(currentCircuit.currentOutputName()))
if _, ex := p.computedInContext[string(hashTraceBuildup)]; !ex {
p.computedInContext[string(hashTraceBuildup)] = make(map[string]string)
p.computedInContext[string(hashTraceBuildup)] = make(map[string]MultiplicationGateSignature)
}
}
if node.OperationType() == IN {
fac := factor{typ: IN, name: node.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}
hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out))
return []factor{fac}, hashTraceBuildup, true
fac := &factor{typ: IN, name: node.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}
return factors{fac}, hashTraceBuildup, true
}
if out, ex := p.computedInContext[string(hashTraceBuildup)][node.value.Out]; ex {
fac := factor{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}
hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(node.value.Out))
return []factor{fac}, hashTraceBuildup, true
fac := &factor{typ: IN, name: out.identifier, invert: invert, negate: negate, multiplicative: out.commonExtracted}
return factors{fac}, hashTraceBuildup, true
}
leftFactors, leftHash, variableEnd := p.r1CSRecursiveBuild(currentCircuit, node.left, hashTraceBuildup, orderedmGates, negate, invert)
@ -185,19 +189,25 @@ func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTr
if node.OperationType() == MULTIPLY {
if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root {
//if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root {
return mulFactors(leftFactors, rightFactors), append(leftHash, rightHash...), variableEnd || cons
return mulFactors(leftFactors, rightFactors), hashTraceBuildup, variableEnd || cons
}
sig := factorsSignature(leftFactors, rightFactors)
if out, ex := p.computedFactors[sig]; ex {
return []factor{{typ: IN, name: out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true
sig, newLef, newRigh := factorsSignature(leftFactors, rightFactors)
if out, ex := p.computedFactors[sig.identifier]; ex {
return factors{{typ: IN, name: out.identifier, invert: invert, negate: negate, multiplicative: sig.commonExtracted}}, hashTraceBuildup, true
}
rootGate := cloneGate(node)
//rootGate := node
rootGate.index = len(*orderedmGates)
rootGate.leftIns = leftFactors
rootGate.rightIns = rightFactors
if p.getMainCircuit().root == node {
newLef = mulFactors(newLef, factors{&factor{typ: CONST, multiplicative: sig.commonExtracted}})
}
rootGate.leftIns = newLef
rootGate.rightIns = newRigh
out := hashTogether(leftHash, rightHash)
rootGate.value.V1 = rootGate.value.V1 + string(leftHash[:10])
rootGate.value.V2 = rootGate.value.V2 + string(rightHash[:10])
@ -208,183 +218,28 @@ func (p *Program) r1CSRecursiveBuild(currentCircuit *Circuit, node *gate, hashTr
rootGate.value.Out = rootGate.value.Out + string(out[:10])
}
p.computedInContext[string(hashTraceBuildup)][node.value.Out] = rootGate.value.Out
p.computedInContext[string(hashTraceBuildup)][node.value.Out] = MultiplicationGateSignature{identifier: rootGate.value.Out, commonExtracted: sig.commonExtracted}
p.computedFactors[sig] = rootGate.value.Out
p.computedFactors[sig.identifier] = MultiplicationGateSignature{identifier: rootGate.value.Out, commonExtracted: sig.commonExtracted}
*orderedmGates = append(*orderedmGates, *rootGate)
hashTraceBuildup = hashTogether(hashTraceBuildup, []byte(rootGate.value.Out))
return []factor{{typ: IN, name: rootGate.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}, hashTraceBuildup, true
return factors{{typ: IN, name: rootGate.value.Out, invert: invert, negate: negate, multiplicative: sig.commonExtracted}}, hashTraceBuildup, true
}
switch node.OperationType() {
case PLUS:
return addFactors(leftFactors, rightFactors), hashTogether(leftHash, rightHash), variableEnd || cons
return addFactors(leftFactors, rightFactors), hashTraceBuildup, variableEnd || cons
default:
panic("unexpected gate")
}
}
type factor struct {
typ Token
name string
invert, negate bool
multiplicative [2]int
}
func (f factor) String() string {
if f.typ == CONST {
return fmt.Sprintf("(const fac: %v)", f.multiplicative)
}
str := f.name
if f.invert {
str += "^-1"
}
if f.negate {
str = "-" + str
}
return fmt.Sprintf("(\"%s\" fac: %v)", str, f.multiplicative)
}
func mul2DVector(a, b [2]int) [2]int {
return [2]int{a[0] * b[0], a[1] * b[1]}
}
func factorsSignature(leftFactors, rightFactors []factor) string {
hasher.Reset()
//using a commutative operation here would be better. since a * b = b * a, but H(a,b) != H(b,a)
//could use (g^a)^b == (g^b)^a where g is a generator of some prime field where the dicrete log is known to be hard
for _, facLeft := range leftFactors {
hasher.Write([]byte(facLeft.String()))
}
for _, Righ := range rightFactors {
hasher.Write([]byte(Righ.String()))
}
return string(hasher.Sum(nil))[:16]
}
//multiplies factor elements and returns the result
//in case the factors do not hold any constants and all inputs are distinct, the output will be the concatenation of left+right
func mulFactors(leftFactors, rightFactors []factor) (result []factor) {
for _, facLeft := range leftFactors {
for i, facRight := range rightFactors {
if facLeft.typ == CONST && facRight.typ == IN {
rightFactors[i] = factor{typ: IN, name: facRight.name, negate: Xor(facLeft.negate, facRight.negate), invert: facRight.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
continue
}
if facRight.typ == CONST && facLeft.typ == IN {
rightFactors[i] = factor{typ: IN, name: facLeft.name, negate: Xor(facLeft.negate, facRight.negate), invert: facLeft.invert, multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
continue
}
if facRight.typ&facLeft.typ == CONST {
rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
continue
}
//tricky part here
//this one should only be reached, after a true mgate had its left and right braches computed. here we
//a factor can appear at most in quadratic form. we reduce terms a*a^-1 here.
if facRight.typ&facLeft.typ == IN {
if facLeft.name == facRight.name {
if facRight.invert != facLeft.invert {
rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
continue
}
}
//rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
//continue
}
panic("unexpected. If this errror is thrown, its probably brcause a true multiplication gate has been skipped and treated as on with constant multiplication or addition ")
}
}
return rightFactors
}
//returns the absolute value of a signed int and a flag telling if the input was positive or not
//this implementation is awesome and fast (see Henry S Warren, Hackers's Delight)
func abs(n int) (val int, positive bool) {
y := n >> 63
return (n ^ y) - y, y == 0
}
//returns the reduced sum of two input factor arrays
//if no reduction was done (worst case), it returns the concatenation of the input arrays
func addFactors(leftFactors, rightFactors []factor) []factor {
var found bool
res := make([]factor, 0, len(leftFactors)+len(rightFactors))
for _, facLeft := range leftFactors {
found = false
for i, facRight := range rightFactors {
if facLeft.typ&facRight.typ == CONST {
var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
if facLeft.negate {
a0 *= -1
}
if facRight.negate {
b0 *= -1
}
absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
rightFactors[i] = factor{typ: CONST, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
found = true
//res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
break
}
if facLeft.typ&facRight.typ == IN && facLeft.invert == facRight.invert && facLeft.name == facRight.name {
var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
if facLeft.negate {
a0 *= -1
}
if facRight.negate {
b0 *= -1
}
absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
rightFactors[i] = factor{typ: IN, invert: facRight.invert, name: facRight.name, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
found = true
//res = append(res, factor{typ: CONST, negate: negate, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}})
break
}
}
if !found {
res = append(res, facLeft)
}
}
for _, val := range rightFactors {
if val.multiplicative[0] != 0 {
res = append(res, val)
}
}
return res
}
//copies a gate neglecting its references to other gates
func cloneGate(in *gate) (out *gate) {
constr := &Constraint{Inputs: in.value.Inputs, Out: in.value.Out, Op: in.value.Op, invert: in.value.invert, negate: in.value.negate, V2: in.value.V2, V1: in.value.V1}
nRightins := make([]factor, len(in.rightIns))
nLeftInst := make([]factor, len(in.leftIns))
for k, v := range in.rightIns {
nRightins[k] = v
}
for k, v := range in.leftIns {
nLeftInst[k] = v
}
nRightins := in.rightIns.clone()
nLeftInst := in.leftIns.clone()
return &gate{value: constr, leftIns: nLeftInst, rightIns: nRightins, index: in.index}
}
@ -499,7 +354,7 @@ func (p *Program) GenerateReducedR1CS(mGates []gate) (r1CS R1CS) {
bConstraint := r1csqap.ArrayOfBigZeros(size)
cConstraint := r1csqap.ArrayOfBigZeros(size)
insertValue := func(val factor, arr []*big.Int) {
insertValue := func(val *factor, arr []*big.Int) {
if val.typ != CONST {
if _, ex := indexMap[val.name]; !ex {
panic(fmt.Sprintf("%v index not found!!!", val.name))

+ 154
- 1
circuitcompiler/Programm_test.go

@ -114,9 +114,26 @@ var correctnesTest = []TraceCorrectnessTest{
out = g * i
`,
},
{
io: []InOut{{
inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))},
result: big.NewInt(int64(264)),
}},
code: `
def main(a,b,c,d):
e = a * 3
f = b * 7
g = c * 11
h = d * 13
i = e + f
j = g + h
k = i + j
out = k * 1
`,
},
}
func TestNewProgramm(t *testing.T) {
func TestCorrectness(t *testing.T) {
for _, test := range correctnesTest {
parser := NewParser(strings.NewReader(test.code))
@ -160,3 +177,139 @@ func TestNewProgramm(t *testing.T) {
}
}
//test to check gate optimisation
//mess around the code s.t. results is unchanged. number of gates should remain the same in any case
func TestGateOptimisation(t *testing.T) {
io := InOut{
inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
result: bigNumberResult1,
}
equalCodes := []string{
`
def main( x , z ) :
out = do(z) + add(x,x)
def do(x):
e = x * 5
b = e * 6
c = 7 * b
f = c * 1
d = f * c
out = d * mul(d,e)
def add(x ,k):
z = k * x
out = do(x) + mul(x,z)
def mul(a,b):
out = b * a
`, //switching order
`
def main( x , z ) :
out = do(z) + add(x,x)
def do(x):
e = x * 5
b = e * 6
c = b * 7
f = c * 1
d = c * f
out = d * mul(d,e)
def add(x ,k):
z = k * x
out = do(x) + mul(x,z)
def mul(a,b):
out = a * b
`, //switching order
`
def main( x , z ) :
out = do(z) + add(x,x)
def do(x):
e = x * 5
j = e * 3
k = e * 3
b = j+k
c = b * 7
f = c * 1
d = c * f
g = d * 1
out = g * mul(d,e)
def add(k ,x):
z = k * x
out = do(x) + mul(x,z)
def mul(b,a):
out = a * b
`, `
def main( x , z ) :
out = add(x,x)+do(z)
def do(x):
e = x * 5
j = 3 * e
k = e * 3
b = j+k
c = b * 7
f = c * 1
d = c * f
g = d * 1
out = mul(d,e) * g
def add(k ,x):
z = k * x
out = mul(x,z) + do(x)
def mul(b,a):
out = a * b
`,
}
var r1css = make([]R1CS, len(equalCodes))
for i, c := range equalCodes {
parser := NewParser(strings.NewReader(c))
program, err := parser.Parse()
if err != nil {
panic(err)
}
program.BuildConstraintTrees()
gates := program.ReduceCombinedTree()
for _, g := range gates {
fmt.Printf("\n %v", g)
}
fmt.Println("\n generating R1CS")
r1cs := program.GenerateReducedR1CS(gates)
r1css[i] = r1cs
fmt.Println(r1cs.A)
fmt.Println(r1cs.B)
fmt.Println(r1cs.C)
}
for i := 0; i < len(equalCodes)-1; i++ {
assert.Equal(t, len(r1css[i].A), len(r1css[i+1].A))
}
for i := 0; i < len(equalCodes); i++ {
//assert.Equal(t, len(r1css[i].A), len(r1css[i+1].A))
w := CalculateWitness(io.inputs, r1css[i])
fmt.Println("witness")
fmt.Println(w)
assert.Equal(t, io.result, w[3])
}
}

+ 2
- 2
circuitcompiler/circuit.go

@ -27,8 +27,8 @@ type gate struct {
right *gate
funcInputs []*gate
value *Constraint //is a pointer a good thing here??
leftIns []factor //leftIns and RightIns after addition gates have been reduced. only multiplication gates remain
rightIns []factor
leftIns factors //leftIns and RightIns after addition gates have been reduced. only multiplication gates remain
rightIns factors
}
func (g gate) String() string {

+ 300
- 0
circuitcompiler/factorHandling.go

@ -0,0 +1,300 @@
package circuitcompiler
import (
"fmt"
"math/big"
"sort"
"strings"
)
type factors []*factor
type factor struct {
typ Token
name string
invert, negate bool
multiplicative [2]int
}
func (f factors) Len() int {
return len(f)
}
func (f factors) Swap(i, j int) {
f[i], f[j] = f[j], f[i]
}
func (f factors) Less(i, j int) bool {
if strings.Compare(f[i].String(), f[j].String()) < 0 {
return false
}
return true
}
func (f factor) String() string {
if f.typ == CONST {
return fmt.Sprintf("(const fac: %v)", f.multiplicative)
}
str := f.name
if f.invert {
str += "^-1"
}
if f.negate {
str = "-" + str
}
return fmt.Sprintf("(\"%s\" fac: %v)", str, f.multiplicative)
}
func (f factors) clone() (res factors) {
res = make(factors, len(f))
for k, v := range f {
res[k] = &factor{multiplicative: v.multiplicative, typ: v.typ, name: v.name, invert: v.invert, negate: v.negate}
}
return
}
func (f factors) normalizeAll() {
for i, _ := range f {
f[i].multiplicative = normalizeFactor(f[i].multiplicative)
}
}
// find Least Common Multiple (LCM) via GCD
func LCMsmall(a, b int) int {
result := a * b / GCD(a, b)
return result
}
func extractFactor(f factors) (factors, [2]int) {
lcm := f[0].multiplicative[1]
for i := 1; i < len(f); i++ {
lcm = LCMsmall(f[i].multiplicative[1], lcm)
}
for i := 0; i < len(f); i++ {
f[i].multiplicative[0] = (lcm / f[i].multiplicative[1]) * f[i].multiplicative[0]
}
gcd := f[0].multiplicative[0]
for i := 1; i < len(f); i++ {
gcd = GCD(f[i].multiplicative[0], gcd)
}
for i := 0; i < len(f); i++ {
f[i].multiplicative[0] = f[i].multiplicative[0] / gcd
f[i].multiplicative[1] = 1
}
return f, [2]int{gcd, lcm}
}
func factorsSignature(leftFactors, rightFactors factors) (sig MultiplicationGateSignature, extractedLeftFactors, extractedRightFactors factors) {
leftFactors = leftFactors.clone()
rightFactors = rightFactors.clone()
leftFactors.normalizeAll()
var extractedFacLeft [2]int
leftFactors, extractedFacLeft = extractFactor(leftFactors)
sort.Sort(leftFactors)
hasher.Reset()
for _, fac := range leftFactors {
hasher.Write([]byte(fac.String()))
}
leftNum := new(big.Int).SetBytes(hasher.Sum(nil))
rightFactors.normalizeAll()
var extractedFacRight [2]int
rightFactors, extractedFacRight = extractFactor(rightFactors)
sort.Sort(rightFactors)
hasher.Reset()
for _, fac := range rightFactors {
hasher.Write([]byte(fac.String()))
}
rightNum := new(big.Int).SetBytes(hasher.Sum(nil))
//we did all this, because multiplication is commutativ, and we want the signature of a
//mulitplication gate factorsSignature(a,b) == factorsSignature(b,a)
leftNum.Add(leftNum, rightNum)
res := normalizeFactor(mul2DVector(extractedFacLeft, extractedFacRight))
return MultiplicationGateSignature{identifier: leftNum.String()[:16], commonExtracted: res}, leftFactors, rightFactors
}
func lengthOfLongestSlice(a, b factors) int {
if len(a) > len(b) {
return len(a)
}
return len(b)
}
//multiplies factor elements and returns the result
//in case the factors do not hold any constants and all inputs are distinct, the output will be the concatenation of left+right
func mulFactors(leftFactors, rightFactors factors) (result factors) {
if len(leftFactors) < len(rightFactors) {
tmp := leftFactors
leftFactors = rightFactors
rightFactors = tmp
}
for i, left := range leftFactors {
for _, right := range rightFactors {
if left.typ == CONST && right.typ == IN {
leftFactors[i] = &factor{typ: IN, name: right.name, negate: Xor(left.negate, right.negate), invert: right.invert, multiplicative: mul2DVector(right.multiplicative, left.multiplicative)}
continue
}
if right.typ == CONST && left.typ == IN {
leftFactors[i] = &factor{typ: IN, name: left.name, negate: Xor(left.negate, right.negate), invert: left.invert, multiplicative: mul2DVector(right.multiplicative, left.multiplicative)}
continue
}
if right.typ&left.typ == CONST {
leftFactors[i] = &factor{typ: CONST, negate: Xor(right.negate, left.negate), multiplicative: mul2DVector(right.multiplicative, left.multiplicative)}
continue
}
//tricky part here
//this one should only be reached, after a true mgate had its left and right braches computed. here we
//a factor can appear at most in quadratic form. we reduce terms a*a^-1 here.
if right.typ&left.typ == IN {
if left.name == right.name {
if right.invert != left.invert {
leftFactors[i] = &factor{typ: CONST, negate: Xor(right.negate, left.negate), multiplicative: mul2DVector(right.multiplicative, left.multiplicative)}
continue
}
}
//rightFactors[i] = factor{typ: CONST, negate: Xor(facRight.negate, facLeft.negate), multiplicative: mul2DVector(facRight.multiplicative, facLeft.multiplicative)}
//continue
}
panic("unexpected. If this errror is thrown, its probably brcause a true multiplication gate has been skipped and treated as on with constant multiplication or addition ")
}
}
return leftFactors
}
//returns the absolute value of a signed int and a flag telling if the input was positive or not
//this implementation is awesome and fast (see Henry S Warren, Hackers's Delight)
func abs(n int) (val int, positive bool) {
y := n >> 63
return (n ^ y) - y, y == 0
}
//adds two factors to one iff they are both are constants or of the same variable
func addFactor(facLeft, facRight factor) (couldAdd bool, sum factor) {
if facLeft.typ&facRight.typ == CONST {
var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
if facLeft.negate {
a0 *= -1
}
if facRight.negate {
b0 *= -1
}
absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
return true, factor{typ: CONST, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
}
if facLeft.typ&facRight.typ == IN && facLeft.invert == facRight.invert && facLeft.name == facRight.name {
var a0, b0 = facLeft.multiplicative[0], facRight.multiplicative[0]
if facLeft.negate {
a0 *= -1
}
if facRight.negate {
b0 *= -1
}
absValue, positive := abs(a0*facRight.multiplicative[1] + facLeft.multiplicative[1]*b0)
return true, factor{typ: IN, invert: facRight.invert, name: facRight.name, negate: !positive, multiplicative: [2]int{absValue, facLeft.multiplicative[1] * facRight.multiplicative[1]}}
}
return false, factor{}
}
//returns the reduced sum of two input factor arrays
//if no reduction was done (worst case), it returns the concatenation of the input arrays
func addFactors(leftFactors, rightFactors factors) factors {
var found bool
res := make(factors, 0, len(leftFactors)+len(rightFactors))
for _, facLeft := range leftFactors {
found = false
for i, facRight := range rightFactors {
var sum factor
found, sum = addFactor(*facLeft, *facRight)
if found {
rightFactors[i] = &sum
break
}
}
if !found {
res = append(res, facLeft)
}
}
for _, val := range rightFactors {
if val.multiplicative[0] != 0 {
res = append(res, val)
}
}
return res
}
// -4/-5 -> 4/5 ; 5/-7 -> -5/7 ; 6 /2 -> 3 / 1
func normalizeFactor(b [2]int) [2]int {
resa, signa := abs(b[0])
resb, signb := abs(b[1])
gcd := GCD(resa, resb)
if Xor(signa, signb) {
resa = -resa
}
return [2]int{resa / gcd, resb / gcd}
}
//naive component multiplication
func mul2DVector(a, b [2]int) [2]int {
return [2]int{a[0] * b[0], a[1] * b[1]}
}
// find Least Common Multiple (LCM) via GCD
func LCM(a, b int, integers ...int) int {
result := a * b / GCD(a, b)
for i := 0; i < len(integers); i++ {
result = LCM(result, integers[i])
}
return result
}
//euclidean algo to determine greates common divisor
func GCD(a, b int) int {
for b != 0 {
t := b
b = a % b
a = t
}
return a
}

+ 205
- 0
circuitcompiler/factorHandling_test.go

@ -0,0 +1,205 @@
package circuitcompiler
import (
"fmt"
"github.com/stretchr/testify/assert"
"math/big"
"math/rand"
"strings"
"testing"
)
//factors are essential to identify, if a specific gate has been computed already
//eg. if we can extract a factor from a gate that is independent of commutativity, multiplicativitz we will do much better, in finding and reusing old outputs do
//minimize the multiplication gate number
// for example the gate a*b == gate b*a hence, we only need to compute one of both.
func TestFactorSignature(t *testing.T) {
facNeutral := factors{&factor{multiplicative: [2]int{1, 1}}}
//dont let the random number be to big, cuz of overflow
r1, r2 := rand.Intn(1<<16), rand.Intn(1<<16)
fmt.Println(r1, r2)
equalityGroups := [][]factors{
[]factors{ //test sign and gcd
{&factor{multiplicative: [2]int{r1 * 2, -r2 * 2}}},
{&factor{multiplicative: [2]int{-r1, r2}}},
{&factor{multiplicative: [2]int{r1, -r2}}},
{&factor{multiplicative: [2]int{r1 * 3, -r2 * 3}}},
{&factor{multiplicative: [2]int{r1 * r1, -r2 * r1}}},
{&factor{multiplicative: [2]int{r1 * r2, -r2 * r2}}},
}, []factors{ //test kommutativity
{&factor{multiplicative: [2]int{r1, -r2}}, &factor{multiplicative: [2]int{13, 27}}},
{&factor{multiplicative: [2]int{13, 27}}, &factor{multiplicative: [2]int{-r1, r2}}},
},
}
for _, equalityGroup := range equalityGroups {
for i := 0; i < len(equalityGroup)-1; i++ {
sig1, _, _ := factorsSignature(facNeutral, equalityGroup[i])
sig2, _, _ := factorsSignature(facNeutral, equalityGroup[i+1])
assert.Equal(t, sig1, sig2)
sig1, _, _ = factorsSignature(equalityGroup[i], facNeutral)
sig2, _, _ = factorsSignature(facNeutral, equalityGroup[i+1])
assert.Equal(t, sig1, sig2)
sig1, _, _ = factorsSignature(facNeutral, equalityGroup[i])
sig2, _, _ = factorsSignature(equalityGroup[i+1], facNeutral)
assert.Equal(t, sig1, sig2)
sig1, _, _ = factorsSignature(equalityGroup[i], facNeutral)
sig2, _, _ = factorsSignature(equalityGroup[i+1], facNeutral)
assert.Equal(t, sig1, sig2)
}
}
}
func TestGate_ExtractValues(t *testing.T) {
facNeutral := factors{&factor{multiplicative: [2]int{8, 7}}, &factor{multiplicative: [2]int{9, 3}}}
facNeutral2 := factors{&factor{multiplicative: [2]int{9, 1}}, &factor{multiplicative: [2]int{13, 7}}}
fmt.Println(factorsSignature(facNeutral, facNeutral2))
f, fc := extractFactor(facNeutral)
fmt.Println(f)
fmt.Println(fc)
f2, _ := extractFactor(facNeutral2)
fmt.Println(f)
fmt.Println(fc)
fmt.Println(factorsSignature(facNeutral, facNeutral2))
fmt.Println(factorsSignature(f, f2))
}
func TestGCD(t *testing.T) {
fmt.Println(LCM(10, 15))
fmt.Println(LCM(10, 15, 20))
fmt.Println(LCM(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
}
var correctnesTest2 = []TraceCorrectnessTest{
{
io: []InOut{{
inputs: []*big.Int{big.NewInt(int64(643)), big.NewInt(int64(76548465))},
result: big.NewInt(int64(98441327276)),
}, {
inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
result: big.NewInt(int64(8675445947220)),
}},
code: `
def main(a,b):
c = a + b
e = c - a
f = e + b
g = f + 2
out = g * a
`,
},
{io: []InOut{{
inputs: []*big.Int{big.NewInt(int64(7))},
result: big.NewInt(int64(4)),
}},
code: `
def mul(a,b):
out = a * b
def main(a):
b = a * a
c = 4 - b
d = 5 * c
out = mul(d,c) / mul(b,b)
`,
},
{io: []InOut{{
inputs: []*big.Int{big.NewInt(int64(7)), big.NewInt(int64(11))},
result: big.NewInt(int64(22638)),
}, {
inputs: []*big.Int{big.NewInt(int64(365235)), big.NewInt(int64(11876525))},
result: bigNumberResult2,
}},
code: `
def main(a,b):
d = b + b
c = a * d
e = c - a
out = e * c
`,
},
{
io: []InOut{{
inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))},
result: big.NewInt(int64(444675)),
}},
code: `
def main(a,b,c,d):
e = a * b
f = c * d
g = e * f
h = g / e
i = h * 5
out = g * i
`,
},
{
io: []InOut{{
inputs: []*big.Int{big.NewInt(int64(3)), big.NewInt(int64(5)), big.NewInt(int64(7)), big.NewInt(int64(11))},
result: big.NewInt(int64(264)),
}},
code: `
def main(a,b,c,d):
e = a * 3
f = b * 7
g = c * 11
h = d * 13
i = e + f
j = g + h
k = i + j
out = k * 1
`,
},
}
func TestCorrectness2(t *testing.T) {
for _, test := range correctnesTest2 {
parser := NewParser(strings.NewReader(test.code))
program, err := parser.Parse()
if err != nil {
panic(err)
}
fmt.Println("\n unreduced")
fmt.Println(test.code)
program.BuildConstraintTrees()
for k, v := range program.functions {
fmt.Println(k)
PrintTree(v.root)
}
fmt.Println("\nReduced gates")
//PrintTree(froots["mul"])
gates := program.ReduceCombinedTree()
for _, g := range gates {
fmt.Printf("\n %v", g)
}
fmt.Println("\n generating R1CS")
r1cs := program.GenerateReducedR1CS(gates)
fmt.Println(r1cs.A)
fmt.Println(r1cs.B)
fmt.Println(r1cs.C)
for _, io := range test.io {
inputs := io.inputs
fmt.Println("input")
fmt.Println(inputs)
w := CalculateWitness(inputs, r1cs)
fmt.Println("witness")
fmt.Println(w)
assert.Equal(t, io.result, w[program.GlobalInputCount()])
}
}
}

Loading…
Cancel
Save