mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
No copy command encoder (#986)
* no copy command encoder * up layer norm test tolerances
This commit is contained in:
parent
91eba8e485
commit
ae18326533
@ -129,7 +129,7 @@ Device::~Device() {
|
|||||||
b.second.second->release();
|
b.second.second->release();
|
||||||
}
|
}
|
||||||
for (auto& e : encoder_map_) {
|
for (auto& e : encoder_map_) {
|
||||||
e.second->release();
|
(*e.second)->release();
|
||||||
}
|
}
|
||||||
for (auto& k : kernel_map_) {
|
for (auto& k : kernel_map_) {
|
||||||
k.second->release();
|
k.second->release();
|
||||||
@ -200,8 +200,8 @@ void Device::commit_command_buffer(int index) {
|
|||||||
void Device::end_encoding(int index) {
|
void Device::end_encoding(int index) {
|
||||||
auto eit = encoder_map_.find(index);
|
auto eit = encoder_map_.find(index);
|
||||||
if (eit != encoder_map_.end()) {
|
if (eit != encoder_map_.end()) {
|
||||||
eit->second->endEncoding();
|
(*eit->second)->endEncoding();
|
||||||
eit->second->release();
|
(*eit->second)->release();
|
||||||
encoder_map_.erase(eit);
|
encoder_map_.erase(eit);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -214,9 +214,11 @@ CommandEncoder& Device::get_command_encoder(int index) {
|
|||||||
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||||
// Increment ref count so the buffer is not garbage collected
|
// Increment ref count so the buffer is not garbage collected
|
||||||
compute_encoder->retain();
|
compute_encoder->retain();
|
||||||
eit = encoder_map_.emplace(index, CommandEncoder{compute_encoder}).first;
|
eit = encoder_map_
|
||||||
|
.emplace(index, std::make_unique<CommandEncoder>(compute_encoder))
|
||||||
|
.first;
|
||||||
}
|
}
|
||||||
return eit->second;
|
return *(eit->second);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::register_library(
|
void Device::register_library(
|
||||||
|
@ -39,6 +39,7 @@ using MTLFCList =
|
|||||||
struct CommandEncoder {
|
struct CommandEncoder {
|
||||||
CommandEncoder(MTL::ComputeCommandEncoder* enc)
|
CommandEncoder(MTL::ComputeCommandEncoder* enc)
|
||||||
: enc(enc), concurrent(false){};
|
: enc(enc), concurrent(false){};
|
||||||
|
CommandEncoder(const CommandEncoder&) = delete;
|
||||||
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
CommandEncoder& operator=(const CommandEncoder&) = delete;
|
||||||
|
|
||||||
struct ConcurrentContext {
|
struct ConcurrentContext {
|
||||||
@ -197,7 +198,7 @@ class Device {
|
|||||||
MTL::Device* device_;
|
MTL::Device* device_;
|
||||||
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_;
|
||||||
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
|
||||||
std::unordered_map<int32_t, CommandEncoder> encoder_map_;
|
std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
|
||||||
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;
|
||||||
std::unordered_map<std::string, MTL::Library*> library_map_;
|
std::unordered_map<std::string, MTL::Library*> library_map_;
|
||||||
std::mutex mtx_;
|
std::mutex mtx_;
|
||||||
|
@ -248,7 +248,7 @@ class TestFast(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
def test_layer_norm(self):
|
def test_layer_norm(self):
|
||||||
# Per dtype absolute tolerance
|
# Per dtype absolute tolerance
|
||||||
tolerances = {mx.float32: 3e-6, mx.float16: 3e-3, mx.bfloat16: 3e-2}
|
tolerances = {mx.float32: 1e-5, mx.float16: 5e-3, mx.bfloat16: 5e-2}
|
||||||
|
|
||||||
dtypes = [mx.float32, mx.float16, mx.bfloat16]
|
dtypes = [mx.float32, mx.float16, mx.bfloat16]
|
||||||
epss = [1e-3, 1e-5]
|
epss = [1e-3, 1e-5]
|
||||||
|
Loading…
Reference in New Issue
Block a user