From 9c5e7da5079cf98f48df150c8bed5c3c0043d22c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 2 May 2025 15:08:50 -0700 Subject: [PATCH] fix compile merging (#2150) --- mlx/compile.cpp | 9 +++++++++ tests/compile_tests.cpp | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/mlx/compile.cpp b/mlx/compile.cpp index 7ff5c8f9e..2baeb6fcf 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -168,6 +168,15 @@ void merge_one(array& dst, array& src, ParentsMap& parents_map) { parent.first.inputs()[parent.second] = dst; pairs.push_back(parent); } + + // If src is a parent of dst, remove it from dst's parents + for (auto it = pairs.begin(); it != pairs.end();) { + if (it->first.id() == src.id()) { + it = pairs.erase(it); + } else { + it++; + } + } // Remove the source from the map to avoid fusing with it again parents_map.erase(src_parents); } diff --git a/tests/compile_tests.cpp b/tests/compile_tests.cpp index 66511682d..96552ef9d 100644 --- a/tests/compile_tests.cpp +++ b/tests/compile_tests.cpp @@ -795,3 +795,12 @@ TEST_CASE("test compile lambda") { out = cfun2({array(0)}); CHECK_EQ(out[0].item(), 3); } + +TEST_CASE("test compile with no-ops") { + auto fun = [](const std::vector& inputs) { + return std::vector{abs(stop_gradient(abs(inputs[0])))}; + }; + auto in = array(1.0); + auto out = compile(fun)({in})[0]; + CHECK_EQ(out.inputs()[0].id(), in.id()); +}