fix wraps compile (#2461)

This commit is contained in:
Awni Hannun
2025-08-04 16:14:18 -07:00
committed by GitHub
parent 6ad0889c8a
commit 0b807893a7
4 changed files with 48 additions and 38 deletions

View File

@@ -1,10 +1,11 @@
# Copyright © 2023-2024 Apple Inc.
import gc
import inspect
import io
import math
import unittest
from functools import partial
from functools import partial, wraps
from io import StringIO
import mlx.core as mx
@@ -1014,6 +1015,28 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertTrue(mx.allclose(d[0], d_hat[0]))
self.assertTrue(mx.allclose(d[1], d_hat[1]))
def test_wrap_compiled(self):
@mx.compile
def inner():
pass
@wraps(inner)
def wrapper():
pass
def test_compiled_preserves_attributes(self):
def inner(x: mx.array, y: str):
"""
A useful function.
"""
pass
c_inner = mx.compile(inner)
self.assertEqual(inner.__name__, c_inner.__name__)
self.assertEqual(inner.__qualname__, c_inner.__qualname__)
self.assertEqual(inner.__doc__, c_inner.__doc__)
self.assertEqual(inspect.signature(inner), inspect.signature(c_inner))
if __name__ == "__main__":
mlx_tests.MLXTestRunner()