From 8bf8034ffdf94c15dd25e0c042ce9513076c0ea8 Mon Sep 17 00:00:00 2001 From: Cheng Date: Thu, 14 Aug 2025 16:59:38 +0900 Subject: [PATCH] Put reshape utils in one file --- mlx/backend/common/utils.cpp | 27 ------------ mlx/backend/common/utils.h | 3 -- mlx/backend/gpu/CMakeLists.txt | 1 + mlx/backend/gpu/primitives.cpp | 30 ++------------ mlx/backend/gpu/reshape.cpp | 75 ++++++++++++++++++++++++++++++++++ mlx/backend/gpu/reshape.h | 15 +++++++ 6 files changed, 95 insertions(+), 56 deletions(-) create mode 100644 mlx/backend/gpu/reshape.cpp create mode 100644 mlx/backend/gpu/reshape.h diff --git a/mlx/backend/common/utils.cpp b/mlx/backend/common/utils.cpp index 4c9e39dc6..ae169e35e 100644 --- a/mlx/backend/common/utils.cpp +++ b/mlx/backend/common/utils.cpp @@ -228,31 +228,4 @@ std::pair get_grid_and_block_common(int dim0, int dim1, int dim2) { std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz)); } -array swapaxes_in_eval(const array& x, int axis1, int axis2) { - int ndim = x.ndim(); - if (axis1 < 0) { - axis1 += ndim; - } - if (axis2 < 0) { - axis2 += ndim; - } - - auto shape = x.shape(); - std::swap(shape[axis1], shape[axis2]); - auto strides = x.strides(); - std::swap(strides[axis1], strides[axis2]); - - auto [data_size, row_contiguous, col_contiguous] = - check_contiguity(shape, strides); - bool contiguous = data_size == x.data_size(); - - array out(std::move(shape), x.dtype(), nullptr, {}); - out.copy_shared_buffer( - x, - std::move(strides), - {contiguous, row_contiguous, col_contiguous}, - x.data_size()); - return out; -} - } // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h index db0da5e10..1b6902ff3 100644 --- a/mlx/backend/common/utils.h +++ b/mlx/backend/common/utils.h @@ -196,9 +196,6 @@ void shared_buffer_reshape( const Strides& out_strides, array& out); -// Like the swapaxes op but safe to call in eval_gpu. -array swapaxes_in_eval(const array& x, int axis1, int axis2); - template inline SmallVector remove_index(SmallVector vec, size_t index) { vec.erase(std::next(vec.begin(), index)); diff --git a/mlx/backend/gpu/CMakeLists.txt b/mlx/backend/gpu/CMakeLists.txt index 0396ae03a..6d92502d2 100644 --- a/mlx/backend/gpu/CMakeLists.txt +++ b/mlx/backend/gpu/CMakeLists.txt @@ -2,4 +2,5 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reshape.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp) diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index 56d389b4f..63a983864 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -4,6 +4,7 @@ #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/reshape.h" #include "mlx/backend/gpu/slicing.h" #if defined(MLX_USE_CUDA) @@ -20,29 +21,6 @@ namespace mlx::core { -namespace { - -void reshape(const array& in, array& out, Stream s) { - auto [copy_necessary, out_strides] = prepare_reshape(in, out); - if (copy_necessary) { - out.set_data(allocator::malloc(out.nbytes())); - copy_gpu_inplace( - in, - out, - in.shape(), - in.strides(), - make_contiguous_strides(in.shape()), - 0, - 0, - CopyType::General, - s); - } else { - shared_buffer_reshape(in, out_strides, out); - } -} - -} // namespace - void AsStrided::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("AsStrided::eval_gpu"); eval(inputs, out); @@ -124,7 +102,7 @@ void Full::eval_gpu(const std::vector& inputs, array& out) { void Flatten::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Flatten::eval_gpu"); - reshape(inputs[0], out, stream()); + reshape_gpu(inputs[0], out, stream()); } void NumberOfElements::eval_gpu(const std::vector& inputs, array& out) { @@ -150,7 +128,7 @@ void Pad::eval_gpu(const std::vector& inputs, array& out) { void Reshape::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Reshape::eval_gpu"); - reshape(inputs[0], out, stream()); + reshape_gpu(inputs[0], out, stream()); } void Split::eval_gpu( @@ -224,7 +202,7 @@ void Transpose::eval_gpu(const std::vector& inputs, array& out) { void Unflatten::eval_gpu(const std::vector& inputs, array& out) { MLX_PROFILER_RANGE("Unflatten::eval_gpu"); - reshape(inputs[0], out, stream()); + reshape_gpu(inputs[0], out, stream()); } void View::eval_gpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/gpu/reshape.cpp b/mlx/backend/gpu/reshape.cpp new file mode 100644 index 000000000..78f3c4e63 --- /dev/null +++ b/mlx/backend/gpu/reshape.cpp @@ -0,0 +1,75 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/gpu/reshape.h" +#include "mlx/backend/gpu/copy.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void reshape_gpu(const array& in, array& out, Stream s) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + out.set_data(allocator::malloc(out.nbytes())); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + make_contiguous_strides(in.shape()), + 0, + 0, + CopyType::General, + s); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + +array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) { + int ndim = x.ndim(); + if (start_axis < 0) { + start_axis += ndim; + } + if (end_axis < 0) { + end_axis += ndim; + } + start_axis = std::max(0, start_axis); + end_axis = std::min(ndim - 1, end_axis); + + return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s); +} + +array reshape_in_eval(const array& x, Shape shape, Stream s) { + array out(std::move(shape), x.dtype(), nullptr, {}); + reshape_gpu(x, out, s); + return out; +} + +array swapaxes_in_eval(const array& x, int axis1, int axis2) { + int ndim = x.ndim(); + if (axis1 < 0) { + axis1 += ndim; + } + if (axis2 < 0) { + axis2 += ndim; + } + + auto shape = x.shape(); + std::swap(shape[axis1], shape[axis2]); + auto strides = x.strides(); + std::swap(strides[axis1], strides[axis2]); + + auto [data_size, row_contiguous, col_contiguous] = + check_contiguity(shape, strides); + bool contiguous = data_size == x.data_size(); + + array out(std::move(shape), x.dtype(), nullptr, {}); + out.copy_shared_buffer( + x, + std::move(strides), + {contiguous, row_contiguous, col_contiguous}, + x.data_size()); + return out; +} + +} // namespace mlx::core diff --git a/mlx/backend/gpu/reshape.h b/mlx/backend/gpu/reshape.h new file mode 100644 index 000000000..e06567482 --- /dev/null +++ b/mlx/backend/gpu/reshape.h @@ -0,0 +1,15 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +// Copy data from |in| and transpose to |out|'s shape. +void reshape_gpu(const array& in, array& out, Stream s); + +// Like the normal ops but safe to call in eval_gpu. +array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s); +array reshape_in_eval(const array& x, Shape shape, Stream s); +array swapaxes_in_eval(const array& x, int axis1, int axis2); + +} // namespace mlx::core