mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Split encoders in non-concurrent context with a max ops per encoder (#1085)
* split encoders * fix race
This commit is contained in:
@@ -25,6 +25,7 @@ namespace {
|
||||
|
||||
// TODO nicer way to set this or possibly expose as an environment variable
|
||||
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
|
||||
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
|
||||
|
||||
constexpr const char* default_mtllib_path = METAL_PATH;
|
||||
|
||||
@@ -37,7 +38,6 @@ auto load_device() {
|
||||
}
|
||||
return device;
|
||||
}
|
||||
|
||||
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
|
||||
MTL::Device* device,
|
||||
const char* path) {
|
||||
@@ -116,6 +116,33 @@ MTL::Library* load_library(
|
||||
|
||||
} // namespace
|
||||
|
||||
void CommandEncoder::dispatchThreadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
num_dispatches++;
|
||||
enc->dispatchThreadgroups(grid_dims, group_dims);
|
||||
maybe_split();
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreads(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
num_dispatches++;
|
||||
enc->dispatchThreads(grid_dims, group_dims);
|
||||
maybe_split();
|
||||
}
|
||||
|
||||
void CommandEncoder::maybe_split() {
|
||||
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
|
||||
enc->endEncoding();
|
||||
enc->release();
|
||||
num_dispatches = 0;
|
||||
outputs.clear();
|
||||
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
enc->retain();
|
||||
}
|
||||
}
|
||||
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
@@ -130,9 +157,6 @@ Device::~Device() {
|
||||
for (auto& b : buffer_map_) {
|
||||
b.second.second->release();
|
||||
}
|
||||
for (auto& e : encoder_map_) {
|
||||
(*e.second)->release();
|
||||
}
|
||||
for (auto& k : kernel_map_) {
|
||||
k.second->release();
|
||||
}
|
||||
@@ -169,27 +193,26 @@ void Device::increment_command_buffer_ops(int index) {
|
||||
|
||||
MTL::CommandBuffer* Device::get_command_buffer(int index) {
|
||||
auto bit = buffer_map_.find(index);
|
||||
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
|
||||
}
|
||||
if (bit == buffer_map_.end()) {
|
||||
auto qit = queue_map_.find(index);
|
||||
if (qit == queue_map_.end()) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Attempting to get command buffer for invalid queue.");
|
||||
}
|
||||
|
||||
MTL::CommandBuffer* Device::new_command_buffer(int index) {
|
||||
auto qit = queue_map_.find(index);
|
||||
if (qit == queue_map_.end()) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Attempting to get command buffer for invalid queue.");
|
||||
auto cb = qit->second->commandBufferWithUnretainedReferences();
|
||||
|
||||
if (!cb) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Unable to create new command buffer");
|
||||
}
|
||||
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
cb->retain();
|
||||
|
||||
bit = buffer_map_.insert({index, {0, cb}}).first;
|
||||
}
|
||||
|
||||
auto cb = qit->second->commandBufferWithUnretainedReferences();
|
||||
|
||||
if (!cb) {
|
||||
throw std::runtime_error(
|
||||
"[metal::Device] Unable to create new command buffer");
|
||||
}
|
||||
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
cb->retain();
|
||||
|
||||
return buffer_map_.insert({index, {0, cb}}).first->second.second;
|
||||
return bit->second.second;
|
||||
}
|
||||
|
||||
void Device::commit_command_buffer(int index) {
|
||||
@@ -200,25 +223,15 @@ void Device::commit_command_buffer(int index) {
|
||||
}
|
||||
|
||||
void Device::end_encoding(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit != encoder_map_.end()) {
|
||||
(*eit->second)->endEncoding();
|
||||
(*eit->second)->release();
|
||||
encoder_map_.erase(eit);
|
||||
}
|
||||
encoder_map_.erase(index);
|
||||
}
|
||||
|
||||
CommandEncoder& Device::get_command_encoder(int index) {
|
||||
auto eit = encoder_map_.find(index);
|
||||
if (eit == encoder_map_.end()) {
|
||||
auto cb = get_command_buffer(index);
|
||||
auto compute_encoder =
|
||||
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
|
||||
// Increment ref count so the buffer is not garbage collected
|
||||
compute_encoder->retain();
|
||||
eit = encoder_map_
|
||||
.emplace(index, std::make_unique<CommandEncoder>(compute_encoder))
|
||||
.first;
|
||||
eit =
|
||||
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
|
||||
}
|
||||
return *(eit->second);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user