[CUDA] Use ConcurrentContext in concatenate_gpu

This commit is contained in:
Cheng 2025-08-22 18:51:59 -07:00
parent 30561229c7
commit b04d6c224c

View File

@ -1,6 +1,7 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h"
@ -27,8 +28,7 @@ void concatenate_gpu(
flags.row_contiguous = false;
flags.col_contiguous = false;
flags.contiguous = false;
// TODO: Handle concurrent outputs:
// https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816
auto concurrent = cu::get_command_encoder(s).concurrent_context();
for (int i = 0; i < inputs.size(); i++) {
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
size_t data_offset = strides[axis] * sizes[i];