diff --git a/README.md b/README.md new file mode 100644 index 0000000..b680dd3 --- /dev/null +++ b/README.md @@ -0,0 +1,6 @@ +# fft-rs + +Fast Fourier Transform implementation in Rust. + +https://en.wikipedia.org/wiki/Fast_Fourier_transform & [DFT](https://en.wikipedia.org/wiki/Discrete_Fourier_transform) + diff --git a/src/lib.rs b/src/lib.rs index eb3e40c..1f96942 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,49 @@ use num::complex::{Complex, Complex64}; use std::f64::consts::PI; +// fft computes the Fast Fourier Transform +pub fn fft(x: &Vec) -> Vec { + let N = x.len(); + if N % 2 > 0 { + panic!("not a power of 2"); + } else if N <= 2 { + return dft(x); + } + + let mut x_even: Vec = Vec::new(); + let mut x_odd: Vec = Vec::new(); + for i in 0..x.len() { + if i % 2 == 0 { + x_even.push(x[i]); + } else { + x_odd.push(x[i]); + } + } + let mut x_even_cmplx = fft(&x_even); + let mut x_odd_cmplx = fft(&x_odd); + + let mut w = Complex::new(0_f64, 2_f64 * PI / N as f64); + let mut f_k: Vec = Vec::new(); + for k in 0..x.len() { + let k_compl = Complex::new(k as f64, 0_f64); + f_k.push((w * k_compl).exp()); + } + + let mut r: Vec = Vec::new(); + let mut aa = add_vv( + x_even_cmplx.clone(), + mul_vv_el(x_odd_cmplx.clone(), f_k.clone()[0..x.len() / 2].to_vec()), + ); + let mut bb = add_vv( + x_even_cmplx.clone(), + mul_vv_el(x_odd_cmplx.clone(), f_k.clone()[x.len() / 2..].to_vec()), + ); + r.append(&mut aa); + r.append(&mut bb); + + r +} + // dft computes the Discrete Fourier Transform pub fn dft(x: &Vec) -> Vec { let mut x_compl: Vec = vec![Complex::new(0_f64, 0_f64); x.len()]; @@ -74,6 +117,29 @@ fn mul_mv(a: Vec>, b: Vec) -> Vec { c } +fn add_vv(a: Vec, b: Vec) -> Vec { + if a.len() != b.len() { + panic!("err a.len():{:?} != b.len():{:?}", a.len(), b.len()); + } + let mut c: Vec = vec![Complex::new(0_f64, 0_f64); a.len()]; + for i in 0..a.len() { + c[i] = a[i] + b[i]; + } + c +} + +// mul_vv_el multiplies elements of one vector by the elements of another vector +fn mul_vv_el(a: Vec, b: Vec) -> Vec { + if a.len() != b.len() { + panic!("err a.len():{:?} != b.len():{:?}", a.len(), b.len()); + } + let mut c: Vec = vec![Complex::new(0_f64, 0_f64); a.len()]; + for i in 0..a.len() { + c[i] = a[i] * b[i]; + } + c +} + #[cfg(test)] mod tests { use super::*; @@ -101,4 +167,18 @@ mod tests { assert_eq!(format!("{:.1}", o[6]), "0.7"); assert_eq!(format!("{:.1}", o[7]), "0.8"); } + + #[test] + fn test_fft_simple_values() { + let values: Vec = vec![0.2, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]; + let r = fft(&values); + + assert_eq!(r.len(), 8); + + assert_eq!(format!("{:.2}", r[0]), "3.70+0.00i"); + assert_eq!(format!("{:.2}", r[1]), "-0.30-0.97i"); + assert_eq!(format!("{:.2}", r[2]), "-0.30-0.40i"); + assert_eq!(format!("{:.2}", r[3]), "-0.30-0.17i"); + assert_eq!(format!("{:.2}", r[4]), "-0.30+0.00i"); + } }