mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
Return empty array when repeats is 0 in mx.repeat (#681)
* Return empty array when repeats is 0 * Add test case for repeats = 0
This commit is contained in:
parent
2fdc2462c3
commit
0c65517e91
@ -774,7 +774,7 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (repeats == 0) {
|
if (repeats == 0) {
|
||||||
return array({}, arr.dtype());
|
return array(std::initializer_list<int>{}, arr.dtype());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (repeats == 1) {
|
if (repeats == 1) {
|
||||||
|
@ -1699,6 +1699,8 @@ class TestOps(mlx_tests.MLXTestCase):
|
|||||||
def test_repeat(self):
|
def test_repeat(self):
|
||||||
# Setup data for the tests
|
# Setup data for the tests
|
||||||
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
|
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
|
||||||
|
# Test repeat 0 times
|
||||||
|
self.assertCmpNumpy([data, 0], mx.repeat, np.repeat)
|
||||||
# Test repeat along axis 0
|
# Test repeat along axis 0
|
||||||
self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0)
|
self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0)
|
||||||
# Test repeat along axis 1
|
# Test repeat along axis 1
|
||||||
|
Loading…
Reference in New Issue
Block a user