diff --git a/mlx/backend/common/compiled.cpp b/mlx/backend/common/compiled.cpp index e5e1d4350..aceeb1f7f 100644 --- a/mlx/backend/common/compiled.cpp +++ b/mlx/backend/common/compiled.cpp @@ -130,7 +130,7 @@ void compiled_allocate_outputs( // - Donatable // - Not a constant if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && - in.is_donatable() && is_constant(i)) { + in.is_donatable() && !is_constant(i)) { outputs[o++].copy_shared_buffer(in); } // Get representative input flags to properly set non-donated outputs @@ -158,7 +158,7 @@ void compiled_allocate_outputs( // - Not a constant if (in.flags().row_contiguous && in.size() == outputs[o].size() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() && - is_constant(i)) { + !is_constant(i)) { outputs[o].copy_shared_buffer( in, outputs[o].strides(), in.flags(), in.data_size()); o++; diff --git a/python/tests/test_compile.py b/python/tests/test_compile.py index b9bff614f..d64c057fd 100644 --- a/python/tests/test_compile.py +++ b/python/tests/test_compile.py @@ -4,12 +4,12 @@ import gc import inspect import io import math -import unittest from functools import partial, wraps from io import StringIO import mlx.core as mx import mlx_tests +import numpy as np class TestCompile(mlx_tests.MLXTestCase): @@ -1252,6 +1252,26 @@ class TestCompile(mlx_tests.MLXTestCase): loss, grads = step(emb, w, x) mx.eval(loss, grads) + def test_compile_donates_input_buffer(self): + mx.set_default_device(mx.cpu) + + def fun(x): + return mx.sin(x) + 1 + + compiled_fn = mx.compile(fun) + + input = mx.arange(16, dtype=mx.float32) + mx.eval(input) + in_ptr = np.asarray(input, copy=False).__array_interface__["data"][0] + + out = compiled_fn(input) + del input # Ensure the reference is dropped + mx.eval(out) + + self.assertEqual( + np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr + ) + if __name__ == "__main__": mlx_tests.MLXTestRunner()