mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-10 11:16:38 +08:00
Allow querying the allocator for the buffer size (#1404)
This commit is contained in:
parent
8b30acd7eb
commit
881f09b2e2
@ -23,11 +23,22 @@ void free(Buffer buffer) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Buffer CommonAllocator::malloc(size_t size, bool) {
|
Buffer CommonAllocator::malloc(size_t size, bool) {
|
||||||
return Buffer{std::malloc(size)};
|
void* ptr = std::malloc(size + sizeof(size_t));
|
||||||
|
if (ptr != nullptr) {
|
||||||
|
*static_cast<size_t*>(ptr) = size;
|
||||||
|
}
|
||||||
|
return Buffer{ptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommonAllocator::free(Buffer buffer) {
|
void CommonAllocator::free(Buffer buffer) {
|
||||||
std::free(buffer.raw_ptr());
|
std::free(buffer.ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CommonAllocator::size(Buffer buffer) const {
|
||||||
|
if (buffer.ptr() == nullptr) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
return *static_cast<size_t*>(buffer.ptr());
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer malloc_or_wait(size_t size) {
|
Buffer malloc_or_wait(size_t size) {
|
||||||
|
@ -41,6 +41,7 @@ class Allocator {
|
|||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
virtual Buffer malloc(size_t size, bool allow_swap = false) = 0;
|
||||||
virtual void free(Buffer buffer) = 0;
|
virtual void free(Buffer buffer) = 0;
|
||||||
|
virtual size_t size(Buffer buffer) const = 0;
|
||||||
|
|
||||||
Allocator() = default;
|
Allocator() = default;
|
||||||
Allocator(const Allocator& other) = delete;
|
Allocator(const Allocator& other) = delete;
|
||||||
@ -57,6 +58,7 @@ class CommonAllocator : public Allocator {
|
|||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||||
virtual void free(Buffer buffer) override;
|
virtual void free(Buffer buffer) override;
|
||||||
|
virtual size_t size(Buffer buffer) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CommonAllocator() = default;
|
CommonAllocator() = default;
|
||||||
|
@ -324,6 +324,10 @@ class array {
|
|||||||
return array_desc_->data->buffer;
|
return array_desc_->data->buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t buffer_size() const {
|
||||||
|
return allocator::allocator().size(buffer());
|
||||||
|
}
|
||||||
|
|
||||||
// Return a copy of the shared pointer
|
// Return a copy of the shared pointer
|
||||||
// to the array::Data struct
|
// to the array::Data struct
|
||||||
std::shared_ptr<Data> data_shared_ptr() const {
|
std::shared_ptr<Data> data_shared_ptr() const {
|
||||||
|
@ -43,13 +43,15 @@ void set_binary_op_output_data(
|
|||||||
array& out,
|
array& out,
|
||||||
BinaryOpType bopt,
|
BinaryOpType bopt,
|
||||||
bool donate_with_move = false) {
|
bool donate_with_move = false) {
|
||||||
|
bool b_donatable = is_donatable(b, out);
|
||||||
|
bool a_donatable = is_donatable(a, out);
|
||||||
switch (bopt) {
|
switch (bopt) {
|
||||||
case BinaryOpType::ScalarScalar:
|
case BinaryOpType::ScalarScalar:
|
||||||
out.set_data(
|
out.set_data(
|
||||||
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::ScalarVector:
|
case BinaryOpType::ScalarVector:
|
||||||
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
if (b_donatable) {
|
||||||
if (donate_with_move) {
|
if (donate_with_move) {
|
||||||
out.move_shared_buffer(b);
|
out.move_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
@ -64,7 +66,7 @@ void set_binary_op_output_data(
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::VectorScalar:
|
case BinaryOpType::VectorScalar:
|
||||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
if (a_donatable) {
|
||||||
if (donate_with_move) {
|
if (donate_with_move) {
|
||||||
out.move_shared_buffer(a);
|
out.move_shared_buffer(a);
|
||||||
} else {
|
} else {
|
||||||
@ -79,13 +81,13 @@ void set_binary_op_output_data(
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::VectorVector:
|
case BinaryOpType::VectorVector:
|
||||||
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
|
if (a_donatable) {
|
||||||
if (donate_with_move) {
|
if (donate_with_move) {
|
||||||
out.move_shared_buffer(a);
|
out.move_shared_buffer(a);
|
||||||
} else {
|
} else {
|
||||||
out.copy_shared_buffer(a);
|
out.copy_shared_buffer(a);
|
||||||
}
|
}
|
||||||
} else if (b.is_donatable() && b.itemsize() == out.itemsize()) {
|
} else if (b_donatable) {
|
||||||
if (donate_with_move) {
|
if (donate_with_move) {
|
||||||
out.move_shared_buffer(b);
|
out.move_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
@ -100,16 +102,14 @@ void set_binary_op_output_data(
|
|||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case BinaryOpType::General:
|
case BinaryOpType::General:
|
||||||
if (a.is_donatable() && a.flags().row_contiguous &&
|
if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) {
|
||||||
a.itemsize() == out.itemsize() && a.size() == out.size()) {
|
|
||||||
if (donate_with_move) {
|
if (donate_with_move) {
|
||||||
out.move_shared_buffer(a);
|
out.move_shared_buffer(a);
|
||||||
} else {
|
} else {
|
||||||
out.copy_shared_buffer(a);
|
out.copy_shared_buffer(a);
|
||||||
}
|
}
|
||||||
} else if (
|
} else if (
|
||||||
b.is_donatable() && b.flags().row_contiguous &&
|
b_donatable && b.flags().row_contiguous && b.size() == out.size()) {
|
||||||
b.itemsize() == out.itemsize() && b.size() == out.size()) {
|
|
||||||
if (donate_with_move) {
|
if (donate_with_move) {
|
||||||
out.move_shared_buffer(b);
|
out.move_shared_buffer(b);
|
||||||
} else {
|
} else {
|
||||||
|
@ -41,7 +41,7 @@ void set_ternary_op_output_data(
|
|||||||
TernaryOpType topt,
|
TernaryOpType topt,
|
||||||
bool donate_with_move = false) {
|
bool donate_with_move = false) {
|
||||||
auto maybe_donate = [&out, donate_with_move](const array& x) {
|
auto maybe_donate = [&out, donate_with_move](const array& x) {
|
||||||
if (x.is_donatable() && x.itemsize() == out.itemsize()) {
|
if (is_donatable(x, out)) {
|
||||||
if (donate_with_move) {
|
if (donate_with_move) {
|
||||||
out.move_shared_buffer(x);
|
out.move_shared_buffer(x);
|
||||||
} else {
|
} else {
|
||||||
|
@ -12,7 +12,7 @@ namespace mlx::core {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
void set_unary_output_data(const array& in, array& out) {
|
void set_unary_output_data(const array& in, array& out) {
|
||||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
if (is_donatable(in, out)) {
|
||||||
out.copy_shared_buffer(in);
|
out.copy_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
auto size = in.data_size();
|
auto size = in.data_size();
|
||||||
|
@ -155,4 +155,11 @@ inline auto check_contiguity(
|
|||||||
no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
|
no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool is_donatable(const array& in, const array& out) {
|
||||||
|
constexpr size_t donation_extra = 16384;
|
||||||
|
|
||||||
|
return in.is_donatable() && in.itemsize() == out.itemsize() &&
|
||||||
|
in.buffer_size() <= out.nbytes() + donation_extra;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -241,6 +241,10 @@ void MetalAllocator::free(Buffer buffer) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t MetalAllocator::size(Buffer buffer) const {
|
||||||
|
return static_cast<MTL::Buffer*>(buffer.ptr())->length();
|
||||||
|
}
|
||||||
|
|
||||||
MetalAllocator& allocator() {
|
MetalAllocator& allocator() {
|
||||||
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
|
// By creating the |allocator_| on heap, the destructor of MetalAllocator will
|
||||||
// not be called on exit and all the buffers will be leaked. This is necessary
|
// not be called on exit and all the buffers will be leaked. This is necessary
|
||||||
|
@ -56,6 +56,7 @@ class MetalAllocator : public allocator::Allocator {
|
|||||||
public:
|
public:
|
||||||
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
virtual Buffer malloc(size_t size, bool allow_swap = false) override;
|
||||||
virtual void free(Buffer buffer) override;
|
virtual void free(Buffer buffer) override;
|
||||||
|
virtual size_t size(Buffer buffer) const override;
|
||||||
size_t get_active_memory() {
|
size_t get_active_memory() {
|
||||||
return active_memory_;
|
return active_memory_;
|
||||||
};
|
};
|
||||||
|
@ -10,7 +10,7 @@ Allocator& allocator() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void* Buffer::raw_ptr() {
|
void* Buffer::raw_ptr() {
|
||||||
return ptr_;
|
return static_cast<size_t*>(ptr_) + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core::allocator
|
} // namespace mlx::core::allocator
|
||||||
|
Loading…
Reference in New Issue
Block a user