Browse Source

add the unit-test

fix/bbjj-err
Cool Developer 2 years ago
parent
commit
26bfd1051a
4 changed files with 90 additions and 76 deletions
  1. +6
    -0
      ffg/element.go
  2. +3
    -9
      goldenposeidon/constants.go
  3. +7
    -65
      goldenposeidon/poseidon.go
  4. +74
    -2
      goldenposeidon/poseidon_test.go

+ 6
- 0
ffg/element.go

@ -559,6 +559,12 @@ func (z Element) ToBigIntRegular(res *big.Int) *big.Int {
return z.ToBigInt(res) return z.ToBigInt(res)
} }
// ToUint64Regular returns z as a uint64 in regular form
func (z Element) ToUint64Regular() uint64 {
z.FromMont()
return z[0]
}
// Bytes returns the regular (non montgomery) value // Bytes returns the regular (non montgomery) value
// of z as a big-endian byte array. // of z as a big-endian byte array.
func (z *Element) Bytes() (res [Limbs * 8]byte) { func (z *Element) Bytes() (res [Limbs * 8]byte) {

+ 3
- 9
goldenposeidon/constants.go

@ -13,7 +13,7 @@ var (
mcirc = []uint64{17, 15, 41, 16, 2, 28, 13, 13, 39, 18, 34, 20} mcirc = []uint64{17, 15, 41, 16, 2, 28, 13, 13, 39, 18, 34, 20}
mdiag = []uint64{8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} mdiag = []uint64{8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
c = [360]uint64{
c = []uint64{
0xb585f766f2144405, 0x7746a55f43921ad7, 0xb2fb0d31cee799b4, 0x0f6760a4803427d7, 0xb585f766f2144405, 0x7746a55f43921ad7, 0xb2fb0d31cee799b4, 0x0f6760a4803427d7,
0xe10d666650f4e012, 0x8cae14cb07d09bf1, 0xd438539c95f63e9f, 0xef781c7ce35b4c3d, 0xe10d666650f4e012, 0x8cae14cb07d09bf1, 0xd438539c95f63e9f, 0xef781c7ce35b4c3d,
0xcdc4a239b0c44426, 0x277fa208bf337bff, 0xe17653a29da578a1, 0xc54302f225db2c76, 0xcdc4a239b0c44426, 0x277fa208bf337bff, 0xe17653a29da578a1, 0xc54302f225db2c76,
@ -115,18 +115,12 @@ func init() {
C = append(C, ffg.NewElementFromUint64(c[i])) C = append(C, ffg.NewElementFromUint64(c[i]))
} }
var mFFCirc, mFFDiag []*ffg.Element
for i := 0; i < mLen; i++ {
mFFCirc = append(mFFCirc, ffg.NewElementFromUint64(mcirc[i]))
mFFDiag = append(mFFDiag, ffg.NewElementFromUint64(mdiag[i]))
}
for i := 0; i < mLen; i++ { for i := 0; i < mLen; i++ {
var row []*ffg.Element var row []*ffg.Element
for j := 0; j < mLen; j++ { for j := 0; j < mLen; j++ {
ele := mFFCirc[(-i+j+mLen)%mLen]
ele := ffg.NewElementFromUint64(mcirc[(-i+j+mLen)%mLen])
if i == j { if i == j {
ele = ele.Add(ele, mFFDiag[i])
ele = ffg.NewElementFromUint64(mcirc[0] + mdiag[i])
} }
row = append(row, ele) row = append(row, ele)
} }

+ 7
- 65
goldenposeidon/poseidon.go

@ -43,7 +43,7 @@ func mix(state []*ffg.Element) []*ffg.Element {
for i := 0; i < mLen; i++ { for i := 0; i < mLen; i++ {
newState[i].SetUint64(0) newState[i].SetUint64(0)
for j := 0; j < mLen; j++ { for j := 0; j < mLen; j++ {
mul.Mul(M[j][i], state[j])
mul.Mul(M[i][j], state[j])
newState[i].Add(newState[i], mul) newState[i].Add(newState[i], mul)
} }
} }
@ -51,12 +51,12 @@ func mix(state []*ffg.Element) []*ffg.Element {
} }
// Hash computes the Poseidon hash for the given inputs // Hash computes the Poseidon hash for the given inputs
func Hash(inpBI []*big.Int, capBI []*big.Int) (*big.Int, error) {
func Hash(inpBI []*big.Int, capBI []*big.Int) ([CAPLEN]uint64, error) {
if len(inpBI) != NROUNDSF { if len(inpBI) != NROUNDSF {
return nil, fmt.Errorf("invalid inputs length %d, must be 8", len(inpBI))
return [CAPLEN]uint64{}, fmt.Errorf("invalid inputs length %d, must be 8", len(inpBI))
} }
if len(capBI) != CAPLEN { if len(capBI) != CAPLEN {
return nil, fmt.Errorf("invalid capcity length %d, must be 4", len(capBI))
return [CAPLEN]uint64{}, fmt.Errorf("invalid capcity length %d, must be 4", len(capBI))
} }
state := make([]*ffg.Element, mLen) state := make([]*ffg.Element, mLen)
@ -69,73 +69,15 @@ func Hash(inpBI []*big.Int, capBI []*big.Int) (*big.Int, error) {
for r := 0; r < NROUNDSF+NROUNDSP; r++ { for r := 0; r < NROUNDSF+NROUNDSP; r++ {
ark(state, r*mLen) ark(state, r*mLen)
if r < NROUNDSF/2 || r >= NROUNDSF/2+NROUNDSP { if r < NROUNDSF/2 || r >= NROUNDSF/2+NROUNDSP {
exp7state(state) exp7state(state)
} else { } else {
exp7(state[0]) exp7(state[0])
} }
state = mix(state)
}
r := big.NewInt(0)
for i := 0; i < CAPLEN; i++ {
res := big.NewInt(0)
state[i].ToBigIntRegular(res)
r.Add(r.Lsh(r, 64), res)
}
return r, nil
}
// HashBytes returns a sponge hash of a msg byte slice split into blocks of 31 bytes
func HashBytes(msg []byte) (*big.Int, error) {
// not used inputs default to zero
inputs := make([]*big.Int, spongeInputs)
for j := 0; j < spongeInputs; j++ {
inputs[j] = new(big.Int)
}
dirty := false
var hash *big.Int
var err error
k := 0
for i := 0; i < len(msg)/spongeChunkSize; i++ {
dirty = true
inputs[k].SetBytes(msg[spongeChunkSize*i : spongeChunkSize*(i+1)])
if k == spongeInputs-1 {
hash, err = Hash(inputs, []*big.Int{big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0)})
dirty = false
if err != nil {
return nil, err
}
inputs = make([]*big.Int, spongeInputs)
inputs[0] = hash
for j := 1; j < spongeInputs; j++ {
inputs[j] = new(big.Int)
}
k = 1
} else {
k++
}
}
if len(msg)%spongeChunkSize != 0 {
// the last chunk of the message is less than 31 bytes
// zero padding it, so that 0xdeadbeaf becomes
// 0xdeadbeaf000000000000000000000000000000000000000000000000000000
var buf [spongeChunkSize]byte
copy(buf[:], msg[(len(msg)/spongeChunkSize)*spongeChunkSize:])
inputs[k] = new(big.Int).SetBytes(buf[:])
dirty = true
}
if dirty {
// we haven't hashed something in the main sponge loop and need to do hash here
hash, err = Hash(inputs, []*big.Int{big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0)})
if err != nil {
return nil, err
}
state = mix(state)
} }
return hash, nil
return [CAPLEN]uint64{state[0].ToUint64Regular(), state[1].ToUint64Regular(), state[2].ToUint64Regular(), state[3].ToUint64Regular()}, nil
} }

+ 74
- 2
goldenposeidon/poseidon_test.go

@ -9,10 +9,82 @@ import (
func TestPoseidonHash(t *testing.T) { func TestPoseidonHash(t *testing.T) {
b0 := big.NewInt(0) b0 := big.NewInt(0)
b1 := big.NewInt(1)
b_1 := big.NewInt(-1)
bM := new(big.Int).SetUint64(18446744069414584321)
h, err := Hash([]*big.Int{b0, b0, b0, b0, b0, b0, b0, b0}, []*big.Int{b0, b0, b0, b0}) h, err := Hash([]*big.Int{b0, b0, b0, b0, b0, b0, b0, b0}, []*big.Int{b0, b0, b0, b0})
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, assert.Equal(t,
"18586133768512220936620570745912940619677854269274689475585506675881198879027",
h.String())
[CAPLEN]uint64{
4330397376401421145,
14124799381142128323,
8742572140681234676,
14345658006221440202,
}, h,
)
h, err = Hash([]*big.Int{b1, b1, b1, b1, b1, b1, b1, b1}, []*big.Int{b1, b1, b1, b1})
assert.Nil(t, err)
assert.Equal(t,
[CAPLEN]uint64{
16428316519797902711,
13351830238340666928,
682362844289978626,
12150588177266359240,
}, h,
)
h, err = Hash([]*big.Int{b1, b1, b1, b1, b1, b1, b1, b1}, []*big.Int{b1, b1, b1, b1})
assert.Nil(t, err)
assert.Equal(t,
[CAPLEN]uint64{
16428316519797902711,
13351830238340666928,
682362844289978626,
12150588177266359240,
}, h,
)
h, err = Hash([]*big.Int{b_1, b_1, b_1, b_1, b_1, b_1, b_1, b_1}, []*big.Int{b_1, b_1, b_1, b_1})
assert.Nil(t, err)
assert.Equal(t,
[CAPLEN]uint64{
13691089994624172887,
15662102337790434313,
14940024623104903507,
10772674582659927682,
}, h,
)
h, err = Hash([]*big.Int{bM, bM, bM, bM, bM, bM, bM, bM}, []*big.Int{b0, b0, b0, b0})
assert.Nil(t, err)
assert.Equal(t,
[CAPLEN]uint64{
4330397376401421145,
14124799381142128323,
8742572140681234676,
14345658006221440202,
}, h,
)
h, err = Hash([]*big.Int{
new(big.Int).SetUint64(923978),
new(big.Int).SetUint64(235763497586),
new(big.Int).SetUint64(9827635653498),
new(big.Int).SetUint64(112870),
new(big.Int).SetUint64(289273673480943876),
new(big.Int).SetUint64(230295874986745876),
new(big.Int).SetUint64(6254867324987),
new(big.Int).SetUint64(2087),
}, []*big.Int{b0, b0, b0, b0})
assert.Nil(t, err)
assert.Equal(t,
[CAPLEN]uint64{
1892171027578617759,
984732815927439256,
7866041765487844082,
8161503938059336191,
}, h,
)
} }

Loading…
Cancel
Save