mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-26 04:21:17 +08:00
simplify inputs also
This commit is contained in:
parent
df1f6c221b
commit
ba8ce23597
@ -1,5 +1,4 @@
|
|||||||
// Copyright © 2023-2024 Apple Inc.
|
// Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
@ -112,15 +111,16 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
|||||||
const std::vector<array>& outputs) {
|
const std::vector<array>& outputs) {
|
||||||
std::function<void(const array&)> recurse;
|
std::function<void(const array&)> recurse;
|
||||||
std::vector<array> tape;
|
std::vector<array> tape;
|
||||||
std::unordered_set<std::uintptr_t> cache;
|
std::unordered_set<std::uintptr_t> input_set;
|
||||||
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
|
||||||
parents_map;
|
parents_map;
|
||||||
for (int i = 0; i < inputs.size(); ++i) {
|
for (int i = 0; i < inputs.size(); ++i) {
|
||||||
auto in = inputs[i];
|
auto in = inputs[i];
|
||||||
cache.insert(in.id());
|
input_set.insert(in.id());
|
||||||
}
|
}
|
||||||
|
|
||||||
// DFS the graph to build the tape, and log parents and scalars
|
// DFS the graph to build the tape, and log parents and scalars
|
||||||
|
std::unordered_set<std::uintptr_t> cache;
|
||||||
recurse = [&](const array& a) {
|
recurse = [&](const array& a) {
|
||||||
auto id = a.id();
|
auto id = a.id();
|
||||||
if (cache.find(id) != cache.end()) {
|
if (cache.find(id) != cache.end()) {
|
||||||
@ -132,8 +132,12 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
|||||||
for (auto& s : a.siblings()) {
|
for (auto& s : a.siblings()) {
|
||||||
parents_map[in.id()].push_back({s, i});
|
parents_map[in.id()].push_back({s, i});
|
||||||
}
|
}
|
||||||
|
// Don't recurse on inputs (but add them to the tape for the purpose
|
||||||
|
// of future optimizations)
|
||||||
|
if (input_set.find(a.id()) == input_set.end()) {
|
||||||
recurse(in);
|
recurse(in);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
cache.insert(id);
|
cache.insert(id);
|
||||||
for (auto& s : a.siblings()) {
|
for (auto& s : a.siblings()) {
|
||||||
cache.insert(s.id());
|
cache.insert(s.id());
|
||||||
@ -277,6 +281,12 @@ void compile_simplify(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Erase orphaned parents so we don't keep fusing with them
|
||||||
|
for (int i = N - 1; i > 0; --i) {
|
||||||
|
if (mask[i]) {
|
||||||
|
parents->second.erase(parents->second.begin() + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
return output_set.find(a.id()) == output_set.end();
|
return output_set.find(a.id()) == output_set.end();
|
||||||
@ -352,7 +362,7 @@ std::function<std::vector<array>(const std::vector<array>&)> compile(
|
|||||||
compile_dfs(entry.inputs, entry.outputs);
|
compile_dfs(entry.inputs, entry.outputs);
|
||||||
|
|
||||||
// Simplify the tape
|
// Simplify the tape
|
||||||
compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 2);
|
compile_simplify(entry.tape, parents_map, entry.outputs, /* passes */ 3);
|
||||||
|
|
||||||
// This is a good point to do more optimizations, e.g. kernel fusion to
|
// This is a good point to do more optimizations, e.g. kernel fusion to
|
||||||
// generate new primitives. The tape needs to be updated accordingly
|
// generate new primitives. The tape needs to be updated accordingly
|
||||||
|
Loading…
Reference in New Issue
Block a user