mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add types, fix kwarg ordering bug + test
This commit is contained in:
@@ -485,6 +485,19 @@ class TestExportImport(mlx_tests.MLXTestCase):
|
||||
mx.array_equal(imported_fn(input_data)[0], model(input_data))
|
||||
)
|
||||
|
||||
def test_export_kwarg_ordering(self):
|
||||
path = os.path.join(self.test_dir, "fun.mlxfn")
|
||||
|
||||
def fn(x, y):
|
||||
return x - y
|
||||
|
||||
mx.export_function(path, fn, x=mx.array(1.0), y=mx.array(1.0))
|
||||
imported = mx.import_function(path)
|
||||
out = imported(x=mx.array(2.0), y=mx.array(3.0))[0]
|
||||
self.assertEqual(out.item(), -1.0)
|
||||
out = imported(y=mx.array(2.0), x=mx.array(3.0))[0]
|
||||
self.assertEqual(out.item(), 1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlx_tests.MLXTestRunner()
|
||||
|
||||
Reference in New Issue
Block a user