Print exceptions in eval_cpu/eval_gpu and abort (#1754)

This commit is contained in:
Cheng 2025-01-08 23:31:09 +09:00 committed by GitHub
parent d1766f2c70
commit b8f76f717a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 22 additions and 2 deletions

View File

@ -47,7 +47,11 @@ std::function<void()> make_task(array arr, bool signal) {
}
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
arr.primitive().eval_gpu(arr.inputs(), outputs);
try {
arr.primitive().eval_gpu(arr.inputs(), outputs);
} catch (const std::exception& error) {
abort_with_exception(error);
}
}
std::vector<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {

View File

@ -220,7 +220,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
}
scheduler::notify_new_task(stream);
auto outputs = arr.outputs();
arr.primitive().eval_cpu(arr.inputs(), outputs);
try {
arr.primitive().eval_cpu(arr.inputs(), outputs);
} catch (const std::exception& error) {
abort_with_exception(error);
}
if (!arr.is_tracer()) {
arr.detach();
}

View File

@ -1,6 +1,7 @@
// Copyright © 2023 Apple Inc.
#include <cstdlib>
#include <iostream>
#include <sstream>
#include <vector>
@ -62,6 +63,13 @@ PrintFormatter& get_global_formatter() {
return formatter;
}
void abort_with_exception(const std::exception& error) {
std::ostringstream msg;
msg << "Terminating due to uncaught exception: " << error.what();
std::cerr << msg.str() << std::endl;
std::abort();
}
Dtype result_type(const std::vector<array>& arrays) {
Dtype t = bool_;
for (auto& arr : arrays) {

View File

@ -2,6 +2,7 @@
#pragma once
#include <exception>
#include <variant>
#include "mlx/array.h"
@ -53,6 +54,9 @@ struct PrintFormatter {
PrintFormatter& get_global_formatter();
/** Print the exception and then abort. */
void abort_with_exception(const std::exception& error);
/** Holds information about floating-point types. */
struct finfo {
explicit finfo(Dtype dtype);