mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	| @@ -161,49 +161,6 @@ void init_fast(nb::module_& parent_module) { | ||||
|             array: The output array. | ||||
|       )pbdoc"); | ||||
|  | ||||
|   m.def( | ||||
|       "affine_quantize", | ||||
|       nb::overload_cast< | ||||
|           const array&, | ||||
|           const array&, | ||||
|           const array&, | ||||
|           int, | ||||
|           int, | ||||
|           StreamOrDevice>(&fast::affine_quantize), | ||||
|       "w"_a, | ||||
|       "scales"_a, | ||||
|       "biases"_a, | ||||
|       "group_size"_a = 64, | ||||
|       "bits"_a = 4, | ||||
|       nb::kw_only(), | ||||
|       "stream"_a = nb::none(), | ||||
|       nb::sig( | ||||
|           "def affine_quantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), | ||||
|       R"pbdoc( | ||||
|         Quantize the matrix ``w`` using the provided ``scales`` and | ||||
|         ``biases`` and the ``group_size`` and ``bits`` configuration. | ||||
|  | ||||
|         Formally, given the notation in :func:`quantize`, we compute | ||||
|         :math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and | ||||
|         :math:`\beta` as follows | ||||
|  | ||||
|         .. math:: | ||||
|  | ||||
|           w_i = s (\hat{w_i} + \beta) | ||||
|  | ||||
|         Args: | ||||
|           w (array): Matrix to be quantize | ||||
|           scales (array): The scales to use per ``group_size`` elements of ``w`` | ||||
|           biases (array): The biases to use per ``group_size`` elements of ``w`` | ||||
|           group_size (int, optional): The size of the group in ``w`` that shares a | ||||
|             scale and bias. (default: ``64``) | ||||
|           bits (int, optional): The number of bits occupied by each element in | ||||
|             ``w``. (default: ``4``) | ||||
|  | ||||
|         Returns: | ||||
|           array: The quantized version of ``w`` | ||||
|       )pbdoc"); | ||||
|  | ||||
|   m.def( | ||||
|       "metal_kernel", | ||||
|       [](const std::string& name, | ||||
|   | ||||
| @@ -549,18 +549,6 @@ class TestFast(mlx_tests.MLXTestCase): | ||||
|         )(x) | ||||
|         self.assertTrue(mx.allclose(vmap_out, vmap_fast_out)) | ||||
|  | ||||
|     def test_affine_quantize(self): | ||||
|         mx.random.seed(7) | ||||
|         x = mx.random.uniform(shape=(4, 1024)) | ||||
|         for bits in (2, 4, 8): | ||||
|             for group_size in (32, 64, 128): | ||||
|                 with self.subTest(bits=bits, group_size=group_size): | ||||
|                     w, scales, biases = mx.quantize(x, bits=bits, group_size=group_size) | ||||
|                     w_p = mx.fast.affine_quantize( | ||||
|                         x, scales, biases, bits=bits, group_size=group_size | ||||
|                     ) | ||||
|                     self.assertTrue(mx.allclose(w, w_p)) | ||||
|  | ||||
|     @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") | ||||
|     def test_custom_kernel_basic(self): | ||||
|         mx.random.seed(7) | ||||
|   | ||||
| @@ -11,7 +11,7 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|     def test_quantize_dequantize(self): | ||||
|         w = mx.random.normal(shape=(128, 512)) | ||||
|         for gs in [32, 64, 128]: | ||||
|             for b in [2, 4, 8]: | ||||
|             for b in [2, 3, 6, 4, 8]: | ||||
|                 with self.subTest(gs=gs, b=b): | ||||
|                     w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b) | ||||
|                     w_hat = mx.dequantize(w_q, scales, biases, gs, b) | ||||
| @@ -22,7 +22,7 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|         # test quantize/dequantize 0s | ||||
|         a = mx.zeros((256, 512)) | ||||
|         for gs in [32, 64, 128]: | ||||
|             for b in [2, 4, 8]: | ||||
|             for b in [2, 3, 4, 6, 8]: | ||||
|                 w_q, scales, biases = mx.quantize(a, gs, b) | ||||
|                 a_hat = mx.dequantize(w_q, scales, biases, gs, b) | ||||
|                 self.assertTrue(mx.all(a_hat == 0)) | ||||
| @@ -116,7 +116,7 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|         k1, k2 = mx.random.split(key) | ||||
|         tests = product( | ||||
|             [128, 64, 32],  # group_size | ||||
|             [2, 4, 8],  # bits | ||||
|             [2, 3, 4, 6, 8],  # bits | ||||
|             [512, 1024, 67],  # M | ||||
|             [64, 128, 512, 1024],  # N | ||||
|             [0, 1, 3, 8],  # B | ||||
| @@ -143,7 +143,7 @@ class TestQuantized(mlx_tests.MLXTestCase): | ||||
|         k1, k2 = mx.random.split(key) | ||||
|         tests = product( | ||||
|             [128, 64, 32],  # group_size | ||||
|             [2, 4, 8],  # bits | ||||
|             [2, 3, 4, 6, 8],  # bits | ||||
|             [512, 1024],  # M | ||||
|             [512, 1024, 67],  # N | ||||
|             [0, 1, 3, 8],  # B | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Alex Barron
					Alex Barron