diff --git a/mlx/compile.cpp b/mlx/compile.cpp index e6cd58755..7ed3dd455 100644 --- a/mlx/compile.cpp +++ b/mlx/compile.cpp @@ -95,6 +95,11 @@ Compiled::Compiled( std::ostringstream os; std::ostringstream constant_hasher; + std::unordered_set output_ids; + for (auto& o : outputs_) { + output_ids.insert(o.id()); + } + // Fill the input names. This is not really necessary, I just like having A, // B, C, ... as the inputs. for (const auto& x : inputs_) { @@ -106,6 +111,12 @@ Compiled::Compiled( for (const auto& a : tape_) { // name and type of output os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize(); + // whether or not it's an output + if (output_ids.find(a.id()) != output_ids.end()) { + os << "O"; + } else { + os << "I"; + } // computation performed os << a.primitive().name(); // name of inputs to the function diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index 572123c60..b78673027 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -1115,6 +1115,25 @@ class TestCompile(mlx_tests.MLXTestCase): self.assertEqual(state[0].item(), 4) + def test_outputs_changing(self): + @mx.compile + def fun(x): + x = mx.abs(mx.negative(x)) + y = mx.abs(x) + return x, y + + @mx.compile + def fun2(x): + x = mx.abs(mx.negative(x)) + y = mx.abs(x) + return y + + a, b = fun(mx.array(-1.0)) + mx.eval(a, b) + + a = fun2(mx.array(-1.0)) + self.assertEqual(a.item(), 1.0) + if __name__ == "__main__": mlx_tests.MLXTestRunner()