mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Support disable metal buffer cache to prevent performance degradation caused by large memory caching (#390)
* support disable metal buffer cache, due to large unused memory buffered when llm generated long context tokens * Run format and add "cache_enabled" feature tests
This commit is contained in:
parent
49a52610b7
commit
a749a91c75
@ -23,6 +23,16 @@ void* Buffer::raw_ptr() {
|
|||||||
|
|
||||||
namespace metal {
|
namespace metal {
|
||||||
|
|
||||||
|
static bool cache_enabled_ = true;
|
||||||
|
|
||||||
|
bool cache_enabled() {
|
||||||
|
return cache_enabled_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_cache_enabled(bool enabled) {
|
||||||
|
cache_enabled_ = enabled;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
BufferCache::BufferCache(MTL::Device* device)
|
BufferCache::BufferCache(MTL::Device* device)
|
||||||
@ -196,7 +206,11 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
|
|||||||
|
|
||||||
void MetalAllocator::free(Buffer buffer) {
|
void MetalAllocator::free(Buffer buffer) {
|
||||||
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
auto buf = static_cast<MTL::Buffer*>(buffer.ptr());
|
||||||
|
if (cache_enabled()) {
|
||||||
buffer_cache_.recycle_to_cache(buf);
|
buffer_cache_.recycle_to_cache(buf);
|
||||||
|
} else {
|
||||||
|
buf->release();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MetalAllocator& allocator() {
|
MetalAllocator& allocator() {
|
||||||
|
@ -19,6 +19,9 @@ constexpr bool is_available() {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool cache_enabled(void);
|
||||||
|
void set_cache_enabled(bool enabled);
|
||||||
|
|
||||||
void new_stream(Stream stream);
|
void new_stream(Stream stream);
|
||||||
std::shared_ptr<void> new_scoped_memory_pool();
|
std::shared_ptr<void> new_scoped_memory_pool();
|
||||||
|
|
||||||
|
@ -11,4 +11,12 @@ using namespace mlx::core;
|
|||||||
void init_metal(py::module_& m) {
|
void init_metal(py::module_& m) {
|
||||||
py::module_ metal = m.def_submodule("metal", "mlx.metal");
|
py::module_ metal = m.def_submodule("metal", "mlx.metal");
|
||||||
metal.def("is_available", &metal::is_available);
|
metal.def("is_available", &metal::is_available);
|
||||||
|
metal.def(
|
||||||
|
"cache_enabled",
|
||||||
|
&metal::cache_enabled,
|
||||||
|
"check if metal buffer cache is enabled, default is true");
|
||||||
|
metal.def(
|
||||||
|
"set_cache_enabled",
|
||||||
|
&metal::set_cache_enabled,
|
||||||
|
"enable or disable metal buffer cache");
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
|
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/metal.h"
|
#include "mlx/backend/metal/metal.h"
|
||||||
|
#include "mlx/backend/metal/allocator.h"
|
||||||
#include "mlx/mlx.h"
|
#include "mlx/mlx.h"
|
||||||
|
|
||||||
using namespace mlx::core;
|
using namespace mlx::core;
|
||||||
@ -471,3 +472,43 @@ TEST_CASE("test metal validation") {
|
|||||||
|
|
||||||
eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
|
eval(scatter_max(array(1), {}, array(2), std::vector<int>{}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test metal enable/disable cache") {
|
||||||
|
// Test enable metal cache
|
||||||
|
{
|
||||||
|
metal::set_cache_enabled(true);
|
||||||
|
CHECK(metal::cache_enabled());
|
||||||
|
|
||||||
|
auto &a = metal::allocator();
|
||||||
|
auto size = 100;
|
||||||
|
auto buf = a.malloc(size, false);
|
||||||
|
|
||||||
|
// Release a
|
||||||
|
a.free(buf);
|
||||||
|
|
||||||
|
// Check size should equals to size
|
||||||
|
CHECK_EQ(static_cast<MTL::Buffer*>(buf.ptr())->length(), size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test disable metal cache
|
||||||
|
{
|
||||||
|
metal::set_cache_enabled(false);
|
||||||
|
CHECK(!metal::cache_enabled());
|
||||||
|
|
||||||
|
auto &a = metal::allocator();
|
||||||
|
auto size = 100;
|
||||||
|
auto buf = a.malloc(size, false);
|
||||||
|
auto buf_ptr = static_cast<MTL::Buffer*>(buf.ptr());
|
||||||
|
unsigned char first_byte = *reinterpret_cast<unsigned char*>(buf_ptr);
|
||||||
|
printf("first byte: %d\n", first_byte);
|
||||||
|
|
||||||
|
// Release a
|
||||||
|
a.free(buf);
|
||||||
|
|
||||||
|
// If release successfully, the first byte should be different from the first byte before release
|
||||||
|
unsigned char new_first_byte = *reinterpret_cast<unsigned char*>(buf_ptr);
|
||||||
|
printf("new first byte: %d\n", new_first_byte);
|
||||||
|
|
||||||
|
CHECK_NE(new_first_byte, first_byte);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user