Mlx array accessor (#128)

* Add an accessor to interoperate with custom types
* Change the docs to custom signatures
This commit is contained in:
Angelos Katharopoulos
2023-12-11 13:42:55 -08:00
committed by GitHub
parent 072044e28f
commit 3214629601
3 changed files with 342 additions and 133 deletions

View File

@@ -458,38 +458,54 @@ void init_array(py::module_& m) {
m.attr("bfloat16") = py::cast(bfloat16);
m.attr("complex64") = py::cast(complex64);
py::class_<array>(m, "array", R"pbdoc(An N-dimensional array object.)pbdoc")
.def(
py::init([](ScalarOrArray v, std::optional<Dtype> t) {
auto arr = to_array(v, t);
auto array_class = py::class_<array>(
m, "array", R"pbdoc(An N-dimensional array object.)pbdoc");
{
py::options options;
options.disable_function_signatures();
array_class.def(
py::init([](std::variant<
py::bool_,
py::int_,
py::float_,
std::complex<float>,
py::list,
py::tuple,
py::array,
py::buffer,
py::object> v,
std::optional<Dtype> t) {
if (auto pv = std::get_if<py::bool_>(&v); pv) {
return array(py::cast<bool>(*pv), t.value_or(bool_));
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
return array(py::cast<int>(*pv), t.value_or(int32));
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
return array(py::cast<float>(*pv), t.value_or(float32));
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
} else if (auto pv = std::get_if<py::list>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<py::tuple>(&v); pv) {
return array_from_list(*pv, t);
} else if (auto pv = std::get_if<py::array>(&v); pv) {
return np_array_to_mlx(*pv, t);
} else if (auto pv = std::get_if<py::buffer>(&v); pv) {
return np_array_to_mlx(*pv, t);
} else {
auto arr = to_array_with_accessor(std::get<py::object>(v));
return astype(arr, t.value_or(arr.dtype()));
}),
"val"_a,
"dtype"_a = std::nullopt)
.def(
py::init([](std::variant<py::list, py::tuple> pl,
std::optional<Dtype> dtype) {
if (auto pv = std::get_if<py::list>(&pl); pv) {
return array_from_list(*pv, dtype);
} else {
auto v = std::get<py::tuple>(pl);
return array_from_list(v, dtype);
}
}),
"vals"_a,
"dtype"_a = std::nullopt)
.def(
py::init([](py::array np_array, std::optional<Dtype> dtype) {
return np_array_to_mlx(np_array, dtype);
}),
"vals"_a,
"dtype"_a = std::nullopt)
.def(
py::init([](py::buffer np_buffer, std::optional<Dtype> dtype) {
return np_array_to_mlx(np_buffer, dtype);
}),
"vals"_a,
"dtype"_a = std::nullopt)
}
}),
"val"_a,
"dtype"_a = std::nullopt,
R"pbdoc(
__init__(self: array, val: Union[scalar, list, tuple, numpy.ndarray, array], dtype: Optional[Dtype] = None)
)pbdoc");
}
array_class
.def_property_readonly(
"size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc")
.def_property_readonly(