add frame size to poseidon hasher

This commit is contained in:
Ilya
2023-05-18 13:37:31 +03:00
parent b015806983
commit 6ff38d47db
3 changed files with 19 additions and 11 deletions

View File

@@ -29,7 +29,8 @@ func TestBjjWrappedPrivateKeyInterfaceImpl(t *testing.T) {
func TestBjjWrappedPrivateKey(t *testing.T) { func TestBjjWrappedPrivateKey(t *testing.T) {
pk := RandomBjjWrappedKey() pk := RandomBjjWrappedKey()
hasher := poseidon.New() hasher, err := poseidon.New(16)
require.NoError(t, err)
hasher.Write([]byte("test")) hasher.Write([]byte("test"))
digest := hasher.Sum(nil) digest := hasher.Sum(nil)

View File

@@ -2,25 +2,32 @@ package poseidon
import ( import (
"bytes" "bytes"
"errors"
"hash" "hash"
) )
type digest struct { type digest struct {
buf *bytes.Buffer buf *bytes.Buffer
frameSize int
} }
// NewPoseidon returns the Poseidon hash of the input bytes. // NewPoseidon returns the Poseidon hash of the input bytes.
// use frame size of 16 inputs by default
func NewPoseidon(b []byte) []byte { func NewPoseidon(b []byte) []byte {
h := New() h, _ := New(16)
h.Write(b) h.Write(b)
return h.Sum(nil) return h.Sum(nil)
} }
// New returns a new hash.Hash computing the Poseidon hash. // New returns a new hash.Hash computing the Poseidon hash.
func New() hash.Hash { func New(frameSize int) (hash.Hash, error) {
return &digest{ if frameSize < 2 || frameSize > 16 {
buf: bytes.NewBuffer([]byte{}), return nil, errors.New("incorrect frame size")
} }
return &digest{
buf: bytes.NewBuffer([]byte{}),
frameSize: frameSize,
}, nil
} }
// Write (via the embedded io.Writer interface) adds more data to the running hash. // Write (via the embedded io.Writer interface) adds more data to the running hash.
@@ -30,7 +37,7 @@ func (d *digest) Write(p []byte) (n int, err error) {
// Sum returns the Poseidon checksum of the data. // Sum returns the Poseidon checksum of the data.
func (d *digest) Sum(b []byte) []byte { func (d *digest) Sum(b []byte) []byte {
hahs, err := HashBytes(d.buf.Bytes()) hahs, err := HashBytesX(d.buf.Bytes(), d.frameSize)
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -58,12 +57,13 @@ func TestPoseidonWrapperSum(t *testing.T) {
inputBytes, err := hex.DecodeString(vector.bytes) inputBytes, err := hex.DecodeString(vector.bytes)
require.NoError(t, err) require.NoError(t, err)
hasher := New() hasher, err := New(16)
require.NoError(t, err)
hasher.Write(inputBytes) hasher.Write(inputBytes)
res := hasher.Sum(nil) res := hasher.Sum(nil)
require.NotEmpty(t, res) require.NotEmpty(t, res)
assert.Equal(t, vector.expectedHash, hex.EncodeToString(res)) require.Equal(t, vector.expectedHash, hex.EncodeToString(res))
}) })
} }
} }
@@ -120,7 +120,7 @@ func TestPoseidonNewPoseidon(t *testing.T) {
res := NewPoseidon(inputBytes) res := NewPoseidon(inputBytes)
require.NotEmpty(t, res) require.NotEmpty(t, res)
assert.Equal(t, vector.expectedHash, hex.EncodeToString(res)) require.Equal(t, vector.expectedHash, hex.EncodeToString(res))
}) })
} }
} }