Remove Hazard tracking with Fences (#1509)

* remove hazard tracking

* with fence map

* no hazard tracking with fences

* nits

* fix fence retain

* cleanup

* fix quantized rebase
This commit is contained in:
Awni Hannun 2024-10-21 19:33:32 -07:00 committed by GitHub
parent d15fa13daf
commit c26208f67d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 268 additions and 299 deletions

View File

@ -136,13 +136,6 @@ Now make an array, and benchmark both functions:
On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is On an M1 Max the times are 15.5 and 3.1 milliseconds. The compiled ``gelu`` is
five times faster. five times faster.
.. note::
As of the latest MLX, CPU functions are not fully compiled. Compiling CPU
functions can still be helpful, but won't typically result in as large a
speedup as compiling operations that run on the GPU.
Debugging Debugging
--------- ---------

View File

@ -180,6 +180,7 @@ void array::move_shared_buffer(
auto char_offset = sizeof(char) * itemsize() * offset; auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>( array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset); static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
other.array_desc_->data_ptr = nullptr;
} }
void array::move_shared_buffer(array other) { void array::move_shared_buffer(array other) {

View File

@ -205,7 +205,7 @@ Buffer MetalAllocator::malloc(size_t size, bool allow_swap /* = false */) {
// Allocate new buffer if needed // Allocate new buffer if needed
size_t res_opt = MTL::ResourceStorageModeShared; size_t res_opt = MTL::ResourceStorageModeShared;
res_opt |= MTL::ResourceHazardTrackingModeTracked; res_opt |= MTL::ResourceHazardTrackingModeUntracked;
lk.unlock(); lk.unlock();
buf = device_->newBuffer(size, res_opt); buf = device_->newBuffer(size, res_opt);
lk.lock(); lk.lock();

View File

@ -918,14 +918,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
"[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions."); "[Convolution::eval_gpu] Only supports 1D, 2D or 3D convolutions.");
} }
// Clear copies // Record copies
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -77,12 +77,7 @@ void CustomKernel::eval_gpu(
MTL::Size grid_dims = MTL::Size(gx, gy, gz); MTL::Size grid_dims = MTL::Size(gx, gy, gz);
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@ -20,7 +20,6 @@ namespace {
// TODO nicer way to set this or possibly expose as an environment variable // TODO nicer way to set this or possibly expose as an environment variable
constexpr int MAX_BUFFERS_PER_QUEUE = 12; constexpr int MAX_BUFFERS_PER_QUEUE = 12;
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
constexpr const char* default_mtllib_path = METAL_PATH; constexpr const char* default_mtllib_path = METAL_PATH;
@ -121,33 +120,41 @@ MTL::Library* load_library(
} // namespace } // namespace
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) { CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) {
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent); enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain(); enc_->retain();
} }
CommandEncoder::~CommandEncoder() { CommandEncoder::~CommandEncoder() {
enc->endEncoding(); enc_->endEncoding();
enc->release(); enc_->release();
}
void CommandEncoder::set_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
// Insert a barrier
enc_->memoryBarrier(&r_buf, 1);
// Remove the output
outputs_.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc_->setBuffer(a_buf, base_offset, idx);
} }
void CommandEncoder::set_input_array( void CommandEncoder::set_input_array(
const array& a, const array& a,
int idx, int idx,
int64_t offset /* = 0 */) { int64_t offset /* = 0 */) {
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr())); all_inputs_.insert(a.buffer().ptr());
if (auto it = outputs.find(r_buf); it != outputs.end()) { set_array(a, idx, offset);
// Insert a barrier
enc->memoryBarrier(&r_buf, 1);
// Remove the output
outputs.erase(it);
}
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
auto base_offset = a.data<char>() -
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
base_offset += offset;
enc->setBuffer(a_buf, base_offset, idx);
} }
void CommandEncoder::set_output_array( void CommandEncoder::set_output_array(
@ -155,40 +162,26 @@ void CommandEncoder::set_output_array(
int idx, int idx,
int64_t offset /* = 0 */) { int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set // Add barriers before adding the output to the output set
set_input_array(a, idx, offset); set_array(a, idx, offset);
all_outputs_.insert(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr()); auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent) { if (concurrent_) {
concurrent_outputs.insert(buf); concurrent_outputs_.insert(buf);
} else { } else {
outputs.insert(buf); outputs_.insert(buf);
} }
} }
void CommandEncoder::dispatchThreadgroups( void CommandEncoder::dispatchThreadgroups(
MTL::Size grid_dims, MTL::Size grid_dims,
MTL::Size group_dims) { MTL::Size group_dims) {
num_dispatches++; enc_->dispatchThreadgroups(grid_dims, group_dims);
enc->dispatchThreadgroups(grid_dims, group_dims);
maybe_split();
} }
void CommandEncoder::dispatchThreads( void CommandEncoder::dispatchThreads(
MTL::Size grid_dims, MTL::Size grid_dims,
MTL::Size group_dims) { MTL::Size group_dims) {
num_dispatches++; enc_->dispatchThreads(grid_dims, group_dims);
enc->dispatchThreads(grid_dims, group_dims);
maybe_split();
}
void CommandEncoder::maybe_split() {
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
enc->endEncoding();
enc->release();
num_dispatches = 0;
outputs.clear();
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
enc->retain();
}
} }
Device::Device() { Device::Device() {
@ -199,12 +192,6 @@ Device::Device() {
Device::~Device() { Device::~Device() {
auto pool = new_scoped_memory_pool(); auto pool = new_scoped_memory_pool();
for (auto& q : queue_map_) {
q.second->release();
}
for (auto& b : buffer_map_) {
b.second.second->release();
}
for (auto& k : kernel_map_) { for (auto& k : kernel_map_) {
k.second->release(); k.second->release();
} }
@ -225,61 +212,125 @@ void Device::new_queue(int index) {
throw std::runtime_error( throw std::runtime_error(
"[metal::Device] Failed to make new command queue."); "[metal::Device] Failed to make new command queue.");
} }
queue_map_.insert({index, q}); stream_map_.emplace(index, q);
buffer_map_.insert({index, {0, nullptr}});
encoder_map_.insert({index, nullptr});
} }
int Device::get_command_buffer_ops(int index) { int Device::get_command_buffer_ops(int index) {
return buffer_map_[index].first; return get_stream_(index).buffer_ops;
} }
void Device::increment_command_buffer_ops(int index) { void Device::increment_command_buffer_ops(int index) {
buffer_map_[index].first++; get_stream_(index).buffer_ops++;
} }
MTL::CommandBuffer* Device::get_command_buffer(int index) { MTL::CommandBuffer* Device::get_command_buffer(int index) {
auto bit = buffer_map_.find(index); auto& stream = get_stream_(index);
if (bit->second.second == nullptr) { if (stream.buffer == nullptr) {
auto qit = queue_map_.find(index); stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
if (qit == queue_map_.end()) { if (!stream.buffer) {
throw std::runtime_error(
"[metal::Device] Attempting to get command buffer for invalid queue.");
}
auto cb = qit->second->commandBufferWithUnretainedReferences();
if (!cb) {
throw std::runtime_error( throw std::runtime_error(
"[metal::Device] Unable to create new command buffer"); "[metal::Device] Unable to create new command buffer");
} }
// Increment ref count so the buffer is not garbage collected // Increment ref count so the buffer is not garbage collected
cb->retain(); stream.buffer->retain();
bit->second = {0, cb};
} }
return bit->second.second; return stream.buffer;
} }
void Device::commit_command_buffer(int index) { void Device::commit_command_buffer(int index) {
auto bit = buffer_map_.find(index); auto& stream = get_stream_(index);
bit->second.second->commit(); stream.buffer->commit();
bit->second.second->release(); stream.buffer->release();
bit->second = {0, nullptr}; stream.buffer = nullptr;
stream.buffer_ops = 0;
}
void Device::add_temporary(array arr, int index) {
get_stream_(index).temporaries.push_back(std::move(arr));
}
void Device::add_temporaries(std::vector<array> arrays, int index) {
if (arrays.empty()) {
return;
}
auto& stream = get_stream_(index);
stream.temporaries.insert(
stream.temporaries.end(),
std::make_move_iterator(arrays.begin()),
std::make_move_iterator(arrays.end()));
} }
void Device::end_encoding(int index) { void Device::end_encoding(int index) {
encoder_map_[index] = nullptr; auto& stream = get_stream_(index);
if (stream.encoder != nullptr) {
// Each command encoder has a unique fence. We also store a map of
// all previous outputs of command encoders to their corresponding fence.
// - The command encoder records its inputs and outputs.
// - Wait on a fence if any inputs in the encoder are outputs of a previous
// encoder.
// - Update the map of outputs to include this command encoder's outputs.
// - Always signal this command encoders fence.
// - Add a completion handler for this command encoder that removes outputs
// from the map to limit the growth of the map and avoid unecessary waits
// - Temporaries are a special case as they do not cross command encoder
// boundaries. These can be removed early from the encoders inputs and
// outputs since they don't need synchronization.
auto& enc = *stream.encoder;
// Remove temporaries from inputs and outputs
for (auto& t : stream.temporaries) {
if (t.data<void>() != nullptr) {
enc.outputs().erase(t.buffer().ptr());
enc.inputs().erase(t.buffer().ptr());
}
}
// Keep references to the fences we waited on and put them
// in the completion handler so they are not prematurely released
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
{
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto in : enc.inputs()) {
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
// If we've already waited on a fence, don't wait on it again.
if (waiting_on.find(it->second) == waiting_on.end()) {
enc->waitForFence(it->second->fence);
waiting_on.insert(it->second);
}
}
}
for (auto out : enc.outputs()) {
stream.outputs[out] = stream.fence;
}
}
enc->updateFence(stream.fence->fence);
stream.buffer->addCompletedHandler(
[&stream,
waiting_on = std::move(waiting_on),
fence = std::move(stream.fence),
outputs = std::move(enc.outputs()),
temporaries =
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
temporaries.clear();
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto o : outputs) {
if (auto it = stream.outputs.find(o); it != stream.outputs.end()) {
if (it->second == fence) {
stream.outputs.erase(it);
}
}
}
});
}
stream.encoder = nullptr;
} }
CommandEncoder& Device::get_command_encoder(int index) { CommandEncoder& Device::get_command_encoder(int index) {
auto eit = encoder_map_.find(index); auto& stream = get_stream_(index);
if (eit->second == nullptr) { if (stream.encoder == nullptr) {
auto cb = get_command_buffer(index); stream.encoder = std::make_unique<CommandEncoder>(stream.buffer);
eit->second = std::make_unique<CommandEncoder>(cb); stream.fence = std::make_shared<Fence>(device_->newFence());
} }
return *(eit->second); return *stream.encoder;
} }
void Device::register_library( void Device::register_library(

View File

@ -45,13 +45,13 @@ struct CommandEncoder {
struct ConcurrentContext { struct ConcurrentContext {
ConcurrentContext(CommandEncoder& enc) : enc(enc) { ConcurrentContext(CommandEncoder& enc) : enc(enc) {
enc.concurrent = true; enc.concurrent_ = true;
} }
~ConcurrentContext() { ~ConcurrentContext() {
enc.concurrent = false; enc.concurrent_ = false;
enc.outputs.insert( enc.outputs_.insert(
enc.concurrent_outputs.begin(), enc.concurrent_outputs.end()); enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
enc.concurrent_outputs.clear(); enc.concurrent_outputs_.clear();
} }
private: private:
@ -59,7 +59,7 @@ struct CommandEncoder {
}; };
MTL::ComputeCommandEncoder* operator->() { MTL::ComputeCommandEncoder* operator->() {
return enc; return enc_;
} }
void set_input_array(const array& a, int idx, int64_t offset = 0); void set_input_array(const array& a, int idx, int64_t offset = 0);
@ -70,18 +70,60 @@ struct CommandEncoder {
ConcurrentContext start_concurrent() { ConcurrentContext start_concurrent() {
return ConcurrentContext(*this); return ConcurrentContext(*this);
} }
~CommandEncoder(); ~CommandEncoder();
private: // Inputs to all kernels in the encoder including temporaries
void maybe_split(); std::unordered_set<const void*>& inputs() {
return all_inputs_;
};
int num_dispatches{0}; // Outputs of all kernels in the encoder including temporaries
MTL::CommandBuffer* cbuf; std::unordered_set<const void*> outputs() {
MTL::ComputeCommandEncoder* enc; return all_outputs_;
bool concurrent{false}; };
std::unordered_set<MTL::Resource*> outputs;
std::unordered_set<MTL::Resource*> concurrent_outputs; private:
void set_array(const array& a, int idx, int64_t offset);
MTL::ComputeCommandEncoder* enc_;
bool concurrent_{false};
std::unordered_set<MTL::Resource*> outputs_;
std::unordered_set<MTL::Resource*> concurrent_outputs_;
std::unordered_set<const void*> all_inputs_;
std::unordered_set<const void*> all_outputs_;
};
struct Fence {
Fence(MTL::Fence* fence) : fence(fence) {}
~Fence() {
fence->release();
}
MTL::Fence* fence;
};
struct DeviceStream {
DeviceStream(MTL::CommandQueue* queue) : queue(queue) {};
~DeviceStream() {
queue->release();
if (buffer != nullptr) {
buffer->release();
}
};
MTL::CommandQueue* queue;
// A map of prior command encoder outputs to their corresponding fence
std::unordered_map<const void*, std::shared_ptr<Fence>> outputs;
// Used to allow thread-safe access to the outputs map
std::mutex fence_mtx;
// The buffer and buffer op count are updated
// between command buffers
MTL::CommandBuffer* buffer{nullptr};
int buffer_ops{0};
// The command encoder, fence, and temporaries are updated between command
// encoders
std::unique_ptr<CommandEncoder> encoder{nullptr};
std::shared_ptr<Fence> fence;
std::vector<array> temporaries;
}; };
class Device { class Device {
@ -136,7 +178,14 @@ class Device {
MTL::ArgumentEncoder* argument_encoder( MTL::ArgumentEncoder* argument_encoder(
const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const; const std::vector<MTL::ArgumentDescriptor*>& arg_descs) const;
// Record temporary arrays for the given stream index
void add_temporary(array arr, int index);
void add_temporaries(std::vector<array> arrays, int index);
private: private:
DeviceStream& get_stream_(int index) {
return stream_map_.find(index)->second;
}
MTL::Library* get_library_cache_(const std::string& name); MTL::Library* get_library_cache_(const std::string& name);
MTL::Library* get_library_(const std::string& name); MTL::Library* get_library_(const std::string& name);
@ -170,9 +219,7 @@ class Device {
const std::vector<MTL::Function*>& linked_functions = {}); const std::vector<MTL::Function*>& linked_functions = {});
MTL::Device* device_; MTL::Device* device_;
std::unordered_map<int32_t, MTL::CommandQueue*> queue_map_; std::unordered_map<int32_t, DeviceStream> stream_map_;
std::unordered_map<int32_t, std::pair<int, MTL::CommandBuffer*>> buffer_map_;
std::unordered_map<int32_t, std::unique_ptr<CommandEncoder>> encoder_map_;
std::shared_mutex kernel_mtx_; std::shared_mutex kernel_mtx_;
std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_; std::unordered_map<std::string, MTL::ComputePipelineState*> kernel_map_;

View File

@ -575,10 +575,7 @@ void fft_op(
auto plan = plan_fft(n); auto plan = plan_fft(n);
if (plan.four_step) { if (plan.four_step) {
four_step_fft(in, out, axis, inverse, real, plan, copies, s); four_step_fft(in, out, axis, inverse, real, plan, copies, s);
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(copies), s.index);
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
return; return;
} }
@ -744,12 +741,7 @@ void fft_op(
compute_encoder->dispatchThreads(grid_dims, group_dims); compute_encoder->dispatchThreads(grid_dims, group_dims);
} }
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
void fft_op( void fft_op(
@ -792,8 +784,7 @@ void nd_fft_op(
} }
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(temp_arrs), s.index);
[temp_arrs](MTL::CommandBuffer*) mutable { temp_arrs.clear(); });
} }
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) { void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@ -174,12 +174,7 @@ void Hadamard::eval_gpu(const std::vector<array>& inputs, array& out) {
launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_); launch_hadamard(in_contiguous, out, "n" + kernel_name, scale_);
} }
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -226,13 +226,8 @@ void steel_matmul_regular(
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies // Record copies
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
void steel_matmul( void steel_matmul(
@ -382,12 +377,7 @@ void steel_matmul(
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return; return;
} }
@ -435,8 +425,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
if (a_pre.size() == 0 || b_pre.size() == 0) { if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype()); array zero = array(0, a_pre.dtype());
fill_gpu(zero, out, s); fill_gpu(zero, out, s);
auto command_buffer = d.get_command_buffer(s.index); d.add_temporary(std::move(zero), s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
return; return;
} }
@ -588,12 +577,7 @@ void Matmul::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return; return;
} }
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
@ -798,12 +782,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return; return;
} }
@ -916,12 +895,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return; return;
} }
@ -1056,12 +1030,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) { void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -1080,8 +1049,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (a_pre.size() == 0 || b_pre.size() == 0) { if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype()); array zero = array(0, a_pre.dtype());
fill_gpu(zero, out, s); fill_gpu(zero, out, s);
auto command_buffer = d.get_command_buffer(s.index); d.add_temporary(std::move(zero), s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
return; return;
} }
@ -1356,12 +1324,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return; return;
} }
@ -1471,13 +1434,7 @@ void BlockMaskedMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies d.add_temporaries(std::move(copies), s.index);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) { void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -1496,8 +1453,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
if (a_pre.size() == 0 || b_pre.size() == 0) { if (a_pre.size() == 0 || b_pre.size() == 0) {
array zero = array(0, a_pre.dtype()); array zero = array(0, a_pre.dtype());
fill_gpu(zero, out, s); fill_gpu(zero, out, s);
auto command_buffer = d.get_command_buffer(s.index); d.add_temporary(std::move(zero), s.index);
command_buffer->addCompletedHandler([zero](MTL::CommandBuffer*) {});
return; return;
} }
@ -1703,12 +1659,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
return; return;
} }
@ -1847,13 +1798,7 @@ void GatherMM::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
// Clear copies d.add_temporaries(std::move(copies), s.index);
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -91,12 +91,8 @@ void RMSNorm::eval_gpu(
compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1); compute_encoder->setThreadgroupMemoryLength(simd_size * sizeof(float), 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(copies), s.index);
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
void RMSNormVJP::eval_gpu( void RMSNormVJP::eval_gpu(
@ -204,10 +200,7 @@ void RMSNormVJP::eval_gpu(
strided_reduce_general_dispatch( strided_reduce_general_dispatch(
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s); gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(copies), s.index);
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
} }
void LayerNorm::eval_gpu( void LayerNorm::eval_gpu(
@ -292,12 +285,8 @@ void LayerNorm::eval_gpu(
compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7); compute_encoder->setBytes(&b_stride, sizeof(uint32_t), 7);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(copies), s.index);
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
void LayerNormVJP::eval_gpu( void LayerNormVJP::eval_gpu(
@ -425,10 +414,7 @@ void LayerNormVJP::eval_gpu(
gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s); gw_temp, gw, "sum", plan, {0}, compute_encoder, d, s);
} }
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(copies), s.index);
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
} }
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@ -145,6 +145,7 @@ void launch_qmm(
} }
compute_encoder.dispatchThreadgroups(grid_dims, group_dims); compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
d.add_temporaries(std::move(copies), s.index);
} }
void qmm_op( void qmm_op(
@ -350,12 +351,7 @@ void fast::AffineQuantize::eval_gpu(
: MTL::Size(nthreads, 1, 1); : MTL::Size(nthreads, 1, 1);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -660,12 +660,7 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
in, out, op_name, plan, axes_, compute_encoder, d, s); in, out, op_name, plan, axes_, compute_encoder, d, s);
} }
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
// Nothing to reduce just initialize the output // Nothing to reduce just initialize the output

View File

@ -262,12 +262,7 @@ void ScaledDotProductAttention::eval_gpu(
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o); sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
} }
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core::fast } // namespace mlx::core::fast

View File

@ -107,13 +107,7 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
if (!copies.empty()) { d.add_temporaries(std::move(copies), s.index);
auto command_buffer = d.get_command_buffer(s.index);
command_buffer->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -88,12 +88,8 @@ void Softmax::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder->setBytes(&axis_size, sizeof(int), 2); compute_encoder->setBytes(&axis_size, sizeof(int), 2);
compute_encoder.dispatchThreads(grid_dims, group_dims); compute_encoder.dispatchThreads(grid_dims, group_dims);
} }
if (!copies.empty()) {
d.get_command_buffer(s.index)->addCompletedHandler( d.add_temporaries(std::move(copies), s.index);
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
}
} }
} // namespace mlx::core } // namespace mlx::core

View File

@ -252,11 +252,7 @@ void multi_block_sort(
(axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General, (axis == in.ndim() - 1) ? CopyType::Vector : CopyType::General,
s); s);
// Clear copies d.add_temporaries(std::move(copies), s.index);
d.get_command_buffer(s.index)->addCompletedHandler(
[copies = std::move(copies)](MTL::CommandBuffer*) mutable {
copies.clear();
});
} }
void gpu_merge_sort( void gpu_merge_sort(

View File

@ -769,7 +769,6 @@ class TestCompile(mlx_tests.MLXTestCase):
out = mx.compile(fn)(a, b) out = mx.compile(fn)(a, b)
expected = fn(a, b) expected = fn(a, b)
print((out - expected).abs().max())
self.assertTrue(mx.allclose(out, expected)) self.assertTrue(mx.allclose(out, expected))
def test_compile_many_inputs(self): def test_compile_many_inputs(self):