From ae18326533e3bec848b2a22a4867759869271257 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 11 Apr 2024 21:15:36 -0700 Subject: [PATCH] No copy command encoder (#986) * no copy command encoder * up layer norm test tolerances --- mlx/backend/metal/device.cpp | 12 +++++++----- mlx/backend/metal/device.h | 3 ++- python/tests/test_fast.py | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 844615284..1ee9335c0 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -129,7 +129,7 @@ Device::~Device() { b.second.second->release(); } for (auto& e : encoder_map_) { - e.second->release(); + (*e.second)->release(); } for (auto& k : kernel_map_) { k.second->release(); @@ -200,8 +200,8 @@ void Device::commit_command_buffer(int index) { void Device::end_encoding(int index) { auto eit = encoder_map_.find(index); if (eit != encoder_map_.end()) { - eit->second->endEncoding(); - eit->second->release(); + (*eit->second)->endEncoding(); + (*eit->second)->release(); encoder_map_.erase(eit); } } @@ -214,9 +214,11 @@ CommandEncoder& Device::get_command_encoder(int index) { cb->computeCommandEncoder(MTL::DispatchTypeConcurrent); // Increment ref count so the buffer is not garbage collected compute_encoder->retain(); - eit = encoder_map_.emplace(index, CommandEncoder{compute_encoder}).first; + eit = encoder_map_ + .emplace(index, std::make_unique(compute_encoder)) + .first; } - return eit->second; + return *(eit->second); } void Device::register_library( diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 4fc43c164..029b2cc92 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -39,6 +39,7 @@ using MTLFCList = struct CommandEncoder { CommandEncoder(MTL::ComputeCommandEncoder* enc) : enc(enc), concurrent(false){}; + CommandEncoder(const CommandEncoder&) = delete; CommandEncoder& operator=(const CommandEncoder&) = delete; struct ConcurrentContext { @@ -197,7 +198,7 @@ class Device { MTL::Device* device_; std::unordered_map queue_map_; std::unordered_map> buffer_map_; - std::unordered_map encoder_map_; + std::unordered_map> encoder_map_; std::unordered_map kernel_map_; std::unordered_map library_map_; std::mutex mtx_; diff --git a/python/tests/test_fast.py b/python/tests/test_fast.py index 3b2db95c6..b554be55b 100644 --- a/python/tests/test_fast.py +++ b/python/tests/test_fast.py @@ -248,7 +248,7 @@ class TestFast(mlx_tests.MLXTestCase): def test_layer_norm(self): # Per dtype absolute tolerance - tolerances = {mx.float32: 3e-6, mx.float16: 3e-3, mx.bfloat16: 3e-2} + tolerances = {mx.float32: 1e-5, mx.float16: 5e-3, mx.bfloat16: 5e-2} dtypes = [mx.float32, mx.float16, mx.bfloat16] epss = [1e-3, 1e-5]