diff --git a/mlx/backend/cuda/slicing.cpp b/mlx/backend/cuda/slicing.cpp index af67fbbdd..64a2afe74 100644 --- a/mlx/backend/cuda/slicing.cpp +++ b/mlx/backend/cuda/slicing.cpp @@ -1,6 +1,7 @@ // Copyright © 2025 Apple Inc. #include "mlx/backend/common/slicing.h" +#include "mlx/backend/cuda/device.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/slicing.h" @@ -27,8 +28,7 @@ void concatenate_gpu( flags.row_contiguous = false; flags.col_contiguous = false; flags.contiguous = false; - // TODO: Handle concurrent outputs: - // https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816 + auto concurrent = cu::get_command_encoder(s).concurrent_context(); for (int i = 0; i < inputs.size(); i++) { array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); size_t data_offset = strides[axis] * sizes[i];