@ -0,0 +1,42 @@ |
|||||
|
# Primitive Root of Unity |
||||
|
def get_primitive_root_of_unity(F, n): |
||||
|
# using the method described by Thomas Pornin in |
||||
|
# https://crypto.stackexchange.com/a/63616 |
||||
|
q = F.order() |
||||
|
for k in range(q): |
||||
|
if k==0: |
||||
|
continue |
||||
|
g = F(k) |
||||
|
# g = F.random_element() |
||||
|
if g==0: |
||||
|
continue |
||||
|
w = g ^ ((q-1)/n) |
||||
|
if w^(n/2) != 1: |
||||
|
return g, w |
||||
|
|
||||
|
# Roots of Unity |
||||
|
def get_nth_roots_of_unity(n, primitive_w): |
||||
|
w = [0]*n |
||||
|
for i in range(n): |
||||
|
w[i] = primitive_w^i |
||||
|
return w |
||||
|
|
||||
|
|
||||
|
# fft (Fast Fourier Transform) returns: |
||||
|
# - nth roots of unity |
||||
|
# - Vandermonde matrix for the nth roots of unity |
||||
|
# - Inverse Vandermonde matrix |
||||
|
def fft(F, n): |
||||
|
g, primitive_w = get_primitive_root_of_unity(F, n) |
||||
|
|
||||
|
w = get_nth_roots_of_unity(n, primitive_w) |
||||
|
|
||||
|
ft = matrix(F, n) |
||||
|
for j in range(n): |
||||
|
row = [] |
||||
|
for k in range(n): |
||||
|
row.append(primitive_w^(j*k)) |
||||
|
ft.set_row(j, row) |
||||
|
ft_inv = ft^-1 |
||||
|
return w, ft, ft_inv |
||||
|
|
@ -0,0 +1,59 @@ |
|||||
|
load("fft.sage") |
||||
|
|
||||
|
##### |
||||
|
# Roots of Unity test: |
||||
|
q = 17 |
||||
|
F = GF(q) |
||||
|
n = 4 |
||||
|
g, primitive_w = get_primitive_root_of_unity(F, n) |
||||
|
print("generator:", g) |
||||
|
print("primitive_w:", primitive_w) |
||||
|
|
||||
|
w = get_nth_roots_of_unity(n, primitive_w) |
||||
|
print(f"{n}th roots of unity: {w}") |
||||
|
assert w == [1, 13, 16, 4] |
||||
|
|
||||
|
|
||||
|
##### |
||||
|
# FFT test: |
||||
|
|
||||
|
def isprime(num): |
||||
|
for n in range(2,int(num^1/2)+1): |
||||
|
if num%n==0: |
||||
|
return False |
||||
|
return True |
||||
|
|
||||
|
# list valid values for q |
||||
|
for i in range(20): |
||||
|
if isprime(8*i+1): |
||||
|
print("q =", 8*i+1) |
||||
|
|
||||
|
q = 41 |
||||
|
F = GF(q) |
||||
|
n = 4 |
||||
|
# q needs to be a prime, s.t. q-1 is divisible by n |
||||
|
assert (q-1)%n==0 |
||||
|
print("q =", q, "n = ", n) |
||||
|
|
||||
|
# ft: Vandermonde matrix for the nth roots of unity |
||||
|
w, ft, ft_inv = fft(F, n) |
||||
|
print("nth roots of unity:", w) |
||||
|
print("Vandermonde matrix:") |
||||
|
print(ft) |
||||
|
|
||||
|
a = vector([3,4,5,9]) |
||||
|
print("a:", a) |
||||
|
|
||||
|
# interpolate f_a(x) |
||||
|
fa_coef = ft_inv * a |
||||
|
print("fa_coef:", fa_coef) |
||||
|
|
||||
|
P.<x> = PolynomialRing(F) |
||||
|
fa = P(list(fa_coef)) |
||||
|
print("f_a(x):", fa) |
||||
|
|
||||
|
# check that evaluating fa(x) at the roots of unity returns the expected values of a |
||||
|
for i in range(len(a)): |
||||
|
assert fa(w[i]) == a[i] |
||||
|
|
||||
|
|