mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 23:08:11 +08:00
fix wraps compile (#2461)
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import gc
|
||||
import inspect
|
||||
import io
|
||||
import math
|
||||
import unittest
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
from io import StringIO
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -1014,6 +1015,28 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertTrue(mx.allclose(d[0], d_hat[0]))
|
||||
self.assertTrue(mx.allclose(d[1], d_hat[1]))
|
||||
|
||||
def test_wrap_compiled(self):
|
||||
@mx.compile
|
||||
def inner():
|
||||
pass
|
||||
|
||||
@wraps(inner)
|
||||
def wrapper():
|
||||
pass
|
||||
|
||||
def test_compiled_preserves_attributes(self):
|
||||
def inner(x: mx.array, y: str):
|
||||
"""
|
||||
A useful function.
|
||||
"""
|
||||
pass
|
||||
|
||||
c_inner = mx.compile(inner)
|
||||
self.assertEqual(inner.__name__, c_inner.__name__)
|
||||
self.assertEqual(inner.__qualname__, c_inner.__qualname__)
|
||||
self.assertEqual(inner.__doc__, c_inner.__doc__)
|
||||
self.assertEqual(inspect.signature(inner), inspect.signature(c_inner))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
Reference in New Issue
Block a user