fix bbjj api: return err when hash fails while sign/verify

This commit is contained in:
arnaucube
2023-08-23 14:47:33 +02:00
parent 3fb23d780c
commit 93bf45c299
3 changed files with 59 additions and 39 deletions

View File

@@ -62,7 +62,10 @@ func (w *BjjWrappedPrivateKey) Sign(rand io.Reader, digest []byte, opts crypto.S
} }
digestBI := big.NewInt(0).SetBytes(digest) digestBI := big.NewInt(0).SetBytes(digest)
sig := w.privKey.SignPoseidon(digestBI) sig, err := w.privKey.SignPoseidon(digestBI)
if err != nil {
return nil, err
}
return sig.Compress().MarshalText() return sig.Compress().MarshalText()
} }

View File

@@ -243,7 +243,7 @@ func (s Signature) Value() (driver.Value, error) {
// SignMimc7 signs a message encoded as a big.Int in Zq using blake-512 hash // SignMimc7 signs a message encoded as a big.Int in Zq using blake-512 hash
// for buffer hashing and mimc7 for big.Int hashing. // for buffer hashing and mimc7 for big.Int hashing.
func (k *PrivateKey) SignMimc7(msg *big.Int) *Signature { func (k *PrivateKey) SignMimc7(msg *big.Int) (*Signature, error) {
h1 := Blake512(k[:]) h1 := Blake512(k[:])
msgBuf := utils.BigIntLEBytes(msg) msgBuf := utils.BigIntLEBytes(msg)
msgBuf32 := [32]byte{} msgBuf32 := [32]byte{}
@@ -256,23 +256,23 @@ func (k *PrivateKey) SignMimc7(msg *big.Int) *Signature {
hmInput := []*big.Int{R8.X, R8.Y, A.X, A.Y, msg} hmInput := []*big.Int{R8.X, R8.Y, A.X, A.Y, msg}
hm, err := mimc7.Hash(hmInput, nil) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) hm, err := mimc7.Hash(hmInput, nil) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg)
if err != nil { if err != nil {
panic(err) return nil, err
} }
S := new(big.Int).Lsh(k.Scalar().BigInt(), 3) S := new(big.Int).Lsh(k.Scalar().BigInt(), 3)
S = S.Mul(hm, S) S = S.Mul(hm, S)
S.Add(r, S) S.Add(r, S)
S.Mod(S, SubOrder) // S = r + hm * 8 * s S.Mod(S, SubOrder) // S = r + hm * 8 * s
return &Signature{R8: R8, S: S} return &Signature{R8: R8, S: S}, nil
} }
// VerifyMimc7 verifies the signature of a message encoded as a big.Int in Zq // VerifyMimc7 verifies the signature of a message encoded as a big.Int in Zq
// using blake-512 hash for buffer hashing and mimc7 for big.Int hashing. // using blake-512 hash for buffer hashing and mimc7 for big.Int hashing.
func (pk *PublicKey) VerifyMimc7(msg *big.Int, sig *Signature) bool { func (pk *PublicKey) VerifyMimc7(msg *big.Int, sig *Signature) error {
hmInput := []*big.Int{sig.R8.X, sig.R8.Y, pk.X, pk.Y, msg} hmInput := []*big.Int{sig.R8.X, sig.R8.Y, pk.X, pk.Y, msg}
hm, err := mimc7.Hash(hmInput, nil) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) hm, err := mimc7.Hash(hmInput, nil) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg)
if err != nil { if err != nil {
return false return err
} }
left := NewPoint().Mul(sig.S, B8) // left = s * 8 * B left := NewPoint().Mul(sig.S, B8) // left = s * 8 * B
@@ -282,12 +282,15 @@ func (pk *PublicKey) VerifyMimc7(msg *big.Int, sig *Signature) bool {
rightProj := right.Projective() rightProj := right.Projective()
rightProj.Add(sig.R8.Projective(), rightProj) // right = 8 * R + 8 * hm * A rightProj.Add(sig.R8.Projective(), rightProj) // right = 8 * R + 8 * hm * A
right = rightProj.Affine() right = rightProj.Affine()
return (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0) if (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0) {
return nil
}
return fmt.Errorf("verifyMimc7 failed")
} }
// SignPoseidon signs a message encoded as a big.Int in Zq using blake-512 hash // SignPoseidon signs a message encoded as a big.Int in Zq using blake-512 hash
// for buffer hashing and Poseidon for big.Int hashing. // for buffer hashing and Poseidon for big.Int hashing.
func (k *PrivateKey) SignPoseidon(msg *big.Int) *Signature { func (k *PrivateKey) SignPoseidon(msg *big.Int) (*Signature, error) {
h1 := Blake512(k[:]) h1 := Blake512(k[:])
msgBuf := utils.BigIntLEBytes(msg) msgBuf := utils.BigIntLEBytes(msg)
msgBuf32 := [32]byte{} msgBuf32 := [32]byte{}
@@ -301,7 +304,7 @@ func (k *PrivateKey) SignPoseidon(msg *big.Int) *Signature {
hmInput := []*big.Int{R8.X, R8.Y, A.X, A.Y, msg} hmInput := []*big.Int{R8.X, R8.Y, A.X, A.Y, msg}
hm, err := poseidon.Hash(hmInput) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) hm, err := poseidon.Hash(hmInput) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg)
if err != nil { if err != nil {
panic(err) return nil, err
} }
S := new(big.Int).Lsh(k.Scalar().BigInt(), 3) S := new(big.Int).Lsh(k.Scalar().BigInt(), 3)
@@ -309,16 +312,16 @@ func (k *PrivateKey) SignPoseidon(msg *big.Int) *Signature {
S.Add(r, S) S.Add(r, S)
S.Mod(S, SubOrder) // S = r + hm * 8 * s S.Mod(S, SubOrder) // S = r + hm * 8 * s
return &Signature{R8: R8, S: S} return &Signature{R8: R8, S: S}, nil
} }
// VerifyPoseidon verifies the signature of a message encoded as a big.Int in Zq // VerifyPoseidon verifies the signature of a message encoded as a big.Int in Zq
// using blake-512 hash for buffer hashing and Poseidon for big.Int hashing. // using blake-512 hash for buffer hashing and Poseidon for big.Int hashing.
func (pk *PublicKey) VerifyPoseidon(msg *big.Int, sig *Signature) bool { func (pk *PublicKey) VerifyPoseidon(msg *big.Int, sig *Signature) error {
hmInput := []*big.Int{sig.R8.X, sig.R8.Y, pk.X, pk.Y, msg} hmInput := []*big.Int{sig.R8.X, sig.R8.Y, pk.X, pk.Y, msg}
hm, err := poseidon.Hash(hmInput) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg) hm, err := poseidon.Hash(hmInput) // hm = H1(8*R.x, 8*R.y, A.x, A.y, msg)
if err != nil { if err != nil {
return false return err
} }
left := NewPoint().Mul(sig.S, B8) // left = s * 8 * B left := NewPoint().Mul(sig.S, B8) // left = s * 8 * B
@@ -328,7 +331,10 @@ func (pk *PublicKey) VerifyPoseidon(msg *big.Int, sig *Signature) bool {
rightProj := right.Projective() rightProj := right.Projective()
rightProj.Add(sig.R8.Projective(), rightProj) // right = 8 * R + 8 * hm * A rightProj.Add(sig.R8.Projective(), rightProj) // right = 8 * R + 8 * hm * A
right = rightProj.Affine() right = rightProj.Affine()
return (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0) if (left.X.Cmp(right.X) == 0) && (left.Y.Cmp(right.Y) == 0) {
return nil
}
return fmt.Errorf("verifyPoseidon failed")
} }
// Scan implements Scanner for database/sql. // Scan implements Scanner for database/sql.

View File

@@ -43,7 +43,8 @@ func TestSignVerifyMimc7(t *testing.T) {
"13622229784656158136036771217484571176836296686641868549125388198837476602820", "13622229784656158136036771217484571176836296686641868549125388198837476602820",
pk.Y.String()) pk.Y.String())
sig := k.SignMimc7(msg) sig, err := k.SignMimc7(msg)
assert.NoError(t, err)
assert.Equal(t, assert.Equal(t,
"11384336176656855268977457483345535180380036354188103142384839473266348197733", "11384336176656855268977457483345535180380036354188103142384839473266348197733",
sig.R8.X.String()) sig.R8.X.String())
@@ -54,20 +55,20 @@ func TestSignVerifyMimc7(t *testing.T) {
"2523202440825208709475937830811065542425109372212752003460238913256192595070", "2523202440825208709475937830811065542425109372212752003460238913256192595070",
sig.S.String()) sig.S.String())
ok := pk.VerifyMimc7(msg, sig) err = pk.VerifyMimc7(msg, sig)
assert.Equal(t, true, ok) assert.NoError(t, err)
sigBuf := sig.Compress() sigBuf := sig.Compress()
sig2, err := new(Signature).Decompress(sigBuf) sig2, err := new(Signature).Decompress(sigBuf)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, ""+ assert.Equal(t, ""+
"dfedb4315d3f2eb4de2d3c510d7a987dcab67089c8ace06308827bf5bcbe02a2"+ "dfedb4315d3f2eb4de2d3c510d7a987dcab67089c8ace06308827bf5bcbe02a2"+
"7ed40dab29bf993c928e789d007387998901a24913d44fddb64b1f21fc149405", "7ed40dab29bf993c928e789d007387998901a24913d44fddb64b1f21fc149405",
hex.EncodeToString(sigBuf[:])) hex.EncodeToString(sigBuf[:]))
ok = pk.VerifyMimc7(msg, sig2) err = pk.VerifyMimc7(msg, sig2)
assert.Equal(t, true, ok) assert.NoError(t, err)
} }
func TestSignVerifyPoseidon(t *testing.T) { func TestSignVerifyPoseidon(t *testing.T) {
@@ -89,7 +90,8 @@ func TestSignVerifyPoseidon(t *testing.T) {
"13622229784656158136036771217484571176836296686641868549125388198837476602820", "13622229784656158136036771217484571176836296686641868549125388198837476602820",
pk.Y.String()) pk.Y.String())
sig := k.SignPoseidon(msg) sig, err := k.SignPoseidon(msg)
assert.NoError(t, err)
assert.Equal(t, assert.Equal(t,
"11384336176656855268977457483345535180380036354188103142384839473266348197733", "11384336176656855268977457483345535180380036354188103142384839473266348197733",
sig.R8.X.String()) sig.R8.X.String())
@@ -100,20 +102,20 @@ func TestSignVerifyPoseidon(t *testing.T) {
"1672775540645840396591609181675628451599263765380031905495115170613215233181", "1672775540645840396591609181675628451599263765380031905495115170613215233181",
sig.S.String()) sig.S.String())
ok := pk.VerifyPoseidon(msg, sig) err = pk.VerifyPoseidon(msg, sig)
assert.Equal(t, true, ok) assert.NoError(t, err)
sigBuf := sig.Compress() sigBuf := sig.Compress()
sig2, err := new(Signature).Decompress(sigBuf) sig2, err := new(Signature).Decompress(sigBuf)
assert.Equal(t, nil, err) assert.NoError(t, err)
assert.Equal(t, ""+ assert.Equal(t, ""+
"dfedb4315d3f2eb4de2d3c510d7a987dcab67089c8ace06308827bf5bcbe02a2"+ "dfedb4315d3f2eb4de2d3c510d7a987dcab67089c8ace06308827bf5bcbe02a2"+
"9d043ece562a8f82bfc0adb640c0107a7d3a27c1c7c1a6179a0da73de5c1b203", "9d043ece562a8f82bfc0adb640c0107a7d3a27c1c7c1a6179a0da73de5c1b203",
hex.EncodeToString(sigBuf[:])) hex.EncodeToString(sigBuf[:]))
ok = pk.VerifyPoseidon(msg, sig2) err = pk.VerifyPoseidon(msg, sig2)
assert.Equal(t, true, ok) assert.NoError(t, err)
} }
func TestCompressDecompress(t *testing.T) { func TestCompressDecompress(t *testing.T) {
@@ -128,22 +130,28 @@ func TestCompressDecompress(t *testing.T) {
panic(err) panic(err)
} }
msg := utils.SetBigIntFromLEBytes(new(big.Int), msgBuf) msg := utils.SetBigIntFromLEBytes(new(big.Int), msgBuf)
sig := k.SignMimc7(msg) sig, err := k.SignMimc7(msg)
assert.NoError(t, err)
sigBuf := sig.Compress() sigBuf := sig.Compress()
sig2, err := new(Signature).Decompress(sigBuf) sig2, err := new(Signature).Decompress(sigBuf)
assert.Equal(t, nil, err) assert.NoError(t, err)
ok := pk.VerifyMimc7(msg, sig2) err = pk.VerifyMimc7(msg, sig2)
assert.Equal(t, true, ok) assert.NoError(t, err)
} }
} }
func TestSignatureCompScannerValuer(t *testing.T) { func TestSignatureCompScannerValuer(t *testing.T) {
privK := NewRandPrivKey() privK := NewRandPrivKey()
var err error
sig, err := privK.SignPoseidon(big.NewInt(674238462))
assert.NoError(t, err)
var value driver.Valuer //nolint:gosimple // this is done to ensure interface compatibility var value driver.Valuer //nolint:gosimple // this is done to ensure interface compatibility
value = privK.SignPoseidon(big.NewInt(674238462)).Compress() value = sig.Compress()
scan := privK.SignPoseidon(big.NewInt(1)).Compress() sig, err = privK.SignPoseidon(big.NewInt(1))
assert.NoError(t, err)
scan := sig.Compress()
fromDB, err := value.Value() fromDB, err := value.Value()
assert.Nil(t, err) assert.NoError(t, err)
assert.Nil(t, scan.Scan(fromDB)) assert.Nil(t, scan.Scan(fromDB))
assert.Equal(t, value, scan) assert.Equal(t, value, scan)
} }
@@ -152,10 +160,13 @@ func TestSignatureScannerValuer(t *testing.T) {
privK := NewRandPrivKey() privK := NewRandPrivKey()
var value driver.Valuer var value driver.Valuer
var scan sql.Scanner var scan sql.Scanner
value = privK.SignPoseidon(big.NewInt(674238462)) var err error
scan = privK.SignPoseidon(big.NewInt(1)) value, err = privK.SignPoseidon(big.NewInt(674238462))
assert.NoError(t, err)
scan, err = privK.SignPoseidon(big.NewInt(1))
assert.NoError(t, err)
fromDB, err := value.Value() fromDB, err := value.Value()
assert.Nil(t, err) assert.NoError(t, err)
assert.Nil(t, scan.Scan(fromDB)) assert.Nil(t, scan.Scan(fromDB))
assert.Equal(t, value, scan) assert.Equal(t, value, scan)
} }
@@ -217,12 +228,12 @@ func BenchmarkBabyjubEddsa(b *testing.B) {
}) })
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
sigs[i%n] = k.SignMimc7(msgs[i%n]) sigs[i%n], _ = k.SignMimc7(msgs[i%n])
} }
b.Run("VerifyMimc7", func(b *testing.B) { b.Run("VerifyMimc7", func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
pk.VerifyMimc7(msgs[i%n], sigs[i%n]) _ = pk.VerifyMimc7(msgs[i%n], sigs[i%n])
} }
}) })
@@ -233,12 +244,12 @@ func BenchmarkBabyjubEddsa(b *testing.B) {
}) })
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
sigs[i%n] = k.SignPoseidon(msgs[i%n]) sigs[i%n], _ = k.SignPoseidon(msgs[i%n])
} }
b.Run("VerifyPoseidon", func(b *testing.B) { b.Run("VerifyPoseidon", func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
pk.VerifyPoseidon(msgs[i%n], sigs[i%n]) _ = pk.VerifyPoseidon(msgs[i%n], sigs[i%n])
} }
}) })
} }