redesign for faster cpu/gpu synch (#1869)

* redesign for faster cpu/gpu synch

* load + more async CPU

* use command encoder API and move more ops to use it

* make fence back-end generic + CPU only fence

* faster build

* fix async eval

* fixes + handle temporaries

* fix / improve cpu conv

* remove unused status, fix siblings

* fix extensions

* fix

* fix no cpu build

* format

* comments

* fix perf regression, remove unecessary abort

* fix events, task limit cpu

* fix waiting

* fix donation / temporaries in normalization
This commit is contained in:
Awni Hannun
2025-03-06 19:23:38 -08:00
committed by GitHub
parent 5245f12a46
commit c4230747a1
103 changed files with 5013 additions and 3873 deletions

View File

@@ -102,16 +102,9 @@ void binary_op_gpu_inplace(
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
// - If a is donated it goes to the first output
// - If b is donated it goes to the first output if a was not donated
// otherwise it goes to the second output.
// - If there is only one output only one of a and b will be donated.
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
int arg_idx = 0;
compute_encoder.set_input_array(donate_a ? outputs[0] : a, arg_idx++);
compute_encoder.set_input_array(
donate_b ? (donate_a ? outputs[1] : outputs[0]) : b, arg_idx++);
compute_encoder.set_input_array(a, arg_idx++);
compute_encoder.set_input_array(b, arg_idx++);
compute_encoder.set_output_array(outputs[0], arg_idx++);
if (outputs.size() == 2) {
compute_encoder.set_output_array(outputs[1], arg_idx++);
@@ -164,8 +157,8 @@ void binary_op_gpu(
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, outputs[0], bopt, true);
set_binary_op_output_data(a, b, outputs[1], bopt, true);
set_binary_op_output_data(a, b, outputs[0], bopt);
set_binary_op_output_data(a, b, outputs[1], bopt);
binary_op_gpu_inplace(inputs, outputs, op, s);
}
@@ -195,7 +188,7 @@ void binary_op_gpu(
auto& a = inputs[0];
auto& b = inputs[1];
auto bopt = get_binary_op_type(a, b);
set_binary_op_output_data(a, b, out, bopt, true);
set_binary_op_output_data(a, b, out, bopt);
binary_op_gpu_inplace(inputs, out, op, s);
}

View File

@@ -457,7 +457,7 @@ void Compiled::eval_gpu(
}
compiled_allocate_outputs(
inputs, outputs, inputs_, constant_ids_, contiguous, true);
inputs, outputs, inputs_, constant_ids_, contiguous);
// Put the outputs in
for (auto& x : outputs) {

View File

@@ -14,25 +14,11 @@ namespace mlx::core {
constexpr int MAX_COPY_SPECIALIZED_DIMS = 3;
void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
if (ctype == CopyType::Vector) {
// If the input is donateable, we are doing a vector copy and the types
// have the same size, then the input buffer can hold the output.
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
if (in.dtype() == out.dtype()) {
return;
}
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
bool donated = set_copy_output_data(in, out, ctype);
if (donated && in.dtype() == out.dtype()) {
// If the output has the same type as the input then there is nothing to
// copy, just use the buffer.
return;
}
if (ctype == CopyType::GeneralGeneral) {
ctype = CopyType::General;

View File

@@ -168,9 +168,10 @@ void CommandEncoder::set_output_array(
register_output_array(a);
}
void CommandEncoder::register_output_array(array& a) {
void CommandEncoder::register_output_array(const array& a) {
all_outputs_.insert(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (concurrent_) {
concurrent_outputs_.insert(buf);
} else {

View File

@@ -62,7 +62,7 @@ struct CommandEncoder {
void set_input_array(const array& a, int idx, int64_t offset = 0);
void set_output_array(array& a, int idx, int64_t offset = 0);
void register_output_array(array& a);
void register_output_array(const array& a);
void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims);
void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims);
void maybeInsertBarrier();

View File

@@ -4,149 +4,30 @@
#include "mlx/allocator.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/backend/metal/fence.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/distributed/ops.h"
#include "mlx/distributed/primitives.h"
#include "mlx/fence.h"
#include "mlx/scheduler.h"
namespace mlx::core::distributed {
void signal_and_wait(const Event& e_signal, const Event& e_wait) {
if (e_signal.valid()) {
encode_signal(e_signal);
}
encode_wait(e_wait);
void AllReduce::eval_gpu(const std::vector<array>&, std::vector<array>&) {
throw std::runtime_error("[AllReduce::eval_gpu] has no GPU implementation.");
}
void AllReduce::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& in = inputs[0];
Fence f{stream()};
if (in.event().valid()) {
f.update_gpu(in);
}
auto& out = outputs[0];
if (in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}
f.wait_gpu(out);
auto task = [in = in,
out = unsafe_weak_copy(out),
f = std::move(f),
reduce_type = reduce_type_,
group = group()]() mutable {
if (in.event().valid()) {
f.wait();
}
switch (reduce_type) {
case Sum:
distributed::detail::all_sum(
group, in.data_shared_ptr() == nullptr ? out : in, out);
break;
default:
throw std::runtime_error("Only all reduce sum is supported for now");
}
f.update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
void AllGather::eval_gpu(const std::vector<array>&, std::vector<array>&) {
throw std::runtime_error("[AllGather::eval_gpu] has no GPU implementation.");
}
void AllGather::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& in = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
Fence f{stream()};
if (in.event().valid()) {
f.update_gpu(in);
}
f.wait_gpu(out);
auto task = [in = in,
out = unsafe_weak_copy(out),
f = std::move(f),
group = group()]() mutable {
if (in.event().valid()) {
f.wait();
}
distributed::detail::all_gather(group, in, out);
f.update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
void Send::eval_gpu(const std::vector<array>&, std::vector<array>&) {
throw std::runtime_error("[Send::eval_gpu] has no GPU implementation.");
}
void Send::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 1);
assert(outputs.size() == 1);
auto& in = inputs[0];
// Encode a signal event for the input
Fence f{stream()};
if (in.event().valid()) {
f.update_gpu(in);
}
auto& out = outputs[0];
move_or_copy(in, out);
// Schedule an async send on the comm stream
auto task = [in = in,
out = unsafe_weak_copy(out),
f = std::move(f),
group = group(),
dst = dst_]() mutable {
if (in.event().valid()) {
f.wait();
}
distributed::detail::send(group, out, dst);
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
}
void Recv::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
assert(inputs.size() == 0);
assert(outputs.size() == 1);
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
Fence f{stream()};
f.wait_gpu(out);
// Schedule an async recv on the comm stream
auto task = [out = unsafe_weak_copy(out),
f = std::move(f),
group = group(),
src = src_]() mutable {
distributed::detail::recv(group, out, src);
f.update();
};
scheduler::enqueue(detail::communication_stream(), std::move(task));
void Recv::eval_gpu(const std::vector<array>&, std::vector<array>&) {
throw std::runtime_error("[Recv::eval_gpu] has no GPU implementation.");
}
} // namespace mlx::core::distributed

View File

@@ -3,52 +3,58 @@
#include "mlx/event.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/scheduler.h"
namespace mlx::core {
void encode_wait(Event e) {
auto& d = metal::device(e.stream().device);
d.end_encoding(e.stream().index);
auto command_buffer = d.get_command_buffer(e.stream().index);
command_buffer->encodeWait(
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
command_buffer->addCompletedHandler(
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
}
void encode_signal(Event e) {
auto& d = metal::device(e.stream().device);
d.end_encoding(e.stream().index);
auto command_buffer = d.get_command_buffer(e.stream().index);
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(e.raw_event().get()), e.value());
command_buffer->addCompletedHandler(
[e = std::move(e)](MTL::CommandBuffer* cbuf) {});
}
Event::Event(const Stream& stream) : stream_(stream) {
Event::Event(Stream stream) : stream_(stream) {
auto dtor = [](void* ptr) {
auto p = metal::new_scoped_memory_pool();
static_cast<MTL::SharedEvent*>(ptr)->release();
};
auto p = metal::new_scoped_memory_pool();
event_ = std::shared_ptr<void>(
metal::device(stream.device).mtl_device()->newSharedEvent(), dtor);
metal::device(Device::gpu).mtl_device()->newSharedEvent(), dtor);
}
void Event::wait() {
if (!static_cast<MTL::SharedEvent*>(raw_event().get())
if (!static_cast<MTL::SharedEvent*>(event_.get())
->waitUntilSignaledValue(value(), -1)) {
throw std::runtime_error("[Event::wait] Timed out");
}
}
void Event::signal() {
static_cast<MTL::SharedEvent*>(raw_event().get())->setSignaledValue(value());
static_cast<MTL::SharedEvent*>(event_.get())->setSignaledValue(value());
}
void Event::wait(Stream stream) {
if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable { wait(); });
} else {
auto& d = metal::device(stream.device);
d.end_encoding(stream.index);
auto command_buffer = d.get_command_buffer(stream.index);
command_buffer->encodeWait(static_cast<MTL::Event*>(event_.get()), value());
command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {});
}
}
void Event::signal(Stream stream) {
if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [*this]() mutable { signal(); });
} else {
auto& d = metal::device(stream.device);
d.end_encoding(stream.index);
auto command_buffer = d.get_command_buffer(stream.index);
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(event_.get()), value());
command_buffer->addCompletedHandler([*this](MTL::CommandBuffer*) {});
}
}
bool Event::is_signaled() const {
return static_cast<MTL::SharedEvent*>(raw_event().get())->signaledValue() >=
return static_cast<MTL::SharedEvent*>(event_.get())->signaledValue() >=
value();
}

View File

@@ -1,10 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
namespace mlx::core {
void encode_wait(Event e);
void encode_signal(Event e);
} // namespace mlx::core

View File

@@ -1,49 +1,81 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/fence.h"
#include "mlx/fence.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/backend/metal/metal_impl.h"
#include "mlx/scheduler.h"
#include "mlx/utils.h"
namespace mlx::core {
Fence::Fence(const Stream& stream) : stream_(stream) {
auto d = metal::device(stream.device).mtl_device();
if (!d->supportsFamily(MTL::GPUFamilyMetal3)) {
use_fast_ = false;
} else if (__builtin_available(macOS 15, iOS 18, *)) {
use_fast_ = env::metal_fast_synch();
}
if (!use_fast_) {
// Wraps Metal SharedEvent
auto dtor = [](void* ptr) {
struct FenceImpl {
FenceImpl() {
auto d = metal::device(Device::gpu).mtl_device();
if (!d->supportsFamily(MTL::GPUFamilyMetal3)) {
use_fast = false;
} else if (__builtin_available(macOS 15, iOS 18, *)) {
use_fast = env::metal_fast_synch();
}
if (!use_fast) {
auto p = metal::new_scoped_memory_pool();
static_cast<MTL::SharedEvent*>(ptr)->release();
};
auto p = metal::new_scoped_memory_pool();
fence_ = std::shared_ptr<void>(
metal::device(stream.device).mtl_device()->newSharedEvent(), dtor);
} else {
auto dtor = [](void* buf) {
allocator::free(static_cast<MTL::Buffer*>(buf));
};
auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr();
fence_ = std::shared_ptr<void>(buf, dtor);
cpu_value()[0] = 0;
fence = static_cast<void*>(d->newSharedEvent());
} else {
auto buf = allocator::malloc_or_wait(sizeof(uint32_t)).ptr();
fence = static_cast<void*>(buf);
cpu_value()[0] = 0;
}
}
~FenceImpl() {
if (!use_fast) {
// Wraps Metal SharedEvent
auto p = metal::new_scoped_memory_pool();
static_cast<MTL::SharedEvent*>(fence)->release();
} else {
allocator::free(static_cast<MTL::Buffer*>(fence));
}
}
bool use_fast{false};
uint32_t count{0};
void* fence;
std::atomic_uint* cpu_value() {
return static_cast<std::atomic_uint*>(
static_cast<MTL::Buffer*>(fence)->contents());
}
};
Fence::Fence(Stream) {
auto dtor = [](void* ptr) { delete static_cast<FenceImpl*>(ptr); };
fence_ = std::shared_ptr<void>(new FenceImpl{}, dtor);
}
void Fence::wait_gpu(array& x) {
gpu_count_++;
auto& d = metal::device(stream_.device);
auto idx = stream_.index;
void Fence::wait(Stream stream, const array& x) {
auto& f = *static_cast<FenceImpl*>(fence_.get());
if (!use_fast_) {
if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable {
auto& f = *static_cast<FenceImpl*>(fence_.get());
if (!f.use_fast) {
if (!static_cast<MTL::SharedEvent*>(f.fence)->waitUntilSignaledValue(
count, -1)) {
throw std::runtime_error("[Fence::wait] Timed out");
}
return;
}
while (f.cpu_value()[0] < count) {
}
});
return;
}
auto& d = metal::device(stream.device);
auto idx = stream.index;
if (!f.use_fast) {
d.end_encoding(idx);
auto command_buffer = d.get_command_buffer(idx);
command_buffer->encodeWait(
static_cast<MTL::Event*>(fence_.get()), gpu_count_);
command_buffer->encodeWait(static_cast<MTL::Event*>(f.fence), f.count);
command_buffer->addCompletedHandler(
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
return;
@@ -51,7 +83,7 @@ void Fence::wait_gpu(array& x) {
auto& compute_encoder = d.get_command_encoder(idx);
// Register the output to ensure that no kernels which depends on the
// Register outputs to ensure that no kernels which depends on the
// output starts before this one is done
compute_encoder.register_output_array(x);
@@ -59,36 +91,49 @@ void Fence::wait_gpu(array& x) {
MTL::Size kernel_dims = MTL::Size(1, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel);
auto buf = static_cast<MTL::Buffer*>(fence_.get());
auto buf = static_cast<MTL::Buffer*>(f.fence);
compute_encoder.set_buffer(buf, 0);
compute_encoder.set_bytes(gpu_count_, 1);
compute_encoder.set_bytes(f.count, 1);
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
d.get_command_buffer(idx)->addCompletedHandler(
[fence = fence_](MTL::CommandBuffer* cbuf) {});
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
}
void Fence::update_gpu(const array& x) {
gpu_count_++;
auto& d = metal::device(stream_.device);
auto idx = stream_.index;
void Fence::update(Stream stream, const array& x) {
auto& f = *static_cast<FenceImpl*>(fence_.get());
f.count++;
if (!use_fast_) {
if (stream.device == Device::cpu) {
scheduler::enqueue(stream, [fence_ = fence_, count = f.count]() mutable {
auto& f = *static_cast<FenceImpl*>(fence_.get());
if (!f.use_fast) {
static_cast<MTL::SharedEvent*>(f.fence)->setSignaledValue(count);
return;
}
f.cpu_value()[0] = count;
});
return;
}
auto& d = metal::device(stream.device);
auto idx = stream.index;
if (!f.use_fast) {
d.end_encoding(idx);
auto command_buffer = d.get_command_buffer(idx);
command_buffer->encodeSignalEvent(
static_cast<MTL::Event*>(fence_.get()), gpu_count_);
static_cast<MTL::Event*>(f.fence), f.count);
command_buffer->addCompletedHandler(
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
return;
}
// Launch input visibility kernel
// Launch input visibility kernels
auto& compute_encoder = d.get_command_encoder(idx);
auto kernel = d.get_kernel("input_coherent");
uint32_t nthreads =
(x.data_size() * x.itemsize() + sizeof(uint32_t) - 1) / sizeof(uint32_t);
MTL::Size group_dims = MTL::Size(1024, 1, 1);
MTL::Size grid_dims = MTL::Size((nthreads + 1024 - 1) / 1024, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel);
@@ -96,7 +141,7 @@ void Fence::update_gpu(const array& x) {
compute_encoder.set_bytes(nthreads, 1);
compute_encoder.dispatch_threadgroups(group_dims, grid_dims);
// Barrier on previous kernel
// Barrier on previous kernels
compute_encoder.barrier();
// Launch value update kernel
@@ -104,42 +149,13 @@ void Fence::update_gpu(const array& x) {
MTL::Size kernel_dims = MTL::Size(1, 1, 1);
compute_encoder.set_compute_pipeline_state(kernel);
auto buf = static_cast<MTL::Buffer*>(fence_.get());
auto buf = static_cast<MTL::Buffer*>(f.fence);
compute_encoder.set_buffer(buf, 0);
compute_encoder.set_bytes(gpu_count_, 1);
compute_encoder.set_bytes(f.count, 1);
compute_encoder.dispatch_threads(kernel_dims, kernel_dims);
d.get_command_buffer(idx)->addCompletedHandler(
[fence = fence_](MTL::CommandBuffer* cbuf) {});
}
void Fence::wait() {
cpu_count_++;
if (!use_fast_) {
if (!static_cast<MTL::SharedEvent*>(fence_.get())
->waitUntilSignaledValue(cpu_count_, -1)) {
throw std::runtime_error("[Event::wait] Timed out");
}
return;
}
while (cpu_value()[0] < cpu_count_) {
}
}
void Fence::update() {
cpu_count_++;
if (!use_fast_) {
static_cast<MTL::SharedEvent*>(fence_.get())->setSignaledValue(cpu_count_);
return;
}
cpu_value()[0] = cpu_count_;
}
std::atomic_uint* Fence::cpu_value() {
return static_cast<std::atomic_uint*>(
static_cast<MTL::Buffer*>(fence_.get())->contents());
[fence_ = fence_](MTL::CommandBuffer* cbuf) {});
}
} // namespace mlx::core

View File

@@ -1,40 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/device.h"
namespace mlx::core {
/* A fence to be used for synchronizing work between the CPU and GPU
*
* Calls to `update_gpu` should be paired with calls to `wait`. This ensures
* that the array passed to `update_gpu` is computed and visible to the CPU
* after the call to `wait` returns.
*
* Calls to `update` should be paired with calls to `wait_gpu`. This ensures
* that the array passed to `wait_gpu` will not be read by the GPU until the CPU
* has called `update`.
*
* The fence supports slow (default) and fast mode. Fast mode requires setting
* the environment variable `MLX_METAL_FAST_SYNCH=1`. Fast mode also requires
* Metal 3.2+ (macOS 15+, iOS 18+).
*/
class Fence {
public:
Fence(const Stream& stream);
void update_gpu(const array& x);
void wait_gpu(array& x);
void wait();
void update();
private:
Stream stream_;
std::shared_ptr<void> fence_;
uint32_t cpu_count_{0};
uint32_t gpu_count_{0};
bool use_fast_;
std::atomic_uint* cpu_value();
};
} // namespace mlx::core

View File

@@ -82,7 +82,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
const array& in_contiguous = check_input(in);
if (in_contiguous.is_donatable()) {
out.move_shared_buffer(in_contiguous);
out.copy_shared_buffer(in_contiguous);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}

View File

@@ -2,7 +2,6 @@
#include <memory>
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"
@@ -23,88 +22,78 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
}
}
std::function<void()> make_task(array arr, bool signal) {
auto task = [arr = std::move(arr), signal]() mutable {
auto pool = new_scoped_memory_pool();
auto s = arr.primitive().stream();
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
void eval(array& arr) {
auto pool = new_scoped_memory_pool();
auto s = arr.primitive().stream();
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);
for (auto& input : arr.inputs()) {
if (input.event().valid() &&
input.event().stream() != arr.primitive().stream()) {
input.event().wait();
}
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
auto outputs = arr.outputs();
{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
std::vector<array> inputs;
if (arr.is_tracer()) {
inputs = arr.inputs();
}
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
arr.primitive().eval_gpu(arr.inputs(), outputs);
}
std::unordered_set<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.insert(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.insert(s.data_shared_ptr());
}
// Remove the output if it was donated to by an input
if (auto it = buffers.find(arr.data_shared_ptr()); it != buffers.end()) {
buffers.erase(it);
}
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
try {
arr.primitive().eval_gpu(arr.inputs(), outputs);
} catch (const std::exception& error) {
abort_with_exception(error);
}
}
std::vector<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {
buffers.push_back(in.data_shared_ptr());
}
for (auto& s : arr.siblings()) {
buffers.push_back(s.data_shared_ptr());
}
if (!arr.is_tracer()) {
arr.detach();
}
for (auto& out : outputs) {
out.set_status(array::Status::evaluated);
}
if (signal || d.command_buffer_needs_commit(s.index)) {
if (signal) {
encode_signal(arr.event());
}
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
}
};
return task;
if (d.command_buffer_needs_commit(s.index)) {
d.end_encoding(s.index);
scheduler::notify_new_task(s);
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
} else {
command_buffer->addCompletedHandler(
[s, buffers = std::move(buffers)](MTL::CommandBuffer* cbuf) {
check_error(cbuf);
});
}
}
std::function<void()> make_synchronize_task(
Stream s,
std::shared_ptr<std::promise<void>> p) {
return [s, p = std::move(p)]() {
auto pool = new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
cb->retain();
d.end_encoding(s.index);
d.commit_command_buffer(s.index);
cb->waitUntilCompleted();
check_error(cb);
cb->release();
p->set_value();
};
void finalize(Stream s) {
auto pool = new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
d.end_encoding(s.index);
scheduler::notify_new_task(s);
cb->addCompletedHandler([s](MTL::CommandBuffer* cbuf) {
scheduler::notify_task_completion(s);
check_error(cbuf);
});
d.commit_command_buffer(s.index);
d.get_command_buffer(s.index);
}
void synchronize(Stream s) {
auto pool = new_scoped_memory_pool();
auto& d = metal::device(s.device);
auto cb = d.get_command_buffer(s.index);
cb->retain();
d.end_encoding(s.index);
d.commit_command_buffer(s.index);
cb->waitUntilCompleted();
check_error(cb);
cb->release();
}
void start_capture(std::string path, id object) {

View File

@@ -14,10 +14,8 @@ void new_stream(Stream stream);
std::unique_ptr<void, std::function<void(void*)>> new_scoped_memory_pool();
std::function<void()> make_task(array arr, bool signal);
std::function<void()> make_synchronize_task(
Stream s,
std::shared_ptr<std::promise<void>> p);
void eval(array& arr);
void finalize(Stream s);
void synchronize(Stream s);
} // namespace mlx::core::metal

View File

@@ -18,33 +18,33 @@ void RMSNorm::eval_gpu(
auto& out = outputs[0];
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& {
auto set_output = [&s, &out](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
};
const array& x = check_input(inputs[0]);
const array& w = inputs[1];
if (x.is_donatable()) {
out.move_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
const array x = set_output(inputs[0]);
const array& w = inputs[1];
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
@@ -88,8 +88,6 @@ void RMSNorm::eval_gpu(
compute_encoder.set_bytes(w_stride, 5);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);
}
void RMSNormVJP::eval_gpu(
@@ -101,20 +99,22 @@ void RMSNormVJP::eval_gpu(
// Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler.
auto check_input = [&d, &s](const array& x) -> array {
auto check_input = [&d, &s](const array& x) -> std::pair<array, bool> {
if (x.flags().row_contiguous) {
return x;
return {x, false};
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
return {x_copy, true};
};
const array& x = check_input(inputs[0]);
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[2].is_donatable();
auto [x, copied] = check_input(inputs[0]);
donate_x |= copied;
const array& w = inputs[1];
const array& g = check_input(inputs[2]);
auto [g, g_copied] = check_input(inputs[2]);
donate_g |= g_copied;
array& gx = outputs[0];
array& gw = outputs[1];
@@ -122,17 +122,18 @@ void RMSNormVJP::eval_gpu(
bool has_w = w.ndim() != 0;
// Allocate space for the outputs
bool x_in_gx = false;
bool g_in_gx = false;
if (x.is_donatable()) {
gx.move_shared_buffer(x);
x_in_gx = true;
gx.copy_shared_buffer(x);
} else if (g.is_donatable()) {
gx.move_shared_buffer(g);
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
}
if (g_copied && !g_in_gx) {
d.add_temporary(g, s.index);
}
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
@@ -141,11 +142,9 @@ void RMSNormVJP::eval_gpu(
// gradients before they are accumulated.
array gw_temp =
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
bool g_in_gw = false;
if (has_w) {
if (!g_in_gx && g.is_donatable()) {
gw_temp.move_shared_buffer(g);
g_in_gw = true;
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
d.add_temporary(gw_temp, s.index);
@@ -189,9 +188,9 @@ void RMSNormVJP::eval_gpu(
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
compute_encoder.set_input_array(g, 2);
compute_encoder.set_output_array(gx, 3);
compute_encoder.set_output_array(gw_temp, 4);
compute_encoder.set_bytes(eps_, 5);
@@ -216,35 +215,35 @@ void LayerNorm::eval_gpu(
auto& out = outputs[0];
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& {
auto set_output = [&s, &out](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
};
const array& x = check_input(inputs[0]);
const array x = set_output(inputs[0]);
const array& w = inputs[1];
const array& b = inputs[2];
if (x.is_donatable()) {
out.move_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
@@ -290,8 +289,6 @@ void LayerNorm::eval_gpu(
compute_encoder.set_bytes(b_stride, 7);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);
}
void LayerNormVJP::eval_gpu(
@@ -303,21 +300,22 @@ void LayerNormVJP::eval_gpu(
// Ensure row contiguity. We could relax this step by checking that the array
// is contiguous (no broadcasts or holes) and that the input strides are the
// same as the cotangent strides but for now this is simpler.
auto check_input = [&d, &s](const array& x) -> array {
auto check_input = [&s](const array& x) -> std::pair<array, bool> {
if (x.flags().row_contiguous) {
return x;
return {x, false};
}
array x_copy(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
d.add_temporary(x_copy, s.index);
return x_copy;
return {x_copy, true};
};
const array& x = check_input(inputs[0]);
bool donate_x = inputs[0].is_donatable();
bool donate_g = inputs[3].is_donatable();
auto [x, copied] = check_input(inputs[0]);
donate_x |= copied;
const array& w = inputs[1];
const array& b = inputs[2];
const array& g = check_input(inputs[3]);
auto [g, g_copied] = check_input(inputs[3]);
donate_g |= g_copied;
array& gx = outputs[0];
array& gw = outputs[1];
array& gb = outputs[2];
@@ -326,17 +324,18 @@ void LayerNormVJP::eval_gpu(
bool has_w = w.ndim() != 0;
// Allocate space for the outputs
bool x_in_gx = false;
bool g_in_gx = false;
if (x.is_donatable()) {
gx.move_shared_buffer(x);
x_in_gx = true;
} else if (g.is_donatable()) {
gx.move_shared_buffer(g);
if (donate_x) {
gx.copy_shared_buffer(x);
} else if (donate_g) {
gx.copy_shared_buffer(g);
g_in_gx = true;
} else {
gx.set_data(allocator::malloc_or_wait(gx.nbytes()));
}
if (g_copied && !g_in_gx) {
d.add_temporary(g, s.index);
}
auto axis_size = static_cast<uint32_t>(x.shape().back());
int n_rows = x.data_size() / axis_size;
@@ -345,15 +344,13 @@ void LayerNormVJP::eval_gpu(
// gradient accumulators.
array gw_temp =
(has_w) ? array({n_rows, x.shape().back()}, gw.dtype(), nullptr, {}) : w;
bool g_in_gw = false;
if (has_w) {
if (!g_in_gx && g.is_donatable()) {
gw_temp.move_shared_buffer(g);
g_in_gw = true;
if (!g_in_gx && donate_g) {
gw_temp.copy_shared_buffer(g);
} else {
gw_temp.set_data(allocator::malloc_or_wait(gw_temp.nbytes()));
d.add_temporary(gw_temp, s.index);
}
d.add_temporary(gw_temp, s.index);
}
gw.set_data(allocator::malloc_or_wait(gw.nbytes()));
gb.set_data(allocator::malloc_or_wait(gb.nbytes()));
@@ -364,14 +361,7 @@ void LayerNormVJP::eval_gpu(
ReductionPlan plan(
ReductionOpType::ContiguousStridedReduce, {n_rows}, {axis_size});
strided_reduce_general_dispatch(
g_in_gx ? gx : (g_in_gw ? gw_temp : g),
gb,
"sum",
plan,
{0},
compute_encoder,
d,
s);
g, gb, "sum", plan, {0}, compute_encoder, d, s);
}
const int simd_size = 32;
@@ -409,9 +399,9 @@ void LayerNormVJP::eval_gpu(
uint32_t w_stride = (w.ndim() == 1) ? w.strides()[0] : 0;
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(x_in_gx ? gx : x, 0);
compute_encoder.set_input_array(x, 0);
compute_encoder.set_input_array(w, 1);
compute_encoder.set_input_array(g_in_gx ? gx : (g_in_gw ? gw_temp : g), 2);
compute_encoder.set_input_array(g, 2);
compute_encoder.set_output_array(gx, 3);
compute_encoder.set_output_array(gw_temp, 4);
compute_encoder.set_bytes(eps_, 5);

View File

@@ -5,12 +5,10 @@
#include <sstream>
#include "mlx/backend/common/compiled.h"
#include "mlx/backend/common/load.h"
#include "mlx/backend/common/slicing.h"
#include "mlx/backend/common/utils.h"
#include "mlx/backend/metal/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/event.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/utils.h"
@@ -45,7 +43,8 @@ void reshape(const array& in, array& out, Stream s) {
shared_buffer_reshape(in, out_strides, out);
}
}
array compute_dynamic_offset(
static array compute_dynamic_offset(
const array& indices,
const Strides& strides,
const std::vector<int>& axes,
@@ -57,7 +56,7 @@ array compute_dynamic_offset(
bool donate = indices.is_donatable() &&
(indices.data_size() * indices.itemsize()) >= offset.itemsize();
if (donate) {
offset.move_shared_buffer(indices);
offset.copy_shared_buffer(indices);
} else {
offset.set_data(allocator::malloc_or_wait(offset.itemsize()));
}
@@ -88,7 +87,7 @@ array compute_dynamic_offset(
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(donate ? offset : indices, 0);
compute_encoder.set_input_array(indices, 0);
compute_encoder.set_output_array(offset, 1);
compute_encoder.set_vector_bytes(strides, 2);
compute_encoder.set_vector_bytes(axes, 3);
@@ -254,7 +253,7 @@ void Contiguous::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
if (in.flags().row_contiguous ||
(allow_col_major_ && in.flags().col_contiguous)) {
move_or_copy(in, out);
out.copy_shared_buffer(in);
} else {
copy_gpu(in, out, CopyType::General);
}
@@ -302,31 +301,7 @@ void Unflatten::eval_gpu(const std::vector<array>& inputs, array& out) {
}
void Load::eval_gpu(const std::vector<array>& inputs, array& out) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto read_task = [out = unsafe_weak_copy(out),
offset = offset_,
reader = reader_,
swap_endianness = swap_endianness_]() mutable {
load(out, offset, reader, swap_endianness);
};
// Limit the size that the command buffer will wait on to avoid timing out
// on the event (<4 seconds).
if (out.nbytes() > (1 << 28)) {
read_task();
return;
}
auto fut = io::thread_pool().enqueue(std::move(read_task)).share();
auto e = Event(stream());
e.set_value(1);
encode_wait(e);
auto signal_task = [e = std::move(e), fut = std::move(fut)]() mutable {
fut.wait();
e.signal();
};
scheduler::enqueue(io_stream(), std::move(signal_task));
throw std::runtime_error("[Load::eval_gpu] Not implemented.");
}
void NumberOfElements::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -452,7 +427,7 @@ void DynamicSliceUpdate::eval_gpu(
auto& start_indices = inputs[2];
if (upd.size() == 0) {
move_or_copy(in, out);
out.copy_shared_buffer(in);
return;
}
@@ -491,7 +466,7 @@ void SliceUpdate::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& upd = inputs[1];
if (upd.size() == 0) {
move_or_copy(in, out);
out.copy_shared_buffer(in);
return;
}
@@ -575,8 +550,8 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
strides[i] *= ibytes;
strides[i] /= obytes;
}
move_or_copy(
in, out, strides, in.flags(), in.data_size() * ibytes / obytes);
out.copy_shared_buffer(
in, strides, in.flags(), in.data_size() * ibytes / obytes);
} else {
auto tmp = array(in.shape(), in.dtype(), nullptr, {});
tmp.set_data(allocator::malloc_or_wait(tmp.nbytes()));
@@ -587,7 +562,7 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
flags.row_contiguous = true;
auto max_dim = std::max_element(out.shape().begin(), out.shape().end());
flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim;
out.move_shared_buffer(tmp, out.strides(), flags, out.size());
out.copy_shared_buffer(tmp, out.strides(), flags, out.size());
}
}

View File

@@ -41,7 +41,7 @@ void RoPE::eval_gpu(
} else if (in.flags().row_contiguous) {
if (in.is_donatable()) {
donated = true;
out.move_shared_buffer(in);
out.copy_shared_buffer(in);
} else {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
}

View File

@@ -152,7 +152,7 @@ void sdpa_vector(
compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
compute_encoder.set_input_array(q, 0);
compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(v, 2);
compute_encoder.set_output_array(out, 3);
@@ -242,7 +242,7 @@ void sdpa_vector_2pass(
compute_encoder.set_compute_pipeline_state(kernel);
// Set its arguments
compute_encoder.set_input_array(q.data_shared_ptr() == nullptr ? out : q, 0);
compute_encoder.set_input_array(q, 0);
compute_encoder.set_input_array(k, 1);
compute_encoder.set_input_array(v, 2);
compute_encoder.set_output_array(intermediate, 3);
@@ -357,7 +357,7 @@ void ScaledDotProductAttention::eval_gpu(
// Donate the query if possible
if (q.is_donatable() && (q.shape(2) == 1 || !q.flags().row_contiguous) &&
q.size() == o.size()) {
o.move_shared_buffer(q);
o.copy_shared_buffer(q);
} else {
if (o.shape(2) == 1) {
o.set_data(allocator::malloc_or_wait(o.nbytes()));

View File

@@ -21,7 +21,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
auto in = inputs[0];
if (in.flags().contiguous && in.strides()[axis_] != 0) {
if (donate && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
@@ -33,7 +33,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
array arr_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, arr_copy, CopyType::General, s);
in = std::move(arr_copy);
out.move_shared_buffer(in);
out.copy_shared_buffer(in);
}
bool contiguous = in.strides()[axis_] == 1;
@@ -68,8 +68,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
if (contiguous) {
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
size_t size = in.shape(axis_);
compute_encoder.set_bytes(size, 2);

View File

@@ -22,31 +22,32 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& d = metal::device(s.device);
// Make sure that the last dimension is contiguous
std::vector<array> copies;
auto check_input = [&copies, &s](const array& x) -> const array& {
auto set_output = [&s, &out](const array& x) {
bool no_copy = x.flags().contiguous && x.strides()[x.ndim() - 1] == 1;
if (no_copy && x.ndim() > 1) {
auto s = x.strides()[x.ndim() - 2];
no_copy &= (s == 0 || s == x.shape().back());
}
if (no_copy) {
if (x.is_donatable()) {
out.copy_shared_buffer(x);
} else {
out.set_data(
allocator::malloc_or_wait(x.data_size() * x.itemsize()),
x.data_size(),
x.strides(),
x.flags());
}
return x;
} else {
copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
copy_gpu(x, copies.back(), CopyType::General, s);
return copies.back();
auto x_copy = array(x.shape(), x.dtype(), nullptr, {});
copy_gpu(x, x_copy, CopyType::General, s);
out.copy_shared_buffer(x_copy);
return x_copy;
}
};
const array& in = check_input(inputs[0]);
if (in.is_donatable()) {
out.move_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * in.itemsize()),
in.data_size(),
in.strides(),
in.flags());
}
const array in = set_output(inputs[0]);
int axis_size = in.shape().back();
int n_rows = in.data_size() / axis_size;
@@ -82,14 +83,11 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
}
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
compute_encoder.set_bytes(axis_size, 2);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@@ -71,12 +71,9 @@ void ternary_op_gpu_inplace(
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
bool donate_a = a.data_shared_ptr() == nullptr;
bool donate_b = b.data_shared_ptr() == nullptr;
bool donate_c = c.data_shared_ptr() == nullptr;
compute_encoder.set_input_array(donate_a ? out : a, 0);
compute_encoder.set_input_array(donate_b ? out : b, 1);
compute_encoder.set_input_array(donate_c ? out : c, 2);
compute_encoder.set_input_array(a, 0);
compute_encoder.set_input_array(b, 1);
compute_encoder.set_input_array(c, 2);
compute_encoder.set_output_array(out, 3);
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
@@ -129,7 +126,7 @@ void ternary_op_gpu(
auto& b = inputs[1];
auto& c = inputs[2];
TernaryOpType topt = get_ternary_op_type(a, b, c);
set_ternary_op_output_data(a, b, c, out, topt, true /* donate_with_move */);
set_ternary_op_output_data(a, b, c, out, topt);
ternary_op_gpu_inplace(inputs, out, op, s);
}

View File

@@ -57,8 +57,7 @@ void unary_op_gpu_inplace(
auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(
in.data_shared_ptr() == nullptr ? out : in, 0);
compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(out, 1);
if (!contig) {
// Launch up to 3D grid of threads
@@ -95,7 +94,7 @@ void unary_op_gpu(
bool contig = in.flags().contiguous;
if (contig) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.move_shared_buffer(in);
out.copy_shared_buffer(in);
} else {
out.set_data(
allocator::malloc_or_wait(in.data_size() * out.itemsize()),
@@ -169,7 +168,7 @@ void Round::eval_gpu(const std::vector<array>& inputs, array& out) {
unary_op_gpu(inputs, out, get_primitive_string(this));
} else {
// No-op integer types
move_or_copy(in, out);
out.copy_shared_buffer(in);
}
}

View File

@@ -69,19 +69,4 @@ void concatenate(std::string& acc, T first, Args... args) {
concatenate(acc, args...);
}
/**
* Get a new array that refers to the same data but has a non-owning pointer to
* them.
*/
inline array unsafe_weak_copy(const array& x) {
return array(
x.buffer(),
x.shape(),
x.dtype(),
x.strides(),
x.data_size(),
x.flags(),
[](auto b) {});
}
} // namespace mlx::core