From 275db7221a1104d2ebf3fff8657a15c3ce2b4f0f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 17 Jan 2024 11:53:30 -0800 Subject: [PATCH] Command buffer reports errors (#479) * command buffer reports errors * typo * simplify --- mlx/backend/metal/metal.cpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 8436dd3d9..557397c4f 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -42,6 +42,15 @@ MTL::CommandBuffer* increment_command_buffer(Stream s) { return command_buffer; } +inline void check_error(MTL::CommandBuffer* cbuf) { + if (cbuf->status() == MTL::CommandBufferStatusError) { + std::ostringstream msg; + msg << "[METAL] Command buffer execution failed: " + << cbuf->error()->localizedDescription()->utf8String(); + throw std::runtime_error(msg.str()); + } +} + std::function make_task( array& arr, std::vector> deps, @@ -59,7 +68,7 @@ std::function make_task( metal::device(s.device).end_encoding(s.index); scheduler::notify_new_task(s); command_buffer->addCompletedHandler( - [s, arr, p = std::move(p)](MTL::CommandBuffer*) mutable { + [s, arr, p = std::move(p)](MTL::CommandBuffer* cbuf) mutable { if (!arr.is_tracer()) { arr.detach(); for (auto s : arr.siblings()) { @@ -68,14 +77,16 @@ std::function make_task( } p->set_value(); scheduler::notify_task_completion(s); + check_error(cbuf); }); metal::device(s.device).commit_command_buffer(s.index); } else { command_buffer->addCompletedHandler( - [s, arr](MTL::CommandBuffer*) mutable { + [s, arr](MTL::CommandBuffer* cbuf) mutable { if (!arr.is_tracer()) { arr.detach(); } + check_error(cbuf); }); } };