diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 01ee6d388..96107d515 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -774,7 +774,7 @@ array repeat(const array& arr, int repeats, int axis, StreamOrDevice s) { } if (repeats == 0) { - return array({}, arr.dtype()); + return array(std::initializer_list{}, arr.dtype()); } if (repeats == 1) { diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index edb98032b..5588ebd62 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1699,6 +1699,8 @@ class TestOps(mlx_tests.MLXTestCase): def test_repeat(self): # Setup data for the tests data = mx.array([[[13, 3], [16, 6]], [[14, 4], [15, 5]], [[11, 1], [12, 2]]]) + # Test repeat 0 times + self.assertCmpNumpy([data, 0], mx.repeat, np.repeat) # Test repeat along axis 0 self.assertCmpNumpy([data, 2], mx.repeat, np.repeat, axis=0) # Test repeat along axis 1