mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-29 06:53:18 +08:00
Added clip function (#159)
* Added clip * Added Python bindings * Formatting * Added cpp tests * Added Python tests * python bindings work * rebase --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -2171,3 +2171,24 @@ TEST_CASE("test eye with negative k offset") {
|
||||
{4, 3});
|
||||
CHECK(array_equal(eye_4_k_minus1, expected_eye_4_k_minus1).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test basic clipping") {
|
||||
array a({1.0f, 4.0f, 3.0f, 8.0f, 5.0f}, {5});
|
||||
array expected({2.0f, 4.0f, 3.0f, 6.0f, 5.0f}, {5});
|
||||
auto clipped = clip(a, array(2.0f), array(6.0f));
|
||||
CHECK(array_equal(clipped, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test clipping with only min") {
|
||||
array a({-1.0f, 1.0f, 0.0f, 5.0f}, {4});
|
||||
array expected({0.0f, 1.0f, 0.0f, 5.0f}, {4});
|
||||
auto clipped = clip(a, array(0.0f), std::nullopt);
|
||||
CHECK(array_equal(clipped, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test clipping with only max") {
|
||||
array a({2.0f, 3.0f, 4.0f, 5.0f}, {4});
|
||||
array expected({2.0f, 3.0f, 4.0f, 4.0f}, {4});
|
||||
auto clipped = clip(a, std::nullopt, array(4.0f));
|
||||
CHECK(array_equal(clipped, expected).item<bool>());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user