mlx/mlx/fft.h
paramthakkar123 80002ed42f Updates
2025-05-06 09:53:10 +05:30

215 lines
5.4 KiB
C++

// Copyright © 2023 Apple Inc.
#pragma once
#include <variant>
#include "array.h"
#include "device.h"
#include "mlx/mlx.h"
#include "utils.h"
namespace mx = mlx::core;
namespace mlx::core::fft {
/** Compute the n-dimensional Fourier Transform. */
array fftn(
const array& a,
const Shape& n,
const std::vector<int>& axes,
StreamOrDevice s = {});
array fftn(const array& a, const std::vector<int>& axes, StreamOrDevice s = {});
array fftn(const array& a, StreamOrDevice s = {});
/** Compute the n-dimensional inverse Fourier Transform. */
array ifftn(
const array& a,
const Shape& n,
const std::vector<int>& axes,
StreamOrDevice s = {});
array ifftn(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
array ifftn(const array& a, StreamOrDevice s = {});
/** Compute the one-dimensional Fourier Transform. */
inline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) {
return fftn(a, {n}, {axis}, s);
}
inline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) {
return fftn(a, {axis}, s);
}
/** Compute the one-dimensional inverse Fourier Transform. */
inline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) {
return ifftn(a, {n}, {axis}, s);
}
inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) {
return ifftn(a, {axis}, s);
}
/** Compute the two-dimensional Fourier Transform. */
inline array fft2(
const array& a,
const Shape& n,
const std::vector<int>& axes,
StreamOrDevice s = {}) {
return fftn(a, n, axes, s);
}
inline array fft2(
const array& a,
const std::vector<int>& axes = {-2, -1},
StreamOrDevice s = {}) {
return fftn(a, axes, s);
}
/** Compute the two-dimensional inverse Fourier Transform. */
inline array ifft2(
const array& a,
const Shape& n,
const std::vector<int>& axes,
StreamOrDevice s = {}) {
return ifftn(a, n, axes, s);
}
inline array ifft2(
const array& a,
const std::vector<int>& axes = {-2, -1},
StreamOrDevice s = {}) {
return ifftn(a, axes, s);
}
/** Compute the n-dimensional Fourier Transform on a real input. */
array rfftn(
const array& a,
const Shape& n,
const std::vector<int>& axes,
StreamOrDevice s = {});
array rfftn(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
array rfftn(const array& a, StreamOrDevice s = {});
/** Compute the n-dimensional inverse of `rfftn`. */
array irfftn(
const array& a,
const Shape& n,
const std::vector<int>& axes,
StreamOrDevice s = {});
array irfftn(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
array irfftn(const array& a, StreamOrDevice s = {});
/** Compute the one-dimensional Fourier Transform on a real input. */
inline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
return rfftn(a, {n}, {axis}, s);
}
inline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
return rfftn(a, {axis}, s);
}
/** Compute the one-dimensional inverse of `rfft`. */
inline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) {
return irfftn(a, {n}, {axis}, s);
}
inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) {
return irfftn(a, {axis}, s);
}
/** Compute the two-dimensional Fourier Transform on a real input. */
inline array rfft2(
const array& a,
const Shape& n,
const std::vector<int>& axes,
StreamOrDevice s = {}) {
return rfftn(a, n, axes, s);
}
inline array rfft2(
const array& a,
const std::vector<int>& axes = {-2, -1},
StreamOrDevice s = {}) {
return rfftn(a, axes, s);
}
/** Compute the two-dimensional inverse of `rfft2`. */
inline array irfft2(
const array& a,
const Shape& n,
const std::vector<int>& axes,
StreamOrDevice s = {}) {
return irfftn(a, n, axes, s);
}
inline array irfft2(
const array& a,
const std::vector<int>& axes = {-2, -1},
StreamOrDevice s = {}) {
return irfftn(a, axes, s);
}
/** Shift the zero-frequency component to the center of the spectrum. */
array fftshift(const array& a, StreamOrDevice s = {});
/** Shift the zero-frequency component to the center of the spectrum along
* specified axes. */
array fftshift(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
/** The inverse of fftshift. */
array ifftshift(const array& a, StreamOrDevice s = {});
/** The inverse of fftshift along specified axes. */
array ifftshift(
const array& a,
const std::vector<int>& axes,
StreamOrDevice s = {});
inline array stft(
const array& x,
int n_fft = 2048,
int hop_length = -1,
int win_length = -1,
const array& window = mx::array({}),
bool center = true,
const std::string& pad_mode = "reflect",
bool normalized = false,
bool onesided = true,
StreamOrDevice s = {}) {
return mlx::core::fft::stft(
x,
n_fft,
hop_length,
win_length,
window,
center,
pad_mode,
normalized,
onesided,
s);
}
inline array istft(
const array& stft_matrix,
int hop_length = -1,
int win_length = -1,
const array& window = mx::array({}),
bool center = true,
int length = -1,
bool normalized = false,
StreamOrDevice s = {}) {
return mlx::core::fft::istft(
stft_matrix,
hop_length,
win_length,
window,
center,
length,
normalized,
s);
}
} // namespace mlx::core::fft