diff --git a/sha512/binsum.go b/sha512/binsum.go new file mode 100644 index 0000000..029761e --- /dev/null +++ b/sha512/binsum.go @@ -0,0 +1,135 @@ +package sha512 + +import ( + "fmt" + "math/big" + "github.com/consensys/gnark/backend/hint" + "github.com/consensys/gnark/frontend" +) + + +func padToSameLength(args [][]frontend.Variable) ([][]frontend.Variable, int) { + maxLength := 0 + for _, v := range args { + if len(v) > maxLength { + maxLength = len(v) + } + } + result := make([][]frontend.Variable, len(args)) + for i := 0; i < len(args); i++ { + if len(args[i]) < maxLength { + arr := make([]frontend.Variable, maxLength) + for j := 0; j < maxLength; j++ { + if j < len(args[i]) { + arr[j] = args[i][j] + } else { + arr[j] = 0 + } + } + result[i] = arr + } else { + result[i] = args[i] + } + } + return result, maxLength +} + +func log2(n int) int { + if n <= 0 { panic("undefined") } + result := 0 + n -= 1 + for n > 0 { + n >>= 1 + result += 1 + } + return result +} + +func extractBit(n big.Int) bool { + if !n.IsInt64() { + panic("not bit") + } + val := n.Int64() + if val == 0 { + return false + } else if val == 1 { + return true + } else { + panic("not bit") + } +} + +func flatten(arr [][]frontend.Variable) ([]frontend.Variable) { + totalLength := 0 + for _, v := range arr { + totalLength += len(v) + } + result := make([]frontend.Variable, totalLength) + i := 0 + for _, v := range arr { + for _, u := range v { + result[i] = u + i += 1 + } + } + return result +} + +func BinSum(api frontend.API, args ...[]frontend.Variable) ([]frontend.Variable) { + ops := len(args) + in, n := padToSameLength(args) + nout := n + log2(ops) +// var nout = nbits((2**n -1)*ops); +// signal input in[ops][n]; +// signal output out[nout]; + + var hintFn hint.Function = func(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(inputs) != ops*n { panic("bad length") } + if len(outputs) != nout { panic("bad length") } + + maxOutputValue := big.NewInt(1) + maxOutputValue.Lsh(maxOutputValue, uint(nout)) + if maxOutputValue.Cmp(field) != -1 { panic("overflow") } + + result := big.NewInt(0) + for i := 0; i < ops; i++ { + placeValue := big.NewInt(1) + for j := 0; j < n; j++ { + if extractBit(*inputs[i*n+j]) { + result.Add(result, placeValue) + } + placeValue.Add(placeValue, placeValue) + } + } + for i := 0; i < nout; i++ { + v := new(big.Int).Rsh(result, uint(i)) + v.And(v, big.NewInt(1)) + outputs[i] = v + } + fmt.Println(ops, n, nout, inputs, outputs) + return nil + } + + out, err := api.NewHint(hintFn, nout, flatten(in)...) + if err != nil { + panic(err) + } + + var lhs frontend.Variable = 0 + var rhs frontend.Variable = 0 + + placeValue := big.NewInt(1) + for i := 0; i < nout; i++ { + for j := 0; j < ops; j++ { + if i < n { + lhs = api.Add(lhs, api.Mul(placeValue, in[j][i])) + } + } + rhs = api.Add(rhs, api.Mul(placeValue, out[i])) + api.AssertIsBoolean(out[i]) + placeValue.Add(placeValue, placeValue) + } + api.AssertIsEqual(lhs, rhs) + + return out +} diff --git a/sha512/binsum_test.go b/sha512/binsum_test.go new file mode 100644 index 0000000..2389c1b --- /dev/null +++ b/sha512/binsum_test.go @@ -0,0 +1,52 @@ +package sha512 + +import ( + "testing" + "fmt" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" + // "github.com/ethereum/go-ethereum/crypto/secp256k1" +) + +type BinsumTest struct { + A []frontend.Variable + B []frontend.Variable + C []frontend.Variable +} + +func (c *BinsumTest) Define(api frontend.API) error { + sum := BinSum(api, c.A, c.B) + for i := 0; i < len(sum) || i < len(c.C); i++ { + fmt.Println(i, c.C) + if i < len(sum) && i < len(c.C) { + api.Println(sum[i]) + api.AssertIsEqual(sum[i], c.C[i]) + } else if i < len(sum) { + api.AssertIsEqual(sum[i], 0) + } else { + fmt.Println(i, c.C[i]) + api.AssertIsEqual(c.C[i], 0) + } + } + return nil +} + +func TestBinsum(t *testing.T) { + assert := test.NewAssert(t) + circuit := BinsumTest{ + A: []frontend.Variable{0, 0, 0}, + B: []frontend.Variable{0, 0, 0}, + C: []frontend.Variable{0, 0, 0, 0}, + } + witness := BinsumTest{ + A: []frontend.Variable{1, 0, 1}, + B: []frontend.Variable{1, 1, 1}, + C: []frontend.Variable{0, 0, 1, 1}, + } + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +var testCurve = ecc.BN254