mirror of
https://github.com/ml-explore/mlx.git
synced 2025-11-06 20:20:11 +08:00
fix
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
// Copyright © 2025 Apple Inc.
|
// Copyright © 2025 Apple Inc.
|
||||||
|
|
||||||
#include "mlx/fence.h"
|
#include "mlx/fence.h"
|
||||||
|
#include "mlx/backend/cuda/allocator.h"
|
||||||
|
#include "mlx/backend/cuda/device.h"
|
||||||
#include "mlx/backend/cuda/event.h"
|
#include "mlx/backend/cuda/event.h"
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
@@ -20,8 +22,23 @@ void Fence::wait(Stream s, const array&) {
|
|||||||
fence->event.wait(fence->count);
|
fence->event.wait(fence->count);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Fence::update(Stream s, const array&) {
|
void Fence::update(Stream s, const array& a, bool cross_device) {
|
||||||
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
auto* fence = static_cast<FenceImpl*>(fence_.get());
|
||||||
|
if (cross_device) {
|
||||||
|
// Move to managed memory if there is a device switch
|
||||||
|
auto& cbuf =
|
||||||
|
*static_cast<cu::CudaBuffer*>(const_cast<array&>(a).buffer().ptr());
|
||||||
|
if (cbuf.device != -1) {
|
||||||
|
void* new_data;
|
||||||
|
CHECK_CUDA_ERROR(cudaMallocManaged(&new_data, cbuf.size));
|
||||||
|
cbuf.device = -1;
|
||||||
|
CHECK_CUDA_ERROR(
|
||||||
|
cudaMemcpyAsync(new_data, cbuf.data, cbuf.size, cudaMemcpyDefault));
|
||||||
|
auto& encoder = cu::device(s.device).get_command_encoder(s);
|
||||||
|
CHECK_CUDA_ERROR(cudaFreeAsync(cbuf.data, encoder.stream()));
|
||||||
|
cbuf.data = new_data;
|
||||||
|
}
|
||||||
|
}
|
||||||
fence->count++;
|
fence->count++;
|
||||||
fence->event.signal(s, fence->count);
|
fence->event.signal(s, fence->count);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ void Fence::wait(Stream stream, const array& x) {
|
|||||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||||
}
|
}
|
||||||
|
|
||||||
void Fence::update(Stream stream, const array& x) {
|
void Fence::update(Stream stream, const array& x, bool cross_device) {
|
||||||
auto& f = *static_cast<FenceImpl*>(fence_.get());
|
auto& f = *static_cast<FenceImpl*>(fence_.get());
|
||||||
f.count++;
|
f.count++;
|
||||||
|
|
||||||
@@ -130,21 +130,23 @@ void Fence::update(Stream stream, const array& x) {
|
|||||||
|
|
||||||
// Launch input visibility kernels
|
// Launch input visibility kernels
|
||||||
auto& compute_encoder = d.get_command_encoder(idx);
|
auto& compute_encoder = d.get_command_encoder(idx);
|
||||||
auto kernel = d.get_kernel("input_coherent");
|
if (cross_device) {
|
||||||
uint32_t nthreads =
|
auto kernel = d.get_kernel("input_coherent");
|
||||||
(x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / sizeof(uint32_t);
|
uint32_t nthreads = (x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) /
|
||||||
MTL::Size group_dims = MTL::Size(1024, 1, 1);
|
sizeof(uint32_t);
|
||||||
MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1);
|
MTL::Size group_dims = MTL::Size(1024, 1, 1);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1);
|
||||||
compute_encoder.set_input_array(x, 0);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
compute_encoder.set_bytes(nthreads, 1);
|
compute_encoder.set_input_array(x, 0);
|
||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.set_bytes(nthreads, 1);
|
||||||
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
// Barrier on previous kernels
|
// Barrier on previous kernels
|
||||||
compute_encoder.barrier();
|
compute_encoder.barrier();
|
||||||
|
|
||||||
// Launch value update kernel
|
// Launch value update kernel
|
||||||
kernel = d.get_kernel("fence_update");
|
auto kernel = d.get_kernel("fence_update");
|
||||||
MTL::Size kernel_dims = MTL::Size(1, 1, 1);
|
MTL::Size kernel_dims = MTL::Size(1, 1, 1);
|
||||||
compute_encoder.set_compute_pipeline_state(kernel);
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ void Fence::wait(Stream stream, const array&) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Fence::update(Stream stream, const array&) {
|
void Fence::update(Stream stream, const array&, bool) {
|
||||||
auto& f = *static_cast<FenceImpl*>(fence_.get());
|
auto& f = *static_cast<FenceImpl*>(fence_.get());
|
||||||
f.count++;
|
f.count++;
|
||||||
if (stream.device == Device::cpu) {
|
if (stream.device == Device::cpu) {
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ namespace mlx::core {
|
|||||||
* `wait` returns. The array passed to `wait` will not be read until all
|
* `wait` returns. The array passed to `wait` will not be read until all
|
||||||
* previous calls to `update` have completed.
|
* previous calls to `update` have completed.
|
||||||
*
|
*
|
||||||
* Note, calls to `update` should always from the same thread or explicitly
|
* Note, calls to `update` should always be from the same thread or explicitly
|
||||||
* synchronized so that they occur in sequence. Calls to `wait` can be on any
|
* synchronized so that they occur in sequence. Calls to `wait` can be on any
|
||||||
* thread.
|
* thread.
|
||||||
*
|
*
|
||||||
@@ -29,7 +29,7 @@ class Fence {
|
|||||||
Fence() {};
|
Fence() {};
|
||||||
explicit Fence(Stream stream);
|
explicit Fence(Stream stream);
|
||||||
|
|
||||||
void update(Stream stream, const array& x);
|
void update(Stream stream, const array& x, bool cross_device);
|
||||||
void wait(Stream stream, const array& x);
|
void wait(Stream stream, const array& x);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Map of array id that needs fence and stream it's computed on
|
// Map of array id that needs fence and stream it's computed on
|
||||||
std::unordered_map<uintptr_t, uint32_t> needs_fence;
|
std::unordered_map<uintptr_t, std::pair<uint32_t, bool>> needs_fence;
|
||||||
|
|
||||||
auto synchronizer = array(
|
auto synchronizer = array(
|
||||||
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
|
{}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
|
||||||
@@ -114,7 +114,14 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
"https://github.com/ml-explore/mlx/issues.");
|
"https://github.com/ml-explore/mlx/issues.");
|
||||||
}
|
}
|
||||||
if (a.primitive().stream() != in.primitive().stream()) {
|
if (a.primitive().stream() != in.primitive().stream()) {
|
||||||
needs_fence.emplace(in.id(), in.primitive().stream().index);
|
bool device_switch =
|
||||||
|
a.primitive().stream().device != in.primitive().stream().device;
|
||||||
|
auto [it, inserted] = needs_fence.emplace(
|
||||||
|
in.id(),
|
||||||
|
std::make_pair(in.primitive().stream().index, device_switch));
|
||||||
|
if (!inserted) {
|
||||||
|
it->second.second |= device_switch;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,7 +223,7 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
// Use fence to wait within a single eval
|
// Use fence to wait within a single eval
|
||||||
// Get the input array's stream fence and wait on the
|
// Get the input array's stream fence and wait on the
|
||||||
// output arrays stream
|
// output arrays stream
|
||||||
fences[it->second].wait(stream, in);
|
fences[it->second.first].wait(stream, in);
|
||||||
} else if (in.event().valid()) {
|
} else if (in.event().valid()) {
|
||||||
if (in.event().is_signaled()) {
|
if (in.event().is_signaled()) {
|
||||||
in.detach_event();
|
in.detach_event();
|
||||||
@@ -251,12 +258,12 @@ array eval_impl(std::vector<array> outputs, bool async) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) {
|
auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) {
|
||||||
if (needs_fence.find(a.id()) != needs_fence.end()) {
|
if (auto nf = needs_fence.find(a.id()); nf != needs_fence.end()) {
|
||||||
auto it = fences.find(stream.index);
|
auto it = fences.find(stream.index);
|
||||||
if (it == fences.end()) {
|
if (it == fences.end()) {
|
||||||
it = fences.emplace(stream.index, Fence{stream}).first;
|
it = fences.emplace(stream.index, Fence{stream}).first;
|
||||||
}
|
}
|
||||||
it->second.update(stream, a);
|
it->second.update(stream, a, nf->second.second);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user