diff --git a/shamirsecretsharing.go b/shamirsecretsharing.go index 2d138d2..f018aea 100644 --- a/shamirsecretsharing.go +++ b/shamirsecretsharing.go @@ -71,3 +71,45 @@ func unpackSharesAndI(sharesPacked [][]*big.Int) ([]*big.Int, []*big.Int) { } return shares, i } + +// LagrangeInterpolation calculates the secret from given shares +func LagrangeInterpolation(sharesGiven [][]*big.Int, p *big.Int) *big.Int { + resultN := big.NewInt(int64(0)) + resultD := big.NewInt(int64(0)) + + //unpack shares + sharesBigInt, sharesIBigInt := unpackSharesAndI(sharesGiven) + + for i := 0; i < len(sharesBigInt); i++ { + lagrangeNumerator := big.NewInt(int64(1)) + lagrangeDenominator := big.NewInt(int64(1)) + for j := 0; j < len(sharesBigInt); j++ { + if sharesIBigInt[i] != sharesIBigInt[j] { + currLagrangeNumerator := sharesIBigInt[j] + currLagrangeDenominator := new(big.Int).Sub(sharesIBigInt[j], sharesIBigInt[i]) + lagrangeNumerator = new(big.Int).Mul(lagrangeNumerator, currLagrangeNumerator) + lagrangeDenominator = new(big.Int).Mul(lagrangeDenominator, currLagrangeDenominator) + } + } + numerator := new(big.Int).Mul(sharesBigInt[i], lagrangeNumerator) + quo := new(big.Int).Quo(numerator, lagrangeDenominator) + if quo.Int64() != 0 { + resultN = resultN.Add(resultN, quo) + } else { + resultNMULlagrangeDenominator := new(big.Int).Mul(resultN, lagrangeDenominator) + resultN = new(big.Int).Add(resultNMULlagrangeDenominator, numerator) + + resultD = resultD.Add(resultD, lagrangeDenominator) + } + } + + var modinvMul *big.Int + if resultD.Int64() != 0 { + modinv := new(big.Int).ModInverse(resultD, p) + modinvMul = new(big.Int).Mul(resultN, modinv) + } else { + modinvMul = resultN + } + r := new(big.Int).Mod(modinvMul, p) + return r +} diff --git a/shamirsecretsharing_test.go b/shamirsecretsharing_test.go new file mode 100644 index 0000000..2bba93d --- /dev/null +++ b/shamirsecretsharing_test.go @@ -0,0 +1,47 @@ +package shamirsecretsharing + +import ( + "bytes" + "crypto/rand" + "math/big" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCreate(t *testing.T) { + k, ok := new(big.Int).SetString("123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890", 10) + assert.True(t, ok) + + p, err := rand.Prime(rand.Reader, bits/2) + assert.Nil(t, err) + + nShares := big.NewInt(int64(6)) + nNeededShares := big.NewInt(int64(3)) + shares, err := Create( + nNeededShares, + nShares, + p, + k) + assert.Nil(t, err) + + //generate sharesToUse + var sharesToUse [][]*big.Int + sharesToUse = append(sharesToUse, shares[2]) + sharesToUse = append(sharesToUse, shares[1]) + sharesToUse = append(sharesToUse, shares[0]) + secr := LagrangeInterpolation(sharesToUse, p) + + // fmt.Print("original secret: ") + // fmt.Println(k) + // fmt.Print("p: ") + // fmt.Println(p) + // fmt.Print("shares: ") + // fmt.Println(shares) + // fmt.Print("recovered secret result: ") + // fmt.Println(secr) + if !bytes.Equal(k.Bytes(), secr.Bytes()) { + t.Errorf("reconstructed secret not correspond to original secret") + } + assert.Equal(t, k, secr) +}