* some updates for numpy 2.0 and array api

* some updates for numpy 2.0 and array api

* fix array api doc
This commit is contained in:
Awni Hannun 2024-07-26 10:40:49 -07:00 committed by GitHub
parent e9e53856d2
commit 7b456fd2c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 70 additions and 37 deletions

View File

@ -1839,15 +1839,6 @@ array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) {
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
// TODO: Fix GPU kernel
if (a.shape(axis) >= (1u << 21) && to_stream(s).device.type == Device::gpu) {
std::ostringstream msg;
msg << "[argsort] GPU sort cannot handle sort axis of >= 2M elements,"
<< " got array with sort axis size " << a.shape(axis) << "."
<< " Please place this operation on the CPU instead.";
throw std::runtime_error(msg.str());
}
return array( return array(
a.shape(), uint32, std::make_shared<ArgSort>(to_stream(s), axis), {a}); a.shape(), uint32, std::make_shared<ArgSort>(to_stream(s), axis), {a});
} }

View File

@ -294,6 +294,29 @@ void init_array(nb::module_& m) {
Returns: Returns:
array: The array with type ``dtype``. array: The array with type ``dtype``.
)pbdoc") )pbdoc")
.def(
"__array_namespace__",
[](const array& a, const std::optional<std::string>& api_version) {
if (api_version) {
throw std::invalid_argument(
"Explicitly specifying api_version is not yet implemented.");
}
return nb::module_::import_("mlx.core");
},
"api_version"_a = nb::none(),
R"pbdoc(
Returns an object that has all the array API functions on it.
See the `Python array API <https://data-apis.org/array-api/latest/index.html>`_
for more information.
Args:
api_version (str, optional): String representing the version
of the array API spec to return. Default: ``None``.
Returns:
out (Any): An object representing the array API namespace.
)pbdoc")
.def("__getitem__", mlx_get_item, nb::arg().none()) .def("__getitem__", mlx_get_item, nb::arg().none())
.def("__setitem__", mlx_set_item, nb::arg().none(), nb::arg()) .def("__setitem__", mlx_set_item, nb::arg().none(), nb::arg())
.def_prop_ro( .def_prop_ro(

View File

@ -6,18 +6,9 @@
namespace nb = nanobind; namespace nb = nanobind;
void init_constants(nb::module_& m) { void init_constants(nb::module_& m) {
m.attr("Inf") = std::numeric_limits<double>::infinity();
m.attr("Infinity") = std::numeric_limits<double>::infinity();
m.attr("NAN") = NAN;
m.attr("NINF") = -std::numeric_limits<double>::infinity();
m.attr("NZERO") = -0.0;
m.attr("NaN") = NAN;
m.attr("PINF") = std::numeric_limits<double>::infinity();
m.attr("PZERO") = 0.0;
m.attr("e") = 2.71828182845904523536028747135266249775724709369995; m.attr("e") = 2.71828182845904523536028747135266249775724709369995;
m.attr("euler_gamma") = 0.5772156649015328606065120900824024310421; m.attr("euler_gamma") = 0.5772156649015328606065120900824024310421;
m.attr("inf") = std::numeric_limits<double>::infinity(); m.attr("inf") = std::numeric_limits<double>::infinity();
m.attr("infty") = std::numeric_limits<double>::infinity();
m.attr("nan") = NAN; m.attr("nan") = NAN;
m.attr("newaxis") = nb::none(); m.attr("newaxis") = nb::none();
m.attr("pi") = 3.1415926535897932384626433; m.attr("pi") = 3.1415926535897932384626433;

View File

@ -2061,7 +2061,7 @@ void init_ops(nb::module_& m) {
const std::optional<std::vector<int>>& axes, const std::optional<std::vector<int>>& axes,
StreamOrDevice s) { StreamOrDevice s) {
if (axes.has_value()) { if (axes.has_value()) {
return transpose(a, get_reduce_axes(axes.value(), a.ndim()), s); return transpose(a, *axes, s);
} else { } else {
return transpose(a, s); return transpose(a, s);
} }
@ -2083,6 +2083,26 @@ void init_ops(nb::module_& m) {
Returns: Returns:
array: The transposed array. array: The transposed array.
)pbdoc"); )pbdoc");
m.def(
"permute_dims",
[](const array& a,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
if (axes.has_value()) {
return transpose(a, *axes, s);
} else {
return transpose(a, s);
}
},
nb::arg(),
"axes"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def permute_dims(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
See :func:`transpose`.
)pbdoc");
m.def( m.def(
"sum", "sum",
[](const array& a, [](const array& a,
@ -2666,6 +2686,26 @@ void init_ops(nb::module_& m) {
Returns: Returns:
array: The concatenated array. array: The concatenated array.
)pbdoc"); )pbdoc");
m.def(
"concat",
[](const std::vector<array>& arrays,
std::optional<int> axis,
StreamOrDevice s) {
if (axis) {
return concatenate(arrays, *axis, s);
} else {
return concatenate(arrays, s);
}
},
nb::arg(),
"axis"_a.none() = 0,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def concat(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
See :func:`concatenate`.
)pbdoc");
m.def( m.def(
"stack", "stack",
[](const std::vector<array>& arrays, [](const std::vector<array>& arrays,

View File

@ -1828,6 +1828,12 @@ class TestArray(mlx_tests.MLXTestCase):
anp[:, idx] = 4 anp[:, idx] = 4
self.assertTrue(np.array_equal(a, anp)) self.assertTrue(np.array_equal(a, anp))
def test_array_namespace(self):
a = mx.array(1.0)
api = a.__array_namespace__()
self.assertTrue(hasattr(api, "array"))
self.assertTrue(hasattr(api, "add"))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -10,14 +10,6 @@ import numpy as np
class TestConstants(mlx_tests.MLXTestCase): class TestConstants(mlx_tests.MLXTestCase):
def test_constants_values(self): def test_constants_values(self):
# Check if mlx constants match expected values # Check if mlx constants match expected values
self.assertAlmostEqual(mx.Inf, float("inf"))
self.assertAlmostEqual(mx.Infinity, float("inf"))
self.assertTrue(np.isnan(mx.NAN))
self.assertAlmostEqual(mx.NINF, float("-inf"))
self.assertEqual(mx.NZERO, -0.0)
self.assertTrue(np.isnan(mx.NaN))
self.assertAlmostEqual(mx.PINF, float("inf"))
self.assertEqual(mx.PZERO, 0.0)
self.assertAlmostEqual( self.assertAlmostEqual(
mx.e, 2.71828182845904523536028747135266249775724709369995 mx.e, 2.71828182845904523536028747135266249775724709369995
) )
@ -25,25 +17,15 @@ class TestConstants(mlx_tests.MLXTestCase):
mx.euler_gamma, 0.5772156649015328606065120900824024310421 mx.euler_gamma, 0.5772156649015328606065120900824024310421
) )
self.assertAlmostEqual(mx.inf, float("inf")) self.assertAlmostEqual(mx.inf, float("inf"))
self.assertAlmostEqual(mx.infty, float("inf"))
self.assertTrue(np.isnan(mx.nan)) self.assertTrue(np.isnan(mx.nan))
self.assertIsNone(mx.newaxis) self.assertIsNone(mx.newaxis)
self.assertAlmostEqual(mx.pi, 3.1415926535897932384626433) self.assertAlmostEqual(mx.pi, 3.1415926535897932384626433)
def test_constants_availability(self): def test_constants_availability(self):
# Check if mlx constants are available # Check if mlx constants are available
self.assertTrue(hasattr(mx, "Inf"))
self.assertTrue(hasattr(mx, "Infinity"))
self.assertTrue(hasattr(mx, "NAN"))
self.assertTrue(hasattr(mx, "NINF"))
self.assertTrue(hasattr(mx, "NaN"))
self.assertTrue(hasattr(mx, "PINF"))
self.assertTrue(hasattr(mx, "NZERO"))
self.assertTrue(hasattr(mx, "PZERO"))
self.assertTrue(hasattr(mx, "e")) self.assertTrue(hasattr(mx, "e"))
self.assertTrue(hasattr(mx, "euler_gamma")) self.assertTrue(hasattr(mx, "euler_gamma"))
self.assertTrue(hasattr(mx, "inf")) self.assertTrue(hasattr(mx, "inf"))
self.assertTrue(hasattr(mx, "infty"))
self.assertTrue(hasattr(mx, "nan")) self.assertTrue(hasattr(mx, "nan"))
self.assertTrue(hasattr(mx, "newaxis")) self.assertTrue(hasattr(mx, "newaxis"))
self.assertTrue(hasattr(mx, "pi")) self.assertTrue(hasattr(mx, "pi"))