From d6492b01630b5cfe5c4f9531a29e2fda2943d299 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 14 Sep 2024 16:09:09 -0700 Subject: [PATCH] fix clip (#1415) --- mlx/ops.cpp | 2 +- python/src/ops.cpp | 4 ++-- python/tests/test_ops.py | 16 ++++++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index a27f7ca91..200afd9dc 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -838,7 +838,7 @@ array clip( if (!a_min.has_value() && !a_max.has_value()) { throw std::invalid_argument("At most one of a_min and a_max may be None"); } - array result = astype(a, a.dtype(), s); + array result = a; if (a_min.has_value()) { result = maximum(result, a_min.value(), s); } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index a594fd287..ba814f1a2 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2827,10 +2827,10 @@ void init_ops(nb::module_& m) { std::optional min_ = std::nullopt; std::optional max_ = std::nullopt; if (min) { - min_ = to_array(min.value()); + min_ = to_arrays(a, min.value()).second; } if (max) { - max_ = to_array(max.value()); + max_ = to_arrays(a, max.value()).second; } return clip(a, min_, max_, s); }, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 2837d9a30..7dcde16d7 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -2008,6 +2008,22 @@ class TestOps(mlx_tests.MLXTestCase): clipped = mx.clip(mx.array(a), mx.array(mins), mx.array(maxs)) self.assertTrue(np.array_equal(clipped, expected)) + # Check clip output types + a = mx.array([1, 2, 3], mx.int16) + out_t = mx.clip(a, a_min=0, a_max=5).dtype + self.assertEqual(out_t, mx.int16) + + out_t = mx.clip(a, a_min=0.0, a_max=5).dtype + self.assertEqual(out_t, mx.float32) + + a = mx.array([1, 2, 3], mx.float16) + out_t = mx.clip(a, a_min=0.0, a_max=5).dtype + self.assertEqual(out_t, mx.float16) + + a = mx.array([1, 2, 3], mx.float16) + out_t = mx.clip(a, a_min=0.0, a_max=mx.array(1.0)).dtype + self.assertEqual(out_t, mx.float32) + def test_linspace(self): # Test default num = 50 a = mx.linspace(0, 1)