mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	fix broadcast bug in bitwise ops (#1157)
This commit is contained in:
		| @@ -2291,6 +2291,13 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|                 out_np = getattr(np, op)(a_np, b_np) | ||||
|                 self.assertTrue(np.array_equal(np.array(out_mlx), out_np)) | ||||
|  | ||||
|         # Check broadcasting | ||||
|         a = mx.ones((3, 1, 5), dtype=mx.bool_) | ||||
|         b = mx.zeros((1, 2, 5), dtype=mx.bool_) | ||||
|         c = a | b | ||||
|         self.assertEqual(c.shape, (3, 2, 5)) | ||||
|         self.assertTrue(mx.array_equal(c, mx.ones((3, 2, 5), dtype=mx.bool_))) | ||||
|  | ||||
|     def test_conjugate(self): | ||||
|         shape = (3, 5, 7) | ||||
|         a = np.random.normal(size=shape) + 1j * np.random.normal(size=shape) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun