mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Fix input buffer donation in compile (#2897)
This commit is contained in:
@@ -130,7 +130,7 @@ void compiled_allocate_outputs(
|
|||||||
// - Donatable
|
// - Donatable
|
||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
|
||||||
in.is_donatable() && is_constant(i)) {
|
in.is_donatable() && !is_constant(i)) {
|
||||||
outputs[o++].copy_shared_buffer(in);
|
outputs[o++].copy_shared_buffer(in);
|
||||||
}
|
}
|
||||||
// Get representative input flags to properly set non-donated outputs
|
// Get representative input flags to properly set non-donated outputs
|
||||||
@@ -158,7 +158,7 @@ void compiled_allocate_outputs(
|
|||||||
// - Not a constant
|
// - Not a constant
|
||||||
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
|
||||||
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
|
||||||
is_constant(i)) {
|
!is_constant(i)) {
|
||||||
outputs[o].copy_shared_buffer(
|
outputs[o].copy_shared_buffer(
|
||||||
in, outputs[o].strides(), in.flags(), in.data_size());
|
in, outputs[o].strides(), in.flags(), in.data_size());
|
||||||
o++;
|
o++;
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ import gc
|
|||||||
import inspect
|
import inspect
|
||||||
import io
|
import io
|
||||||
import math
|
import math
|
||||||
import unittest
|
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx_tests
|
import mlx_tests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class TestCompile(mlx_tests.MLXTestCase):
|
class TestCompile(mlx_tests.MLXTestCase):
|
||||||
@@ -1252,6 +1252,26 @@ class TestCompile(mlx_tests.MLXTestCase):
|
|||||||
loss, grads = step(emb, w, x)
|
loss, grads = step(emb, w, x)
|
||||||
mx.eval(loss, grads)
|
mx.eval(loss, grads)
|
||||||
|
|
||||||
|
def test_compile_donates_input_buffer(self):
|
||||||
|
mx.set_default_device(mx.cpu)
|
||||||
|
|
||||||
|
def fun(x):
|
||||||
|
return mx.sin(x) + 1
|
||||||
|
|
||||||
|
compiled_fn = mx.compile(fun)
|
||||||
|
|
||||||
|
input = mx.arange(16, dtype=mx.float32)
|
||||||
|
mx.eval(input)
|
||||||
|
in_ptr = np.asarray(input, copy=False).__array_interface__["data"][0]
|
||||||
|
|
||||||
|
out = compiled_fn(input)
|
||||||
|
del input # Ensure the reference is dropped
|
||||||
|
mx.eval(out)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
np.asarray(out, copy=False).__array_interface__["data"][0], in_ptr
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mlx_tests.MLXTestRunner()
|
mlx_tests.MLXTestRunner()
|
||||||
|
|||||||
Reference in New Issue
Block a user