diff --git a/mlx/backend/cuda/conv.cpp b/mlx/backend/cuda/conv.cpp index 3f7c68f8c..3d7ef60bc 100644 --- a/mlx/backend/cuda/conv.cpp +++ b/mlx/backend/cuda/conv.cpp @@ -4,7 +4,6 @@ #include "mlx/backend/cuda/device/config.h" #include "mlx/backend/cuda/lru_cache.h" #include "mlx/backend/gpu/copy.h" -#include "mlx/backend/gpu/reshape.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/cuda/sort.cu b/mlx/backend/cuda/sort.cu index 5936c08e2..e81ae12eb 100644 --- a/mlx/backend/cuda/sort.cu +++ b/mlx/backend/cuda/sort.cu @@ -3,7 +3,6 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/gpu/copy.h" -#include "mlx/backend/gpu/reshape.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/gpu/CMakeLists.txt b/mlx/backend/gpu/CMakeLists.txt index 6d92502d2..0396ae03a 100644 --- a/mlx/backend/gpu/CMakeLists.txt +++ b/mlx/backend/gpu/CMakeLists.txt @@ -2,5 +2,4 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/reshape.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp) diff --git a/mlx/backend/gpu/copy.cpp b/mlx/backend/gpu/copy.cpp index 4556f7d98..472ee486b 100644 --- a/mlx/backend/gpu/copy.cpp +++ b/mlx/backend/gpu/copy.cpp @@ -52,4 +52,70 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) { return arr_copy; } +void reshape_gpu(const array& in, array& out, Stream s) { + auto [copy_necessary, out_strides] = prepare_reshape(in, out); + if (copy_necessary) { + out.set_data(allocator::malloc(out.nbytes())); + copy_gpu_inplace( + in, + out, + in.shape(), + in.strides(), + make_contiguous_strides(in.shape()), + 0, + 0, + CopyType::General, + s); + } else { + shared_buffer_reshape(in, out_strides, out); + } +} + +array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) { + int ndim = x.ndim(); + if (start_axis < 0) { + start_axis += ndim; + } + if (end_axis < 0) { + end_axis += ndim; + } + start_axis = std::max(0, start_axis); + end_axis = std::min(ndim - 1, end_axis); + + return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s); +} + +array reshape_in_eval(const array& x, Shape shape, Stream s) { + array out(std::move(shape), x.dtype(), nullptr, {}); + reshape_gpu(x, out, s); + return out; +} + +array swapaxes_in_eval(const array& x, int axis1, int axis2) { + int ndim = x.ndim(); + if (axis1 < 0) { + axis1 += ndim; + } + if (axis2 < 0) { + axis2 += ndim; + } + + auto shape = x.shape(); + std::swap(shape[axis1], shape[axis2]); + auto strides = x.strides(); + std::swap(strides[axis1], strides[axis2]); + + auto [data_size, row_contiguous, col_contiguous] = + check_contiguity(shape, strides); + bool contiguous = data_size == x.data_size(); + + array out(std::move(shape), x.dtype(), nullptr, {}); + out.copy_shared_buffer( + x, + std::move(strides), + {contiguous, row_contiguous, col_contiguous}, + x.data_size()); + return out; +} + } // namespace mlx::core diff --git a/mlx/backend/gpu/copy.h b/mlx/backend/gpu/copy.h index f01fe9fda..274250202 100644 --- a/mlx/backend/gpu/copy.h +++ b/mlx/backend/gpu/copy.h @@ -46,4 +46,12 @@ void fill_gpu(const array& val, array& out, const Stream& s); // Return a contiguous array with same shape that copies the data of |arr|. array contiguous_copy_gpu(const array& arr, const Stream& s); +// Copy data from |in| and transpose to |out|'s shape. +void reshape_gpu(const array& in, array& out, Stream s); + +// Like the normal ops but safe to call in eval_gpu. +array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s); +array reshape_in_eval(const array& x, Shape shape, Stream s); +array swapaxes_in_eval(const array& x, int axis1, int axis2); + } // namespace mlx::core diff --git a/mlx/backend/gpu/primitives.cpp b/mlx/backend/gpu/primitives.cpp index 63a983864..6017879a5 100644 --- a/mlx/backend/gpu/primitives.cpp +++ b/mlx/backend/gpu/primitives.cpp @@ -4,7 +4,6 @@ #include "mlx/backend/common/slicing.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" -#include "mlx/backend/gpu/reshape.h" #include "mlx/backend/gpu/slicing.h" #if defined(MLX_USE_CUDA) diff --git a/mlx/backend/gpu/reshape.cpp b/mlx/backend/gpu/reshape.cpp deleted file mode 100644 index 78f3c4e63..000000000 --- a/mlx/backend/gpu/reshape.cpp +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/gpu/reshape.h" -#include "mlx/backend/gpu/copy.h" -#include "mlx/primitives.h" - -namespace mlx::core { - -void reshape_gpu(const array& in, array& out, Stream s) { - auto [copy_necessary, out_strides] = prepare_reshape(in, out); - if (copy_necessary) { - out.set_data(allocator::malloc(out.nbytes())); - copy_gpu_inplace( - in, - out, - in.shape(), - in.strides(), - make_contiguous_strides(in.shape()), - 0, - 0, - CopyType::General, - s); - } else { - shared_buffer_reshape(in, out_strides, out); - } -} - -array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) { - int ndim = x.ndim(); - if (start_axis < 0) { - start_axis += ndim; - } - if (end_axis < 0) { - end_axis += ndim; - } - start_axis = std::max(0, start_axis); - end_axis = std::min(ndim - 1, end_axis); - - return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s); -} - -array reshape_in_eval(const array& x, Shape shape, Stream s) { - array out(std::move(shape), x.dtype(), nullptr, {}); - reshape_gpu(x, out, s); - return out; -} - -array swapaxes_in_eval(const array& x, int axis1, int axis2) { - int ndim = x.ndim(); - if (axis1 < 0) { - axis1 += ndim; - } - if (axis2 < 0) { - axis2 += ndim; - } - - auto shape = x.shape(); - std::swap(shape[axis1], shape[axis2]); - auto strides = x.strides(); - std::swap(strides[axis1], strides[axis2]); - - auto [data_size, row_contiguous, col_contiguous] = - check_contiguity(shape, strides); - bool contiguous = data_size == x.data_size(); - - array out(std::move(shape), x.dtype(), nullptr, {}); - out.copy_shared_buffer( - x, - std::move(strides), - {contiguous, row_contiguous, col_contiguous}, - x.data_size()); - return out; -} - -} // namespace mlx::core diff --git a/mlx/backend/gpu/reshape.h b/mlx/backend/gpu/reshape.h deleted file mode 100644 index e06567482..000000000 --- a/mlx/backend/gpu/reshape.h +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include "mlx/backend/common/utils.h" - -namespace mlx::core { - -// Copy data from |in| and transpose to |out|'s shape. -void reshape_gpu(const array& in, array& out, Stream s); - -// Like the normal ops but safe to call in eval_gpu. -array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s); -array reshape_in_eval(const array& x, Shape shape, Stream s); -array swapaxes_in_eval(const array& x, int axis1, int axis2); - -} // namespace mlx::core