From fc00f16e306f865dd749f12f6bbeda21c95e2acd Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 4 Nov 2025 07:41:17 -0800 Subject: [PATCH] fix --- mlx/backend/cuda/fence.cpp | 19 ++++++++++++++++++- mlx/backend/metal/fence.cpp | 24 +++++++++++++----------- mlx/backend/no_gpu/fence.cpp | 2 +- mlx/fence.h | 4 ++-- mlx/transforms.cpp | 17 ++++++++++++----- 5 files changed, 46 insertions(+), 20 deletions(-) diff --git a/mlx/backend/cuda/fence.cpp b/mlx/backend/cuda/fence.cpp index 9121c7f4e..a4aedeb74 100644 --- a/mlx/backend/cuda/fence.cpp +++ b/mlx/backend/cuda/fence.cpp @@ -1,6 +1,8 @@ // Copyright © 2025 Apple Inc. #include "mlx/fence.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/event.h" namespace mlx::core { @@ -20,8 +22,23 @@ void Fence::wait(Stream s, const array&) { fence->event.wait(fence->count); } -void Fence::update(Stream s, const array&) { +void Fence::update(Stream s, const array& a, bool cross_device) { auto* fence = static_cast(fence_.get()); + if (cross_device) { + // Move to managed memory if there is a device switch + auto& cbuf = + *static_cast(const_cast(a).buffer().ptr()); + if (cbuf.device != -1) { + void* new_data; + CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size)); + cbuf.device = -1; + CHECK_CUDA_ERROR( + cudaMemcpyAsync(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault)); + auto& encoder = cu::device(s.device).get_command_encoder(s); + CHECK_CUDA_ERROR(cudaFreeAsync(cbuf.data, encoder.stream())); + cbuf.data = new_data; + } + } fence->count++; fence->event.signal(s, fence->count); } diff --git a/mlx/backend/metal/fence.cpp b/mlx/backend/metal/fence.cpp index 1ae3f9c22..da73d5d91 100644 --- a/mlx/backend/metal/fence.cpp +++ b/mlx/backend/metal/fence.cpp @@ -99,7 +99,7 @@ void Fence::wait(Stream stream, const array& x) { [fence_ = fence_](MTL::CommandBuffer* cbuf) {}); } -void Fence::update(Stream stream, const array& x) { +void Fence::update(Stream stream, const array& x, bool cross_device) { auto& f = *static_cast(fence_.get()); f.count++; @@ -130,21 +130,23 @@ void Fence::update(Stream stream, const array& x) { // Launch input visibility kernels auto& compute_encoder = d.get_command_encoder(idx); - auto kernel = d.get_kernel("input_coherent"); - uint32_t nthreads = - (x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / sizeof(uint32_t); - MTL::Size group_dims = MTL::Size(1024, 1, 1); - MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1); - compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(x, 0); - compute_encoder.set_bytes(nthreads, 1); - compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + if (cross_device) { + auto kernel = d.get_kernel("input_coherent"); + uint32_t nthreads = (x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / + sizeof(uint32_t); + MTL::Size group_dims = MTL::Size(1024, 1, 1); + MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1); + compute_encoder.set_compute_pipeline_state(kernel); + compute_encoder.set_input_array(x, 0); + compute_encoder.set_bytes(nthreads, 1); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + } // Barrier on previous kernels compute_encoder.barrier(); // Launch value update kernel - kernel = d.get_kernel("fence_update"); + auto kernel = d.get_kernel("fence_update"); MTL::Size kernel_dims = MTL::Size(1, 1, 1); compute_encoder.set_compute_pipeline_state(kernel); diff --git a/mlx/backend/no_gpu/fence.cpp b/mlx/backend/no_gpu/fence.cpp index af22108a8..cd66d23cf 100644 --- a/mlx/backend/no_gpu/fence.cpp +++ b/mlx/backend/no_gpu/fence.cpp @@ -36,7 +36,7 @@ void Fence::wait(Stream stream, const array&) { } } -void Fence::update(Stream stream, const array&) { +void Fence::update(Stream stream, const array&, bool) { auto& f = *static_cast(fence_.get()); f.count++; if (stream.device == Device::cpu) { diff --git a/mlx/fence.h b/mlx/fence.h index 66c2b5f49..0ececdb6d 100644 --- a/mlx/fence.h +++ b/mlx/fence.h @@ -15,7 +15,7 @@ namespace mlx::core { * `wait` returns. The array passed to `wait` will not be read until all * previous calls to `update` have completed. * - * Note, calls to `update` should always from the same thread or explicitly + * Note, calls to `update` should always be from the same thread or explicitly * synchronized so that they occur in sequence. Calls to `wait` can be on any * thread. * @@ -29,7 +29,7 @@ class Fence { Fence() {}; explicit Fence(Stream stream); - void update(Stream stream, const array& x); + void update(Stream stream, const array& x, bool cross_device); void wait(Stream stream, const array& x); private: diff --git a/mlx/transforms.cpp b/mlx/transforms.cpp index 3ec64feea..f98293d80 100644 --- a/mlx/transforms.cpp +++ b/mlx/transforms.cpp @@ -62,7 +62,7 @@ array eval_impl(std::vector outputs, bool async) { } // Map of array id that needs fence and stream it's computed on - std::unordered_map needs_fence; + std::unordered_map> needs_fence; auto synchronizer = array( {}, bool_, std::make_shared(stream), std::move(outputs)); @@ -114,7 +114,14 @@ array eval_impl(std::vector outputs, bool async) { "https://github.com/ml-explore/mlx/issues."); } if (a.primitive().stream() != in.primitive().stream()) { - needs_fence.emplace(in.id(), in.primitive().stream().index); + bool device_switch = + a.primitive().stream().device != in.primitive().stream().device; + auto [it, inserted] = needs_fence.emplace( + in.id(), + std::make_pair(in.primitive().stream().index, device_switch)); + if (!inserted) { + it->second.second |= device_switch; + } } } @@ -216,7 +223,7 @@ array eval_impl(std::vector outputs, bool async) { // Use fence to wait within a single eval // Get the input array's stream fence and wait on the // output arrays stream - fences[it->second].wait(stream, in); + fences[it->second.first].wait(stream, in); } else if (in.event().valid()) { if (in.event().is_signaled()) { in.detach_event(); @@ -251,12 +258,12 @@ array eval_impl(std::vector outputs, bool async) { } auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) { - if (needs_fence.find(a.id()) != needs_fence.end()) { + if (auto nf = needs_fence.find(a.id()); nf != needs_fence.end()) { auto it = fences.find(stream.index); if (it == fences.end()) { it = fences.emplace(stream.index, Fence{stream}).first; } - it->second.update(stream, a); + it->second.update(stream, a, nf->second.second); } };