Add contiguous_copy_gpu util for copying array (#2379)

This commit is contained in:
Cheng 2025-07-18 22:44:25 +09:00 committed by GitHub
parent 31fc530c76
commit 45adec102c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 40 additions and 67 deletions

View File

@ -237,8 +237,7 @@ void LayerNorm::eval_gpu(
} }
return x; return x;
} else { } else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }
@ -295,9 +294,7 @@ void LayerNormVJP::eval_gpu(
return x; return x;
} }
copied = true; copied = true;
array x_copy(x.shape(), x.dtype(), nullptr, {}); return contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
return x_copy;
}; };
bool donate_x = inputs[0].is_donatable(); bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[3].is_donatable(); bool donate_g = inputs[3].is_donatable();

View File

@ -108,8 +108,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x; return x;
} else { } else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
encoder.add_temporary(x_copy); encoder.add_temporary(x_copy);
return x_copy; return x_copy;
} }

View File

@ -297,8 +297,7 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
} else if (stx == 1 && sty == arr.shape(-2)) { } else if (stx == 1 && sty == arr.shape(-2)) {
return std::make_tuple(true, sty, arr); return std::make_tuple(true, sty, arr);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy = contiguous_copy_gpu(arr, s);
copy_gpu(arr, arr_copy, CopyType::General, s);
enc.add_temporary(arr_copy); enc.add_temporary(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy); return std::make_tuple(false, arr.shape(-1), arr_copy);
} }

View File

@ -247,8 +247,7 @@ inline array ensure_row_contiguous(
cu::CommandEncoder& enc, cu::CommandEncoder& enc,
const Stream& s) { const Stream& s) {
if (!x.flags().row_contiguous) { if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
enc.add_temporary(x_copy); enc.add_temporary(x_copy);
return x_copy; return x_copy;
} else { } else {

View File

@ -47,8 +47,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
} }
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) { if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
array in_copy(in.shape(), in.dtype(), nullptr, {}); array in_copy = contiguous_copy_gpu(in, s);
copy_gpu(in, in_copy, CopyType::General, s);
encoder.add_temporary(in_copy); encoder.add_temporary(in_copy);
in = in_copy; in = in_copy;
plan = get_reduction_plan(in, axes_); plan = get_reduction_plan(in, axes_);

View File

@ -206,8 +206,7 @@ void RMSNorm::eval_gpu(
} }
return x; return x;
} else { } else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }
@ -259,9 +258,7 @@ void RMSNormVJP::eval_gpu(
return x; return x;
} }
copied = true; copied = true;
array x_copy(x.shape(), x.dtype(), nullptr, {}); return contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
return x_copy;
}; };
bool donate_x = inputs[0].is_donatable(); bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[2].is_donatable(); bool donate_g = inputs[2].is_donatable();

View File

@ -379,9 +379,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
in.flags()); in.flags());
} }
} else { } else {
array arr_copy(in.shape(), in.dtype(), nullptr, {}); in = contiguous_copy_gpu(in, s);
copy_gpu(in, arr_copy, CopyType::General, s);
in = std::move(arr_copy);
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} }

View File

@ -125,8 +125,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
return x; return x;
} else { } else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }

View File

@ -72,8 +72,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1; bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
if (!is_segmented_sort) { if (!is_segmented_sort) {
array trans = swapaxes_in_eval(in, axis, last_dim); array trans = swapaxes_in_eval(in, axis, last_dim);
in = array(trans.shape(), trans.dtype(), nullptr, {}); in = contiguous_copy_gpu(trans, s);
copy_gpu(trans, in, CopyType::General, s);
encoder.add_temporary(in); encoder.add_temporary(in);
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype()); out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
encoder.add_temporary(out); encoder.add_temporary(out);

View File

@ -46,4 +46,10 @@ void copy_gpu_inplace(
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s); in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
} }
array contiguous_copy_gpu(const array& arr, const Stream& s) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
return arr_copy;
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -43,4 +43,7 @@ void copy_gpu_inplace(
// Fill the output with the scalar val // Fill the output with the scalar val
void fill_gpu(const array& val, array& out, const Stream& s); void fill_gpu(const array& val, array& out, const Stream& s);
// Return a contiguous array with same shape that copies the data of |arr|.
array contiguous_copy_gpu(const array& arr, const Stream& s);
} // namespace mlx::core } // namespace mlx::core

View File

@ -149,8 +149,7 @@ void explicit_gemm_conv_group_ND_gpu(
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size()); wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
// Materialize // Materialize
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {}); array wt_transpose = contiguous_copy_gpu(wt_view, s);
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
// Perform gemm // Perform gemm
std::vector<array> copies = {in_unfolded, wt_transpose}; std::vector<array> copies = {in_unfolded, wt_transpose};
@ -961,16 +960,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0]; auto in = inputs[0];
auto wt = inputs[1]; auto wt = inputs[1];
if (!in.flags().row_contiguous) { if (!in.flags().row_contiguous) {
array arr_copy(in.shape(), in.dtype(), nullptr, {}); in = contiguous_copy_gpu(in, s);
copy_gpu(in, arr_copy, CopyType::General, s); copies.push_back(in);
copies.push_back(arr_copy);
in = arr_copy;
} }
if (!wt.flags().row_contiguous) { if (!wt.flags().row_contiguous) {
array arr_copy(wt.shape(), wt.dtype(), nullptr, {}); wt = contiguous_copy_gpu(wt, s);
copy_gpu(wt, arr_copy, CopyType::General, s); copies.push_back(wt);
copies.push_back(arr_copy);
wt = arr_copy;
} }
// 3D conv // 3D conv

View File

@ -25,8 +25,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) { if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
return x; return x;
} else { } else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index); d.add_temporary(x_copy, s.index);
return x_copy; return x_copy;
} }

View File

@ -33,8 +33,7 @@ std::tuple<bool, int64_t, array> check_transpose(
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) { } else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
return std::make_tuple(true, sty, arr); return std::make_tuple(true, sty, arr);
} else { } else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy = contiguous_copy_gpu(arr, s);
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy); copies.push_back(arr_copy);
return std::make_tuple(false, arr.shape(-1), arr_copy); return std::make_tuple(false, arr.shape(-1), arr_copy);
} }
@ -43,8 +42,7 @@ std::tuple<bool, int64_t, array> check_transpose(
inline array inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) { if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index); d.add_temporary(x_copy, s.index);
return x_copy; return x_copy;
} else { } else {
@ -75,8 +73,7 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
} }
} }
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index); d.add_temporary(x_copy, s.index);
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy); return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
} }
@ -1894,8 +1891,7 @@ void segmented_mm(
return std::make_tuple(false, x); return std::make_tuple(false, x);
} }
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index); d.add_temporary(x_copy, s.index);
return std::make_tuple(true, x_copy); return std::make_tuple(true, x_copy);
}; };

View File

@ -40,8 +40,7 @@ void RMSNorm::eval_gpu(
} }
return x; return x;
} else { } else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }
@ -107,9 +106,7 @@ void RMSNormVJP::eval_gpu(
if (x.flags().row_contiguous) { if (x.flags().row_contiguous) {
return {x, false}; return {x, false};
} }
array x_copy = contiguous_copy_gpu(x, s);
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true}; return {x_copy, true};
}; };
bool donate_x = inputs[0].is_donatable(); bool donate_x = inputs[0].is_donatable();
@ -241,8 +238,7 @@ void LayerNorm::eval_gpu(
} }
return x; return x;
} else { } else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }
@ -319,8 +315,7 @@ void LayerNormVJP::eval_gpu(
if (x.flags().row_contiguous) { if (x.flags().row_contiguous) {
return {x, false}; return {x, false};
} }
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
return {x_copy, true}; return {x_copy, true};
}; };
bool donate_x = inputs[0].is_donatable(); bool donate_x = inputs[0].is_donatable();

View File

@ -20,8 +20,7 @@ namespace {
inline array inline array
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) { ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
if (!x.flags().row_contiguous) { if (!x.flags().row_contiguous) {
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index); d.add_temporary(x_copy, s.index);
return x_copy; return x_copy;
} else { } else {
@ -38,8 +37,7 @@ inline array ensure_row_contiguous_matrix(
if (stride_0 == x.shape(-1) && stride_1 == 1) { if (stride_0 == x.shape(-1) && stride_1 == 1) {
return x; return x;
} else { } else {
array x_copy(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index); d.add_temporary(x_copy, s.index);
return x_copy; return x_copy;
} }

View File

@ -989,8 +989,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
// input for the axes with stride smaller than the minimum reduction // input for the axes with stride smaller than the minimum reduction
// stride. // stride.
if (plan.type == GeneralReduce) { if (plan.type == GeneralReduce) {
array in_copy(in.shape(), in.dtype(), nullptr, {}); array in_copy = contiguous_copy_gpu(in, s);
copy_gpu(in, in_copy, CopyType::General, s);
d.add_temporary(in_copy, s.index); d.add_temporary(in_copy, s.index);
in = in_copy; in = in_copy;
plan = get_reduction_plan(in, axes_); plan = get_reduction_plan(in, axes_);

View File

@ -398,8 +398,7 @@ void ScaledDotProductAttention::eval_gpu(
auto copy_unless = [&copies, &s]( auto copy_unless = [&copies, &s](
auto predicate, const array& arr) -> const array& { auto predicate, const array& arr) -> const array& {
if (!predicate(arr)) { if (!predicate(arr)) {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); array arr_copy = contiguous_copy_gpu(arr, s);
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(std::move(arr_copy)); copies.push_back(std::move(arr_copy));
return copies.back(); return copies.back();
} else { } else {

View File

@ -30,9 +30,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
in.flags()); in.flags());
} }
} else { } else {
array arr_copy(in.shape(), in.dtype(), nullptr, {}); in = contiguous_copy_gpu(in, s);
copy_gpu(in, arr_copy, CopyType::General, s);
in = std::move(arr_copy);
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} }

View File

@ -35,8 +35,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
} }
return x; return x;
} else { } else {
auto x_copy = array(x.shape(), x.dtype(), nullptr, {}); array x_copy = contiguous_copy_gpu(x, s);
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy); out.copy_shared_buffer(x_copy);
return x_copy; return x_copy;
} }