mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Put the reshape utils in gpu/copy.h
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
#include "mlx/backend/cuda/device/config.h"
|
||||
#include "mlx/backend/cuda/lru_cache.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/reshape.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/cuda/kernel_utils.cuh"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/reshape.h"
|
||||
#include "mlx/dtype_utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
|
||||
@@ -2,5 +2,4 @@ target_sources(
|
||||
mlx
|
||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/reshape.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)
|
||||
|
||||
@@ -52,4 +52,70 @@ array contiguous_copy_gpu(const array& arr, const Stream& s) {
|
||||
return arr_copy;
|
||||
}
|
||||
|
||||
void reshape_gpu(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {
|
||||
int ndim = x.ndim();
|
||||
if (start_axis < 0) {
|
||||
start_axis += ndim;
|
||||
}
|
||||
if (end_axis < 0) {
|
||||
end_axis += ndim;
|
||||
}
|
||||
start_axis = std::max(0, start_axis);
|
||||
end_axis = std::min(ndim - 1, end_axis);
|
||||
|
||||
return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s);
|
||||
}
|
||||
|
||||
array reshape_in_eval(const array& x, Shape shape, Stream s) {
|
||||
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||
reshape_gpu(x, out, s);
|
||||
return out;
|
||||
}
|
||||
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
||||
int ndim = x.ndim();
|
||||
if (axis1 < 0) {
|
||||
axis1 += ndim;
|
||||
}
|
||||
if (axis2 < 0) {
|
||||
axis2 += ndim;
|
||||
}
|
||||
|
||||
auto shape = x.shape();
|
||||
std::swap(shape[axis1], shape[axis2]);
|
||||
auto strides = x.strides();
|
||||
std::swap(strides[axis1], strides[axis2]);
|
||||
|
||||
auto [data_size, row_contiguous, col_contiguous] =
|
||||
check_contiguity(shape, strides);
|
||||
bool contiguous = data_size == x.data_size();
|
||||
|
||||
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||
out.copy_shared_buffer(
|
||||
x,
|
||||
std::move(strides),
|
||||
{contiguous, row_contiguous, col_contiguous},
|
||||
x.data_size());
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -46,4 +46,12 @@ void fill_gpu(const array& val, array& out, const Stream& s);
|
||||
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||
array contiguous_copy_gpu(const array& arr, const Stream& s);
|
||||
|
||||
// Copy data from |in| and transpose to |out|'s shape.
|
||||
void reshape_gpu(const array& in, array& out, Stream s);
|
||||
|
||||
// Like the normal ops but safe to call in eval_gpu.
|
||||
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s);
|
||||
array reshape_in_eval(const array& x, Shape shape, Stream s);
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/reshape.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
#if defined(MLX_USE_CUDA)
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/gpu/reshape.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void reshape_gpu(const array& in, array& out, Stream s) {
|
||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||
if (copy_necessary) {
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
copy_gpu_inplace(
|
||||
in,
|
||||
out,
|
||||
in.shape(),
|
||||
in.strides(),
|
||||
make_contiguous_strides(in.shape()),
|
||||
0,
|
||||
0,
|
||||
CopyType::General,
|
||||
s);
|
||||
} else {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
|
||||
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {
|
||||
int ndim = x.ndim();
|
||||
if (start_axis < 0) {
|
||||
start_axis += ndim;
|
||||
}
|
||||
if (end_axis < 0) {
|
||||
end_axis += ndim;
|
||||
}
|
||||
start_axis = std::max(0, start_axis);
|
||||
end_axis = std::min(ndim - 1, end_axis);
|
||||
|
||||
return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s);
|
||||
}
|
||||
|
||||
array reshape_in_eval(const array& x, Shape shape, Stream s) {
|
||||
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||
reshape_gpu(x, out, s);
|
||||
return out;
|
||||
}
|
||||
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
||||
int ndim = x.ndim();
|
||||
if (axis1 < 0) {
|
||||
axis1 += ndim;
|
||||
}
|
||||
if (axis2 < 0) {
|
||||
axis2 += ndim;
|
||||
}
|
||||
|
||||
auto shape = x.shape();
|
||||
std::swap(shape[axis1], shape[axis2]);
|
||||
auto strides = x.strides();
|
||||
std::swap(strides[axis1], strides[axis2]);
|
||||
|
||||
auto [data_size, row_contiguous, col_contiguous] =
|
||||
check_contiguity(shape, strides);
|
||||
bool contiguous = data_size == x.data_size();
|
||||
|
||||
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||
out.copy_shared_buffer(
|
||||
x,
|
||||
std::move(strides),
|
||||
{contiguous, row_contiguous, col_contiguous},
|
||||
x.data_size());
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
@@ -1,15 +0,0 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
// Copy data from |in| and transpose to |out|'s shape.
|
||||
void reshape_gpu(const array& in, array& out, Stream s);
|
||||
|
||||
// Like the normal ops but safe to call in eval_gpu.
|
||||
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s);
|
||||
array reshape_in_eval(const array& x, Shape shape, Stream s);
|
||||
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||
|
||||
} // namespace mlx::core
|
||||
Reference in New Issue
Block a user