diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 6c4e25067..efeacc375 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -6,4 +6,5 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/transpose.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) diff --git a/mlx/backend/common/common.cpp b/mlx/backend/common/common.cpp index 2cda88a31..86376b6a6 100644 --- a/mlx/backend/common/common.cpp +++ b/mlx/backend/common/common.cpp @@ -2,6 +2,7 @@ #include #include "mlx/backend/common/broadcasting.h" +#include "mlx/backend/common/transpose.h" #include "mlx/backend/common/utils.h" #include "mlx/primitives.h" @@ -19,26 +20,19 @@ void AsStrided::eval(const std::vector& inputs, array& out) { "AsStrided must be used with row contiguous arrays only."); } - // Compute the flags given the shape and strides - bool row_contiguous = true, col_contiguous = true; - size_t r = 1, c = 1; - for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) { - row_contiguous &= (r == strides_[i]) || (shape_[i] == 1); - col_contiguous &= (c == strides_[j]) || (shape_[j] == 1); - r *= shape_[i]; - c *= shape_[j]; - } + // Calculate the contiguity based on the given shape and strides + auto [ds, rc, cc] = check_contiguity(shape_, strides_); auto flags = in.flags(); + // TODO: Compute the contiguous flag in a better way cause now we are // unnecessarily strict. - flags.contiguous = row_contiguous || col_contiguous; - flags.row_contiguous = row_contiguous; - flags.col_contiguous = col_contiguous; + flags.contiguous = rc || cc; + flags.row_contiguous = rc; + flags.col_contiguous = cc; - // There is no easy way to compute the actual data size so we use out.size(). - // The contiguous flag will almost certainly not be set so no code should - // rely on data_size anyway. - size_t data_size = out.size(); + // There is no easy way to compute the actual data size so we use out.size() + // when the array is not contiguous. + size_t data_size = flags.contiguous ? ds : out.size(); return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); } @@ -270,36 +264,7 @@ void StopGradient::eval(const std::vector& inputs, array& out) { void Transpose::eval(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - Strides out_strides(out.ndim()); - auto& in = inputs[0]; - for (int ax = 0; ax < axes_.size(); ++ax) { - out_strides[ax] = in.strides()[axes_[ax]]; - } - - // Conditions for {row/col}_contiguous - // - array must be contiguous (no gaps) - // - underlying buffer size should have the same size as the array - // - cumulative product of shapes is equal to the strides (we can ignore axes - // with size == 1) - // - in the forward direction (column contiguous) - // - in the reverse direction (row contiguous) - // - vectors are both row and col contiguous (hence if both row/col are - // true, they stay true) - auto flags = in.flags(); - if (flags.contiguous && in.data_size() == in.size()) { - int64_t f_stride = 1; - int64_t b_stride = 1; - flags.col_contiguous = true; - flags.row_contiguous = true; - for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) { - flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1); - f_stride *= out.shape(i); - flags.row_contiguous &= - (out_strides[ri] == b_stride || out.shape(ri) == 1); - b_stride *= out.shape(ri); - } - } - out.copy_shared_buffer(in, out_strides, flags, in.data_size()); + transpose(inputs[0], out, axes_); } } // namespace mlx::core diff --git a/mlx/backend/common/transpose.cpp b/mlx/backend/common/transpose.cpp new file mode 100644 index 000000000..d0c79011e --- /dev/null +++ b/mlx/backend/common/transpose.cpp @@ -0,0 +1,31 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +void transpose(const array& in, array& out, const std::vector& axes) { + Strides out_strides(out.ndim()); + for (int ax = 0; ax < axes.size(); ++ax) { + out_strides[ax] = in.strides()[axes[ax]]; + } + + // Conditions for {row/col}_contiguous + // - array must be contiguous (no gaps) + // - underlying buffer size should have the same size as the array + // - cumulative product of shapes is equal to the strides (we can ignore axes + // with size == 1) + // - in the forward direction (column contiguous) + // - in the reverse direction (row contiguous) + // - vectors are both row and col contiguous (hence if both row/col are + // true, they stay true) + auto flags = in.flags(); + if (flags.contiguous && in.data_size() == in.size()) { + auto [_, rc, cc] = check_contiguity(out.shape(), out_strides); + flags.row_contiguous = rc; + flags.col_contiguous = cc; + } + out.copy_shared_buffer(in, out_strides, flags, in.data_size()); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/transpose.h b/mlx/backend/common/transpose.h new file mode 100644 index 000000000..378e80a2a --- /dev/null +++ b/mlx/backend/common/transpose.h @@ -0,0 +1,11 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void transpose(const array& in, array& out, const std::vector& axes); + +} // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index 20a65d7b1..685bcb1ca 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -132,6 +132,11 @@ struct ContiguousIterator { }; inline auto check_contiguity(const Shape& shape, const Strides& strides) { + // Conditions for {row/col}_contiguous + // - cumulative product of shapes is equal to the strides (we can ignore axes + // with size == 1) + // - in the forward direction (column contiguous) + // - in the reverse direction (row contiguous) size_t no_broadcast_data_size = 1; int64_t f_stride = 1; int64_t b_stride = 1; diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index cd9296075..035d8ae47 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -71,7 +71,12 @@ void Contiguous::eval_gpu(const std::vector& inputs, array& out) { (allow_col_major_ && in.flags().col_contiguous))) { out.copy_shared_buffer(in); } else { - copy_gpu(in, out, CopyType::General); + out.set_data(allocator::malloc(out.nbytes())); + copy_gpu_inplace( + in, + out, + in.flags().row_contiguous ? CopyType::Vector : CopyType::General, + stream()); } } diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index 011eb7ebb..60e256174 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -1,11 +1,13 @@ // Copyright © 2024 Apple Inc. #include #include +#include #include #include #include #include "mlx/3rdparty/pocketfft.h" +#include "mlx/backend/common/transpose.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" @@ -27,7 +29,7 @@ using MTLFC = std::tuple; // For strided reads/writes, coalesce at least this many complex64s #define MIN_COALESCE_WIDTH 4 -inline const std::vector supported_radices() { +inline constexpr std::array supported_radices() { // Ordered by preference in decomposition. return {13, 11, 8, 7, 6, 5, 4, 3, 2}; } @@ -65,6 +67,7 @@ void fft_op( bool real, const FourStepParams four_step_params, bool inplace, + metal::Device& d, const Stream& s); struct FFTPlan { @@ -112,13 +115,10 @@ std::vector plan_stockham_fft(int n) { FFTPlan plan_fft(int n) { auto radices = supported_radices(); - std::set radices_set(radices.begin(), radices.end()); FFTPlan plan; plan.n = n; plan.rader = std::vector(radices.size(), 0); - auto factors = prime_factors(n); - int remaining_n = n; // Four Step FFT when N is too large for shared mem. if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) { @@ -128,16 +128,20 @@ FFTPlan plan_fft(int n) { plan.n2 = n > 65536 ? 1024 : 64; plan.n1 = n / plan.n2; return plan; - } else if (n > MAX_STOCKHAM_FFT_SIZE) { + } + + if (n > MAX_STOCKHAM_FFT_SIZE) { // Otherwise we use a multi-upload Bluestein's plan.four_step = true; plan.bluestein_n = next_fast_n(2 * n - 1); return plan; } + int remaining_n = n; + auto factors = prime_factors(n); for (int factor : factors) { // Make sure the factor is a supported radix - if (radices_set.find(factor) == radices_set.end()) { + if (std::find(radices.begin(), radices.end(), factor) == radices.end()) { // We only support a single Rader factor currently // TODO(alexbarron) investigate weirdness with large // Rader sizes -- possibly a compiler issue? @@ -154,7 +158,7 @@ FFTPlan plan_fft(int n) { for (int rf : rader_factors) { // We don't nest Rader's algorithm so if `factor - 1` // isn't Stockham decomposable we give up and do Bluestein's. - if (radices_set.find(rf) == radices_set.end()) { + if (std::find(radices.begin(), radices.end(), rf) == radices.end()) { plan.four_step = n > MAX_BLUESTEIN_FFT_SIZE; plan.bluestein_n = next_fast_n(2 * n - 1); plan.stockham = plan_stockham_fft(plan.bluestein_n); @@ -358,6 +362,8 @@ void multi_upload_bluestein_fft( FFTPlan& plan, std::vector& copies, const Stream& s) { + auto& d = metal::device(s.device); + // TODO(alexbarron) Implement fused kernels for mutli upload bluestein's // algorithm int n = inverse ? out.shape(axis) : in.shape(axis); @@ -420,6 +426,7 @@ void multi_upload_bluestein_fft( /*real=*/false, FourStepParams(), /*inplace=*/false, + d, s); copies.push_back(pad_temp1); @@ -435,6 +442,7 @@ void multi_upload_bluestein_fft( /* real= */ false, FourStepParams(), /*inplace=*/true, + d, s); int offset = plan.bluestein_n - (2 * n - 1); @@ -493,7 +501,15 @@ void four_step_fft( auto temp_shape = (real && inverse) ? out.shape() : in.shape(); array temp(temp_shape, complex64, nullptr, {}); fft_op( - in, temp, axis, inverse, real, four_step_params, /*inplace=*/false, s); + in, + temp, + axis, + inverse, + real, + four_step_params, + /*inplace=*/false, + d, + s); four_step_params.first_step = false; fft_op( temp, @@ -503,6 +519,7 @@ void four_step_fft( real, four_step_params, /*inplace=*/in_place, + d, s); copies.push_back(temp); } else { @@ -518,9 +535,8 @@ void fft_op( bool real, const FourStepParams four_step_params, bool inplace, + metal::Device& d, const Stream& s) { - auto& d = metal::device(s.device); - size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis); if (n == 1) { out.copy_shared_buffer(in); @@ -755,57 +771,116 @@ void fft_op( d.add_temporaries(std::move(copies), s.index); } -void fft_op( +inline array prepare_input( + +void fft_stockham_inplace( const array& in, array& out, size_t axis, bool inverse, bool real, - bool inplace, + metal::Device& d, const Stream& s) { - fft_op(in, out, axis, inverse, real, FourStepParams(), inplace, s); + } -void nd_fft_op( +void fft_op_inplace( + const array& in, + array& out, + size_t axis, + bool inverse, + bool real, + metal::Device &d, + const Stream& s) { + // Get the FFT size and plan it + size_t n = out.dtype() == float32 ? out.shape(axis) : in.shape(axis); + auto plan = plan_fft(n); + if (n == 1) { + std::cout << "--------------> 1-size FFT <-----------------" << std::endl; + } + + if (plan.four_step && plan.bluestein_n < 0) { + // four_step_fft(in, out, axis, inverse, real, plan, inplace, d, s); + return; + } +} + +void nd_fft_op_inplace( const array& in, array& out, const std::vector& axes, bool inverse, bool real, + metal::Device &d, const Stream& s) { - // Perform ND FFT on GPU as a series of 1D FFTs - auto temp_shape = inverse ? in.shape() : out.shape(); - array temp1(temp_shape, complex64, nullptr, {}); - array temp2(temp_shape, complex64, nullptr, {}); - std::vector temp_arrs = {temp1, temp2}; - for (int i = axes.size() - 1; i >= 0; i--) { - int reverse_index = axes.size() - i - 1; - // For 5D and above, we don't want to reallocate our two temporary arrays - bool inplace = reverse_index >= 3 && i != 0; - // Opposite order for fft vs ifft - int index = inverse ? reverse_index : i; - size_t axis = axes[index]; - // Mirror np.fft.(i)rfftn and perform a real transform - // only on the final axis. - bool step_real = (real && index == axes.size() - 1); - auto step_shape = inverse ? out.shape(axis) : in.shape(axis); - const array& in_arr = i == axes.size() - 1 ? in : temp_arrs[1 - i % 2]; - array& out_arr = i == 0 ? out : temp_arrs[i % 2]; - fft_op(in_arr, out_arr, axis, inverse, step_real, inplace, s); - } + // We are going to make and possibly reuse some intermediate arrays that will + // hold the intermediate fft results. + auto shape = inverse ? in.shape() : out.shape(); + std::vector intermediates; + intermediates.reserve(2); - auto& d = metal::device(s.device); - d.add_temporaries(std::move(temp_arrs), s.index); + // Utility to return either in or one of the intermediates. + auto get_input_array = [&](int step) -> const array& { + // The first step so use the input array + if (step == 0) { + return in; + } + + return intermediates[(step - 1) % 2]; + }; + + // Utility to return either out or one of the intermediates. It also informs + // us if we should allocate memory for that output or there is already some + // allocated. + auto get_output_array = [&](int step) -> array& { + // It is the final step so return the output array + if (step == axes.size() - 1) { + return out; + } + + // We already have made an array that we can use so go ahead and use it and + // don't reallocate the memory. + if (step % 2 < intermediates.size()) { + return intermediates[step % 2]; + } + + array x(shape, complex64, nullptr, {}); + x.set_data(allocator::malloc(x.nbytes())); + intermediates.emplace_back(std::move(x)); + d.add_temporary(intermediates.back(), s.index); + + return intermediates.back(); + }; + + // Perform ND FFT on GPU as a series of 1D FFTs + for (int step = 0; step < axes.size(); step++) { + auto x = get_input_array(step); + auto y = get_output_array(step); + auto step_axis = axes[inverse ? step : axes.size() - step - 1]; + auto step_real = real && (inverse ? step == axes.size() - 1 : step == 0); + fft_op_inplace(x, y, step_axis, inverse, step_real, d, s); + } } void FFT::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); + auto& d = metal::device(s.device); auto& in = inputs[0]; + // The FFT ops above have the *_inplace suffix. This means that the memory + // needs to be already allocated in the output array. Similar to + // copy_gpu_inplace and so on. + // + // Even though we allocate the memory, we do not necessarily want the + // contiguous strides so the *_inplace ops may change the strides and flags + // of the array but will not reallocate the memory. + + out.set_data(allocator::malloc(out.nbytes())); + if (axes_.size() > 1) { - nd_fft_op(in, out, axes_, inverse_, real_, s); + nd_fft_op_inplace(in, out, axes_, inverse_, real_, d, s); } else { - fft_op(in, out, axes_[0], inverse_, real_, /*inplace=*/false, s); + fft_op_inplace(in, out, axes_[0], inverse_, real_, d, s); } }