Fix a couple bugs (#1161)

* fix jit reduce for RMS norm

* make strides a single buffer

* better eval error message

* fix compiling with inf and bf16

* fix cpu compile with bf16
This commit is contained in:
Awni Hannun
2024-05-28 15:18:18 -07:00
committed by GitHub
parent a87ef5bfc1
commit e7a2a3dcd1
9 changed files with 59 additions and 27 deletions

View File

@@ -704,6 +704,13 @@ class TestCompile(mlx_tests.MLXTestCase):
self.assertEqual(y1.item(), y2.item())
self.assertEqual(y1.item(), 6)
def test_inf_constant(self):
def fn(x):
return mx.where(mx.isinf(x), 0, 1)
x = mx.array([0, float("inf"), 1], dtype=mx.bfloat16)
self.assertTrue(mx.array_equal(mx.compile(fn)(x), fn(x)))
if __name__ == "__main__":
unittest.main()