diff --git a/mlx/backend/cuda/slicing.cpp b/mlx/backend/cuda/slicing.cpp index 18cc14bbd..93241936b 100644 --- a/mlx/backend/cuda/slicing.cpp +++ b/mlx/backend/cuda/slicing.cpp @@ -30,8 +30,7 @@ void concatenate_gpu( flags.row_contiguous = false; flags.col_contiguous = false; flags.contiguous = false; - // TODO: Handle concurrent outputs: - // https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816 + auto concurrent = cu::get_command_encoder(s).concurrent_context(); for (int i = 0; i < inputs.size(); i++) { array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); size_t data_offset = strides[axis] * sizes[i];