diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 699772950..f2f98da06 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -33,9 +33,6 @@ void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { } else { out.set_data(allocator::malloc_or_wait(out.nbytes())); } - if (out.size() == 0) { - return; - } if (ctype == CopyType::GeneralGeneral) { ctype = CopyType::General; } @@ -57,6 +54,10 @@ void copy_gpu_inplace( int64_t out_offset, CopyType ctype, const Stream& s) { + if (out.size() == 0) { + return; + } + // Try to collapse contiguous dims auto [shape, strides] = collapse_contiguous_dims( data_shape, std::vector{strides_in_pre, strides_out_pre}); diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 209ddcae2..dc89a4afd 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1533,6 +1533,12 @@ class TestOps(mlx_tests.MLXTestCase): b = mx.array([1, 2]) mx.concatenate([a, b], axis=0) + # Cocnatenate with 0-sized array + a = mx.zeros((2, 0, 2)) + b = mx.zeros((2, 2, 2)) + out = mx.concatenate([a, b], axis=1) + self.assertTrue(mx.array_equal(out, b)) + def test_meshgrid(self): x = mx.array([1, 2, 3], dtype=mx.int32) y = np.array([1, 2, 3], dtype=np.int32)