mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Put reshape utils in one file
This commit is contained in:
@@ -228,31 +228,4 @@ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
|
|||||||
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
|
||||||
}
|
}
|
||||||
|
|
||||||
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
|
||||||
int ndim = x.ndim();
|
|
||||||
if (axis1 < 0) {
|
|
||||||
axis1 += ndim;
|
|
||||||
}
|
|
||||||
if (axis2 < 0) {
|
|
||||||
axis2 += ndim;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto shape = x.shape();
|
|
||||||
std::swap(shape[axis1], shape[axis2]);
|
|
||||||
auto strides = x.strides();
|
|
||||||
std::swap(strides[axis1], strides[axis2]);
|
|
||||||
|
|
||||||
auto [data_size, row_contiguous, col_contiguous] =
|
|
||||||
check_contiguity(shape, strides);
|
|
||||||
bool contiguous = data_size == x.data_size();
|
|
||||||
|
|
||||||
array out(std::move(shape), x.dtype(), nullptr, {});
|
|
||||||
out.copy_shared_buffer(
|
|
||||||
x,
|
|
||||||
std::move(strides),
|
|
||||||
{contiguous, row_contiguous, col_contiguous},
|
|
||||||
x.data_size());
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
|||||||
@@ -196,9 +196,6 @@ void shared_buffer_reshape(
|
|||||||
const Strides& out_strides,
|
const Strides& out_strides,
|
||||||
array& out);
|
array& out);
|
||||||
|
|
||||||
// Like the swapaxes op but safe to call in eval_gpu.
|
|
||||||
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
|
||||||
vec.erase(std::next(vec.begin(), index));
|
vec.erase(std::next(vec.begin(), index));
|
||||||
|
|||||||
@@ -2,4 +2,5 @@ target_sources(
|
|||||||
mlx
|
mlx
|
||||||
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/reshape.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
#include "mlx/backend/common/slicing.h"
|
#include "mlx/backend/common/slicing.h"
|
||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/gpu/copy.h"
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/backend/gpu/reshape.h"
|
||||||
#include "mlx/backend/gpu/slicing.h"
|
#include "mlx/backend/gpu/slicing.h"
|
||||||
|
|
||||||
#if defined(MLX_USE_CUDA)
|
#if defined(MLX_USE_CUDA)
|
||||||
@@ -20,29 +21,6 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
void reshape(const array& in, array& out, Stream s) {
|
|
||||||
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
|
||||||
if (copy_necessary) {
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
copy_gpu_inplace(
|
|
||||||
in,
|
|
||||||
out,
|
|
||||||
in.shape(),
|
|
||||||
in.strides(),
|
|
||||||
make_contiguous_strides(in.shape()),
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
CopyType::General,
|
|
||||||
s);
|
|
||||||
} else {
|
|
||||||
shared_buffer_reshape(in, out_strides, out);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void AsStrided::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
|
MLX_PROFILER_RANGE("AsStrided::eval_gpu");
|
||||||
eval(inputs, out);
|
eval(inputs, out);
|
||||||
@@ -124,7 +102,7 @@ void Full::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Flatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
MLX_PROFILER_RANGE("Flatten::eval_gpu");
|
MLX_PROFILER_RANGE("Flatten::eval_gpu");
|
||||||
reshape(inputs[0], out, stream());
|
reshape_gpu(inputs[0], out, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@@ -150,7 +128,7 @@ void Pad::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Reshape::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
MLX_PROFILER_RANGE("Reshape::eval_gpu");
|
MLX_PROFILER_RANGE("Reshape::eval_gpu");
|
||||||
reshape(inputs[0], out, stream());
|
reshape_gpu(inputs[0], out, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Split::eval_gpu(
|
void Split::eval_gpu(
|
||||||
@@ -224,7 +202,7 @@ void Transpose::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
|
MLX_PROFILER_RANGE("Unflatten::eval_gpu");
|
||||||
reshape(inputs[0], out, stream());
|
reshape_gpu(inputs[0], out, stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
|||||||
75
mlx/backend/gpu/reshape.cpp
Normal file
75
mlx/backend/gpu/reshape.cpp
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/gpu/reshape.h"
|
||||||
|
#include "mlx/backend/gpu/copy.h"
|
||||||
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void reshape_gpu(const array& in, array& out, Stream s) {
|
||||||
|
auto [copy_necessary, out_strides] = prepare_reshape(in, out);
|
||||||
|
if (copy_necessary) {
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
copy_gpu_inplace(
|
||||||
|
in,
|
||||||
|
out,
|
||||||
|
in.shape(),
|
||||||
|
in.strides(),
|
||||||
|
make_contiguous_strides(in.shape()),
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
CopyType::General,
|
||||||
|
s);
|
||||||
|
} else {
|
||||||
|
shared_buffer_reshape(in, out_strides, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s) {
|
||||||
|
int ndim = x.ndim();
|
||||||
|
if (start_axis < 0) {
|
||||||
|
start_axis += ndim;
|
||||||
|
}
|
||||||
|
if (end_axis < 0) {
|
||||||
|
end_axis += ndim;
|
||||||
|
}
|
||||||
|
start_axis = std::max(0, start_axis);
|
||||||
|
end_axis = std::min(ndim - 1, end_axis);
|
||||||
|
|
||||||
|
return reshape_in_eval(x, Flatten::output_shape(x, start_axis, end_axis), s);
|
||||||
|
}
|
||||||
|
|
||||||
|
array reshape_in_eval(const array& x, Shape shape, Stream s) {
|
||||||
|
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||||
|
reshape_gpu(x, out, s);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
array swapaxes_in_eval(const array& x, int axis1, int axis2) {
|
||||||
|
int ndim = x.ndim();
|
||||||
|
if (axis1 < 0) {
|
||||||
|
axis1 += ndim;
|
||||||
|
}
|
||||||
|
if (axis2 < 0) {
|
||||||
|
axis2 += ndim;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape = x.shape();
|
||||||
|
std::swap(shape[axis1], shape[axis2]);
|
||||||
|
auto strides = x.strides();
|
||||||
|
std::swap(strides[axis1], strides[axis2]);
|
||||||
|
|
||||||
|
auto [data_size, row_contiguous, col_contiguous] =
|
||||||
|
check_contiguity(shape, strides);
|
||||||
|
bool contiguous = data_size == x.data_size();
|
||||||
|
|
||||||
|
array out(std::move(shape), x.dtype(), nullptr, {});
|
||||||
|
out.copy_shared_buffer(
|
||||||
|
x,
|
||||||
|
std::move(strides),
|
||||||
|
{contiguous, row_contiguous, col_contiguous},
|
||||||
|
x.data_size());
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
15
mlx/backend/gpu/reshape.h
Normal file
15
mlx/backend/gpu/reshape.h
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
// Copy data from |in| and transpose to |out|'s shape.
|
||||||
|
void reshape_gpu(const array& in, array& out, Stream s);
|
||||||
|
|
||||||
|
// Like the normal ops but safe to call in eval_gpu.
|
||||||
|
array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s);
|
||||||
|
array reshape_in_eval(const array& x, Shape shape, Stream s);
|
||||||
|
array swapaxes_in_eval(const array& x, int axis1, int axis2);
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
||||||
Reference in New Issue
Block a user