Added clip function (#159)

* Added clip

* Added Python bindings

* Formatting

* Added cpp tests

* Added Python tests

* python bindings work

* rebase

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Cyril Zakka, MD 2023-12-17 22:00:29 -06:00 committed by GitHub
parent ee0c2835c5
commit 8eb56beb3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 116 additions and 0 deletions

View File

@ -27,6 +27,7 @@ Operations
array_equal array_equal
broadcast_to broadcast_to
ceil ceil
clip
concatenate concatenate
convolve convolve
conv1d conv1d

View File

@ -595,6 +595,24 @@ split(const array& a, int num_splits, StreamOrDevice s /* = {} */) {
return split(a, num_splits, 0, to_stream(s)); return split(a, num_splits, 0, to_stream(s));
} }
array clip(
const array& a,
const std::optional<array>& a_min,
const std::optional<array>& a_max,
StreamOrDevice s /* = {} */) {
if (!a_min.has_value() && !a_max.has_value()) {
throw std::invalid_argument("At most one of a_min and a_max may be None");
}
array result = astype(a, a.dtype(), s);
if (a_min.has_value()) {
result = maximum(result, a_min.value(), s);
}
if (a_max.has_value()) {
result = minimum(result, a_max.value(), s);
}
return result;
}
array concatenate( array concatenate(
const std::vector<array>& arrays, const std::vector<array>& arrays,
int axis, int axis,

View File

@ -185,6 +185,15 @@ std::vector<array> split(
std::vector<array> std::vector<array>
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {}); split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
/**
* Clip (limit) the values in an array.
*/
array clip(
const array& a,
const std::optional<array>& a_min = std::nullopt,
const std::optional<array>& a_max = std::nullopt,
StreamOrDevice s = {});
/** Concatenate arrays along a given axis. */ /** Concatenate arrays along a given axis. */
array concatenate( array concatenate(
const std::vector<array>& arrays, const std::vector<array>& arrays,

View File

@ -2353,6 +2353,45 @@ void init_ops(py::module_& m) {
Returns: Returns:
array: The resulting stacked array. array: The resulting stacked array.
)pbdoc"); )pbdoc");
m.def(
"clip",
[](const array& a,
const std::optional<ScalarOrArray>& min,
const std::optional<ScalarOrArray>& max,
StreamOrDevice s) {
std::optional<array> min_ = std::nullopt;
std::optional<array> max_ = std::nullopt;
if (min) {
min_ = to_array(min.value());
}
if (max) {
max_ = to_array(max.value());
}
return clip(a, min_, max_, s);
},
"a"_a,
py::pos_only(),
"a_min"_a,
"a_max"_a,
py::kw_only(),
"stream"_a = none,
R"pbdoc(
clip(a: array, /, a_min: Union[scalar, array, None], a_max: Union[scalar, array, None], *, stream: Union[None, Stream, Device] = None) -> array
Clip the values of the array between the given minimum and maximum.
If either ``a_min`` or ``a_max`` are ``None``, then corresponding edge
is ignored. At least one of ``a_min`` and ``a_max`` cannot be ``None``.
The input ``a`` and the limits must broadcast with one another.
Args:
a (array): Input array.
a_min (scalar or array or None): Minimum value to clip to.
a_max (scalar or array or None): Maximum value to clip to.
Returns:
array: The clipped array.
)pbdoc");
m.def( m.def(
"pad", "pad",
[](const array& a, [](const array& a,

View File

@ -1435,6 +1435,34 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4]) self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4])
self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4]) self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4])
def test_clip(self):
a = np.array([1, 4, 3, 8, 5], np.int32)
expected = np.clip(a, 2, 6)
clipped = mx.clip(mx.array(a), 2, 6)
self.assertTrue(np.array_equal(clipped, expected))
a = np.array([-1, 1, 0, 5], np.int32)
expected = np.clip(a, 0, None)
clipped = mx.clip(mx.array(a), 0, None)
self.assertTrue(np.array_equal(clipped, expected))
a = np.array([2, 3, 4, 5], np.int32)
expected = np.clip(a, None, 4)
clipped = mx.clip(mx.array(a), None, 4)
self.assertTrue(np.array_equal(clipped, expected))
mins = np.array([3, 1, 5, 5])
a = np.array([2, 3, 4, 5], np.int32)
expected = np.clip(a, mins, 4)
clipped = mx.clip(mx.array(a), mx.array(mins), 4)
self.assertTrue(np.array_equal(clipped, expected))
maxs = np.array([5, -1, 2, 9])
a = np.array([2, 3, 4, 5], np.int32)
expected = np.clip(a, mins, maxs)
clipped = mx.clip(mx.array(a), mx.array(mins), mx.array(maxs))
self.assertTrue(np.array_equal(clipped, expected))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -2171,3 +2171,24 @@ TEST_CASE("test eye with negative k offset") {
{4, 3}); {4, 3});
CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>()); CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());
} }
TEST_CASE("test basic clipping") {
array a({1.0f, 4.0f, 3.0f, 8.0f, 5.0f}, {5});
array expected({2.0f, 4.0f, 3.0f, 6.0f, 5.0f}, {5});
auto clipped = clip(a, array(2.0f), array(6.0f));
CHECK(array_equal(clipped, expected).item<bool>());
}
TEST_CASE("test clipping with only min") {
array a({-1.0f, 1.0f, 0.0f, 5.0f}, {4});
array expected({0.0f, 1.0f, 0.0f, 5.0f}, {4});
auto clipped = clip(a, array(0.0f), std::nullopt);
CHECK(array_equal(clipped, expected).item<bool>());
}
TEST_CASE("test clipping with only max") {
array a({2.0f, 3.0f, 4.0f, 5.0f}, {4});
array expected({2.0f, 3.0f, 4.0f, 4.0f}, {4});
auto clipped = clip(a, std::nullopt, array(4.0f));
CHECK(array_equal(clipped, expected).item<bool>());
}