Browse Source

Add FFT & Roots of Unity in Sage

master
arnaucube 2 years ago
parent
commit
9468993bda
2 changed files with 101 additions and 0 deletions
  1. +42
    -0
      fft.sage
  2. +59
    -0
      fft_test.sage

+ 42
- 0
fft.sage

@ -0,0 +1,42 @@
# Primitive Root of Unity
def get_primitive_root_of_unity(F, n):
# using the method described by Thomas Pornin in
# https://crypto.stackexchange.com/a/63616
q = F.order()
for k in range(q):
if k==0:
continue
g = F(k)
# g = F.random_element()
if g==0:
continue
w = g ^ ((q-1)/n)
if w^(n/2) != 1:
return g, w
# Roots of Unity
def get_nth_roots_of_unity(n, primitive_w):
w = [0]*n
for i in range(n):
w[i] = primitive_w^i
return w
# fft (Fast Fourier Transform) returns:
# - nth roots of unity
# - Vandermonde matrix for the nth roots of unity
# - Inverse Vandermonde matrix
def fft(F, n):
g, primitive_w = get_primitive_root_of_unity(F, n)
w = get_nth_roots_of_unity(n, primitive_w)
ft = matrix(F, n)
for j in range(n):
row = []
for k in range(n):
row.append(primitive_w^(j*k))
ft.set_row(j, row)
ft_inv = ft^-1
return w, ft, ft_inv

+ 59
- 0
fft_test.sage

@ -0,0 +1,59 @@
load("fft.sage")
#####
# Roots of Unity test:
q = 17
F = GF(q)
n = 4
g, primitive_w = get_primitive_root_of_unity(F, n)
print("generator:", g)
print("primitive_w:", primitive_w)
w = get_nth_roots_of_unity(n, primitive_w)
print(f"{n}th roots of unity: {w}")
assert w == [1, 13, 16, 4]
#####
# FFT test:
def isprime(num):
for n in range(2,int(num^1/2)+1):
if num%n==0:
return False
return True
# list valid values for q
for i in range(20):
if isprime(8*i+1):
print("q =", 8*i+1)
q = 41
F = GF(q)
n = 4
# q needs to be a prime, s.t. q-1 is divisible by n
assert (q-1)%n==0
print("q =", q, "n = ", n)
# ft: Vandermonde matrix for the nth roots of unity
w, ft, ft_inv = fft(F, n)
print("nth roots of unity:", w)
print("Vandermonde matrix:")
print(ft)
a = vector([3,4,5,9])
print("a:", a)
# interpolate f_a(x)
fa_coef = ft_inv * a
print("fa_coef:", fa_coef)
P.<x> = PolynomialRing(F)
fa = P(list(fa_coef))
print("f_a(x):", fa)
# check that evaluating fa(x) at the roots of unity returns the expected values of a
for i in range(len(a)):
assert fa(w[i]) == a[i]

Loading…
Cancel
Save