mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Print exceptions in eval_cpu/eval_gpu and abort (#1754)
This commit is contained in:
parent
d1766f2c70
commit
b8f76f717a
@ -47,7 +47,11 @@ std::function<void()> make_task(array arr, bool signal) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
try {
|
||||||
|
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||||
|
} catch (const std::exception& error) {
|
||||||
|
abort_with_exception(error);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||||
for (auto& in : arr.inputs()) {
|
for (auto& in : arr.inputs()) {
|
||||||
|
@ -220,7 +220,11 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
scheduler::notify_new_task(stream);
|
scheduler::notify_new_task(stream);
|
||||||
auto outputs = arr.outputs();
|
auto outputs = arr.outputs();
|
||||||
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
try {
|
||||||
|
arr.primitive().eval_cpu(arr.inputs(), outputs);
|
||||||
|
} catch (const std::exception& error) {
|
||||||
|
abort_with_exception(error);
|
||||||
|
}
|
||||||
if (!arr.is_tracer()) {
|
if (!arr.is_tracer()) {
|
||||||
arr.detach();
|
arr.detach();
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -62,6 +63,13 @@ PrintFormatter& get_global_formatter() {
|
|||||||
return formatter;
|
return formatter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void abort_with_exception(const std::exception& error) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "Terminating due to uncaught exception: " << error.what();
|
||||||
|
std::cerr << msg.str() << std::endl;
|
||||||
|
std::abort();
|
||||||
|
}
|
||||||
|
|
||||||
Dtype result_type(const std::vector<array>& arrays) {
|
Dtype result_type(const std::vector<array>& arrays) {
|
||||||
Dtype t = bool_;
|
Dtype t = bool_;
|
||||||
for (auto& arr : arrays) {
|
for (auto& arr : arrays) {
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <exception>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
|
|
||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
@ -53,6 +54,9 @@ struct PrintFormatter {
|
|||||||
|
|
||||||
PrintFormatter& get_global_formatter();
|
PrintFormatter& get_global_formatter();
|
||||||
|
|
||||||
|
/** Print the exception and then abort. */
|
||||||
|
void abort_with_exception(const std::exception& error);
|
||||||
|
|
||||||
/** Holds information about floating-point types. */
|
/** Holds information about floating-point types. */
|
||||||
struct finfo {
|
struct finfo {
|
||||||
explicit finfo(Dtype dtype);
|
explicit finfo(Dtype dtype);
|
||||||
|
Loading…
Reference in New Issue
Block a user