mirror of
https://github.com/arnaucube/math.git
synced 2026-01-09 23:41:33 +01:00
Add FFT & Roots of Unity in Sage
This commit is contained in:
42
fft.sage
Normal file
42
fft.sage
Normal file
@@ -0,0 +1,42 @@
|
||||
# Primitive Root of Unity
|
||||
def get_primitive_root_of_unity(F, n):
|
||||
# using the method described by Thomas Pornin in
|
||||
# https://crypto.stackexchange.com/a/63616
|
||||
q = F.order()
|
||||
for k in range(q):
|
||||
if k==0:
|
||||
continue
|
||||
g = F(k)
|
||||
# g = F.random_element()
|
||||
if g==0:
|
||||
continue
|
||||
w = g ^ ((q-1)/n)
|
||||
if w^(n/2) != 1:
|
||||
return g, w
|
||||
|
||||
# Roots of Unity
|
||||
def get_nth_roots_of_unity(n, primitive_w):
|
||||
w = [0]*n
|
||||
for i in range(n):
|
||||
w[i] = primitive_w^i
|
||||
return w
|
||||
|
||||
|
||||
# fft (Fast Fourier Transform) returns:
|
||||
# - nth roots of unity
|
||||
# - Vandermonde matrix for the nth roots of unity
|
||||
# - Inverse Vandermonde matrix
|
||||
def fft(F, n):
|
||||
g, primitive_w = get_primitive_root_of_unity(F, n)
|
||||
|
||||
w = get_nth_roots_of_unity(n, primitive_w)
|
||||
|
||||
ft = matrix(F, n)
|
||||
for j in range(n):
|
||||
row = []
|
||||
for k in range(n):
|
||||
row.append(primitive_w^(j*k))
|
||||
ft.set_row(j, row)
|
||||
ft_inv = ft^-1
|
||||
return w, ft, ft_inv
|
||||
|
||||
59
fft_test.sage
Normal file
59
fft_test.sage
Normal file
@@ -0,0 +1,59 @@
|
||||
load("fft.sage")
|
||||
|
||||
#####
|
||||
# Roots of Unity test:
|
||||
q = 17
|
||||
F = GF(q)
|
||||
n = 4
|
||||
g, primitive_w = get_primitive_root_of_unity(F, n)
|
||||
print("generator:", g)
|
||||
print("primitive_w:", primitive_w)
|
||||
|
||||
w = get_nth_roots_of_unity(n, primitive_w)
|
||||
print(f"{n}th roots of unity: {w}")
|
||||
assert w == [1, 13, 16, 4]
|
||||
|
||||
|
||||
#####
|
||||
# FFT test:
|
||||
|
||||
def isprime(num):
|
||||
for n in range(2,int(num^1/2)+1):
|
||||
if num%n==0:
|
||||
return False
|
||||
return True
|
||||
|
||||
# list valid values for q
|
||||
for i in range(20):
|
||||
if isprime(8*i+1):
|
||||
print("q =", 8*i+1)
|
||||
|
||||
q = 41
|
||||
F = GF(q)
|
||||
n = 4
|
||||
# q needs to be a prime, s.t. q-1 is divisible by n
|
||||
assert (q-1)%n==0
|
||||
print("q =", q, "n = ", n)
|
||||
|
||||
# ft: Vandermonde matrix for the nth roots of unity
|
||||
w, ft, ft_inv = fft(F, n)
|
||||
print("nth roots of unity:", w)
|
||||
print("Vandermonde matrix:")
|
||||
print(ft)
|
||||
|
||||
a = vector([3,4,5,9])
|
||||
print("a:", a)
|
||||
|
||||
# interpolate f_a(x)
|
||||
fa_coef = ft_inv * a
|
||||
print("fa_coef:", fa_coef)
|
||||
|
||||
P.<x> = PolynomialRing(F)
|
||||
fa = P(list(fa_coef))
|
||||
print("f_a(x):", fa)
|
||||
|
||||
# check that evaluating fa(x) at the roots of unity returns the expected values of a
|
||||
for i in range(len(a)):
|
||||
assert fa(w[i]) == a[i]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user