Simplify removes no-ops from the tape (#1759)

* simplify removes no-ops from the tape

* comment
This commit is contained in:
Awni Hannun
2025-01-09 11:23:19 -08:00
committed by GitHub
parent 1ccaf80575
commit da8c885784
4 changed files with 41 additions and 11 deletions

View File

@@ -11,9 +11,7 @@
namespace mlx::core {
inline bool is_static_cast(const Primitive& p) {
return (
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
}
std::string build_lib_name(

View File

@@ -68,8 +68,7 @@ bool is_reduction(const Primitive& p) {
}
bool is_fusable(const Primitive& p) {
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) ||
is_noop(p);
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p);
}
Compiled::Compiled(
@@ -351,12 +350,12 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
return {tape, parents_map};
}
// Simplify the tape. Note, this function modifies in-place both the tape and
// the parents map to remove orphaned arrays
// Simplify the tape. Note, this function modifies in-place both the tape,
// the parents map to remove orphaned arrays, and potentially the outputs
void compile_simplify(
std::vector<array>& tape,
ParentsMap& parents_map,
const std::vector<array>& outputs,
std::vector<array>& outputs,
int passes) {
// Helpers to identify identical scalars
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
@@ -433,6 +432,28 @@ void compile_simplify(
}
tape = std::move(new_tape);
// Remove no-ops
{
std::unordered_map<uintptr_t, array> output_map;
for (auto& o : outputs) {
output_map.insert({o.id(), o});
}
for (auto& arr : tape) {
if (!arr.has_primitive() || !is_noop(arr.primitive())) {
new_tape.push_back(std::move(arr));
continue;
}
merge_one(arr.inputs()[0], arr, parents_map);
if (auto it = output_map.find(arr.id()); it != output_map.end()) {
it->second = arr.inputs()[0];
}
}
tape = std::move(new_tape);
for (auto& o : outputs) {
o = output_map.at(o.id());
}
}
std::unordered_map<std::uintptr_t, uint32_t> tape_order;
for (uint32_t i = 0; i < tape.size(); ++i) {
tape_order.insert({tape[i].id(), i});
@@ -442,6 +463,7 @@ void compile_simplify(
for (auto& o : outputs) {
output_set.insert(o.id());
}
// Multi-pass merge only keeping non-orphaned arrays in the tape
for (int pass = 0; pass < passes; ++pass) {
for (auto& arr : tape) {

View File

@@ -39,12 +39,11 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
const std::vector<array>& outputs,
const std::vector<array>& original_inputs);
// Simplify the tape. Note, this function modifies in-place both the tape and
// the parents map to remove orphaned arrays
// Simplify the tape.
void compile_simplify(
std::vector<array>& tape,
ParentsMap& parents_map,
const std::vector<array>& outputs,
std::vector<array>& outputs,
int passes);
std::vector<array> compile_replace(