mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 12:49:44 +08:00
Quantize embedding (#994)
* quantize embedding * rename as_linear + comment * consistency in docs * fix test
This commit is contained in:
@@ -172,6 +172,19 @@ class TestBase(mlx_tests.MLXTestCase):
|
||||
self.assertFalse(m.update(params_dict).eval()._training)
|
||||
self.assertTrue(m.train()._training)
|
||||
|
||||
def test_quantize(self):
|
||||
m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256))
|
||||
nn.quantize(m)
|
||||
self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding))
|
||||
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
||||
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
||||
|
||||
m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256))
|
||||
nn.quantize(m, class_predicate=lambda _, m: isinstance(m, nn.Linear))
|
||||
self.assertTrue(isinstance(m.layers[0], nn.Embedding))
|
||||
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
|
||||
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
|
||||
|
||||
|
||||
class TestLayers(mlx_tests.MLXTestCase):
|
||||
def test_identity(self):
|
||||
@@ -1606,6 +1619,19 @@ class TestLayers(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(h_out.shape, (44, 12))
|
||||
self.assertEqual(c_out.shape, (44, 12))
|
||||
|
||||
def test_quantized_embedding(self):
|
||||
emb = nn.Embedding(32, 256)
|
||||
qemb = nn.QuantizedEmbedding.from_embedding(emb, bits=8)
|
||||
x = mx.array([2, 6, 9, 3, 0, 3])
|
||||
y = emb(x)
|
||||
yq = qemb(x)
|
||||
self.assertLess((y - yq).abs().max(), 1e-3)
|
||||
|
||||
x = mx.random.uniform(shape=(2, 256))
|
||||
y = emb.as_linear(x)
|
||||
yq = qemb.as_linear(x)
|
||||
self.assertLess((y - yq).abs().max(), 1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user