You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

147 lines
3.3 KiB

  1. # Primitive Root of Unity
  2. def get_primitive_root_of_unity(F, n):
  3. # using the method described by Thomas Pornin in
  4. # https://crypto.stackexchange.com/a/63616
  5. q = F.order()
  6. for k in range(q):
  7. if k==0:
  8. continue
  9. g = F(k)
  10. # g = F.random_element()
  11. if g==0:
  12. continue
  13. w = g ^ ((q-1)/n)
  14. if w^(n/2) != 1:
  15. return g, w
  16. # Roots of Unity
  17. def get_nth_roots_of_unity(n, primitive_w):
  18. w = [0]*n
  19. for i in range(n):
  20. w[i] = primitive_w^i
  21. return w
  22. # fft (Fast Fourier Transform) returns:
  23. # - nth roots of unity
  24. # - Vandermonde matrix for the nth roots of unity
  25. # - Inverse Vandermonde matrix
  26. def fft(F, n):
  27. g, primitive_w = get_primitive_root_of_unity(F, n)
  28. w = get_nth_roots_of_unity(n, primitive_w)
  29. ft = matrix(F, n)
  30. for j in range(n):
  31. row = []
  32. for k in range(n):
  33. row.append(primitive_w^(j*k))
  34. ft.set_row(j, row)
  35. ft_inv = ft^-1
  36. return w, ft, ft_inv
  37. # Fast polynomial multiplication using FFT
  38. def poly_mul(fa, fb, F, n):
  39. w, ft, ft_inv = fft(F, n)
  40. # compute evaluation points from polynomials fa & fb at the roots of unity
  41. a_evals = []
  42. b_evals = []
  43. for i in range(n):
  44. a_evals.append(fa(w[i]))
  45. b_evals.append(fb(w[i]))
  46. # multiply elements in a_evals by b_evals
  47. c_evals = map(operator.mul, a_evals, b_evals)
  48. c_evals = vector(c_evals)
  49. # using FFT, convert the c_evals into fc(x)
  50. fc_coef = c_evals*ft_inv
  51. fc2=P(fc_coef.list())
  52. return fc2, c_evals
  53. # Tests
  54. #####
  55. # Roots of Unity test:
  56. q = 17
  57. F = GF(q)
  58. n = 4
  59. g, primitive_w = get_primitive_root_of_unity(F, n)
  60. print("generator:", g)
  61. print("primitive_w:", primitive_w)
  62. w = get_nth_roots_of_unity(n, primitive_w)
  63. print(f"{n}th roots of unity: {w}")
  64. assert w == [1, 13, 16, 4]
  65. #####
  66. # FFT test:
  67. def isprime(num):
  68. for n in range(2,int(num^1/2)+1):
  69. if num%n==0:
  70. return False
  71. return True
  72. # list valid values for q
  73. for i in range(20):
  74. if isprime(8*i+1):
  75. print("q =", 8*i+1)
  76. q = 41
  77. F = GF(q)
  78. n = 4
  79. # q needs to be a prime, s.t. q-1 is divisible by n
  80. assert (q-1)%n==0
  81. print("q =", q, "n = ", n)
  82. # ft: Vandermonde matrix for the nth roots of unity
  83. w, ft, ft_inv = fft(F, n)
  84. print("nth roots of unity:", w)
  85. print("Vandermonde matrix:")
  86. print(ft)
  87. fa_eval = vector([3,4,5,9])
  88. print("fa_eval:", fa_eval)
  89. # interpolate f_a(x)
  90. fa_coef = ft_inv * fa_eval
  91. print("fa_coef:", fa_coef)
  92. P.<x> = PolynomialRing(F)
  93. fa = P(list(fa_coef))
  94. print("f_a(x):", fa)
  95. # check that evaluating fa(x) at the roots of unity returns the expected values of fa_eval
  96. for i in range(len(fa_eval)):
  97. assert fa(w[i]) == fa_eval[i]
  98. # go from coefficient form to evaluation form
  99. fa_eval2 = ft * fa_coef
  100. print("fa_eval'", fa_eval)
  101. assert fa_eval2 == fa_eval
  102. # Fast polynomial multiplication using FFT
  103. print("\n---------")
  104. print("---Fast polynomial multiplication using FFT")
  105. n = 8
  106. # q needs to be a prime, s.t. q-1 is divisible by n
  107. assert (q-1)%n==0
  108. print("q =", q, "n = ", n)
  109. fa=P([1,2,3,4])
  110. fb=P([1,2,3,4])
  111. fc_expected = fa*fb
  112. print("fc expected result:", fc_expected) # expected result
  113. print("fc expected coef", fc_expected.coefficients())
  114. fc, c_evals = poly_mul(fa, fb, F, n)
  115. print("c_evals=(a_evals*b_evals)=", c_evals)
  116. print("fc:", fc)
  117. assert fc_expected == fc