From 83b11bc58ddd84c8e5ec83f9455b57db66df0c73 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 4 Jun 2024 13:17:08 -0700 Subject: [PATCH] Fix Metal API validation for empty concat (#1183) --- mlx/backend/metal/copy.cpp | 7 ++++--- python/tests/test_ops.py | 6 ++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp index 699772950..f2f98da06 100644 --- a/mlx/backend/metal/copy.cpp +++ b/mlx/backend/metal/copy.cpp @@ -33,9 +33,6 @@ void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { } else { out.set_data(allocator::malloc_or_wait(out.nbytes())); } - if (out.size() == 0) { - return; - } if (ctype == CopyType::GeneralGeneral) { ctype = CopyType::General; } @@ -57,6 +54,10 @@ void copy_gpu_inplace( int64_t out_offset, CopyType ctype, const Stream& s) { + if (out.size() == 0) { + return; + } + // Try to collapse contiguous dims auto [shape, strides] = collapse_contiguous_dims( data_shape, std::vector{strides_in_pre, strides_out_pre}); diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 209ddcae2..dc89a4afd 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1533,6 +1533,12 @@ class TestOps(mlx_tests.MLXTestCase): b = mx.array([1, 2]) mx.concatenate([a, b], axis=0) + # Cocnatenate with 0-sized array + a = mx.zeros((2, 0, 2)) + b = mx.zeros((2, 2, 2)) + out = mx.concatenate([a, b], axis=1) + self.assertTrue(mx.array_equal(out, b)) + def test_meshgrid(self): x = mx.array([1, 2, 3], dtype=mx.int32) y = np.array([1, 2, 3], dtype=np.int32)