mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-28 00:36:32 +08:00
[CUDA] Use ConcurrentContext in concatenate_gpu
This commit is contained in:
parent
30561229c7
commit
b04d6c224c
@ -1,6 +1,7 @@
|
||||
// Copyright © 2025 Apple Inc.
|
||||
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/cuda/device.h"
|
||||
#include "mlx/backend/gpu/copy.h"
|
||||
#include "mlx/backend/gpu/slicing.h"
|
||||
|
||||
@ -27,8 +28,7 @@ void concatenate_gpu(
|
||||
flags.row_contiguous = false;
|
||||
flags.col_contiguous = false;
|
||||
flags.contiguous = false;
|
||||
// TODO: Handle concurrent outputs:
|
||||
// https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816
|
||||
auto concurrent = cu::get_command_encoder(s).concurrent_context();
|
||||
for (int i = 0; i < inputs.size(); i++) {
|
||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||
size_t data_offset = strides[axis] * sizes[i];
|
||||
|
Loading…
Reference in New Issue
Block a user