Simplify removes no-ops from the tape (#1759)

* simplify removes no-ops from the tape

* comment
This commit is contained in:
Awni Hannun 2025-01-09 11:23:19 -08:00 committed by GitHub
parent 1ccaf80575
commit da8c885784
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 41 additions and 11 deletions

View File

@ -11,9 +11,7 @@
namespace mlx::core {
inline bool is_static_cast(const Primitive& p) {
return (
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
}
std::string build_lib_name(

View File

@ -68,8 +68,7 @@ bool is_reduction(const Primitive& p) {
}
bool is_fusable(const Primitive& p) {
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) ||
is_noop(p);
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p);
}
Compiled::Compiled(
@ -351,12 +350,12 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
return {tape, parents_map};
}
// Simplify the tape. Note, this function modifies in-place both the tape and
// the parents map to remove orphaned arrays
// Simplify the tape. Note, this function modifies in-place both the tape,
// the parents map to remove orphaned arrays, and potentially the outputs
void compile_simplify(
std::vector<array>& tape,
ParentsMap& parents_map,
const std::vector<array>& outputs,
std::vector<array>& outputs,
int passes) {
// Helpers to identify identical scalars
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
@ -433,6 +432,28 @@ void compile_simplify(
}
tape = std::move(new_tape);
// Remove no-ops
{
std::unordered_map<uintptr_t, array> output_map;
for (auto& o : outputs) {
output_map.insert({o.id(), o});
}
for (auto& arr : tape) {
if (!arr.has_primitive() || !is_noop(arr.primitive())) {
new_tape.push_back(std::move(arr));
continue;
}
merge_one(arr.inputs()[0], arr, parents_map);
if (auto it = output_map.find(arr.id()); it != output_map.end()) {
it->second = arr.inputs()[0];
}
}
tape = std::move(new_tape);
for (auto& o : outputs) {
o = output_map.at(o.id());
}
}
std::unordered_map<std::uintptr_t, uint32_t> tape_order;
for (uint32_t i = 0; i < tape.size(); ++i) {
tape_order.insert({tape[i].id(), i});
@ -442,6 +463,7 @@ void compile_simplify(
for (auto& o : outputs) {
output_set.insert(o.id());
}
// Multi-pass merge only keeping non-orphaned arrays in the tape
for (int pass = 0; pass < passes; ++pass) {
for (auto& arr : tape) {

View File

@ -39,12 +39,11 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
const std::vector<array>& outputs,
const std::vector<array>& original_inputs);
// Simplify the tape. Note, this function modifies in-place both the tape and
// the parents map to remove orphaned arrays
// Simplify the tape.
void compile_simplify(
std::vector<array>& tape,
ParentsMap& parents_map,
const std::vector<array>& outputs,
std::vector<array>& outputs,
int passes);
std::vector<array> compile_replace(

View File

@ -157,6 +157,17 @@ TEST_CASE("test simplify") {
set_compile_mode(CompileMode::enabled);
}
TEST_CASE("test simplify noops") {
set_compile_mode(CompileMode::no_fuse);
auto a = array({1.0f, 2.0f});
auto fun = [](const std::vector<array>& inputs) -> std::vector<array> {
return {copy(stop_gradient(exp(stop_gradient(inputs[0]))))};
};
auto b = compile(fun)({a})[0];
CHECK(b.inputs()[0].id() == a.id());
set_compile_mode(CompileMode::enabled);
}
auto add_diff(const std::vector<array>& inputs) {
auto a = inputs[0];
return std::vector<array>{cos(a) + sin(a)};