mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-26 02:33:21 +08:00
Fixes output donation for IO ops on the GPU (#1857)
This commit is contained in:
parent
0a5215693e
commit
0145911bea
@ -76,6 +76,18 @@ array::array(allocator::Buffer data, Shape shape, Dtype dtype, Deleter deleter)
|
|||||||
set_data(data, deleter);
|
set_data(data, deleter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array::array(
|
||||||
|
allocator::Buffer data,
|
||||||
|
Shape shape,
|
||||||
|
Dtype dtype,
|
||||||
|
Strides strides,
|
||||||
|
size_t data_size,
|
||||||
|
Flags flags,
|
||||||
|
Deleter deleter)
|
||||||
|
: array_desc_(std::make_shared<ArrayDesc>(std::move(shape), dtype)) {
|
||||||
|
set_data(data, data_size, std::move(strides), flags, deleter);
|
||||||
|
}
|
||||||
|
|
||||||
void array::detach() {
|
void array::detach() {
|
||||||
for (auto& s : array_desc_->siblings) {
|
for (auto& s : array_desc_->siblings) {
|
||||||
s.array_desc_->inputs.clear();
|
s.array_desc_->inputs.clear();
|
||||||
|
12
mlx/array.h
12
mlx/array.h
@ -243,6 +243,18 @@ class array {
|
|||||||
bool col_contiguous : 1;
|
bool col_contiguous : 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** Build an array from all the info held by the array description. Including
|
||||||
|
* the buffer, strides, flags.
|
||||||
|
*/
|
||||||
|
explicit array(
|
||||||
|
allocator::Buffer data,
|
||||||
|
Shape shape,
|
||||||
|
Dtype dtype,
|
||||||
|
Strides strides,
|
||||||
|
size_t data_size,
|
||||||
|
Flags flags,
|
||||||
|
Deleter deleter = allocator::free);
|
||||||
|
|
||||||
/** The array's primitive. */
|
/** The array's primitive. */
|
||||||
Primitive& primitive() const {
|
Primitive& primitive() const {
|
||||||
return *(array_desc_->primitive);
|
return *(array_desc_->primitive);
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/event.h"
|
#include "mlx/backend/metal/event.h"
|
||||||
#include "mlx/backend/metal/fence.h"
|
#include "mlx/backend/metal/fence.h"
|
||||||
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
@ -43,7 +44,7 @@ void AllReduce::eval_gpu(
|
|||||||
f.wait_gpu(out);
|
f.wait_gpu(out);
|
||||||
|
|
||||||
auto task = [in = in,
|
auto task = [in = in,
|
||||||
out = out,
|
out = unsafe_weak_copy(out),
|
||||||
f = std::move(f),
|
f = std::move(f),
|
||||||
reduce_type = reduce_type_,
|
reduce_type = reduce_type_,
|
||||||
group = group()]() mutable {
|
group = group()]() mutable {
|
||||||
@ -80,8 +81,10 @@ void AllGather::eval_gpu(
|
|||||||
}
|
}
|
||||||
f.wait_gpu(out);
|
f.wait_gpu(out);
|
||||||
|
|
||||||
auto task =
|
auto task = [in = in,
|
||||||
[in = in, out = out, f = std::move(f), group = group()]() mutable {
|
out = unsafe_weak_copy(out),
|
||||||
|
f = std::move(f),
|
||||||
|
group = group()]() mutable {
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
f.wait();
|
f.wait();
|
||||||
}
|
}
|
||||||
@ -110,7 +113,7 @@ void Send::eval_gpu(
|
|||||||
|
|
||||||
// Schedule an async send on the comm stream
|
// Schedule an async send on the comm stream
|
||||||
auto task = [in = in,
|
auto task = [in = in,
|
||||||
out = out,
|
out = unsafe_weak_copy(out),
|
||||||
f = std::move(f),
|
f = std::move(f),
|
||||||
group = group(),
|
group = group(),
|
||||||
dst = dst_]() mutable {
|
dst = dst_]() mutable {
|
||||||
@ -136,8 +139,10 @@ void Recv::eval_gpu(
|
|||||||
f.wait_gpu(out);
|
f.wait_gpu(out);
|
||||||
|
|
||||||
// Schedule an async recv on the comm stream
|
// Schedule an async recv on the comm stream
|
||||||
auto task =
|
auto task = [out = unsafe_weak_copy(out),
|
||||||
[out = out, f = std::move(f), group = group(), src = src_]() mutable {
|
f = std::move(f),
|
||||||
|
group = group(),
|
||||||
|
src = src_]() mutable {
|
||||||
distributed::detail::recv(group, out, src);
|
distributed::detail::recv(group, out, src);
|
||||||
f.update();
|
f.update();
|
||||||
};
|
};
|
||||||
|
@ -303,7 +303,7 @@ void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
auto read_task = [out = out,
|
auto read_task = [out = unsafe_weak_copy(out),
|
||||||
offset = offset_,
|
offset = offset_,
|
||||||
reader = reader_,
|
reader = reader_,
|
||||||
swap_endianness = swap_endianness_]() mutable {
|
swap_endianness = swap_endianness_]() mutable {
|
||||||
|
@ -69,4 +69,19 @@ void concatenate(std::string& acc, T first, Args... args) {
|
|||||||
concatenate(acc, args...);
|
concatenate(acc, args...);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a new array that refers to the same data but has a non-owning pointer to
|
||||||
|
* them.
|
||||||
|
*/
|
||||||
|
inline array unsafe_weak_copy(const array& x) {
|
||||||
|
return array(
|
||||||
|
x.buffer(),
|
||||||
|
x.shape(),
|
||||||
|
x.dtype(),
|
||||||
|
x.strides(),
|
||||||
|
x.data_size(),
|
||||||
|
x.flags(),
|
||||||
|
[](auto b) {});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -174,6 +174,21 @@ class TestDistributed(mlx_tests.MLXTestCase):
|
|||||||
finally:
|
finally:
|
||||||
mx.distributed.all_sum = original_all_sum
|
mx.distributed.all_sum = original_all_sum
|
||||||
|
|
||||||
|
def test_donation(self):
|
||||||
|
x = mx.random.normal((1024,))
|
||||||
|
mx.eval(x)
|
||||||
|
|
||||||
|
mx.metal.reset_peak_memory()
|
||||||
|
scale = mx.array(2.0)
|
||||||
|
y = mx.distributed.all_sum(x)
|
||||||
|
mx.eval(y)
|
||||||
|
all_sum_only = mx.metal.get_peak_memory()
|
||||||
|
y = mx.distributed.all_sum(x) * scale
|
||||||
|
mx.eval(y)
|
||||||
|
all_sum_with_binary = mx.metal.get_peak_memory()
|
||||||
|
|
||||||
|
self.assertEqual(all_sum_only, all_sum_with_binary)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -356,6 +356,23 @@ class TestLoad(mlx_tests.MLXTestCase):
|
|||||||
aload = mx.load(save_file)["a"]
|
aload = mx.load(save_file)["a"]
|
||||||
self.assertTrue(mx.array_equal(a, aload))
|
self.assertTrue(mx.array_equal(a, aload))
|
||||||
|
|
||||||
|
def test_load_donation(self):
|
||||||
|
x = mx.random.normal((1024,))
|
||||||
|
mx.eval(x)
|
||||||
|
save_file = os.path.join(self.test_dir, "donation.npy")
|
||||||
|
mx.save(save_file, x)
|
||||||
|
|
||||||
|
mx.metal.reset_peak_memory()
|
||||||
|
scale = mx.array(2.0)
|
||||||
|
y = mx.load(save_file)
|
||||||
|
mx.eval(y)
|
||||||
|
load_only = mx.metal.get_peak_memory()
|
||||||
|
y = mx.load(save_file) * scale
|
||||||
|
mx.eval(y)
|
||||||
|
load_with_binary = mx.metal.get_peak_memory()
|
||||||
|
|
||||||
|
self.assertEqual(load_only, load_with_binary)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user