mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-12 07:14:34 +08:00
Array api (#1289)
* some updates for numpy 2.0 and array api * some updates for numpy 2.0 and array api * fix array api doc
This commit is contained in:
@@ -294,6 +294,29 @@ void init_array(nb::module_& m) {
|
||||
Returns:
|
||||
array: The array with type ``dtype``.
|
||||
)pbdoc")
|
||||
.def(
|
||||
"__array_namespace__",
|
||||
[](const array& a, const std::optional<std::string>& api_version) {
|
||||
if (api_version) {
|
||||
throw std::invalid_argument(
|
||||
"Explicitly specifying api_version is not yet implemented.");
|
||||
}
|
||||
return nb::module_::import_("mlx.core");
|
||||
},
|
||||
"api_version"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Returns an object that has all the array API functions on it.
|
||||
|
||||
See the `Python array API <https://data-apis.org/array-api/latest/index.html>`_
|
||||
for more information.
|
||||
|
||||
Args:
|
||||
api_version (str, optional): String representing the version
|
||||
of the array API spec to return. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
out (Any): An object representing the array API namespace.
|
||||
)pbdoc")
|
||||
.def("__getitem__", mlx_get_item, nb::arg().none())
|
||||
.def("__setitem__", mlx_set_item, nb::arg().none(), nb::arg())
|
||||
.def_prop_ro(
|
||||
|
@@ -6,18 +6,9 @@
|
||||
namespace nb = nanobind;
|
||||
|
||||
void init_constants(nb::module_& m) {
|
||||
m.attr("Inf") = std::numeric_limits<double>::infinity();
|
||||
m.attr("Infinity") = std::numeric_limits<double>::infinity();
|
||||
m.attr("NAN") = NAN;
|
||||
m.attr("NINF") = -std::numeric_limits<double>::infinity();
|
||||
m.attr("NZERO") = -0.0;
|
||||
m.attr("NaN") = NAN;
|
||||
m.attr("PINF") = std::numeric_limits<double>::infinity();
|
||||
m.attr("PZERO") = 0.0;
|
||||
m.attr("e") = 2.71828182845904523536028747135266249775724709369995;
|
||||
m.attr("euler_gamma") = 0.5772156649015328606065120900824024310421;
|
||||
m.attr("inf") = std::numeric_limits<double>::infinity();
|
||||
m.attr("infty") = std::numeric_limits<double>::infinity();
|
||||
m.attr("nan") = NAN;
|
||||
m.attr("newaxis") = nb::none();
|
||||
m.attr("pi") = 3.1415926535897932384626433;
|
||||
|
@@ -2061,7 +2061,7 @@ void init_ops(nb::module_& m) {
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value()) {
|
||||
return transpose(a, get_reduce_axes(axes.value(), a.ndim()), s);
|
||||
return transpose(a, *axes, s);
|
||||
} else {
|
||||
return transpose(a, s);
|
||||
}
|
||||
@@ -2083,6 +2083,26 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The transposed array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"permute_dims",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value()) {
|
||||
return transpose(a, *axes, s);
|
||||
} else {
|
||||
return transpose(a, s);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
"axes"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def permute_dims(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
See :func:`transpose`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"sum",
|
||||
[](const array& a,
|
||||
@@ -2666,6 +2686,26 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The concatenated array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"concat",
|
||||
[](const std::vector<array>& arrays,
|
||||
std::optional<int> axis,
|
||||
StreamOrDevice s) {
|
||||
if (axis) {
|
||||
return concatenate(arrays, *axis, s);
|
||||
} else {
|
||||
return concatenate(arrays, s);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
"axis"_a.none() = 0,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def concat(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
See :func:`concatenate`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"stack",
|
||||
[](const std::vector<array>& arrays,
|
||||
|
Reference in New Issue
Block a user