mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -136,13 +136,8 @@ void CommandEncoder::set_input_array(
|
||||
int64_t offset /* = 0 */) {
|
||||
all_inputs_.insert(a.buffer().ptr());
|
||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
|
||||
// Insert a barrier
|
||||
enc_->memoryBarrier(&r_buf, 1);
|
||||
|
||||
// Remove the output
|
||||
outputs_.erase(it);
|
||||
}
|
||||
needs_barrier_ =
|
||||
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
|
||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||
auto base_offset = a.data<char>() -
|
||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||
@@ -161,19 +156,32 @@ void CommandEncoder::set_output_array(
|
||||
if (concurrent_) {
|
||||
concurrent_outputs_.insert(buf);
|
||||
} else {
|
||||
outputs_.insert(buf);
|
||||
next_outputs_.insert(buf);
|
||||
}
|
||||
}
|
||||
|
||||
void CommandEncoder::maybeInsertBarrier() {
|
||||
if (needs_barrier_) {
|
||||
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
|
||||
needs_barrier_ = false;
|
||||
prev_outputs_ = std::move(next_outputs_);
|
||||
} else {
|
||||
prev_outputs_.insert(next_outputs_.begin(), next_outputs_.end());
|
||||
}
|
||||
next_outputs_.clear();
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreadgroups(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
enc_->dispatchThreadgroups(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
void CommandEncoder::dispatchThreads(
|
||||
MTL::Size grid_dims,
|
||||
MTL::Size group_dims) {
|
||||
maybeInsertBarrier();
|
||||
enc_->dispatchThreads(grid_dims, group_dims);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user