mirror of
https://github.com/arnaucube/gnark-plonky2-verifier.git
synced 2026-01-12 09:01:32 +01:00
write binsum
This commit is contained in:
135
sha512/binsum.go
Normal file
135
sha512/binsum.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package sha512
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"github.com/consensys/gnark/backend/hint"
|
||||
"github.com/consensys/gnark/frontend"
|
||||
)
|
||||
|
||||
|
||||
func padToSameLength(args [][]frontend.Variable) ([][]frontend.Variable, int) {
|
||||
maxLength := 0
|
||||
for _, v := range args {
|
||||
if len(v) > maxLength {
|
||||
maxLength = len(v)
|
||||
}
|
||||
}
|
||||
result := make([][]frontend.Variable, len(args))
|
||||
for i := 0; i < len(args); i++ {
|
||||
if len(args[i]) < maxLength {
|
||||
arr := make([]frontend.Variable, maxLength)
|
||||
for j := 0; j < maxLength; j++ {
|
||||
if j < len(args[i]) {
|
||||
arr[j] = args[i][j]
|
||||
} else {
|
||||
arr[j] = 0
|
||||
}
|
||||
}
|
||||
result[i] = arr
|
||||
} else {
|
||||
result[i] = args[i]
|
||||
}
|
||||
}
|
||||
return result, maxLength
|
||||
}
|
||||
|
||||
func log2(n int) int {
|
||||
if n <= 0 { panic("undefined") }
|
||||
result := 0
|
||||
n -= 1
|
||||
for n > 0 {
|
||||
n >>= 1
|
||||
result += 1
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func extractBit(n big.Int) bool {
|
||||
if !n.IsInt64() {
|
||||
panic("not bit")
|
||||
}
|
||||
val := n.Int64()
|
||||
if val == 0 {
|
||||
return false
|
||||
} else if val == 1 {
|
||||
return true
|
||||
} else {
|
||||
panic("not bit")
|
||||
}
|
||||
}
|
||||
|
||||
func flatten(arr [][]frontend.Variable) ([]frontend.Variable) {
|
||||
totalLength := 0
|
||||
for _, v := range arr {
|
||||
totalLength += len(v)
|
||||
}
|
||||
result := make([]frontend.Variable, totalLength)
|
||||
i := 0
|
||||
for _, v := range arr {
|
||||
for _, u := range v {
|
||||
result[i] = u
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func BinSum(api frontend.API, args ...[]frontend.Variable) ([]frontend.Variable) {
|
||||
ops := len(args)
|
||||
in, n := padToSameLength(args)
|
||||
nout := n + log2(ops)
|
||||
// var nout = nbits((2**n -1)*ops);
|
||||
// signal input in[ops][n];
|
||||
// signal output out[nout];
|
||||
|
||||
var hintFn hint.Function = func(field *big.Int, inputs []*big.Int, outputs []*big.Int) error {
|
||||
if len(inputs) != ops*n { panic("bad length") }
|
||||
if len(outputs) != nout { panic("bad length") }
|
||||
|
||||
maxOutputValue := big.NewInt(1)
|
||||
maxOutputValue.Lsh(maxOutputValue, uint(nout))
|
||||
if maxOutputValue.Cmp(field) != -1 { panic("overflow") }
|
||||
|
||||
result := big.NewInt(0)
|
||||
for i := 0; i < ops; i++ {
|
||||
placeValue := big.NewInt(1)
|
||||
for j := 0; j < n; j++ {
|
||||
if extractBit(*inputs[i*n+j]) {
|
||||
result.Add(result, placeValue)
|
||||
}
|
||||
placeValue.Add(placeValue, placeValue)
|
||||
}
|
||||
}
|
||||
for i := 0; i < nout; i++ {
|
||||
v := new(big.Int).Rsh(result, uint(i))
|
||||
v.And(v, big.NewInt(1))
|
||||
outputs[i] = v
|
||||
}
|
||||
fmt.Println(ops, n, nout, inputs, outputs)
|
||||
return nil
|
||||
}
|
||||
|
||||
out, err := api.NewHint(hintFn, nout, flatten(in)...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var lhs frontend.Variable = 0
|
||||
var rhs frontend.Variable = 0
|
||||
|
||||
placeValue := big.NewInt(1)
|
||||
for i := 0; i < nout; i++ {
|
||||
for j := 0; j < ops; j++ {
|
||||
if i < n {
|
||||
lhs = api.Add(lhs, api.Mul(placeValue, in[j][i]))
|
||||
}
|
||||
}
|
||||
rhs = api.Add(rhs, api.Mul(placeValue, out[i]))
|
||||
api.AssertIsBoolean(out[i])
|
||||
placeValue.Add(placeValue, placeValue)
|
||||
}
|
||||
api.AssertIsEqual(lhs, rhs)
|
||||
|
||||
return out
|
||||
}
|
||||
52
sha512/binsum_test.go
Normal file
52
sha512/binsum_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package sha512
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"fmt"
|
||||
|
||||
"github.com/consensys/gnark-crypto/ecc"
|
||||
"github.com/consensys/gnark/frontend"
|
||||
"github.com/consensys/gnark/test"
|
||||
// "github.com/ethereum/go-ethereum/crypto/secp256k1"
|
||||
)
|
||||
|
||||
type BinsumTest struct {
|
||||
A []frontend.Variable
|
||||
B []frontend.Variable
|
||||
C []frontend.Variable
|
||||
}
|
||||
|
||||
func (c *BinsumTest) Define(api frontend.API) error {
|
||||
sum := BinSum(api, c.A, c.B)
|
||||
for i := 0; i < len(sum) || i < len(c.C); i++ {
|
||||
fmt.Println(i, c.C)
|
||||
if i < len(sum) && i < len(c.C) {
|
||||
api.Println(sum[i])
|
||||
api.AssertIsEqual(sum[i], c.C[i])
|
||||
} else if i < len(sum) {
|
||||
api.AssertIsEqual(sum[i], 0)
|
||||
} else {
|
||||
fmt.Println(i, c.C[i])
|
||||
api.AssertIsEqual(c.C[i], 0)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBinsum(t *testing.T) {
|
||||
assert := test.NewAssert(t)
|
||||
circuit := BinsumTest{
|
||||
A: []frontend.Variable{0, 0, 0},
|
||||
B: []frontend.Variable{0, 0, 0},
|
||||
C: []frontend.Variable{0, 0, 0, 0},
|
||||
}
|
||||
witness := BinsumTest{
|
||||
A: []frontend.Variable{1, 0, 1},
|
||||
B: []frontend.Variable{1, 1, 1},
|
||||
C: []frontend.Variable{0, 0, 1, 1},
|
||||
}
|
||||
err := test.IsSolved(&circuit, &witness, testCurve.ScalarField())
|
||||
assert.NoError(err)
|
||||
}
|
||||
|
||||
var testCurve = ecc.BN254
|
||||
Reference in New Issue
Block a user