mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Return empty array when repeats is 0 in mx.repeat (#681)
* Return empty array when repeats is 0 * Add test case for repeats = 0
This commit is contained in:
@@ -1699,6 +1699,8 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
def test_repeat(self):
|
||||
# Setup data for the tests
|
||||
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
|
||||
# Test repeat 0 times
|
||||
self.assertCmpNumpy([data, 0], mx.repeat, np.repeat)
|
||||
# Test repeat along axis 0
|
||||
self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0)
|
||||
# Test repeat along axis 1
|
||||
|
Reference in New Issue
Block a user