Return empty array when repeats is 0 in mx.repeat (#681)

* Return empty array when repeats is 0

* Add test case for repeats = 0
This commit is contained in:
Noah Farr 2024-02-14 02:49:31 +01:00 committed by GitHub
parent 2fdc2462c3
commit 0c65517e91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 1 deletions

View File

@ -774,7 +774,7 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) {
} }
if (repeats == 0) { if (repeats == 0) {
return array({}, arr.dtype()); return array(std::initializer_list<int>{}, arr.dtype());
} }
if (repeats == 1) { if (repeats == 1) {

View File

@ -1699,6 +1699,8 @@ class TestOps(mlx_tests.MLXTestCase):
def test_repeat(self): def test_repeat(self):
# Setup data for the tests # Setup data for the tests
data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]]) data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]])
# Test repeat 0 times
self.assertCmpNumpy([data, 0], mx.repeat, np.repeat)
# Test repeat along axis 0 # Test repeat along axis 0
self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0) self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0)
# Test repeat along axis 1 # Test repeat along axis 1