mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
[CUDA] Use ConcurrentContext in concatenate_gpu (#2549)
This commit is contained in:
@@ -30,8 +30,7 @@ void concatenate_gpu(
|
|||||||
flags.row_contiguous = false;
|
flags.row_contiguous = false;
|
||||||
flags.col_contiguous = false;
|
flags.col_contiguous = false;
|
||||||
flags.contiguous = false;
|
flags.contiguous = false;
|
||||||
// TODO: Handle concurrent outputs:
|
auto concurrent = cu::get_command_encoder(s).concurrent_context();
|
||||||
// https://github.com/ml-explore/mlx/pull/2145#discussion_r2070753816
|
|
||||||
for (int i = 0; i < inputs.size(); i++) {
|
for (int i = 0; i < inputs.size(); i++) {
|
||||||
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
array out_slice(inputs[i].shape(), out.dtype(), nullptr, {});
|
||||||
size_t data_offset = strides[axis] * sizes[i];
|
size_t data_offset = strides[axis] * sizes[i];
|
||||||
|
Reference in New Issue
Block a user