mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 04:24:36 +08:00
Fix a couple bugs (#1161)
* fix jit reduce for RMS norm * make strides a single buffer * better eval error message * fix compiling with inf and bf16 * fix cpu compile with bf16
This commit is contained in:
@@ -704,6 +704,13 @@ class TestCompile(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y1.item(), y2.item())
|
||||
self.assertEqual(y1.item(), 6)
|
||||
|
||||
def test_inf_constant(self):
|
||||
def fn(x):
|
||||
return mx.where(mx.isinf(x), 0, 1)
|
||||
|
||||
x = mx.array([0, float("inf"), 1], dtype=mx.bfloat16)
|
||||
self.assertTrue(mx.array_equal(mx.compile(fn)(x), fn(x)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user