feat: Plonk optimizations (#39)

* Fixed poseidion hash TOOD in fri/fri.go

* optimized goldilocks

* Another optimization

* Down to 16 million

* Finished TODOs
This commit is contained in:
puma314
2023-10-13 14:00:54 -07:00
committed by GitHub
parent 940c81b212
commit 89b5a01e4b
9 changed files with 171 additions and 82 deletions

View File

@@ -47,6 +47,9 @@ func (c *BN254Chip) HashNoPad(input []gl.Variable) BN254HashOut {
frontend.Variable(0),
}
two_to_32 := new(big.Int).SetInt64(1 << 32)
two_to_64 := new(big.Int).Mul(two_to_32, two_to_32)
for i := 0; i < len(input); i += BN254_SPONGE_RATE * 3 {
endI := c.min(len(input), i+BN254_SPONGE_RATE*3)
rateChunk := input[i:endI]
@@ -54,13 +57,12 @@ func (c *BN254Chip) HashNoPad(input []gl.Variable) BN254HashOut {
endJ := c.min(len(rateChunk), j+3)
bn254Chunk := rateChunk[j:endJ]
bits := []frontend.Variable{}
inter := frontend.Variable(0)
for k := 0; k < len(bn254Chunk); k++ {
bn254Chunk[k] = c.gl.Reduce(bn254Chunk[k])
bits = append(bits, c.api.ToBinary(bn254Chunk[k].Limb, 64)...)
inter = c.api.MulAcc(inter, bn254Chunk[k].Limb, new(big.Int).Exp(two_to_64, big.NewInt(int64(k)), nil))
}
state[stateIdx+1] = c.api.FromBinary(bits...)
state[stateIdx+1] = inter
}
state = c.Poseidon(state)
@@ -75,7 +77,7 @@ func (c *BN254Chip) HashOrNoop(input []gl.Variable) BN254HashOut {
alpha := new(big.Int).SetInt64(1 << 32)
for i, inputElement := range input {
returnVal = c.api.Add(returnVal, c.api.Mul(inputElement, alpha.Exp(alpha, big.NewInt(int64(i)), nil)))
returnVal = c.api.MulAcc(returnVal, inputElement, alpha.Exp(alpha, big.NewInt(int64(i)), nil))
}
return BN254HashOut(returnVal)
@@ -145,16 +147,13 @@ func (c *BN254Chip) partialRounds(state BN254State) BN254State {
state[0] = c.exp5(state[0])
state[0] = c.api.Add(state[0], cConstants[(BN254_FULL_ROUNDS/2+1)*BN254_SPONGE_WIDTH+i])
var mul frontend.Variable
newState0 := frontend.Variable(0)
for j := 0; j < BN254_SPONGE_WIDTH; j++ {
mul = c.api.Mul(sConstants[(BN254_SPONGE_WIDTH*2-1)*i+j], state[j])
newState0 = c.api.Add(newState0, mul)
newState0 = c.api.MulAcc(newState0, sConstants[(BN254_SPONGE_WIDTH*2-1)*i+j], state[j])
}
for k := 1; k < BN254_SPONGE_WIDTH; k++ {
mul = c.api.Mul(state[0], sConstants[(BN254_SPONGE_WIDTH*2-1)*i+BN254_SPONGE_WIDTH+k-1])
state[k] = c.api.Add(state[k], mul)
state[k] = c.api.MulAcc(state[k], state[0], sConstants[(BN254_SPONGE_WIDTH*2-1)*i+BN254_SPONGE_WIDTH+k-1])
}
state[0] = newState0
}
@@ -186,7 +185,6 @@ func (c *BN254Chip) exp5state(state BN254State) BN254State {
}
func (c *BN254Chip) mix(state_ BN254State, constantMatrix [][]*big.Int) BN254State {
var mul frontend.Variable
var result BN254State
for i := 0; i < BN254_SPONGE_WIDTH; i++ {
@@ -195,8 +193,7 @@ func (c *BN254Chip) mix(state_ BN254State, constantMatrix [][]*big.Int) BN254Sta
for i := 0; i < BN254_SPONGE_WIDTH; i++ {
for j := 0; j < BN254_SPONGE_WIDTH; j++ {
mul = c.api.Mul(constantMatrix[j][i], state_[j])
result[i] = c.api.Add(result[i], mul)
result[i] = c.api.MulAcc(result[i], constantMatrix[j][i], state_[j])
}
}