No reshapes in quantized embedding (#1682)

* no reshapes in quantized embedding

* fix inadvertant cast

* add tol
This commit is contained in:
Awni Hannun
2024-12-09 18:57:38 -08:00
committed by GitHub
parent 87d7a2520e
commit 29a620cab2
6 changed files with 26 additions and 12 deletions

View File

@@ -1066,6 +1066,16 @@ class TestOps(mlx_tests.MLXTestCase):
out = mx.take(a, 1, axis=1)
self.assertTrue(mx.array_equal(out, mx.array([1, 5])))
# Take with multi-dim scalar preserves dims
out = mx.take(a, mx.array(1), axis=0)
self.assertEqual(out.shape, (4,))
out = mx.take(a, mx.array([1]), axis=0)
self.assertEqual(out.shape, (1, 4))
out = mx.take(a, mx.array([[1]]), axis=0)
self.assertEqual(out.shape, (1, 1, 4))
def test_take_along_axis(self):
a_np = np.arange(8).reshape(2, 2, 2)
a_mlx = mx.array(a_np)