mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
fix clip (#1415)
This commit is contained in:
@@ -2008,6 +2008,22 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
clipped = mx.clip(mx.array(a), mx.array(mins), mx.array(maxs))
|
||||
self.assertTrue(np.array_equal(clipped, expected))
|
||||
|
||||
# Check clip output types
|
||||
a = mx.array([1, 2, 3], mx.int16)
|
||||
out_t = mx.clip(a, a_min=0, a_max=5).dtype
|
||||
self.assertEqual(out_t, mx.int16)
|
||||
|
||||
out_t = mx.clip(a, a_min=0.0, a_max=5).dtype
|
||||
self.assertEqual(out_t, mx.float32)
|
||||
|
||||
a = mx.array([1, 2, 3], mx.float16)
|
||||
out_t = mx.clip(a, a_min=0.0, a_max=5).dtype
|
||||
self.assertEqual(out_t, mx.float16)
|
||||
|
||||
a = mx.array([1, 2, 3], mx.float16)
|
||||
out_t = mx.clip(a, a_min=0.0, a_max=mx.array(1.0)).dtype
|
||||
self.assertEqual(out_t, mx.float32)
|
||||
|
||||
def test_linspace(self):
|
||||
# Test default num = 50
|
||||
a = mx.linspace(0, 1)
|
||||
|
Reference in New Issue
Block a user