Allow querying the allocator for the buffer size (#1404)

This commit is contained in:
Angelos Katharopoulos 2024-09-11 21:02:16 -07:00 committed by GitHub
parent 8b30acd7eb
commit 881f09b2e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 42 additions and 13 deletions

View File

@ -23,11 +23,22 @@ void free(Buffer buffer) {
} }
Buffer CommonAllocator::malloc(size_t size, bool) { Buffer CommonAllocator::malloc(size_t size, bool) {
return Buffer{std::malloc(size)}; void* ptr = std::malloc(size + sizeof(size_t));
if (ptr != nullptr) {
*static_cast<size_t*>(ptr) = size;
}
return Buffer{ptr};
} }
void CommonAllocator::free(Buffer buffer) { void CommonAllocator::free(Buffer buffer) {
std::free(buffer.raw_ptr()); std::free(buffer.ptr());
}
size_t CommonAllocator::size(Buffer buffer) const {
if (buffer.ptr() == nullptr) {
return 0;
}
return *static_cast<size_t*>(buffer.ptr());
} }
Buffer malloc_or_wait(size_t size) { Buffer malloc_or_wait(size_t size) {

View File

@ -41,6 +41,7 @@ class Allocator {
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0; virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
virtual void free(Buffer buffer) = 0; virtual void free(Buffer buffer) = 0;
virtual size_t size(Buffer buffer) const = 0;
Allocator() = default; Allocator() = default;
Allocator(const Allocator& other) = delete; Allocator(const Allocator& other) = delete;
@ -57,6 +58,7 @@ class CommonAllocator : public Allocator {
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
private: private:
CommonAllocator() = default; CommonAllocator() = default;

View File

@ -324,6 +324,10 @@ class array {
return array_desc_->data->buffer; return array_desc_->data->buffer;
} }
size_t buffer_size() const {
return allocator::allocator().size(buffer());
}
// Return a copy of the shared pointer // Return a copy of the shared pointer
// to the array::Data struct // to the array::Data struct
std::shared_ptr<Data> data_shared_ptr() const { std::shared_ptr<Data> data_shared_ptr() const {

View File

@ -43,13 +43,15 @@ void set_binary_op_output_data(
array& out, array& out,
BinaryOpType bopt, BinaryOpType bopt,
bool donate_with_move = false) { bool donate_with_move = false) {
bool b_donatable = is_donatable(b, out);
bool a_donatable = is_donatable(a, out);
switch (bopt) { switch (bopt) {
case BinaryOpType::ScalarScalar: case BinaryOpType::ScalarScalar:
out.set_data( out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags()); allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break; break;
case BinaryOpType::ScalarVector: case BinaryOpType::ScalarVector:
if (b.is_donatable() && b.itemsize() == out.itemsize()) { if (b_donatable) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(b); out.move_shared_buffer(b);
} else { } else {
@ -64,7 +66,7 @@ void set_binary_op_output_data(
} }
break; break;
case BinaryOpType::VectorScalar: case BinaryOpType::VectorScalar:
if (a.is_donatable() && a.itemsize() == out.itemsize()) { if (a_donatable) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(a); out.move_shared_buffer(a);
} else { } else {
@ -79,13 +81,13 @@ void set_binary_op_output_data(
} }
break; break;
case BinaryOpType::VectorVector: case BinaryOpType::VectorVector:
if (a.is_donatable() && a.itemsize() == out.itemsize()) { if (a_donatable) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(a); out.move_shared_buffer(a);
} else { } else {
out.copy_shared_buffer(a); out.copy_shared_buffer(a);
} }
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) { } else if (b_donatable) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(b); out.move_shared_buffer(b);
} else { } else {
@ -100,16 +102,14 @@ void set_binary_op_output_data(
} }
break; break;
case BinaryOpType::General: case BinaryOpType::General:
if (a.is_donatable() && a.flags().row_contiguous && if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(a); out.move_shared_buffer(a);
} else { } else {
out.copy_shared_buffer(a); out.copy_shared_buffer(a);
} }
} else if ( } else if (
b.is_donatable() && b.flags().row_contiguous && b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
b.itemsize() == out.itemsize() && b.size() == out.size()) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(b); out.move_shared_buffer(b);
} else { } else {

View File

@ -41,7 +41,7 @@ void set_ternary_op_output_data(
TernaryOpType topt, TernaryOpType topt,
bool donate_with_move = false) { bool donate_with_move = false) {
auto maybe_donate = [&out, donate_with_move](const array& x) { auto maybe_donate = [&out, donate_with_move](const array& x) {
if (x.is_donatable() && x.itemsize() == out.itemsize()) { if (is_donatable(x, out)) {
if (donate_with_move) { if (donate_with_move) {
out.move_shared_buffer(x); out.move_shared_buffer(x);
} else { } else {

View File

@ -12,7 +12,7 @@ namespace mlx::core {
namespace { namespace {
void set_unary_output_data(const array& in, array& out) { void set_unary_output_data(const array& in, array& out) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) { if (is_donatable(in, out)) {
out.copy_shared_buffer(in); out.copy_shared_buffer(in);
} else { } else {
auto size = in.data_size(); auto size = in.data_size();

View File

@ -155,4 +155,11 @@ inline auto check_contiguity(
no_broadcast_data_size, is_row_contiguous, is_col_contiguous); no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
} }
inline bool is_donatable(const array& in, const array& out) {
constexpr size_t donation_extra = 16384;
return in.is_donatable() && in.itemsize() == out.itemsize() &&
in.buffer_size() <= out.nbytes() + donation_extra;
}
} // namespace mlx::core } // namespace mlx::core

View File

@ -241,6 +241,10 @@ void MetalAllocator::free(Buffer buffer) {
} }
} }
size_t MetalAllocator::size(Buffer buffer) const {
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
}
MetalAllocator& allocator() { MetalAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of MetalAllocator will // By creating the |allocator_| on heap, the destructor of MetalAllocator will
// not be called on exit and all the buffers will be leaked. This is necessary // not be called on exit and all the buffers will be leaked. This is necessary

View File

@ -56,6 +56,7 @@ class MetalAllocator : public allocator::Allocator {
public: public:
virtual Buffer malloc(size_t size, bool allow_swap = false) override; virtual Buffer malloc(size_t size, bool allow_swap = false) override;
virtual void free(Buffer buffer) override; virtual void free(Buffer buffer) override;
virtual size_t size(Buffer buffer) const override;
size_t get_active_memory() { size_t get_active_memory() {
return active_memory_; return active_memory_;
}; };

View File

@ -10,7 +10,7 @@ Allocator& allocator() {
} }
void* Buffer::raw_ptr() { void* Buffer::raw_ptr() {
return ptr_; return static_cast<size_t*>(ptr_) + 1;
} }
} // namespace mlx::core::allocator } // namespace mlx::core::allocator