Quantize embedding (#994)

* quantize embedding

* rename as_linear + comment

* consistency in docs

* fix test
This commit is contained in:
Awni Hannun
2024-04-15 16:42:10 -07:00
committed by GitHub
parent 2e7c02d5cd
commit cd9e184529
9 changed files with 269 additions and 54 deletions

View File

@@ -172,6 +172,19 @@ class TestBase(mlx_tests.MLXTestCase):
self.assertFalse(m.update(params_dict).eval()._training)
self.assertTrue(m.train()._training)
def test_quantize(self):
m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256))
nn.quantize(m)
self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding))
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256))
nn.quantize(m, class_predicate=lambda _, m: isinstance(m, nn.Linear))
self.assertTrue(isinstance(m.layers[0], nn.Embedding))
self.assertTrue(isinstance(m.layers[1], nn.ReLU))
self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear))
class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self):
@@ -1606,6 +1619,19 @@ class TestLayers(mlx_tests.MLXTestCase):
self.assertEqual(h_out.shape, (44, 12))
self.assertEqual(c_out.shape, (44, 12))
def test_quantized_embedding(self):
emb = nn.Embedding(32, 256)
qemb = nn.QuantizedEmbedding.from_embedding(emb, bits=8)
x = mx.array([2, 6, 9, 3, 0, 3])
y = emb(x)
yq = qemb(x)
self.assertLess((y - yq).abs().max(), 1e-3)
x = mx.random.uniform(shape=(2, 256))
y = emb.as_linear(x)
yq = qemb.as_linear(x)
self.assertLess((y - yq).abs().max(), 1e-2)
if __name__ == "__main__":
unittest.main()