diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 844615284..1ee9335c0 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -129,7 +129,7 @@ Device::~Device() { b.second.second->release(); } for (auto& e : encoder_map_) { - e.second->release(); + (*e.second)->release(); } for (auto& k : kernel_map_) { k.second->release(); @@ -200,8 +200,8 @@ void Device::commit_command_buffer(int index) { void Device::end_encoding(int index) { auto eit = encoder_map_.find(index); if (eit != encoder_map_.end()) { - eit->second->endEncoding(); - eit->second->release(); + (*eit->second)->endEncoding(); + (*eit->second)->release(); encoder_map_.erase(eit); } } @@ -214,9 +214,11 @@ CommandEncoder& Device::get_command_encoder(int index) { cb->computeCommandEncoder(MTL::DispatchTypeConcurrent); // Increment ref count so the buffer is not garbage collected compute_encoder->retain(); - eit = encoder_map_.emplace(index, CommandEncoder{compute_encoder}).first; + eit = encoder_map_ + .emplace(index, std::make_unique(compute_encoder)) + .first; } - return eit->second; + return *(eit->second); } void Device::register_library( diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 4fc43c164..029b2cc92 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -39,6 +39,7 @@ using MTLFCList = struct CommandEncoder { CommandEncoder(MTL::ComputeCommandEncoder* enc) : enc(enc), concurrent(false){}; + CommandEncoder(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete; struct ConcurrentContext { @@ -197,7 +198,7 @@ class Device { MTL::Device* device_; std::unordered_map queue_map_; std::unordered_map> buffer_map_; - std::unordered_map encoder_map_; + std::unordered_map> encoder_map_; std::unordered_map kernel_map_; std::unordered_map library_map_; std::mutex mtx_; diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 3b2db95c6..b554be55b 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -248,7 +248,7 @@ class TestFast(mlx_tests.MLXTestCase): def test_layer_norm(self): # Per dtype absolute tolerance - tolerances = {mx.float32: 3e-6, mx.float16: 3e-3, mx.bfloat16: 3e-2} + tolerances = {mx.float32: 1e-5, mx.float16: 5e-3, mx.bfloat16: 5e-2} dtypes = [mx.float32, mx.float16, mx.bfloat16] epss = [1e-3, 1e-5]