mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Print exceptions in eval_cpu/eval_gpu and abort (#1754)
This commit is contained in:
parent
d1766f2c70
commit
b8f76f717a
@ -47,7 +47,11 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
}
|
||||
|
||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||
try {
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
} catch (const std::exception& error) {
|
||||
abort_with_exception(error);
|
||||
}
|
||||
}
|
||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
|
@ -220,7 +220,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
||||
}
|
||||
scheduler::notify_new_task(stream);
|
||||
auto outputs = arr.outputs();
|
||||
try {
|
||||
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
||||
} catch (const std::exception& error) {
|
||||
abort_with_exception(error);
|
||||
}
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
@ -62,6 +63,13 @@ PrintFormatter& get_global_formatter() {
|
||||
return formatter;
|
||||
}
|
||||
|
||||
void abort_with_exception(const std::exception& error) {
|
||||
std::ostringstream msg;
|
||||
msg << "Terminating due to uncaught exception: " << error.what();
|
||||
std::cerr << msg.str() << std::endl;
|
||||
std::abort();
|
||||
}
|
||||
|
||||
Dtype result_type(const std::vector<array>& arrays) {
|
||||
Dtype t = bool_;
|
||||
for (auto& arr : arrays) {
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
#include <variant>
|
||||
|
||||
#include "mlx/array.h"
|
||||
@ -53,6 +54,9 @@ struct PrintFormatter {
|
||||
|
||||
PrintFormatter& get_global_formatter();
|
||||
|
||||
/** Print the exception and then abort. */
|
||||
void abort_with_exception(const std::exception& error);
|
||||
|
||||
/** Holds information about floating-point types. */
|
||||
struct finfo {
|
||||
explicit finfo(Dtype dtype);
|
||||
|
Loading…
Reference in New Issue
Block a user