mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
No reshapes in quantized embedding (#1682)
* no reshapes in quantized embedding * fix inadvertant cast * add tol
This commit is contained in:
parent
87d7a2520e
commit
29a620cab2
@ -98,16 +98,13 @@ class QuantizedEmbedding(Module):
|
||||
self.freeze()
|
||||
|
||||
def __call__(self, x):
|
||||
s = x.shape
|
||||
x = x.flatten()
|
||||
out = mx.dequantize(
|
||||
return mx.dequantize(
|
||||
self["weight"][x],
|
||||
scales=self["scales"][x],
|
||||
biases=self["biases"][x],
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
)
|
||||
return out.reshape(*s, -1)
|
||||
|
||||
def as_linear(self, x):
|
||||
"""
|
||||
|
@ -1395,11 +1395,12 @@ void init_ops(nb::module_& m) {
|
||||
m.def(
|
||||
"take",
|
||||
[](const array& a,
|
||||
const std::variant<int, array>& indices,
|
||||
const std::variant<nb::int_, array>& indices,
|
||||
const std::optional<int>& axis,
|
||||
StreamOrDevice s) {
|
||||
if (auto pv = std::get_if<int>(&indices); pv) {
|
||||
return axis ? take(a, *pv, axis.value(), s) : take(a, *pv, s);
|
||||
if (auto pv = std::get_if<nb::int_>(&indices); pv) {
|
||||
auto idx = nb::cast<int>(*pv);
|
||||
return axis ? take(a, idx, axis.value(), s) : take(a, idx, s);
|
||||
} else {
|
||||
auto indices_ = std::get<array>(indices);
|
||||
return axis ? take(a, indices_, axis.value(), s)
|
||||
|
@ -459,13 +459,13 @@ void init_random(nb::module_& parent_module) {
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"permuation",
|
||||
[](const std::variant<int, array>& x,
|
||||
[](const std::variant<nb::int_, array>& x,
|
||||
int axis,
|
||||
const std::optional<array>& key_,
|
||||
StreamOrDevice s) {
|
||||
auto key = key_ ? key_.value() : default_key().next();
|
||||
if (auto pv = std::get_if<int>(&x); pv) {
|
||||
return permutation(*pv, key, s);
|
||||
if (auto pv = std::get_if<nb::int_>(&x); pv) {
|
||||
return permutation(nb::cast<int>(*pv), key, s);
|
||||
} else {
|
||||
return permutation(std::get<array>(x), axis, key, s);
|
||||
}
|
||||
|
@ -1066,6 +1066,16 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = mx.take(a, 1, axis=1)
|
||||
self.assertTrue(mx.array_equal(out, mx.array([1, 5])))
|
||||
|
||||
# Take with multi-dim scalar preserves dims
|
||||
out = mx.take(a, mx.array(1), axis=0)
|
||||
self.assertEqual(out.shape, (4,))
|
||||
|
||||
out = mx.take(a, mx.array([1]), axis=0)
|
||||
self.assertEqual(out.shape, (1, 4))
|
||||
|
||||
out = mx.take(a, mx.array([[1]]), axis=0)
|
||||
self.assertEqual(out.shape, (1, 1, 4))
|
||||
|
||||
def test_take_along_axis(self):
|
||||
a_np = np.arange(8).reshape(2, 2, 2)
|
||||
a_mlx = mx.array(a_np)
|
||||
|
@ -348,6 +348,10 @@ class TestRandom(mlx_tests.MLXTestCase):
|
||||
x = mx.random.permutation(16384)
|
||||
self.assertFalse(mx.array_equal(sorted_x, x))
|
||||
|
||||
# Preserves shape / doesn't cast input to int
|
||||
x = mx.random.permutation(mx.array([[1]]))
|
||||
self.assertEqual(x.shape, (1, 1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -353,7 +353,7 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
|
||||
for i in range(a.shape[0]):
|
||||
self.assertTrue(
|
||||
mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=0, atol=1e-5)
|
||||
mx.allclose(a[i] @ invs[i], mx.eye(a.shape[1]), rtol=1e-4, atol=1e-5)
|
||||
)
|
||||
|
||||
a = mx.random.uniform(shape=(4, 3, 4))
|
||||
@ -367,7 +367,9 @@ class TestVmap(mlx_tests.MLXTestCase):
|
||||
|
||||
for i in range(a.shape[1]):
|
||||
self.assertTrue(
|
||||
mx.allclose(a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=0, atol=1e-5)
|
||||
mx.allclose(
|
||||
a[:, i, :] @ invs[i], mx.eye(a.shape[0]), rtol=1e-4, atol=1e-5
|
||||
)
|
||||
)
|
||||
|
||||
def test_vmap_gather(self):
|
||||
|
Loading…
Reference in New Issue
Block a user