Buffer Donation (#519)

* buffer donation

* fix to move shared pointer

* format

* gpu in place for copy and binary

* revert ops test

* cpu in place

* a little cleanup

* remove useless bench
This commit is contained in:
Awni Hannun 2024-01-26 16:30:33 -08:00 committed by GitHub
parent 07f35c9d8a
commit 8993382aaa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 199 additions and 178 deletions

View File

@ -155,6 +155,14 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size()); copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
} }
void array::move_shared_buffer(array other) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = other.strides();
array_desc_->flags = other.flags();
array_desc_->data_size = other.data_size();
array_desc_->data_ptr = other.array_desc_->data_ptr;
}
array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype) array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
: shape(shape), dtype(dtype) { : shape(shape), dtype(dtype) {
std::tie(size, strides) = cum_prod(shape); std::tie(size, strides) = cum_prod(shape);

View File

@ -240,6 +240,11 @@ class array {
return array_desc_->inputs; return array_desc_->inputs;
} }
/** True indicates the arrays buffer is safe to reuse */
bool is_donatable() const {
return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1);
}
/** The array's siblings. */ /** The array's siblings. */
const std::vector<array>& siblings() const { const std::vector<array>& siblings() const {
return array_desc_->siblings; return array_desc_->siblings;
@ -282,6 +287,12 @@ class array {
return array_desc_->data->buffer; return array_desc_->data->buffer;
}; };
// Return a copy of the shared pointer
// to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const {
return array_desc_->data;
}
// Return a raw pointer to the arrays data
template <typename T> template <typename T>
T* data() { T* data() {
return static_cast<T*>(array_desc_->data_ptr); return static_cast<T*>(array_desc_->data_ptr);
@ -322,6 +333,8 @@ class array {
void copy_shared_buffer(const array& other); void copy_shared_buffer(const array& other);
void move_shared_buffer(array other);
void overwrite_descriptor(const array& other) { void overwrite_descriptor(const array& other) {
array_desc_ = other.array_desc_; array_desc_ = other.array_desc_;
} }

View File

@ -71,21 +71,11 @@ void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) { if (in.dtype() == float32 && in.flags().contiguous) {
auto size = in.data_size(); set_unary_output_data(in, out);
out.set_data( vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabs(in.data<float>(), 1, out.data<float>(), 1, size);
} else if (in.dtype() == int32 && in.flags().contiguous) { } else if (in.dtype() == int32 && in.flags().contiguous) {
auto size = in.data_size(); set_unary_output_data(in, out);
out.set_data( vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, in.data_size());
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vabsi(in.data<int>(), 1, out.data<int>(), 1, size);
} else if (is_unsigned(in.dtype())) { } else if (is_unsigned(in.dtype())) {
// No-op for unsigned types // No-op for unsigned types
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
@ -138,12 +128,8 @@ void ArcCos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacosf(out.data<float>(), in.data<float>(), &size); vvacosf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@ -154,12 +140,8 @@ void ArcCosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvacoshf(out.data<float>(), in.data<float>(), &size); vvacoshf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@ -170,12 +152,8 @@ void ArcSin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinf(out.data<float>(), in.data<float>(), &size); vvasinf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@ -186,12 +164,8 @@ void ArcSinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvasinhf(out.data<float>(), in.data<float>(), &size); vvasinhf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@ -202,12 +176,8 @@ void ArcTan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanf(out.data<float>(), in.data<float>(), &size); vvatanf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@ -218,12 +188,8 @@ void ArcTanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvatanhf(out.data<float>(), in.data<float>(), &size); vvatanhf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@ -235,30 +201,23 @@ void AsType::eval_cpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0]; auto& in = inputs[0];
if (in.flags().contiguous) { if (in.flags().contiguous) {
auto allocfn = [&in, &out]() {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
};
// Use accelerate functions if possible // Use accelerate functions if possible
if (in.dtype() == float32 && out.dtype() == uint32) { if (in.dtype() == float32 && out.dtype() == uint32) {
allocfn(); set_unary_output_data(in, out);
vDSP_vfixu32( vDSP_vfixu32(
in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size()); in.data<float>(), 1, out.data<uint32_t>(), 1, in.data_size());
return; return;
} else if (in.dtype() == float32 && out.dtype() == int32) { } else if (in.dtype() == float32 && out.dtype() == int32) {
allocfn(); set_unary_output_data(in, out);
vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size()); vDSP_vfix32(in.data<float>(), 1, out.data<int32_t>(), 1, in.data_size());
return; return;
} else if (in.dtype() == uint32 && out.dtype() == float32) { } else if (in.dtype() == uint32 && out.dtype() == float32) {
allocfn(); set_unary_output_data(in, out);
vDSP_vfltu32( vDSP_vfltu32(
in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size()); in.data<uint32_t>(), 1, out.data<float>(), 1, in.data_size());
return; return;
} else if (in.dtype() == int32 && out.dtype() == float32) { } else if (in.dtype() == int32 && out.dtype() == float32) {
allocfn(); set_unary_output_data(in, out);
vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size()); vDSP_vflt32(in.data<int32_t>(), 1, out.data<float>(), 1, in.data_size());
return; return;
} }
@ -270,12 +229,8 @@ void Cos::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcosf(out.data<float>(), in.data<float>(), &size); vvcosf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@ -286,12 +241,8 @@ void Cosh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvcoshf(out.data<float>(), in.data<float>(), &size); vvcoshf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@ -379,12 +330,8 @@ void Exp::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size(); auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); vvexpf(out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) { } else if (is_floating_point(out.dtype())) {
unary_fp(in, out, [](auto x) { return std::exp(x); }); unary_fp(in, out, [](auto x) { return std::exp(x); });
@ -411,12 +358,8 @@ void Log::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size(); auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
switch (base_) { switch (base_) {
case Base::e: case Base::e:
vvlogf( vvlogf(
@ -440,12 +383,8 @@ void Log1p::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size(); auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvlog1pf( vvlog1pf(
out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size)); out.data<float>(), in.data<float>(), reinterpret_cast<int*>(&size));
} else if (is_floating_point(out.dtype())) { } else if (is_floating_point(out.dtype())) {
@ -527,13 +466,8 @@ void Negative::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) { if (in.dtype() == float32 && in.flags().contiguous) {
auto size = in.data_size(); set_unary_output_data(in, out);
out.set_data( vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, in.data_size());
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vneg(in.data<float>(), 1, out.data<float>(), 1, size);
} else { } else {
unary(in, out, [](auto x) { return -x; }); unary(in, out, [](auto x) { return -x; });
} }
@ -546,7 +480,13 @@ void Power::eval_cpu(const std::vector<array>& inputs, array& out) {
if (out.dtype() == float32 && a.flags().row_contiguous && if (out.dtype() == float32 && a.flags().row_contiguous &&
b.flags().row_contiguous) { b.flags().row_contiguous) {
int size = a.size(); int size = a.size();
out.set_data(allocator::malloc_or_wait(out.nbytes())); if (a.is_donatable() && a.itemsize() == out.itemsize()) {
out.copy_shared_buffer(a);
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
out.copy_shared_buffer(b);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size); vvpowf(out.data<float>(), b.data<float>(), a.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@ -588,12 +528,8 @@ void Sin::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinf(out.data<float>(), in.data<float>(), &size); vvsinf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@ -604,12 +540,8 @@ void Sinh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvsinhf(out.data<float>(), in.data<float>(), &size); vvsinhf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@ -620,12 +552,8 @@ void Square::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) { if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
auto size = in.data_size(); auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size); vDSP_vsq(in.data<float>(), 1, out.data<float>(), 1, size);
} else { } else {
unary(in, out, [](auto x) { return x * x; }); unary(in, out, [](auto x) { return x * x; });
@ -636,12 +564,8 @@ void Sqrt::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto& in = inputs[0]; auto& in = inputs[0];
if (in.dtype() == float32 && in.flags().contiguous) { if (in.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
if (recip_) { if (recip_) {
vvrsqrtf(out.data<float>(), in.data<float>(), &size); vvrsqrtf(out.data<float>(), in.data<float>(), &size);
} else { } else {
@ -696,12 +620,8 @@ void Tan::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanf(out.data<float>(), in.data<float>(), &size); vvtanf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);
@ -712,12 +632,8 @@ void Tanh::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1); assert(inputs.size() == 1);
const auto& in = inputs[0]; const auto& in = inputs[0];
if (out.dtype() == float32 && in.flags().contiguous) { if (out.dtype() == float32 && in.flags().contiguous) {
set_unary_output_data(in, out);
int size = in.data_size(); int size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
vvtanhf(out.data<float>(), in.data<float>(), &size); vvtanhf(out.data<float>(), in.data<float>(), &size);
} else { } else {
eval(inputs, out); eval(inputs, out);

View File

@ -1,7 +1,6 @@
// Copyright © 2023 Apple Inc. // Copyright © 2023 Apple Inc.
#pragma once #pragma once
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/backend/common/utils.h" #include "mlx/backend/common/utils.h"
@ -40,29 +39,83 @@ void set_binary_op_output_data(
const array& a, const array& a,
const array& b, const array& b,
array& out, array& out,
BinaryOpType bopt) { BinaryOpType bopt,
bool donate_with_move = false) {
switch (bopt) { switch (bopt) {
case ScalarScalar: case ScalarScalar:
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break; break;
case ScalarVector: case ScalarVector:
out.set_data( if (b.is_donatable() && b.itemsize() == out.itemsize()) {
allocator::malloc_or_wait(b.data_size() * out.itemsize()), if (donate_with_move) {
b.data_size(), out.move_shared_buffer(b);
b.strides(), } else {
b.flags()); out.copy_shared_buffer(b);
}
} else {
out.set_data(
allocator::malloc_or_wait(b.data_size() * out.itemsize()),
b.data_size(),
b.strides(),
b.flags());
}
break; break;
case VectorScalar: case VectorScalar:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
break;
case VectorVector: case VectorVector:
out.set_data( if (a.is_donatable() && a.itemsize() == out.itemsize()) {
allocator::malloc_or_wait(a.data_size() * out.itemsize()), if (donate_with_move) {
a.data_size(), out.move_shared_buffer(a);
a.strides(), } else {
a.flags()); out.copy_shared_buffer(a);
}
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
}
break; break;
case General: case General:
out.set_data(allocator::malloc_or_wait(out.nbytes())); if (a.is_donatable() && a.flags().row_contiguous &&
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(a);
} else {
out.copy_shared_buffer(a);
}
} else if (
b.is_donatable() && b.flags().row_contiguous &&
b.itemsize() == out.itemsize() && b.size() == out.size()) {
if (donate_with_move) {
out.move_shared_buffer(b);
} else {
out.copy_shared_buffer(b);
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
break; break;
} }
} }

View File

@ -289,11 +289,16 @@ void copy(const array& src, array& dst, CopyType ctype) {
// Allocate the output // Allocate the output
switch (ctype) { switch (ctype) {
case CopyType::Vector: case CopyType::Vector:
dst.set_data( if (src.is_donatable() && src.itemsize() == dst.itemsize()) {
allocator::malloc_or_wait(src.data_size() * dst.itemsize()), dst.copy_shared_buffer(src);
src.data_size(), } else {
src.strides(), auto size = src.data_size();
src.flags()); dst.set_data(
allocator::malloc_or_wait(size * dst.itemsize()),
size,
src.strides(),
src.flags());
}
break; break;
case CopyType::Scalar: case CopyType::Scalar:
case CopyType::General: case CopyType::General:

View File

@ -237,17 +237,14 @@ void Erf::eval(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0]; const auto& in = inputs[0];
switch (out.dtype()) { switch (out.dtype()) {
case float32: case float32:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float>(in, out, [](auto x) { return std::erf(x); }); unary_op<float>(in, out, [](auto x) { return std::erf(x); });
break; break;
case float16: case float16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float16_t>(in, out, [](auto x) { unary_op<float16_t>(in, out, [](auto x) {
return static_cast<float16_t>(std::erf(static_cast<float>(x))); return static_cast<float16_t>(std::erf(static_cast<float>(x)));
}); });
break; break;
case bfloat16: case bfloat16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<bfloat16_t>(in, out, [](auto x) { unary_op<bfloat16_t>(in, out, [](auto x) {
return static_cast<bfloat16_t>(std::erf(static_cast<float>(x))); return static_cast<bfloat16_t>(std::erf(static_cast<float>(x)));
}); });
@ -264,17 +261,14 @@ void ErfInv::eval(const std::vector<array>& inputs, array& out) {
const auto& in = inputs[0]; const auto& in = inputs[0];
switch (out.dtype()) { switch (out.dtype()) {
case float32: case float32:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float>(in, out, [](auto x) { return erfinv(x); }); unary_op<float>(in, out, [](auto x) { return erfinv(x); });
break; break;
case float16: case float16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<float16_t>(in, out, [](auto x) { unary_op<float16_t>(in, out, [](auto x) {
return static_cast<float16_t>(erfinv(static_cast<float>(x))); return static_cast<float16_t>(erfinv(static_cast<float>(x)));
}); });
break; break;
case bfloat16: case bfloat16:
out.set_data(allocator::malloc_or_wait(out.nbytes()));
unary_op<bfloat16_t>(in, out, [](auto x) { unary_op<bfloat16_t>(in, out, [](auto x) {
return static_cast<bfloat16_t>(erfinv(static_cast<float>(x))); return static_cast<bfloat16_t>(erfinv(static_cast<float>(x)));
}); });

View File

@ -64,15 +64,24 @@ struct RoundOp {
} }
}; };
void set_unary_output_data(const array& in, array& out) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
} else {
auto size = in.data_size();
out.set_data(
allocator::malloc_or_wait(size * out.itemsize()),
size,
in.strides(),
in.flags());
}
}
template <typename T, typename Op> template <typename T, typename Op>
void unary_op(const array& a, array& out, Op op) { void unary_op(const array& a, array& out, Op op) {
const T* a_ptr = a.data<T>(); const T* a_ptr = a.data<T>();
if (a.flags().contiguous) { if (a.flags().contiguous) {
out.set_data( set_unary_output_data(a, out);
allocator::malloc_or_wait(a.data_size() * out.itemsize()),
a.data_size(),
a.strides(),
a.flags());
T* dst = out.data<T>(); T* dst = out.data<T>();
for (size_t i = 0; i < a.data_size(); ++i) { for (size_t i = 0; i < a.data_size(); ++i) {
dst[i] = op(a_ptr[i]); dst[i] = op(a_ptr[i]);

View File

@ -12,11 +12,15 @@ namespace mlx::core {
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) { if (ctype == CopyType::Vector) {
out.set_data( if (in.is_donatable() && in.itemsize() == out.itemsize()) {
allocator::malloc_or_wait(in.data_size() * out.itemsize()), out.move_shared_buffer(in);
in.data_size(), } else {
in.strides(), out.set_data(
in.flags()); allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
} }
@ -67,7 +71,8 @@ void copy_gpu_inplace(
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index); auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0); bool donate_in = in.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_in ? out : in, 0);
set_array_buffer(compute_encoder, out, 1); set_array_buffer(compute_encoder, out, 1);
if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {

View File

@ -64,14 +64,23 @@ std::function<void()> make_task(
auto command_buffer = increment_command_buffer(s); auto command_buffer = increment_command_buffer(s);
auto outputs = arr.outputs(); auto outputs = arr.outputs();
arr.primitive().eval_gpu(arr.inputs(), outputs); arr.primitive().eval_gpu(arr.inputs(), outputs);
std::vector<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.push_back(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.push_back(s.data_shared_ptr());
}
if (!arr.is_tracer()) {
arr.detach();
}
if (p) { if (p) {
metal::device(s.device).end_encoding(s.index); metal::device(s.device).end_encoding(s.index);
scheduler::notify_new_task(s); scheduler::notify_new_task(s);
command_buffer->addCompletedHandler( command_buffer->addCompletedHandler(
[s, arr, p = std::move(p)](MTL::CommandBuffer* cbuf) mutable { [s, buffers = std::move(buffers), p = std::move(p)](
if (!arr.is_tracer()) { MTL::CommandBuffer* cbuf) {
arr.detach();
}
p->set_value(); p->set_value();
scheduler::notify_task_completion(s); scheduler::notify_task_completion(s);
check_error(cbuf); check_error(cbuf);
@ -79,10 +88,7 @@ std::function<void()> make_task(
metal::device(s.device).commit_command_buffer(s.index); metal::device(s.device).commit_command_buffer(s.index);
} else { } else {
command_buffer->addCompletedHandler( command_buffer->addCompletedHandler(
[s, arr](MTL::CommandBuffer* cbuf) mutable { [s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
if (!arr.is_tracer()) {
arr.detach();
}
check_error(cbuf); check_error(cbuf);
}); });
} }

View File

@ -27,8 +27,8 @@ void binary_op(
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt); set_binary_op_output_data(a, b, outputs[0], bopt, true);
set_binary_op_output_data(a, b, outputs[1], bopt); set_binary_op_output_data(a, b, outputs[1], bopt, true);
auto& out = outputs[0]; auto& out = outputs[0];
if (out.size() == 0) { if (out.size() == 0) {
@ -69,8 +69,14 @@ void binary_op(
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index); auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, a, 0); // - If a is donated it goes to the first output
set_array_buffer(compute_encoder, b, 1); // - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_a ? outputs[0] : a, 0);
set_array_buffer(
compute_encoder, donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, 1);
set_array_buffer(compute_encoder, outputs[0], 2); set_array_buffer(compute_encoder, outputs[0], 2);
set_array_buffer(compute_encoder, outputs[1], 3); set_array_buffer(compute_encoder, outputs[1], 3);
@ -122,7 +128,7 @@ void binary_op(
auto& a = inputs[0]; auto& a = inputs[0];
auto& b = inputs[1]; auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b); auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt); set_binary_op_output_data(a, b, out, bopt, true);
if (out.size() == 0) { if (out.size() == 0) {
return; return;
} }
@ -161,8 +167,10 @@ void binary_op(
auto kernel = d.get_kernel(kname.str()); auto kernel = d.get_kernel(kname.str());
auto compute_encoder = d.get_command_encoder(s.index); auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, a, 0); bool donate_a = a.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, b, 1); bool donate_b = b.data_shared_ptr() == nullptr;
set_array_buffer(compute_encoder, donate_a ? out : a, 0);
set_array_buffer(compute_encoder, donate_b ? out : b, 1);
set_array_buffer(compute_encoder, out, 2); set_array_buffer(compute_encoder, out, 2);
if (bopt == General) { if (bopt == General) {
@ -212,11 +220,15 @@ void unary_op(
auto& in = inputs[0]; auto& in = inputs[0];
bool contig = in.flags().contiguous; bool contig = in.flags().contiguous;
if (contig) { if (contig) {
out.set_data( if (in.is_donatable() && in.itemsize() == out.itemsize()) {
allocator::malloc_or_wait(in.data_size() * out.itemsize()), out.move_shared_buffer(in);
in.data_size(), } else {
in.strides(), out.set_data(
in.flags()); allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else { } else {
out.set_data(allocator::malloc_or_wait(out.nbytes())); out.set_data(allocator::malloc_or_wait(out.nbytes()));
} }
@ -240,7 +252,8 @@ void unary_op(
auto compute_encoder = d.get_command_encoder(s.index); auto compute_encoder = d.get_command_encoder(s.index);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
set_array_buffer(compute_encoder, in, 0); set_array_buffer(
compute_encoder, in.data_shared_ptr() == nullptr ? out : in, 0);
set_array_buffer(compute_encoder, out, 1); set_array_buffer(compute_encoder, out, 1);
if (!contig) { if (!contig) {
compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2); compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2);

View File

@ -98,6 +98,7 @@ void eval(const std::vector<array>& outputs) {
auto stream = arr.primitive().stream(); auto stream = arr.primitive().stream();
std::vector<std::shared_future<void>> arr_deps; std::vector<std::shared_future<void>> arr_deps;
for (auto& in : arr.inputs()) { for (auto& in : arr.inputs()) {
// TODO that's a bug
if (auto it = deps.find(in.primitive_id()); it != deps.end()) { if (auto it = deps.find(in.primitive_id()); it != deps.end()) {
arr_deps.push_back(it->second); arr_deps.push_back(it->second);
} }

View File

@ -65,11 +65,9 @@ class MLXTestCase(unittest.TestCase):
) )
if not isinstance(mx_res, mx.array) and not isinstance(expected, mx.array): if not isinstance(mx_res, mx.array) and not isinstance(expected, mx.array):
np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol) np.testing.assert_allclose(mx_res, expected, rtol=rtol, atol=atol)
return
elif not isinstance(mx_res, mx.array): elif not isinstance(mx_res, mx.array):
mx_res = mx.array(mx_res) mx_res = mx.array(mx_res)
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))
elif not isinstance(expected, mx.array): elif not isinstance(expected, mx.array):
expected = mx.array(expected) expected = mx.array(expected)
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol)) self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))
else:
self.assertTrue(mx.allclose(mx_res, expected, rtol=rtol, atol=atol))