mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 07:31:26 +08:00
Add contiguous_copy_gpu util for copying array (#2379)
This commit is contained in:
parent
31fc530c76
commit
45adec102c
@ -237,8 +237,7 @@ void LayerNorm::eval_gpu(
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
@ -295,9 +294,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
copied = true;
|
copied = true;
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
return contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
return x_copy;
|
|
||||||
};
|
};
|
||||||
bool donate_x = inputs[0].is_donatable();
|
bool donate_x = inputs[0].is_donatable();
|
||||||
bool donate_g = inputs[3].is_donatable();
|
bool donate_g = inputs[3].is_donatable();
|
||||||
|
@ -108,8 +108,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
encoder.add_temporary(x_copy);
|
encoder.add_temporary(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
@ -297,8 +297,7 @@ check_transpose(cu::CommandEncoder& enc, const Stream& s, const array& arr) {
|
|||||||
} else if (stx == 1 && sty == arr.shape(-2)) {
|
} else if (stx == 1 && sty == arr.shape(-2)) {
|
||||||
return std::make_tuple(true, sty, arr);
|
return std::make_tuple(true, sty, arr);
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy = contiguous_copy_gpu(arr, s);
|
||||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
|
||||||
enc.add_temporary(arr_copy);
|
enc.add_temporary(arr_copy);
|
||||||
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
||||||
}
|
}
|
||||||
|
@ -247,8 +247,7 @@ inline array ensure_row_contiguous(
|
|||||||
cu::CommandEncoder& enc,
|
cu::CommandEncoder& enc,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
if (!x.flags().row_contiguous) {
|
if (!x.flags().row_contiguous) {
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
enc.add_temporary(x_copy);
|
enc.add_temporary(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
} else {
|
} else {
|
||||||
|
@ -47,8 +47,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
|
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
|
||||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
array in_copy = contiguous_copy_gpu(in, s);
|
||||||
copy_gpu(in, in_copy, CopyType::General, s);
|
|
||||||
encoder.add_temporary(in_copy);
|
encoder.add_temporary(in_copy);
|
||||||
in = in_copy;
|
in = in_copy;
|
||||||
plan = get_reduction_plan(in, axes_);
|
plan = get_reduction_plan(in, axes_);
|
||||||
|
@ -206,8 +206,7 @@ void RMSNorm::eval_gpu(
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
@ -259,9 +258,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
copied = true;
|
copied = true;
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
return contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
return x_copy;
|
|
||||||
};
|
};
|
||||||
bool donate_x = inputs[0].is_donatable();
|
bool donate_x = inputs[0].is_donatable();
|
||||||
bool donate_g = inputs[2].is_donatable();
|
bool donate_g = inputs[2].is_donatable();
|
||||||
|
@ -379,9 +379,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in.flags());
|
in.flags());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
in = contiguous_copy_gpu(in, s);
|
||||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
|
||||||
in = std::move(arr_copy);
|
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,8 +125,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
@ -72,8 +72,7 @@ void gpu_sort(const Stream& s, array in, array& out_, int axis, bool argsort) {
|
|||||||
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
|
bool is_segmented_sort = in.flags().contiguous && in.strides()[axis] == 1;
|
||||||
if (!is_segmented_sort) {
|
if (!is_segmented_sort) {
|
||||||
array trans = swapaxes_in_eval(in, axis, last_dim);
|
array trans = swapaxes_in_eval(in, axis, last_dim);
|
||||||
in = array(trans.shape(), trans.dtype(), nullptr, {});
|
in = contiguous_copy_gpu(trans, s);
|
||||||
copy_gpu(trans, in, CopyType::General, s);
|
|
||||||
encoder.add_temporary(in);
|
encoder.add_temporary(in);
|
||||||
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
out = array(allocator::malloc(out.nbytes()), in.shape(), out.dtype());
|
||||||
encoder.add_temporary(out);
|
encoder.add_temporary(out);
|
||||||
|
@ -46,4 +46,10 @@ void copy_gpu_inplace(
|
|||||||
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
in, out, in.shape(), i_strides, out.strides(), i_offset, 0, ctype, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array contiguous_copy_gpu(const array& arr, const Stream& s) {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -43,4 +43,7 @@ void copy_gpu_inplace(
|
|||||||
// Fill the output with the scalar val
|
// Fill the output with the scalar val
|
||||||
void fill_gpu(const array& val, array& out, const Stream& s);
|
void fill_gpu(const array& val, array& out, const Stream& s);
|
||||||
|
|
||||||
|
// Return a contiguous array with same shape that copies the data of |arr|.
|
||||||
|
array contiguous_copy_gpu(const array& arr, const Stream& s);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -149,8 +149,7 @@ void explicit_gemm_conv_group_ND_gpu(
|
|||||||
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
wt, {wt.strides(0), 1, C_per_group}, wt.flags(), wt.size());
|
||||||
|
|
||||||
// Materialize
|
// Materialize
|
||||||
auto wt_transpose = array(wt_view.shape(), wt_view.dtype(), nullptr, {});
|
array wt_transpose = contiguous_copy_gpu(wt_view, s);
|
||||||
copy_gpu(wt_view, wt_transpose, CopyType::General, s);
|
|
||||||
|
|
||||||
// Perform gemm
|
// Perform gemm
|
||||||
std::vector<array> copies = {in_unfolded, wt_transpose};
|
std::vector<array> copies = {in_unfolded, wt_transpose};
|
||||||
@ -961,16 +960,12 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto in = inputs[0];
|
auto in = inputs[0];
|
||||||
auto wt = inputs[1];
|
auto wt = inputs[1];
|
||||||
if (!in.flags().row_contiguous) {
|
if (!in.flags().row_contiguous) {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
in = contiguous_copy_gpu(in, s);
|
||||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
copies.push_back(in);
|
||||||
copies.push_back(arr_copy);
|
|
||||||
in = arr_copy;
|
|
||||||
}
|
}
|
||||||
if (!wt.flags().row_contiguous) {
|
if (!wt.flags().row_contiguous) {
|
||||||
array arr_copy(wt.shape(), wt.dtype(), nullptr, {});
|
wt = contiguous_copy_gpu(wt, s);
|
||||||
copy_gpu(wt, arr_copy, CopyType::General, s);
|
copies.push_back(wt);
|
||||||
copies.push_back(arr_copy);
|
|
||||||
wt = arr_copy;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3D conv
|
// 3D conv
|
||||||
|
@ -25,8 +25,7 @@ void LogSumExp::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
if (x.flags().contiguous && x.strides()[x.ndim() - 1] == 1) {
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
d.add_temporary(x_copy, s.index);
|
d.add_temporary(x_copy, s.index);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
@ -33,8 +33,7 @@ std::tuple<bool, int64_t, array> check_transpose(
|
|||||||
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
} else if (stx == 1 && (!is_vector || sty == arr.shape(-2))) {
|
||||||
return std::make_tuple(true, sty, arr);
|
return std::make_tuple(true, sty, arr);
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy = contiguous_copy_gpu(arr, s);
|
||||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
|
||||||
copies.push_back(arr_copy);
|
copies.push_back(arr_copy);
|
||||||
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
return std::make_tuple(false, arr.shape(-1), arr_copy);
|
||||||
}
|
}
|
||||||
@ -43,8 +42,7 @@ std::tuple<bool, int64_t, array> check_transpose(
|
|||||||
inline array
|
inline array
|
||||||
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||||
if (!x.flags().row_contiguous) {
|
if (!x.flags().row_contiguous) {
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
d.add_temporary(x_copy, s.index);
|
d.add_temporary(x_copy, s.index);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
} else {
|
} else {
|
||||||
@ -75,8 +73,7 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
d.add_temporary(x_copy, s.index);
|
d.add_temporary(x_copy, s.index);
|
||||||
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
|
return std::make_tuple(false, x_copy.strides()[x_copy.ndim() - 2], x_copy);
|
||||||
}
|
}
|
||||||
@ -1894,8 +1891,7 @@ void segmented_mm(
|
|||||||
return std::make_tuple(false, x);
|
return std::make_tuple(false, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
d.add_temporary(x_copy, s.index);
|
d.add_temporary(x_copy, s.index);
|
||||||
return std::make_tuple(true, x_copy);
|
return std::make_tuple(true, x_copy);
|
||||||
};
|
};
|
||||||
|
@ -40,8 +40,7 @@ void RMSNorm::eval_gpu(
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
@ -107,9 +106,7 @@ void RMSNormVJP::eval_gpu(
|
|||||||
if (x.flags().row_contiguous) {
|
if (x.flags().row_contiguous) {
|
||||||
return {x, false};
|
return {x, false};
|
||||||
}
|
}
|
||||||
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
return {x_copy, true};
|
return {x_copy, true};
|
||||||
};
|
};
|
||||||
bool donate_x = inputs[0].is_donatable();
|
bool donate_x = inputs[0].is_donatable();
|
||||||
@ -241,8 +238,7 @@ void LayerNorm::eval_gpu(
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
@ -319,8 +315,7 @@ void LayerNormVJP::eval_gpu(
|
|||||||
if (x.flags().row_contiguous) {
|
if (x.flags().row_contiguous) {
|
||||||
return {x, false};
|
return {x, false};
|
||||||
}
|
}
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
return {x_copy, true};
|
return {x_copy, true};
|
||||||
};
|
};
|
||||||
bool donate_x = inputs[0].is_donatable();
|
bool donate_x = inputs[0].is_donatable();
|
||||||
|
@ -20,8 +20,7 @@ namespace {
|
|||||||
inline array
|
inline array
|
||||||
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||||
if (!x.flags().row_contiguous) {
|
if (!x.flags().row_contiguous) {
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
d.add_temporary(x_copy, s.index);
|
d.add_temporary(x_copy, s.index);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
} else {
|
} else {
|
||||||
@ -38,8 +37,7 @@ inline array ensure_row_contiguous_matrix(
|
|||||||
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
d.add_temporary(x_copy, s.index);
|
d.add_temporary(x_copy, s.index);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
@ -989,8 +989,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// input for the axes with stride smaller than the minimum reduction
|
// input for the axes with stride smaller than the minimum reduction
|
||||||
// stride.
|
// stride.
|
||||||
if (plan.type == GeneralReduce) {
|
if (plan.type == GeneralReduce) {
|
||||||
array in_copy(in.shape(), in.dtype(), nullptr, {});
|
array in_copy = contiguous_copy_gpu(in, s);
|
||||||
copy_gpu(in, in_copy, CopyType::General, s);
|
|
||||||
d.add_temporary(in_copy, s.index);
|
d.add_temporary(in_copy, s.index);
|
||||||
in = in_copy;
|
in = in_copy;
|
||||||
plan = get_reduction_plan(in, axes_);
|
plan = get_reduction_plan(in, axes_);
|
||||||
|
@ -398,8 +398,7 @@ void ScaledDotProductAttention::eval_gpu(
|
|||||||
auto copy_unless = [&copies, &s](
|
auto copy_unless = [&copies, &s](
|
||||||
auto predicate, const array& arr) -> const array& {
|
auto predicate, const array& arr) -> const array& {
|
||||||
if (!predicate(arr)) {
|
if (!predicate(arr)) {
|
||||||
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
array arr_copy = contiguous_copy_gpu(arr, s);
|
||||||
copy_gpu(arr, arr_copy, CopyType::General, s);
|
|
||||||
copies.push_back(std::move(arr_copy));
|
copies.push_back(std::move(arr_copy));
|
||||||
return copies.back();
|
return copies.back();
|
||||||
} else {
|
} else {
|
||||||
|
@ -30,9 +30,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
in.flags());
|
in.flags());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
in = contiguous_copy_gpu(in, s);
|
||||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
|
||||||
in = std::move(arr_copy);
|
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,8 +35,7 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
return x;
|
return x;
|
||||||
} else {
|
} else {
|
||||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
array x_copy = contiguous_copy_gpu(x, s);
|
||||||
copy_gpu(x, x_copy, CopyType::General, s);
|
|
||||||
out.copy_shared_buffer(x_copy);
|
out.copy_shared_buffer(x_copy);
|
||||||
return x_copy;
|
return x_copy;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user