mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Remove Hazard tracking with Fences (#1509)
* remove hazard tracking * with fence map * no hazard tracking with fences * nits * fix fence retain * cleanup * fix quantized rebase
This commit is contained in:
@@ -20,7 +20,6 @@ namespace {
|
||||
|
||||
// TODO nicer way to set this or possibly expose as an environment variable
|
||||
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
|
||||
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
@@ -121,33 +120,41 @@ MTL::Library* load_library(
|
||||
|
||||
} // namespace
|
||||
|
||||
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) {
|
||||
enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc_->retain();
|
||||
}
|
||||
|
||||
CommandEncoder::~CommandEncoder() {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
enc_->endEncoding();
|
||||
enc_->release();
|
||||
}
|
||||
|
||||
void CommandEncoder::set_array(
|
||||
const array& a,
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
|
||||
// Insert a barrier
|
||||
enc_->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs_.erase(it);
|
||||
}
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc_->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
void CommandEncoder::set_input_array(
|
||||
const array& a,
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs.find(r_buf); it != outputs.end()) {
|
||||
// Insert a barrier
|
||||
enc->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs.erase(it);
|
||||
}
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
base_offset += offset;
|
||||
enc->setBuffer(a_buf, base_offset, idx);
|
||||
all_inputs_.insert(a.buffer().ptr());
|
||||
set_array(a, idx, offset);
|
||||
}
|
||||
|
||||
void CommandEncoder::set_output_array(
|
||||
@@ -155,40 +162,26 @@ void CommandEncoder::set_output_array(
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
set_array(a, idx, offset);
|
||||
all_outputs_.insert(a.buffer().ptr());
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
if (concurrent) {
|
||||
concurrent_outputs.insert(buf);
|
||||
if (concurrent_) {
|
||||
concurrent_outputs_.insert(buf);
|
||||
} else {
|
||||
outputs.insert(buf);
|
||||
outputs_.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
num_dispatches++;
|
||||
enc->dispatchThreadgroups(grid_dims, group_dims);
|
||||
maybe_split();
|
||||
enc_->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreads(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
num_dispatches++;
|
||||
enc->dispatchThreads(grid_dims, group_dims);
|
||||
maybe_split();
|
||||
}
|
||||
|
||||
void CommandEncoder::maybe_split() {
|
||||
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
num_dispatches = 0;
|
||||
outputs.clear();
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
}
|
||||
enc_->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
Device::Device() {
|
||||
@@ -199,12 +192,6 @@ Device::Device() {
|
||||
|
||||
Device::~Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
for (auto& q : queue_map_) {
|
||||
q.second->release();
|
||||
}
|
||||
for (auto& b : buffer_map_) {
|
||||
b.second.second->release();
|
||||
}
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
}
|
||||
@@ -225,61 +212,125 @@ void Device::new_queue(int index) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Failed to make new command queue.");
|
||||
}
|
||||
queue_map_.insert({index, q});
|
||||
buffer_map_.insert({index, {0, nullptr}});
|
||||
encoder_map_.insert({index, nullptr});
|
||||
stream_map_.emplace(index, q);
|
||||
}
|
||||
|
||||
int Device::get_command_buffer_ops(int index) {
|
||||
return buffer_map_[index].first;
|
||||
return get_stream_(index).buffer_ops;
|
||||
}
|
||||
|
||||
void Device::increment_command_buffer_ops(int index) {
|
||||
buffer_map_[index].first++;
|
||||
get_stream_(index).buffer_ops++;
|
||||
}
|
||||
|
||||
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
if (bit->second.second == nullptr) {
|
||||
auto qit = queue_map_.find(index);
|
||||
if (qit == queue_map_.end()) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Attempting to get command buffer for invalid queue.");
|
||||
}
|
||||
|
||||
auto cb = qit->second->commandBufferWithUnretainedReferences();
|
||||
|
||||
if (!cb) {
|
||||
auto& stream = get_stream_(index);
|
||||
if (stream.buffer == nullptr) {
|
||||
stream.buffer = stream.queue->commandBufferWithUnretainedReferences();
|
||||
if (!stream.buffer) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Unable to create new command buffer");
|
||||
}
|
||||
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
cb->retain();
|
||||
|
||||
bit->second = {0, cb};
|
||||
stream.buffer->retain();
|
||||
}
|
||||
return bit->second.second;
|
||||
return stream.buffer;
|
||||
}
|
||||
|
||||
void Device::commit_command_buffer(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
bit->second.second->commit();
|
||||
bit->second.second->release();
|
||||
bit->second = {0, nullptr};
|
||||
auto& stream = get_stream_(index);
|
||||
stream.buffer->commit();
|
||||
stream.buffer->release();
|
||||
stream.buffer = nullptr;
|
||||
stream.buffer_ops = 0;
|
||||
}
|
||||
|
||||
void Device::add_temporary(array arr, int index) {
|
||||
get_stream_(index).temporaries.push_back(std::move(arr));
|
||||
}
|
||||
|
||||
void Device::add_temporaries(std::vector<array> arrays, int index) {
|
||||
if (arrays.empty()) {
|
||||
return;
|
||||
}
|
||||
auto& stream = get_stream_(index);
|
||||
stream.temporaries.insert(
|
||||
stream.temporaries.end(),
|
||||
std::make_move_iterator(arrays.begin()),
|
||||
std::make_move_iterator(arrays.end()));
|
||||
}
|
||||
|
||||
void Device::end_encoding(int index) {
|
||||
encoder_map_[index] = nullptr;
|
||||
auto& stream = get_stream_(index);
|
||||
if (stream.encoder != nullptr) {
|
||||
// Each command encoder has a unique fence. We also store a map of
|
||||
// all previous outputs of command encoders to their corresponding fence.
|
||||
// - The command encoder records its inputs and outputs.
|
||||
// - Wait on a fence if any inputs in the encoder are outputs of a previous
|
||||
// encoder.
|
||||
// - Update the map of outputs to include this command encoder's outputs.
|
||||
// - Always signal this command encoders fence.
|
||||
// - Add a completion handler for this command encoder that removes outputs
|
||||
// from the map to limit the growth of the map and avoid unecessary waits
|
||||
// - Temporaries are a special case as they do not cross command encoder
|
||||
// boundaries. These can be removed early from the encoders inputs and
|
||||
// outputs since they don't need synchronization.
|
||||
auto& enc = *stream.encoder;
|
||||
// Remove temporaries from inputs and outputs
|
||||
for (auto& t : stream.temporaries) {
|
||||
if (t.data<void>() != nullptr) {
|
||||
enc.outputs().erase(t.buffer().ptr());
|
||||
enc.inputs().erase(t.buffer().ptr());
|
||||
}
|
||||
}
|
||||
|
||||
// Keep references to the fences we waited on and put them
|
||||
// in the completion handler so they are not prematurely released
|
||||
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(stream.fence_mtx);
|
||||
for (auto in : enc.inputs()) {
|
||||
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
|
||||
// If we've already waited on a fence, don't wait on it again.
|
||||
if (waiting_on.find(it->second) == waiting_on.end()) {
|
||||
enc->waitForFence(it->second->fence);
|
||||
waiting_on.insert(it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto out : enc.outputs()) {
|
||||
stream.outputs[out] = stream.fence;
|
||||
}
|
||||
}
|
||||
enc->updateFence(stream.fence->fence);
|
||||
stream.buffer->addCompletedHandler(
|
||||
[&stream,
|
||||
waiting_on = std::move(waiting_on),
|
||||
fence = std::move(stream.fence),
|
||||
outputs = std::move(enc.outputs()),
|
||||
temporaries =
|
||||
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
|
||||
temporaries.clear();
|
||||
std::lock_guard<std::mutex> lk(stream.fence_mtx);
|
||||
for (auto o : outputs) {
|
||||
if (auto it = stream.outputs.find(o); it != stream.outputs.end()) {
|
||||
if (it->second == fence) {
|
||||
stream.outputs.erase(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
stream.encoder = nullptr;
|
||||
}
|
||||
|
||||
CommandEncoder& Device::get_command_encoder(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit->second == nullptr) {
|
||||
auto cb = get_command_buffer(index);
|
||||
eit->second = std::make_unique<CommandEncoder>(cb);
|
||||
auto& stream = get_stream_(index);
|
||||
if (stream.encoder == nullptr) {
|
||||
stream.encoder = std::make_unique<CommandEncoder>(stream.buffer);
|
||||
stream.fence = std::make_shared<Fence>(device_->newFence());
|
||||
}
|
||||
return *(eit->second);
|
||||
return *stream.encoder;
|
||||
}
|
||||
|
||||
void Device::register_library(
|
||||
|
||||
Reference in New Issue
Block a user