From 56292d916737fa2af0824c833f0d517c0fdf659f Mon Sep 17 00:00:00 2001 From: arnaucube Date: Fri, 6 Nov 2020 22:58:10 +0100 Subject: [PATCH] Add DFT --- .gitignore | 2 ++ Cargo.toml | 19 ++++++++++++++++ src/lib.rs | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 src/lib.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..96ef6c0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..b177955 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "fft-rs" +version = "0.1.0" +authors = ["arnaucube "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +rand = "0.7.3" +num = "0.3.0" +num-complex= "0.3" + +[dev-dependencies] +criterion = "0.3" + +[[bench]] +name = "bench_fft" +harness = false diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..75fa60f --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,67 @@ +use num::complex::{Complex, Complex64}; +use std::f64::consts::PI; + +// dft computes the Discrete Fourier Transform +pub fn dft(x: &Vec) -> Vec { + let mut x_compl: Vec = vec![Complex::new(0_f64, 0_f64); x.len()]; + for i in 0..x.len() { + x_compl[i] = Complex::new(x[i], 0_f64); + } + + let mut w = Complex::new(0_f64, -2_f64 * PI / x.len() as f64); + + // f_k = SUM{n=0, N-1} f_n * e^(-j2pi*k*n)/N + // https://en.wikipedia.org/wiki/Discrete_Fourier_transform + let mut f: Vec> = Vec::new(); + for i in 0..x.len() { + let mut f_k: Vec = Vec::new(); + for j in 0..x.len() { + let i_compl = Complex::new(0_f64, i as f64); + let j_compl = Complex::new(0_f64, j as f64); + let fe = (w * i_compl * j_compl).exp(); + f_k.push(fe); + } + f.push(f_k.clone()); + } + let r = mul_vv(f, x_compl); + r +} + +// mul_vv multiplies a Matrix by a Vector +fn mul_vv(a: Vec>, b: Vec) -> Vec { + if a[0].len() != a.len() { + panic!("err b[0].len():{:?} != b.len():{:?}", a[0].len(), a.len()); + } + if a.len() != b.len() { + panic!("err a.len():{:?} != b.len():{:?}", a.len(), b.len()); + } + + let rows = a.len(); + let cols = a.len(); + + let mut c: Vec = vec![Complex::new(0_f64, 0_f64); cols]; + for i in 0..rows { + for j in 0..cols { + c[i] += a[i][j] * b[j]; + } + } + c +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dft_simple_values() { + let values: Vec = vec![0.2, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]; + let r = dft(&values); + assert_eq!(r.len(), 8); + + assert_eq!(format!("{:.2}", r[0]), "3.70+0.00i"); + assert_eq!(format!("{:.2}", r[1]), "-0.30-0.97i"); + assert_eq!(format!("{:.2}", r[2]), "-0.30-0.40i"); + assert_eq!(format!("{:.2}", r[3]), "-0.30-0.17i"); + assert_eq!(format!("{:.2}", r[4]), "-0.30+0.00i"); + } +}