register in task submission

This commit is contained in:
Awni Hannun 2024-11-05 14:42:51 -08:00
parent 76f275b4df
commit fc6494cac7
3 changed files with 50 additions and 38 deletions

View File

@ -134,7 +134,6 @@ void CommandEncoder::set_input_array(
const array& a,
int idx,
int64_t offset /* = 0 */) {
all_inputs_.insert(a.buffer().ptr());
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
needs_barrier_ =
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
@ -151,7 +150,6 @@ void CommandEncoder::set_output_array(
int64_t offset /* = 0 */) {
// Add barriers before adding the output to the output set
set_input_array(a, idx, offset);
all_outputs_.insert(a.buffer().ptr());
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
if (concurrent_) {
concurrent_outputs_.insert(buf);
@ -185,6 +183,24 @@ void CommandEncoder::dispatchThreads(
enc_->dispatchThreads(grid_dims, group_dims);
}
void DeviceStream::register_inputs(const std::vector<array>& arrays) {
std::lock_guard<std::mutex> lk(fence_mtx);
for (auto& a : arrays) {
auto buf = a.buffer().ptr();
if (auto it = outputs.find(buf); it != outputs.end()) {
waiting_on.insert(it->second);
}
}
}
void DeviceStream::register_outputs(const std::vector<array>& arrays) {
for (auto& a : arrays) {
if (a.data<void>() != nullptr) {
all_outputs.insert(a.buffer().ptr());
}
}
}
Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
@ -281,38 +297,22 @@ void Device::end_encoding(int index) {
// boundaries. These can be removed early from the encoders inputs and
// outputs since they don't need synchronization.
auto& enc = *stream.encoder;
// Remove temporaries from inputs and outputs
for (auto& t : stream.temporaries) {
if (t.data<void>() != nullptr) {
enc.outputs().erase(t.buffer().ptr());
enc.inputs().erase(t.buffer().ptr());
}
}
// Keep references to the fences we waited on and put them
// in the completion handler so they are not prematurely released
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
for (auto& f : stream.waiting_on) {
enc->waitForFence(f->fence);
}
{
std::lock_guard<std::mutex> lk(stream.fence_mtx);
for (auto in : enc.inputs()) {
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
// If we've already waited on a fence, don't wait on it again.
if (waiting_on.find(it->second) == waiting_on.end()) {
enc->waitForFence(it->second->fence);
waiting_on.insert(it->second);
}
}
}
for (auto out : enc.outputs()) {
for (auto out : stream.all_outputs) {
stream.outputs[out] = stream.fence;
}
}
enc->updateFence(stream.fence->fence);
stream.buffer->addCompletedHandler(
[&stream,
waiting_on = std::move(waiting_on),
waiting_on = std::move(stream.waiting_on),
fence = std::move(stream.fence),
outputs = std::move(enc.outputs()),
outputs = std::move(stream.all_outputs),
temporaries =
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
temporaries.clear();
@ -325,6 +325,9 @@ void Device::end_encoding(int index) {
}
}
});
} else {
stream.all_outputs.clear();
stream.waiting_on.clear();
}
stream.encoder = nullptr;
}

View File

@ -73,16 +73,6 @@ struct CommandEncoder {
}
~CommandEncoder();
// Inputs to all kernels in the encoder including temporaries
std::unordered_set<const void*>& inputs() {
return all_inputs_;
};
// Outputs of all kernels in the encoder including temporaries
std::unordered_set<const void*> outputs() {
return all_outputs_;
};
private:
MTL::ComputeCommandEncoder* enc_;
bool needs_barrier_{false};
@ -90,8 +80,6 @@ struct CommandEncoder {
std::unordered_set<MTL::Resource*> prev_outputs_;
std::unordered_set<MTL::Resource*> next_outputs_;
std::unordered_set<MTL::Resource*> concurrent_outputs_;
std::unordered_set<const void*> all_inputs_;
std::unordered_set<const void*> all_outputs_;
};
struct Fence {
@ -121,11 +109,15 @@ struct DeviceStream {
MTL::CommandBuffer* buffer{nullptr};
int buffer_ops{0};
// The command encoder, fence, and temporaries are updated between command
// encoders
void register_inputs(const std::vector<array>& inputs);
void register_outputs(const std::vector<array>& inputs);
// The following variables are all reset between command encoders
std::unique_ptr<CommandEncoder> encoder{nullptr};
std::shared_ptr<Fence> fence;
std::vector<array> temporaries;
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
std::unordered_set<const void*> all_outputs;
};
class Device {
@ -190,10 +182,15 @@ class Device {
void set_residency_set(const MTL::ResidencySet* residency_set);
DeviceStream& get_stream(int index) {
return stream_map_.find(index)->second;
}
private:
DeviceStream& get_stream_(int index) {
return stream_map_.find(index)->second;
}
MTL::Library* get_library_cache_(const std::string& name);
MTL::Library* get_library_(const std::string& name);

View File

@ -60,8 +60,20 @@ std::function<void()> make_task(array arr, bool signal) {
inputs = arr.inputs();
}
// if (dstream.active_without_fence()) {
// for (auto& in : arr.inputs()) {
// // If any input needs a fence
// if (command_encoder.needs_fence(in)) {
// end_encoding();
// break;
// }
// }
// }
auto& dstream = d.get_stream(s.index);
dstream.register_inputs(arr.inputs());
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
arr.primitive().eval_gpu(arr.inputs(), outputs);
dstream.register_outputs(outputs);
}
std::vector<std::shared_ptr<array::Data>> buffers;
for (auto& in : arr.inputs()) {