mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-22 02:58:16 +08:00
redesign for faster cpu/gpu synch (#1869)
* redesign for faster cpu/gpu synch * load + more async CPU * use command encoder API and move more ops to use it * make fence back-end generic + CPU only fence * faster build * fix async eval * fixes + handle temporaries * fix / improve cpu conv * remove unused status, fix siblings * fix extensions * fix * fix no cpu build * format * comments * fix perf regression, remove unecessary abort * fix events, task limit cpu * fix waiting * fix donation / temporaries in normalization
This commit is contained in:
@@ -102,16 +102,9 @@ void binary_op_gpu_inplace(
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// - If a is donated it goes to the first output
|
||||
// - If b is donated it goes to the first output if a was not donated
|
||||
// otherwise it goes to the second output.
|
||||
// - If there is only one output only one of a and b will be donated.
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
int arg_idx = 0;
|
||||
compute_encoder.set_input_array(donate_a ? outputs[0] : a, arg_idx++);
|
||||
compute_encoder.set_input_array(
|
||||
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, arg_idx++);
|
||||
compute_encoder.set_input_array(a, arg_idx++);
|
||||
compute_encoder.set_input_array(b, arg_idx++);
|
||||
compute_encoder.set_output_array(outputs[0], arg_idx++);
|
||||
if (outputs.size() == 2) {
|
||||
compute_encoder.set_output_array(outputs[1], arg_idx++);
|
||||
@@ -164,8 +157,8 @@ void binary_op_gpu(
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt, true);
|
||||
set_binary_op_output_data(a, b, outputs[0], bopt);
|
||||
set_binary_op_output_data(a, b, outputs[1], bopt);
|
||||
binary_op_gpu_inplace(inputs, outputs, op, s);
|
||||
}
|
||||
|
||||
@@ -195,7 +188,7 @@ void binary_op_gpu(
|
||||
auto& a = inputs[0];
|
||||
auto& b = inputs[1];
|
||||
auto bopt = get_binary_op_type(a, b);
|
||||
set_binary_op_output_data(a, b, out, bopt, true);
|
||||
set_binary_op_output_data(a, b, out, bopt);
|
||||
binary_op_gpu_inplace(inputs, out, op, s);
|
||||
}
|
||||
|
||||
|
@@ -457,7 +457,7 @@ void Compiled::eval_gpu(
|
||||
}
|
||||
|
||||
compiled_allocate_outputs(
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous, true);
|
||||
inputs, outputs, inputs_, constant_ids_, contiguous);
|
||||
|
||||
// Put the outputs in
|
||||
for (auto& x : outputs) {
|
||||
|
@@ -14,25 +14,11 @@ namespace mlx::core {
|
||||
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
|
||||
|
||||
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
|
||||
if (ctype == CopyType::Vector) {
|
||||
// If the input is donateable, we are doing a vector copy and the types
|
||||
// have the same size, then the input buffer can hold the output.
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.move_shared_buffer(in);
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
if (in.dtype() == out.dtype()) {
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
bool donated = set_copy_output_data(in, out, ctype);
|
||||
if (donated && in.dtype() == out.dtype()) {
|
||||
// If the output has the same type as the input then there is nothing to
|
||||
// copy, just use the buffer.
|
||||
return;
|
||||
}
|
||||
if (ctype == CopyType::GeneralGeneral) {
|
||||
ctype = CopyType::General;
|
||||
|
@@ -168,9 +168,10 @@ void CommandEncoder::set_output_array(
|
||||
register_output_array(a);
|
||||
}
|
||||
|
||||
void CommandEncoder::register_output_array(array& a) {
|
||||
void CommandEncoder::register_output_array(const array& a) {
|
||||
all_outputs_.insert(a.buffer().ptr());
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
|
||||
auto buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (concurrent_) {
|
||||
concurrent_outputs_.insert(buf);
|
||||
} else {
|
||||
|
@@ -62,7 +62,7 @@ struct CommandEncoder {
|
||||
|
||||
void set_input_array(const array& a, int idx, int64_t offset = 0);
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0);
|
||||
void register_output_array(array& a);
|
||||
void register_output_array(const array& a);
|
||||
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void maybeInsertBarrier();
|
||||
|
@@ -4,149 +4,30 @@
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/event.h"
|
||||
#include "mlx/backend/metal/fence.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/distributed/ops.h"
|
||||
#include "mlx/distributed/primitives.h"
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core::distributed {
|
||||
|
||||
void signal_and_wait(const Event& e_signal, const Event& e_wait) {
|
||||
if (e_signal.valid()) {
|
||||
encode_signal(e_signal);
|
||||
}
|
||||
encode_wait(e_wait);
|
||||
void AllReduce::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
||||
throw std::runtime_error("[AllReduce::eval_gpu] has no GPU implementation.");
|
||||
}
|
||||
|
||||
void AllReduce::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
Fence f{stream()};
|
||||
|
||||
if (in.event().valid()) {
|
||||
f.update_gpu(in);
|
||||
}
|
||||
|
||||
auto& out = outputs[0];
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
f.wait_gpu(out);
|
||||
|
||||
auto task = [in = in,
|
||||
out = unsafe_weak_copy(out),
|
||||
f = std::move(f),
|
||||
reduce_type = reduce_type_,
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
f.wait();
|
||||
}
|
||||
switch (reduce_type) {
|
||||
case Sum:
|
||||
distributed::detail::all_sum(
|
||||
group, in.data_shared_ptr() == nullptr ? out : in, out);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Only all reduce sum is supported for now");
|
||||
}
|
||||
f.update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
void AllGather::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
||||
throw std::runtime_error("[AllGather::eval_gpu] has no GPU implementation.");
|
||||
}
|
||||
|
||||
void AllGather::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
auto& in = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
Fence f{stream()};
|
||||
|
||||
if (in.event().valid()) {
|
||||
f.update_gpu(in);
|
||||
}
|
||||
f.wait_gpu(out);
|
||||
|
||||
auto task = [in = in,
|
||||
out = unsafe_weak_copy(out),
|
||||
f = std::move(f),
|
||||
group = group()]() mutable {
|
||||
if (in.event().valid()) {
|
||||
f.wait();
|
||||
}
|
||||
distributed::detail::all_gather(group, in, out);
|
||||
f.update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
void Send::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
||||
throw std::runtime_error("[Send::eval_gpu] has no GPU implementation.");
|
||||
}
|
||||
|
||||
void Send::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& in = inputs[0];
|
||||
|
||||
// Encode a signal event for the input
|
||||
Fence f{stream()};
|
||||
if (in.event().valid()) {
|
||||
f.update_gpu(in);
|
||||
}
|
||||
|
||||
auto& out = outputs[0];
|
||||
move_or_copy(in, out);
|
||||
|
||||
// Schedule an async send on the comm stream
|
||||
auto task = [in = in,
|
||||
out = unsafe_weak_copy(out),
|
||||
f = std::move(f),
|
||||
group = group(),
|
||||
dst = dst_]() mutable {
|
||||
if (in.event().valid()) {
|
||||
f.wait();
|
||||
}
|
||||
distributed::detail::send(group, out, dst);
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
}
|
||||
|
||||
void Recv::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
assert(inputs.size() == 0);
|
||||
assert(outputs.size() == 1);
|
||||
|
||||
auto& out = outputs[0];
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
|
||||
Fence f{stream()};
|
||||
f.wait_gpu(out);
|
||||
|
||||
// Schedule an async recv on the comm stream
|
||||
auto task = [out = unsafe_weak_copy(out),
|
||||
f = std::move(f),
|
||||
group = group(),
|
||||
src = src_]() mutable {
|
||||
distributed::detail::recv(group, out, src);
|
||||
f.update();
|
||||
};
|
||||
scheduler::enqueue(detail::communication_stream(), std::move(task));
|
||||
void Recv::eval_gpu(const std::vector<array>&, std::vector<array>&) {
|
||||
throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation.");
|
||||
}
|
||||
|
||||
} // namespace mlx::core::distributed
|
||||
|
@@ -3,52 +3,58 @@
|
||||
#include "mlx/event.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/scheduler.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void encode_wait(Event e) {
|
||||
auto& d = metal::device(e.stream().device);
|
||||
d.end_encoding(e.stream().index);
|
||||
auto command_buffer = d.get_command_buffer(e.stream().index);
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
|
||||
command_buffer->addCompletedHandler(
|
||||
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
|
||||
}
|
||||
|
||||
void encode_signal(Event e) {
|
||||
auto& d = metal::device(e.stream().device);
|
||||
d.end_encoding(e.stream().index);
|
||||
auto command_buffer = d.get_command_buffer(e.stream().index);
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
|
||||
command_buffer->addCompletedHandler(
|
||||
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
|
||||
}
|
||||
|
||||
Event::Event(const Stream& stream) : stream_(stream) {
|
||||
Event::Event(Stream stream) : stream_(stream) {
|
||||
auto dtor = [](void* ptr) {
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
static_cast<MTL::SharedEvent*>(ptr)->release();
|
||||
};
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
event_ = std::shared_ptr<void>(
|
||||
metal::device(stream.device).mtl_device()->newSharedEvent(), dtor);
|
||||
metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor);
|
||||
}
|
||||
|
||||
void Event::wait() {
|
||||
if (!static_cast<MTL::SharedEvent*>(raw_event().get())
|
||||
if (!static_cast<MTL::SharedEvent*>(event_.get())
|
||||
->waitUntilSignaledValue(value(), -1)) {
|
||||
throw std::runtime_error("[Event::wait] Timed out");
|
||||
}
|
||||
}
|
||||
|
||||
void Event::signal() {
|
||||
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
|
||||
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
|
||||
}
|
||||
|
||||
void Event::wait(Stream stream) {
|
||||
if (stream.device == Device::cpu) {
|
||||
scheduler::enqueue(stream, [*this]() mutable { wait(); });
|
||||
} else {
|
||||
auto& d = metal::device(stream.device);
|
||||
d.end_encoding(stream.index);
|
||||
auto command_buffer = d.get_command_buffer(stream.index);
|
||||
command_buffer->encodeWait(static_cast<MTL::Event*>(event_.get()), value());
|
||||
command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {});
|
||||
}
|
||||
}
|
||||
|
||||
void Event::signal(Stream stream) {
|
||||
if (stream.device == Device::cpu) {
|
||||
scheduler::enqueue(stream, [*this]() mutable { signal(); });
|
||||
} else {
|
||||
auto& d = metal::device(stream.device);
|
||||
d.end_encoding(stream.index);
|
||||
auto command_buffer = d.get_command_buffer(stream.index);
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(event_.get()), value());
|
||||
command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {});
|
||||
}
|
||||
}
|
||||
|
||||
bool Event::is_signaled() const {
|
||||
return static_cast<MTL::SharedEvent*>(raw_event().get())->signaledValue() >=
|
||||
return static_cast<MTL::SharedEvent*>(event_.get())->signaledValue() >=
|
||||
value();
|
||||
}
|
||||
|
||||
|
@@ -1,10 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
#pragma once
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void encode_wait(Event e);
|
||||
|
||||
void encode_signal(Event e);
|
||||
|
||||
} // namespace mlx::core
|
@@ -1,49 +1,81 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/fence.h"
|
||||
#include "mlx/fence.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/event.h"
|
||||
#include "mlx/backend/metal/metal_impl.h"
|
||||
#include "mlx/scheduler.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
Fence::Fence(const Stream& stream) : stream_(stream) {
|
||||
auto d = metal::device(stream.device).mtl_device();
|
||||
if (!d->supportsFamily(MTL::GPUFamilyMetal3)) {
|
||||
use_fast_ = false;
|
||||
} else if (__builtin_available(macOS 15, iOS 18, *)) {
|
||||
use_fast_ = env::metal_fast_synch();
|
||||
}
|
||||
if (!use_fast_) {
|
||||
// Wraps Metal SharedEvent
|
||||
auto dtor = [](void* ptr) {
|
||||
struct FenceImpl {
|
||||
FenceImpl() {
|
||||
auto d = metal::device(Device::gpu).mtl_device();
|
||||
if (!d->supportsFamily(MTL::GPUFamilyMetal3)) {
|
||||
use_fast = false;
|
||||
} else if (__builtin_available(macOS 15, iOS 18, *)) {
|
||||
use_fast = env::metal_fast_synch();
|
||||
}
|
||||
|
||||
if (!use_fast) {
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
static_cast<MTL::SharedEvent*>(ptr)->release();
|
||||
};
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
fence_ = std::shared_ptr<void>(
|
||||
metal::device(stream.device).mtl_device()->newSharedEvent(), dtor);
|
||||
} else {
|
||||
auto dtor = [](void* buf) {
|
||||
allocator::free(static_cast<MTL::Buffer*>(buf));
|
||||
};
|
||||
auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr();
|
||||
fence_ = std::shared_ptr<void>(buf, dtor);
|
||||
cpu_value()[0] = 0;
|
||||
fence = static_cast<void*>(d->newSharedEvent());
|
||||
} else {
|
||||
auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr();
|
||||
fence = static_cast<void*>(buf);
|
||||
cpu_value()[0] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
~FenceImpl() {
|
||||
if (!use_fast) {
|
||||
// Wraps Metal SharedEvent
|
||||
auto p = metal::new_scoped_memory_pool();
|
||||
static_cast<MTL::SharedEvent*>(fence)->release();
|
||||
} else {
|
||||
allocator::free(static_cast<MTL::Buffer*>(fence));
|
||||
}
|
||||
}
|
||||
bool use_fast{false};
|
||||
uint32_t count{0};
|
||||
void* fence;
|
||||
|
||||
std::atomic_uint* cpu_value() {
|
||||
return static_cast<std::atomic_uint*>(
|
||||
static_cast<MTL::Buffer*>(fence)->contents());
|
||||
}
|
||||
};
|
||||
|
||||
Fence::Fence(Stream) {
|
||||
auto dtor = [](void* ptr) { delete static_cast<FenceImpl*>(ptr); };
|
||||
fence_ = std::shared_ptr<void>(new FenceImpl{}, dtor);
|
||||
}
|
||||
|
||||
void Fence::wait_gpu(array& x) {
|
||||
gpu_count_++;
|
||||
auto& d = metal::device(stream_.device);
|
||||
auto idx = stream_.index;
|
||||
void Fence::wait(Stream stream, const array& x) {
|
||||
auto& f = *static_cast<FenceImpl*>(fence_.get());
|
||||
|
||||
if (!use_fast_) {
|
||||
if (stream.device == Device::cpu) {
|
||||
scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable {
|
||||
auto& f = *static_cast<FenceImpl*>(fence_.get());
|
||||
if (!f.use_fast) {
|
||||
if (!static_cast<MTL::SharedEvent*>(f.fence)->waitUntilSignaledValue(
|
||||
count, -1)) {
|
||||
throw std::runtime_error("[Fence::wait] Timed out");
|
||||
}
|
||||
return;
|
||||
}
|
||||
while (f.cpu_value()[0] < count) {
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
auto& d = metal::device(stream.device);
|
||||
auto idx = stream.index;
|
||||
|
||||
if (!f.use_fast) {
|
||||
d.end_encoding(idx);
|
||||
auto command_buffer = d.get_command_buffer(idx);
|
||||
command_buffer->encodeWait(
|
||||
static_cast<MTL::Event*>(fence_.get()), gpu_count_);
|
||||
command_buffer->encodeWait(static_cast<MTL::Event*>(f.fence), f.count);
|
||||
command_buffer->addCompletedHandler(
|
||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
return;
|
||||
@@ -51,7 +83,7 @@ void Fence::wait_gpu(array& x) {
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(idx);
|
||||
|
||||
// Register the output to ensure that no kernels which depends on the
|
||||
// Register outputs to ensure that no kernels which depends on the
|
||||
// output starts before this one is done
|
||||
compute_encoder.register_output_array(x);
|
||||
|
||||
@@ -59,36 +91,49 @@ void Fence::wait_gpu(array& x) {
|
||||
MTL::Size kernel_dims = MTL::Size(1, 1, 1);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
auto buf = static_cast<MTL::Buffer*>(fence_.get());
|
||||
auto buf = static_cast<MTL::Buffer*>(f.fence);
|
||||
compute_encoder.set_buffer(buf, 0);
|
||||
compute_encoder.set_bytes(gpu_count_, 1);
|
||||
compute_encoder.set_bytes(f.count, 1);
|
||||
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
||||
|
||||
d.get_command_buffer(idx)->addCompletedHandler(
|
||||
[fence = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
}
|
||||
|
||||
void Fence::update_gpu(const array& x) {
|
||||
gpu_count_++;
|
||||
auto& d = metal::device(stream_.device);
|
||||
auto idx = stream_.index;
|
||||
void Fence::update(Stream stream, const array& x) {
|
||||
auto& f = *static_cast<FenceImpl*>(fence_.get());
|
||||
f.count++;
|
||||
|
||||
if (!use_fast_) {
|
||||
if (stream.device == Device::cpu) {
|
||||
scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable {
|
||||
auto& f = *static_cast<FenceImpl*>(fence_.get());
|
||||
if (!f.use_fast) {
|
||||
static_cast<MTL::SharedEvent*>(f.fence)->setSignaledValue(count);
|
||||
return;
|
||||
}
|
||||
|
||||
f.cpu_value()[0] = count;
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
auto& d = metal::device(stream.device);
|
||||
auto idx = stream.index;
|
||||
if (!f.use_fast) {
|
||||
d.end_encoding(idx);
|
||||
auto command_buffer = d.get_command_buffer(idx);
|
||||
command_buffer->encodeSignalEvent(
|
||||
static_cast<MTL::Event*>(fence_.get()), gpu_count_);
|
||||
static_cast<MTL::Event*>(f.fence), f.count);
|
||||
command_buffer->addCompletedHandler(
|
||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
return;
|
||||
}
|
||||
|
||||
// Launch input visibility kernel
|
||||
// Launch input visibility kernels
|
||||
auto& compute_encoder = d.get_command_encoder(idx);
|
||||
auto kernel = d.get_kernel("input_coherent");
|
||||
uint32_t nthreads =
|
||||
(x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / sizeof(uint32_t);
|
||||
|
||||
MTL::Size group_dims = MTL::Size(1024, 1, 1);
|
||||
MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
@@ -96,7 +141,7 @@ void Fence::update_gpu(const array& x) {
|
||||
compute_encoder.set_bytes(nthreads, 1);
|
||||
compute_encoder.dispatch_threadgroups(group_dims, grid_dims);
|
||||
|
||||
// Barrier on previous kernel
|
||||
// Barrier on previous kernels
|
||||
compute_encoder.barrier();
|
||||
|
||||
// Launch value update kernel
|
||||
@@ -104,42 +149,13 @@ void Fence::update_gpu(const array& x) {
|
||||
MTL::Size kernel_dims = MTL::Size(1, 1, 1);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
auto buf = static_cast<MTL::Buffer*>(fence_.get());
|
||||
auto buf = static_cast<MTL::Buffer*>(f.fence);
|
||||
compute_encoder.set_buffer(buf, 0);
|
||||
compute_encoder.set_bytes(gpu_count_, 1);
|
||||
compute_encoder.set_bytes(f.count, 1);
|
||||
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
|
||||
|
||||
d.get_command_buffer(idx)->addCompletedHandler(
|
||||
[fence = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
}
|
||||
|
||||
void Fence::wait() {
|
||||
cpu_count_++;
|
||||
if (!use_fast_) {
|
||||
if (!static_cast<MTL::SharedEvent*>(fence_.get())
|
||||
->waitUntilSignaledValue(cpu_count_, -1)) {
|
||||
throw std::runtime_error("[Event::wait] Timed out");
|
||||
}
|
||||
return;
|
||||
}
|
||||
while (cpu_value()[0] < cpu_count_) {
|
||||
}
|
||||
}
|
||||
|
||||
void Fence::update() {
|
||||
cpu_count_++;
|
||||
|
||||
if (!use_fast_) {
|
||||
static_cast<MTL::SharedEvent*>(fence_.get())->setSignaledValue(cpu_count_);
|
||||
return;
|
||||
}
|
||||
|
||||
cpu_value()[0] = cpu_count_;
|
||||
}
|
||||
|
||||
std::atomic_uint* Fence::cpu_value() {
|
||||
return static_cast<std::atomic_uint*>(
|
||||
static_cast<MTL::Buffer*>(fence_.get())->contents());
|
||||
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1,40 +0,0 @@
|
||||
// Copyright © 2024 Apple Inc.
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
/* A fence to be used for synchronizing work between the CPU and GPU
|
||||
*
|
||||
* Calls to `update_gpu` should be paired with calls to `wait`. This ensures
|
||||
* that the array passed to `update_gpu` is computed and visible to the CPU
|
||||
* after the call to `wait` returns.
|
||||
*
|
||||
* Calls to `update` should be paired with calls to `wait_gpu`. This ensures
|
||||
* that the array passed to `wait_gpu` will not be read by the GPU until the CPU
|
||||
* has called `update`.
|
||||
*
|
||||
* The fence supports slow (default) and fast mode. Fast mode requires setting
|
||||
* the environment variable `MLX_METAL_FAST_SYNCH=1`. Fast mode also requires
|
||||
* Metal 3.2+ (macOS 15+, iOS 18+).
|
||||
*/
|
||||
class Fence {
|
||||
public:
|
||||
Fence(const Stream& stream);
|
||||
|
||||
void update_gpu(const array& x);
|
||||
void wait_gpu(array& x);
|
||||
|
||||
void wait();
|
||||
void update();
|
||||
|
||||
private:
|
||||
Stream stream_;
|
||||
std::shared_ptr<void> fence_;
|
||||
uint32_t cpu_count_{0};
|
||||
uint32_t gpu_count_{0};
|
||||
bool use_fast_;
|
||||
std::atomic_uint* cpu_value();
|
||||
};
|
||||
|
||||
} // namespace mlx::core
|
@@ -82,7 +82,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
const array& in_contiguous = check_input(in);
|
||||
|
||||
if (in_contiguous.is_donatable()) {
|
||||
out.move_shared_buffer(in_contiguous);
|
||||
out.copy_shared_buffer(in_contiguous);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
@@ -2,7 +2,6 @@
|
||||
#include <memory>
|
||||
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/event.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/scheduler.h"
|
||||
@@ -23,88 +22,78 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
|
||||
}
|
||||
}
|
||||
|
||||
std::function<void()> make_task(array arr, bool signal) {
|
||||
auto task = [arr = std::move(arr), signal]() mutable {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto s = arr.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
void eval(array& arr) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto s = arr.primitive().stream();
|
||||
auto& d = metal::device(s.device);
|
||||
auto command_buffer = d.get_command_buffer(s.index);
|
||||
|
||||
for (auto& input : arr.inputs()) {
|
||||
if (input.event().valid() &&
|
||||
input.event().stream() != arr.primitive().stream()) {
|
||||
input.event().wait();
|
||||
}
|
||||
auto outputs = arr.outputs();
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
|
||||
auto outputs = arr.outputs();
|
||||
{
|
||||
// If the array is a tracer hold a reference
|
||||
// to its inputs so they don't get donated
|
||||
std::vector<array> inputs;
|
||||
if (arr.is_tracer()) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
}
|
||||
std::unordered_set<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.insert(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.insert(s.data_shared_ptr());
|
||||
}
|
||||
// Remove the output if it was donated to by an input
|
||||
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
|
||||
buffers.erase(it);
|
||||
}
|
||||
|
||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||
try {
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
} catch (const std::exception& error) {
|
||||
abort_with_exception(error);
|
||||
}
|
||||
}
|
||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
buffers.push_back(in.data_shared_ptr());
|
||||
}
|
||||
for (auto& s : arr.siblings()) {
|
||||
buffers.push_back(s.data_shared_ptr());
|
||||
}
|
||||
if (!arr.is_tracer()) {
|
||||
arr.detach();
|
||||
}
|
||||
for (auto& out : outputs) {
|
||||
out.set_status(array::Status::evaluated);
|
||||
}
|
||||
|
||||
if (signal || d.command_buffer_needs_commit(s.index)) {
|
||||
if (signal) {
|
||||
encode_signal(arr.event());
|
||||
}
|
||||
d.end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
scheduler::notify_task_completion(s);
|
||||
check_error(cbuf);
|
||||
});
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
check_error(cbuf);
|
||||
});
|
||||
}
|
||||
};
|
||||
return task;
|
||||
if (d.command_buffer_needs_commit(s.index)) {
|
||||
d.end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
scheduler::notify_task_completion(s);
|
||||
check_error(cbuf);
|
||||
});
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
} else {
|
||||
command_buffer->addCompletedHandler(
|
||||
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
|
||||
check_error(cbuf);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
std::function<void()> make_synchronize_task(
|
||||
Stream s,
|
||||
std::shared_ptr<std::promise<void>> p) {
|
||||
return [s, p = std::move(p)]() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
cb->retain();
|
||||
d.end_encoding(s.index);
|
||||
d.commit_command_buffer(s.index);
|
||||
cb->waitUntilCompleted();
|
||||
check_error(cb);
|
||||
cb->release();
|
||||
p->set_value();
|
||||
};
|
||||
void finalize(Stream s) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
d.end_encoding(s.index);
|
||||
scheduler::notify_new_task(s);
|
||||
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) {
|
||||
scheduler::notify_task_completion(s);
|
||||
check_error(cbuf);
|
||||
});
|
||||
d.commit_command_buffer(s.index);
|
||||
d.get_command_buffer(s.index);
|
||||
}
|
||||
|
||||
void synchronize(Stream s) {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
auto& d = metal::device(s.device);
|
||||
auto cb = d.get_command_buffer(s.index);
|
||||
cb->retain();
|
||||
d.end_encoding(s.index);
|
||||
d.commit_command_buffer(s.index);
|
||||
cb->waitUntilCompleted();
|
||||
check_error(cb);
|
||||
cb->release();
|
||||
}
|
||||
|
||||
void start_capture(std::string path, id object) {
|
||||
|
@@ -14,10 +14,8 @@ void new_stream(Stream stream);
|
||||
|
||||
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
|
||||
|
||||
std::function<void()> make_task(array arr, bool signal);
|
||||
|
||||
std::function<void()> make_synchronize_task(
|
||||
Stream s,
|
||||
std::shared_ptr<std::promise<void>> p);
|
||||
void eval(array& arr);
|
||||
void finalize(Stream s);
|
||||
void synchronize(Stream s);
|
||||
|
||||
} // namespace mlx::core::metal
|
||||
|
@@ -18,33 +18,33 @@ void RMSNorm::eval_gpu(
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
|
||||
if (x.is_donatable()) {
|
||||
out.move_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
const array x = set_output(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
@@ -88,8 +88,6 @@ void RMSNorm::eval_gpu(
|
||||
compute_encoder.set_bytes(w_stride, 5);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void RMSNormVJP::eval_gpu(
|
||||
@@ -101,20 +99,22 @@ void RMSNormVJP::eval_gpu(
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
auto check_input = [&d, &s](const array& x) -> array {
|
||||
auto check_input = [&d, &s](const array& x) -> std::pair<array, bool> {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
return {x, false};
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
|
||||
return x_copy;
|
||||
return {x_copy, true};
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[2].is_donatable();
|
||||
auto [x, copied] = check_input(inputs[0]);
|
||||
donate_x |= copied;
|
||||
const array& w = inputs[1];
|
||||
const array& g = check_input(inputs[2]);
|
||||
auto [g, g_copied] = check_input(inputs[2]);
|
||||
donate_g |= g_copied;
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
|
||||
@@ -122,17 +122,18 @@ void RMSNormVJP::eval_gpu(
|
||||
bool has_w = w.ndim() != 0;
|
||||
|
||||
// Allocate space for the outputs
|
||||
bool x_in_gx = false;
|
||||
bool g_in_gx = false;
|
||||
if (x.is_donatable()) {
|
||||
gx.move_shared_buffer(x);
|
||||
x_in_gx = true;
|
||||
gx.copy_shared_buffer(x);
|
||||
} else if (g.is_donatable()) {
|
||||
gx.move_shared_buffer(g);
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
d.add_temporary(g, s.index);
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
@@ -141,11 +142,9 @@ void RMSNormVJP::eval_gpu(
|
||||
// gradients before they are accumulated.
|
||||
array gw_temp =
|
||||
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||
bool g_in_gw = false;
|
||||
if (has_w) {
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
gw_temp.move_shared_buffer(g);
|
||||
g_in_gw = true;
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
d.add_temporary(gw_temp, s.index);
|
||||
@@ -189,9 +188,9 @@ void RMSNormVJP::eval_gpu(
|
||||
|
||||
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
compute_encoder.set_input_array(g, 2);
|
||||
compute_encoder.set_output_array(gx, 3);
|
||||
compute_encoder.set_output_array(gw_temp, 4);
|
||||
compute_encoder.set_bytes(eps_, 5);
|
||||
@@ -216,35 +215,35 @@ void LayerNorm::eval_gpu(
|
||||
auto& out = outputs[0];
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
|
||||
const array x = set_output(inputs[0]);
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
|
||||
if (x.is_donatable()) {
|
||||
out.move_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
|
||||
@@ -290,8 +289,6 @@ void LayerNorm::eval_gpu(
|
||||
compute_encoder.set_bytes(b_stride, 7);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
void LayerNormVJP::eval_gpu(
|
||||
@@ -303,21 +300,22 @@ void LayerNormVJP::eval_gpu(
|
||||
// Ensure row contiguity. We could relax this step by checking that the array
|
||||
// is contiguous (no broadcasts or holes) and that the input strides are the
|
||||
// same as the cotangent strides but for now this is simpler.
|
||||
auto check_input = [&d, &s](const array& x) -> array {
|
||||
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
|
||||
if (x.flags().row_contiguous) {
|
||||
return x;
|
||||
return {x, false};
|
||||
}
|
||||
|
||||
array x_copy(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
d.add_temporary(x_copy, s.index);
|
||||
|
||||
return x_copy;
|
||||
return {x_copy, true};
|
||||
};
|
||||
const array& x = check_input(inputs[0]);
|
||||
bool donate_x = inputs[0].is_donatable();
|
||||
bool donate_g = inputs[3].is_donatable();
|
||||
auto [x, copied] = check_input(inputs[0]);
|
||||
donate_x |= copied;
|
||||
const array& w = inputs[1];
|
||||
const array& b = inputs[2];
|
||||
const array& g = check_input(inputs[3]);
|
||||
auto [g, g_copied] = check_input(inputs[3]);
|
||||
donate_g |= g_copied;
|
||||
array& gx = outputs[0];
|
||||
array& gw = outputs[1];
|
||||
array& gb = outputs[2];
|
||||
@@ -326,17 +324,18 @@ void LayerNormVJP::eval_gpu(
|
||||
bool has_w = w.ndim() != 0;
|
||||
|
||||
// Allocate space for the outputs
|
||||
bool x_in_gx = false;
|
||||
bool g_in_gx = false;
|
||||
if (x.is_donatable()) {
|
||||
gx.move_shared_buffer(x);
|
||||
x_in_gx = true;
|
||||
} else if (g.is_donatable()) {
|
||||
gx.move_shared_buffer(g);
|
||||
if (donate_x) {
|
||||
gx.copy_shared_buffer(x);
|
||||
} else if (donate_g) {
|
||||
gx.copy_shared_buffer(g);
|
||||
g_in_gx = true;
|
||||
} else {
|
||||
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
|
||||
}
|
||||
if (g_copied && !g_in_gx) {
|
||||
d.add_temporary(g, s.index);
|
||||
}
|
||||
|
||||
auto axis_size = static_cast<uint32_t>(x.shape().back());
|
||||
int n_rows = x.data_size() / axis_size;
|
||||
@@ -345,15 +344,13 @@ void LayerNormVJP::eval_gpu(
|
||||
// gradient accumulators.
|
||||
array gw_temp =
|
||||
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
|
||||
bool g_in_gw = false;
|
||||
if (has_w) {
|
||||
if (!g_in_gx && g.is_donatable()) {
|
||||
gw_temp.move_shared_buffer(g);
|
||||
g_in_gw = true;
|
||||
if (!g_in_gx && donate_g) {
|
||||
gw_temp.copy_shared_buffer(g);
|
||||
} else {
|
||||
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
|
||||
d.add_temporary(gw_temp, s.index);
|
||||
}
|
||||
d.add_temporary(gw_temp, s.index);
|
||||
}
|
||||
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
|
||||
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
|
||||
@@ -364,14 +361,7 @@ void LayerNormVJP::eval_gpu(
|
||||
ReductionPlan plan(
|
||||
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
|
||||
strided_reduce_general_dispatch(
|
||||
g_in_gx ? gx : (g_in_gw ? gw_temp : g),
|
||||
gb,
|
||||
"sum",
|
||||
plan,
|
||||
{0},
|
||||
compute_encoder,
|
||||
d,
|
||||
s);
|
||||
g, gb, "sum", plan, {0}, compute_encoder, d, s);
|
||||
}
|
||||
|
||||
const int simd_size = 32;
|
||||
@@ -409,9 +399,9 @@ void LayerNormVJP::eval_gpu(
|
||||
|
||||
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
|
||||
compute_encoder.set_input_array(x, 0);
|
||||
compute_encoder.set_input_array(w, 1);
|
||||
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
|
||||
compute_encoder.set_input_array(g, 2);
|
||||
compute_encoder.set_output_array(gx, 3);
|
||||
compute_encoder.set_output_array(gw_temp, 4);
|
||||
compute_encoder.set_bytes(eps_, 5);
|
||||
|
@@ -5,12 +5,10 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "mlx/backend/common/compiled.h"
|
||||
#include "mlx/backend/common/load.h"
|
||||
#include "mlx/backend/common/slicing.h"
|
||||
#include "mlx/backend/common/utils.h"
|
||||
#include "mlx/backend/metal/copy.h"
|
||||
#include "mlx/backend/metal/device.h"
|
||||
#include "mlx/backend/metal/event.h"
|
||||
#include "mlx/backend/metal/kernels.h"
|
||||
#include "mlx/backend/metal/slicing.h"
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
@@ -45,7 +43,8 @@ void reshape(const array& in, array& out, Stream s) {
|
||||
shared_buffer_reshape(in, out_strides, out);
|
||||
}
|
||||
}
|
||||
array compute_dynamic_offset(
|
||||
|
||||
static array compute_dynamic_offset(
|
||||
const array& indices,
|
||||
const Strides& strides,
|
||||
const std::vector<int>& axes,
|
||||
@@ -57,7 +56,7 @@ array compute_dynamic_offset(
|
||||
bool donate = indices.is_donatable() &&
|
||||
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
|
||||
if (donate) {
|
||||
offset.move_shared_buffer(indices);
|
||||
offset.copy_shared_buffer(indices);
|
||||
} else {
|
||||
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
|
||||
}
|
||||
@@ -88,7 +87,7 @@ array compute_dynamic_offset(
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(donate ? offset : indices, 0);
|
||||
compute_encoder.set_input_array(indices, 0);
|
||||
compute_encoder.set_output_array(offset, 1);
|
||||
compute_encoder.set_vector_bytes(strides, 2);
|
||||
compute_encoder.set_vector_bytes(axes, 3);
|
||||
@@ -254,7 +253,7 @@ void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& in = inputs[0];
|
||||
if (in.flags().row_contiguous ||
|
||||
(allow_col_major_ && in.flags().col_contiguous)) {
|
||||
move_or_copy(in, out);
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
copy_gpu(in, out, CopyType::General);
|
||||
}
|
||||
@@ -302,31 +301,7 @@ void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
auto read_task = [out = unsafe_weak_copy(out),
|
||||
offset = offset_,
|
||||
reader = reader_,
|
||||
swap_endianness = swap_endianness_]() mutable {
|
||||
load(out, offset, reader, swap_endianness);
|
||||
};
|
||||
|
||||
// Limit the size that the command buffer will wait on to avoid timing out
|
||||
// on the event (<4 seconds).
|
||||
if (out.nbytes() > (1 << 28)) {
|
||||
read_task();
|
||||
return;
|
||||
}
|
||||
|
||||
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
|
||||
|
||||
auto e = Event(stream());
|
||||
e.set_value(1);
|
||||
encode_wait(e);
|
||||
auto signal_task = [e = std::move(e), fut = std::move(fut)]() mutable {
|
||||
fut.wait();
|
||||
e.signal();
|
||||
};
|
||||
scheduler::enqueue(io_stream(), std::move(signal_task));
|
||||
throw std::runtime_error("[Load::eval_gpu] Not implemented.");
|
||||
}
|
||||
|
||||
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -452,7 +427,7 @@ void DynamicSliceUpdate::eval_gpu(
|
||||
auto& start_indices = inputs[2];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
move_or_copy(in, out);
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -491,7 +466,7 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& upd = inputs[1];
|
||||
|
||||
if (upd.size() == 0) {
|
||||
move_or_copy(in, out);
|
||||
out.copy_shared_buffer(in);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -575,8 +550,8 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
strides[i] *= ibytes;
|
||||
strides[i] /= obytes;
|
||||
}
|
||||
move_or_copy(
|
||||
in, out, strides, in.flags(), in.data_size() * ibytes / obytes);
|
||||
out.copy_shared_buffer(
|
||||
in, strides, in.flags(), in.data_size() * ibytes / obytes);
|
||||
} else {
|
||||
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
|
||||
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
|
||||
@@ -587,7 +562,7 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
flags.row_contiguous = true;
|
||||
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
|
||||
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
|
||||
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -41,7 +41,7 @@ void RoPE::eval_gpu(
|
||||
} else if (in.flags().row_contiguous) {
|
||||
if (in.is_donatable()) {
|
||||
donated = true;
|
||||
out.move_shared_buffer(in);
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
}
|
||||
|
@@ -152,7 +152,7 @@ void sdpa_vector(
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
|
||||
compute_encoder.set_input_array(q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
@@ -242,7 +242,7 @@ void sdpa_vector_2pass(
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
|
||||
// Set its arguments
|
||||
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
|
||||
compute_encoder.set_input_array(q, 0);
|
||||
compute_encoder.set_input_array(k, 1);
|
||||
compute_encoder.set_input_array(v, 2);
|
||||
compute_encoder.set_output_array(intermediate, 3);
|
||||
@@ -357,7 +357,7 @@ void ScaledDotProductAttention::eval_gpu(
|
||||
// Donate the query if possible
|
||||
if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) &&
|
||||
q.size() == o.size()) {
|
||||
o.move_shared_buffer(q);
|
||||
o.copy_shared_buffer(q);
|
||||
} else {
|
||||
if (o.shape(2) == 1) {
|
||||
o.set_data(allocator::malloc_or_wait(o.nbytes()));
|
||||
|
@@ -21,7 +21,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto in = inputs[0];
|
||||
if (in.flags().contiguous && in.strides()[axis_] != 0) {
|
||||
if (donate && in.itemsize() == out.itemsize()) {
|
||||
out.move_shared_buffer(in);
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
@@ -33,7 +33,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
array arr_copy(in.shape(), in.dtype(), nullptr, {});
|
||||
copy_gpu(in, arr_copy, CopyType::General, s);
|
||||
in = std::move(arr_copy);
|
||||
out.move_shared_buffer(in);
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
|
||||
bool contiguous = in.strides()[axis_] == 1;
|
||||
@@ -68,8 +68,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (contiguous) {
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
size_t size = in.shape(axis_);
|
||||
compute_encoder.set_bytes(size, 2);
|
||||
|
@@ -22,31 +22,32 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto& d = metal::device(s.device);
|
||||
|
||||
// Make sure that the last dimension is contiguous
|
||||
std::vector<array> copies;
|
||||
auto check_input = [&copies, &s](const array& x) -> const array& {
|
||||
auto set_output = [&s, &out](const array& x) {
|
||||
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
|
||||
if (no_copy && x.ndim() > 1) {
|
||||
auto s = x.strides()[x.ndim() - 2];
|
||||
no_copy &= (s == 0 || s == x.shape().back());
|
||||
}
|
||||
if (no_copy) {
|
||||
if (x.is_donatable()) {
|
||||
out.copy_shared_buffer(x);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
|
||||
x.data_size(),
|
||||
x.strides(),
|
||||
x.flags());
|
||||
}
|
||||
return x;
|
||||
} else {
|
||||
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
|
||||
copy_gpu(x, copies.back(), CopyType::General, s);
|
||||
return copies.back();
|
||||
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
|
||||
copy_gpu(x, x_copy, CopyType::General, s);
|
||||
out.copy_shared_buffer(x_copy);
|
||||
return x_copy;
|
||||
}
|
||||
};
|
||||
const array& in = check_input(inputs[0]);
|
||||
if (in.is_donatable()) {
|
||||
out.move_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
|
||||
in.data_size(),
|
||||
in.strides(),
|
||||
in.flags());
|
||||
}
|
||||
|
||||
const array in = set_output(inputs[0]);
|
||||
|
||||
int axis_size = in.shape().back();
|
||||
int n_rows = in.data_size() / axis_size;
|
||||
@@ -82,14 +83,11 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
compute_encoder.set_bytes(axis_size, 2);
|
||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
d.add_temporaries(std::move(copies), s.index);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -71,12 +71,9 @@ void ternary_op_gpu_inplace(
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
bool donate_a = a.data_shared_ptr() == nullptr;
|
||||
bool donate_b = b.data_shared_ptr() == nullptr;
|
||||
bool donate_c = c.data_shared_ptr() == nullptr;
|
||||
compute_encoder.set_input_array(donate_a ? out : a, 0);
|
||||
compute_encoder.set_input_array(donate_b ? out : b, 1);
|
||||
compute_encoder.set_input_array(donate_c ? out : c, 2);
|
||||
compute_encoder.set_input_array(a, 0);
|
||||
compute_encoder.set_input_array(b, 1);
|
||||
compute_encoder.set_input_array(c, 2);
|
||||
compute_encoder.set_output_array(out, 3);
|
||||
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
@@ -129,7 +126,7 @@ void ternary_op_gpu(
|
||||
auto& b = inputs[1];
|
||||
auto& c = inputs[2];
|
||||
TernaryOpType topt = get_ternary_op_type(a, b, c);
|
||||
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
|
||||
set_ternary_op_output_data(a, b, c, out, topt);
|
||||
ternary_op_gpu_inplace(inputs, out, op, s);
|
||||
}
|
||||
|
||||
|
@@ -57,8 +57,7 @@ void unary_op_gpu_inplace(
|
||||
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_compute_pipeline_state(kernel);
|
||||
compute_encoder.set_input_array(
|
||||
in.data_shared_ptr() == nullptr ? out : in, 0);
|
||||
compute_encoder.set_input_array(in, 0);
|
||||
compute_encoder.set_output_array(out, 1);
|
||||
if (!contig) {
|
||||
// Launch up to 3D grid of threads
|
||||
@@ -95,7 +94,7 @@ void unary_op_gpu(
|
||||
bool contig = in.flags().contiguous;
|
||||
if (contig) {
|
||||
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
|
||||
out.move_shared_buffer(in);
|
||||
out.copy_shared_buffer(in);
|
||||
} else {
|
||||
out.set_data(
|
||||
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
|
||||
@@ -169,7 +168,7 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
unary_op_gpu(inputs, out, get_primitive_string(this));
|
||||
} else {
|
||||
// No-op integer types
|
||||
move_or_copy(in, out);
|
||||
out.copy_shared_buffer(in);
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -69,19 +69,4 @@ void concatenate(std::string& acc, T first, Args... args) {
|
||||
concatenate(acc, args...);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a new array that refers to the same data but has a non-owning pointer to
|
||||
* them.
|
||||
*/
|
||||
inline array unsafe_weak_copy(const array& x) {
|
||||
return array(
|
||||
x.buffer(),
|
||||
x.shape(),
|
||||
x.dtype(),
|
||||
x.strides(),
|
||||
x.data_size(),
|
||||
x.flags(),
|
||||
[](auto b) {});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Reference in New Issue
Block a user