mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-18 15:28:16 +08:00
Fix compile when outputs change (#2648)
This commit is contained in:
@@ -95,6 +95,11 @@ Compiled::Compiled(
|
|||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
std::ostringstream constant_hasher;
|
std::ostringstream constant_hasher;
|
||||||
|
|
||||||
|
std::unordered_set<uintptr_t> output_ids;
|
||||||
|
for (auto& o : outputs_) {
|
||||||
|
output_ids.insert(o.id());
|
||||||
|
}
|
||||||
|
|
||||||
// Fill the input names. This is not really necessary, I just like having A,
|
// Fill the input names. This is not really necessary, I just like having A,
|
||||||
// B, C, ... as the inputs.
|
// B, C, ... as the inputs.
|
||||||
for (const auto& x : inputs_) {
|
for (const auto& x : inputs_) {
|
||||||
@@ -106,6 +111,12 @@ Compiled::Compiled(
|
|||||||
for (const auto& a : tape_) {
|
for (const auto& a : tape_) {
|
||||||
// name and type of output
|
// name and type of output
|
||||||
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
|
||||||
|
// whether or not it's an output
|
||||||
|
if (output_ids.find(a.id()) != output_ids.end()) {
|
||||||
|
os << "O";
|
||||||
|
} else {
|
||||||
|
os << "I";
|
||||||
|
}
|
||||||
// computation performed
|
// computation performed
|
||||||
os << a.primitive().name();
|
os << a.primitive().name();
|
||||||
// name of inputs to the function
|
// name of inputs to the function
|
||||||
|
@@ -1115,6 +1115,25 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
|
|
||||||
self.assertEqual(state[0].item(), 4)
|
self.assertEqual(state[0].item(), 4)
|
||||||
|
|
||||||
|
def test_outputs_changing(self):
|
||||||
|
@mx.compile
|
||||||
|
def fun(x):
|
||||||
|
x = mx.abs(mx.negative(x))
|
||||||
|
y = mx.abs(x)
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
@mx.compile
|
||||||
|
def fun2(x):
|
||||||
|
x = mx.abs(mx.negative(x))
|
||||||
|
y = mx.abs(x)
|
||||||
|
return y
|
||||||
|
|
||||||
|
a, b = fun(mx.array(-1.0))
|
||||||
|
mx.eval(a, b)
|
||||||
|
|
||||||
|
a = fun2(mx.array(-1.0))
|
||||||
|
self.assertEqual(a.item(), 1.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
Reference in New Issue
Block a user