mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-23 14:08:07 +08:00 
			
		
		
		
	Use fast rope (#945)
* use fast rope * fix llama * use fast rope for llama3.1 * requires unreleased mlx * fix su * fix deepseek v2 * only one of base or freqs * nit * fix * hard code freqs
This commit is contained in:
		| @@ -68,13 +68,12 @@ def yarn_get_mscale(scale=1, mscale=1): | ||||
|     return 0.1 * mscale * math.log(scale) + 1.0 | ||||
|  | ||||
|  | ||||
| def yarn_linear_ramp_mask(min, max, dim): | ||||
|     if min == max: | ||||
|         max += 0.001  # Prevent singularity | ||||
| def yarn_linear_ramp_mask(min_val, max_val, dim): | ||||
|     if min_val == max_val: | ||||
|         max_val += 0.001  # Prevent singularity | ||||
|  | ||||
|     linear_func = (mx.arange(dim, dtype=mx.float32) - min) / (max - min) | ||||
|     ramp_func = mx.clip(linear_func, 0, 1) | ||||
|     return ramp_func | ||||
|     linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val) | ||||
|     return mx.clip(linear_func, 0, 1) | ||||
|  | ||||
|  | ||||
| class DeepseekV2YarnRotaryEmbedding(nn.Module): | ||||
| @@ -91,72 +90,36 @@ class DeepseekV2YarnRotaryEmbedding(nn.Module): | ||||
|         mscale_all_dim=0, | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.dim = dim | ||||
|         self.max_position_embeddings = max_position_embeddings | ||||
|         self.base = base | ||||
|         self.scaling_factor = scaling_factor | ||||
|         self.original_max_position_embeddings = original_max_position_embeddings | ||||
|         self.beta_fast = beta_fast | ||||
|         self.beta_slow = beta_slow | ||||
|         self.mscale = mscale | ||||
|         self.mscale_all_dim = mscale_all_dim | ||||
|  | ||||
|         self.max_seq_len_cached = None | ||||
|         self._cos_cached = None | ||||
|         self._sin_cached = None | ||||
|         self._inv_freq = None | ||||
|         self.set_cos_sin_cache(max_position_embeddings) | ||||
|  | ||||
|     def set_cos_sin_cache(self, seq_len): | ||||
|         self.max_seq_len_cached = seq_len | ||||
|         dim = self.dim | ||||
|         freq_extra = 1.0 / (self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)) | ||||
|         freq_inter = 1.0 / ( | ||||
|             self.scaling_factor | ||||
|             * self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim) | ||||
|         self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale( | ||||
|             scaling_factor, mscale_all_dim | ||||
|         ) | ||||
|         freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim) | ||||
|         freq_inter = scaling_factor * base ** ( | ||||
|             mx.arange(0, dim, 2, dtype=mx.float32) / dim | ||||
|         ) | ||||
|  | ||||
|         low, high = yarn_find_correction_range( | ||||
|             self.beta_fast, | ||||
|             self.beta_slow, | ||||
|             beta_fast, | ||||
|             beta_slow, | ||||
|             dim, | ||||
|             self.base, | ||||
|             self.original_max_position_embeddings, | ||||
|             base, | ||||
|             original_max_position_embeddings, | ||||
|         ) | ||||
|         inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) | ||||
|         inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask | ||||
|         self._inv_freq = inv_freq | ||||
|  | ||||
|         t = mx.arange(seq_len, dtype=mx.float32) | ||||
|         freqs = mx.outer(t, inv_freq) | ||||
|  | ||||
|         mscale = yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale( | ||||
|             self.scaling_factor, self.mscale_all_dim | ||||
|         freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) | ||||
|         self._freqs = (freq_inter * freq_extra) / ( | ||||
|             freq_inter * freq_mask + freq_extra * (1 - freq_mask) | ||||
|         ) | ||||
|  | ||||
|         self._cos_cached = mx.cos(freqs) * mscale | ||||
|         self._sin_cached = mx.sin(freqs) * mscale | ||||
|  | ||||
|     def apply_rotary_pos_emb(self, x, cos, sin): | ||||
|         x1 = x[..., ::2] | ||||
|         x2 = x[..., 1::2] | ||||
|         rx1 = x1 * cos - x2 * sin | ||||
|         rx2 = x1 * sin + x2 * cos | ||||
|         return mx.concatenate([rx1, rx2], axis=-1) | ||||
|  | ||||
|     def __call__(self, x, offset=0): | ||||
|         seq_len = offset + x.shape[2] | ||||
|         if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: | ||||
|             self.set_cos_sin_cache(seq_len=seq_len) | ||||
|  | ||||
|         if self._cos_cached.dtype != x.dtype: | ||||
|             self._cos_cached = self._cos_cached.astype(x.dtype) | ||||
|             self._sin_cached = self._sin_cached.astype(x.dtype) | ||||
|  | ||||
|         return self.apply_rotary_pos_emb( | ||||
|         if self.mscale != 1.0: | ||||
|             x = self.mscale * x | ||||
|         return mx.fast.rope( | ||||
|             x, | ||||
|             self._cos_cached[offset:seq_len], | ||||
|             self._sin_cached[offset:seq_len], | ||||
|             x.shape[-1], | ||||
|             traditional=True, | ||||
|             base=None, | ||||
|             scale=1.0, | ||||
|             offset=offset, | ||||
|             freqs=self._freqs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -65,19 +65,16 @@ class DynamicNTKScalingRoPE(nn.Module): | ||||
|         self.dims = dims | ||||
|         self.max_position_embeddings = max_position_embeddings | ||||
|         self.traditional = traditional | ||||
|         self.original_base = base | ||||
|         self.scale = scale | ||||
|         self.rope_type = rope_type | ||||
|         self.rope_scaling = rope_scaling | ||||
|         self.base = self.compute_base_freq() | ||||
|         self.base = base | ||||
|         self.compute_freqs() | ||||
|  | ||||
|     def compute_base_freq(self): | ||||
|         if self.rope_type == "llama3": | ||||
|             return self.compute_llama3_base_freq() | ||||
|         return self.original_base | ||||
|  | ||||
|     # source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318 | ||||
|     def compute_llama3_base_freq(self): | ||||
|     def compute_freqs(self): | ||||
|         if self.rope_type != "llama3": | ||||
|             self._freqs = None | ||||
|             return | ||||
|         factor = self.rope_scaling["factor"] | ||||
|         low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0) | ||||
|         high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0) | ||||
| @@ -89,19 +86,17 @@ class DynamicNTKScalingRoPE(nn.Module): | ||||
|         low_freq_wavelen = old_context_len / low_freq_factor | ||||
|         high_freq_wavelen = old_context_len / high_freq_factor | ||||
|  | ||||
|         freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims) | ||||
|         freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims) | ||||
|         wavelens = 2 * mx.pi * freqs | ||||
|         new_base_freqs = [] | ||||
|  | ||||
|         smooths = (wavelens - high_freq_wavelen) / ( | ||||
|             low_freq_wavelen - high_freq_wavelen | ||||
|         freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) | ||||
|         is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) | ||||
|         smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( | ||||
|             high_freq_factor - low_freq_factor | ||||
|         ) | ||||
|         new_base_freqs = freqs * (1 - smooths) * factor + smooths | ||||
|         new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs) | ||||
|         new_base_freqs = mx.where( | ||||
|             wavelens > low_freq_wavelen, freqs * factor, new_base_freqs | ||||
|         ) | ||||
|         return new_base_freqs.mean().item() | ||||
|         smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) | ||||
|         self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) | ||||
|         self.base = None | ||||
|  | ||||
|     def extra_repr(self): | ||||
|         return ( | ||||
| @@ -111,20 +106,14 @@ class DynamicNTKScalingRoPE(nn.Module): | ||||
|         ) | ||||
|  | ||||
|     def __call__(self, x, offset: int = 0): | ||||
|         seq_len = x.shape[1] + offset | ||||
|         base = self.base | ||||
|         if self.max_position_embeddings and seq_len > self.max_position_embeddings: | ||||
|             base *= ( | ||||
|                 (self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1) | ||||
|             ) ** (self.dims / (self.dims - 2)) | ||||
|  | ||||
|         return mx.fast.rope( | ||||
|             x, | ||||
|             self.dims, | ||||
|             traditional=self.traditional, | ||||
|             base=base, | ||||
|             base=self.base, | ||||
|             scale=self.scale, | ||||
|             offset=offset, | ||||
|             freqs=self._freqs, | ||||
|         ) | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -59,19 +59,17 @@ class Attention(nn.Module): | ||||
|         self.qkv_proj = nn.Linear(dim, op_size, bias=False) | ||||
|         self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) | ||||
|  | ||||
|         rope_scale = 1.0 | ||||
|         if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]: | ||||
|             self.rope = SuScaledRotaryEmbedding( | ||||
|                 head_dim, | ||||
|                 traditional=False, | ||||
|                 base=args.rope_theta, | ||||
|                 scale=rope_scale, | ||||
|                 max_position_embeddings=args.max_position_embeddings, | ||||
|                 original_max_position_embeddings=args.original_max_position_embeddings, | ||||
|                 short_factor=args.rope_scaling["short_factor"], | ||||
|                 long_factor=args.rope_scaling["long_factor"], | ||||
|             ) | ||||
|         else: | ||||
|             rope_scale = 1.0 | ||||
|             if args.rope_scaling and args.rope_scaling["type"] == "linear": | ||||
|                 assert isinstance(args.rope_scaling["factor"], float) | ||||
|                 rope_scale = 1 / args.rope_scaling["factor"] | ||||
|   | ||||
| @@ -4,15 +4,14 @@ import math | ||||
| from typing import List, Union | ||||
|  | ||||
| import mlx.core as mx | ||||
| import mlx.nn as nn | ||||
|  | ||||
|  | ||||
| class SuScaledRotaryEmbedding: | ||||
| class SuScaledRotaryEmbedding(nn.Module): | ||||
|     def __init__( | ||||
|         self, | ||||
|         dims: int, | ||||
|         traditional: bool = False, | ||||
|         base: float = 10000.0, | ||||
|         scale: float = 1.0, | ||||
|         max_position_embeddings: int = 131072, | ||||
|         original_max_position_embeddings: int = 4096, | ||||
|         short_factor: Union[List[float], float] = 1.0, | ||||
| @@ -23,10 +22,7 @@ class SuScaledRotaryEmbedding: | ||||
|  | ||||
|         Args: | ||||
|             dims (int): The feature dimensions to be rotated. | ||||
|             traditional (bool, optional): Unused. Default: ``False``. | ||||
|             base (int, optional): Base for the exponential scaling. | ||||
|             scale (float, optional): The scale used to scale the positions. | ||||
|               Default: ``1.0``. | ||||
|             max_position_embeddings (int, optional): The maximum sequence | ||||
|               length that this model was trained with. This is used to determine | ||||
|               the size of the original RoPE embeddings when using long scaling. | ||||
| @@ -42,40 +38,23 @@ class SuScaledRotaryEmbedding: | ||||
|               factors for sequences of length greater than | ||||
|               ``original_max_position_embeddings``.  Default: ``1.0``. | ||||
|         """ | ||||
|         self.inv_freq_short = 1.0 / ( | ||||
|             mx.array(short_factor, dtype=mx.float32) | ||||
|             * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) | ||||
|         ) | ||||
|         self.inv_freq_long = 1.0 / ( | ||||
|             scale | ||||
|             * mx.array(long_factor, dtype=mx.float32) | ||||
|             * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) | ||||
|         ) | ||||
|         super().__init__() | ||||
|         freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) | ||||
|         self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs | ||||
|         self.original_max_position_embeddings = original_max_position_embeddings | ||||
|         self.scaling_factor = math.sqrt( | ||||
|         self.scale = math.sqrt( | ||||
|             1 | ||||
|             + math.log(max_position_embeddings / original_max_position_embeddings) | ||||
|             / math.log(original_max_position_embeddings) | ||||
|         ) | ||||
|  | ||||
|     def _get_cos_sin(self, offset, L): | ||||
|         position_ids = mx.arange(offset, offset + L, dtype=mx.float32) | ||||
|         inv_freq = ( | ||||
|             self.inv_freq_long | ||||
|             if (offset + L) > self.original_max_position_embeddings | ||||
|             else self.inv_freq_short | ||||
|         ) | ||||
|         freqs = position_ids[:, None] * inv_freq[None, :] | ||||
|         emb = mx.concatenate([freqs, freqs], axis=-1) | ||||
|         cos = mx.cos(emb) * self.scaling_factor | ||||
|         sin = mx.sin(emb) * self.scaling_factor | ||||
|         return cos, sin | ||||
|  | ||||
|     def __call__(self, x, offset: int = 0): | ||||
|         def _rotate_half(_x): | ||||
|             midpoint = _x.shape[-1] // 2 | ||||
|             x1, x2 = _x[..., :midpoint], _x[..., midpoint:] | ||||
|             return mx.concatenate([-x2, x1], axis=-1) | ||||
|  | ||||
|         cos, sin = self._get_cos_sin(offset, x.shape[2]) | ||||
|         return (x * cos) + (_rotate_half(x) * sin) | ||||
|         return mx.fast.rope( | ||||
|             self.scale * x, | ||||
|             x.shape[-1], | ||||
|             traditional=False, | ||||
|             base=None, | ||||
|             scale=1.0, | ||||
|             offset=offset, | ||||
|             freqs=self._freqs, | ||||
|         ) | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| mlx>=0.14.1 | ||||
| mlx>=0.17.0 | ||||
| numpy | ||||
| transformers[sentencepiece]>=4.39.3 | ||||
| protobuf | ||||
|   | ||||
| @@ -1,3 +1,3 @@ | ||||
| # Copyright © 2023-2024 Apple Inc. | ||||
|  | ||||
| __version__ = "0.17.0" | ||||
| __version__ = "0.17.1" | ||||
|   | ||||
| @@ -110,7 +110,7 @@ class Transformer2D(nn.Module): | ||||
|  | ||||
|         # Perform the input norm and projection | ||||
|         B, H, W, C = x.shape | ||||
|         x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C) | ||||
|         x = self.norm(x).reshape(B, -1, C) | ||||
|         x = self.proj_in(x) | ||||
|  | ||||
|         # Apply the transformer | ||||
| @@ -156,12 +156,12 @@ class ResnetBlock2D(nn.Module): | ||||
|         if temb is not None: | ||||
|             temb = self.time_emb_proj(nn.silu(temb)) | ||||
|  | ||||
|         y = self.norm1(x.astype(mx.float32)).astype(dtype) | ||||
|         y = self.norm1(x) | ||||
|         y = nn.silu(y) | ||||
|         y = self.conv1(y) | ||||
|         if temb is not None: | ||||
|             y = y + temb[:, None, None, :] | ||||
|         y = self.norm2(y.astype(mx.float32)).astype(dtype) | ||||
|         y = self.norm2(y) | ||||
|         y = nn.silu(y) | ||||
|         y = self.conv2(y) | ||||
|  | ||||
| @@ -453,8 +453,7 @@ class UNetModel(nn.Module): | ||||
|             ) | ||||
|  | ||||
|         # Postprocess the output | ||||
|         dtype = x.dtype | ||||
|         x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype) | ||||
|         x = self.conv_norm_out(x) | ||||
|         x = nn.silu(x) | ||||
|         x = self.conv_out(x) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun