You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

143 lines
3.1 KiB

  1. # Primitive Root of Unity
  2. def get_primitive_root_of_unity(F, n):
  3. # using the method described by Thomas Pornin in
  4. # https://crypto.stackexchange.com/a/63616
  5. q = F.order()
  6. for k in range(q):
  7. if k==0:
  8. continue
  9. g = F(k)
  10. # g = F.random_element()
  11. if g==0:
  12. continue
  13. w = g ^ ((q-1)/n)
  14. if w^(n/2) != 1:
  15. return g, w
  16. # Roots of Unity
  17. def get_nth_roots_of_unity(n, primitive_w):
  18. w = [0]*n
  19. for i in range(n):
  20. w[i] = primitive_w^i
  21. return w
  22. # fft (Fast Fourier Transform) returns:
  23. # - nth roots of unity
  24. # - Vandermonde matrix for the nth roots of unity
  25. # - Inverse Vandermonde matrix
  26. def fft(F, n):
  27. g, primitive_w = get_primitive_root_of_unity(F, n)
  28. w = get_nth_roots_of_unity(n, primitive_w)
  29. ft = matrix(F, n)
  30. for j in range(n):
  31. row = []
  32. for k in range(n):
  33. row.append(primitive_w^(j*k))
  34. ft.set_row(j, row)
  35. ft_inv = ft^-1
  36. return w, ft, ft_inv
  37. # Fast polynomial multiplicaton using FFT
  38. def poly_mul(fa, fb, F, n):
  39. w, ft, ft_inv = fft(F, n)
  40. # compute evaluation points from polynomials fa & fb at the roots of unity
  41. a_evals = []
  42. b_evals = []
  43. for i in range(n):
  44. a_evals.append(fa(w[i]))
  45. b_evals.append(fb(w[i]))
  46. # multiply elements in a_evals by b_evals
  47. c_evals = map(operator.mul, a_evals, b_evals)
  48. c_evals = vector(c_evals)
  49. # using FFT, convert the c_evals into fc(x)
  50. fc_coef = c_evals*ft_inv
  51. fc2=P(fc_coef.list())
  52. return fc2, c_evals
  53. # Tests
  54. #####
  55. # Roots of Unity test:
  56. q = 17
  57. F = GF(q)
  58. n = 4
  59. g, primitive_w = get_primitive_root_of_unity(F, n)
  60. print("generator:", g)
  61. print("primitive_w:", primitive_w)
  62. w = get_nth_roots_of_unity(n, primitive_w)
  63. print(f"{n}th roots of unity: {w}")
  64. assert w == [1, 13, 16, 4]
  65. #####
  66. # FFT test:
  67. def isprime(num):
  68. for n in range(2,int(num^1/2)+1):
  69. if num%n==0:
  70. return False
  71. return True
  72. # list valid values for q
  73. for i in range(20):
  74. if isprime(8*i+1):
  75. print("q =", 8*i+1)
  76. q = 41
  77. F = GF(q)
  78. n = 4
  79. # q needs to be a prime, s.t. q-1 is divisible by n
  80. assert (q-1)%n==0
  81. print("q =", q, "n = ", n)
  82. # ft: Vandermonde matrix for the nth roots of unity
  83. w, ft, ft_inv = fft(F, n)
  84. print("nth roots of unity:", w)
  85. print("Vandermonde matrix:")
  86. print(ft)
  87. a = vector([3,4,5,9])
  88. print("a:", a)
  89. # interpolate f_a(x)
  90. fa_coef = ft_inv * a
  91. print("fa_coef:", fa_coef)
  92. P.<x> = PolynomialRing(F)
  93. fa = P(list(fa_coef))
  94. print("f_a(x):", fa)
  95. # check that evaluating fa(x) at the roots of unity returns the expected values of a
  96. for i in range(len(a)):
  97. assert fa(w[i]) == a[i]
  98. # Fast polynomial multiplicaton using FFT
  99. print("\n---------")
  100. print("---Fast polynomial multiplication using FFT")
  101. n = 8
  102. # q needs to be a prime, s.t. q-1 is divisible by n
  103. assert (q-1)%n==0
  104. print("q =", q, "n = ", n)
  105. fa=P([1,2,3,4])
  106. fb=P([1,2,3,4])
  107. fc_expected = fa*fb
  108. print("fc expected result:", fc_expected) # expected result
  109. print("fc expected coef", fc_expected.coefficients())
  110. fc, c_evals = poly_mul(fa, fb, F, n)
  111. print("c_evals=(a_evals*b_evals)=", c_evals)
  112. print("fc:", fc)
  113. assert fc_expected == fc