@ -0,0 +1,135 @@ |
|||
package sha512 |
|||
|
|||
import ( |
|||
"fmt" |
|||
"math/big" |
|||
"github.com/consensys/gnark/backend/hint" |
|||
"github.com/consensys/gnark/frontend" |
|||
) |
|||
|
|||
|
|||
func padToSameLength(args [][]frontend.Variable) ([][]frontend.Variable, int) { |
|||
maxLength := 0 |
|||
for _, v := range args { |
|||
if len(v) > maxLength { |
|||
maxLength = len(v) |
|||
} |
|||
} |
|||
result := make([][]frontend.Variable, len(args)) |
|||
for i := 0; i < len(args); i++ { |
|||
if len(args[i]) < maxLength { |
|||
arr := make([]frontend.Variable, maxLength) |
|||
for j := 0; j < maxLength; j++ { |
|||
if j < len(args[i]) { |
|||
arr[j] = args[i][j] |
|||
} else { |
|||
arr[j] = 0 |
|||
} |
|||
} |
|||
result[i] = arr |
|||
} else { |
|||
result[i] = args[i] |
|||
} |
|||
} |
|||
return result, maxLength |
|||
} |
|||
|
|||
func log2(n int) int { |
|||
if n <= 0 { panic("undefined") } |
|||
result := 0 |
|||
n -= 1 |
|||
for n > 0 { |
|||
n >>= 1 |
|||
result += 1 |
|||
} |
|||
return result |
|||
} |
|||
|
|||
func extractBit(n big.Int) bool { |
|||
if !n.IsInt64() { |
|||
panic("not bit") |
|||
} |
|||
val := n.Int64() |
|||
if val == 0 { |
|||
return false |
|||
} else if val == 1 { |
|||
return true |
|||
} else { |
|||
panic("not bit") |
|||
} |
|||
} |
|||
|
|||
func flatten(arr [][]frontend.Variable) ([]frontend.Variable) { |
|||
totalLength := 0 |
|||
for _, v := range arr { |
|||
totalLength += len(v) |
|||
} |
|||
result := make([]frontend.Variable, totalLength) |
|||
i := 0 |
|||
for _, v := range arr { |
|||
for _, u := range v { |
|||
result[i] = u |
|||
i += 1 |
|||
} |
|||
} |
|||
return result |
|||
} |
|||
|
|||
func BinSum(api frontend.API, args ...[]frontend.Variable) ([]frontend.Variable) { |
|||
ops := len(args) |
|||
in, n := padToSameLength(args) |
|||
nout := n + log2(ops) |
|||
// var nout = nbits((2**n -1)*ops);
|
|||
// signal input in[ops][n];
|
|||
// signal output out[nout];
|
|||
|
|||
var hintFn hint.Function = func(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { |
|||
if len(inputs) != ops*n { panic("bad length") } |
|||
if len(outputs) != nout { panic("bad length") } |
|||
|
|||
maxOutputValue := big.NewInt(1) |
|||
maxOutputValue.Lsh(maxOutputValue, uint(nout)) |
|||
if maxOutputValue.Cmp(field) != -1 { panic("overflow") } |
|||
|
|||
result := big.NewInt(0) |
|||
for i := 0; i < ops; i++ { |
|||
placeValue := big.NewInt(1) |
|||
for j := 0; j < n; j++ { |
|||
if extractBit(*inputs[i*n+j]) { |
|||
result.Add(result, placeValue) |
|||
} |
|||
placeValue.Add(placeValue, placeValue) |
|||
} |
|||
} |
|||
for i := 0; i < nout; i++ { |
|||
v := new(big.Int).Rsh(result, uint(i)) |
|||
v.And(v, big.NewInt(1)) |
|||
outputs[i] = v |
|||
} |
|||
fmt.Println(ops, n, nout, inputs, outputs) |
|||
return nil |
|||
} |
|||
|
|||
out, err := api.NewHint(hintFn, nout, flatten(in)...) |
|||
if err != nil { |
|||
panic(err) |
|||
} |
|||
|
|||
var lhs frontend.Variable = 0 |
|||
var rhs frontend.Variable = 0 |
|||
|
|||
placeValue := big.NewInt(1) |
|||
for i := 0; i < nout; i++ { |
|||
for j := 0; j < ops; j++ { |
|||
if i < n { |
|||
lhs = api.Add(lhs, api.Mul(placeValue, in[j][i])) |
|||
} |
|||
} |
|||
rhs = api.Add(rhs, api.Mul(placeValue, out[i])) |
|||
api.AssertIsBoolean(out[i]) |
|||
placeValue.Add(placeValue, placeValue) |
|||
} |
|||
api.AssertIsEqual(lhs, rhs) |
|||
|
|||
return out |
|||
} |
@ -0,0 +1,52 @@ |
|||
package sha512 |
|||
|
|||
import ( |
|||
"testing" |
|||
"fmt" |
|||
|
|||
"github.com/consensys/gnark-crypto/ecc" |
|||
"github.com/consensys/gnark/frontend" |
|||
"github.com/consensys/gnark/test" |
|||
// "github.com/ethereum/go-ethereum/crypto/secp256k1"
|
|||
) |
|||
|
|||
type BinsumTest struct { |
|||
A []frontend.Variable |
|||
B []frontend.Variable |
|||
C []frontend.Variable |
|||
} |
|||
|
|||
func (c *BinsumTest) Define(api frontend.API) error { |
|||
sum := BinSum(api, c.A, c.B) |
|||
for i := 0; i < len(sum) || i < len(c.C); i++ { |
|||
fmt.Println(i, c.C) |
|||
if i < len(sum) && i < len(c.C) { |
|||
api.Println(sum[i]) |
|||
api.AssertIsEqual(sum[i], c.C[i]) |
|||
} else if i < len(sum) { |
|||
api.AssertIsEqual(sum[i], 0) |
|||
} else { |
|||
fmt.Println(i, c.C[i]) |
|||
api.AssertIsEqual(c.C[i], 0) |
|||
} |
|||
} |
|||
return nil |
|||
} |
|||
|
|||
func TestBinsum(t *testing.T) { |
|||
assert := test.NewAssert(t) |
|||
circuit := BinsumTest{ |
|||
A: []frontend.Variable{0, 0, 0}, |
|||
B: []frontend.Variable{0, 0, 0}, |
|||
C: []frontend.Variable{0, 0, 0, 0}, |
|||
} |
|||
witness := BinsumTest{ |
|||
A: []frontend.Variable{1, 0, 1}, |
|||
B: []frontend.Variable{1, 1, 1}, |
|||
C: []frontend.Variable{0, 0, 1, 1}, |
|||
} |
|||
err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) |
|||
assert.NoError(err) |
|||
} |
|||
|
|||
var testCurve = ecc.BN254 |