fix multi-output with compile

This commit is contained in:
Awni Hannun
2024-01-17 06:55:11 -08:00
parent ed4d867092
commit 390e50b163
5 changed files with 39 additions and 21 deletions

View File

@@ -164,12 +164,12 @@ class TestCompile(mlx_tests.MLXTestCase):
n_compiled = count_prims(cfun(x))
# Check disabled
mx.disable_compiler()
mx.disable_compile()
n_uncompiled = count_prims(cfun(x))
self.assertTrue(n_compiled < n_uncompiled)
# Check renabled
mx.enable_compiler()
mx.enable_compile()
n_enable_compiled = count_prims(cfun(x))
self.assertEqual(n_compiled, n_enable_compiled)