mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Added clip function (#159)
* Added clip * Added Python bindings * Formatting * Added cpp tests * Added Python tests * python bindings work * rebase --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -1435,6 +1435,34 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(x.flatten(start_axis=1).shape, [2, 3 * 4])
|
||||
self.assertEqual(x.flatten(end_axis=1).shape, [2 * 3, 4])
|
||||
|
||||
def test_clip(self):
|
||||
a = np.array([1, 4, 3, 8, 5], np.int32)
|
||||
expected = np.clip(a, 2, 6)
|
||||
clipped = mx.clip(mx.array(a), 2, 6)
|
||||
self.assertTrue(np.array_equal(clipped, expected))
|
||||
|
||||
a = np.array([-1, 1, 0, 5], np.int32)
|
||||
expected = np.clip(a, 0, None)
|
||||
clipped = mx.clip(mx.array(a), 0, None)
|
||||
self.assertTrue(np.array_equal(clipped, expected))
|
||||
|
||||
a = np.array([2, 3, 4, 5], np.int32)
|
||||
expected = np.clip(a, None, 4)
|
||||
clipped = mx.clip(mx.array(a), None, 4)
|
||||
self.assertTrue(np.array_equal(clipped, expected))
|
||||
|
||||
mins = np.array([3, 1, 5, 5])
|
||||
a = np.array([2, 3, 4, 5], np.int32)
|
||||
expected = np.clip(a, mins, 4)
|
||||
clipped = mx.clip(mx.array(a), mx.array(mins), 4)
|
||||
self.assertTrue(np.array_equal(clipped, expected))
|
||||
|
||||
maxs = np.array([5, -1, 2, 9])
|
||||
a = np.array([2, 3, 4, 5], np.int32)
|
||||
expected = np.clip(a, mins, maxs)
|
||||
clipped = mx.clip(mx.array(a), mx.array(mins), mx.array(maxs))
|
||||
self.assertTrue(np.array_equal(clipped, expected))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user