Fix compile when outputs change (#2648)

This commit is contained in:
Awni Hannun
2025-10-03 08:40:57 -07:00
committed by GitHub
parent 22a5da76c8
commit a7a94b29d7
2 changed files with 30 additions and 0 deletions

View File

@@ -95,6 +95,11 @@ Compiled::Compiled(
std::ostringstream os;
std::ostringstream constant_hasher;
std::unordered_set<uintptr_t> output_ids;
for (auto& o : outputs_) {
output_ids.insert(o.id());
}
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (const auto& x : inputs_) {
@@ -106,6 +111,12 @@ Compiled::Compiled(
for (const auto& a : tape_) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// whether or not it's an output
if (output_ids.find(a.id()) != output_ids.end()) {
os << "O";
} else {
os << "I";
}
// computation performed
os << a.primitive().name();
// name of inputs to the function

View File

@@ -1115,6 +1115,25 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(state[0].item(), 4)
def test_outputs_changing(self):
@mx.compile
def fun(x):
x = mx.abs(mx.negative(x))
y = mx.abs(x)
return x, y
@mx.compile
def fun2(x):
x = mx.abs(mx.negative(x))
y = mx.abs(x)
return y
a, b = fun(mx.array(-1.0))
mx.eval(a, b)
a = fun2(mx.array(-1.0))
self.assertEqual(a.item(), 1.0)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()