mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
register in task submission
This commit is contained in:
parent
76f275b4df
commit
fc6494cac7
@ -134,7 +134,6 @@ void CommandEncoder::set_input_array(
|
||||
const array& a,
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
all_inputs_.insert(a.buffer().ptr());
|
||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
needs_barrier_ =
|
||||
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
|
||||
@ -151,7 +150,6 @@ void CommandEncoder::set_output_array(
|
||||
int64_t offset /* = 0 */) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_input_array(a, idx, offset);
|
||||
all_outputs_.insert(a.buffer().ptr());
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
if (concurrent_) {
|
||||
concurrent_outputs_.insert(buf);
|
||||
@ -185,6 +183,24 @@ void CommandEncoder::dispatchThreads(
|
||||
enc_->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void DeviceStream::register_inputs(const std::vector<array>& arrays) {
|
||||
std::lock_guard<std::mutex> lk(fence_mtx);
|
||||
for (auto& a : arrays) {
|
||||
auto buf = a.buffer().ptr();
|
||||
if (auto it = outputs.find(buf); it != outputs.end()) {
|
||||
waiting_on.insert(it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceStream::register_outputs(const std::vector<array>& arrays) {
|
||||
for (auto& a : arrays) {
|
||||
if (a.data<void>() != nullptr) {
|
||||
all_outputs.insert(a.buffer().ptr());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Device::Device() {
|
||||
auto pool = new_scoped_memory_pool();
|
||||
device_ = load_device();
|
||||
@ -281,38 +297,22 @@ void Device::end_encoding(int index) {
|
||||
// boundaries. These can be removed early from the encoders inputs and
|
||||
// outputs since they don't need synchronization.
|
||||
auto& enc = *stream.encoder;
|
||||
// Remove temporaries from inputs and outputs
|
||||
for (auto& t : stream.temporaries) {
|
||||
if (t.data<void>() != nullptr) {
|
||||
enc.outputs().erase(t.buffer().ptr());
|
||||
enc.inputs().erase(t.buffer().ptr());
|
||||
}
|
||||
}
|
||||
|
||||
// Keep references to the fences we waited on and put them
|
||||
// in the completion handler so they are not prematurely released
|
||||
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
|
||||
for (auto& f : stream.waiting_on) {
|
||||
enc->waitForFence(f->fence);
|
||||
}
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(stream.fence_mtx);
|
||||
for (auto in : enc.inputs()) {
|
||||
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
|
||||
// If we've already waited on a fence, don't wait on it again.
|
||||
if (waiting_on.find(it->second) == waiting_on.end()) {
|
||||
enc->waitForFence(it->second->fence);
|
||||
waiting_on.insert(it->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto out : enc.outputs()) {
|
||||
for (auto out : stream.all_outputs) {
|
||||
stream.outputs[out] = stream.fence;
|
||||
}
|
||||
}
|
||||
enc->updateFence(stream.fence->fence);
|
||||
stream.buffer->addCompletedHandler(
|
||||
[&stream,
|
||||
waiting_on = std::move(waiting_on),
|
||||
waiting_on = std::move(stream.waiting_on),
|
||||
fence = std::move(stream.fence),
|
||||
outputs = std::move(enc.outputs()),
|
||||
outputs = std::move(stream.all_outputs),
|
||||
temporaries =
|
||||
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
|
||||
temporaries.clear();
|
||||
@ -325,6 +325,9 @@ void Device::end_encoding(int index) {
|
||||
}
|
||||
}
|
||||
});
|
||||
} else {
|
||||
stream.all_outputs.clear();
|
||||
stream.waiting_on.clear();
|
||||
}
|
||||
stream.encoder = nullptr;
|
||||
}
|
||||
|
@ -73,16 +73,6 @@ struct CommandEncoder {
|
||||
}
|
||||
~CommandEncoder();
|
||||
|
||||
// Inputs to all kernels in the encoder including temporaries
|
||||
std::unordered_set<const void*>& inputs() {
|
||||
return all_inputs_;
|
||||
};
|
||||
|
||||
// Outputs of all kernels in the encoder including temporaries
|
||||
std::unordered_set<const void*> outputs() {
|
||||
return all_outputs_;
|
||||
};
|
||||
|
||||
private:
|
||||
MTL::ComputeCommandEncoder* enc_;
|
||||
bool needs_barrier_{false};
|
||||
@ -90,8 +80,6 @@ struct CommandEncoder {
|
||||
std::unordered_set<MTL::Resource*> prev_outputs_;
|
||||
std::unordered_set<MTL::Resource*> next_outputs_;
|
||||
std::unordered_set<MTL::Resource*> concurrent_outputs_;
|
||||
std::unordered_set<const void*> all_inputs_;
|
||||
std::unordered_set<const void*> all_outputs_;
|
||||
};
|
||||
|
||||
struct Fence {
|
||||
@ -121,11 +109,15 @@ struct DeviceStream {
|
||||
MTL::CommandBuffer* buffer{nullptr};
|
||||
int buffer_ops{0};
|
||||
|
||||
// The command encoder, fence, and temporaries are updated between command
|
||||
// encoders
|
||||
void register_inputs(const std::vector<array>& inputs);
|
||||
void register_outputs(const std::vector<array>& inputs);
|
||||
|
||||
// The following variables are all reset between command encoders
|
||||
std::unique_ptr<CommandEncoder> encoder{nullptr};
|
||||
std::shared_ptr<Fence> fence;
|
||||
std::vector<array> temporaries;
|
||||
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
|
||||
std::unordered_set<const void*> all_outputs;
|
||||
};
|
||||
|
||||
class Device {
|
||||
@ -190,10 +182,15 @@ class Device {
|
||||
|
||||
void set_residency_set(const MTL::ResidencySet* residency_set);
|
||||
|
||||
DeviceStream& get_stream(int index) {
|
||||
return stream_map_.find(index)->second;
|
||||
}
|
||||
|
||||
private:
|
||||
DeviceStream& get_stream_(int index) {
|
||||
return stream_map_.find(index)->second;
|
||||
}
|
||||
|
||||
MTL::Library* get_library_cache_(const std::string& name);
|
||||
|
||||
MTL::Library* get_library_(const std::string& name);
|
||||
|
@ -60,8 +60,20 @@ std::function<void()> make_task(array arr, bool signal) {
|
||||
inputs = arr.inputs();
|
||||
}
|
||||
|
||||
// if (dstream.active_without_fence()) {
|
||||
// for (auto& in : arr.inputs()) {
|
||||
// // If any input needs a fence
|
||||
// if (command_encoder.needs_fence(in)) {
|
||||
// end_encoding();
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
auto& dstream = d.get_stream(s.index);
|
||||
dstream.register_inputs(arr.inputs());
|
||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||
dstream.register_outputs(outputs);
|
||||
}
|
||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||
for (auto& in : arr.inputs()) {
|
||||
|
Loading…
Reference in New Issue
Block a user