mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 19:11:17 +08:00
Tmp FFT commit
This commit is contained in:
parent
0cae0bdac8
commit
1704809f29
@ -6,4 +6,5 @@ target_sources(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/transpose.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "mlx/backend/common/broadcasting.h"
|
||||
#include "mlx/backend/common/transpose.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
@ -19,26 +20,19 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
||||
"AsStrided must be used with row contiguous arrays only.");
|
||||
}
|
||||
|
||||
// Compute the flags given the shape and strides
|
||||
bool row_contiguous = true, col_contiguous = true;
|
||||
size_t r = 1, c = 1;
|
||||
for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) {
|
||||
row_contiguous &= (r == strides_[i]) || (shape_[i] == 1);
|
||||
col_contiguous &= (c == strides_[j]) || (shape_[j] == 1);
|
||||
r *= shape_[i];
|
||||
c *= shape_[j];
|
||||
}
|
||||
// Calculate the contiguity based on the given shape and strides
|
||||
auto [ds, rc, cc] = check_contiguity(shape_, strides_);
|
||||
auto flags = in.flags();
|
||||
|
||||
// TODO: Compute the contiguous flag in a better way cause now we are
|
||||
// unnecessarily strict.
|
||||
flags.contiguous = row_contiguous || col_contiguous;
|
||||
flags.row_contiguous = row_contiguous;
|
||||
flags.col_contiguous = col_contiguous;
|
||||
flags.contiguous = rc || cc;
|
||||
flags.row_contiguous = rc;
|
||||
flags.col_contiguous = cc;
|
||||
|
||||
// There is no easy way to compute the actual data size so we use out.size().
|
||||
// The contiguous flag will almost certainly not be set so no code should
|
||||
// rely on data_size anyway.
|
||||
size_t data_size = out.size();
|
||||
// There is no easy way to compute the actual data size so we use out.size()
|
||||
// when the array is not contiguous.
|
||||
size_t data_size = flags.contiguous ? ds : out.size();
|
||||
|
||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
||||
}
|
||||
@ -270,36 +264,7 @@ void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 1);
|
||||
Strides out_strides(out.ndim());
|
||||
auto& in = inputs[0];
|
||||
for (int ax = 0; ax < axes_.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes_[ax]];
|
||||
}
|
||||
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - array must be contiguous (no gaps)
|
||||
// - underlying buffer size should have the same size as the array
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
// - vectors are both row and col contiguous (hence if both row/col are
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
flags.col_contiguous = true;
|
||||
flags.row_contiguous = true;
|
||||
for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) {
|
||||
flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1);
|
||||
f_stride *= out.shape(i);
|
||||
flags.row_contiguous &=
|
||||
(out_strides[ri] == b_stride || out.shape(ri) == 1);
|
||||
b_stride *= out.shape(ri);
|
||||
}
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
transpose(inputs[0], out, axes_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
31
mlx/backend/common/transpose.cpp
Normal file
31
mlx/backend/common/transpose.cpp
Normal file
@ -0,0 +1,31 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void transpose(const array& in, array& out, const std::vector<int>& axes) {
|
||||
Strides out_strides(out.ndim());
|
||||
for (int ax = 0; ax < axes.size(); ++ax) {
|
||||
out_strides[ax] = in.strides()[axes[ax]];
|
||||
}
|
||||
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - array must be contiguous (no gaps)
|
||||
// - underlying buffer size should have the same size as the array
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
// - vectors are both row and col contiguous (hence if both row/col are
|
||||
// true, they stay true)
|
||||
auto flags = in.flags();
|
||||
if (flags.contiguous && in.data_size() == in.size()) {
|
||||
auto [_, rc, cc] = check_contiguity(out.shape(), out_strides);
|
||||
flags.row_contiguous = rc;
|
||||
flags.col_contiguous = cc;
|
||||
}
|
||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
11
mlx/backend/common/transpose.h
Normal file
11
mlx/backend/common/transpose.h
Normal file
@ -0,0 +1,11 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "mlx/array.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void transpose(const array& in, array& out, const std::vector<int>& axes);
|
||||
|
||||
} // namespace mlx::core
|
@ -132,6 +132,11 @@ struct ContiguousIterator {
|
||||
};
|
||||
|
||||
inline auto check_contiguity(const Shape& shape, const Strides& strides) {
|
||||
// Conditions for {row/col}_contiguous
|
||||
// - cumulative product of shapes is equal to the strides (we can ignore axes
|
||||
// with size == 1)
|
||||
// - in the forward direction (column contiguous)
|
||||
// - in the reverse direction (row contiguous)
|
||||
size_t no_broadcast_data_size = 1;
|
||||
int64_t f_stride = 1;
|
||||
int64_t b_stride = 1;
|
||||
|
@ -71,7 +71,12 @@ void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
(allow_col_major_ && in.flags().col_contiguous))) {
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy_gpu(in, out, CopyType::General);
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.flags().row_contiguous ? CopyType::Vector : CopyType::General,
|
||||
stream());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,11 +1,13 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#include <cassert>
|
||||
#include <complex>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
|
||||
#include "mlx/3rdparty/pocketfft.h"
|
||||
#include "mlx/backend/common/transpose.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
@ -27,7 +29,7 @@ using MTLFC = std::tuple<const void*, MTL::DataType, NS::UInteger>;
|
||||
// For strided reads/writes, coalesce at least this many complex64s
|
||||
#define MIN_COALESCE_WIDTH 4
|
||||
|
||||
inline const std::vector<int> supported_radices() {
|
||||
inline constexpr std::array<int, 9> supported_radices() {
|
||||
// Ordered by preference in decomposition.
|
||||
return {13, 11, 8, 7, 6, 5, 4, 3, 2};
|
||||
}
|
||||
@ -65,6 +67,7 @@ void fft_op(
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
metal::Device& d,
|
||||
const Stream& s);
|
||||
|
||||
struct FFTPlan {
|
||||
@ -112,13 +115,10 @@ std::vector<int> plan_stockham_fft(int n) {
|
||||
|
||||
FFTPlan plan_fft(int n) {
|
||||
auto radices = supported_radices();
|
||||
std::set<int> radices_set(radices.begin(), radices.end());
|
||||
|
||||
FFTPlan plan;
|
||||
plan.n = n;
|
||||
plan.rader = std::vector<int>(radices.size(), 0);
|
||||
auto factors = prime_factors(n);
|
||||
int remaining_n = n;
|
||||
|
||||
// Four Step FFT when N is too large for shared mem.
|
||||
if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) {
|
||||
@ -128,16 +128,20 @@ FFTPlan plan_fft(int n) {
|
||||
plan.n2 = n > 65536 ? 1024 : 64;
|
||||
plan.n1 = n / plan.n2;
|
||||
return plan;
|
||||
} else if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
}
|
||||
|
||||
if (n > MAX_STOCKHAM_FFT_SIZE) {
|
||||
// Otherwise we use a multi-upload Bluestein's
|
||||
plan.four_step = true;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
return plan;
|
||||
}
|
||||
|
||||
int remaining_n = n;
|
||||
auto factors = prime_factors(n);
|
||||
for (int factor : factors) {
|
||||
// Make sure the factor is a supported radix
|
||||
if (radices_set.find(factor) == radices_set.end()) {
|
||||
if (std::find(radices.begin(), radices.end(), factor) == radices.end()) {
|
||||
// We only support a single Rader factor currently
|
||||
// TODO(alexbarron) investigate weirdness with large
|
||||
// Rader sizes -- possibly a compiler issue?
|
||||
@ -154,7 +158,7 @@ FFTPlan plan_fft(int n) {
|
||||
for (int rf : rader_factors) {
|
||||
// We don't nest Rader's algorithm so if `factor - 1`
|
||||
// isn't Stockham decomposable we give up and do Bluestein's.
|
||||
if (radices_set.find(rf) == radices_set.end()) {
|
||||
if (std::find(radices.begin(), radices.end(), rf) == radices.end()) {
|
||||
plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE;
|
||||
plan.bluestein_n = next_fast_n(2 * n - 1);
|
||||
plan.stockham = plan_stockham_fft(plan.bluestein_n);
|
||||
@ -358,6 +362,8 @@ void multi_upload_bluestein_fft(
|
||||
FFTPlan& plan,
|
||||
std::vector<array>& copies,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// TODO(alexbarron) Implement fused kernels for mutli upload bluestein's
|
||||
// algorithm
|
||||
int n = inverse ? out.shape(axis) : in.shape(axis);
|
||||
@ -420,6 +426,7 @@ void multi_upload_bluestein_fft(
|
||||
/*real=*/false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/false,
|
||||
d,
|
||||
s);
|
||||
copies.push_back(pad_temp1);
|
||||
|
||||
@ -435,6 +442,7 @@ void multi_upload_bluestein_fft(
|
||||
/* real= */ false,
|
||||
FourStepParams(),
|
||||
/*inplace=*/true,
|
||||
d,
|
||||
s);
|
||||
|
||||
int offset = plan.bluestein_n - (2 * n - 1);
|
||||
@ -493,7 +501,15 @@ void four_step_fft(
|
||||
auto temp_shape = (real && inverse) ? out.shape() : in.shape();
|
||||
array temp(temp_shape, complex64, nullptr, {});
|
||||
fft_op(
|
||||
in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s);
|
||||
in,
|
||||
temp,
|
||||
axis,
|
||||
inverse,
|
||||
real,
|
||||
four_step_params,
|
||||
/*inplace=*/false,
|
||||
d,
|
||||
s);
|
||||
four_step_params.first_step = false;
|
||||
fft_op(
|
||||
temp,
|
||||
@ -503,6 +519,7 @@ void four_step_fft(
|
||||
real,
|
||||
four_step_params,
|
||||
/*inplace=*/in_place,
|
||||
d,
|
||||
s);
|
||||
copies.push_back(temp);
|
||||
} else {
|
||||
@ -518,9 +535,8 @@ void fft_op(
|
||||
bool real,
|
||||
const FourStepParams four_step_params,
|
||||
bool inplace,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
|
||||
if (n == 1) {
|
||||
out.copy_shared_buffer(in);
|
||||
@ -755,57 +771,116 @@ void fft_op(
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void fft_op(
|
||||
inline array prepare_input(
|
||||
|
||||
void fft_stockham_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
bool inplace,
|
||||
metal::Device& d,
|
||||
const Stream& s) {
|
||||
fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s);
|
||||
|
||||
}
|
||||
|
||||
void nd_fft_op(
|
||||
void fft_op_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
size_t axis,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device &d,
|
||||
const Stream& s) {
|
||||
// Get the FFT size and plan it
|
||||
size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis);
|
||||
auto plan = plan_fft(n);
|
||||
if (n == 1) {
|
||||
std::cout << "--------------> 1-size FFT <-----------------" << std::endl;
|
||||
}
|
||||
|
||||
if (plan.four_step && plan.bluestein_n < 0) {
|
||||
// four_step_fft(in, out, axis, inverse, real, plan, inplace, d, s);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void nd_fft_op_inplace(
|
||||
const array& in,
|
||||
array& out,
|
||||
const std::vector<size_t>& axes,
|
||||
bool inverse,
|
||||
bool real,
|
||||
metal::Device &d,
|
||||
const Stream& s) {
|
||||
// Perform ND FFT on GPU as a series of 1D FFTs
|
||||
auto temp_shape = inverse ? in.shape() : out.shape();
|
||||
array temp1(temp_shape, complex64, nullptr, {});
|
||||
array temp2(temp_shape, complex64, nullptr, {});
|
||||
std::vector<array> temp_arrs = {temp1, temp2};
|
||||
for (int i = axes.size() - 1; i >= 0; i--) {
|
||||
int reverse_index = axes.size() - i - 1;
|
||||
// For 5D and above, we don't want to reallocate our two temporary arrays
|
||||
bool inplace = reverse_index >= 3 && i != 0;
|
||||
// Opposite order for fft vs ifft
|
||||
int index = inverse ? reverse_index : i;
|
||||
size_t axis = axes[index];
|
||||
// Mirror np.fft.(i)rfftn and perform a real transform
|
||||
// only on the final axis.
|
||||
bool step_real = (real && index == axes.size() - 1);
|
||||
auto step_shape = inverse ? out.shape(axis) : in.shape(axis);
|
||||
const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2];
|
||||
array& out_arr = i == 0 ? out : temp_arrs[i % 2];
|
||||
fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s);
|
||||
}
|
||||
// We are going to make and possibly reuse some intermediate arrays that will
|
||||
// hold the intermediate fft results.
|
||||
auto shape = inverse ? in.shape() : out.shape();
|
||||
std::vector<array> intermediates;
|
||||
intermediates.reserve(2);
|
||||
|
||||
auto& d = metal::device(s.device);
|
||||
d.add_temporaries(std::move(temp_arrs), s.index);
|
||||
// Utility to return either in or one of the intermediates.
|
||||
auto get_input_array = [&](int step) -> const array& {
|
||||
// The first step so use the input array
|
||||
if (step == 0) {
|
||||
return in;
|
||||
}
|
||||
|
||||
return intermediates[(step - 1) % 2];
|
||||
};
|
||||
|
||||
// Utility to return either out or one of the intermediates. It also informs
|
||||
// us if we should allocate memory for that output or there is already some
|
||||
// allocated.
|
||||
auto get_output_array = [&](int step) -> array& {
|
||||
// It is the final step so return the output array
|
||||
if (step == axes.size() - 1) {
|
||||
return out;
|
||||
}
|
||||
|
||||
// We already have made an array that we can use so go ahead and use it and
|
||||
// don't reallocate the memory.
|
||||
if (step % 2 < intermediates.size()) {
|
||||
return intermediates[step % 2];
|
||||
}
|
||||
|
||||
array x(shape, complex64, nullptr, {});
|
||||
x.set_data(allocator::malloc(x.nbytes()));
|
||||
intermediates.emplace_back(std::move(x));
|
||||
d.add_temporary(intermediates.back(), s.index);
|
||||
|
||||
return intermediates.back();
|
||||
};
|
||||
|
||||
// Perform ND FFT on GPU as a series of 1D FFTs
|
||||
for (int step = 0; step < axes.size(); step++) {
|
||||
auto x = get_input_array(step);
|
||||
auto y = get_output_array(step);
|
||||
auto step_axis = axes[inverse ? step : axes.size() - step - 1];
|
||||
auto step_real = real && (inverse ? step == axes.size() - 1 : step == 0);
|
||||
fft_op_inplace(x, y, step_axis, inverse, step_real, d, s);
|
||||
}
|
||||
}
|
||||
|
||||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& s = stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto& in = inputs[0];
|
||||
|
||||
// The FFT ops above have the *_inplace suffix. This means that the memory
|
||||
// needs to be already allocated in the output array. Similar to
|
||||
// copy_gpu_inplace and so on.
|
||||
//
|
||||
// Even though we allocate the memory, we do not necessarily want the
|
||||
// contiguous strides so the *_inplace ops may change the strides and flags
|
||||
// of the array but will not reallocate the memory.
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
|
||||
if (axes_.size() > 1) {
|
||||
nd_fft_op(in, out, axes_, inverse_, real_, s);
|
||||
nd_fft_op_inplace(in, out, axes_, inverse_, real_, d, s);
|
||||
} else {
|
||||
fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s);
|
||||
fft_op_inplace(in, out, axes_[0], inverse_, real_, d, s);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user