Mlx array accessor (#128)

* Add an accessor to interoperate with custom types
* Change the docs to custom signatures
This commit is contained in:
Angelos Katharopoulos
2023-12-11 13:42:55 -08:00
committed by GitHub
parent 072044e28f
commit 3214629601
3 changed files with 342 additions and 133 deletions

View File

@@ -15,8 +15,8 @@ namespace py = pybind11;
using namespace mlx::core;
using IntOrVec = std::variant<std::monostate, int, std::vector<int>>;
using ScalarOrArray =
std::variant<py::bool_, py::int_, py::float_, std::complex<float>, array>;
using ScalarOrArray = std::
variant<py::bool_, py::int_, py::float_, std::complex<float>, py::object>;
static constexpr std::monostate none{};
inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
@@ -32,6 +32,14 @@ inline std::vector<int> get_reduce_axes(const IntOrVec& v, int dims) {
return axes;
}
inline array to_array_with_accessor(py::object obj) {
if (py::hasattr(obj, "__mlx_array__")) {
return obj.attr("__mlx_array__")().cast<array>();
} else {
return obj.cast<array>();
}
}
inline array to_array(
const ScalarOrArray& v,
std::optional<Dtype> dtype = std::nullopt) {
@@ -48,7 +56,7 @@ inline array to_array(
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), complex64);
} else {
return std::get<array>(v);
return to_array_with_accessor(std::get<py::object>(v));
}
}
@@ -60,13 +68,16 @@ inline std::pair<array, array> to_arrays(
// - If a is an array but b is not, treat b as a weak python type
// - If b is an array but a is not, treat a as a weak python type
// - If neither is an array convert to arrays but leave their types alone
if (auto pa = std::get_if<array>(&a); pa) {
if (auto pb = std::get_if<array>(&b); pb) {
return {*pa, *pb};
if (auto pa = std::get_if<py::object>(&a); pa) {
auto arr_a = to_array_with_accessor(*pa);
if (auto pb = std::get_if<py::object>(&b); pb) {
auto arr_b = to_array_with_accessor(*pb);
return {arr_a, arr_b};
}
return {*pa, to_array(b, pa->dtype())};
} else if (auto pb = std::get_if<array>(&b); pb) {
return {to_array(a, pb->dtype()), *pb};
return {arr_a, to_array(b, arr_a.dtype())};
} else if (auto pb = std::get_if<py::object>(&b); pb) {
auto arr_b = to_array_with_accessor(*pb);
return {to_array(a, arr_b.dtype()), arr_b};
} else {
return {to_array(a), to_array(b)};
}