You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

329 lines
7.0 KiB

//! Complex
use rand::Rng;
use std::ops::{Add, Div, Mul, Neg, Sub};
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct C<T> {
pub re: T,
pub im: T,
}
impl From<f64> for C<f64> {
fn from(v: f64) -> Self {
Self { re: v, im: 0_f64 }
}
}
impl<T> C<T>
where
T: Default + From<i32>,
{
pub fn new(re: T, im: T) -> Self {
Self { re, im }
}
pub fn rand(mut rng: impl Rng, max: u64) -> Self {
Self::new(
T::from(rng.gen_range(0..max) as i32),
T::from(rng.gen_range(0..max) as i32),
)
}
pub fn zero() -> C<T> {
Self {
re: T::from(0),
im: T::from(0),
}
}
pub fn one() -> C<T> {
Self {
re: T::from(1),
im: T::from(0),
}
}
pub fn i() -> C<T> {
Self {
re: T::from(0),
im: T::from(1),
}
}
}
impl C<f64> {
// cos & sin from Taylor series approximation, details at
// https://en.wikipedia.org/wiki/Sine_and_cosine#Series_and_polynomials
fn cos(x: f64) -> f64 {
let mut r = 1.0;
let mut term = 1.0;
let mut n = 1;
for _ in 0..10 {
term *= -(x * x) / ((2 * n - 1) * (2 * n)) as f64;
r += term;
n += 1;
}
r
}
fn sin(x: f64) -> f64 {
let mut r = x;
let mut term = x;
let mut n = 1;
for _ in 0..10 {
term *= -(x * x) / ((2 * n) * (2 * n + 1)) as f64;
r += term;
n += 1;
}
r
}
// e^(self))
pub fn exp(self) -> Self {
Self {
re: Self::cos(self.im), // TODO WIP review
im: Self::sin(self.im),
}
}
pub fn pow(self, k: u32) -> Self {
let mut k = k;
if k == 0 {
return Self::one();
}
let mut base = self.clone();
while k & 1 == 0 {
base = base.clone() * base;
k >>= 1;
}
if k == 1 {
return base;
}
let mut acc = base.clone();
while k > 1 {
k >>= 1;
base = base.clone() * base;
if k & 1 == 1 {
acc = acc * base.clone();
}
}
acc
}
pub fn modulus<const Q: u64>(self) -> Self {
let q: f64 = Q as f64;
let re = (self.re % q + q) % q;
let im = (self.im % q + q) % q;
Self { re, im }
}
pub fn modulus_centered<const Q: u64>(self) -> Self {
let re = modulus_centered_f64::<Q>(self.re);
let im = modulus_centered_f64::<Q>(self.im);
Self { re, im }
}
}
fn modulus_centered_f64<const Q: u64>(v: f64) -> f64 {
let q = Q as f64;
let mut res = v % q;
if res > q / 2.0 {
res = res - q;
}
res
}
impl<T: Default> Default for C<T> {
fn default() -> Self {
C {
re: T::default(),
im: T::default(),
}
}
}
impl<T> Add for C<T>
where
T: Add<Output = T> + Copy,
{
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
C {
re: self.re + rhs.re,
im: self.im + rhs.im,
}
}
}
impl<T> Sub for C<T>
where
T: Sub<Output = T> + Copy,
{
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
C {
re: self.re - rhs.re,
im: self.im - rhs.im,
}
}
}
impl<T> Mul for C<T>
where
T: Mul<Output = T> + Sub<Output = T> + Add<Output = T> + Copy,
{
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
C {
re: self.re * rhs.re - self.im * rhs.im,
im: self.re * rhs.im + self.im * rhs.re,
}
}
}
impl<T> Neg for C<T>
where
T: Neg<Output = T> + Copy,
{
type Output = Self;
fn neg(self) -> Self::Output {
C {
re: -self.re,
im: -self.im,
}
}
}
impl<T> Div for C<T>
where
T: Copy
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Neg<Output = T>,
{
type Output = Self;
fn div(self, rhs: Self) -> Self {
// (a+ib)/(c+id) = (ac + bd)/(c^2 + d^2) + i* (bc -ad)/(c^2 + d^2)
let den = rhs.re * rhs.re + rhs.im * rhs.im;
C {
re: (self.re * rhs.re + self.im * rhs.im) / den,
im: (self.im * rhs.re - self.re * rhs.im) / den,
}
}
}
impl<T> C<T>
where
T: Neg<Output = T> + Copy,
{
pub fn conj(&self) -> Self {
C {
re: self.re,
im: -self.im,
}
}
}
impl<T: Add<Output = T> + Default + Copy> std::iter::Sum for C<T> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(C::default(), |acc, x| acc + x)
}
}
impl C<f64> {
pub fn abs(&self) -> f64 {
(self.re * self.re + self.im * self.im).sqrt()
}
}
// poly mul with complex coefficients
pub fn naive_poly_mul<const N: usize>(poly1: &Vec<C<f64>>, poly2: &Vec<C<f64>>) -> Vec<C<f64>> {
let mut result: Vec<C<f64>> = vec![C::<f64>::zero(); (N * 2) - 1];
for i in 0..N {
for j in 0..N {
result[i + j] = result[i + j] + poly1[i] * poly2[j];
}
}
// apply mod (X^N + 1))
// R::<N>::from_vec(result.iter().map(|c| *c as i64).collect())
// modulus_i128::<N>(&mut result);
// dbg!(&result);
// dbg!(R::<N>(array::from_fn(|i| result[i] as i64)).coeffs());
// R(array::from_fn(|i| result[i] as i64))
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_i() {
assert_eq!(C::i(), C::new(0.0, 1.0));
}
#[test]
fn test_add() {
let a = C::new(2.0, 3.0);
let b = C::new(1.0, 4.0);
assert_eq!(a + b, C::new(3.0, 7.0));
}
#[test]
fn test_sub() {
let a = C::new(5.0, 7.0);
let b = C::new(2.0, 3.0);
assert_eq!(a - b, C::new(3.0, 4.0));
}
#[test]
fn test_mult() {
let a = C::new(1.0, 2.0);
let b = C::new(3.0, 4.0);
assert_eq!(a * b, C::new(-5.0, 10.0));
}
#[test]
fn test_div() {
let a: C<f64> = C::new(1.0, 2.0);
let b: C<f64> = C::new(3.0, 4.0);
let r = a / b;
let expected = C::new(0.44, 0.08);
let epsilon = 1e-2;
assert!((r.re - expected.re).abs() < epsilon);
assert!((r.im - expected.im).abs() < epsilon);
}
#[test]
fn test_conj() {
let a = C::new(3.0, -4.0);
assert_eq!(a.conj(), C::new(3.0, 4.0));
assert_eq!(a.conj().conj(), a);
}
#[test]
fn test_neg() {
let a = C::new(1.0, -2.0);
assert_eq!(-a, C::new(-1.0, 2.0));
}
#[test]
fn test_abs() {
let a = C::new(3.0, 4.0);
assert!((a.abs() - 5.0).abs() < 1e-10);
}
#[test]
fn test_exp() {
// let a = C::new(3.0, 4.0);
let pi = C::<f64>::from(std::f64::consts::PI);
let n = 4;
let a = ((C::<f64>::from(2f64) * pi * C::<f64>::i()) / C::<f64>::new(n as f64, 0f64)).exp();
dbg!(&a);
assert_eq!(a.exp(), a.exp());
}
}