mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-27 11:31:21 +08:00
Simplify removes no-ops from the tape (#1759)
* simplify removes no-ops from the tape * comment
This commit is contained in:
parent
1ccaf80575
commit
da8c885784
@ -11,9 +11,7 @@
|
|||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
inline bool is_static_cast(const Primitive& p) {
|
inline bool is_static_cast(const Primitive& p) {
|
||||||
return (
|
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
|
||||||
typeid(p) == typeid(Broadcast) || typeid(p) == typeid(Copy) ||
|
|
||||||
typeid(p) == typeid(StopGradient) || typeid(p) == typeid(AsType));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string build_lib_name(
|
std::string build_lib_name(
|
||||||
|
@ -68,8 +68,7 @@ bool is_reduction(const Primitive& p) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool is_fusable(const Primitive& p) {
|
bool is_fusable(const Primitive& p) {
|
||||||
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p) ||
|
return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p);
|
||||||
is_noop(p);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Compiled::Compiled(
|
Compiled::Compiled(
|
||||||
@ -351,12 +350,12 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
|||||||
return {tape, parents_map};
|
return {tape, parents_map};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simplify the tape. Note, this function modifies in-place both the tape and
|
// Simplify the tape. Note, this function modifies in-place both the tape,
|
||||||
// the parents map to remove orphaned arrays
|
// the parents map to remove orphaned arrays, and potentially the outputs
|
||||||
void compile_simplify(
|
void compile_simplify(
|
||||||
std::vector<array>& tape,
|
std::vector<array>& tape,
|
||||||
ParentsMap& parents_map,
|
ParentsMap& parents_map,
|
||||||
const std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
int passes) {
|
int passes) {
|
||||||
// Helpers to identify identical scalars
|
// Helpers to identify identical scalars
|
||||||
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
|
std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
|
||||||
@ -433,6 +432,28 @@ void compile_simplify(
|
|||||||
}
|
}
|
||||||
tape = std::move(new_tape);
|
tape = std::move(new_tape);
|
||||||
|
|
||||||
|
// Remove no-ops
|
||||||
|
{
|
||||||
|
std::unordered_map<uintptr_t, array> output_map;
|
||||||
|
for (auto& o : outputs) {
|
||||||
|
output_map.insert({o.id(), o});
|
||||||
|
}
|
||||||
|
for (auto& arr : tape) {
|
||||||
|
if (!arr.has_primitive() || !is_noop(arr.primitive())) {
|
||||||
|
new_tape.push_back(std::move(arr));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
merge_one(arr.inputs()[0], arr, parents_map);
|
||||||
|
if (auto it = output_map.find(arr.id()); it != output_map.end()) {
|
||||||
|
it->second = arr.inputs()[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tape = std::move(new_tape);
|
||||||
|
for (auto& o : outputs) {
|
||||||
|
o = output_map.at(o.id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::unordered_map<std::uintptr_t, uint32_t> tape_order;
|
std::unordered_map<std::uintptr_t, uint32_t> tape_order;
|
||||||
for (uint32_t i = 0; i < tape.size(); ++i) {
|
for (uint32_t i = 0; i < tape.size(); ++i) {
|
||||||
tape_order.insert({tape[i].id(), i});
|
tape_order.insert({tape[i].id(), i});
|
||||||
@ -442,6 +463,7 @@ void compile_simplify(
|
|||||||
for (auto& o : outputs) {
|
for (auto& o : outputs) {
|
||||||
output_set.insert(o.id());
|
output_set.insert(o.id());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Multi-pass merge only keeping non-orphaned arrays in the tape
|
// Multi-pass merge only keeping non-orphaned arrays in the tape
|
||||||
for (int pass = 0; pass < passes; ++pass) {
|
for (int pass = 0; pass < passes; ++pass) {
|
||||||
for (auto& arr : tape) {
|
for (auto& arr : tape) {
|
||||||
|
@ -39,12 +39,11 @@ std::pair<std::vector<array>, ParentsMap> compile_dfs(
|
|||||||
const std::vector<array>& outputs,
|
const std::vector<array>& outputs,
|
||||||
const std::vector<array>& original_inputs);
|
const std::vector<array>& original_inputs);
|
||||||
|
|
||||||
// Simplify the tape. Note, this function modifies in-place both the tape and
|
// Simplify the tape.
|
||||||
// the parents map to remove orphaned arrays
|
|
||||||
void compile_simplify(
|
void compile_simplify(
|
||||||
std::vector<array>& tape,
|
std::vector<array>& tape,
|
||||||
ParentsMap& parents_map,
|
ParentsMap& parents_map,
|
||||||
const std::vector<array>& outputs,
|
std::vector<array>& outputs,
|
||||||
int passes);
|
int passes);
|
||||||
|
|
||||||
std::vector<array> compile_replace(
|
std::vector<array> compile_replace(
|
||||||
|
@ -157,6 +157,17 @@ TEST_CASE("test simplify") {
|
|||||||
set_compile_mode(CompileMode::enabled);
|
set_compile_mode(CompileMode::enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_CASE("test simplify noops") {
|
||||||
|
set_compile_mode(CompileMode::no_fuse);
|
||||||
|
auto a = array({1.0f, 2.0f});
|
||||||
|
auto fun = [](const std::vector<array>& inputs) -> std::vector<array> {
|
||||||
|
return {copy(stop_gradient(exp(stop_gradient(inputs[0]))))};
|
||||||
|
};
|
||||||
|
auto b = compile(fun)({a})[0];
|
||||||
|
CHECK(b.inputs()[0].id() == a.id());
|
||||||
|
set_compile_mode(CompileMode::enabled);
|
||||||
|
}
|
||||||
|
|
||||||
auto add_diff(const std::vector<array>& inputs) {
|
auto add_diff(const std::vector<array>& inputs) {
|
||||||
auto a = inputs[0];
|
auto a = inputs[0];
|
||||||
return std::vector<array>{cos(a) + sin(a)};
|
return std::vector<array>{cos(a) + sin(a)};
|
||||||
|
Loading…
Reference in New Issue
Block a user