mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
parent
29221fa238
commit
dc175f08d3
11
mlx/array.h
11
mlx/array.h
@ -256,6 +256,17 @@ class array {
|
||||
array_desc_->position = position;
|
||||
}
|
||||
|
||||
/** The i-th output of the array's primitive. */
|
||||
const array& output(int i) const {
|
||||
if (i == array_desc_->position) {
|
||||
return *this;
|
||||
} else if (i < array_desc_->position) {
|
||||
return siblings()[i];
|
||||
} else {
|
||||
return siblings()[i + 1];
|
||||
}
|
||||
};
|
||||
|
||||
/** The outputs of the array's primitive (i.e. this array and
|
||||
* its siblings) in the order the primitive expects. */
|
||||
std::vector<array> outputs() const {
|
||||
|
@ -76,7 +76,7 @@ void eval(std::vector<array> outputs) {
|
||||
// If the input is being computed on a different stream, we need to
|
||||
// manage the dependency.
|
||||
if (a.primitive().stream() != in.primitive().stream()) {
|
||||
deps.insert({in.primitive_id(), std::shared_future<void>{}});
|
||||
deps.insert({in.output(0).id(), std::shared_future<void>{}});
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -96,8 +96,7 @@ void eval(std::vector<array> outputs) {
|
||||
};
|
||||
|
||||
recurse(synchronizer);
|
||||
uintptr_t synch_id = synchronizer.primitive_id();
|
||||
deps.insert({synch_id, std::shared_future<void>{}});
|
||||
deps.insert({synchronizer.id(), std::shared_future<void>{}});
|
||||
|
||||
std::vector<std::shared_ptr<std::promise<void>>> ps;
|
||||
while (!tape.empty()) {
|
||||
@ -113,13 +112,12 @@ void eval(std::vector<array> outputs) {
|
||||
auto stream = arr.primitive().stream();
|
||||
std::vector<std::shared_future<void>> arr_deps;
|
||||
for (auto& in : arr.inputs()) {
|
||||
// TODO that's a bug
|
||||
if (auto it = deps.find(in.primitive_id()); it != deps.end()) {
|
||||
if (auto it = deps.find(in.output(0).id()); it != deps.end()) {
|
||||
arr_deps.push_back(it->second);
|
||||
}
|
||||
}
|
||||
std::shared_ptr<std::promise<void>> p;
|
||||
if (auto it = deps.find(arr.primitive_id()); it != deps.end()) {
|
||||
if (auto it = deps.find(arr.output(0).id()); it != deps.end()) {
|
||||
p = std::make_unique<std::promise<void>>();
|
||||
ps.push_back(p);
|
||||
it->second = p->get_future().share();
|
||||
@ -154,7 +152,7 @@ void eval(std::vector<array> outputs) {
|
||||
}
|
||||
}
|
||||
|
||||
deps[synch_id].wait();
|
||||
deps[synchronizer.id()].wait();
|
||||
}
|
||||
|
||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||
|
Loading…
Reference in New Issue
Block a user