mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
More buffer donation with no-ops (#1591)
* more donation * fix test * fix build
This commit is contained in:
parent
6931f84412
commit
9bd03dd9b4
@ -39,7 +39,7 @@ void AsStrided::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
// rely on data_size anyway.
|
// rely on data_size anyway.
|
||||||
size_t data_size = out.size();
|
size_t data_size = out.size();
|
||||||
|
|
||||||
return out.copy_shared_buffer(in, strides_, flags, data_size, offset_);
|
return move_or_copy(in, out, strides_, flags, data_size, offset_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
||||||
@ -58,12 +58,12 @@ void Broadcast::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
if (out.size() > in.size()) {
|
if (out.size() > in.size()) {
|
||||||
flags.row_contiguous = flags.col_contiguous = false;
|
flags.row_contiguous = flags.col_contiguous = false;
|
||||||
}
|
}
|
||||||
out.copy_shared_buffer(in, strides, flags, in.data_size());
|
move_or_copy(in, out, strides, flags, in.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
void Copy::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
out.copy_shared_buffer(inputs[0]);
|
move_or_copy(inputs[0], out);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CustomTransforms::eval(
|
void CustomTransforms::eval(
|
||||||
@ -72,7 +72,7 @@ void CustomTransforms::eval(
|
|||||||
assert(inputs.size() > outputs.size());
|
assert(inputs.size() > outputs.size());
|
||||||
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
for (int i = 0, j = inputs.size() - outputs.size(); i < outputs.size();
|
||||||
i++, j++) {
|
i++, j++) {
|
||||||
outputs[i].copy_shared_buffer(inputs[j]);
|
move_or_copy(inputs[j], outputs[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ void Depends::eval(
|
|||||||
std::vector<array>& outputs) {
|
std::vector<array>& outputs) {
|
||||||
assert(inputs.size() > outputs.size());
|
assert(inputs.size() > outputs.size());
|
||||||
for (int i = 0; i < outputs.size(); i++) {
|
for (int i = 0; i < outputs.size(); i++) {
|
||||||
outputs[i].copy_shared_buffer(inputs[i]);
|
move_or_copy(inputs[i], outputs[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -194,7 +194,7 @@ void Reshape::shared_buffer_reshape(
|
|||||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||||
}
|
}
|
||||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
move_or_copy(in, out, out_strides, flags, in.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Split::eval(
|
void Split::eval(
|
||||||
@ -263,7 +263,7 @@ std::tuple<int64_t, std::vector<int64_t>> SliceUpdate::prepare_slice(
|
|||||||
|
|
||||||
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
void StopGradient::eval(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
out.copy_shared_buffer(inputs[0]);
|
move_or_copy(inputs[0], out);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
||||||
@ -297,7 +297,7 @@ void Transpose::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
b_stride *= out.shape(ri);
|
b_stride *= out.shape(ri);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
out.copy_shared_buffer(in, out_strides, flags, in.data_size());
|
move_or_copy(in, out, out_strides, flags, in.data_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -34,7 +34,7 @@ void shared_buffer_slice(
|
|||||||
flags.col_contiguous = is_col_contiguous;
|
flags.col_contiguous = is_col_contiguous;
|
||||||
flags.contiguous = (no_bsx_size == data_size);
|
flags.contiguous = (no_bsx_size == data_size);
|
||||||
|
|
||||||
out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
|
move_or_copy(in, out, out_strides, flags, data_size, data_offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -4,6 +4,28 @@
|
|||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
|
void move_or_copy(const array& in, array& out) {
|
||||||
|
if (in.is_donatable()) {
|
||||||
|
out.move_shared_buffer(in);
|
||||||
|
} else {
|
||||||
|
out.copy_shared_buffer(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void move_or_copy(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
const std::vector<size_t>& strides,
|
||||||
|
array::Flags flags,
|
||||||
|
size_t data_size,
|
||||||
|
size_t offset /* = 0 */) {
|
||||||
|
if (in.is_donatable()) {
|
||||||
|
out.move_shared_buffer(in, strides, flags, data_size, offset);
|
||||||
|
} else {
|
||||||
|
out.copy_shared_buffer(in, strides, flags, data_size, offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename StrideT>
|
template <typename StrideT>
|
||||||
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
|
std::tuple<std::vector<int>, std::vector<std::vector<StrideT>>>
|
||||||
collapse_contiguous_dims_impl(
|
collapse_contiguous_dims_impl(
|
||||||
|
@ -178,4 +178,13 @@ inline bool is_donatable(const array& in, const array& out) {
|
|||||||
in.buffer_size() <= out.nbytes() + donation_extra;
|
in.buffer_size() <= out.nbytes() + donation_extra;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void move_or_copy(const array& in, array& out);
|
||||||
|
void move_or_copy(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
const std::vector<size_t>& strides,
|
||||||
|
array::Flags flags,
|
||||||
|
size_t data_size,
|
||||||
|
size_t offset = 0);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "mlx/backend/common/load.h"
|
#include "mlx/backend/common/load.h"
|
||||||
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/metal/copy.h"
|
#include "mlx/backend/metal/copy.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
@ -343,7 +344,7 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
auto& upd = inputs[1];
|
auto& upd = inputs[1];
|
||||||
|
|
||||||
if (upd.size() == 0) {
|
if (upd.size() == 0) {
|
||||||
out.copy_shared_buffer(in);
|
move_or_copy(in, out);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -420,8 +421,8 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
strides[i] *= ibytes;
|
strides[i] *= ibytes;
|
||||||
strides[i] /= obytes;
|
strides[i] /= obytes;
|
||||||
}
|
}
|
||||||
out.copy_shared_buffer(
|
move_or_copy(
|
||||||
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
in, out, strides, in.flags(), in.data_size() * ibytes / obytes);
|
||||||
} else {
|
} else {
|
||||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||||
|
@ -161,7 +161,7 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
unary_op_gpu(inputs, out, get_primitive_string(this));
|
||||||
} else {
|
} else {
|
||||||
// No-op integer types
|
// No-op integer types
|
||||||
out.copy_shared_buffer(in);
|
move_or_copy(in, out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -137,6 +137,43 @@ class TestEval(mlx_tests.MLXTestCase):
|
|||||||
mx.async_eval(x)
|
mx.async_eval(x)
|
||||||
mx.eval(a + b)
|
mx.eval(a + b)
|
||||||
|
|
||||||
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
|
||||||
|
def test_donation_for_noops(self):
|
||||||
|
def fun(x):
|
||||||
|
s = x.shape
|
||||||
|
for _ in range(10):
|
||||||
|
x = mx.abs(x)
|
||||||
|
x = mx.reshape(x, (-1,))
|
||||||
|
x = x.T.T
|
||||||
|
x = mx.stop_gradient(x)
|
||||||
|
x = mx.abs(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = mx.zeros((4096, 4096))
|
||||||
|
mx.eval(x)
|
||||||
|
pre = mx.metal.get_peak_memory()
|
||||||
|
out = fun(x)
|
||||||
|
del x
|
||||||
|
mx.eval(out)
|
||||||
|
post = mx.metal.get_peak_memory()
|
||||||
|
self.assertEqual(pre, post)
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
for _ in range(10):
|
||||||
|
x = mx.abs(x)
|
||||||
|
x = x[:-1]
|
||||||
|
x = mx.abs(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
x = mx.zeros((4096 * 4096,))
|
||||||
|
mx.eval(x)
|
||||||
|
pre = mx.metal.get_peak_memory()
|
||||||
|
out = fun(x)
|
||||||
|
del x
|
||||||
|
mx.eval(out)
|
||||||
|
post = mx.metal.get_peak_memory()
|
||||||
|
self.assertEqual(pre, post)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user