@ -1,10 +1,12 @@
package circuitcompiler
package circuitcompiler
import (
import (
"crypto/sha256"
"fmt"
"fmt"
"github.com/mottla/go-snark/bn128"
"github.com/mottla/go-snark/bn128"
"github.com/mottla/go-snark/fields"
"github.com/mottla/go-snark/fields"
"github.com/mottla/go-snark/r1csqap"
"github.com/mottla/go-snark/r1csqap"
"hash"
"math/big"
"math/big"
"sync"
"sync"
)
)
@ -19,8 +21,9 @@ type Program struct {
functions map [ string ] * Circuit
functions map [ string ] * Circuit
globalInputs [ ] string
globalInputs [ ] string
arithmeticEnvironment utils //find a better name
arithmeticEnvironment utils //find a better name
R1CS struct {
sha256Hasher hash . Hash
computedInContext map [ string ] map [ string ] string
R1CS struct {
A [ ] [ ] * big . Int
A [ ] [ ] * big . Int
B [ ] [ ] * big . Int
B [ ] [ ] * big . Int
C [ ] [ ] * big . Int
C [ ] [ ] * big . Int
@ -45,10 +48,12 @@ func (p *Program) BuildConstraintTrees() {
p . getMainCircuit ( ) . gateMap [ mainRoot . value . Out ] = mainRoot
p . getMainCircuit ( ) . gateMap [ mainRoot . value . Out ] = mainRoot
}
}
//for _, in := range p.getMainCircuit().Inputs {
// p.globalInputs = append(p.globalInputs, composeNewFunction(in, p.getMainCircuit().Inputs))
//}
for _ , in := range p . getMainCircuit ( ) . Inputs {
for _ , in := range p . getMainCircuit ( ) . Inputs {
p . globalInputs = append ( p . globalInputs , composeNewFunction ( in , p . getMainCircuit ( ) . Inputs ) )
p . globalInputs = append ( p . globalInputs , in )
}
}
var wg = sync . WaitGroup { }
var wg = sync . WaitGroup { }
for _ , circuit := range p . functions {
for _ , circuit := range p . functions {
@ -101,67 +106,86 @@ func (c *Circuit) buildTree(g *gate) {
}
}
func ( p * Program ) ReduceCombinedTree ( ) ( orderedmGates [ ] gate ) {
func ( p * Program ) ReduceCombinedTree ( ) ( orderedmGates [ ] gate ) {
mGatesUsed := make ( map [ string ] bool )
//mGatesUsed := make(map[string]bool)
orderedmGates = [ ] gate { }
orderedmGates = [ ] gate { }
p . r1CSRecursiveBuild ( p . getMainCircuit ( ) , p . getMainCircuit ( ) . root , mGatesUsed , & orderedmGates , false , false )
p . computedInContext = make ( map [ string ] map [ string ] string )
rootHash := [ ] byte { }
p . computedInContext [ string ( rootHash ) ] = make ( map [ string ] string )
p . r1CSRecursiveBuild ( p . getMainCircuit ( ) , p . getMainCircuit ( ) . root , rootHash , & orderedmGates , false , false )
return orderedmGates
return orderedmGates
}
}
func ( p * Program ) r1CSRecursiveBuild ( currentCircuit * Circuit , root * gate , mGatesUsed map [ string ] bool , orderedmGates * [ ] gate , negate bool , inverse bool ) ( variableEnd bool ) {
func ( p * Program ) r1CSRecursiveBuild ( currentCircuit * Circuit , node * gate , hashTraceBuildup [ ] byte , orderedmGates * [ ] gate , negate bool , invert bool ) ( facs [ ] factor , hashTraceResult [ ] byte , variableEnd bool ) {
if root . OperationType ( ) == IN {
return true
}
if node . OperationType ( ) == CONST {
b1 , v1 := isValue ( node . value . Out )
if ! b1 {
panic ( "not a constant" )
}
mul := [ 2 ] int { v1 , 1 }
if invert {
mul = [ 2 ] int { 1 , v1 }
if root . OperationType ( ) == CONST {
return false
}
return [ ] factor { { typ : CONST , negate : negate , multiplicative : mul } } , make ( [ ] byte , 10 ) , false
}
}
if root . OperationType ( ) == FUNC {
nextContext := p . extendedFunctionRenamer ( currentCircuit , root . value )
if node . OperationType ( ) == FUNC {
nextContext := p . extendedFunctionRenamer ( currentCircuit , node . value )
currentCircuit = nextContext
currentCircuit = nextContext
root = nextContext . root
node = nextContext . root
hashTraceBuildup = hashTogether ( hashTraceBuildup , [ ] byte ( currentCircuit . currentOutputName ( ) ) )
if _ , ex := p . computedInContext [ string ( hashTraceBuildup ) ] ; ! ex {
p . computedInContext [ string ( hashTraceBuildup ) ] = make ( map [ string ] string )
}
}
}
originOfVariable := p . functions [ getContextFromVariable ( root . value . Out ) ]
if _ , alreadyComputed := mGatesUsed [ composeNewFunction ( root . value . Out , originOfVariable . currentOutputs ( ) ) ] ; alreadyComputed {
return true
if node . OperationType ( ) == IN {
fac := factor { typ : IN , name : node . value . Out , invert : invert , negate : negate , multiplicative : [ 2 ] int { 1 , 1 } }
hashTraceBuildup = hashTogether ( hashTraceBuildup , [ ] byte ( node . value . Out ) )
return [ ] factor { fac } , hashTraceBuildup , true
}
}
variableEnd = p . r1CSRecursiveBuild ( currentCircuit , root . left , mGatesUsed , orderedmGates , negate , inverse )
if out , ex := p . computedInContext [ string ( hashTraceBuildup ) ] [ node . value . Out ] ; ex {
fac := factor { typ : IN , name : out , invert : invert , negate : negate , multiplicative : [ 2 ] int { 1 , 1 } }
hashTraceBuildup = hashTogether ( hashTraceBuildup , [ ] byte ( node . value . Out ) )
return [ ] factor { fac } , hashTraceBuildup , true
}
cons := p . r1CSRecursiveBuild ( currentCircuit , root . right , mGatesUsed , orderedmGates , Xor ( negate , root . value . negate ) , Xor ( inverse , root . value . invert ) )
leftFactors , leftHash , variableEnd := p . r1CSRecursiveBuild ( currentCircuit , node . left , hashTraceBuildup , orderedmGates , negate , invert )
if root . OperationType ( ) == MULTIPLY {
rightFactors , rightHash , cons := p . r1CSRecursiveBuild ( currentCircuit , node . right , hashTraceBuildup , orderedmGates , Xor ( negate , node . value . negate ) , Xor ( invert , node . value . invert ) )
if ! ( variableEnd && cons ) && ! root . value . invert && root != p . getMainCircuit ( ) . root {
return variableEnd || cons
}
root . leftIns = p . collectFactors ( currentCircuit , root . left , mGatesUsed , false , false )
//if root.left.value.Out== root.right.value.Out{
// //note this is not a full copy, but shouldnt be a problem
// root.rightIns= root.leftIns
//}else{
// collectAtomsInSubtree(root.right, mGatesUsed, 1, root.rightIns, functionRootMap, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
//}
//root.rightIns = collectAtomsInSubtree3(root.right, mGatesUsed, Xor(negate, root.value.negate), Xor(inverse, root.value.invert))
root . rightIns = p . collectFactors ( currentCircuit , root . right , mGatesUsed , false , false )
root . index = len ( mGatesUsed )
var nn = composeNewFunction ( root . value . Out , originOfVariable . currentOutputs ( ) )
//var nn = root.value.Out
//if _, ex := p.functions[root.value.Out]; ex {
// nn = currentCircuit.currentOutputName()
//}
if _ , ex := mGatesUsed [ nn ] ; ex {
panic ( fmt . Sprintf ( "told ya so %v" , nn ) )
if node . OperationType ( ) == MULTIPLY {
if ! ( variableEnd && cons ) && ! node . value . invert && node != p . getMainCircuit ( ) . root {
//if !(variableEnd && cons) && !node.value.invert && node != p.getMainCircuit().root {
return mulFactors ( leftFactors , rightFactors ) , append ( leftHash , rightHash ... ) , variableEnd || cons
}
}
mGatesUsed [ nn ] = true
rootGate := cloneGate ( root )
rootGate . value . Out = nn
rootGate := cloneGate ( node )
rootGate . index = len ( * orderedmGates )
rootGate . leftIns = leftFactors
rootGate . rightIns = rightFactors
out := hashTogether ( leftHash , rightHash )
rootGate . value . V1 = rootGate . value . V1 + string ( leftHash [ : 10 ] )
rootGate . value . V2 = rootGate . value . V2 + string ( rightHash [ : 10 ] )
rootGate . value . Out = rootGate . value . Out + string ( out [ : 10 ] )
p . computedInContext [ string ( hashTraceBuildup ) ] [ node . value . Out ] = rootGate . value . Out
* orderedmGates = append ( * orderedmGates , * rootGate )
* orderedmGates = append ( * orderedmGates , * rootGate )
hashTraceBuildup = hashTogether ( hashTraceBuildup , [ ] byte ( rootGate . value . Out ) )
return [ ] factor { { typ : IN , name : rootGate . value . Out , invert : invert , negate : negate , multiplicative : [ 2 ] int { 1 , 1 } } } , hashTraceBuildup , true
}
switch node . OperationType ( ) {
case PLUS :
return addFactors ( leftFactors , rightFactors ) , hashTogether ( leftHash , rightHash ) , variableEnd || cons
default :
panic ( "unexpected gate" )
}
}
return variableEnd || cons
//TODO optimize if output is not a multipication gate
//TODO optimize if output is not a multipication gate
}
}
@ -298,61 +322,6 @@ func addFactors(leftFactors, rightFactors []factor) []factor {
return res
return res
}
}
func ( p * Program ) collectFactors ( contextCircut * Circuit , node * gate , mGatesUsed map [ string ] bool , negate bool , invert bool ) [ ] factor {
if node . OperationType ( ) == CONST {
b1 , v1 := isValue ( node . value . Out )
if ! b1 {
panic ( "not a constant" )
}
if invert {
return [ ] factor { { typ : CONST , negate : negate , multiplicative : [ 2 ] int { 1 , v1 } } }
}
return [ ] factor { { typ : CONST , negate : negate , multiplicative : [ 2 ] int { v1 , 1 } } }
}
if node . OperationType ( ) == FUNC {
nextContext := p . extendedFunctionRenamer ( contextCircut , node . value )
//if _, ex := mGatesUsed[nextContext.currentOutputName()]; ex {
// return []factor{{typ: IN, name: nextContext.currentOutputName(), invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}
//}
contextCircut = nextContext
node = nextContext . root
}
originOfVariable := p . functions [ getContextFromVariable ( node . value . Out ) ]
if originOfVariable == nil {
fmt . Println ( "asdf" )
}
lookingFOr := composeNewFunction ( node . value . Out , originOfVariable . currentOutputs ( ) )
//if _, ex := mGatesUsed[node.value.Out]; ex {
// return []factor{{typ: IN, name: node.value.Out, invert: invert, negate: negate, multiplicative: [2]int{1, 1}}}
//}
if node . OperationType ( ) == IN {
return [ ] factor { { typ : IN , name : lookingFOr , invert : invert , negate : negate , multiplicative : [ 2 ] int { 1 , 1 } } }
}
if _ , alreadyComputed := mGatesUsed [ lookingFOr ] ; alreadyComputed {
return [ ] factor { { typ : IN , name : lookingFOr , invert : invert , negate : negate , multiplicative : [ 2 ] int { 1 , 1 } } }
}
leftFactors := p . collectFactors ( contextCircut , node . left , mGatesUsed , negate , invert )
rightFactors := p . collectFactors ( contextCircut , node . right , mGatesUsed , Xor ( negate , node . value . negate ) , Xor ( invert , node . value . invert ) )
switch node . OperationType ( ) {
case MULTIPLY :
return mulFactors ( leftFactors , rightFactors )
case PLUS :
return addFactors ( leftFactors , rightFactors )
default :
panic ( "unexpected gate" )
}
}
//copies a gate neglecting its references to other gates
//copies a gate neglecting its references to other gates
func cloneGate ( in * gate ) ( out * gate ) {
func cloneGate ( in * gate ) ( out * gate ) {
constr := & Constraint { Inputs : in . value . Inputs , Out : in . value . Out , Op : in . value . Op , invert : in . value . invert , negate : in . value . negate , V2 : in . value . V2 , V1 : in . value . V1 }
constr := & Constraint { Inputs : in . value . Inputs , Out : in . value . Out , Op : in . value . Op , invert : in . value . invert , negate : in . value . negate , V2 : in . value . V2 , V1 : in . value . V1 }
@ -445,7 +414,12 @@ func (p *Program) extendedFunctionRenamer(contextCircuit *Circuit, constraint *C
}
}
func NewProgram ( ) ( p * Program ) {
func NewProgram ( ) ( p * Program ) {
p = & Program { functions : map [ string ] * Circuit { } , globalInputs : [ ] string { "one" } , arithmeticEnvironment : prepareUtils ( ) }
p = & Program {
functions : map [ string ] * Circuit { } ,
globalInputs : [ ] string { "one" } ,
arithmeticEnvironment : prepareUtils ( ) ,
sha256Hasher : sha256 . New ( ) ,
}
return
return
}
}
@ -606,3 +580,18 @@ func (p *Program) CalculateWitness(input []*big.Int) (witness []*big.Int) {
return
return
}
}
var hasher = sha256 . New ( )
func hashFactorWithContext ( f factor , currentCircuit * Circuit ) [ ] byte {
hasher . Reset ( )
hasher . Write ( [ ] byte ( f . name ) )
hasher . Write ( [ ] byte ( currentCircuit . currentOutputName ( ) ) )
return hasher . Sum ( nil )
}
func hashTogether ( a , b [ ] byte ) [ ] byte {
hasher . Reset ( )
hasher . Write ( a )
hasher . Write ( b )
return hasher . Sum ( nil )
}