mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-11 19:56:40 +08:00
Faster synchronization Fence
primitive (#1773)
* try faster synchronization move event fixes update bench fix fix * non-functioning kernel * try alternative fence * cleanup barrier * get rid of event_fence * update benchmarks * doc string in metal fence
This commit is contained in:
parent
0c259961ac
commit
a4667da1eb
55
benchmarks/python/synchronize_bench.py
Normal file
55
benchmarks/python/synchronize_bench.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
rank = mx.distributed.init().rank()
|
||||||
|
|
||||||
|
|
||||||
|
def timeit(fn, a):
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
for _ in range(5):
|
||||||
|
mx.eval(fn(a))
|
||||||
|
|
||||||
|
its = 10
|
||||||
|
tic = time.perf_counter()
|
||||||
|
for _ in range(its):
|
||||||
|
mx.eval(fn(a))
|
||||||
|
toc = time.perf_counter()
|
||||||
|
ms = 1000 * (toc - tic) / its
|
||||||
|
return ms
|
||||||
|
|
||||||
|
|
||||||
|
def all_reduce_benchmark():
|
||||||
|
a = mx.ones((5, 5), mx.int32)
|
||||||
|
|
||||||
|
its_per_eval = 100
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
for _ in range(its_per_eval):
|
||||||
|
x = mx.distributed.all_sum(x)
|
||||||
|
x = x - 1
|
||||||
|
return x
|
||||||
|
|
||||||
|
ms = timeit(fn, a) / its_per_eval
|
||||||
|
if rank == 0:
|
||||||
|
print(f"All Reduce: time per iteration {ms:.6f} (ms)")
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_benchmark():
|
||||||
|
a = mx.ones((5, 5), mx.int32)
|
||||||
|
its_per_eval = 100
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
for _ in range(its_per_eval):
|
||||||
|
x = mx.distributed.all_gather(x)[0]
|
||||||
|
return x
|
||||||
|
|
||||||
|
ms = timeit(fn, a) / its_per_eval
|
||||||
|
if rank == 0:
|
||||||
|
print(f"All gather: time per iteration {ms:.6f} (ms)")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
all_reduce_benchmark()
|
||||||
|
all_gather_benchmark()
|
@ -89,6 +89,7 @@ target_sources(
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
|
||||||
|
@ -134,6 +134,13 @@ CommandEncoder::~CommandEncoder() {
|
|||||||
enc_->release();
|
enc_->release();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::set_buffer(
|
||||||
|
const MTL::Buffer* buf,
|
||||||
|
int idx,
|
||||||
|
int64_t offset /* = 0 */) {
|
||||||
|
enc_->setBuffer(buf, offset, idx);
|
||||||
|
}
|
||||||
|
|
||||||
void CommandEncoder::set_input_array(
|
void CommandEncoder::set_input_array(
|
||||||
const array& a,
|
const array& a,
|
||||||
int idx,
|
int idx,
|
||||||
@ -155,6 +162,10 @@ void CommandEncoder::set_output_array(
|
|||||||
int64_t offset /* = 0 */) {
|
int64_t offset /* = 0 */) {
|
||||||
// Add barriers before adding the output to the output set
|
// Add barriers before adding the output to the output set
|
||||||
set_input_array(a, idx, offset);
|
set_input_array(a, idx, offset);
|
||||||
|
register_output_array(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::register_output_array(array& a) {
|
||||||
all_outputs_.insert(a.buffer().ptr());
|
all_outputs_.insert(a.buffer().ptr());
|
||||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||||
if (concurrent_) {
|
if (concurrent_) {
|
||||||
@ -189,6 +200,10 @@ void CommandEncoder::dispatch_threads(
|
|||||||
enc_->dispatchThreads(grid_dims, group_dims);
|
enc_->dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::barrier() {
|
||||||
|
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
|
||||||
|
}
|
||||||
|
|
||||||
Device::Device() {
|
Device::Device() {
|
||||||
auto pool = new_scoped_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
device_ = load_device();
|
device_ = load_device();
|
||||||
|
@ -60,9 +60,11 @@ struct CommandEncoder {
|
|||||||
|
|
||||||
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
||||||
void set_output_array(array& a, int idx, int64_t offset = 0);
|
void set_output_array(array& a, int idx, int64_t offset = 0);
|
||||||
|
void register_output_array(array& a);
|
||||||
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||||
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
|
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||||
void maybeInsertBarrier();
|
void maybeInsertBarrier();
|
||||||
|
void set_buffer(const MTL::Buffer* buf, int idx, int64_t offset = 0);
|
||||||
|
|
||||||
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
|
void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) {
|
||||||
enc_->setComputePipelineState(kernel);
|
enc_->setComputePipelineState(kernel);
|
||||||
@ -110,6 +112,8 @@ struct CommandEncoder {
|
|||||||
return all_outputs_;
|
return all_outputs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void barrier();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MTL::ComputeCommandEncoder* enc_;
|
MTL::ComputeCommandEncoder* enc_;
|
||||||
bool needs_barrier_{false};
|
bool needs_barrier_{false};
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
#include "mlx/backend/common/utils.h"
|
#include "mlx/backend/common/utils.h"
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/event.h"
|
#include "mlx/backend/metal/event.h"
|
||||||
|
#include "mlx/backend/metal/fence.h"
|
||||||
#include "mlx/distributed/ops.h"
|
#include "mlx/distributed/ops.h"
|
||||||
#include "mlx/distributed/primitives.h"
|
#include "mlx/distributed/primitives.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
@ -26,23 +27,28 @@ void AllReduce::eval_gpu(
|
|||||||
assert(outputs.size() == 1);
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
Fence f{stream()};
|
||||||
|
|
||||||
|
if (in.event().valid()) {
|
||||||
|
f.update_gpu(in);
|
||||||
|
}
|
||||||
|
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
if (in.is_donatable()) {
|
if (in.is_donatable()) {
|
||||||
out.move_shared_buffer(in);
|
out.move_shared_buffer(in);
|
||||||
} else {
|
} else {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
}
|
}
|
||||||
|
f.wait_gpu(out);
|
||||||
|
|
||||||
auto e = Event(stream());
|
|
||||||
e.set_value(1);
|
|
||||||
signal_and_wait(in.event(), e);
|
|
||||||
auto task = [in = in,
|
auto task = [in = in,
|
||||||
out = out,
|
out = out,
|
||||||
e = std::move(e),
|
f = std::move(f),
|
||||||
reduce_type = reduce_type_,
|
reduce_type = reduce_type_,
|
||||||
group = group()]() mutable {
|
group = group()]() mutable {
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
in.event().wait();
|
f.wait();
|
||||||
}
|
}
|
||||||
switch (reduce_type) {
|
switch (reduce_type) {
|
||||||
case Sum:
|
case Sum:
|
||||||
@ -52,7 +58,7 @@ void AllReduce::eval_gpu(
|
|||||||
default:
|
default:
|
||||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||||
}
|
}
|
||||||
e.signal();
|
f.update();
|
||||||
};
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
}
|
}
|
||||||
@ -67,17 +73,20 @@ void AllGather::eval_gpu(
|
|||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto e = Event(stream());
|
Fence f{stream()};
|
||||||
e.set_value(1);
|
|
||||||
signal_and_wait(in.event(), e);
|
if (in.event().valid()) {
|
||||||
|
f.update_gpu(in);
|
||||||
|
}
|
||||||
|
f.wait_gpu(out);
|
||||||
|
|
||||||
auto task =
|
auto task =
|
||||||
[in = in, out = out, e = std::move(e), group = group()]() mutable {
|
[in = in, out = out, f = std::move(f), group = group()]() mutable {
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
in.event().wait();
|
f.wait();
|
||||||
}
|
}
|
||||||
distributed::detail::all_gather(group, in, out);
|
distributed::detail::all_gather(group, in, out);
|
||||||
e.signal();
|
f.update();
|
||||||
};
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
}
|
}
|
||||||
@ -89,22 +98,28 @@ void Send::eval_gpu(
|
|||||||
assert(outputs.size() == 1);
|
assert(outputs.size() == 1);
|
||||||
|
|
||||||
auto& in = inputs[0];
|
auto& in = inputs[0];
|
||||||
|
|
||||||
|
// Encode a signal event for the input
|
||||||
|
Fence f{stream()};
|
||||||
|
if (in.event().valid()) {
|
||||||
|
f.update_gpu(in);
|
||||||
|
}
|
||||||
|
|
||||||
auto& out = outputs[0];
|
auto& out = outputs[0];
|
||||||
move_or_copy(in, out);
|
move_or_copy(in, out);
|
||||||
|
|
||||||
// Schedule an async send on the comm stream
|
// Schedule an async send on the comm stream
|
||||||
auto task = [in = in, out = out, group = group(), dst = dst_]() mutable {
|
auto task = [in = in,
|
||||||
|
out = out,
|
||||||
|
f = std::move(f),
|
||||||
|
group = group(),
|
||||||
|
dst = dst_]() mutable {
|
||||||
if (in.event().valid()) {
|
if (in.event().valid()) {
|
||||||
in.event().wait();
|
f.wait();
|
||||||
}
|
}
|
||||||
distributed::detail::send(group, out, dst);
|
distributed::detail::send(group, out, dst);
|
||||||
};
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
|
|
||||||
// Encode a signal event for the input
|
|
||||||
if (in.event().valid()) {
|
|
||||||
encode_signal(in.event());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Recv::eval_gpu(
|
void Recv::eval_gpu(
|
||||||
@ -117,16 +132,14 @@ void Recv::eval_gpu(
|
|||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
auto e = Event(stream());
|
Fence f{stream()};
|
||||||
e.set_value(1);
|
f.wait_gpu(out);
|
||||||
|
|
||||||
encode_wait(e);
|
|
||||||
|
|
||||||
// Schedule an async recv on the comm stream
|
// Schedule an async recv on the comm stream
|
||||||
auto task =
|
auto task =
|
||||||
[out = out, e = std::move(e), group = group(), src = src_]() mutable {
|
[out = out, f = std::move(f), group = group(), src = src_]() mutable {
|
||||||
distributed::detail::recv(group, out, src);
|
distributed::detail::recv(group, out, src);
|
||||||
e.signal();
|
f.update();
|
||||||
};
|
};
|
||||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||||
}
|
}
|
||||||
|
145
mlx/backend/metal/fence.cpp
Normal file
145
mlx/backend/metal/fence.cpp
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/fence.h"
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/event.h"
|
||||||
|
#include "mlx/backend/metal/metal_impl.h"
|
||||||
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
Fence::Fence(const Stream& stream) : stream_(stream) {
|
||||||
|
auto d = metal::device(stream.device).mtl_device();
|
||||||
|
if (!d->supportsFamily(MTL::GPUFamilyMetal3)) {
|
||||||
|
use_fast_ = false;
|
||||||
|
} else if (__builtin_available(macOS 15, iOS 18, *)) {
|
||||||
|
use_fast_ = env::metal_fast_synch();
|
||||||
|
}
|
||||||
|
if (!use_fast_) {
|
||||||
|
// Wraps Metal SharedEvent
|
||||||
|
auto dtor = [](void* ptr) {
|
||||||
|
auto p = metal::new_scoped_memory_pool();
|
||||||
|
static_cast<MTL::SharedEvent*>(ptr)->release();
|
||||||
|
};
|
||||||
|
auto p = metal::new_scoped_memory_pool();
|
||||||
|
fence_ = std::shared_ptr<void>(
|
||||||
|
metal::device(stream.device).mtl_device()->newSharedEvent(), dtor);
|
||||||
|
} else {
|
||||||
|
auto dtor = [](void* buf) {
|
||||||
|
allocator::free(static_cast<MTL::Buffer*>(buf));
|
||||||
|
};
|
||||||
|
auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr();
|
||||||
|
fence_ = std::shared_ptr<void>(buf, dtor);
|
||||||
|
cpu_value()[0] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Fence::wait_gpu(array& x) {
|
||||||
|
gpu_count_++;
|
||||||
|
auto& d = metal::device(stream_.device);
|
||||||
|
auto idx = stream_.index;
|
||||||
|
|
||||||
|
if (!use_fast_) {
|
||||||
|
d.end_encoding(idx);
|
||||||
|
auto command_buffer = d.get_command_buffer(idx);
|
||||||
|
command_buffer->encodeWait(
|
||||||
|
static_cast<MTL::Event*>(fence_.get()), gpu_count_);
|
||||||
|
command_buffer->addCompletedHandler(
|
||||||
|
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto& compute_encoder = d.get_command_encoder(idx);
|
||||||
|
|
||||||
|
// Register the output to ensure that no kernels which depends on the
|
||||||
|
// output starts before this one is done
|
||||||
|
compute_encoder.register_output_array(x);
|
||||||
|
|
||||||
|
auto kernel = d.get_kernel("fence_wait");
|
||||||
|
MTL::Size kernel_dims = MTL::Size(1, 1, 1);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
auto buf = static_cast<MTL::Buffer*>(fence_.get());
|
||||||
|
compute_encoder.set_buffer(buf, 0);
|
||||||
|
compute_encoder.set_bytes(gpu_count_, 1);
|
||||||
|
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
||||||
|
|
||||||
|
d.get_command_buffer(idx)->addCompletedHandler(
|
||||||
|
[fence = fence_](MTL::CommandBuffer* cbuf) {});
|
||||||
|
}
|
||||||
|
|
||||||
|
void Fence::update_gpu(const array& x) {
|
||||||
|
gpu_count_++;
|
||||||
|
auto& d = metal::device(stream_.device);
|
||||||
|
auto idx = stream_.index;
|
||||||
|
|
||||||
|
if (!use_fast_) {
|
||||||
|
d.end_encoding(idx);
|
||||||
|
auto command_buffer = d.get_command_buffer(idx);
|
||||||
|
command_buffer->encodeSignalEvent(
|
||||||
|
static_cast<MTL::Event*>(fence_.get()), gpu_count_);
|
||||||
|
command_buffer->addCompletedHandler(
|
||||||
|
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Launch input visibility kernel
|
||||||
|
auto& compute_encoder = d.get_command_encoder(idx);
|
||||||
|
auto kernel = d.get_kernel("input_coherent");
|
||||||
|
uint32_t nthreads =
|
||||||
|
(x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / sizeof(uint32_t);
|
||||||
|
|
||||||
|
MTL::Size group_dims = MTL::Size(1024, 1, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
compute_encoder.set_input_array(x, 0);
|
||||||
|
compute_encoder.set_bytes(nthreads, 1);
|
||||||
|
compute_encoder.dispatch_threadgroups(group_dims, grid_dims);
|
||||||
|
|
||||||
|
// Barrier on previous kernel
|
||||||
|
compute_encoder.barrier();
|
||||||
|
|
||||||
|
// Launch value update kernel
|
||||||
|
kernel = d.get_kernel("fence_update");
|
||||||
|
MTL::Size kernel_dims = MTL::Size(1, 1, 1);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
auto buf = static_cast<MTL::Buffer*>(fence_.get());
|
||||||
|
compute_encoder.set_buffer(buf, 0);
|
||||||
|
compute_encoder.set_bytes(gpu_count_, 1);
|
||||||
|
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
||||||
|
|
||||||
|
d.get_command_buffer(idx)->addCompletedHandler(
|
||||||
|
[fence = fence_](MTL::CommandBuffer* cbuf) {});
|
||||||
|
}
|
||||||
|
|
||||||
|
void Fence::wait() {
|
||||||
|
cpu_count_++;
|
||||||
|
if (!use_fast_) {
|
||||||
|
if (!static_cast<MTL::SharedEvent*>(fence_.get())
|
||||||
|
->waitUntilSignaledValue(cpu_count_, -1)) {
|
||||||
|
throw std::runtime_error("[Event::wait] Timed out");
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
while (cpu_value()[0] < cpu_count_) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Fence::update() {
|
||||||
|
cpu_count_++;
|
||||||
|
|
||||||
|
if (!use_fast_) {
|
||||||
|
static_cast<MTL::SharedEvent*>(fence_.get())->setSignaledValue(cpu_count_);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
cpu_value()[0] = cpu_count_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::atomic_uint* Fence::cpu_value() {
|
||||||
|
return static_cast<std::atomic_uint*>(
|
||||||
|
static_cast<MTL::Buffer*>(fence_.get())->contents());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
40
mlx/backend/metal/fence.h
Normal file
40
mlx/backend/metal/fence.h
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#include "mlx/backend/metal/device.h"
|
||||||
|
|
||||||
|
namespace mlx::core {
|
||||||
|
|
||||||
|
/* A fence to be used for synchronizing work between the CPU and GPU
|
||||||
|
*
|
||||||
|
* Calls to `update_gpu` should be paired with calls to `wait`. This ensures
|
||||||
|
* that the array passed to `update_gpu` is computed and visible to the CPU
|
||||||
|
* after the call to `wait` returns.
|
||||||
|
*
|
||||||
|
* Calls to `update` should be paired with calls to `wait_gpu`. This ensures
|
||||||
|
* that the array passed to `wait_gpu` will not be read by the GPU until the CPU
|
||||||
|
* has called `update`.
|
||||||
|
*
|
||||||
|
* The fence supports slow (default) and fast mode. Fast mode requires setting
|
||||||
|
* the environment variable `MLX_METAL_FAST_SYNCH=1`. Fast mode also requires
|
||||||
|
* Metal 3.2+ (macOS 15+, iOS 18+).
|
||||||
|
*/
|
||||||
|
class Fence {
|
||||||
|
public:
|
||||||
|
Fence(const Stream& stream);
|
||||||
|
|
||||||
|
void update_gpu(const array& x);
|
||||||
|
void wait_gpu(array& x);
|
||||||
|
|
||||||
|
void wait();
|
||||||
|
void update();
|
||||||
|
|
||||||
|
private:
|
||||||
|
Stream stream_;
|
||||||
|
std::shared_ptr<void> fence_;
|
||||||
|
uint32_t cpu_count_{0};
|
||||||
|
uint32_t gpu_count_{0};
|
||||||
|
bool use_fast_;
|
||||||
|
std::atomic_uint* cpu_value();
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlx::core
|
@ -45,6 +45,9 @@ build_kernel(random)
|
|||||||
build_kernel(rms_norm)
|
build_kernel(rms_norm)
|
||||||
build_kernel(rope)
|
build_kernel(rope)
|
||||||
build_kernel(scaled_dot_product_attention sdpa_vector.h)
|
build_kernel(scaled_dot_product_attention sdpa_vector.h)
|
||||||
|
if(MLX_METAL_VERSION GREATER_EQUAL 320)
|
||||||
|
build_kernel(fence)
|
||||||
|
endif()
|
||||||
|
|
||||||
set(STEEL_HEADERS
|
set(STEEL_HEADERS
|
||||||
steel/defines.h
|
steel/defines.h
|
||||||
|
52
mlx/backend/metal/kernels/fence.metal
Normal file
52
mlx/backend/metal/kernels/fence.metal
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
// Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
#pragma METAL internals : enable
|
||||||
|
|
||||||
|
#ifndef __METAL_MEMORY_SCOPE_SYSTEM__
|
||||||
|
#define __METAL_MEMORY_SCOPE_SYSTEM__ 3
|
||||||
|
namespace metal {
|
||||||
|
constexpr constant metal::thread_scope thread_scope_system =
|
||||||
|
static_cast<thread_scope>(__METAL_MEMORY_SCOPE_SYSTEM__);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <metal_atomic>
|
||||||
|
|
||||||
|
[[kernel]] void input_coherent(
|
||||||
|
volatile coherent(system) device uint* input [[buffer(0)]],
|
||||||
|
const constant uint& size [[buffer(1)]],
|
||||||
|
uint index [[thread_position_in_grid]]) {
|
||||||
|
if (index < size) {
|
||||||
|
input[index] = input[index];
|
||||||
|
}
|
||||||
|
metal::atomic_thread_fence(
|
||||||
|
metal::mem_flags::mem_device,
|
||||||
|
metal::memory_order_seq_cst,
|
||||||
|
metal::thread_scope_system);
|
||||||
|
}
|
||||||
|
|
||||||
|
// single thread kernel to update timestamp
|
||||||
|
[[kernel]] void fence_update(
|
||||||
|
volatile coherent(system) device uint* timestamp [[buffer(0)]],
|
||||||
|
constant uint& value [[buffer(1)]]) {
|
||||||
|
timestamp[0] = value;
|
||||||
|
metal::atomic_thread_fence(
|
||||||
|
metal::mem_flags::mem_device,
|
||||||
|
metal::memory_order_seq_cst,
|
||||||
|
metal::thread_scope_system);
|
||||||
|
}
|
||||||
|
|
||||||
|
// single thread kernel to spin wait for timestamp value
|
||||||
|
[[kernel]] void fence_wait(
|
||||||
|
volatile coherent(system) device uint* timestamp [[buffer(0)]],
|
||||||
|
constant uint& value [[buffer(1)]]) {
|
||||||
|
while (1) {
|
||||||
|
metal::atomic_thread_fence(
|
||||||
|
metal::mem_flags::mem_device,
|
||||||
|
metal::memory_order_seq_cst,
|
||||||
|
metal::thread_scope_system);
|
||||||
|
if (timestamp[0] >= value) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -2,6 +2,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
|
#include "mlx/backend/metal/event.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
@ -69,16 +70,13 @@ std::function<void()> make_task(array arr, bool signal) {
|
|||||||
|
|
||||||
if (signal ||
|
if (signal ||
|
||||||
d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) {
|
d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) {
|
||||||
d.end_encoding(s.index);
|
|
||||||
if (signal) {
|
if (signal) {
|
||||||
command_buffer->encodeSignalEvent(
|
encode_signal(arr.event());
|
||||||
static_cast<MTL::Event*>(arr.event().raw_event().get()),
|
|
||||||
arr.event().value());
|
|
||||||
}
|
}
|
||||||
|
d.end_encoding(s.index);
|
||||||
scheduler::notify_new_task(s);
|
scheduler::notify_new_task(s);
|
||||||
command_buffer->addCompletedHandler(
|
command_buffer->addCompletedHandler(
|
||||||
[s, buffers = std::move(buffers), event = arr.event()](
|
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||||
MTL::CommandBuffer* cbuf) {
|
|
||||||
scheduler::notify_task_completion(s);
|
scheduler::notify_task_completion(s);
|
||||||
check_error(cbuf);
|
check_error(cbuf);
|
||||||
});
|
});
|
||||||
|
@ -126,6 +126,11 @@ inline int max_ops_per_buffer() {
|
|||||||
return max_ops_per_buffer_;
|
return max_ops_per_buffer_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool metal_fast_synch() {
|
||||||
|
static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0);
|
||||||
|
return metal_fast_synch;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace env
|
} // namespace env
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user