mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
simplify inputs also
This commit is contained in:
parent
df1f6c221b
commit
ba8ce23597
@ -1,5 +1,4 @@
|
||||
// Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
@ -112,15 +111,16 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
const std::vector<array>& outputs) {
|
||||
std::function<void(const array&)> recurse;
|
||||
std::vector<array> tape;
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
std::unordered_set<std::uintptr_t> input_set;
|
||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
||||
parents_map;
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
auto in = inputs[i];
|
||||
cache.insert(in.id());
|
||||
input_set.insert(in.id());
|
||||
}
|
||||
|
||||
// DFS the graph to build the tape, and log parents and scalars
|
||||
std::unordered_set<std::uintptr_t> cache;
|
||||
recurse = [&](const array& a) {
|
||||
auto id = a.id();
|
||||
if (cache.find(id) != cache.end()) {
|
||||
@ -132,7 +132,11 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
for (auto& s : a.siblings()) {
|
||||
parents_map[in.id()].push_back({s, i});
|
||||
}
|
||||
recurse(in);
|
||||
// Don't recurse on inputs (but add them to the tape for the purpose
|
||||
// of future optimizations)
|
||||
if (input_set.find(a.id()) == input_set.end()) {
|
||||
recurse(in);
|
||||
}
|
||||
}
|
||||
cache.insert(id);
|
||||
for (auto& s : a.siblings()) {
|
||||
@ -277,6 +281,12 @@ void compile_simplify(
|
||||
}
|
||||
}
|
||||
}
|
||||
// Erase orphaned parents so we don't keep fusing with them
|
||||
for (int i = N - 1; i > 0; --i) {
|
||||
if (mask[i]) {
|
||||
parents->second.erase(parents->second.begin() + i);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
} else {
|
||||
return output_set.find(a.id()) == output_set.end();
|
||||
@ -352,7 +362,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
||||
compile_dfs(entry.inputs, entry.outputs);
|
||||
|
||||
// Simplify the tape
|
||||
compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 2);
|
||||
compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 3);
|
||||
|
||||
// This is a good point to do more optimizations, e.g. kernel fusion to
|
||||
// generate new primitives. The tape needs to be updated accordingly
|
||||
|
Loading…
Reference in New Issue
Block a user