This commit is contained in:
Awni Hannun
2025-11-04 07:41:17 -08:00
parent 529842fed9
commit fc00f16e30
5 changed files with 46 additions and 20 deletions

View File

@@ -1,6 +1,8 @@
// Copyright © 2025 Apple Inc. // Copyright © 2025 Apple Inc.
#include "mlx/fence.h" #include "mlx/fence.h"
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/event.h"
namespace mlx::core { namespace mlx::core {
@@ -20,8 +22,23 @@ void Fence::wait(Stream s, const array&) {
fence->event.wait(fence->count); fence->event.wait(fence->count);
} }
void Fence::update(Stream s, const array&) { void Fence::update(Stream s, const array& a, bool cross_device) {
auto* fence = static_cast<FenceImpl*>(fence_.get()); auto* fence = static_cast<FenceImpl*>(fence_.get());
if (cross_device) {
// Move to managed memory if there is a device switch
auto& cbuf =
*static_cast<cu::CudaBuffer*>(const_cast<array&>(a).buffer().ptr());
if (cbuf.device != -1) {
void* new_data;
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
cbuf.device = -1;
CHECK_CUDA_ERROR(
cudaMemcpyAsync(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault));
auto& encoder = cu::device(s.device).get_command_encoder(s);
CHECK_CUDA_ERROR(cudaFreeAsync(cbuf.data, encoder.stream()));
cbuf.data = new_data;
}
}
fence->count++; fence->count++;
fence->event.signal(s, fence->count); fence->event.signal(s, fence->count);
} }

View File

@@ -99,7 +99,7 @@ void Fence::wait(Stream stream, const array& x) {
[fence_ = fence_](MTL::CommandBuffer* cbuf) {}); [fence_ = fence_](MTL::CommandBuffer* cbuf) {});
} }
void Fence::update(Stream stream, const array& x) { void Fence::update(Stream stream, const array& x, bool cross_device) {
auto& f = *static_cast<FenceImpl*>(fence_.get()); auto& f = *static_cast<FenceImpl*>(fence_.get());
f.count++; f.count++;
@@ -130,21 +130,23 @@ void Fence::update(Stream stream, const array& x) {
// Launch input visibility kernels // Launch input visibility kernels
auto& compute_encoder = d.get_command_encoder(idx); auto& compute_encoder = d.get_command_encoder(idx);
if (cross_device) {
auto kernel = d.get_kernel("input_coherent"); auto kernel = d.get_kernel("input_coherent");
uint32_t nthreads = uint32_t nthreads = (x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) /
(x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / sizeof(uint32_t); sizeof(uint32_t);
MTL::Size group_dims = MTL::Size(1024, 1, 1); MTL::Size group_dims = MTL::Size(1024, 1, 1);
MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1); MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x, 0); compute_encoder.set_input_array(x, 0);
compute_encoder.set_bytes(nthreads, 1); compute_encoder.set_bytes(nthreads, 1);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims); compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
// Barrier on previous kernels // Barrier on previous kernels
compute_encoder.barrier(); compute_encoder.barrier();
// Launch value update kernel // Launch value update kernel
kernel = d.get_kernel("fence_update"); auto kernel = d.get_kernel("fence_update");
MTL::Size kernel_dims = MTL::Size(1, 1, 1); MTL::Size kernel_dims = MTL::Size(1, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel); compute_encoder.set_compute_pipeline_state(kernel);

View File

@@ -36,7 +36,7 @@ void Fence::wait(Stream stream, const array&) {
} }
} }
void Fence::update(Stream stream, const array&) { void Fence::update(Stream stream, const array&, bool) {
auto& f = *static_cast<FenceImpl*>(fence_.get()); auto& f = *static_cast<FenceImpl*>(fence_.get());
f.count++; f.count++;
if (stream.device == Device::cpu) { if (stream.device == Device::cpu) {

View File

@@ -15,7 +15,7 @@ namespace mlx::core {
* `wait` returns. The array passed to `wait` will not be read until all * `wait` returns. The array passed to `wait` will not be read until all
* previous calls to `update` have completed. * previous calls to `update` have completed.
* *
* Note, calls to `update` should always from the same thread or explicitly * Note, calls to `update` should always be from the same thread or explicitly
* synchronized so that they occur in sequence. Calls to `wait` can be on any * synchronized so that they occur in sequence. Calls to `wait` can be on any
* thread. * thread.
* *
@@ -29,7 +29,7 @@ class Fence {
Fence() {}; Fence() {};
explicit Fence(Stream stream); explicit Fence(Stream stream);
void update(Stream stream, const array& x); void update(Stream stream, const array& x, bool cross_device);
void wait(Stream stream, const array& x); void wait(Stream stream, const array& x);
private: private:

View File

@@ -62,7 +62,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
} }
// Map of array id that needs fence and stream it's computed on // Map of array id that needs fence and stream it's computed on
std::unordered_map<uintptr_t, uint32_t> needs_fence; std::unordered_map<uintptr_t, std::pair<uint32_t, bool>> needs_fence;
auto synchronizer = array( auto synchronizer = array(
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs)); {}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
@@ -114,7 +114,14 @@ array eval_impl(std::vector<array> outputs, bool async) {
"https://github.com/ml-explore/mlx/issues."); "https://github.com/ml-explore/mlx/issues.");
} }
if (a.primitive().stream() != in.primitive().stream()) { if (a.primitive().stream() != in.primitive().stream()) {
needs_fence.emplace(in.id(), in.primitive().stream().index); bool device_switch =
a.primitive().stream().device != in.primitive().stream().device;
auto [it, inserted] = needs_fence.emplace(
in.id(),
std::make_pair(in.primitive().stream().index, device_switch));
if (!inserted) {
it->second.second |= device_switch;
}
} }
} }
@@ -216,7 +223,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
// Use fence to wait within a single eval // Use fence to wait within a single eval
// Get the input array's stream fence and wait on the // Get the input array's stream fence and wait on the
// output arrays stream // output arrays stream
fences[it->second].wait(stream, in); fences[it->second.first].wait(stream, in);
} else if (in.event().valid()) { } else if (in.event().valid()) {
if (in.event().is_signaled()) { if (in.event().is_signaled()) {
in.detach_event(); in.detach_event();
@@ -251,12 +258,12 @@ array eval_impl(std::vector<array> outputs, bool async) {
} }
auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) { auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) {
if (needs_fence.find(a.id()) != needs_fence.end()) { if (auto nf = needs_fence.find(a.id()); nf != needs_fence.end()) {
auto it = fences.find(stream.index); auto it = fences.find(stream.index);
if (it == fences.end()) { if (it == fences.end()) {
it = fences.emplace(stream.index, Fence{stream}).first; it = fences.emplace(stream.index, Fence{stream}).first;
} }
it->second.update(stream, a); it->second.update(stream, a, nf->second.second);
} }
}; };