Fix input buffer donation in compile (#2897)

This commit is contained in:
CCYeh
2025-12-11 15:07:03 +01:00
committed by GitHub
parent 937ce79660
commit 3c8ce9b00e
2 changed files with 23 additions and 3 deletions

View File

@@ -130,7 +130,7 @@ void compiled_allocate_outputs(
// - Donatable // - Donatable
// - Not a constant // - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() && is_constant(i)) { in.is_donatable() && !is_constant(i)) {
outputs[o++].copy_shared_buffer(in); outputs[o++].copy_shared_buffer(in);
} }
// Get representative input flags to properly set non-donated outputs // Get representative input flags to properly set non-donated outputs
@@ -158,7 +158,7 @@ void compiled_allocate_outputs(
// - Not a constant // - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() && if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
is_constant(i)) { !is_constant(i)) {
outputs[o].copy_shared_buffer( outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size()); in, outputs[o].strides(), in.flags(), in.data_size());
o++; o++;

View File

@@ -4,12 +4,12 @@ import gc
import inspect import inspect
import io import io
import math import math
import unittest
from functools import partial, wraps from functools import partial, wraps
from io import StringIO from io import StringIO
import mlx.core as mx import mlx.core as mx
import mlx_tests import mlx_tests
import numpy as np
class TestCompile(mlx_tests.MLXTestCase): class TestCompile(mlx_tests.MLXTestCase):
@@ -1252,6 +1252,26 @@ class TestCompile(mlx_tests.MLXTestCase):
loss, grads = step(emb, w, x) loss, grads = step(emb, w, x)
mx.eval(loss, grads) mx.eval(loss, grads)
def test_compile_donates_input_buffer(self):
mx.set_default_device(mx.cpu)
def fun(x):
return mx.sin(x) + 1
compiled_fn = mx.compile(fun)
input = mx.arange(16, dtype=mx.float32)
mx.eval(input)
in_ptr = np.asarray(input, copy=False).__array_interface__["data"][0]
out = compiled_fn(input)
del input # Ensure the reference is dropped
mx.eval(out)
self.assertEqual(
np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr
)
if __name__ == "__main__": if __name__ == "__main__":
mlx_tests.MLXTestRunner() mlx_tests.MLXTestRunner()