mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Put the reshape utils in gpu/copy.h
This commit is contained in:
@@ -4,7 +4,6 @@
|
|||||||
#include "mlx/backend/cuda/device/config.h"
|
#include "mlx/backend/cuda/device/config.h"
|
||||||
#include "mlx/backend/cuda/lru_cache.h"
|
#include "mlx/backend/cuda/lru_cache.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/gpu/reshape.h"
|
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
#include "mlx/backend/cuda/device.h"
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/gpu/reshape.h"
|
|
||||||
#include "mlx/dtype_utils.h"
|
#include "mlx/dtype_utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
|||||||
@@ -2,5 +2,4 @@ target_sources(
|
|||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/reshape.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)
|
||||||
|
|||||||
@@ -52,4 +52,70 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) {
|
|||||||
return arr_copy;
|
return arr_copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void reshape_gpu(const array& in, array& out, Stream s) {
|
||||||
|
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||||
|
if (copy_necessary) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
copy_gpu_inplace(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
in.shape(),
|
||||||
|
in.strides(),
|
||||||
|
make_contiguous_strides(in.shape()),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
CopyType::General,
|
||||||
|
s);
|
||||||
|
} else {
|
||||||
|
shared_buffer_reshape(in, out_strides, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {
|
||||||
|
int ndim = x.ndim();
|
||||||
|
if (start_axis < 0) {
|
||||||
|
start_axis += ndim;
|
||||||
|
}
|
||||||
|
if (end_axis < 0) {
|
||||||
|
end_axis += ndim;
|
||||||
|
}
|
||||||
|
start_axis = std::max(0, start_axis);
|
||||||
|
end_axis = std::min(ndim - 1, end_axis);
|
||||||
|
|
||||||
|
return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array reshape_in_eval(const array& x, Shape shape, Stream s) {
|
||||||
|
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||||
|
reshape_gpu(x, out, s);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
||||||
|
int ndim = x.ndim();
|
||||||
|
if (axis1 < 0) {
|
||||||
|
axis1 += ndim;
|
||||||
|
}
|
||||||
|
if (axis2 < 0) {
|
||||||
|
axis2 += ndim;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape = x.shape();
|
||||||
|
std::swap(shape[axis1], shape[axis2]);
|
||||||
|
auto strides = x.strides();
|
||||||
|
std::swap(strides[axis1], strides[axis2]);
|
||||||
|
|
||||||
|
auto [data_size, row_contiguous, col_contiguous] =
|
||||||
|
check_contiguity(shape, strides);
|
||||||
|
bool contiguous = data_size == x.data_size();
|
||||||
|
|
||||||
|
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||||
|
out.copy_shared_buffer(
|
||||||
|
x,
|
||||||
|
std::move(strides),
|
||||||
|
{contiguous, row_contiguous, col_contiguous},
|
||||||
|
x.data_size());
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -46,4 +46,12 @@ void fill_gpu(const array& val, array& out, const Stream& s);
|
|||||||
// Return a contiguous array with same shape that copies the data of |arr|.
|
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||||
array contiguous_copy_gpu(const array& arr, const Stream& s);
|
array contiguous_copy_gpu(const array& arr, const Stream& s);
|
||||||
|
|
||||||
|
// Copy data from |in| and transpose to |out|'s shape.
|
||||||
|
void reshape_gpu(const array& in, array& out, Stream s);
|
||||||
|
|
||||||
|
// Like the normal ops but safe to call in eval_gpu.
|
||||||
|
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s);
|
||||||
|
array reshape_in_eval(const array& x, Shape shape, Stream s);
|
||||||
|
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
#include "mlx/backend/common/slicing.h"
|
#include "mlx/backend/common/slicing.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
#include "mlx/backend/gpu/reshape.h"
|
|
||||||
#include "mlx/backend/gpu/slicing.h"
|
#include "mlx/backend/gpu/slicing.h"
|
||||||
|
|
||||||
#if defined(MLX_USE_CUDA)
|
#if defined(MLX_USE_CUDA)
|
||||||
|
|||||||
@@ -1,75 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/gpu/reshape.h"
|
|
||||||
#include "mlx/backend/gpu/copy.h"
|
|
||||||
#include "mlx/primitives.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
void reshape_gpu(const array& in, array& out, Stream s) {
|
|
||||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
|
||||||
if (copy_necessary) {
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
copy_gpu_inplace(
|
|
||||||
in,
|
|
||||||
out,
|
|
||||||
in.shape(),
|
|
||||||
in.strides(),
|
|
||||||
make_contiguous_strides(in.shape()),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
CopyType::General,
|
|
||||||
s);
|
|
||||||
} else {
|
|
||||||
shared_buffer_reshape(in, out_strides, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {
|
|
||||||
int ndim = x.ndim();
|
|
||||||
if (start_axis < 0) {
|
|
||||||
start_axis += ndim;
|
|
||||||
}
|
|
||||||
if (end_axis < 0) {
|
|
||||||
end_axis += ndim;
|
|
||||||
}
|
|
||||||
start_axis = std::max(0, start_axis);
|
|
||||||
end_axis = std::min(ndim - 1, end_axis);
|
|
||||||
|
|
||||||
return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s);
|
|
||||||
}
|
|
||||||
|
|
||||||
array reshape_in_eval(const array& x, Shape shape, Stream s) {
|
|
||||||
array out(std::move(shape), x.dtype(), nullptr, {});
|
|
||||||
reshape_gpu(x, out, s);
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
|
||||||
int ndim = x.ndim();
|
|
||||||
if (axis1 < 0) {
|
|
||||||
axis1 += ndim;
|
|
||||||
}
|
|
||||||
if (axis2 < 0) {
|
|
||||||
axis2 += ndim;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto shape = x.shape();
|
|
||||||
std::swap(shape[axis1], shape[axis2]);
|
|
||||||
auto strides = x.strides();
|
|
||||||
std::swap(strides[axis1], strides[axis2]);
|
|
||||||
|
|
||||||
auto [data_size, row_contiguous, col_contiguous] =
|
|
||||||
check_contiguity(shape, strides);
|
|
||||||
bool contiguous = data_size == x.data_size();
|
|
||||||
|
|
||||||
array out(std::move(shape), x.dtype(), nullptr, {});
|
|
||||||
out.copy_shared_buffer(
|
|
||||||
x,
|
|
||||||
std::move(strides),
|
|
||||||
{contiguous, row_contiguous, col_contiguous},
|
|
||||||
x.data_size());
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/common/utils.h"
|
|
||||||
|
|
||||||
namespace mlx::core {
|
|
||||||
|
|
||||||
// Copy data from |in| and transpose to |out|'s shape.
|
|
||||||
void reshape_gpu(const array& in, array& out, Stream s);
|
|
||||||
|
|
||||||
// Like the normal ops but safe to call in eval_gpu.
|
|
||||||
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s);
|
|
||||||
array reshape_in_eval(const array& x, Shape shape, Stream s);
|
|
||||||
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
|
||||||
Reference in New Issue
Block a user