mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	No reshapes in quantized embedding (#1682)
* no reshapes in quantized embedding * fix inadvertant cast * add tol
This commit is contained in:
		| @@ -1066,6 +1066,16 @@ class TestOps(mlx_tests.MLXTestCase): | ||||
|         out = mx.take(a, 1, axis=1) | ||||
|         self.assertTrue(mx.array_equal(out, mx.array([1, 5]))) | ||||
|  | ||||
|         # Take with multi-dim scalar preserves dims | ||||
|         out = mx.take(a, mx.array(1), axis=0) | ||||
|         self.assertEqual(out.shape, (4,)) | ||||
|  | ||||
|         out = mx.take(a, mx.array([1]), axis=0) | ||||
|         self.assertEqual(out.shape, (1, 4)) | ||||
|  | ||||
|         out = mx.take(a, mx.array([[1]]), axis=0) | ||||
|         self.assertEqual(out.shape, (1, 1, 4)) | ||||
|  | ||||
|     def test_take_along_axis(self): | ||||
|         a_np = np.arange(8).reshape(2, 2, 2) | ||||
|         a_mlx = mx.array(a_np) | ||||
|   | ||||
| @@ -348,6 +348,10 @@ class TestRandom(mlx_tests.MLXTestCase): | ||||
|         x = mx.random.permutation(16384) | ||||
|         self.assertFalse(mx.array_equal(sorted_x, x)) | ||||
|  | ||||
|         # Preserves shape / doesn't cast input to int | ||||
|         x = mx.random.permutation(mx.array([[1]])) | ||||
|         self.assertEqual(x.shape, (1, 1)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
| @@ -353,7 +353,7 @@ class TestVmap(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         for i in range(a.shape[0]): | ||||
|             self.assertTrue( | ||||
|                 mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=0, atol=1e-5) | ||||
|                 mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=1e-4, atol=1e-5) | ||||
|             ) | ||||
|  | ||||
|         a = mx.random.uniform(shape=(4, 3, 4)) | ||||
| @@ -367,7 +367,9 @@ class TestVmap(mlx_tests.MLXTestCase): | ||||
|  | ||||
|         for i in range(a.shape[1]): | ||||
|             self.assertTrue( | ||||
|                 mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5) | ||||
|                 mx.allclose( | ||||
|                     a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=1e-4, atol=1e-5 | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|     def test_vmap_gather(self): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun