mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Return empty array when repeats is 0 in mx.repeat (#681)
* Return empty array when repeats is 0 * Add test case for repeats = 0
This commit is contained in:
parent
2fdc2462c3
commit
0c65517e91
@ -774,7 +774,7 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
|
||||
}
|
||||
|
||||
if (repeats == 0) {
|
||||
return array({}, arr.dtype());
|
||||
return array(std::initializer_list<int>{}, arr.dtype());
|
||||
}
|
||||
|
||||
if (repeats == 1) {
|
||||
|
@ -1699,6 +1699,8 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
def test_repeat(self):
|
||||
# Setup data for the tests
|
||||
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
|
||||
# Test repeat 0 times
|
||||
self.assertCmpNumpy([data, 0], mx.repeat, np.repeat)
|
||||
# Test repeat along axis 0
|
||||
self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0)
|
||||
# Test repeat along axis 1
|
||||
|
Loading…
Reference in New Issue
Block a user