mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
fix (#1523)
This commit is contained in:
parent
430ffef58a
commit
dad1b00b13
@ -130,10 +130,11 @@ CommandEncoder::~CommandEncoder() {
|
||||
enc_->release();
|
||||
}
|
||||
|
||||
void CommandEncoder::set_array(
|
||||
void CommandEncoder::set_input_array(
|
||||
const array& a,
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
all_inputs_.insert(a.buffer().ptr());
|
||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
|
||||
// Insert a barrier
|
||||
@ -149,20 +150,12 @@ void CommandEncoder::set_array(
|
||||
enc_->setBuffer(a_buf, base_offset, idx);
|
||||
}
|
||||
|
||||
void CommandEncoder::set_input_array(
|
||||
const array& a,
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
all_inputs_.insert(a.buffer().ptr());
|
||||
set_array(a, idx, offset);
|
||||
}
|
||||
|
||||
void CommandEncoder::set_output_array(
|
||||
array& a,
|
||||
int idx,
|
||||
int64_t offset /* = 0 */) {
|
||||
// Add barriers before adding the output to the output set
|
||||
set_array(a, idx, offset);
|
||||
set_input_array(a, idx, offset);
|
||||
all_outputs_.insert(a.buffer().ptr());
|
||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||
if (concurrent_) {
|
||||
|
@ -83,7 +83,6 @@ struct CommandEncoder {
|
||||
};
|
||||
|
||||
private:
|
||||
void set_array(const array& a, int idx, int64_t offset);
|
||||
MTL::ComputeCommandEncoder* enc_;
|
||||
bool concurrent_{false};
|
||||
std::unordered_set<MTL::Resource*> outputs_;
|
||||
|
Loading…
Reference in New Issue
Block a user