mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Fixes output donation for IO ops on the GPU (#1857)
This commit is contained in:
parent
0a5215693e
commit
0145911bea
@ -76,6 +76,18 @@ array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
||||
set_data(data, deleter);
|
||||
}
|
||||
|
||||
array::array(
|
||||
allocator::Buffer data,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
Strides strides,
|
||||
size_t data_size,
|
||||
Flags flags,
|
||||
Deleter deleter)
|
||||
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||
set_data(data, data_size, std::move(strides), flags, deleter);
|
||||
}
|
||||
|
||||
void array::detach() {
|
||||
for (auto& s : array_desc_->siblings) {
|
||||
s.array_desc_->inputs.clear();
|
||||
|
12
mlx/array.h
12
mlx/array.h
@ -243,6 +243,18 @@ class array {
|
||||
bool col_contiguous : 1;
|
||||
};
|
||||
|
||||
/** Build an array from all the info held by the array description. Including
|
||||
* the buffer, strides, flags.
|
||||
*/
|
||||
explicit array(
|
||||
allocator::Buffer data,
|
||||
Shape shape,
|
||||
Dtype dtype,
|
||||
Strides strides,
|
||||
size_t data_size,
|
||||
Flags flags,
|
||||
Deleter deleter = allocator::free);
|
||||
|
||||
/** The array's primitive. */
|
||||
Primitive& primitive() const {
|
||||
return *(array_desc_->primitive);
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/event.h"
|
||||
#include "mlx/backend/metal/fence.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
@ -43,7 +44,7 @@ void AllReduce::eval_gpu(
|
||||
f.wait_gpu(out);
|
||||
|
||||
auto task = [in = in,
|
||||
out = out,
|
||||
out = unsafe_weak_copy(out),
|
||||
f = std::move(f),
|
||||
reduce_type = reduce_type_,
|
||||
group = group()]() mutable {
|
||||
@ -80,14 +81,16 @@ void AllGather::eval_gpu(
|
||||
}
|
||||
f.wait_gpu(out);
|
||||
|
||||
auto task =
|
||||
[in = in, out = out, f = std::move(f), group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
f.wait();
|
||||
}
|
||||
distributed::detail::all_gather(group, in, out);
|
||||
f.update();
|
||||
};
|
||||
auto task = [in = in,
|
||||
out = unsafe_weak_copy(out),
|
||||
f = std::move(f),
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
f.wait();
|
||||
}
|
||||
distributed::detail::all_gather(group, in, out);
|
||||
f.update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
}
|
||||
|
||||
@ -110,7 +113,7 @@ void Send::eval_gpu(
|
||||
|
||||
// Schedule an async send on the comm stream
|
||||
auto task = [in = in,
|
||||
out = out,
|
||||
out = unsafe_weak_copy(out),
|
||||
f = std::move(f),
|
||||
group = group(),
|
||||
dst = dst_]() mutable {
|
||||
@ -136,11 +139,13 @@ void Recv::eval_gpu(
|
||||
f.wait_gpu(out);
|
||||
|
||||
// Schedule an async recv on the comm stream
|
||||
auto task =
|
||||
[out = out, f = std::move(f), group = group(), src = src_]() mutable {
|
||||
distributed::detail::recv(group, out, src);
|
||||
f.update();
|
||||
};
|
||||
auto task = [out = unsafe_weak_copy(out),
|
||||
f = std::move(f),
|
||||
group = group(),
|
||||
src = src_]() mutable {
|
||||
distributed::detail::recv(group, out, src);
|
||||
f.update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
}
|
||||
|
||||
|
@ -303,7 +303,7 @@ void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto read_task = [out = out,
|
||||
auto read_task = [out = unsafe_weak_copy(out),
|
||||
offset = offset_,
|
||||
reader = reader_,
|
||||
swap_endianness = swap_endianness_]() mutable {
|
||||
|
@ -69,4 +69,19 @@ void concatenate(std::string& acc, T first, Args... args) {
|
||||
concatenate(acc, args...);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a new array that refers to the same data but has a non-owning pointer to
|
||||
* them.
|
||||
*/
|
||||
inline array unsafe_weak_copy(const array& x) {
|
||||
return array(
|
||||
x.buffer(),
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
x.strides(),
|
||||
x.data_size(),
|
||||
x.flags(),
|
||||
[](auto b) {});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -174,6 +174,21 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
||||
finally:
|
||||
mx.distributed.all_sum = original_all_sum
|
||||
|
||||
def test_donation(self):
|
||||
x = mx.random.normal((1024,))
|
||||
mx.eval(x)
|
||||
|
||||
mx.metal.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
y = mx.distributed.all_sum(x)
|
||||
mx.eval(y)
|
||||
all_sum_only = mx.metal.get_peak_memory()
|
||||
y = mx.distributed.all_sum(x) * scale
|
||||
mx.eval(y)
|
||||
all_sum_with_binary = mx.metal.get_peak_memory()
|
||||
|
||||
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -356,6 +356,23 @@ class TestLoad(mlx_tests.MLXTestCase):
|
||||
aload = mx.load(save_file)["a"]
|
||||
self.assertTrue(mx.array_equal(a, aload))
|
||||
|
||||
def test_load_donation(self):
|
||||
x = mx.random.normal((1024,))
|
||||
mx.eval(x)
|
||||
save_file = os.path.join(self.test_dir, "donation.npy")
|
||||
mx.save(save_file, x)
|
||||
|
||||
mx.metal.reset_peak_memory()
|
||||
scale = mx.array(2.0)
|
||||
y = mx.load(save_file)
|
||||
mx.eval(y)
|
||||
load_only = mx.metal.get_peak_memory()
|
||||
y = mx.load(save_file) * scale
|
||||
mx.eval(y)
|
||||
load_with_binary = mx.metal.get_peak_memory()
|
||||
|
||||
self.assertEqual(load_only, load_with_binary)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user