diff --git a/babyjub/babyjub.go b/babyjub/babyjub.go new file mode 100644 index 0000000..81dd405 --- /dev/null +++ b/babyjub/babyjub.go @@ -0,0 +1,240 @@ +package babyjub + +import ( + "fmt" + "math/big" +) + +// Q is the order of the integer field where the curve point coordinates are (Zq). +var Q *big.Int + +// A is one of the babyjub constants. +var A *big.Int + +// D is one of the babyjub constants. +var D *big.Int + +// Zero is 0. +var Zero *big.Int + +// One is 1. +var One *big.Int + +// MinusOne is -1. +var MinusOne *big.Int + +// Order of the babyjub curve. +var Order *big.Int + +// SubOrder is the order of the subgroup of the babyjub curve that contains the +// points that we use. +var SubOrder *big.Int + +// B8 is a base point of the babyjub multiplied by 8 to make it a base point of +// the subgroup in the curve. +var B8 *Point + +// NewIntFromString creates a new big.Int from a decimal integer encoded as a +// string. It will panic if the string is not a decimal integer. +func NewIntFromString(s string) *big.Int { + v, ok := new(big.Int).SetString(s, 10) + if !ok { + panic(fmt.Sprintf("Bad base 10 string %s", s)) + } + return v +} + +// init initializes global numbers and the subgroup base. +func init() { + Zero = big.NewInt(0) + One = big.NewInt(1) + MinusOne = big.NewInt(-1) + Q = NewIntFromString( + "21888242871839275222246405745257275088548364400416034343698204186575808495617") + A = NewIntFromString("168700") + D = NewIntFromString("168696") + + Order = NewIntFromString( + "21888242871839275222246405745257275088614511777268538073601725287587578984328") + SubOrder = new(big.Int).Rsh(Order, 3) + + B8 = NewPoint() + B8.X = NewIntFromString( + "17777552123799933955779906779655732241715742912184938656739573121738514868268") + B8.Y = NewIntFromString( + "2626589144620713026669568689430873010625803728049924121243784502389097019475") +} + +// Point represents a point of the babyjub curve. +type Point struct { + X *big.Int + Y *big.Int +} + +// NewPoint creates a new Point. +func NewPoint() *Point { + return &Point{X: big.NewInt(0), Y: big.NewInt(1)} +} + +// Set copies a Point c into the Point p +func (p *Point) Set(c *Point) *Point { + p.X.Set(c.X) + p.Y.Set(c.Y) + return p +} + +// Add adds Point a and b into res +func (res *Point) Add(a *Point, b *Point) *Point { + // x = (a.x * b.y + b.x * a.y) * (1 + D * a.x * b.x * a.y * b.y)^-1 mod q + x1a := new(big.Int).Mul(a.X, b.Y) + x1b := new(big.Int).Mul(b.X, a.Y) + x1a.Add(x1a, x1b) // x1a = a.x * b.y + b.x * a.y + + x2 := new(big.Int).Set(D) + x2.Mul(x2, a.X) + x2.Mul(x2, b.X) + x2.Mul(x2, a.Y) + x2.Mul(x2, b.Y) + x2.Add(One, x2) + x2.Mod(x2, Q) + x2.ModInverse(x2, Q) // x2 = (1 + D * a.x * b.x * a.y * b.y)^-1 + + // y = (a.y * b.y + A * a.x * a.x) * (1 - D * a.x * b.x * a.y * b.y)^-1 mod q + y1a := new(big.Int).Mul(a.Y, b.Y) + y1b := new(big.Int).Set(A) + y1b.Mul(y1b, a.X) + y1b.Mul(y1b, b.X) + + y1a.Sub(y1a, y1b) // y1a = a.y * b.y - A * a.x * b.x + + y2 := new(big.Int).Set(D) + y2.Mul(y2, a.X) + y2.Mul(y2, b.X) + y2.Mul(y2, a.Y) + y2.Mul(y2, b.Y) + y2.Sub(One, y2) + y2.Mod(y2, Q) + y2.ModInverse(y2, Q) // y2 = (1 - D * a.x * b.x * a.y * b.y)^-1 + + res.X = x1a.Mul(x1a, x2) + res.X = res.X.Mod(res.X, Q) + + res.Y = y1a.Mul(y1a, y2) + res.Y = res.Y.Mod(res.Y, Q) + + return res +} + +// Mul multiplies the Point p by the scalar s and stores the result in res, +// which is also returned. +func (res *Point) Mul(s *big.Int, p *Point) *Point { + res.X = big.NewInt(0) + res.Y = big.NewInt(1) + exp := NewPoint().Set(p) + + for i := 0; i < s.BitLen(); i++ { + if s.Bit(i) == 1 { + res.Add(res, exp) + } + exp.Add(exp, exp) + } + + return res +} + +// InCurve returns true when the Point p is in the babyjub curve. +func (p *Point) InCurve() bool { + x2 := new(big.Int).Set(p.X) + x2.Mul(x2, x2) + x2.Mod(x2, Q) + + y2 := new(big.Int).Set(p.Y) + y2.Mul(y2, y2) + y2.Mod(y2, Q) + + a := new(big.Int).Mul(A, x2) + a.Add(a, y2) + a.Mod(a, Q) + + b := new(big.Int).Set(D) + b.Mul(b, x2) + b.Mul(b, y2) + b.Add(One, b) + b.Mod(b, Q) + + return a.Cmp(b) == 0 +} + +// InSubGroup returns true when the Point p is in the subgroup of the babyjub +// curve. +func (p *Point) InSubGroup() bool { + if !p.InCurve() { + return false + } + res := NewPoint().Mul(SubOrder, p) + return (res.X.Cmp(Zero) == 0) && (res.Y.Cmp(One) == 0) +} + +// PointCoordSign returns the sign of the curve point coordinate. It returns +// false if the sign is positive and false if the sign is negative. +func PointCoordSign(c *big.Int) bool { + if c.Cmp(new(big.Int).Rsh(Q, 1)) == 1 { + return true + } + return false +} + +func PackPoint(ay *big.Int, sign bool) [32]byte { + leBuf := BigIntLEBytes(ay) + if sign { + leBuf[31] = leBuf[31] | 0x80 + } + return leBuf +} + +// Compress the point into a 32 byte array that contains the y coordinate in +// little endian and the sign of the x coordinate. +func (p *Point) Compress() [32]byte { + sign := false + if PointCoordSign(p.X) { + sign = true + } + return PackPoint(p.Y, sign) +} + +// Decompress a compressed Point into p, and also returns the decompressed +// Point. Returns error if the compressed Point is invalid. +func (p *Point) Decompress(leBuf [32]byte) (*Point, error) { + sign := false + if (leBuf[31] & 0x80) != 0x00 { + sign = true + leBuf[31] = leBuf[31] & 0x7F + } + SetBigIntFromLEBytes(p.Y, leBuf[:]) + if p.Y.Cmp(Q) >= 0 { + return nil, fmt.Errorf("p.y >= Q") + } + + y2 := new(big.Int).Mul(p.Y, p.Y) + y2.Mod(y2, Q) + xa := big.NewInt(1) + xa.Sub(xa, y2) // xa == 1 - y^2 + + xb := new(big.Int).Mul(D, y2) + xb.Mod(xb, Q) + xb.Sub(A, xb) // xb = A - d * y^2 + + if xb.Cmp(big.NewInt(0)) == 0 { + return nil, fmt.Errorf("division by 0") + } + xb.ModInverse(xb, Q) + p.X.Mul(xa, xb) // xa / xb + p.X.Mod(p.X, Q) + p.X.ModSqrt(p.X, Q) + if (sign && !PointCoordSign(p.X)) || (!sign && PointCoordSign(p.X)) { + p.X.Mul(p.X, MinusOne) + } + p.X.Mod(p.X, Q) + + return p, nil +} diff --git a/babyjub/babyjub_test.go b/babyjub/babyjub_test.go new file mode 100644 index 0000000..0a820a8 --- /dev/null +++ b/babyjub/babyjub_test.go @@ -0,0 +1,232 @@ +package babyjub + +import ( + // "fmt" + "encoding/hex" + "github.com/stretchr/testify/assert" + "math/big" + "testing" +) + +func TestAdd1(t *testing.T) { + a := &Point{X: big.NewInt(0), Y: big.NewInt(1)} + b := &Point{X: big.NewInt(0), Y: big.NewInt(1)} + + c := NewPoint().Add(a, b) + // fmt.Printf("%v = 2 * %v", *c, *a) + assert.Equal(t, "0", c.X.String()) + assert.Equal(t, "1", c.Y.String()) +} + +func TestAdd2(t *testing.T) { + aX := NewIntFromString( + "17777552123799933955779906779655732241715742912184938656739573121738514868268") + aY := NewIntFromString( + "2626589144620713026669568689430873010625803728049924121243784502389097019475") + a := &Point{X: aX, Y: aY} + + bX := NewIntFromString( + "17777552123799933955779906779655732241715742912184938656739573121738514868268") + bY := NewIntFromString( + "2626589144620713026669568689430873010625803728049924121243784502389097019475") + b := &Point{X: bX, Y: bY} + + c := NewPoint().Add(a, b) + // fmt.Printf("%v = 2 * %v", *c, *a) + assert.Equal(t, + "6890855772600357754907169075114257697580319025794532037257385534741338397365", + c.X.String()) + assert.Equal(t, + "4338620300185947561074059802482547481416142213883829469920100239455078257889", + c.Y.String()) +} + +func TestAdd3(t *testing.T) { + aX := NewIntFromString( + "17777552123799933955779906779655732241715742912184938656739573121738514868268") + aY := NewIntFromString( + "2626589144620713026669568689430873010625803728049924121243784502389097019475") + a := &Point{X: aX, Y: aY} + + bX := NewIntFromString( + "16540640123574156134436876038791482806971768689494387082833631921987005038935") + bY := NewIntFromString( + "20819045374670962167435360035096875258406992893633759881276124905556507972311") + b := &Point{X: bX, Y: bY} + + c := NewPoint().Add(a, b) + // fmt.Printf("%v = 2 * %v", *c, *a) + assert.Equal(t, + "7916061937171219682591368294088513039687205273691143098332585753343424131937", + c.X.String()) + assert.Equal(t, + "14035240266687799601661095864649209771790948434046947201833777492504781204499", + c.Y.String()) +} + +func TestAdd4(t *testing.T) { + aX := NewIntFromString( + "0") + aY := NewIntFromString( + "1") + a := &Point{X: aX, Y: aY} + + bX := NewIntFromString( + "16540640123574156134436876038791482806971768689494387082833631921987005038935") + bY := NewIntFromString( + "20819045374670962167435360035096875258406992893633759881276124905556507972311") + b := &Point{X: bX, Y: bY} + + c := NewPoint().Add(a, b) + // fmt.Printf("%v = 2 * %v", *c, *a) + assert.Equal(t, + "16540640123574156134436876038791482806971768689494387082833631921987005038935", + c.X.String()) + assert.Equal(t, + "20819045374670962167435360035096875258406992893633759881276124905556507972311", + c.Y.String()) +} + +func TestInCurve1(t *testing.T) { + p := &Point{X: big.NewInt(0), Y: big.NewInt(1)} + assert.Equal(t, true, p.InCurve()) +} + +func TestInCurve2(t *testing.T) { + p := &Point{X: big.NewInt(1), Y: big.NewInt(0)} + assert.Equal(t, false, p.InCurve()) +} + +func TestMul0(t *testing.T) { + x := NewIntFromString( + "17777552123799933955779906779655732241715742912184938656739573121738514868268") + y := NewIntFromString( + "2626589144620713026669568689430873010625803728049924121243784502389097019475") + p := &Point{X: x, Y: y} + s := NewIntFromString("3") + + r2 := NewPoint().Add(p, p) + r2 = NewPoint().Add(r2, p) + r := NewPoint().Mul(s, p) + assert.Equal(t, r2.X.String(), r.X.String()) + assert.Equal(t, r2.Y.String(), r.Y.String()) + + assert.Equal(t, + "19372461775513343691590086534037741906533799473648040012278229434133483800898", + r.X.String()) + assert.Equal(t, + "9458658722007214007257525444427903161243386465067105737478306991484593958249", + r.Y.String()) +} + +func TestMul1(t *testing.T) { + x := NewIntFromString( + "17777552123799933955779906779655732241715742912184938656739573121738514868268") + y := NewIntFromString( + "2626589144620713026669568689430873010625803728049924121243784502389097019475") + p := &Point{X: x, Y: y} + s := NewIntFromString( + "14035240266687799601661095864649209771790948434046947201833777492504781204499") + r := NewPoint().Mul(s, p) + assert.Equal(t, + "17070357974431721403481313912716834497662307308519659060910483826664480189605", + r.X.String()) + assert.Equal(t, + "4014745322800118607127020275658861516666525056516280575712425373174125159339", + r.Y.String()) +} + +func TestMul2(t *testing.T) { + x := NewIntFromString( + "6890855772600357754907169075114257697580319025794532037257385534741338397365") + y := NewIntFromString( + "4338620300185947561074059802482547481416142213883829469920100239455078257889") + p := &Point{X: x, Y: y} + s := NewIntFromString( + "20819045374670962167435360035096875258406992893633759881276124905556507972311") + r := NewPoint().Mul(s, p) + assert.Equal(t, + "13563888653650925984868671744672725781658357821216877865297235725727006259983", + r.X.String()) + assert.Equal(t, + "8442587202676550862664528699803615547505326611544120184665036919364004251662", + r.Y.String()) +} + +func TestInCurve3(t *testing.T) { + x := NewIntFromString( + "17777552123799933955779906779655732241715742912184938656739573121738514868268") + y := NewIntFromString( + "2626589144620713026669568689430873010625803728049924121243784502389097019475") + p := &Point{X: x, Y: y} + assert.Equal(t, true, p.InCurve()) +} + +func TestInCurve4(t *testing.T) { + x := NewIntFromString( + "6890855772600357754907169075114257697580319025794532037257385534741338397365") + y := NewIntFromString( + "4338620300185947561074059802482547481416142213883829469920100239455078257889") + p := &Point{X: x, Y: y} + assert.Equal(t, true, p.InCurve()) +} + +func TestInSubGroup1(t *testing.T) { + x := NewIntFromString( + "17777552123799933955779906779655732241715742912184938656739573121738514868268") + y := NewIntFromString( + "2626589144620713026669568689430873010625803728049924121243784502389097019475") + p := &Point{X: x, Y: y} + assert.Equal(t, true, p.InSubGroup()) +} + +func TestInSubGroup2(t *testing.T) { + x := NewIntFromString( + "6890855772600357754907169075114257697580319025794532037257385534741338397365") + y := NewIntFromString( + "4338620300185947561074059802482547481416142213883829469920100239455078257889") + p := &Point{X: x, Y: y} + assert.Equal(t, true, p.InSubGroup()) +} + +func TestCompressDecompress1(t *testing.T) { + x := NewIntFromString( + "17777552123799933955779906779655732241715742912184938656739573121738514868268") + y := NewIntFromString( + "2626589144620713026669568689430873010625803728049924121243784502389097019475") + p := &Point{X: x, Y: y} + + buf := p.Compress() + assert.Equal(t, "53b81ed5bffe9545b54016234682e7b2f699bd42a5e9eae27ff4051bc698ce85", hex.EncodeToString(buf[:])) + + p2, err := NewPoint().Decompress(buf) + assert.Equal(t, nil, err) + assert.Equal(t, p.X.String(), p2.X.String()) + assert.Equal(t, p.Y.String(), p2.Y.String()) +} + +func TestCompressDecompress2(t *testing.T) { + x := NewIntFromString( + "6890855772600357754907169075114257697580319025794532037257385534741338397365") + y := NewIntFromString( + "4338620300185947561074059802482547481416142213883829469920100239455078257889") + p := &Point{X: x, Y: y} + + buf := p.Compress() + assert.Equal(t, "e114eb17eddf794f063a68fecac515e3620e131976108555735c8b0773929709", hex.EncodeToString(buf[:])) + + p2, err := NewPoint().Decompress(buf) + assert.Equal(t, nil, err) + assert.Equal(t, p.X.String(), p2.X.String()) + assert.Equal(t, p.Y.String(), p2.Y.String()) +} + +func TestCompressDecompressRnd(t *testing.T) { + for i := 0; i < 64; i++ { + p1 := NewPoint().Mul(big.NewInt(int64(i)), B8) + buf := p1.Compress() + p2, err := NewPoint().Decompress(buf) + assert.Equal(t, nil, err) + assert.Equal(t, p1, p2) + } +} diff --git a/babyjub/eddsa.go b/babyjub/eddsa.go new file mode 100644 index 0000000..5a189d2 --- /dev/null +++ b/babyjub/eddsa.go @@ -0,0 +1,211 @@ +package babyjub + +import ( + "crypto/rand" + // "encoding/hex" + // "fmt" + common3 "github.com/iden3/go-iden3/common" + "github.com/iden3/go-iden3/crypto/mimc7" + // "golang.org/x/crypto/blake2b" + "math/big" +) + +// pruneBuffer prunes the buffer during key generation according to RFC 8032. +// https://tools.ietf.org/html/rfc8032#page-13 +func pruneBuffer(buf *[32]byte) *[32]byte { + buf[0] = buf[0] & 0xF8 + buf[31] = buf[31] & 0x7F + buf[31] = buf[31] | 0x40 + return buf +} + +// PrivateKey is an EdDSA private key, which is a 32byte buffer. +type PrivateKey [32]byte + +// NewRandPrivKey generates a new random private key (using cryptographically +// secure randomness). +func NewRandPrivKey() PrivateKey { + var k PrivateKey + _, err := rand.Read(k[:]) + if err != nil { + panic(err) + } + return k +} + +// Scalar converts a private key into the scalar value s following the EdDSA +// standard, and using blake-512 hash. +func (k *PrivateKey) Scalar() *PrivKeyScalar { + sBuf := Blake512(k[:]) + sBuf32 := [32]byte{} + copy(sBuf32[:], sBuf[:32]) + pruneBuffer(&sBuf32) + s := new(big.Int) + SetBigIntFromLEBytes(s, sBuf32[:]) + s.Rsh(s, 3) + return NewPrivKeyScalar(s) +} + +// Pub returns the public key corresponding to a private key. +func (k *PrivateKey) Public() *PublicKey { + return k.Scalar().Public() +} + +// PrivKeyScalar represents the scalar s output of a private key +type PrivKeyScalar big.Int + +// NewPrivKeyScalar creates a new PrivKeyScalar from a big.Int +func NewPrivKeyScalar(s *big.Int) *PrivKeyScalar { + sk := PrivKeyScalar(*s) + return &sk +} + +// Pub returns the public key corresponding to the scalar value s of a private +// key. +func (s *PrivKeyScalar) Public() *PublicKey { + p := NewPoint().Mul((*big.Int)(s), B8) + pk := PublicKey(*p) + return &pk +} + +// BigInt returns the big.Int corresponding to a PrivKeyScalar. +func (s *PrivKeyScalar) BigInt() *big.Int { + return (*big.Int)(s) +} + +// PublicKey represents an EdDSA public key, which is a curve point. +type PublicKey Point + +func (pk PublicKey) MarshalText() ([]byte, error) { + pkc := pk.Compress() + return common3.Hex(pkc[:]).MarshalText() +} + +func (pk PublicKey) String() string { + pkc := pk.Compress() + return common3.Hex(pkc[:]).String() +} + +func (pk *PublicKey) UnmarshalText(h []byte) error { + var pkc PublicKeyComp + if err := common3.HexDecodeInto(pkc[:], h); err != nil { + return err + } + pkd, err := pkc.Decompress() + if err != nil { + return err + } + *pk = *pkd + return nil +} + +// Point returns the Point corresponding to a PublicKey. +func (p *PublicKey) Point() *Point { + return (*Point)(p) +} + +// PublicKeyComp represents a compressed EdDSA Public key; it's a compressed curve +// point. +type PublicKeyComp [32]byte + +func (buf PublicKeyComp) MarshalText() ([]byte, error) { return common3.Hex(buf[:]).MarshalText() } +func (buf PublicKeyComp) String() string { return common3.Hex(buf[:]).String() } +func (buf *PublicKeyComp) UnmarshalText(h []byte) error { return common3.HexDecodeInto(buf[:], h) } + +func (p *PublicKey) Compress() PublicKeyComp { + return PublicKeyComp((*Point)(p).Compress()) +} + +func (p *PublicKeyComp) Decompress() (*PublicKey, error) { + point, err := NewPoint().Decompress(*p) + if err != nil { + return nil, err + } + pk := PublicKey(*point) + return &pk, nil +} + +// Signature represents an EdDSA uncompressed signature. +type Signature struct { + R8 *Point + S *big.Int +} + +// SignatureComp represents a compressed EdDSA signature. +type SignatureComp [64]byte + +func (buf SignatureComp) MarshalText() ([]byte, error) { return common3.Hex(buf[:]).MarshalText() } +func (buf SignatureComp) String() string { return common3.Hex(buf[:]).String() } +func (buf *SignatureComp) UnmarshalText(h []byte) error { return common3.HexDecodeInto(buf[:], h) } + +// Compress an EdDSA signature by concatenating the compression of +// the point R8 and the Little-Endian encoding of S. +func (s *Signature) Compress() SignatureComp { + R8p := s.R8.Compress() + Sp := BigIntLEBytes(s.S) + buf := [64]byte{} + copy(buf[:32], R8p[:]) + copy(buf[32:], Sp[:]) + return SignatureComp(buf) +} + +// Decompress a compressed signature into s, and also returns the decompressed +// signature. Returns error if the Point decompression fails. +func (s *Signature) Decompress(buf [64]byte) (*Signature, error) { + R8p := [32]byte{} + copy(R8p[:], buf[:32]) + var err error + if s.R8, err = NewPoint().Decompress(R8p); err != nil { + return nil, err + } + s.S = SetBigIntFromLEBytes(new(big.Int), buf[32:]) + return s, nil +} + +// Decompress a compressed signature. Returns error if the Point decompression +// fails. +func (s *SignatureComp) Decompress() (*Signature, error) { + return new(Signature).Decompress(*s) +} + +// SignMimc7 signs a message encoded as a big.Int in Zq using blake-512 hash +// for buffer hashing and mimc7 for big.Int hashing. +func (k *PrivateKey) SignMimc7(msg *big.Int) *Signature { + h1 := Blake512(k[:]) + msgBuf := BigIntLEBytes(msg) + msgBuf32 := [32]byte{} + copy(msgBuf32[:], msgBuf[:]) + rBuf := Blake512(append(h1[32:], msgBuf32[:]...)) + r := SetBigIntFromLEBytes(new(big.Int), rBuf) // r = H(H_{32..63}(k), msg) + r.Mod(r, SubOrder) + R8 := NewPoint().Mul(r, B8) // R8 = r * 8 * B + A := k.Public().Point() + hmInput, err := mimc7.BigIntsToRElems([]*big.Int{R8.X, R8.Y, A.X, A.Y, msg}) + if err != nil { + panic(err) + } + hm := mimc7.Hash(hmInput, nil) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) + S := new(big.Int).Lsh(k.Scalar().BigInt(), 3) + S = S.Mul(hm, S) + S.Add(r, S) + S.Mod(S, SubOrder) // S = r + hm * 8 * s + + return &Signature{R8: R8, S: S} +} + +// VerifyMimc7 verifies the signature of a message encoded as a big.Int in Zq +// using blake-512 hash for buffer hashing and mimc7 for big.Int hashing. +func (p *PublicKey) VerifyMimc7(msg *big.Int, sig *Signature) bool { + hmInput, err := mimc7.BigIntsToRElems([]*big.Int{sig.R8.X, sig.R8.Y, p.X, p.Y, msg}) + if err != nil { + panic(err) + } + hm := mimc7.Hash(hmInput, nil) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) + + left := NewPoint().Mul(sig.S, B8) // left = s * 8 * B + r1 := big.NewInt(8) + r1.Mul(r1, hm) + right := NewPoint().Mul(r1, p.Point()) + right.Add(sig.R8, right) // right = 8 * R + 8 * hm * A + return (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0) +} diff --git a/babyjub/eddsa_test.go b/babyjub/eddsa_test.go new file mode 100644 index 0000000..fbf56dd --- /dev/null +++ b/babyjub/eddsa_test.go @@ -0,0 +1,89 @@ +package babyjub + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + + "github.com/stretchr/testify/assert" + + // "golang.org/x/crypto/blake2b" + "math/big" + "testing" +) + +func genInputs() (*PrivateKey, *big.Int) { + k := NewRandPrivKey() + fmt.Println("k", hex.EncodeToString(k[:])) + + msgBuf := [32]byte{} + rand.Read(msgBuf[:]) + msg := SetBigIntFromLEBytes(new(big.Int), msgBuf[:]) + msg.Mod(msg, Q) + fmt.Println("msg", msg) + + return &k, msg +} + +func TestSignVerify1(t *testing.T) { + var k PrivateKey + hex.Decode(k[:], []byte("0001020304050607080900010203040506070809000102030405060708090001")) + msgBuf, err := hex.DecodeString("00010203040506070809") + if err != nil { + panic(err) + } + msg := SetBigIntFromLEBytes(new(big.Int), msgBuf) + + pk := k.Public() + assert.Equal(t, + "2610057752638682202795145288373380503107623443963127956230801721756904484787", + pk.X.String()) + assert.Equal(t, + "16617171478497210597712478520507818259149717466230047843969353176573634386897", + pk.Y.String()) + + sig := k.SignMimc7(msg) + assert.Equal(t, + "4974729414807584049518234760796200867685098748448054182902488636762478901554", + sig.R8.X.String()) + assert.Equal(t, + "18714049394522540751536514815950425694461287643205706667341348804546050128733", + sig.R8.Y.String()) + assert.Equal(t, + "2171284143457722024136077617757713039502332290425057126942676527240038689549", + sig.S.String()) + + ok := pk.VerifyMimc7(msg, sig) + assert.Equal(t, true, ok) + + sigBuf := sig.Compress() + sig2, err := new(Signature).Decompress(sigBuf) + assert.Equal(t, nil, err) + + assert.Equal(t, ""+ + "5dfb6f843c023fe3e52548ccf22e55c81b426f7af81b4f51f7152f2fcfc65f29"+ + "0dab19c5a0a75973cd75a54780de0c3a41ede6f57396fe99b5307fff3ce7cc04", + hex.EncodeToString(sigBuf[:])) + + ok = pk.VerifyMimc7(msg, sig2) + assert.Equal(t, true, ok) +} + +func TestCompressDecompress(t *testing.T) { + var k PrivateKey + hex.Decode(k[:], []byte("0001020304050607080900010203040506070809000102030405060708090001")) + pk := k.Public() + for i := 0; i < 64; i++ { + msgBuf, err := hex.DecodeString(fmt.Sprintf("000102030405060708%02d", i)) + if err != nil { + panic(err) + } + msg := SetBigIntFromLEBytes(new(big.Int), msgBuf) + sig := k.SignMimc7(msg) + sigBuf := sig.Compress() + sig2, err := new(Signature).Decompress(sigBuf) + assert.Equal(t, nil, err) + ok := pk.VerifyMimc7(msg, sig2) + assert.Equal(t, true, ok) + } +} diff --git a/babyjub/helpers.go b/babyjub/helpers.go new file mode 100644 index 0000000..c983979 --- /dev/null +++ b/babyjub/helpers.go @@ -0,0 +1,39 @@ +package babyjub + +import ( + "github.com/dchest/blake512" // I have personally reviewed that this module doesn't do anything suspicious + "math/big" +) + +// SwapEndianness swaps the endianness of the value encoded in xs. If xs is +// Big-Endian, the result will be Little-Endian and viceversa. +func SwapEndianness(xs []byte) []byte { + ys := make([]byte, len(xs)) + for i, b := range xs { + ys[len(xs)-1-i] = b + } + return ys +} + +// BigIntLEBytes encodes a big.Int into an array in Little-Endian. +func BigIntLEBytes(v *big.Int) [32]byte { + le := SwapEndianness(v.Bytes()) + res := [32]byte{} + copy(res[:], le) + return res +} + +// SetBigIntFromLEBytes sets the value of a big.Int from a Little-Endian +// encoded value. +func SetBigIntFromLEBytes(v *big.Int, leBuf []byte) *big.Int { + beBuf := SwapEndianness(leBuf) + return v.SetBytes(beBuf) +} + +// Blake512 performs the blake-512 hash over the buffer m. Note that this is +// the original blake from the SHA3 competition and not the new blake2 version. +func Blake512(m []byte) []byte { + h := blake512.New() + h.Write(m[:]) + return h.Sum(nil) +}