2024-03-19 11:12:25 +08:00
|
|
|
// Copyright © 2023-2024 Apple Inc.
|
2023-11-30 02:30:41 +08:00
|
|
|
#pragma once
|
|
|
|
#include <numeric>
|
2024-03-19 11:12:25 +08:00
|
|
|
#include <optional>
|
2024-04-03 12:11:24 +08:00
|
|
|
#include <string>
|
2023-11-30 02:30:41 +08:00
|
|
|
#include <variant>
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
#include <nanobind/nanobind.h>
|
2024-05-07 07:02:49 +08:00
|
|
|
#include <nanobind/ndarray.h>
|
2024-03-19 11:12:25 +08:00
|
|
|
#include <nanobind/stl/complex.h>
|
|
|
|
#include <nanobind/stl/variant.h>
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
#include "mlx/array.h"
|
|
|
|
|
2024-03-19 11:12:25 +08:00
|
|
|
namespace nb = nanobind;
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
using namespace mlx::core;
|
|
|
|
|
|
|
|
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
|
2024-05-07 07:02:49 +08:00
|
|
|
using ScalarOrArray = std::variant<
|
|
|
|
nb::bool_,
|
|
|
|
nb::int_,
|
|
|
|
nb::float_,
|
|
|
|
// Must be above ndarray
|
|
|
|
array,
|
|
|
|
// Must be above complex
|
|
|
|
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
|
|
|
|
std::complex<float>,
|
|
|
|
nb::object>;
|
2023-11-30 02:30:41 +08:00
|
|
|
|
|
|
|
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
|
|
|
|
std::vector<int> axes;
|
|
|
|
if (std::holds_alternative<std::monostate>(v)) {
|
|
|
|
axes.resize(dims);
|
|
|
|
std::iota(axes.begin(), axes.end(), 0);
|
|
|
|
} else if (auto pv = std::get_if<int>(&v); pv) {
|
|
|
|
axes.push_back(*pv);
|
|
|
|
} else {
|
|
|
|
axes = std::get<std::vector<int>>(v);
|
|
|
|
}
|
|
|
|
return axes;
|
|
|
|
}
|
|
|
|
|
2024-03-29 21:52:30 +08:00
|
|
|
inline bool is_comparable_with_array(const ScalarOrArray& v) {
|
|
|
|
// Checks if the value can be compared to an array (or is already an
|
|
|
|
// mlx array)
|
|
|
|
if (auto pv = std::get_if<nb::object>(&v); pv) {
|
|
|
|
return nb::isinstance<array>(*pv) || nb::hasattr(*pv, "__mlx_array__");
|
|
|
|
} else {
|
|
|
|
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
|
|
|
|
// and can be compared to an array
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-04-03 12:11:24 +08:00
|
|
|
inline nb::handle get_handle_of_object(const ScalarOrArray& v) {
|
|
|
|
return std::get<nb::object>(v).ptr();
|
|
|
|
}
|
|
|
|
|
|
|
|
inline void throw_invalid_operation(
|
|
|
|
const std::string& operation,
|
|
|
|
const ScalarOrArray operand) {
|
|
|
|
std::ostringstream msg;
|
|
|
|
msg << "Cannot perform " << operation << " on an mlx.core.array and "
|
|
|
|
<< nb::type_name(get_handle_of_object(operand).type()).c_str();
|
|
|
|
throw std::invalid_argument(msg.str());
|
|
|
|
}
|
|
|
|
|
2024-05-07 07:02:49 +08:00
|
|
|
array to_array(
|
2023-11-30 02:30:41 +08:00
|
|
|
const ScalarOrArray& v,
|
2024-05-07 07:02:49 +08:00
|
|
|
std::optional<Dtype> dtype = std::nullopt);
|
2023-11-30 02:30:41 +08:00
|
|
|
|
2024-05-07 07:02:49 +08:00
|
|
|
std::pair<array, array> to_arrays(
|
2023-11-30 02:30:41 +08:00
|
|
|
const ScalarOrArray& a,
|
2024-05-07 07:02:49 +08:00
|
|
|
const ScalarOrArray& b);
|
|
|
|
|
|
|
|
array to_array_with_accessor(nb::object obj);
|