Compile stride bug (#812)

* fix compile stride bug

* revert sdpa fix

* fix cpu

* fix bug with simplifying outputs
This commit is contained in:
Awni Hannun
2024-03-11 06:31:31 -07:00
committed by GitHub
parent a4d290adb9
commit 7c441600fe
9 changed files with 58 additions and 12 deletions

View File

@@ -605,6 +605,14 @@ class TestCompile(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
out = fun(mx.array(0.0), y=MyClass())
def test_compile_create_list(self):
@mx.compile
def fun():
return [0.1 * mx.zeros((2,)), 0.1 * mx.zeros((2,))]
out = fun()
mx.eval(out)
if __name__ == "__main__":
unittest.main()