mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Added clip function (#159)
* Added clip * Added Python bindings * Formatting * Added cpp tests * Added Python tests * python bindings work * rebase --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -2353,6 +2353,45 @@ void init_ops(py::module_& m) {
|
||||
Returns:
|
||||
array: The resulting stacked array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"clip",
|
||||
[](const array& a,
|
||||
const std::optional<ScalarOrArray>& min,
|
||||
const std::optional<ScalarOrArray>& max,
|
||||
StreamOrDevice s) {
|
||||
std::optional<array> min_ = std::nullopt;
|
||||
std::optional<array> max_ = std::nullopt;
|
||||
if (min) {
|
||||
min_ = to_array(min.value());
|
||||
}
|
||||
if (max) {
|
||||
max_ = to_array(max.value());
|
||||
}
|
||||
return clip(a, min_, max_, s);
|
||||
},
|
||||
"a"_a,
|
||||
py::pos_only(),
|
||||
"a_min"_a,
|
||||
"a_max"_a,
|
||||
py::kw_only(),
|
||||
"stream"_a = none,
|
||||
R"pbdoc(
|
||||
clip(a: array, /, a_min: Union[scalar, array, None], a_max: Union[scalar, array, None], *, stream: Union[None, Stream, Device] = None) -> array
|
||||
|
||||
Clip the values of the array between the given minimum and maximum.
|
||||
|
||||
If either ``a_min`` or ``a_max`` are ``None``, then corresponding edge
|
||||
is ignored. At least one of ``a_min`` and ``a_max`` cannot be ``None``.
|
||||
The input ``a`` and the limits must broadcast with one another.
|
||||
|
||||
Args:
|
||||
a (array): Input array.
|
||||
a_min (scalar or array or None): Minimum value to clip to.
|
||||
a_max (scalar or array or None): Maximum value to clip to.
|
||||
|
||||
Returns:
|
||||
array: The clipped array.
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"pad",
|
||||
[](const array& a,
|
||||
|
@@ -1435,6 +1435,34 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4])
|
||||
self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4])
|
||||
|
||||
def test_clip(self):
|
||||
a = np.array([1, 4, 3, 8, 5], np.int32)
|
||||
expected = np.clip(a, 2, 6)
|
||||
clipped = mx.clip(mx.array(a), 2, 6)
|
||||
self.assertTrue(np.array_equal(clipped, expected))
|
||||
|
||||
a = np.array([-1, 1, 0, 5], np.int32)
|
||||
expected = np.clip(a, 0, None)
|
||||
clipped = mx.clip(mx.array(a), 0, None)
|
||||
self.assertTrue(np.array_equal(clipped, expected))
|
||||
|
||||
a = np.array([2, 3, 4, 5], np.int32)
|
||||
expected = np.clip(a, None, 4)
|
||||
clipped = mx.clip(mx.array(a), None, 4)
|
||||
self.assertTrue(np.array_equal(clipped, expected))
|
||||
|
||||
mins = np.array([3, 1, 5, 5])
|
||||
a = np.array([2, 3, 4, 5], np.int32)
|
||||
expected = np.clip(a, mins, 4)
|
||||
clipped = mx.clip(mx.array(a), mx.array(mins), 4)
|
||||
self.assertTrue(np.array_equal(clipped, expected))
|
||||
|
||||
maxs = np.array([5, -1, 2, 9])
|
||||
a = np.array([2, 3, 4, 5], np.int32)
|
||||
expected = np.clip(a, mins, maxs)
|
||||
clipped = mx.clip(mx.array(a), mx.array(mins), mx.array(maxs))
|
||||
self.assertTrue(np.array_equal(clipped, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user