mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
* Nicer exceptions for ops on non-arrays
This commit is contained in:
committed by
GitHub
parent
3fc993f82d
commit
0caf35f4b8
@@ -2,6 +2,7 @@
|
||||
#pragma once
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <variant>
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
@@ -56,6 +57,19 @@ inline bool is_comparable_with_array(const ScalarOrArray& v) {
|
||||
}
|
||||
}
|
||||
|
||||
inline nb::handle get_handle_of_object(const ScalarOrArray& v) {
|
||||
return std::get<nb::object>(v).ptr();
|
||||
}
|
||||
|
||||
inline void throw_invalid_operation(
|
||||
const std::string& operation,
|
||||
const ScalarOrArray operand) {
|
||||
std::ostringstream msg;
|
||||
msg << "Cannot perform " << operation << " on an mlx.core.array and "
|
||||
<< nb::type_name(get_handle_of_object(operand).type()).c_str();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
inline array to_array(
|
||||
const ScalarOrArray& v,
|
||||
std::optional<Dtype> dtype = std::nullopt) {
|
||||
|
||||
Reference in New Issue
Block a user