From 5b79ded5402fca8ee0763efd689865f8e3c6975c Mon Sep 17 00:00:00 2001 From: Eduard S Date: Mon, 1 Jul 2019 12:51:34 +0200 Subject: [PATCH] Move constants and utils to package, apply small fixes --- babyjub/babyjub.go | 91 +++++++++++++--------------------- babyjub/babyjub_test.go | 67 ++++++++++++------------- babyjub/eddsa.go | 29 +++++------ babyjub/eddsa_test.go | 10 ++-- babyjub/helpers.go | 75 ---------------------------- constants/constants.go | 26 ++++++++++ go.sum | 1 + mimc7/mimc7.go | 7 +-- poseidon/poseidon.go | 62 +++++++++-------------- utils/utils.go | 106 ++++++++++++++++++++++++++++++++++++++++ 10 files changed, 245 insertions(+), 229 deletions(-) create mode 100644 constants/constants.go create mode 100644 utils/utils.go diff --git a/babyjub/babyjub.go b/babyjub/babyjub.go index 81dd405..e01e13b 100644 --- a/babyjub/babyjub.go +++ b/babyjub/babyjub.go @@ -2,27 +2,17 @@ package babyjub import ( "fmt" + "github.com/iden3/go-iden3-crypto/constants" + "github.com/iden3/go-iden3-crypto/utils" "math/big" ) -// Q is the order of the integer field where the curve point coordinates are (Zq). -var Q *big.Int - // A is one of the babyjub constants. var A *big.Int // D is one of the babyjub constants. var D *big.Int -// Zero is 0. -var Zero *big.Int - -// One is 1. -var One *big.Int - -// MinusOne is -1. -var MinusOne *big.Int - // Order of the babyjub curve. var Order *big.Int @@ -34,34 +24,19 @@ var SubOrder *big.Int // the subgroup in the curve. var B8 *Point -// NewIntFromString creates a new big.Int from a decimal integer encoded as a -// string. It will panic if the string is not a decimal integer. -func NewIntFromString(s string) *big.Int { - v, ok := new(big.Int).SetString(s, 10) - if !ok { - panic(fmt.Sprintf("Bad base 10 string %s", s)) - } - return v -} - // init initializes global numbers and the subgroup base. func init() { - Zero = big.NewInt(0) - One = big.NewInt(1) - MinusOne = big.NewInt(-1) - Q = NewIntFromString( - "21888242871839275222246405745257275088548364400416034343698204186575808495617") - A = NewIntFromString("168700") - D = NewIntFromString("168696") - - Order = NewIntFromString( + A = utils.NewIntFromString("168700") + D = utils.NewIntFromString("168696") + + Order = utils.NewIntFromString( "21888242871839275222246405745257275088614511777268538073601725287587578984328") SubOrder = new(big.Int).Rsh(Order, 3) B8 = NewPoint() - B8.X = NewIntFromString( + B8.X = utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - B8.Y = NewIntFromString( + B8.Y = utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") } @@ -95,9 +70,9 @@ func (res *Point) Add(a *Point, b *Point) *Point { x2.Mul(x2, b.X) x2.Mul(x2, a.Y) x2.Mul(x2, b.Y) - x2.Add(One, x2) - x2.Mod(x2, Q) - x2.ModInverse(x2, Q) // x2 = (1 + D * a.x * b.x * a.y * b.y)^-1 + x2.Add(constants.One, x2) + x2.Mod(x2, constants.Q) + x2.ModInverse(x2, constants.Q) // x2 = (1 + D * a.x * b.x * a.y * b.y)^-1 // y = (a.y * b.y + A * a.x * a.x) * (1 - D * a.x * b.x * a.y * b.y)^-1 mod q y1a := new(big.Int).Mul(a.Y, b.Y) @@ -112,15 +87,15 @@ func (res *Point) Add(a *Point, b *Point) *Point { y2.Mul(y2, b.X) y2.Mul(y2, a.Y) y2.Mul(y2, b.Y) - y2.Sub(One, y2) - y2.Mod(y2, Q) - y2.ModInverse(y2, Q) // y2 = (1 - D * a.x * b.x * a.y * b.y)^-1 + y2.Sub(constants.One, y2) + y2.Mod(y2, constants.Q) + y2.ModInverse(y2, constants.Q) // y2 = (1 - D * a.x * b.x * a.y * b.y)^-1 res.X = x1a.Mul(x1a, x2) - res.X = res.X.Mod(res.X, Q) + res.X = res.X.Mod(res.X, constants.Q) res.Y = y1a.Mul(y1a, y2) - res.Y = res.Y.Mod(res.Y, Q) + res.Y = res.Y.Mod(res.Y, constants.Q) return res } @@ -146,21 +121,21 @@ func (res *Point) Mul(s *big.Int, p *Point) *Point { func (p *Point) InCurve() bool { x2 := new(big.Int).Set(p.X) x2.Mul(x2, x2) - x2.Mod(x2, Q) + x2.Mod(x2, constants.Q) y2 := new(big.Int).Set(p.Y) y2.Mul(y2, y2) - y2.Mod(y2, Q) + y2.Mod(y2, constants.Q) a := new(big.Int).Mul(A, x2) a.Add(a, y2) - a.Mod(a, Q) + a.Mod(a, constants.Q) b := new(big.Int).Set(D) b.Mul(b, x2) b.Mul(b, y2) - b.Add(One, b) - b.Mod(b, Q) + b.Add(constants.One, b) + b.Mod(b, constants.Q) return a.Cmp(b) == 0 } @@ -172,20 +147,20 @@ func (p *Point) InSubGroup() bool { return false } res := NewPoint().Mul(SubOrder, p) - return (res.X.Cmp(Zero) == 0) && (res.Y.Cmp(One) == 0) + return (res.X.Cmp(constants.Zero) == 0) && (res.Y.Cmp(constants.One) == 0) } // PointCoordSign returns the sign of the curve point coordinate. It returns // false if the sign is positive and false if the sign is negative. func PointCoordSign(c *big.Int) bool { - if c.Cmp(new(big.Int).Rsh(Q, 1)) == 1 { + if c.Cmp(new(big.Int).Rsh(constants.Q, 1)) == 1 { return true } return false } func PackPoint(ay *big.Int, sign bool) [32]byte { - leBuf := BigIntLEBytes(ay) + leBuf := utils.BigIntLEBytes(ay) if sign { leBuf[31] = leBuf[31] | 0x80 } @@ -210,31 +185,31 @@ func (p *Point) Decompress(leBuf [32]byte) (*Point, error) { sign = true leBuf[31] = leBuf[31] & 0x7F } - SetBigIntFromLEBytes(p.Y, leBuf[:]) - if p.Y.Cmp(Q) >= 0 { + utils.SetBigIntFromLEBytes(p.Y, leBuf[:]) + if p.Y.Cmp(constants.Q) >= 0 { return nil, fmt.Errorf("p.y >= Q") } y2 := new(big.Int).Mul(p.Y, p.Y) - y2.Mod(y2, Q) + y2.Mod(y2, constants.Q) xa := big.NewInt(1) xa.Sub(xa, y2) // xa == 1 - y^2 xb := new(big.Int).Mul(D, y2) - xb.Mod(xb, Q) + xb.Mod(xb, constants.Q) xb.Sub(A, xb) // xb = A - d * y^2 if xb.Cmp(big.NewInt(0)) == 0 { return nil, fmt.Errorf("division by 0") } - xb.ModInverse(xb, Q) + xb.ModInverse(xb, constants.Q) p.X.Mul(xa, xb) // xa / xb - p.X.Mod(p.X, Q) - p.X.ModSqrt(p.X, Q) + p.X.Mod(p.X, constants.Q) + p.X.ModSqrt(p.X, constants.Q) if (sign && !PointCoordSign(p.X)) || (!sign && PointCoordSign(p.X)) { - p.X.Mul(p.X, MinusOne) + p.X.Mul(p.X, constants.MinusOne) } - p.X.Mod(p.X, Q) + p.X.Mod(p.X, constants.Q) return p, nil } diff --git a/babyjub/babyjub_test.go b/babyjub/babyjub_test.go index 0328412..8791d67 100644 --- a/babyjub/babyjub_test.go +++ b/babyjub/babyjub_test.go @@ -5,6 +5,7 @@ import ( "math/big" "testing" + "github.com/iden3/go-iden3-crypto/utils" "github.com/stretchr/testify/assert" ) @@ -19,15 +20,15 @@ func TestAdd1(t *testing.T) { } func TestAdd2(t *testing.T) { - aX := NewIntFromString( + aX := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - aY := NewIntFromString( + aY := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") a := &Point{X: aX, Y: aY} - bX := NewIntFromString( + bX := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - bY := NewIntFromString( + bY := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") b := &Point{X: bX, Y: bY} @@ -42,15 +43,15 @@ func TestAdd2(t *testing.T) { } func TestAdd3(t *testing.T) { - aX := NewIntFromString( + aX := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - aY := NewIntFromString( + aY := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") a := &Point{X: aX, Y: aY} - bX := NewIntFromString( + bX := utils.NewIntFromString( "16540640123574156134436876038791482806971768689494387082833631921987005038935") - bY := NewIntFromString( + bY := utils.NewIntFromString( "20819045374670962167435360035096875258406992893633759881276124905556507972311") b := &Point{X: bX, Y: bY} @@ -65,15 +66,15 @@ func TestAdd3(t *testing.T) { } func TestAdd4(t *testing.T) { - aX := NewIntFromString( + aX := utils.NewIntFromString( "0") - aY := NewIntFromString( + aY := utils.NewIntFromString( "1") a := &Point{X: aX, Y: aY} - bX := NewIntFromString( + bX := utils.NewIntFromString( "16540640123574156134436876038791482806971768689494387082833631921987005038935") - bY := NewIntFromString( + bY := utils.NewIntFromString( "20819045374670962167435360035096875258406992893633759881276124905556507972311") b := &Point{X: bX, Y: bY} @@ -98,12 +99,12 @@ func TestInCurve2(t *testing.T) { } func TestMul0(t *testing.T) { - x := NewIntFromString( + x := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := NewIntFromString( + y := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} - s := NewIntFromString("3") + s := utils.NewIntFromString("3") r2 := NewPoint().Add(p, p) r2 = NewPoint().Add(r2, p) @@ -120,12 +121,12 @@ func TestMul0(t *testing.T) { } func TestMul1(t *testing.T) { - x := NewIntFromString( + x := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := NewIntFromString( + y := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} - s := NewIntFromString( + s := utils.NewIntFromString( "14035240266687799601661095864649209771790948434046947201833777492504781204499") r := NewPoint().Mul(s, p) assert.Equal(t, @@ -137,12 +138,12 @@ func TestMul1(t *testing.T) { } func TestMul2(t *testing.T) { - x := NewIntFromString( + x := utils.NewIntFromString( "6890855772600357754907169075114257697580319025794532037257385534741338397365") - y := NewIntFromString( + y := utils.NewIntFromString( "4338620300185947561074059802482547481416142213883829469920100239455078257889") p := &Point{X: x, Y: y} - s := NewIntFromString( + s := utils.NewIntFromString( "20819045374670962167435360035096875258406992893633759881276124905556507972311") r := NewPoint().Mul(s, p) assert.Equal(t, @@ -154,45 +155,45 @@ func TestMul2(t *testing.T) { } func TestInCurve3(t *testing.T) { - x := NewIntFromString( + x := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := NewIntFromString( + y := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} assert.Equal(t, true, p.InCurve()) } func TestInCurve4(t *testing.T) { - x := NewIntFromString( + x := utils.NewIntFromString( "6890855772600357754907169075114257697580319025794532037257385534741338397365") - y := NewIntFromString( + y := utils.NewIntFromString( "4338620300185947561074059802482547481416142213883829469920100239455078257889") p := &Point{X: x, Y: y} assert.Equal(t, true, p.InCurve()) } func TestInSubGroup1(t *testing.T) { - x := NewIntFromString( + x := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := NewIntFromString( + y := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} assert.Equal(t, true, p.InSubGroup()) } func TestInSubGroup2(t *testing.T) { - x := NewIntFromString( + x := utils.NewIntFromString( "6890855772600357754907169075114257697580319025794532037257385534741338397365") - y := NewIntFromString( + y := utils.NewIntFromString( "4338620300185947561074059802482547481416142213883829469920100239455078257889") p := &Point{X: x, Y: y} assert.Equal(t, true, p.InSubGroup()) } func TestCompressDecompress1(t *testing.T) { - x := NewIntFromString( + x := utils.NewIntFromString( "17777552123799933955779906779655732241715742912184938656739573121738514868268") - y := NewIntFromString( + y := utils.NewIntFromString( "2626589144620713026669568689430873010625803728049924121243784502389097019475") p := &Point{X: x, Y: y} @@ -206,9 +207,9 @@ func TestCompressDecompress1(t *testing.T) { } func TestCompressDecompress2(t *testing.T) { - x := NewIntFromString( + x := utils.NewIntFromString( "6890855772600357754907169075114257697580319025794532037257385534741338397365") - y := NewIntFromString( + y := utils.NewIntFromString( "4338620300185947561074059802482547481416142213883829469920100239455078257889") p := &Point{X: x, Y: y} diff --git a/babyjub/eddsa.go b/babyjub/eddsa.go index e9e7cd7..3274b07 100644 --- a/babyjub/eddsa.go +++ b/babyjub/eddsa.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "github.com/iden3/go-iden3-crypto/mimc7" + "github.com/iden3/go-iden3-crypto/utils" "math/big" ) @@ -39,7 +40,7 @@ func (k *PrivateKey) Scalar() *PrivKeyScalar { copy(sBuf32[:], sBuf[:32]) pruneBuffer(&sBuf32) s := new(big.Int) - SetBigIntFromLEBytes(s, sBuf32[:]) + utils.SetBigIntFromLEBytes(s, sBuf32[:]) s.Rsh(s, 3) return NewPrivKeyScalar(s) } @@ -76,17 +77,17 @@ type PublicKey Point func (pk PublicKey) MarshalText() ([]byte, error) { pkc := pk.Compress() - return Hex(pkc[:]).MarshalText() + return utils.Hex(pkc[:]).MarshalText() } func (pk PublicKey) String() string { pkc := pk.Compress() - return Hex(pkc[:]).String() + return utils.Hex(pkc[:]).String() } func (pk *PublicKey) UnmarshalText(h []byte) error { var pkc PublicKeyComp - if err := HexDecodeInto(pkc[:], h); err != nil { + if err := utils.HexDecodeInto(pkc[:], h); err != nil { return err } pkd, err := pkc.Decompress() @@ -106,9 +107,9 @@ func (p *PublicKey) Point() *Point { // point. type PublicKeyComp [32]byte -func (buf PublicKeyComp) MarshalText() ([]byte, error) { return Hex(buf[:]).MarshalText() } -func (buf PublicKeyComp) String() string { return Hex(buf[:]).String() } -func (buf *PublicKeyComp) UnmarshalText(h []byte) error { return HexDecodeInto(buf[:], h) } +func (buf PublicKeyComp) MarshalText() ([]byte, error) { return utils.Hex(buf[:]).MarshalText() } +func (buf PublicKeyComp) String() string { return utils.Hex(buf[:]).String() } +func (buf *PublicKeyComp) UnmarshalText(h []byte) error { return utils.HexDecodeInto(buf[:], h) } func (p *PublicKey) Compress() PublicKeyComp { return PublicKeyComp((*Point)(p).Compress()) @@ -132,15 +133,15 @@ type Signature struct { // SignatureComp represents a compressed EdDSA signature. type SignatureComp [64]byte -func (buf SignatureComp) MarshalText() ([]byte, error) { return Hex(buf[:]).MarshalText() } -func (buf SignatureComp) String() string { return Hex(buf[:]).String() } -func (buf *SignatureComp) UnmarshalText(h []byte) error { return HexDecodeInto(buf[:], h) } +func (buf SignatureComp) MarshalText() ([]byte, error) { return utils.Hex(buf[:]).MarshalText() } +func (buf SignatureComp) String() string { return utils.Hex(buf[:]).String() } +func (buf *SignatureComp) UnmarshalText(h []byte) error { return utils.HexDecodeInto(buf[:], h) } // Compress an EdDSA signature by concatenating the compression of // the point R8 and the Little-Endian encoding of S. func (s *Signature) Compress() SignatureComp { R8p := s.R8.Compress() - Sp := BigIntLEBytes(s.S) + Sp := utils.BigIntLEBytes(s.S) buf := [64]byte{} copy(buf[:32], R8p[:]) copy(buf[32:], Sp[:]) @@ -156,7 +157,7 @@ func (s *Signature) Decompress(buf [64]byte) (*Signature, error) { if s.R8, err = NewPoint().Decompress(R8p); err != nil { return nil, err } - s.S = SetBigIntFromLEBytes(new(big.Int), buf[32:]) + s.S = utils.SetBigIntFromLEBytes(new(big.Int), buf[32:]) return s, nil } @@ -170,11 +171,11 @@ func (s *SignatureComp) Decompress() (*Signature, error) { // for buffer hashing and mimc7 for big.Int hashing. func (k *PrivateKey) SignMimc7(msg *big.Int) *Signature { h1 := Blake512(k[:]) - msgBuf := BigIntLEBytes(msg) + msgBuf := utils.BigIntLEBytes(msg) msgBuf32 := [32]byte{} copy(msgBuf32[:], msgBuf[:]) rBuf := Blake512(append(h1[32:], msgBuf32[:]...)) - r := SetBigIntFromLEBytes(new(big.Int), rBuf) // r = H(H_{32..63}(k), msg) + r := utils.SetBigIntFromLEBytes(new(big.Int), rBuf) // r = H(H_{32..63}(k), msg) r.Mod(r, SubOrder) R8 := NewPoint().Mul(r, B8) // R8 = r * 8 * B A := k.Public().Point() diff --git a/babyjub/eddsa_test.go b/babyjub/eddsa_test.go index c9c4153..90b5733 100644 --- a/babyjub/eddsa_test.go +++ b/babyjub/eddsa_test.go @@ -5,6 +5,8 @@ import ( "encoding/hex" "fmt" + "github.com/iden3/go-iden3-crypto/constants" + "github.com/iden3/go-iden3-crypto/utils" "github.com/stretchr/testify/assert" "math/big" @@ -17,8 +19,8 @@ func genInputs() (*PrivateKey, *big.Int) { msgBuf := [32]byte{} rand.Read(msgBuf[:]) - msg := SetBigIntFromLEBytes(new(big.Int), msgBuf[:]) - msg.Mod(msg, Q) + msg := utils.SetBigIntFromLEBytes(new(big.Int), msgBuf[:]) + msg.Mod(msg, constants.Q) fmt.Println("msg", msg) return &k, msg @@ -31,7 +33,7 @@ func TestSignVerify1(t *testing.T) { if err != nil { panic(err) } - msg := SetBigIntFromLEBytes(new(big.Int), msgBuf) + msg := utils.SetBigIntFromLEBytes(new(big.Int), msgBuf) pk := k.Public() assert.Equal(t, @@ -77,7 +79,7 @@ func TestCompressDecompress(t *testing.T) { if err != nil { panic(err) } - msg := SetBigIntFromLEBytes(new(big.Int), msgBuf) + msg := utils.SetBigIntFromLEBytes(new(big.Int), msgBuf) sig := k.SignMimc7(msg) sigBuf := sig.Compress() sig2, err := new(Signature).Decompress(sigBuf) diff --git a/babyjub/helpers.go b/babyjub/helpers.go index 1372a2f..2392cf9 100644 --- a/babyjub/helpers.go +++ b/babyjub/helpers.go @@ -1,40 +1,9 @@ package babyjub import ( - "bytes" - "encoding/hex" - "fmt" - "math/big" - "strings" - "github.com/dchest/blake512" // I have personally reviewed that this module doesn't do anything suspicious ) -// SwapEndianness swaps the endianness of the value encoded in xs. If xs is -// Big-Endian, the result will be Little-Endian and viceversa. -func SwapEndianness(xs []byte) []byte { - ys := make([]byte, len(xs)) - for i, b := range xs { - ys[len(xs)-1-i] = b - } - return ys -} - -// BigIntLEBytes encodes a big.Int into an array in Little-Endian. -func BigIntLEBytes(v *big.Int) [32]byte { - le := SwapEndianness(v.Bytes()) - res := [32]byte{} - copy(res[:], le) - return res -} - -// SetBigIntFromLEBytes sets the value of a big.Int from a Little-Endian -// encoded value. -func SetBigIntFromLEBytes(v *big.Int, leBuf []byte) *big.Int { - beBuf := SwapEndianness(leBuf) - return v.SetBytes(beBuf) -} - // Blake512 performs the blake-512 hash over the buffer m. Note that this is // the original blake from the SHA3 competition and not the new blake2 version. func Blake512(m []byte) []byte { @@ -42,47 +11,3 @@ func Blake512(m []byte) []byte { h.Write(m[:]) return h.Sum(nil) } - -// Hex is a byte slice type that can be marshalled and unmarshaled in hex -type Hex []byte - -// MarshalText encodes buf as hex -func (buf Hex) MarshalText() ([]byte, error) { - return []byte(hex.EncodeToString(buf)), nil -} - -// String encodes buf as hex -func (buf Hex) String() string { - return hex.EncodeToString(buf) -} - -// HexEncode encodes an array of bytes into a string in hex. -func HexEncode(bs []byte) string { - return fmt.Sprintf("0x%s", hex.EncodeToString(bs)) -} - -// HexDecode decodes a hex string into an array of bytes. -func HexDecode(h string) ([]byte, error) { - if strings.HasPrefix(h, "0x") { - h = h[2:] - } - return hex.DecodeString(h) -} - -// HexDecodeInto decodes a hex string into an array of bytes (dst), verifying -// that the decoded array has the same length as dst. -func HexDecodeInto(dst []byte, h []byte) error { - if bytes.HasPrefix(h, []byte("0x")) { - h = h[2:] - } - if len(h)/2 != len(dst) { - return fmt.Errorf("expected %v bytes in hex string, got %v", len(dst), len(h)/2) - } - n, err := hex.Decode(dst, h) - if err != nil { - return err - } else if n != len(dst) { - return fmt.Errorf("expected %v bytes when decoding hex string, got %v", len(dst), n) - } - return nil -} diff --git a/constants/constants.go b/constants/constants.go new file mode 100644 index 0000000..986ceb8 --- /dev/null +++ b/constants/constants.go @@ -0,0 +1,26 @@ +package constants + +import ( + "github.com/iden3/go-iden3-crypto/utils" + "math/big" +) + +// Q is the order of the integer field (Zq) that fits inside the SNARK. +var Q *big.Int + +// Zero is 0. +var Zero *big.Int + +// One is 1. +var One *big.Int + +// MinusOne is -1. +var MinusOne *big.Int + +func init() { + Zero = big.NewInt(0) + One = big.NewInt(1) + MinusOne = big.NewInt(-1) + Q = utils.NewIntFromString( + "21888242871839275222246405745257275088548364400416034343698204186575808495617") +} diff --git a/go.sum b/go.sum index 3cbf7c6..6d5fdbb 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,7 @@ github.com/dchest/blake512 v1.0.0 h1:oDFEQFIqFSeuA34xLtXZ/rWxCXdSjirjzPhey5EUvmA github.com/dchest/blake512 v1.0.0/go.mod h1:FV1x7xPPLWukZlpDpWQ88rF/SFwZ5qbskrzhLMB92JI= github.com/ethereum/go-ethereum v1.8.27 h1:d+gkiLaBDk5fn3Pe/xNVaMrB/ozI+AUB2IlVBp29IrY= github.com/ethereum/go-ethereum v1.8.27/go.mod h1:PwpWDrCLZrV+tfrhqqF6kPknbISMHaJv9Ln3kPCZLwY= +github.com/iden3/go-iden3 v0.0.5 h1:NV6HXnLmp+1YmKd2FmymzU6OAP77q1WWDcB/B+BUL9g= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/mimc7/mimc7.go b/mimc7/mimc7.go index 18afdf2..5634925 100644 --- a/mimc7/mimc7.go +++ b/mimc7/mimc7.go @@ -6,6 +6,7 @@ import ( "math/big" "github.com/ethereum/go-ethereum/crypto" + _constants "github.com/iden3/go-iden3-crypto/constants" "github.com/iden3/go-iden3-crypto/field" ) @@ -31,11 +32,7 @@ func getIV(seed string) { func generateConstantsData() constantsData { var constants constantsData - r, ok := new(big.Int).SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) - if !ok { - - } - fqR := field.NewFq(r) + fqR := field.NewFq(_constants.Q) constants.fqR = fqR // maxFieldVal is the R value of the Finite Field diff --git a/poseidon/poseidon.go b/poseidon/poseidon.go index cb1a828..730479e 100644 --- a/poseidon/poseidon.go +++ b/poseidon/poseidon.go @@ -6,7 +6,9 @@ import ( "math/big" "strconv" + _constants "github.com/iden3/go-iden3-crypto/constants" "github.com/iden3/go-iden3-crypto/field" + "github.com/iden3/go-iden3-crypto/utils" "golang.org/x/crypto/blake2b" ) @@ -23,55 +25,35 @@ type constantsData struct { m [][]*big.Int } -// checkBigIntInField checks if given big.Int fits in a Field R element -func checkBigIntInField(a *big.Int, q *big.Int) bool { - if a.Cmp(q) != -1 { - return false - } - return true -} - -// checkBigIntArrayInField checks if given big.Int fits in a Field R element -func checkBigIntArrayInField(arr []*big.Int, q *big.Int) bool { - for _, a := range arr { - if !checkBigIntInField(a, q) { - return false - } - } - return true -} - func generateConstantsData() constantsData { var constants constantsData - r, ok := new(big.Int).SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) - if !ok { - - } - fqR := field.NewFq(r) + fqR := field.NewFq(_constants.Q) constants.fqR = fqR - constants.c = getPseudoRandom(fqR, SEED+"_constants", big.NewInt(int64(NROUNDSF+NROUNDSP))) + constants.c = getPseudoRandom(fqR, SEED+"_constants", NROUNDSF+NROUNDSP) constants.m = getMDS(fqR) return constants } -func getPseudoRandom(fqR field.Fq, seed string, n *big.Int) []*big.Int { - var res []*big.Int - hash := blake2b.Sum256([]byte(seed)) - for big.NewInt(int64(len(res))).Cmp(n) == -1 { // res < n - newN := fqR.Affine(leByteArrayToBigInt(fqR, hash[:])) - res = append(res, newN) - hash = blake2b.Sum256(hash[:]) +func leByteArrayToBigInt(b []byte) *big.Int { + res := big.NewInt(0) + for i := 0; i < len(b); i++ { + n := big.NewInt(int64(b[i])) + res = new(big.Int).Add(res, new(big.Int).Lsh(n, uint(i*8))) } return res } -func leByteArrayToBigInt(fqR field.Fq, b []byte) *big.Int { - res := fqR.Zero() - for i := 0; i < len(b); i++ { - n := big.NewInt(int64(b[i])) - res = new(big.Int).Add(res, new(big.Int).Lsh(n, uint(i*8))) +func getPseudoRandom(fqR field.Fq, seed string, n int) []*big.Int { + var res []*big.Int + hash := blake2b.Sum256([]byte(seed)) + for len(res) < n { + hashBigInt := new(big.Int) + newN := fqR.Affine(utils.SetBigIntFromLEBytes(hashBigInt, hash[:])) + // newN := fqR.Affine(leByteArrayToBigInt(hash[:])) + res = append(res, newN) + hash = blake2b.Sum256(hash[:]) } return res } @@ -87,10 +69,10 @@ func nonceToString(n int) string { // https://eprint.iacr.org/2019/458.pdf pag.8 func getMDS(fqR field.Fq) [][]*big.Int { nonce := 0 - cauchyMatrix := getPseudoRandom(fqR, SEED+"_matrix_"+nonceToString(nonce), big.NewInt(T*2)) + cauchyMatrix := getPseudoRandom(fqR, SEED+"_matrix_"+nonceToString(nonce), T*2) for !checkAllDifferent(cauchyMatrix) { nonce += 1 - cauchyMatrix = getPseudoRandom(fqR, SEED+"_matrix_"+nonceToString(nonce), big.NewInt(T*2)) + cauchyMatrix = getPseudoRandom(fqR, SEED+"_matrix_"+nonceToString(nonce), T*2) } var m [][]*big.Int for i := 0; i < T; i++ { @@ -160,10 +142,10 @@ func mix(state []*big.Int, m [][]*big.Int) []*big.Int { // Hash computes the Poseidon hash for the given inputs func Hash(inp []*big.Int) (*big.Int, error) { var state []*big.Int - if len(inp) < 0 || len(inp) > T { + if len(inp) == 0 || len(inp) > T { return nil, errors.New("wrong inputs length") } - if !checkBigIntArrayInField(inp, constants.fqR.Q) { + if !utils.CheckBigIntArrayInField(inp, constants.fqR.Q) { return nil, errors.New("inputs values not inside Finite Field") } diff --git a/utils/utils.go b/utils/utils.go new file mode 100644 index 0000000..0f2d639 --- /dev/null +++ b/utils/utils.go @@ -0,0 +1,106 @@ +package utils + +import ( + "bytes" + "encoding/hex" + "fmt" + "math/big" + "strings" +) + +// NewIntFromString creates a new big.Int from a decimal integer encoded as a +// string. It will panic if the string is not a decimal integer. +func NewIntFromString(s string) *big.Int { + v, ok := new(big.Int).SetString(s, 10) + if !ok { + panic(fmt.Sprintf("Bad base 10 string %s", s)) + } + return v +} + +// SwapEndianness swaps the endianness of the value encoded in xs. If xs is +// Big-Endian, the result will be Little-Endian and viceversa. +func SwapEndianness(xs []byte) []byte { + ys := make([]byte, len(xs)) + for i, b := range xs { + ys[len(xs)-1-i] = b + } + return ys +} + +// BigIntLEBytes encodes a big.Int into an array in Little-Endian. +func BigIntLEBytes(v *big.Int) [32]byte { + le := SwapEndianness(v.Bytes()) + res := [32]byte{} + copy(res[:], le) + return res +} + +// SetBigIntFromLEBytes sets the value of a big.Int from a Little-Endian +// encoded value. +func SetBigIntFromLEBytes(v *big.Int, leBuf []byte) *big.Int { + beBuf := SwapEndianness(leBuf) + return v.SetBytes(beBuf) +} + +// Hex is a byte slice type that can be marshalled and unmarshaled in hex +type Hex []byte + +// MarshalText encodes buf as hex +func (buf Hex) MarshalText() ([]byte, error) { + return []byte(hex.EncodeToString(buf)), nil +} + +// String encodes buf as hex +func (buf Hex) String() string { + return hex.EncodeToString(buf) +} + +// HexEncode encodes an array of bytes into a string in hex. +func HexEncode(bs []byte) string { + return fmt.Sprintf("0x%s", hex.EncodeToString(bs)) +} + +// HexDecode decodes a hex string into an array of bytes. +func HexDecode(h string) ([]byte, error) { + if strings.HasPrefix(h, "0x") { + h = h[2:] + } + return hex.DecodeString(h) +} + +// HexDecodeInto decodes a hex string into an array of bytes (dst), verifying +// that the decoded array has the same length as dst. +func HexDecodeInto(dst []byte, h []byte) error { + if bytes.HasPrefix(h, []byte("0x")) { + h = h[2:] + } + if len(h)/2 != len(dst) { + return fmt.Errorf("expected %v bytes in hex string, got %v", len(dst), len(h)/2) + } + n, err := hex.Decode(dst, h) + if err != nil { + return err + } else if n != len(dst) { + return fmt.Errorf("expected %v bytes when decoding hex string, got %v", len(dst), n) + } + return nil +} + +// CheckBigIntInField checks if given big.Int fits in a Field Q element +func CheckBigIntInField(a *big.Int, q *big.Int) bool { + if a.Cmp(q) != -1 { + return false + } + return true +} + +// CheckBigIntArrayInField checks if given big.Int fits in a Field Q element +func CheckBigIntArrayInField(arr []*big.Int, q *big.Int) bool { + for _, a := range arr { + if !CheckBigIntInField(a, q) { + return false + } + } + return true +}