diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 6109a3e5f..d848473d8 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -106,6 +106,7 @@ Operations minimum moveaxis multiply + nan_to_num negative not_equal ones diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 430fdc6f0..f7801f24c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1,5 +1,4 @@ // Copyright © 2023-2024 Apple Inc. - #include #include #include @@ -1344,6 +1343,40 @@ array where( inputs); } +array nan_to_num( + const array& a, + float nan /* = 0.0f */, + const std::optional& posinf_ /* = std::nullopt */, + const std::optional& neginf_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + Dtype dtype = a.dtype(); + if (!issubdtype(dtype, inexact)) { + return a; + } + + auto type_to_max = [](const auto& dtype) -> float { + if (dtype == float32) { + return std::numeric_limits::max(); + } else if (dtype == bfloat16) { + return std::numeric_limits::max(); + } else if (dtype == float16) { + return std::numeric_limits::max(); + } else { + std::ostringstream msg; + msg << "[nan_to_num] Does not yet support given type: " << dtype << "."; + throw std::invalid_argument(msg.str()); + } + }; + + float posinf = posinf_ ? *posinf_ : type_to_max(dtype); + float neginf = neginf_ ? *neginf_ : -type_to_max(dtype); + + auto out = where(isnan(a, s), array(nan, dtype), a, s); + out = where(isposinf(a, s), array(posinf, dtype), out, s); + out = where(isneginf(a, s), array(neginf, dtype), out, s); + return out; +} + array allclose( const array& a, const array& b, diff --git a/mlx/ops.h b/mlx/ops.h index d637633b6..448c62a89 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -406,6 +406,14 @@ array where( const array& y, StreamOrDevice s = {}); +/** Replace NaN and infinities with finite numbers. */ +array nan_to_num( + const array& a, + float nan = 0.0f, + const std::optional& posinf = std::nullopt, + const std::optional& neginf = std::nullopt, + StreamOrDevice s = {}); + /** True if all elements in the array are true (or non-zero). **/ array all(const array& a, bool keepdims, StreamOrDevice s = {}); inline array all(const array& a, StreamOrDevice s = {}) { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index be396f308..0f9220473 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3595,6 +3595,39 @@ void init_ops(nb::module_& m) { array: The output containing elements selected from ``x`` and ``y``. )pbdoc"); + m.def( + "nan_to_num", + [](const ScalarOrArray& a, + float nan, + std::optional& posinf, + std::optional& neginf, + StreamOrDevice s) { + return nan_to_num(to_array(a), nan, posinf, neginf, s); + }, + nb::arg(), + "nan"_a = 0.0f, + "posinf"_a = nb::none(), + "neginf"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def nan_to_num(a: Union[scalar, array], nan: float = 0, posinf: Optional[float] = None, neginf: Optional[float] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Replace NaN and Inf values with finite numbers. + + Args: + a (array): Input array + nan (float, optional): Value to replace NaN with. Default: ``0``. + posinf (float, optional): Value to replace positive infinities + with. If ``None``, defaults to largest finite value for the + given data type. Default: ``None``. + neginf (float, optional): Value to replace negative infinities + with. If ``None``, defaults to the negative of the largest + finite value for the given data type. Default: ``None``. + + Returns: + array: Output array with NaN and Inf replaced. + )pbdoc"); m.def( "round", [](const ScalarOrArray& a, int decimals, StreamOrDevice s) { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 16799965f..c3fe0866b 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1653,6 +1653,23 @@ class TestOps(mlx_tests.MLXTestCase): np.where, ) + def test_nan_to_num(self): + a = mx.array([6, float("inf"), 2, 0]) + out_mx = mx.nan_to_num(a) + out_np = np.nan_to_num(a) + self.assertTrue(np.allclose(out_mx, out_np)) + + for t in [mx.float32, mx.float16]: + a = mx.array([float("inf"), 6.9, float("nan"), float("-inf")]) + out_mx = mx.nan_to_num(a) + out_np = np.nan_to_num(a) + self.assertTrue(np.allclose(out_mx, out_np)) + + a = mx.array([float("inf"), 6.9, float("nan"), float("-inf")]).astype(t) + out_np = np.nan_to_num(a, nan=0.0, posinf=1000, neginf=-1000) + out_mx = mx.nan_to_num(a, nan=0.0, posinf=1000, neginf=-1000) + self.assertTrue(np.allclose(out_mx, out_np)) + def test_as_strided(self): x_npy = np.random.randn(128).astype(np.float32) x_mlx = mx.array(x_npy)