Fix input buffer donation in compile (#2897)

This commit is contained in:
CCYeh
2025-12-11 15:07:03 +01:00
committed by GitHub
parent 937ce79660
commit 3c8ce9b00e
2 changed files with 23 additions and 3 deletions

View File

@@ -4,12 +4,12 @@ import gc
import inspect
import io
import math
import unittest
from functools import partial, wraps
from io import StringIO
import mlx.core as mx
import mlx_tests
import numpy as np
class TestCompile(mlx_tests.MLXTestCase):
@@ -1252,6 +1252,26 @@ class TestCompile(mlx_tests.MLXTestCase):
loss, grads = step(emb, w, x)
mx.eval(loss, grads)
def test_compile_donates_input_buffer(self):
mx.set_default_device(mx.cpu)
def fun(x):
return mx.sin(x) + 1
compiled_fn = mx.compile(fun)
input = mx.arange(16, dtype=mx.float32)
mx.eval(input)
in_ptr = np.asarray(input, copy=False).__array_interface__["data"][0]
out = compiled_fn(input)
del input # Ensure the reference is dropped
mx.eval(out)
self.assertEqual(
np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr
)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()