mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-10-31 16:21:27 +08:00 
			
		
		
		
	Add theta cache for Rope and mask cache for ALiBi (#375)
This commit is contained in:
		| @@ -25,8 +25,15 @@ class RoPE(Module): | ||||
|         base (float, optional): The base used to compute angular frequency for | ||||
|             each dimension in the positional encodings. Default: ``10000``. | ||||
|         scale (float, optional): The scale used to scale the positions. Default: ``1.0``. | ||||
|  | ||||
|     Attributes: | ||||
|         _cos_sin_theta_key (tuple): Cached key for the precomputed cosine and sine values. | ||||
|         _cos_sin_theta_value (tuple): Cached cosine and sine values. | ||||
|     """ | ||||
|  | ||||
|     _cos_sin_theta_key = None | ||||
|     _cos_sin_theta_value = None | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         dims: int, | ||||
| @@ -86,8 +93,9 @@ class RoPE(Module): | ||||
|  | ||||
|         return mx.reshape(rx, shape) | ||||
|  | ||||
|     @staticmethod | ||||
|     @classmethod | ||||
|     def create_cos_sin_theta( | ||||
|         cls, | ||||
|         N: int, | ||||
|         D: int, | ||||
|         offset: int = 0, | ||||
| @@ -95,11 +103,14 @@ class RoPE(Module): | ||||
|         scale: float = 1.0, | ||||
|         dtype=mx.float32, | ||||
|     ): | ||||
|         D = D // 2 | ||||
|         positions = mx.arange(offset, N, dtype=dtype) * scale | ||||
|         freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D)) | ||||
|         theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) | ||||
|         return mx.cos(theta), mx.sin(theta) | ||||
|         if (N, D, offset, base, scale, dtype) != cls._cos_sin_theta_key: | ||||
|             D = D // 2 | ||||
|             positions = mx.arange(offset, N, dtype=dtype) * scale | ||||
|             freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D)) | ||||
|             theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) | ||||
|             cls._cos_sin_theta_key = (N, D, offset, base, scale, dtype) | ||||
|             cls._cos_sin_theta_value = (mx.cos(theta), mx.sin(theta)) | ||||
|         return cls._cos_sin_theta_value | ||||
|  | ||||
|  | ||||
| class SinusoidalPositionalEncoding(Module): | ||||
| @@ -163,22 +174,42 @@ class SinusoidalPositionalEncoding(Module): | ||||
|  | ||||
|  | ||||
| class ALiBi(Module): | ||||
|     @staticmethod | ||||
|     _alibi_mask_key = None | ||||
|     _alibi_mask = None | ||||
|  | ||||
|     @classmethod | ||||
|     def create_alibi_matrix( | ||||
|         cls, | ||||
|         q_sequence_length: int, | ||||
|         k_sequence_length: int, | ||||
|         num_heads: int, | ||||
|         offset: int, | ||||
|         dtype=mx.float32, | ||||
|     ): | ||||
|         x1 = mx.arange(offset, q_sequence_length) | ||||
|         x2 = mx.arange(0, k_sequence_length) | ||||
|         distance_matrix = -mx.abs( | ||||
|             mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1)) | ||||
|         ) | ||||
|         alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads) | ||||
|         alibi_mask = (distance_matrix * alibi_slope).astype(dtype) | ||||
|         return alibi_mask | ||||
|         if ( | ||||
|             q_sequence_length, | ||||
|             k_sequence_length, | ||||
|             num_heads, | ||||
|             offset, | ||||
|             dtype, | ||||
|         ) != cls._alibi_mask_key: | ||||
|             x1 = mx.arange(offset, q_sequence_length) | ||||
|             x2 = mx.arange(0, k_sequence_length) | ||||
|             distance_matrix = -mx.abs( | ||||
|                 mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1)) | ||||
|             ) | ||||
|             alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads) | ||||
|             alibi_mask = (distance_matrix * alibi_slope).astype(dtype) | ||||
|             cls._alibi_mask_key = ( | ||||
|                 q_sequence_length, | ||||
|                 k_sequence_length, | ||||
|                 num_heads, | ||||
|                 offset, | ||||
|                 dtype, | ||||
|             ) | ||||
|             cls._alibi_mask = alibi_mask | ||||
|  | ||||
|         return cls._alibi_mask | ||||
|  | ||||
|     @staticmethod | ||||
|     def create_alibi_slope(num_heads): | ||||
| @@ -196,4 +227,4 @@ class ALiBi(Module): | ||||
|         ) | ||||
|         if mask is not None: | ||||
|             alibi_mask = alibi_mask + mask | ||||
|         return attention_scores + alibi_mask | ||||
|         return attention_scores + alibi_mask | ||||
		Reference in New Issue
	
	Block a user
	 Anchen
					Anchen