mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Simplify removes no-ops from the tape (#1759)
* simplify removes no-ops from the tape * comment
This commit is contained in:
parent
1ccaf80575
commit
da8c885784
@ -11,9 +11,7 @@
|
||||
namespace mlx::core {
|
||||
|
||||
inline bool is_static_cast(const Primitive& p) {
|
||||
return (
|
||||
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
|
||||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
|
||||
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||
}
|
||||
|
||||
std::string build_lib_name(
|
||||
|
@ -68,8 +68,7 @@ bool is_reduction(const Primitive& p) {
|
||||
}
|
||||
|
||||
bool is_fusable(const Primitive& p) {
|
||||
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) ||
|
||||
is_noop(p);
|
||||
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p);
|
||||
}
|
||||
|
||||
Compiled::Compiled(
|
||||
@ -351,12 +350,12 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
return {tape, parents_map};
|
||||
}
|
||||
|
||||
// Simplify the tape. Note, this function modifies in-place both the tape and
|
||||
// the parents map to remove orphaned arrays
|
||||
// Simplify the tape. Note, this function modifies in-place both the tape,
|
||||
// the parents map to remove orphaned arrays, and potentially the outputs
|
||||
void compile_simplify(
|
||||
std::vector<array>& tape,
|
||||
ParentsMap& parents_map,
|
||||
const std::vector<array>& outputs,
|
||||
std::vector<array>& outputs,
|
||||
int passes) {
|
||||
// Helpers to identify identical scalars
|
||||
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
|
||||
@ -433,6 +432,28 @@ void compile_simplify(
|
||||
}
|
||||
tape = std::move(new_tape);
|
||||
|
||||
// Remove no-ops
|
||||
{
|
||||
std::unordered_map<uintptr_t, array> output_map;
|
||||
for (auto& o : outputs) {
|
||||
output_map.insert({o.id(), o});
|
||||
}
|
||||
for (auto& arr : tape) {
|
||||
if (!arr.has_primitive() || !is_noop(arr.primitive())) {
|
||||
new_tape.push_back(std::move(arr));
|
||||
continue;
|
||||
}
|
||||
merge_one(arr.inputs()[0], arr, parents_map);
|
||||
if (auto it = output_map.find(arr.id()); it != output_map.end()) {
|
||||
it->second = arr.inputs()[0];
|
||||
}
|
||||
}
|
||||
tape = std::move(new_tape);
|
||||
for (auto& o : outputs) {
|
||||
o = output_map.at(o.id());
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::uintptr_t, uint32_t> tape_order;
|
||||
for (uint32_t i = 0; i < tape.size(); ++i) {
|
||||
tape_order.insert({tape[i].id(), i});
|
||||
@ -442,6 +463,7 @@ void compile_simplify(
|
||||
for (auto& o : outputs) {
|
||||
output_set.insert(o.id());
|
||||
}
|
||||
|
||||
// Multi-pass merge only keeping non-orphaned arrays in the tape
|
||||
for (int pass = 0; pass < passes; ++pass) {
|
||||
for (auto& arr : tape) {
|
||||
|
@ -39,12 +39,11 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
||||
const std::vector<array>& outputs,
|
||||
const std::vector<array>& original_inputs);
|
||||
|
||||
// Simplify the tape. Note, this function modifies in-place both the tape and
|
||||
// the parents map to remove orphaned arrays
|
||||
// Simplify the tape.
|
||||
void compile_simplify(
|
||||
std::vector<array>& tape,
|
||||
ParentsMap& parents_map,
|
||||
const std::vector<array>& outputs,
|
||||
std::vector<array>& outputs,
|
||||
int passes);
|
||||
|
||||
std::vector<array> compile_replace(
|
||||
|
@ -157,6 +157,17 @@ TEST_CASE("test simplify") {
|
||||
set_compile_mode(CompileMode::enabled);
|
||||
}
|
||||
|
||||
TEST_CASE("test simplify noops") {
|
||||
set_compile_mode(CompileMode::no_fuse);
|
||||
auto a = array({1.0f, 2.0f});
|
||||
auto fun = [](const std::vector<array>& inputs) -> std::vector<array> {
|
||||
return {copy(stop_gradient(exp(stop_gradient(inputs[0]))))};
|
||||
};
|
||||
auto b = compile(fun)({a})[0];
|
||||
CHECK(b.inputs()[0].id() == a.id());
|
||||
set_compile_mode(CompileMode::enabled);
|
||||
}
|
||||
|
||||
auto add_diff(const std::vector<array>& inputs) {
|
||||
auto a = inputs[0];
|
||||
return std::vector<array>{cos(a) + sin(a)};
|
||||
|
Loading…
Reference in New Issue
Block a user