mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
register in task submission
This commit is contained in:
parent
76f275b4df
commit
fc6494cac7
@ -134,7 +134,6 @@ void CommandEncoder::set_input_array(
|
|||||||
const array& a,
|
const array& a,
|
||||||
int idx,
|
int idx,
|
||||||
int64_t offset /* = 0 */) {
|
int64_t offset /* = 0 */) {
|
||||||
all_inputs_.insert(a.buffer().ptr());
|
|
||||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||||
needs_barrier_ =
|
needs_barrier_ =
|
||||||
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
|
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
|
||||||
@ -151,7 +150,6 @@ void CommandEncoder::set_output_array(
|
|||||||
int64_t offset /* = 0 */) {
|
int64_t offset /* = 0 */) {
|
||||||
// Add barriers before adding the output to the output set
|
// Add barriers before adding the output to the output set
|
||||||
set_input_array(a, idx, offset);
|
set_input_array(a, idx, offset);
|
||||||
all_outputs_.insert(a.buffer().ptr());
|
|
||||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||||
if (concurrent_) {
|
if (concurrent_) {
|
||||||
concurrent_outputs_.insert(buf);
|
concurrent_outputs_.insert(buf);
|
||||||
@ -185,6 +183,24 @@ void CommandEncoder::dispatchThreads(
|
|||||||
enc_->dispatchThreads(grid_dims, group_dims);
|
enc_->dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void DeviceStream::register_inputs(const std::vector<array>& arrays) {
|
||||||
|
std::lock_guard<std::mutex> lk(fence_mtx);
|
||||||
|
for (auto& a : arrays) {
|
||||||
|
auto buf = a.buffer().ptr();
|
||||||
|
if (auto it = outputs.find(buf); it != outputs.end()) {
|
||||||
|
waiting_on.insert(it->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceStream::register_outputs(const std::vector<array>& arrays) {
|
||||||
|
for (auto& a : arrays) {
|
||||||
|
if (a.data<void>() != nullptr) {
|
||||||
|
all_outputs.insert(a.buffer().ptr());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Device::Device() {
|
Device::Device() {
|
||||||
auto pool = new_scoped_memory_pool();
|
auto pool = new_scoped_memory_pool();
|
||||||
device_ = load_device();
|
device_ = load_device();
|
||||||
@ -281,38 +297,22 @@ void Device::end_encoding(int index) {
|
|||||||
// boundaries. These can be removed early from the encoders inputs and
|
// boundaries. These can be removed early from the encoders inputs and
|
||||||
// outputs since they don't need synchronization.
|
// outputs since they don't need synchronization.
|
||||||
auto& enc = *stream.encoder;
|
auto& enc = *stream.encoder;
|
||||||
// Remove temporaries from inputs and outputs
|
|
||||||
for (auto& t : stream.temporaries) {
|
|
||||||
if (t.data<void>() != nullptr) {
|
|
||||||
enc.outputs().erase(t.buffer().ptr());
|
|
||||||
enc.inputs().erase(t.buffer().ptr());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Keep references to the fences we waited on and put them
|
for (auto& f : stream.waiting_on) {
|
||||||
// in the completion handler so they are not prematurely released
|
enc->waitForFence(f->fence);
|
||||||
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
|
}
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lk(stream.fence_mtx);
|
std::lock_guard<std::mutex> lk(stream.fence_mtx);
|
||||||
for (auto in : enc.inputs()) {
|
for (auto out : stream.all_outputs) {
|
||||||
if (auto it = stream.outputs.find(in); it != stream.outputs.end()) {
|
|
||||||
// If we've already waited on a fence, don't wait on it again.
|
|
||||||
if (waiting_on.find(it->second) == waiting_on.end()) {
|
|
||||||
enc->waitForFence(it->second->fence);
|
|
||||||
waiting_on.insert(it->second);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (auto out : enc.outputs()) {
|
|
||||||
stream.outputs[out] = stream.fence;
|
stream.outputs[out] = stream.fence;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
enc->updateFence(stream.fence->fence);
|
enc->updateFence(stream.fence->fence);
|
||||||
stream.buffer->addCompletedHandler(
|
stream.buffer->addCompletedHandler(
|
||||||
[&stream,
|
[&stream,
|
||||||
waiting_on = std::move(waiting_on),
|
waiting_on = std::move(stream.waiting_on),
|
||||||
fence = std::move(stream.fence),
|
fence = std::move(stream.fence),
|
||||||
outputs = std::move(enc.outputs()),
|
outputs = std::move(stream.all_outputs),
|
||||||
temporaries =
|
temporaries =
|
||||||
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
|
std::move(stream.temporaries)](MTL::CommandBuffer*) mutable {
|
||||||
temporaries.clear();
|
temporaries.clear();
|
||||||
@ -325,6 +325,9 @@ void Device::end_encoding(int index) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
} else {
|
||||||
|
stream.all_outputs.clear();
|
||||||
|
stream.waiting_on.clear();
|
||||||
}
|
}
|
||||||
stream.encoder = nullptr;
|
stream.encoder = nullptr;
|
||||||
}
|
}
|
||||||
|
@ -73,16 +73,6 @@ struct CommandEncoder {
|
|||||||
}
|
}
|
||||||
~CommandEncoder();
|
~CommandEncoder();
|
||||||
|
|
||||||
// Inputs to all kernels in the encoder including temporaries
|
|
||||||
std::unordered_set<const void*>& inputs() {
|
|
||||||
return all_inputs_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Outputs of all kernels in the encoder including temporaries
|
|
||||||
std::unordered_set<const void*> outputs() {
|
|
||||||
return all_outputs_;
|
|
||||||
};
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MTL::ComputeCommandEncoder* enc_;
|
MTL::ComputeCommandEncoder* enc_;
|
||||||
bool needs_barrier_{false};
|
bool needs_barrier_{false};
|
||||||
@ -90,8 +80,6 @@ struct CommandEncoder {
|
|||||||
std::unordered_set<MTL::Resource*> prev_outputs_;
|
std::unordered_set<MTL::Resource*> prev_outputs_;
|
||||||
std::unordered_set<MTL::Resource*> next_outputs_;
|
std::unordered_set<MTL::Resource*> next_outputs_;
|
||||||
std::unordered_set<MTL::Resource*> concurrent_outputs_;
|
std::unordered_set<MTL::Resource*> concurrent_outputs_;
|
||||||
std::unordered_set<const void*> all_inputs_;
|
|
||||||
std::unordered_set<const void*> all_outputs_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Fence {
|
struct Fence {
|
||||||
@ -121,11 +109,15 @@ struct DeviceStream {
|
|||||||
MTL::CommandBuffer* buffer{nullptr};
|
MTL::CommandBuffer* buffer{nullptr};
|
||||||
int buffer_ops{0};
|
int buffer_ops{0};
|
||||||
|
|
||||||
// The command encoder, fence, and temporaries are updated between command
|
void register_inputs(const std::vector<array>& inputs);
|
||||||
// encoders
|
void register_outputs(const std::vector<array>& inputs);
|
||||||
|
|
||||||
|
// The following variables are all reset between command encoders
|
||||||
std::unique_ptr<CommandEncoder> encoder{nullptr};
|
std::unique_ptr<CommandEncoder> encoder{nullptr};
|
||||||
std::shared_ptr<Fence> fence;
|
std::shared_ptr<Fence> fence;
|
||||||
std::vector<array> temporaries;
|
std::vector<array> temporaries;
|
||||||
|
std::unordered_set<std::shared_ptr<Fence>> waiting_on;
|
||||||
|
std::unordered_set<const void*> all_outputs;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Device {
|
class Device {
|
||||||
@ -190,10 +182,15 @@ class Device {
|
|||||||
|
|
||||||
void set_residency_set(const MTL::ResidencySet* residency_set);
|
void set_residency_set(const MTL::ResidencySet* residency_set);
|
||||||
|
|
||||||
|
DeviceStream& get_stream(int index) {
|
||||||
|
return stream_map_.find(index)->second;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DeviceStream& get_stream_(int index) {
|
DeviceStream& get_stream_(int index) {
|
||||||
return stream_map_.find(index)->second;
|
return stream_map_.find(index)->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
MTL::Library* get_library_cache_(const std::string& name);
|
MTL::Library* get_library_cache_(const std::string& name);
|
||||||
|
|
||||||
MTL::Library* get_library_(const std::string& name);
|
MTL::Library* get_library_(const std::string& name);
|
||||||
|
@ -60,8 +60,20 @@ std::function<void()> make_task(array arr, bool signal) {
|
|||||||
inputs = arr.inputs();
|
inputs = arr.inputs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if (dstream.active_without_fence()) {
|
||||||
|
// for (auto& in : arr.inputs()) {
|
||||||
|
// // If any input needs a fence
|
||||||
|
// if (command_encoder.needs_fence(in)) {
|
||||||
|
// end_encoding();
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
auto& dstream = d.get_stream(s.index);
|
||||||
|
dstream.register_inputs(arr.inputs());
|
||||||
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
debug_set_primitive_buffer_label(command_buffer, arr.primitive());
|
||||||
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
arr.primitive().eval_gpu(arr.inputs(), outputs);
|
||||||
|
dstream.register_outputs(outputs);
|
||||||
}
|
}
|
||||||
std::vector<std::shared_ptr<array::Data>> buffers;
|
std::vector<std::shared_ptr<array::Data>> buffers;
|
||||||
for (auto& in : arr.inputs()) {
|
for (auto& in : arr.inputs()) {
|
||||||
|
Loading…
Reference in New Issue
Block a user