mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-01 00:28:11 +08:00 
			
		
		
		
	Fix a couple bugs (#1161)
* fix jit reduce for RMS norm * make strides a single buffer * better eval error message * fix compiling with inf and bf16 * fix cpu compile with bf16
This commit is contained in:
		| @@ -704,6 +704,13 @@ class TestCompile(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual(y1.item(), y2.item()) | ||||
|         self.assertEqual(y1.item(), 6) | ||||
|  | ||||
|     def test_inf_constant(self): | ||||
|         def fn(x): | ||||
|             return mx.where(mx.isinf(x), 0, 1) | ||||
|  | ||||
|         x = mx.array([0, float("inf"), 1], dtype=mx.bfloat16) | ||||
|         self.assertTrue(mx.array_equal(mx.compile(fn)(x), fn(x))) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun