mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-05 07:34:42 +08:00
No reshapes in quantized embedding (#1682)
* no reshapes in quantized embedding * fix inadvertant cast * add tol
This commit is contained in:
@@ -1066,6 +1066,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = mx.take(a, 1, axis=1)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([1, 5])))
|
||||
|
||||
# Take with multi-dim scalar preserves dims
|
||||
out = mx.take(a, mx.array(1), axis=0)
|
||||
self.assertEqual(out.shape, (4,))
|
||||
|
||||
out = mx.take(a, mx.array([1]), axis=0)
|
||||
self.assertEqual(out.shape, (1, 4))
|
||||
|
||||
out = mx.take(a, mx.array([[1]]), axis=0)
|
||||
self.assertEqual(out.shape, (1, 1, 4))
|
||||
|
||||
def test_take_along_axis(self):
|
||||
a_np = np.arange(8).reshape(2, 2, 2)
|
||||
a_mlx = mx.array(a_np)
|
||||
|
Reference in New Issue
Block a user