diff --git a/edwards_curve/eddsa25519.go b/edwards_curve/eddsa25519.go index 0a829ac..daa1ab6 100644 --- a/edwards_curve/eddsa25519.go +++ b/edwards_curve/eddsa25519.go @@ -4,6 +4,7 @@ package edwards_curve // This file is little-endian import ( + "fmt" "math/big" "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/emulated" @@ -12,8 +13,11 @@ import ( func H(api frontend.API, m []frontend.Variable) []frontend.Variable { - result := sha512.Sha512Bytes(api, m) - return result[:] + fmt.Println("sha input", m) + rawResult := sha512.Sha512(api, swapByteEndianness(m)) + sResult := swapByteEndianness(rawResult[:]) + fmt.Println("sha output", sResult) + return sResult } func pow2(n uint) *big.Int { @@ -32,10 +36,10 @@ func bits_to_scalar(c *EdCurve, s []frontend.Variable) EdCoordinate { elt := emulated.NewElement[Ed25519](0) if len(elt.Limbs) != 4 { panic("bad length") } i := 0 - elt.Limbs[0] = c.api.FromBinary(s[i:i+64]); i += 64 - elt.Limbs[1] = c.api.FromBinary(s[i:i+64]); i += 64 - elt.Limbs[2] = c.api.FromBinary(s[i:i+64]); i += 64 - elt.Limbs[3] = c.api.FromBinary(s[i:i+64]); i += 64 + elt.Limbs[0] = c.api.FromBinary(s[i:i+64]...); i += 64 + elt.Limbs[1] = c.api.FromBinary(s[i:i+64]...); i += 64 + elt.Limbs[2] = c.api.FromBinary(s[i:i+64]...); i += 64 + elt.Limbs[3] = c.api.FromBinary(s[i:i+64]...); i += 64 if i != len(s) { panic("bad length") } return elt } @@ -54,7 +58,15 @@ func bits_to_scalar(c *EdCurve, s []frontend.Variable) EdCoordinate { func bits_to_element(c *EdCurve, input []frontend.Variable) EdPoint { L := emulated.NewElement[Ed25519Scalars](rEd25519) unchecked_point := decodepoint(c, input) + + // TODO: https://github.com/warner/python-pure25519 says this check is not necessary: + // + // > This library is conservative, and performs full subgroup-membership checks on decoded + // > points, which adds considerable overhead. The Curve25519/Ed25519 algorithms were + // > designed to not require these checks, so a careful application might be able to + // > improve on this slightly (Ed25519 verify down to 6.2ms, DH-finish to 3.2ms). c.AssertIsZero(c.ScalarMul(unchecked_point, L)) + return unchecked_point } @@ -64,17 +76,31 @@ func bits_to_element(c *EdCurve, input []frontend.Variable) EdPoint { // return c.ScalarMul(c.g, a) // } -func checkvalid(c *EdCurve, s, m, pk []frontend.Variable) { +func CheckValid(c *EdCurve, s, m, pk []frontend.Variable) { if len(s) != 512 { panic("bad signature length") } if len(pk) != 256 { panic("bad public key length") } + if len(m) % 8 != 0 { panic("bad message length") } R := bits_to_element(c, s[:256]) A := bits_to_element(c, pk) h := H(c.api, concat(s[:256], pk, m)) + fmt.Println("h", h) + fmt.Println("g", dbg(c.g.X), dbg(c.g.Y)) + fmt.Println("s last half", s[256:]) v1 := c.ScalarMulBinary(c.g, s[256:]) + fmt.Println("v1", dbg(v1.X), dbg(v1.Y)) v2 := c.Add(R, c.ScalarMulBinary(A, h)) + fmt.Println("v2", dbg(v2.X), dbg(v2.Y)) c.AssertIsEqual(v1, v2) } +func reverse[T interface{}](arr []T) []T { + result := make([]T, len(arr)) + for i, v := range arr { + result[len(result)-i-1] = v + } + return result +} + func concat(args ...[]frontend.Variable) []frontend.Variable { result := []frontend.Variable{} for _, v := range args { @@ -83,11 +109,11 @@ func concat(args ...[]frontend.Variable) []frontend.Variable { return result } -func decodepoint(c *EdCurve, input []frontend.Variable) EdPoint { - if len(input) != 256 { panic("bad length") } +func decodepoint(c *EdCurve, unclamped []frontend.Variable) EdPoint { + if len(unclamped) != 256 { panic("bad length") } - s := make([]frontend.Variable, len(input)) - copy(s, input) + s := make([]frontend.Variable, len(unclamped)) + copy(s, unclamped) s[255] = 0 y := bits_to_scalar(c, s) // unclamped = int(binascii.hexlify(s[:32][::-1]), 16) @@ -99,7 +125,7 @@ func decodepoint(c *EdCurve, input []frontend.Variable) EdPoint { xbits := c.baseApi.ToBinary(x) if len(xbits) != 256 { panic("bad length") } - mismatch := c.api.Xor(xbits[0], xbits[255]) + mismatch := c.api.Xor(xbits[0], unclamped[255]) x = c.baseApi.Select(mismatch, c.baseApi.Neg(x), x).(EdCoordinate) // if bool(x & 1) != bool(unclamped & (1<<255)): x = Q-x @@ -115,6 +141,21 @@ func decodepoint(c *EdCurve, input []frontend.Variable) EdPoint { return P } +func toValue(s EdCoordinate) *big.Int { + result := big.NewInt(0) + placeValue := big.NewInt(1) + for _, v := range s.Limbs { + q := new(big.Int).Mul(placeValue, v.(*big.Int)) + result.Add(result, q) + placeValue.Lsh(placeValue, Ed25519{}.BitsPerLimb()) + } + return result +} + +func dbg(s EdCoordinate) string { + return toValue(s).Text(16) +} + func _const(x int64) EdCoordinate { return emulated.NewElement[Ed25519](big.NewInt(x)) } @@ -149,8 +190,8 @@ func xrecover(c *EdCurve, y EdCoordinate) EdCoordinate { x = c.baseApi.Select(matches, x, c.baseApi.Mul(x, emulated.NewElement[Ed25519](I))).(EdCoordinate) // if (x*x - xx) % Q != 0: x = (x*I) % Q - even := c.baseApi.ToBinary(x)[0] - x = c.baseApi.Select(even, x, c.baseApi.Neg(x)).(EdCoordinate) + odd := c.baseApi.ToBinary(x)[0] + x = c.baseApi.Select(odd, c.baseApi.Neg(x), x).(EdCoordinate) // if x % 2 != 0: x = Q-x return x @@ -169,6 +210,17 @@ func pow(c *EdCurve, base EdCoordinate, exponent *big.Int) EdCoordinate { return result } +func swapByteEndianness(in []frontend.Variable) []frontend.Variable { + if len(in) % 8 != 0 { panic("must be a multiple of 8 bits") } + result := make([]frontend.Variable, len(in)) + for i := 0; i < len(in); i += 8 { + for j := 0; j < 8; j++ { + result[i+j] = in[i+7-j] + } + } + return result +} + // def checkvalid(s, m, pk): // if len(s) != 64: raise Exception("signature length is wrong") // if len(pk) != 32: raise Exception("public-key length is wrong") diff --git a/edwards_curve/eddsa25519_test.go b/edwards_curve/eddsa25519_test.go new file mode 100644 index 0000000..d3cbf5f --- /dev/null +++ b/edwards_curve/eddsa25519_test.go @@ -0,0 +1,65 @@ +package edwards_curve + +import ( + "testing" + "encoding/hex" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/test" +) + +type Eddsa25519Circuit struct { + m []frontend.Variable + pk []frontend.Variable + sig []frontend.Variable +} + +func (circuit *Eddsa25519Circuit) Define(api frontend.API) error { + c, err := New[Ed25519, Ed25519Scalars](api) + if err != nil { + return err + } + CheckValid(c, circuit.sig, circuit.m, circuit.pk) + return nil +} + +func TestEddsa25519(t *testing.T) { + assert := test.NewAssert(t) + + m := "53756363696e6374204c616273" + pk := "f7ec1c43f4de9d49556de87b86b26a98942cb078486fdb44de38b80864c39731" + sig := "35c323757c20640a294345c89c0bfcebe3d554fdb0c7b7a0bdb72222c531b1ec849fed99a053e0f5b02dd9a25bb6eb018885526d9f583cdbde0b1e9f6329da09" + + circuit := Eddsa25519Circuit { + m: hexToBits(m), + pk: hexToBits(pk), + sig: hexToBits(sig), + } + witness := Eddsa25519Circuit { + m: hexToBits(m), + pk: hexToBits(pk), + sig: hexToBits(sig), + } + + err := test.IsSolved(&circuit, &witness, testCurve.ScalarField()) + assert.NoError(err) +} + +func hexToBits(h string) []frontend.Variable { + b, err := hex.DecodeString(h) + if err != nil { + panic(err) + } + result := make([]frontend.Variable, len(b) * 8) + for i, v := range b { + for j := 0; j < 8; j++ { + if (v & (1 << j)) != 0 { + result[i*8+j] = 1 + } else { + result[i*8+j] = 0 + } + } + } + return result +} + diff --git a/edwards_curve/edpoint.go b/edwards_curve/edpoint.go index b99733a..d1930bc 100644 --- a/edwards_curve/edpoint.go +++ b/edwards_curve/edpoint.go @@ -185,16 +185,17 @@ func (c *Curve[T, S]) ScalarMul(p AffinePoint[T], s emulated.Element[S]) AffineP } func (c *Curve[T, S]) ScalarMulBinary(p AffinePoint[T], sBits []frontend.Variable) AffinePoint[T] { - res := p - acc := c.Double(p) + res := AffinePoint[T]{ + X: emulated.NewElement[T](0), + Y: emulated.NewElement[T](1), + } + acc := p - for i := 1; i < len(sBits); i++ { + for i := 0; i < len(sBits); i++ { tmp := c.Add(res, acc) res = c.Select(sBits[i], tmp, res) acc = c.Double(acc) } - tmp := c.Add(res, c.Neg(p)) - res = c.Select(sBits[0], res, tmp) return res } diff --git a/sha512/sha512.go b/sha512/sha512.go index b9be0fd..f9a8da6 100644 --- a/sha512/sha512.go +++ b/sha512/sha512.go @@ -15,18 +15,7 @@ func _right_rotate(n [64]frontend.Variable, bits int) [64]frontend.Variable { return result } -func Sha512Bytes(api frontend.API, in []frontend.Variable) ([512]frontend.Variable) { - bits := []frontend.Variable{} - for _, v := range in { - b := api.ToBinary(v, 8) - for i := 0; i < 8; i++ { - bits = append(bits, b[7-i]) - } - } - return Sha512Bits(api, bits) -} - -func Sha512Bits(api frontend.API, in []frontend.Variable) ([512]frontend.Variable) { +func Sha512(api frontend.API, in []frontend.Variable) ([512]frontend.Variable) { _not := func(x [64]frontend.Variable) [64]frontend.Variable { return not(api, x) } diff --git a/sha512/sha_test.go b/sha512/sha_test.go index 66c4fe4..a5d0acb 100644 --- a/sha512/sha_test.go +++ b/sha512/sha_test.go @@ -15,7 +15,7 @@ type Sha512Circuit struct { } func (circuit *Sha512Circuit) Define(api frontend.API) error { - res := Sha512Bits(api, circuit.in) + res := Sha512(api, circuit.in) if len(res) != 512 { panic("bad length") } for i := 0; i < 512; i++ { api.AssertIsEqual(res[i], circuit.out[i]) @@ -28,8 +28,7 @@ var testCurve = ecc.BN254 func TestSha512(t *testing.T) { assert := test.NewAssert(t) - testCase := func(input, output string) { - in := toBytes(input) + testCase := func(in []byte, output string) { out, err := hex.DecodeString(output) if err != nil { panic(err) } if len(out) != 512 / 8 { panic("bad output length") } @@ -46,8 +45,9 @@ func TestSha512(t *testing.T) { assert.NoError(err) } - testCase("", "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e") - testCase("Succinct Labs", "503ace098aa03f6feec1b5df0a38aee923f744a775508bc81f2b94ad139be297c2e8cd8c44af527b5d3f017a7fc929892c896604047e52e3f518924f52bff0dc") + testCase([]byte(""), "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e") + testCase([]byte("Succinct Labs"), "503ace098aa03f6feec1b5df0a38aee923f744a775508bc81f2b94ad139be297c2e8cd8c44af527b5d3f017a7fc929892c896604047e52e3f518924f52bff0dc") + testCase(decode("35c323757c20640a294345c89c0bfcebe3d554fdb0c7b7a0bdb72222c531b1ecf7ec1c43f4de9d49556de87b86b26a98942cb078486fdb44de38b80864c3973153756363696e6374204c616273"), "4388243c4452274402673de881b2f942ff5730fd2c7d8ddb94c3e3d789fb3754380cba8faa40554d9506a0730a681e88ab348a04bc5c41d18926f140b59aed39") } func toBits(arr []byte) []frontend.Variable { @@ -64,6 +64,10 @@ func toBits(arr []byte) []frontend.Variable { return result } -func toBytes(s string) []byte { - return []byte(s) +func decode(s string) []byte { + result, err := hex.DecodeString(s) + if err != nil { + panic(err) + } + return result }