mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
parent
62f297b51d
commit
f1951d6cce
@ -136,13 +136,8 @@ void CommandEncoder::set_input_array(
|
|||||||
int64_t offset /* = 0 */) {
|
int64_t offset /* = 0 */) {
|
||||||
all_inputs_.insert(a.buffer().ptr());
|
all_inputs_.insert(a.buffer().ptr());
|
||||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||||
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
|
needs_barrier_ =
|
||||||
// Insert a barrier
|
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
|
||||||
enc_->memoryBarrier(&r_buf, 1);
|
|
||||||
|
|
||||||
// Remove the output
|
|
||||||
outputs_.erase(it);
|
|
||||||
}
|
|
||||||
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
auto a_buf = static_cast<const MTL::Buffer*>(a.buffer().ptr());
|
||||||
auto base_offset = a.data<char>() -
|
auto base_offset = a.data<char>() -
|
||||||
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
static_cast<char*>(const_cast<MTL::Buffer*>(a_buf)->contents());
|
||||||
@ -161,19 +156,32 @@ void CommandEncoder::set_output_array(
|
|||||||
if (concurrent_) {
|
if (concurrent_) {
|
||||||
concurrent_outputs_.insert(buf);
|
concurrent_outputs_.insert(buf);
|
||||||
} else {
|
} else {
|
||||||
outputs_.insert(buf);
|
next_outputs_.insert(buf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CommandEncoder::maybeInsertBarrier() {
|
||||||
|
if (needs_barrier_) {
|
||||||
|
enc_->memoryBarrier(MTL::BarrierScopeBuffers);
|
||||||
|
needs_barrier_ = false;
|
||||||
|
prev_outputs_ = std::move(next_outputs_);
|
||||||
|
} else {
|
||||||
|
prev_outputs_.insert(next_outputs_.begin(), next_outputs_.end());
|
||||||
|
}
|
||||||
|
next_outputs_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
void CommandEncoder::dispatchThreadgroups(
|
void CommandEncoder::dispatchThreadgroups(
|
||||||
MTL::Size grid_dims,
|
MTL::Size grid_dims,
|
||||||
MTL::Size group_dims) {
|
MTL::Size group_dims) {
|
||||||
|
maybeInsertBarrier();
|
||||||
enc_->dispatchThreadgroups(grid_dims, group_dims);
|
enc_->dispatchThreadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::dispatchThreads(
|
void CommandEncoder::dispatchThreads(
|
||||||
MTL::Size grid_dims,
|
MTL::Size grid_dims,
|
||||||
MTL::Size group_dims) {
|
MTL::Size group_dims) {
|
||||||
|
maybeInsertBarrier();
|
||||||
enc_->dispatchThreads(grid_dims, group_dims);
|
enc_->dispatchThreads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ struct CommandEncoder {
|
|||||||
}
|
}
|
||||||
~ConcurrentContext() {
|
~ConcurrentContext() {
|
||||||
enc.concurrent_ = false;
|
enc.concurrent_ = false;
|
||||||
enc.outputs_.insert(
|
enc.prev_outputs_.insert(
|
||||||
enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
|
enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end());
|
||||||
enc.concurrent_outputs_.clear();
|
enc.concurrent_outputs_.clear();
|
||||||
}
|
}
|
||||||
@ -66,6 +66,7 @@ struct CommandEncoder {
|
|||||||
void set_output_array(array& a, int idx, int64_t offset = 0);
|
void set_output_array(array& a, int idx, int64_t offset = 0);
|
||||||
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
|
||||||
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
|
||||||
|
void maybeInsertBarrier();
|
||||||
|
|
||||||
ConcurrentContext start_concurrent() {
|
ConcurrentContext start_concurrent() {
|
||||||
return ConcurrentContext(*this);
|
return ConcurrentContext(*this);
|
||||||
@ -84,8 +85,10 @@ struct CommandEncoder {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
MTL::ComputeCommandEncoder* enc_;
|
MTL::ComputeCommandEncoder* enc_;
|
||||||
|
bool needs_barrier_{false};
|
||||||
bool concurrent_{false};
|
bool concurrent_{false};
|
||||||
std::unordered_set<MTL::Resource*> outputs_;
|
std::unordered_set<MTL::Resource*> prev_outputs_;
|
||||||
|
std::unordered_set<MTL::Resource*> next_outputs_;
|
||||||
std::unordered_set<MTL::Resource*> concurrent_outputs_;
|
std::unordered_set<MTL::Resource*> concurrent_outputs_;
|
||||||
std::unordered_set<const void*> all_inputs_;
|
std::unordered_set<const void*> all_inputs_;
|
||||||
std::unordered_set<const void*> all_outputs_;
|
std::unordered_set<const void*> all_outputs_;
|
||||||
|
Loading…
Reference in New Issue
Block a user