mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
parent
29221fa238
commit
dc175f08d3
11
mlx/array.h
11
mlx/array.h
@ -256,6 +256,17 @@ class array {
|
|||||||
array_desc_->position = position;
|
array_desc_->position = position;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** The i-th output of the array's primitive. */
|
||||||
|
const array& output(int i) const {
|
||||||
|
if (i == array_desc_->position) {
|
||||||
|
return *this;
|
||||||
|
} else if (i < array_desc_->position) {
|
||||||
|
return siblings()[i];
|
||||||
|
} else {
|
||||||
|
return siblings()[i + 1];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/** The outputs of the array's primitive (i.e. this array and
|
/** The outputs of the array's primitive (i.e. this array and
|
||||||
* its siblings) in the order the primitive expects. */
|
* its siblings) in the order the primitive expects. */
|
||||||
std::vector<array> outputs() const {
|
std::vector<array> outputs() const {
|
||||||
|
@ -76,7 +76,7 @@ void eval(std::vector<array> outputs) {
|
|||||||
// If the input is being computed on a different stream, we need to
|
// If the input is being computed on a different stream, we need to
|
||||||
// manage the dependency.
|
// manage the dependency.
|
||||||
if (a.primitive().stream() != in.primitive().stream()) {
|
if (a.primitive().stream() != in.primitive().stream()) {
|
||||||
deps.insert({in.primitive_id(), std::shared_future<void>{}});
|
deps.insert({in.output(0).id(), std::shared_future<void>{}});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -96,8 +96,7 @@ void eval(std::vector<array> outputs) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
recurse(synchronizer);
|
recurse(synchronizer);
|
||||||
uintptr_t synch_id = synchronizer.primitive_id();
|
deps.insert({synchronizer.id(), std::shared_future<void>{}});
|
||||||
deps.insert({synch_id, std::shared_future<void>{}});
|
|
||||||
|
|
||||||
std::vector<std::shared_ptr<std::promise<void>>> ps;
|
std::vector<std::shared_ptr<std::promise<void>>> ps;
|
||||||
while (!tape.empty()) {
|
while (!tape.empty()) {
|
||||||
@ -113,13 +112,12 @@ void eval(std::vector<array> outputs) {
|
|||||||
auto stream = arr.primitive().stream();
|
auto stream = arr.primitive().stream();
|
||||||
std::vector<std::shared_future<void>> arr_deps;
|
std::vector<std::shared_future<void>> arr_deps;
|
||||||
for (auto& in : arr.inputs()) {
|
for (auto& in : arr.inputs()) {
|
||||||
// TODO that's a bug
|
if (auto it = deps.find(in.output(0).id()); it != deps.end()) {
|
||||||
if (auto it = deps.find(in.primitive_id()); it != deps.end()) {
|
|
||||||
arr_deps.push_back(it->second);
|
arr_deps.push_back(it->second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::shared_ptr<std::promise<void>> p;
|
std::shared_ptr<std::promise<void>> p;
|
||||||
if (auto it = deps.find(arr.primitive_id()); it != deps.end()) {
|
if (auto it = deps.find(arr.output(0).id()); it != deps.end()) {
|
||||||
p = std::make_unique<std::promise<void>>();
|
p = std::make_unique<std::promise<void>>();
|
||||||
ps.push_back(p);
|
ps.push_back(p);
|
||||||
it->second = p->get_future().share();
|
it->second = p->get_future().share();
|
||||||
@ -154,7 +152,7 @@ void eval(std::vector<array> outputs) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
deps[synch_id].wait();
|
deps[synchronizer.id()].wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<array>> vjp(
|
std::pair<std::vector<array>, std::vector<array>> vjp(
|
||||||
|
Loading…
Reference in New Issue
Block a user