You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

135 lines
2.8 KiB

package sha512
import (
"fmt"
"math/big"
"github.com/consensys/gnark/backend/hint"
"github.com/consensys/gnark/frontend"
)
func padToSameLength(args [][]frontend.Variable) ([][]frontend.Variable, int) {
maxLength := 0
for _, v := range args {
if len(v) > maxLength {
maxLength = len(v)
}
}
result := make([][]frontend.Variable, len(args))
for i := 0; i < len(args); i++ {
if len(args[i]) < maxLength {
arr := make([]frontend.Variable, maxLength)
for j := 0; j < maxLength; j++ {
if j < len(args[i]) {
arr[j] = args[i][j]
} else {
arr[j] = 0
}
}
result[i] = arr
} else {
result[i] = args[i]
}
}
return result, maxLength
}
func log2(n int) int {
if n <= 0 { panic("undefined") }
result := 0
n -= 1
for n > 0 {
n >>= 1
result += 1
}
return result
}
func extractBit(n big.Int) bool {
if !n.IsInt64() {
panic("not bit")
}
val := n.Int64()
if val == 0 {
return false
} else if val == 1 {
return true
} else {
panic("not bit")
}
}
func flatten(arr [][]frontend.Variable) ([]frontend.Variable) {
totalLength := 0
for _, v := range arr {
totalLength += len(v)
}
result := make([]frontend.Variable, totalLength)
i := 0
for _, v := range arr {
for _, u := range v {
result[i] = u
i += 1
}
}
return result
}
func BinSum(api frontend.API, args ...[]frontend.Variable) ([]frontend.Variable) {
ops := len(args)
in, n := padToSameLength(args)
nout := n + log2(ops)
// var nout = nbits((2**n -1)*ops);
// signal input in[ops][n];
// signal output out[nout];
var hintFn hint.Function = func(field *big.Int, inputs []*big.Int, outputs []*big.Int) error {
if len(inputs) != ops*n { panic("bad length") }
if len(outputs) != nout { panic("bad length") }
maxOutputValue := big.NewInt(1)
maxOutputValue.Lsh(maxOutputValue, uint(nout))
if maxOutputValue.Cmp(field) != -1 { panic("overflow") }
result := big.NewInt(0)
for i := 0; i < ops; i++ {
placeValue := big.NewInt(1)
for j := 0; j < n; j++ {
if extractBit(*inputs[i*n+j]) {
result.Add(result, placeValue)
}
placeValue.Add(placeValue, placeValue)
}
}
for i := 0; i < nout; i++ {
v := new(big.Int).Rsh(result, uint(i))
v.And(v, big.NewInt(1))
outputs[i] = v
}
fmt.Println(ops, n, nout, inputs, outputs)
return nil
}
out, err := api.NewHint(hintFn, nout, flatten(in)...)
if err != nil {
panic(err)
}
var lhs frontend.Variable = 0
var rhs frontend.Variable = 0
placeValue := big.NewInt(1)
for i := 0; i < nout; i++ {
for j := 0; j < ops; j++ {
if i < n {
lhs = api.Add(lhs, api.Mul(placeValue, in[j][i]))
}
}
rhs = api.Add(rhs, api.Mul(placeValue, out[i]))
api.AssertIsBoolean(out[i])
placeValue.Add(placeValue, placeValue)
}
api.AssertIsEqual(lhs, rhs)
return out
}