mlx/python/src/utils.h
Awni Hannun f110357aaa
Bump nanobind to 2.4 + fix (#1710)
* bump nanobind to 2.4 + fix

* fix
2024-12-17 10:57:54 -08:00

79 lines
2.1 KiB
C++

// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <numeric>
#include <optional>
#include <string>
#include <variant>
#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
#include <nanobind/stl/complex.h>
#include <nanobind/stl/variant.h>
#include "mlx/array.h"
#include "python/src/convert.h"
namespace mx = mlx::core;
namespace nb = nanobind;
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
using ScalarOrArray = std::variant<
nb::bool_,
nb::int_,
nb::float_,
// Must be above ndarray
mx::array,
// Must be above complex
nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>,
std::complex<float>,
ArrayLike>;
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
std::vector<int> axes;
if (std::holds_alternative<std::monostate>(v)) {
axes.resize(dims);
std::iota(axes.begin(), axes.end(), 0);
} else if (auto pv = std::get_if<int>(&v); pv) {
axes.push_back(*pv);
} else {
axes = std::get<std::vector<int>>(v);
}
return axes;
}
inline bool is_comparable_with_array(const ScalarOrArray& v) {
// Checks if the value can be compared to an array (or is already an
// mlx array)
if (auto pv = std::get_if<ArrayLike>(&v); pv) {
auto obj = (*pv).obj;
return nb::isinstance<mx::array>(obj) || nb::hasattr(obj, "__mlx_array__");
} else {
// If it's not an object, it's a scalar (nb::int_, nb::float_, etc.)
// and can be compared to an array
return true;
}
}
inline nb::handle get_handle_of_object(const ScalarOrArray& v) {
return std::get<ArrayLike>(v).obj.ptr();
}
inline void throw_invalid_operation(
const std::string& operation,
const ScalarOrArray operand) {
std::ostringstream msg;
msg << "Cannot perform " << operation << " on an mlx.core.array and "
<< nb::type_name(get_handle_of_object(operand).type()).c_str();
throw std::invalid_argument(msg.str());
}
mx::array to_array(
const ScalarOrArray& v,
std::optional<mx::Dtype> dtype = std::nullopt);
std::pair<mx::array, mx::array> to_arrays(
const ScalarOrArray& a,
const ScalarOrArray& b);
mx::array to_array_with_accessor(nb::object obj);