mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix input buffer donation in compile (#2897)
This commit is contained in:
@@ -130,7 +130,7 @@ void compiled_allocate_outputs(
|
||||
// - Donatable
|
||||
// - Not a constant
|
||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||
in.is_donatable() && is_constant(i)) {
|
||||
in.is_donatable() && !is_constant(i)) {
|
||||
outputs[o++].copy_shared_buffer(in);
|
||||
}
|
||||
// Get representative input flags to properly set non-donated outputs
|
||||
@@ -158,7 +158,7 @@ void compiled_allocate_outputs(
|
||||
// - Not a constant
|
||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||
is_constant(i)) {
|
||||
!is_constant(i)) {
|
||||
outputs[o].copy_shared_buffer(
|
||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||
o++;
|
||||
|
||||
@@ -4,12 +4,12 @@ import gc
|
||||
import inspect
|
||||
import io
|
||||
import math
|
||||
import unittest
|
||||
from functools import partial, wraps
|
||||
from io import StringIO
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx_tests
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestCompile(mlx_tests.MLXTestCase):
|
||||
@@ -1252,6 +1252,26 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
loss, grads = step(emb, w, x)
|
||||
mx.eval(loss, grads)
|
||||
|
||||
def test_compile_donates_input_buffer(self):
|
||||
mx.set_default_device(mx.cpu)
|
||||
|
||||
def fun(x):
|
||||
return mx.sin(x) + 1
|
||||
|
||||
compiled_fn = mx.compile(fun)
|
||||
|
||||
input = mx.arange(16, dtype=mx.float32)
|
||||
mx.eval(input)
|
||||
in_ptr = np.asarray(input, copy=False).__array_interface__["data"][0]
|
||||
|
||||
out = compiled_fn(input)
|
||||
del input # Ensure the reference is dropped
|
||||
mx.eval(out)
|
||||
|
||||
self.assertEqual(
|
||||
np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
Reference in New Issue
Block a user