diff --git a/python/mlx/nn/layers/quantized.py b/python/mlx/nn/layers/quantized.py index aa5996dab..823a0084f 100644 --- a/python/mlx/nn/layers/quantized.py +++ b/python/mlx/nn/layers/quantized.py @@ -98,16 +98,13 @@ class QuantizedEmbedding(Module): self.freeze() def __call__(self, x): - s = x.shape - x = x.flatten() - out = mx.dequantize( + return mx.dequantize( self["weight"][x], scales=self["scales"][x], biases=self["biases"][x], group_size=self.group_size, bits=self.bits, ) - return out.reshape(*s, -1) def as_linear(self, x): """ diff --git a/python/src/ops.cpp b/python/src/ops.cpp index eb69b2659..1becce7e8 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -1395,11 +1395,12 @@ void init_ops(nb::module_& m) { m.def( "take", [](const array& a, - const std::variant& indices, + const std::variant& indices, const std::optional& axis, StreamOrDevice s) { - if (auto pv = std::get_if(&indices); pv) { - return axis ? take(a, *pv, axis.value(), s) : take(a, *pv, s); + if (auto pv = std::get_if(&indices); pv) { + auto idx = nb::cast(*pv); + return axis ? take(a, idx, axis.value(), s) : take(a, idx, s); } else { auto indices_ = std::get(indices); return axis ? take(a, indices_, axis.value(), s) diff --git a/python/src/random.cpp b/python/src/random.cpp index 538a46aaf..7b3200764 100644 --- a/python/src/random.cpp +++ b/python/src/random.cpp @@ -459,13 +459,13 @@ void init_random(nb::module_& parent_module) { )pbdoc"); m.def( "permuation", - [](const std::variant& x, + [](const std::variant& x, int axis, const std::optional& key_, StreamOrDevice s) { auto key = key_ ? key_.value() : default_key().next(); - if (auto pv = std::get_if(&x); pv) { - return permutation(*pv, key, s); + if (auto pv = std::get_if(&x); pv) { + return permutation(nb::cast(*pv), key, s); } else { return permutation(std::get(x), axis, key, s); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 54a3cf8c4..7763a2e58 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1066,6 +1066,16 @@ class TestOps(mlx_tests.MLXTestCase): out = mx.take(a, 1, axis=1) self.assertTrue(mx.array_equal(out, mx.array([1, 5]))) + # Take with multi-dim scalar preserves dims + out = mx.take(a, mx.array(1), axis=0) + self.assertEqual(out.shape, (4,)) + + out = mx.take(a, mx.array([1]), axis=0) + self.assertEqual(out.shape, (1, 4)) + + out = mx.take(a, mx.array([[1]]), axis=0) + self.assertEqual(out.shape, (1, 1, 4)) + def test_take_along_axis(self): a_np = np.arange(8).reshape(2, 2, 2) a_mlx = mx.array(a_np) diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 3491297ed..9efbfb5f6 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -348,6 +348,10 @@ class TestRandom(mlx_tests.MLXTestCase): x = mx.random.permutation(16384) self.assertFalse(mx.array_equal(sorted_x, x)) + # Preserves shape / doesn't cast input to int + x = mx.random.permutation(mx.array([[1]])) + self.assertEqual(x.shape, (1, 1)) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_vmap.py b/python/tests/test_vmap.py index f71c500e2..0789593c5 100644 --- a/python/tests/test_vmap.py +++ b/python/tests/test_vmap.py @@ -353,7 +353,7 @@ class TestVmap(mlx_tests.MLXTestCase): for i in range(a.shape[0]): self.assertTrue( - mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=0, atol=1e-5) + mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=1e-4, atol=1e-5) ) a = mx.random.uniform(shape=(4, 3, 4)) @@ -367,7 +367,9 @@ class TestVmap(mlx_tests.MLXTestCase): for i in range(a.shape[1]): self.assertTrue( - mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5) + mx.allclose( + a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=1e-4, atol=1e-5 + ) ) def test_vmap_gather(self):