mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Compile stride bug (#812)
* fix compile stride bug * revert sdpa fix * fix cpu * fix bug with simplifying outputs
This commit is contained in:
@@ -605,6 +605,14 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
out = fun(mx.array(0.0), y=MyClass())
|
||||
|
||||
def test_compile_create_list(self):
|
||||
@mx.compile
|
||||
def fun():
|
||||
return [0.1 * mx.zeros((2,)), 0.1 * mx.zeros((2,))]
|
||||
|
||||
out = fun()
|
||||
mx.eval(out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user