mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Array api (#1289)
* some updates for numpy 2.0 and array api * some updates for numpy 2.0 and array api * fix array api doc
This commit is contained in:
parent
e9e53856d2
commit
7b456fd2c0
@ -1839,15 +1839,6 @@ array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) {
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// TODO: Fix GPU kernel
|
||||
if (a.shape(axis) >= (1u << 21) && to_stream(s).device.type == Device::gpu) {
|
||||
std::ostringstream msg;
|
||||
msg << "[argsort] GPU sort cannot handle sort axis of >= 2M elements,"
|
||||
<< " got array with sort axis size " << a.shape(axis) << "."
|
||||
<< " Please place this operation on the CPU instead.";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return array(
|
||||
a.shape(), uint32, std::make_shared<ArgSort>(to_stream(s), axis), {a});
|
||||
}
|
||||
|
@ -294,6 +294,29 @@ void init_array(nb::module_& m) {
|
||||
Returns:
|
||||
array: The array with type ``dtype``.
|
||||
)pbdoc")
|
||||
.def(
|
||||
"__array_namespace__",
|
||||
[](const array& a, const std::optional<std::string>& api_version) {
|
||||
if (api_version) {
|
||||
throw std::invalid_argument(
|
||||
"Explicitly specifying api_version is not yet implemented.");
|
||||
}
|
||||
return nb::module_::import_("mlx.core");
|
||||
},
|
||||
"api_version"_a = nb::none(),
|
||||
R"pbdoc(
|
||||
Returns an object that has all the array API functions on it.
|
||||
|
||||
See the `Python array API <https://data-apis.org/array-api/latest/index.html>`_
|
||||
for more information.
|
||||
|
||||
Args:
|
||||
api_version (str, optional): String representing the version
|
||||
of the array API spec to return. Default: ``None``.
|
||||
|
||||
Returns:
|
||||
out (Any): An object representing the array API namespace.
|
||||
)pbdoc")
|
||||
.def("__getitem__", mlx_get_item, nb::arg().none())
|
||||
.def("__setitem__", mlx_set_item, nb::arg().none(), nb::arg())
|
||||
.def_prop_ro(
|
||||
|
@ -6,18 +6,9 @@
|
||||
namespace nb = nanobind;
|
||||
|
||||
void init_constants(nb::module_& m) {
|
||||
m.attr("Inf") = std::numeric_limits<double>::infinity();
|
||||
m.attr("Infinity") = std::numeric_limits<double>::infinity();
|
||||
m.attr("NAN") = NAN;
|
||||
m.attr("NINF") = -std::numeric_limits<double>::infinity();
|
||||
m.attr("NZERO") = -0.0;
|
||||
m.attr("NaN") = NAN;
|
||||
m.attr("PINF") = std::numeric_limits<double>::infinity();
|
||||
m.attr("PZERO") = 0.0;
|
||||
m.attr("e") = 2.71828182845904523536028747135266249775724709369995;
|
||||
m.attr("euler_gamma") = 0.5772156649015328606065120900824024310421;
|
||||
m.attr("inf") = std::numeric_limits<double>::infinity();
|
||||
m.attr("infty") = std::numeric_limits<double>::infinity();
|
||||
m.attr("nan") = NAN;
|
||||
m.attr("newaxis") = nb::none();
|
||||
m.attr("pi") = 3.1415926535897932384626433;
|
||||
|
@ -2061,7 +2061,7 @@ void init_ops(nb::module_& m) {
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value()) {
|
||||
return transpose(a, get_reduce_axes(axes.value(), a.ndim()), s);
|
||||
return transpose(a, *axes, s);
|
||||
} else {
|
||||
return transpose(a, s);
|
||||
}
|
||||
@ -2083,6 +2083,26 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The transposed array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"permute_dims",
|
||||
[](const array& a,
|
||||
const std::optional<std::vector<int>>& axes,
|
||||
StreamOrDevice s) {
|
||||
if (axes.has_value()) {
|
||||
return transpose(a, *axes, s);
|
||||
} else {
|
||||
return transpose(a, s);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
"axes"_a = nb::none(),
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def permute_dims(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
See :func:`transpose`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"sum",
|
||||
[](const array& a,
|
||||
@ -2666,6 +2686,26 @@ void init_ops(nb::module_& m) {
|
||||
Returns:
|
||||
array: The concatenated array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"concat",
|
||||
[](const std::vector<array>& arrays,
|
||||
std::optional<int> axis,
|
||||
StreamOrDevice s) {
|
||||
if (axis) {
|
||||
return concatenate(arrays, *axis, s);
|
||||
} else {
|
||||
return concatenate(arrays, s);
|
||||
}
|
||||
},
|
||||
nb::arg(),
|
||||
"axis"_a.none() = 0,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def concat(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
See :func:`concatenate`.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"stack",
|
||||
[](const std::vector<array>& arrays,
|
||||
|
@ -1828,6 +1828,12 @@ class TestArray(mlx_tests.MLXTestCase):
|
||||
anp[:, idx] = 4
|
||||
self.assertTrue(np.array_equal(a, anp))
|
||||
|
||||
def test_array_namespace(self):
|
||||
a = mx.array(1.0)
|
||||
api = a.__array_namespace__()
|
||||
self.assertTrue(hasattr(api, "array"))
|
||||
self.assertTrue(hasattr(api, "add"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -10,14 +10,6 @@ import numpy as np
|
||||
class TestConstants(mlx_tests.MLXTestCase):
|
||||
def test_constants_values(self):
|
||||
# Check if mlx constants match expected values
|
||||
self.assertAlmostEqual(mx.Inf, float("inf"))
|
||||
self.assertAlmostEqual(mx.Infinity, float("inf"))
|
||||
self.assertTrue(np.isnan(mx.NAN))
|
||||
self.assertAlmostEqual(mx.NINF, float("-inf"))
|
||||
self.assertEqual(mx.NZERO, -0.0)
|
||||
self.assertTrue(np.isnan(mx.NaN))
|
||||
self.assertAlmostEqual(mx.PINF, float("inf"))
|
||||
self.assertEqual(mx.PZERO, 0.0)
|
||||
self.assertAlmostEqual(
|
||||
mx.e, 2.71828182845904523536028747135266249775724709369995
|
||||
)
|
||||
@ -25,25 +17,15 @@ class TestConstants(mlx_tests.MLXTestCase):
|
||||
mx.euler_gamma, 0.5772156649015328606065120900824024310421
|
||||
)
|
||||
self.assertAlmostEqual(mx.inf, float("inf"))
|
||||
self.assertAlmostEqual(mx.infty, float("inf"))
|
||||
self.assertTrue(np.isnan(mx.nan))
|
||||
self.assertIsNone(mx.newaxis)
|
||||
self.assertAlmostEqual(mx.pi, 3.1415926535897932384626433)
|
||||
|
||||
def test_constants_availability(self):
|
||||
# Check if mlx constants are available
|
||||
self.assertTrue(hasattr(mx, "Inf"))
|
||||
self.assertTrue(hasattr(mx, "Infinity"))
|
||||
self.assertTrue(hasattr(mx, "NAN"))
|
||||
self.assertTrue(hasattr(mx, "NINF"))
|
||||
self.assertTrue(hasattr(mx, "NaN"))
|
||||
self.assertTrue(hasattr(mx, "PINF"))
|
||||
self.assertTrue(hasattr(mx, "NZERO"))
|
||||
self.assertTrue(hasattr(mx, "PZERO"))
|
||||
self.assertTrue(hasattr(mx, "e"))
|
||||
self.assertTrue(hasattr(mx, "euler_gamma"))
|
||||
self.assertTrue(hasattr(mx, "inf"))
|
||||
self.assertTrue(hasattr(mx, "infty"))
|
||||
self.assertTrue(hasattr(mx, "nan"))
|
||||
self.assertTrue(hasattr(mx, "newaxis"))
|
||||
self.assertTrue(hasattr(mx, "pi"))
|
||||
|
Loading…
Reference in New Issue
Block a user