diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index 661985c3f..cd4d67a17 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -47,7 +47,11 @@ std::function make_task(array arr, bool signal) { } debug_set_primitive_buffer_label(command_buffer, arr.primitive()); - arr.primitive().eval_gpu(arr.inputs(), outputs); + try { + arr.primitive().eval_gpu(arr.inputs(), outputs); + } catch (const std::exception& error) { + abort_with_exception(error); + } } std::vector> buffers; for (auto& in : arr.inputs()) { diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 88bfebc1b..3c3066e93 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -220,7 +220,11 @@ array eval_impl(std::vector outputs, bool async) { } scheduler::notify_new_task(stream); auto outputs = arr.outputs(); - arr.primitive().eval_cpu(arr.inputs(), outputs); + try { + arr.primitive().eval_cpu(arr.inputs(), outputs); + } catch (const std::exception& error) { + abort_with_exception(error); + } if (!arr.is_tracer()) { arr.detach(); } diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 6a840172f..13fa70994 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -1,6 +1,7 @@ // Copyright © 2023 Apple Inc. #include +#include #include #include @@ -62,6 +63,13 @@ PrintFormatter& get_global_formatter() { return formatter; } +void abort_with_exception(const std::exception& error) { + std::ostringstream msg; + msg << "Terminating due to uncaught exception: " << error.what(); + std::cerr << msg.str() << std::endl; + std::abort(); +} + Dtype result_type(const std::vector& arrays) { Dtype t = bool_; for (auto& arr : arrays) { diff --git a/mlx/utils.h b/mlx/utils.h index 28134bd20..aca6ccd6f 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -2,6 +2,7 @@ #pragma once +#include #include #include "mlx/array.h" @@ -53,6 +54,9 @@ struct PrintFormatter { PrintFormatter& get_global_formatter(); +/** Print the exception and then abort. */ +void abort_with_exception(const std::exception& error); + /** Holds information about floating-point types. */ struct finfo { explicit finfo(Dtype dtype);