mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
parent
62f297b51d
commit
f1951d6cce
@ -136,13 +136,8 @@ void CommandEncoder::set_input_array(
|
||||
int64_t offset /* = 0 */) {
|
||||
all_inputs_.insert(a.buffer().ptr());
|
||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
|
||||
// Insert a barrier
|
||||
enc_->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs_.erase(it);
|
||||
}
|
||||
needs_barrier_ =
|
||||
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
@ -161,19 +156,32 @@ void CommandEncoder::set_output_array(
|
||||
if (concurrent_) {
|
||||
concurrent_outputs_.insert(buf);
|
||||
} else {
|
||||
outputs_.insert(buf);
|
||||
next_outputs_.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::maybeInsertBarrier() {
|
||||
if (needs_barrier_) {
|
||||
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
|
||||
needs_barrier_ = false;
|
||||
prev_outputs_ = std::move(next_outputs_);
|
||||
} else {
|
||||
prev_outputs_.insert(next_outputs_.begin(), next_outputs_.end());
|
||||
}
|
||||
next_outputs_.clear();
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
enc_->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreads(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
enc_->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
|
@ -49,7 +49,7 @@ struct CommandEncoder {
|
||||
}
|
||||
~ConcurrentContext() {
|
||||
enc.concurrent_ = false;
|
||||
enc.outputs_.insert(
|
||||
enc.prev_outputs_.insert(
|
||||
enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
|
||||
enc.concurrent_outputs_.clear();
|
||||
}
|
||||
@ -66,6 +66,7 @@ struct CommandEncoder {
|
||||
void set_output_array(array& a, int idx, int64_t offset = 0);
|
||||
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||
void maybeInsertBarrier();
|
||||
|
||||
ConcurrentContext start_concurrent() {
|
||||
return ConcurrentContext(*this);
|
||||
@ -84,8 +85,10 @@ struct CommandEncoder {
|
||||
|
||||
private:
|
||||
MTL::ComputeCommandEncoder* enc_;
|
||||
bool needs_barrier_{false};
|
||||
bool concurrent_{false};
|
||||
std::unordered_set<MTL::Resource*> outputs_;
|
||||
std::unordered_set<MTL::Resource*> prev_outputs_;
|
||||
std::unordered_set<MTL::Resource*> next_outputs_;
|
||||
std::unordered_set<MTL::Resource*> concurrent_outputs_;
|
||||
std::unordered_set<const void*> all_inputs_;
|
||||
std::unordered_set<const void*> all_outputs_;
|
||||
|
Loading…
Reference in New Issue
Block a user