mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-14 09:07:12 +08:00
Mlx array accessor (#128)
* Add an accessor to interoperate with custom types * Change the docs to custom signatures
This commit is contained in:
committed by
GitHub
parent
072044e28f
commit
3214629601
@@ -458,38 +458,54 @@ void init_array(py::module_& m) {
|
||||
m.attr("bfloat16") = py::cast(bfloat16);
|
||||
m.attr("complex64") = py::cast(complex64);
|
||||
|
||||
py::class_<array>(m, "array", R"pbdoc(An N-dimensional array object.)pbdoc")
|
||||
.def(
|
||||
py::init([](ScalarOrArray v, std::optional<Dtype> t) {
|
||||
auto arr = to_array(v, t);
|
||||
auto array_class = py::class_<array>(
|
||||
m, "array", R"pbdoc(An N-dimensional array object.)pbdoc");
|
||||
|
||||
{
|
||||
py::options options;
|
||||
options.disable_function_signatures();
|
||||
|
||||
array_class.def(
|
||||
py::init([](std::variant<
|
||||
py::bool_,
|
||||
py::int_,
|
||||
py::float_,
|
||||
std::complex<float>,
|
||||
py::list,
|
||||
py::tuple,
|
||||
py::array,
|
||||
py::buffer,
|
||||
py::object> v,
|
||||
std::optional<Dtype> t) {
|
||||
if (auto pv = std::get_if<py::bool_>(&v); pv) {
|
||||
return array(py::cast<bool>(*pv), t.value_or(bool_));
|
||||
} else if (auto pv = std::get_if<py::int_>(&v); pv) {
|
||||
return array(py::cast<int>(*pv), t.value_or(int32));
|
||||
} else if (auto pv = std::get_if<py::float_>(&v); pv) {
|
||||
return array(py::cast<float>(*pv), t.value_or(float32));
|
||||
} else if (auto pv = std::get_if<std::complex<float>>(&v); pv) {
|
||||
return array(static_cast<complex64_t>(*pv), t.value_or(complex64));
|
||||
} else if (auto pv = std::get_if<py::list>(&v); pv) {
|
||||
return array_from_list(*pv, t);
|
||||
} else if (auto pv = std::get_if<py::tuple>(&v); pv) {
|
||||
return array_from_list(*pv, t);
|
||||
} else if (auto pv = std::get_if<py::array>(&v); pv) {
|
||||
return np_array_to_mlx(*pv, t);
|
||||
} else if (auto pv = std::get_if<py::buffer>(&v); pv) {
|
||||
return np_array_to_mlx(*pv, t);
|
||||
} else {
|
||||
auto arr = to_array_with_accessor(std::get<py::object>(v));
|
||||
return astype(arr, t.value_or(arr.dtype()));
|
||||
}),
|
||||
"val"_a,
|
||||
"dtype"_a = std::nullopt)
|
||||
.def(
|
||||
py::init([](std::variant<py::list, py::tuple> pl,
|
||||
std::optional<Dtype> dtype) {
|
||||
if (auto pv = std::get_if<py::list>(&pl); pv) {
|
||||
return array_from_list(*pv, dtype);
|
||||
} else {
|
||||
auto v = std::get<py::tuple>(pl);
|
||||
return array_from_list(v, dtype);
|
||||
}
|
||||
}),
|
||||
"vals"_a,
|
||||
"dtype"_a = std::nullopt)
|
||||
.def(
|
||||
py::init([](py::array np_array, std::optional<Dtype> dtype) {
|
||||
return np_array_to_mlx(np_array, dtype);
|
||||
}),
|
||||
"vals"_a,
|
||||
"dtype"_a = std::nullopt)
|
||||
.def(
|
||||
py::init([](py::buffer np_buffer, std::optional<Dtype> dtype) {
|
||||
return np_array_to_mlx(np_buffer, dtype);
|
||||
}),
|
||||
"vals"_a,
|
||||
"dtype"_a = std::nullopt)
|
||||
}
|
||||
}),
|
||||
"val"_a,
|
||||
"dtype"_a = std::nullopt,
|
||||
R"pbdoc(
|
||||
__init__(self: array, val: Union[scalar, list, tuple, numpy.ndarray, array], dtype: Optional[Dtype] = None)
|
||||
)pbdoc");
|
||||
}
|
||||
|
||||
array_class
|
||||
.def_property_readonly(
|
||||
"size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc")
|
||||
.def_property_readonly(
|
||||
|
||||
Reference in New Issue
Block a user