mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
fix (#1523)
This commit is contained in:
parent
430ffef58a
commit
dad1b00b13
@ -130,10 +130,11 @@ CommandEncoder::~CommandEncoder() {
|
|||||||
enc_->release();
|
enc_->release();
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::set_array(
|
void CommandEncoder::set_input_array(
|
||||||
const array& a,
|
const array& a,
|
||||||
int idx,
|
int idx,
|
||||||
int64_t offset /* = 0 */) {
|
int64_t offset /* = 0 */) {
|
||||||
|
all_inputs_.insert(a.buffer().ptr());
|
||||||
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
|
||||||
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
|
if (auto it = outputs_.find(r_buf); it != outputs_.end()) {
|
||||||
// Insert a barrier
|
// Insert a barrier
|
||||||
@ -149,20 +150,12 @@ void CommandEncoder::set_array(
|
|||||||
enc_->setBuffer(a_buf, base_offset, idx);
|
enc_->setBuffer(a_buf, base_offset, idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CommandEncoder::set_input_array(
|
|
||||||
const array& a,
|
|
||||||
int idx,
|
|
||||||
int64_t offset /* = 0 */) {
|
|
||||||
all_inputs_.insert(a.buffer().ptr());
|
|
||||||
set_array(a, idx, offset);
|
|
||||||
}
|
|
||||||
|
|
||||||
void CommandEncoder::set_output_array(
|
void CommandEncoder::set_output_array(
|
||||||
array& a,
|
array& a,
|
||||||
int idx,
|
int idx,
|
||||||
int64_t offset /* = 0 */) {
|
int64_t offset /* = 0 */) {
|
||||||
// Add barriers before adding the output to the output set
|
// Add barriers before adding the output to the output set
|
||||||
set_array(a, idx, offset);
|
set_input_array(a, idx, offset);
|
||||||
all_outputs_.insert(a.buffer().ptr());
|
all_outputs_.insert(a.buffer().ptr());
|
||||||
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
auto buf = static_cast<MTL::Resource*>(a.buffer().ptr());
|
||||||
if (concurrent_) {
|
if (concurrent_) {
|
||||||
|
@ -83,7 +83,6 @@ struct CommandEncoder {
|
|||||||
};
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void set_array(const array& a, int idx, int64_t offset);
|
|
||||||
MTL::ComputeCommandEncoder* enc_;
|
MTL::ComputeCommandEncoder* enc_;
|
||||||
bool concurrent_{false};
|
bool concurrent_{false};
|
||||||
std::unordered_set<MTL::Resource*> outputs_;
|
std::unordered_set<MTL::Resource*> outputs_;
|
||||||
|
Loading…
Reference in New Issue
Block a user