mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 07:58:14 +08:00 
			
		
		
		
	Custom primitive + RoPE fat op (#676)
* extensions start * rope custom op * fix build * docs + rope benchmark * fix test * Add a Metal kernel for RoPE * Fix position of traditional * transform tests * Move rope computation to float and fix tests * Fix the test and a typo * change to fast * fix no metal build --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
		| @@ -1,4 +1,4 @@ | ||||
| # Copyright © 2023 Apple Inc. | ||||
| # Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| import math | ||||
| from typing import Optional | ||||
| @@ -20,20 +20,13 @@ class RoPE(Module): | ||||
|     Args: | ||||
|         dims (int): The feature dimensions to be rotated. If the input feature | ||||
|             is larger than dims then the rest is left unchanged. | ||||
|         traditional (bool, optional): If set to True choose the traditional | ||||
|         traditional (bool, optional): If set to ``True`` choose the traditional | ||||
|             implementation which is slightly less efficient. Default: ``False``. | ||||
|         base (float, optional): The base used to compute angular frequency for | ||||
|             each dimension in the positional encodings. Default: ``10000``. | ||||
|         scale (float, optional): The scale used to scale the positions. Default: ``1.0``. | ||||
|  | ||||
|     Attributes: | ||||
|         _cos_sin_theta_key (tuple): Cached key for the precomputed cosine and sine values. | ||||
|         _cos_sin_theta_value (tuple): Cached cosine and sine values. | ||||
|     """ | ||||
|  | ||||
|     _cos_sin_theta_key = None | ||||
|     _cos_sin_theta_value = None | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         dims: int, | ||||
| @@ -50,69 +43,18 @@ class RoPE(Module): | ||||
|     def _extra_repr(self): | ||||
|         return f"{self.dims}, traditional={self.traditional}" | ||||
|  | ||||
|     def _compute_rope(self, costheta, sintheta, x): | ||||
|         x1 = x[..., : self.dims // 2] | ||||
|         x2 = x[..., self.dims // 2 : self.dims] | ||||
|         rx1 = x1 * costheta - x2 * sintheta | ||||
|         rx2 = x1 * sintheta + x2 * costheta | ||||
|  | ||||
|         if self.dims < x.shape[-1]: | ||||
|             rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1) | ||||
|         else: | ||||
|             rx = mx.concatenate([rx1, rx2], axis=-1) | ||||
|  | ||||
|         return rx | ||||
|  | ||||
|     def _compute_traditional_rope(self, costheta, sintheta, x): | ||||
|         x1 = x[..., ::2] | ||||
|         x2 = x[..., 1::2] | ||||
|         rx1 = x1 * costheta - x2 * sintheta | ||||
|         rx2 = x1 * sintheta + x2 * costheta | ||||
|  | ||||
|         if self.dims < x.shape[-1]: | ||||
|             raise NotImplementedError( | ||||
|                 "RoPE doesn't implement partial traditional application" | ||||
|             ) | ||||
|  | ||||
|         rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1) | ||||
|  | ||||
|         return rx | ||||
|  | ||||
|     def __call__(self, x, offset: int = 0): | ||||
|         shape = x.shape | ||||
|         x = mx.reshape(x, (-1, shape[-2], shape[-1])) | ||||
|         N = x.shape[1] + offset | ||||
|         costheta, sintheta = RoPE.create_cos_sin_theta( | ||||
|             N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype | ||||
|         x = mx.fast.rope( | ||||
|             x, | ||||
|             self.dims, | ||||
|             traditional=self.traditional, | ||||
|             base=self.base, | ||||
|             scale=self.scale, | ||||
|             offset=offset, | ||||
|         ) | ||||
|  | ||||
|         rope = ( | ||||
|             self._compute_traditional_rope if self.traditional else self._compute_rope | ||||
|         ) | ||||
|         rx = rope(costheta, sintheta, x) | ||||
|  | ||||
|         return mx.reshape(rx, shape) | ||||
|  | ||||
|     @classmethod | ||||
|     def create_cos_sin_theta( | ||||
|         cls, | ||||
|         N: int, | ||||
|         D: int, | ||||
|         offset: int = 0, | ||||
|         base: float = 10000, | ||||
|         scale: float = 1.0, | ||||
|         dtype=mx.float32, | ||||
|     ): | ||||
|         if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key: | ||||
|             half_D = D // 2 | ||||
|             positions = mx.arange(offset, N, dtype=dtype) * scale | ||||
|             freqs = mx.exp( | ||||
|                 -mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D) | ||||
|             ) | ||||
|             theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) | ||||
|             cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype) | ||||
|             cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta)) | ||||
|         return cls._cos_sin_theta_value | ||||
|         return mx.reshape(x, shape) | ||||
|  | ||||
|  | ||||
| class SinusoidalPositionalEncoding(Module): | ||||
|   | ||||
| @@ -3,6 +3,7 @@ pybind11_add_module( | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp | ||||
|   ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp | ||||
|   | ||||
							
								
								
									
										59
									
								
								python/src/fast.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								python/src/fast.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,59 @@ | ||||
| // Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| #include <pybind11/pybind11.h> | ||||
| #include <pybind11/stl.h> | ||||
|  | ||||
| #include "mlx/fast.h" | ||||
| #include "mlx/ops.h" | ||||
| #include "python/src/utils.h" | ||||
|  | ||||
| namespace py = pybind11; | ||||
| using namespace py::literals; | ||||
| using namespace mlx::core; | ||||
|  | ||||
| void init_extensions(py::module_& parent_module) { | ||||
|   py::options options; | ||||
|   options.disable_function_signatures(); | ||||
|  | ||||
|   auto m = | ||||
|       parent_module.def_submodule("fast", "mlx.core.fast: fast operations"); | ||||
|  | ||||
|   m.def( | ||||
|       "rope", | ||||
|       [](const array& a, | ||||
|          int dims, | ||||
|          bool traditional, | ||||
|          float base, | ||||
|          float scale, | ||||
|          int offset, | ||||
|          const StreamOrDevice& s /* = {} */) { | ||||
|         return fast::rope(a, dims, traditional, base, scale, offset, s); | ||||
|       }, | ||||
|       "a"_a, | ||||
|       "dims"_a, | ||||
|       py::kw_only(), | ||||
|       "traditional"_a, | ||||
|       "base"_a, | ||||
|       "scale"_a, | ||||
|       "offset"_a, | ||||
|       "stream"_a = none, | ||||
|       R"pbdoc( | ||||
|         rope(a: array, dims: int, *, traditinoal: bool, base: float, scale: float, offset: int, stream: Union[None, Stream, Device] = None) -> array | ||||
|  | ||||
|         Apply rotary positional encoding to the input. | ||||
|  | ||||
|         Args: | ||||
|             a (array): Input array. | ||||
|             dims (int): The feature dimensions to be rotated. If the input feature | ||||
|                 is larger than dims then the rest is left unchanged. | ||||
|             traditional (bool): If set to ``True`` choose the traditional | ||||
|                 implementation which rotates consecutive dimensions. | ||||
|             base (float): The base used to compute angular frequency for | ||||
|                 each dimension in the positional encodings. | ||||
|             scale (float): The scale used to scale the positions. | ||||
|             offset (int): The position offset to start at. | ||||
|  | ||||
|         Returns: | ||||
|             array: The output array. | ||||
|       )pbdoc"); | ||||
| } | ||||
| @@ -17,6 +17,7 @@ void init_random(py::module_&); | ||||
| void init_fft(py::module_&); | ||||
| void init_linalg(py::module_&); | ||||
| void init_constants(py::module_&); | ||||
| void init_extensions(py::module_&); | ||||
|  | ||||
| PYBIND11_MODULE(core, m) { | ||||
|   m.doc() = "mlx: A framework for machine learning on Apple silicon."; | ||||
| @@ -33,5 +34,6 @@ PYBIND11_MODULE(core, m) { | ||||
|   init_fft(m); | ||||
|   init_linalg(m); | ||||
|   init_constants(m); | ||||
|   init_extensions(m); | ||||
|   m.attr("__version__") = TOSTRING(_VERSION_); | ||||
| } | ||||
|   | ||||
| @@ -133,7 +133,7 @@ void init_random(py::module_& parent_module) { | ||||
|             low (scalar or array, optional): Lower bound of the distribution. Default is ``0``. | ||||
|             high (scalar or array, optional): Upper bound of the distribution. Default is ``1``. | ||||
|             shape (list(int), optional): Shape of the output. Default is ``()``. | ||||
|             key (array, optional): A PRNG key. Default: None. | ||||
|             key (array, optional): A PRNG key. Default: ``None``. | ||||
|             dtype (Dtype, optional): Type of the output. Default is ``float32``. | ||||
|  | ||||
|         Returns: | ||||
|   | ||||
							
								
								
									
										158
									
								
								python/tests/test_fast.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										158
									
								
								python/tests/test_fast.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,158 @@ | ||||
| # Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| import math | ||||
| import unittest | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx_tests | ||||
|  | ||||
|  | ||||
| def rope_orig(x, dims, traditional, base, scale, offset): | ||||
|     N = x.shape[1] + offset | ||||
|     dtype = x.dtype | ||||
|     half_D = dims // 2 | ||||
|     positions = mx.arange(offset, N, dtype=dtype) * scale | ||||
|     freqs = mx.exp(-mx.arange(0.0, half_D, dtype=dtype) * (math.log(base) / half_D)) | ||||
|     theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) | ||||
|     costheta, sintheta = mx.cos(theta), mx.sin(theta) | ||||
|     if traditional: | ||||
|         x1 = x[..., ::2] | ||||
|         x2 = x[..., 1::2] | ||||
|         rx1 = x1 * costheta - x2 * sintheta | ||||
|         rx2 = x1 * sintheta + x2 * costheta | ||||
|         rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1) | ||||
|         return mx.reshape(rx, x.shape) | ||||
|     else: | ||||
|         x1 = x[..., : dims // 2] | ||||
|         x2 = x[..., dims // 2 : dims] | ||||
|         rx1 = x1 * costheta - x2 * sintheta | ||||
|         rx2 = x1 * sintheta + x2 * costheta | ||||
|         if dims < x.shape[-1]: | ||||
|             rx = mx.concatenate([rx1, rx2, x[..., dims:]], axis=-1) | ||||
|         else: | ||||
|             rx = mx.concatenate([rx1, rx2], axis=-1) | ||||
|         return rx | ||||
|  | ||||
|  | ||||
| class TestFast(mlx_tests.MLXTestCase): | ||||
|     def test_rope(self): | ||||
|         T = 4 | ||||
|  | ||||
|         # Defaults: dims, dtype, base, scale, offset, traditional | ||||
|         defaults = (8, mx.float32, 10000.0, 1.0, 0, False) | ||||
|  | ||||
|         # Per dtype absolute tolerance | ||||
|         tolerances = {mx.float32: 1e-6, mx.float16: 1e-3, mx.bfloat16: 1e-2} | ||||
|  | ||||
|         # Test cases: | ||||
|         dtypes = [mx.float32, mx.float16, mx.bfloat16] | ||||
|         bases = [10000.0, 1000000.0] | ||||
|         scales = [1.0, 2.0] | ||||
|         offsets = [0, 3] | ||||
|         traditional = [True, False] | ||||
|  | ||||
|         for traditional in [True, False]: | ||||
|             dims, dtype, _, scale, offset, _ = defaults | ||||
|             for base in bases: | ||||
|                 x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) | ||||
|                 rx = rope_orig(x, dims, traditional, base, scale, offset) | ||||
|                 rx_fast = mx.fast.rope( | ||||
|                     x, | ||||
|                     dims, | ||||
|                     traditional=traditional, | ||||
|                     base=base, | ||||
|                     scale=scale, | ||||
|                     offset=offset, | ||||
|                 ) | ||||
|                 self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) | ||||
|  | ||||
|             dims, _, base, scale, offset, _ = defaults | ||||
|             for dtype in dtypes: | ||||
|                 x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) | ||||
|                 ry = rope_orig( | ||||
|                     x.astype(mx.float32), dims, traditional, base, scale, offset | ||||
|                 ) | ||||
|                 rx = rope_orig(x, dims, traditional, base, scale, offset) | ||||
|                 rx_fast = mx.fast.rope( | ||||
|                     x, | ||||
|                     dims, | ||||
|                     traditional=traditional, | ||||
|                     base=base, | ||||
|                     scale=scale, | ||||
|                     offset=offset, | ||||
|                 ) | ||||
|                 if dtype != mx.float32: | ||||
|                     self.assertLessEqual( | ||||
|                         mx.abs(ry - rx_fast).max(), mx.abs(ry - rx).max() | ||||
|                     ) | ||||
|                 self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) | ||||
|  | ||||
|             dims, dtype, base, scale, _, _ = defaults | ||||
|             for offset in offsets: | ||||
|                 x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) | ||||
|                 rx = rope_orig(x, dims, traditional, base, scale, offset) | ||||
|                 rx_fast = mx.fast.rope( | ||||
|                     x, | ||||
|                     dims, | ||||
|                     traditional=traditional, | ||||
|                     base=base, | ||||
|                     scale=scale, | ||||
|                     offset=offset, | ||||
|                 ) | ||||
|                 self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) | ||||
|  | ||||
|             dims, dtype, base, _, offset, _ = defaults | ||||
|             for scale in scales: | ||||
|                 x = mx.random.uniform(shape=(2, T, dims)).astype(dtype) | ||||
|                 rx = rope_orig(x, dims, traditional, base, scale, offset) | ||||
|                 rx_fast = mx.fast.rope( | ||||
|                     x, | ||||
|                     dims, | ||||
|                     traditional=traditional, | ||||
|                     base=base, | ||||
|                     scale=scale, | ||||
|                     offset=offset, | ||||
|                 ) | ||||
|                 self.assertLess(mx.abs(rx - rx_fast).max(), tolerances[dtype]) | ||||
|  | ||||
|     def test_fast_transforms(self): | ||||
|         x = mx.random.uniform(shape=(2, 2, 8)) | ||||
|  | ||||
|         defaults = (8, False, 10000.0, 1.0, 0) | ||||
|         dims, traditional, base, scale, offset = defaults | ||||
|  | ||||
|         # VJP | ||||
|         _, vjp_out = mx.vjp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),)) | ||||
|         _, vjp_fast_out = mx.vjp( | ||||
|             lambda x: mx.fast.rope( | ||||
|                 x, dims, traditional=traditional, base=base, scale=scale, offset=offset | ||||
|             ), | ||||
|             (x,), | ||||
|             (mx.ones_like(x),), | ||||
|         ) | ||||
|         self.assertTrue(mx.allclose(vjp_out[0], vjp_fast_out[0])) | ||||
|  | ||||
|         # JVP | ||||
|         _, jvp_out = mx.jvp(lambda x: rope_orig(x, *defaults), (x,), (mx.ones_like(x),)) | ||||
|         _, jvp_fast_out = mx.jvp( | ||||
|             lambda x: mx.fast.rope( | ||||
|                 x, dims, traditional=traditional, base=base, scale=scale, offset=offset | ||||
|             ), | ||||
|             (x,), | ||||
|             (mx.ones_like(x),), | ||||
|         ) | ||||
|         self.assertTrue(mx.allclose(jvp_out[0], jvp_fast_out[0])) | ||||
|  | ||||
|         # VMAP | ||||
|         x = mx.random.uniform(shape=(2, 2, 2, 8)) | ||||
|         vmap_out = mx.vmap(lambda x: rope_orig(x, *defaults))(x) | ||||
|         vmap_fast_out = mx.vmap( | ||||
|             lambda x: mx.fast.rope( | ||||
|                 x, dims, traditional=traditional, base=base, scale=scale, offset=offset | ||||
|             ) | ||||
|         )(x) | ||||
|         self.assertTrue(mx.allclose(vmap_out, vmap_fast_out)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun