mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-20 10:27:41 +08:00
Added clip function (#159)
* Added clip * Added Python bindings * Formatting * Added cpp tests * Added Python tests * python bindings work * rebase --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
ee0c2835c5
commit
8eb56beb3a
@ -27,6 +27,7 @@ Operations
|
|||||||
array_equal
|
array_equal
|
||||||
broadcast_to
|
broadcast_to
|
||||||
ceil
|
ceil
|
||||||
|
clip
|
||||||
concatenate
|
concatenate
|
||||||
convolve
|
convolve
|
||||||
conv1d
|
conv1d
|
||||||
|
18
mlx/ops.cpp
18
mlx/ops.cpp
@ -595,6 +595,24 @@ split(const array& a, int num_splits, StreamOrDevice s /* = {} */) {
|
|||||||
return split(a, num_splits, 0, to_stream(s));
|
return split(a, num_splits, 0, to_stream(s));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array clip(
|
||||||
|
const array& a,
|
||||||
|
const std::optional<array>& a_min,
|
||||||
|
const std::optional<array>& a_max,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (!a_min.has_value() && !a_max.has_value()) {
|
||||||
|
throw std::invalid_argument("At most one of a_min and a_max may be None");
|
||||||
|
}
|
||||||
|
array result = astype(a, a.dtype(), s);
|
||||||
|
if (a_min.has_value()) {
|
||||||
|
result = maximum(result, a_min.value(), s);
|
||||||
|
}
|
||||||
|
if (a_max.has_value()) {
|
||||||
|
result = minimum(result, a_max.value(), s);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
array concatenate(
|
array concatenate(
|
||||||
const std::vector<array>& arrays,
|
const std::vector<array>& arrays,
|
||||||
int axis,
|
int axis,
|
||||||
|
@ -185,6 +185,15 @@ std::vector<array> split(
|
|||||||
std::vector<array>
|
std::vector<array>
|
||||||
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
|
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clip (limit) the values in an array.
|
||||||
|
*/
|
||||||
|
array clip(
|
||||||
|
const array& a,
|
||||||
|
const std::optional<array>& a_min = std::nullopt,
|
||||||
|
const std::optional<array>& a_max = std::nullopt,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Concatenate arrays along a given axis. */
|
/** Concatenate arrays along a given axis. */
|
||||||
array concatenate(
|
array concatenate(
|
||||||
const std::vector<array>& arrays,
|
const std::vector<array>& arrays,
|
||||||
|
@ -2353,6 +2353,45 @@ void init_ops(py::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
array: The resulting stacked array.
|
array: The resulting stacked array.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"clip",
|
||||||
|
[](const array& a,
|
||||||
|
const std::optional<ScalarOrArray>& min,
|
||||||
|
const std::optional<ScalarOrArray>& max,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
std::optional<array> min_ = std::nullopt;
|
||||||
|
std::optional<array> max_ = std::nullopt;
|
||||||
|
if (min) {
|
||||||
|
min_ = to_array(min.value());
|
||||||
|
}
|
||||||
|
if (max) {
|
||||||
|
max_ = to_array(max.value());
|
||||||
|
}
|
||||||
|
return clip(a, min_, max_, s);
|
||||||
|
},
|
||||||
|
"a"_a,
|
||||||
|
py::pos_only(),
|
||||||
|
"a_min"_a,
|
||||||
|
"a_max"_a,
|
||||||
|
py::kw_only(),
|
||||||
|
"stream"_a = none,
|
||||||
|
R"pbdoc(
|
||||||
|
clip(a: array, /, a_min: Union[scalar, array, None], a_max: Union[scalar, array, None], *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
|
Clip the values of the array between the given minimum and maximum.
|
||||||
|
|
||||||
|
If either ``a_min`` or ``a_max`` are ``None``, then corresponding edge
|
||||||
|
is ignored. At least one of ``a_min`` and ``a_max`` cannot be ``None``.
|
||||||
|
The input ``a`` and the limits must broadcast with one another.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a (array): Input array.
|
||||||
|
a_min (scalar or array or None): Minimum value to clip to.
|
||||||
|
a_max (scalar or array or None): Maximum value to clip to.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
array: The clipped array.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"pad",
|
"pad",
|
||||||
[](const array& a,
|
[](const array& a,
|
||||||
|
@ -1435,6 +1435,34 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4])
|
self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4])
|
||||||
self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4])
|
self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4])
|
||||||
|
|
||||||
|
def test_clip(self):
|
||||||
|
a = np.array([1, 4, 3, 8, 5], np.int32)
|
||||||
|
expected = np.clip(a, 2, 6)
|
||||||
|
clipped = mx.clip(mx.array(a), 2, 6)
|
||||||
|
self.assertTrue(np.array_equal(clipped, expected))
|
||||||
|
|
||||||
|
a = np.array([-1, 1, 0, 5], np.int32)
|
||||||
|
expected = np.clip(a, 0, None)
|
||||||
|
clipped = mx.clip(mx.array(a), 0, None)
|
||||||
|
self.assertTrue(np.array_equal(clipped, expected))
|
||||||
|
|
||||||
|
a = np.array([2, 3, 4, 5], np.int32)
|
||||||
|
expected = np.clip(a, None, 4)
|
||||||
|
clipped = mx.clip(mx.array(a), None, 4)
|
||||||
|
self.assertTrue(np.array_equal(clipped, expected))
|
||||||
|
|
||||||
|
mins = np.array([3, 1, 5, 5])
|
||||||
|
a = np.array([2, 3, 4, 5], np.int32)
|
||||||
|
expected = np.clip(a, mins, 4)
|
||||||
|
clipped = mx.clip(mx.array(a), mx.array(mins), 4)
|
||||||
|
self.assertTrue(np.array_equal(clipped, expected))
|
||||||
|
|
||||||
|
maxs = np.array([5, -1, 2, 9])
|
||||||
|
a = np.array([2, 3, 4, 5], np.int32)
|
||||||
|
expected = np.clip(a, mins, maxs)
|
||||||
|
clipped = mx.clip(mx.array(a), mx.array(mins), mx.array(maxs))
|
||||||
|
self.assertTrue(np.array_equal(clipped, expected))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -2171,3 +2171,24 @@ TEST_CASE("test eye with negative k offset") {
|
|||||||
{4, 3});
|
{4, 3});
|
||||||
CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());
|
CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test basic clipping") {
|
||||||
|
array a({1.0f, 4.0f, 3.0f, 8.0f, 5.0f}, {5});
|
||||||
|
array expected({2.0f, 4.0f, 3.0f, 6.0f, 5.0f}, {5});
|
||||||
|
auto clipped = clip(a, array(2.0f), array(6.0f));
|
||||||
|
CHECK(array_equal(clipped, expected).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test clipping with only min") {
|
||||||
|
array a({-1.0f, 1.0f, 0.0f, 5.0f}, {4});
|
||||||
|
array expected({0.0f, 1.0f, 0.0f, 5.0f}, {4});
|
||||||
|
auto clipped = clip(a, array(0.0f), std::nullopt);
|
||||||
|
CHECK(array_equal(clipped, expected).item<bool>());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test clipping with only max") {
|
||||||
|
array a({2.0f, 3.0f, 4.0f, 5.0f}, {4});
|
||||||
|
array expected({2.0f, 3.0f, 4.0f, 4.0f}, {4});
|
||||||
|
auto clipped = clip(a, std::nullopt, array(4.0f));
|
||||||
|
CHECK(array_equal(clipped, expected).item<bool>());
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user