mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 07:18:29 +08:00
Fix compile when outputs change (#2648)
This commit is contained in:
@@ -95,6 +95,11 @@ Compiled::Compiled(
|
||||
std::ostringstream os;
|
||||
std::ostringstream constant_hasher;
|
||||
|
||||
std::unordered_set<uintptr_t> output_ids;
|
||||
for (auto& o : outputs_) {
|
||||
output_ids.insert(o.id());
|
||||
}
|
||||
|
||||
// Fill the input names. This is not really necessary, I just like having A,
|
||||
// B, C, ... as the inputs.
|
||||
for (const auto& x : inputs_) {
|
||||
@@ -106,6 +111,12 @@ Compiled::Compiled(
|
||||
for (const auto& a : tape_) {
|
||||
// name and type of output
|
||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||
// whether or not it's an output
|
||||
if (output_ids.find(a.id()) != output_ids.end()) {
|
||||
os << "O";
|
||||
} else {
|
||||
os << "I";
|
||||
}
|
||||
// computation performed
|
||||
os << a.primitive().name();
|
||||
// name of inputs to the function
|
||||
|
@@ -1115,6 +1115,25 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
|
||||
self.assertEqual(state[0].item(), 4)
|
||||
|
||||
def test_outputs_changing(self):
|
||||
@mx.compile
|
||||
def fun(x):
|
||||
x = mx.abs(mx.negative(x))
|
||||
y = mx.abs(x)
|
||||
return x, y
|
||||
|
||||
@mx.compile
|
||||
def fun2(x):
|
||||
x = mx.abs(mx.negative(x))
|
||||
y = mx.abs(x)
|
||||
return y
|
||||
|
||||
a, b = fun(mx.array(-1.0))
|
||||
mx.eval(a, b)
|
||||
|
||||
a = fun2(mx.array(-1.0))
|
||||
self.assertEqual(a.item(), 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Reference in New Issue
Block a user