mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add contiguous_copy_gpu util for copying array (#2379)
This commit is contained in:
@@ -20,8 +20,7 @@ namespace {
|
||||
inline array
|
||||
ensure_row_contiguous(const array& x, metal::Device& d, const Stream& s) {
|
||||
if (!x.flags().row_contiguous) {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return x_copy;
|
||||
} else {
|
||||
@@ -38,8 +37,7 @@ inline array ensure_row_contiguous_matrix(
|
||||
if (stride_0 == x.shape(-1) && stride_1 == 1) {
|
||||
return x;
|
||||
} else {
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
array x_copy = contiguous_copy_gpu(x, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
return x_copy;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user