From 8eb56beb3a76e242489dc025043a3e88d9db527e Mon Sep 17 00:00:00 2001 From: "Cyril Zakka, MD" <1841186+cyrilzakka@users.noreply.github.com> Date: Sun, 17 Dec 2023 22:00:29 -0600 Subject: [PATCH] Added clip function (#159) * Added clip * Added Python bindings * Formatting * Added cpp tests * Added Python tests * python bindings work * rebase --------- Co-authored-by: Awni Hannun --- docs/src/python/ops.rst | 1 + mlx/ops.cpp | 18 ++++++++++++++++++ mlx/ops.h | 9 +++++++++ python/src/ops.cpp | 39 +++++++++++++++++++++++++++++++++++++++ python/tests/test_ops.py | 28 ++++++++++++++++++++++++++++ tests/ops_tests.cpp | 21 +++++++++++++++++++++ 6 files changed, 116 insertions(+) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index ea25b90f9..dcdf0ffd9 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -27,6 +27,7 @@ Operations array_equal broadcast_to ceil + clip concatenate convolve conv1d diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 147c2c111..125672af7 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -595,6 +595,24 @@ split(const array& a, int num_splits, StreamOrDevice s /* = {} */) { return split(a, num_splits, 0, to_stream(s)); } +array clip( + const array& a, + const std::optional& a_min, + const std::optional& a_max, + StreamOrDevice s /* = {} */) { + if (!a_min.has_value() && !a_max.has_value()) { + throw std::invalid_argument("At most one of a_min and a_max may be None"); + } + array result = astype(a, a.dtype(), s); + if (a_min.has_value()) { + result = maximum(result, a_min.value(), s); + } + if (a_max.has_value()) { + result = minimum(result, a_max.value(), s); + } + return result; +} + array concatenate( const std::vector& arrays, int axis, diff --git a/mlx/ops.h b/mlx/ops.h index 86c475e6e..50f29afa1 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -185,6 +185,15 @@ std::vector split( std::vector split(const array& a, const std::vector& indices, StreamOrDevice s = {}); +/** + * Clip (limit) the values in an array. + */ +array clip( + const array& a, + const std::optional& a_min = std::nullopt, + const std::optional& a_max = std::nullopt, + StreamOrDevice s = {}); + /** Concatenate arrays along a given axis. */ array concatenate( const std::vector& arrays, diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 58b15e1d6..438b6d80f 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2353,6 +2353,45 @@ void init_ops(py::module_& m) { Returns: array: The resulting stacked array. )pbdoc"); + m.def( + "clip", + [](const array& a, + const std::optional& min, + const std::optional& max, + StreamOrDevice s) { + std::optional min_ = std::nullopt; + std::optional max_ = std::nullopt; + if (min) { + min_ = to_array(min.value()); + } + if (max) { + max_ = to_array(max.value()); + } + return clip(a, min_, max_, s); + }, + "a"_a, + py::pos_only(), + "a_min"_a, + "a_max"_a, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + clip(a: array, /, a_min: Union[scalar, array, None], a_max: Union[scalar, array, None], *, stream: Union[None, Stream, Device] = None) -> array + + Clip the values of the array between the given minimum and maximum. + + If either ``a_min`` or ``a_max`` are ``None``, then corresponding edge + is ignored. At least one of ``a_min`` and ``a_max`` cannot be ``None``. + The input ``a`` and the limits must broadcast with one another. + + Args: + a (array): Input array. + a_min (scalar or array or None): Minimum value to clip to. + a_max (scalar or array or None): Maximum value to clip to. + + Returns: + array: The clipped array. + )pbdoc"); m.def( "pad", [](const array& a, diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index eea726b16..0207fa630 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1435,6 +1435,34 @@ class TestOps(mlx_tests.MLXTestCase): self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4]) self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4]) + def test_clip(self): + a = np.array([1, 4, 3, 8, 5], np.int32) + expected = np.clip(a, 2, 6) + clipped = mx.clip(mx.array(a), 2, 6) + self.assertTrue(np.array_equal(clipped, expected)) + + a = np.array([-1, 1, 0, 5], np.int32) + expected = np.clip(a, 0, None) + clipped = mx.clip(mx.array(a), 0, None) + self.assertTrue(np.array_equal(clipped, expected)) + + a = np.array([2, 3, 4, 5], np.int32) + expected = np.clip(a, None, 4) + clipped = mx.clip(mx.array(a), None, 4) + self.assertTrue(np.array_equal(clipped, expected)) + + mins = np.array([3, 1, 5, 5]) + a = np.array([2, 3, 4, 5], np.int32) + expected = np.clip(a, mins, 4) + clipped = mx.clip(mx.array(a), mx.array(mins), 4) + self.assertTrue(np.array_equal(clipped, expected)) + + maxs = np.array([5, -1, 2, 9]) + a = np.array([2, 3, 4, 5], np.int32) + expected = np.clip(a, mins, maxs) + clipped = mx.clip(mx.array(a), mx.array(mins), mx.array(maxs)) + self.assertTrue(np.array_equal(clipped, expected)) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index af53ce8a1..c62fb4a39 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2171,3 +2171,24 @@ TEST_CASE("test eye with negative k offset") { {4, 3}); CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item()); } + +TEST_CASE("test basic clipping") { + array a({1.0f, 4.0f, 3.0f, 8.0f, 5.0f}, {5}); + array expected({2.0f, 4.0f, 3.0f, 6.0f, 5.0f}, {5}); + auto clipped = clip(a, array(2.0f), array(6.0f)); + CHECK(array_equal(clipped, expected).item()); +} + +TEST_CASE("test clipping with only min") { + array a({-1.0f, 1.0f, 0.0f, 5.0f}, {4}); + array expected({0.0f, 1.0f, 0.0f, 5.0f}, {4}); + auto clipped = clip(a, array(0.0f), std::nullopt); + CHECK(array_equal(clipped, expected).item()); +} + +TEST_CASE("test clipping with only max") { + array a({2.0f, 3.0f, 4.0f, 5.0f}, {4}); + array expected({2.0f, 3.0f, 4.0f, 4.0f}, {4}); + auto clipped = clip(a, std::nullopt, array(4.0f)); + CHECK(array_equal(clipped, expected).item()); +}