Browse Source

add frame size to poseidon hasher

fix/bbjj-err
Ilya 1 year ago
parent
commit
6ff38d47db
3 changed files with 19 additions and 11 deletions
  1. +2
    -1
      babyjub/babyjub_wrapper_test.go
  2. +13
    -6
      poseidon/poseidon_wrapper.go
  3. +4
    -4
      poseidon/poseidon_wrapper_test.go

+ 2
- 1
babyjub/babyjub_wrapper_test.go

@ -29,7 +29,8 @@ func TestBjjWrappedPrivateKeyInterfaceImpl(t *testing.T) {
func TestBjjWrappedPrivateKey(t *testing.T) { func TestBjjWrappedPrivateKey(t *testing.T) {
pk := RandomBjjWrappedKey() pk := RandomBjjWrappedKey()
hasher := poseidon.New()
hasher, err := poseidon.New(16)
require.NoError(t, err)
hasher.Write([]byte("test")) hasher.Write([]byte("test"))
digest := hasher.Sum(nil) digest := hasher.Sum(nil)

+ 13
- 6
poseidon/poseidon_wrapper.go

@ -2,25 +2,32 @@ package poseidon
import ( import (
"bytes" "bytes"
"errors"
"hash" "hash"
) )
type digest struct { type digest struct {
buf *bytes.Buffer
buf *bytes.Buffer
frameSize int
} }
// NewPoseidon returns the Poseidon hash of the input bytes. // NewPoseidon returns the Poseidon hash of the input bytes.
// use frame size of 16 inputs by default
func NewPoseidon(b []byte) []byte { func NewPoseidon(b []byte) []byte {
h := New()
h, _ := New(16)
h.Write(b) h.Write(b)
return h.Sum(nil) return h.Sum(nil)
} }
// New returns a new hash.Hash computing the Poseidon hash. // New returns a new hash.Hash computing the Poseidon hash.
func New() hash.Hash {
return &digest{
buf: bytes.NewBuffer([]byte{}),
func New(frameSize int) (hash.Hash, error) {
if frameSize < 2 || frameSize > 16 {
return nil, errors.New("incorrect frame size")
} }
return &digest{
buf: bytes.NewBuffer([]byte{}),
frameSize: frameSize,
}, nil
} }
// Write (via the embedded io.Writer interface) adds more data to the running hash. // Write (via the embedded io.Writer interface) adds more data to the running hash.
@ -30,7 +37,7 @@ func (d *digest) Write(p []byte) (n int, err error) {
// Sum returns the Poseidon checksum of the data. // Sum returns the Poseidon checksum of the data.
func (d *digest) Sum(b []byte) []byte { func (d *digest) Sum(b []byte) []byte {
hahs, err := HashBytes(d.buf.Bytes())
hahs, err := HashBytesX(d.buf.Bytes(), d.frameSize)
if err != nil { if err != nil {
panic(err) panic(err)
} }

+ 4
- 4
poseidon/poseidon_wrapper_test.go

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -58,12 +57,13 @@ func TestPoseidonWrapperSum(t *testing.T) {
inputBytes, err := hex.DecodeString(vector.bytes) inputBytes, err := hex.DecodeString(vector.bytes)
require.NoError(t, err) require.NoError(t, err)
hasher := New()
hasher, err := New(16)
require.NoError(t, err)
hasher.Write(inputBytes) hasher.Write(inputBytes)
res := hasher.Sum(nil) res := hasher.Sum(nil)
require.NotEmpty(t, res) require.NotEmpty(t, res)
assert.Equal(t, vector.expectedHash, hex.EncodeToString(res))
require.Equal(t, vector.expectedHash, hex.EncodeToString(res))
}) })
} }
} }
@ -120,7 +120,7 @@ func TestPoseidonNewPoseidon(t *testing.T) {
res := NewPoseidon(inputBytes) res := NewPoseidon(inputBytes)
require.NotEmpty(t, res) require.NotEmpty(t, res)
assert.Equal(t, vector.expectedHash, hex.EncodeToString(res))
require.Equal(t, vector.expectedHash, hex.EncodeToString(res))
}) })
} }
} }

Loading…
Cancel
Save