mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Quantize embedding (#994)
* quantize embedding * rename as_linear + comment * consistency in docs * fix test
This commit is contained in:
		| @@ -172,6 +172,19 @@ class TestBase(mlx_tests.MLXTestCase): | ||||
|         self.assertFalse(m.update(params_dict).eval()._training) | ||||
|         self.assertTrue(m.train()._training) | ||||
|  | ||||
|     def test_quantize(self): | ||||
|         m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256)) | ||||
|         nn.quantize(m) | ||||
|         self.assertTrue(isinstance(m.layers[0], nn.QuantizedEmbedding)) | ||||
|         self.assertTrue(isinstance(m.layers[1], nn.ReLU)) | ||||
|         self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) | ||||
|  | ||||
|         m = nn.Sequential(nn.Embedding(5, 256), nn.ReLU(), nn.Linear(256, 256)) | ||||
|         nn.quantize(m, class_predicate=lambda _, m: isinstance(m, nn.Linear)) | ||||
|         self.assertTrue(isinstance(m.layers[0], nn.Embedding)) | ||||
|         self.assertTrue(isinstance(m.layers[1], nn.ReLU)) | ||||
|         self.assertTrue(isinstance(m.layers[2], nn.QuantizedLinear)) | ||||
|  | ||||
|  | ||||
| class TestLayers(mlx_tests.MLXTestCase): | ||||
|     def test_identity(self): | ||||
| @@ -1606,6 +1619,19 @@ class TestLayers(mlx_tests.MLXTestCase): | ||||
|         self.assertEqual(h_out.shape, (44, 12)) | ||||
|         self.assertEqual(c_out.shape, (44, 12)) | ||||
|  | ||||
|     def test_quantized_embedding(self): | ||||
|         emb = nn.Embedding(32, 256) | ||||
|         qemb = nn.QuantizedEmbedding.from_embedding(emb, bits=8) | ||||
|         x = mx.array([2, 6, 9, 3, 0, 3]) | ||||
|         y = emb(x) | ||||
|         yq = qemb(x) | ||||
|         self.assertLess((y - yq).abs().max(), 1e-3) | ||||
|  | ||||
|         x = mx.random.uniform(shape=(2, 256)) | ||||
|         y = emb.as_linear(x) | ||||
|         yq = qemb.as_linear(x) | ||||
|         self.assertLess((y - yq).abs().max(), 1e-2) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun