From 6731254e761f69d3c0925fd682a4d988fbf3fe7a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 23 Aug 2024 13:18:51 -0700 Subject: [PATCH] Use fast rope (#945) * use fast rope * fix llama * use fast rope for llama3.1 * requires unreleased mlx * fix su * fix deepseek v2 * only one of base or freqs * nit * fix * hard code freqs --- llms/mlx_lm/models/deepseek_v2.py | 91 +++++++---------------- llms/mlx_lm/models/llama.py | 43 ++++------- llms/mlx_lm/models/phi3.py | 4 +- llms/mlx_lm/models/su_rope.py | 51 ++++--------- llms/mlx_lm/requirements.txt | 2 +- llms/mlx_lm/version.py | 2 +- stable_diffusion/stable_diffusion/unet.py | 9 +-- 7 files changed, 65 insertions(+), 137 deletions(-) diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index f320b564..602a9710 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -68,13 +68,12 @@ def yarn_get_mscale(scale=1, mscale=1): return 0.1 * mscale * math.log(scale) + 1.0 -def yarn_linear_ramp_mask(min, max, dim): - if min == max: - max += 0.001 # Prevent singularity +def yarn_linear_ramp_mask(min_val, max_val, dim): + if min_val == max_val: + max_val += 0.001 # Prevent singularity - linear_func = (mx.arange(dim, dtype=mx.float32) - min) / (max - min) - ramp_func = mx.clip(linear_func, 0, 1) - return ramp_func + linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val) + return mx.clip(linear_func, 0, 1) class DeepseekV2YarnRotaryEmbedding(nn.Module): @@ -91,72 +90,36 @@ class DeepseekV2YarnRotaryEmbedding(nn.Module): mscale_all_dim=0, ): super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.scaling_factor = scaling_factor - self.original_max_position_embeddings = original_max_position_embeddings - self.beta_fast = beta_fast - self.beta_slow = beta_slow - self.mscale = mscale - self.mscale_all_dim = mscale_all_dim - - self.max_seq_len_cached = None - self._cos_cached = None - self._sin_cached = None - self._inv_freq = None - self.set_cos_sin_cache(max_position_embeddings) - - def set_cos_sin_cache(self, seq_len): - self.max_seq_len_cached = seq_len - dim = self.dim - freq_extra = 1.0 / (self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)) - freq_inter = 1.0 / ( - self.scaling_factor - * self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim) + self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale( + scaling_factor, mscale_all_dim + ) + freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim) + freq_inter = scaling_factor * base ** ( + mx.arange(0, dim, 2, dtype=mx.float32) / dim ) - low, high = yarn_find_correction_range( - self.beta_fast, - self.beta_slow, + beta_fast, + beta_slow, dim, - self.base, - self.original_max_position_embeddings, + base, + original_max_position_embeddings, ) - inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) - inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask - self._inv_freq = inv_freq - - t = mx.arange(seq_len, dtype=mx.float32) - freqs = mx.outer(t, inv_freq) - - mscale = yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale( - self.scaling_factor, self.mscale_all_dim + freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) + self._freqs = (freq_inter * freq_extra) / ( + freq_inter * freq_mask + freq_extra * (1 - freq_mask) ) - self._cos_cached = mx.cos(freqs) * mscale - self._sin_cached = mx.sin(freqs) * mscale - - def apply_rotary_pos_emb(self, x, cos, sin): - x1 = x[..., ::2] - x2 = x[..., 1::2] - rx1 = x1 * cos - x2 * sin - rx2 = x1 * sin + x2 * cos - return mx.concatenate([rx1, rx2], axis=-1) - def __call__(self, x, offset=0): - seq_len = offset + x.shape[2] - if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: - self.set_cos_sin_cache(seq_len=seq_len) - - if self._cos_cached.dtype != x.dtype: - self._cos_cached = self._cos_cached.astype(x.dtype) - self._sin_cached = self._sin_cached.astype(x.dtype) - - return self.apply_rotary_pos_emb( + if self.mscale != 1.0: + x = self.mscale * x + return mx.fast.rope( x, - self._cos_cached[offset:seq_len], - self._sin_cached[offset:seq_len], + x.shape[-1], + traditional=True, + base=None, + scale=1.0, + offset=offset, + freqs=self._freqs, ) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 192e591f..c4a947a5 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -65,19 +65,16 @@ class DynamicNTKScalingRoPE(nn.Module): self.dims = dims self.max_position_embeddings = max_position_embeddings self.traditional = traditional - self.original_base = base self.scale = scale self.rope_type = rope_type self.rope_scaling = rope_scaling - self.base = self.compute_base_freq() + self.base = base + self.compute_freqs() - def compute_base_freq(self): - if self.rope_type == "llama3": - return self.compute_llama3_base_freq() - return self.original_base - - # source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318 - def compute_llama3_base_freq(self): + def compute_freqs(self): + if self.rope_type != "llama3": + self._freqs = None + return factor = self.rope_scaling["factor"] low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0) high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0) @@ -89,19 +86,17 @@ class DynamicNTKScalingRoPE(nn.Module): low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor - freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims) + freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims) wavelens = 2 * mx.pi * freqs - new_base_freqs = [] - smooths = (wavelens - high_freq_wavelen) / ( - low_freq_wavelen - high_freq_wavelen + freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) + is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) + smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( + high_freq_factor - low_freq_factor ) - new_base_freqs = freqs * (1 - smooths) * factor + smooths - new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs) - new_base_freqs = mx.where( - wavelens > low_freq_wavelen, freqs * factor, new_base_freqs - ) - return new_base_freqs.mean().item() + smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) + self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) + self.base = None def extra_repr(self): return ( @@ -111,20 +106,14 @@ class DynamicNTKScalingRoPE(nn.Module): ) def __call__(self, x, offset: int = 0): - seq_len = x.shape[1] + offset - base = self.base - if self.max_position_embeddings and seq_len > self.max_position_embeddings: - base *= ( - (self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1) - ) ** (self.dims / (self.dims - 2)) - return mx.fast.rope( x, self.dims, traditional=self.traditional, - base=base, + base=self.base, scale=self.scale, offset=offset, + freqs=self._freqs, ) diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index f8facdb1..112ade7d 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -59,19 +59,17 @@ class Attention(nn.Module): self.qkv_proj = nn.Linear(dim, op_size, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - rope_scale = 1.0 if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]: self.rope = SuScaledRotaryEmbedding( head_dim, - traditional=False, base=args.rope_theta, - scale=rope_scale, max_position_embeddings=args.max_position_embeddings, original_max_position_embeddings=args.original_max_position_embeddings, short_factor=args.rope_scaling["short_factor"], long_factor=args.rope_scaling["long_factor"], ) else: + rope_scale = 1.0 if args.rope_scaling and args.rope_scaling["type"] == "linear": assert isinstance(args.rope_scaling["factor"], float) rope_scale = 1 / args.rope_scaling["factor"] diff --git a/llms/mlx_lm/models/su_rope.py b/llms/mlx_lm/models/su_rope.py index 2ee20a63..f96b9957 100644 --- a/llms/mlx_lm/models/su_rope.py +++ b/llms/mlx_lm/models/su_rope.py @@ -4,15 +4,14 @@ import math from typing import List, Union import mlx.core as mx +import mlx.nn as nn -class SuScaledRotaryEmbedding: +class SuScaledRotaryEmbedding(nn.Module): def __init__( self, dims: int, - traditional: bool = False, base: float = 10000.0, - scale: float = 1.0, max_position_embeddings: int = 131072, original_max_position_embeddings: int = 4096, short_factor: Union[List[float], float] = 1.0, @@ -23,10 +22,7 @@ class SuScaledRotaryEmbedding: Args: dims (int): The feature dimensions to be rotated. - traditional (bool, optional): Unused. Default: ``False``. base (int, optional): Base for the exponential scaling. - scale (float, optional): The scale used to scale the positions. - Default: ``1.0``. max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. @@ -42,40 +38,23 @@ class SuScaledRotaryEmbedding: factors for sequences of length greater than ``original_max_position_embeddings``. Default: ``1.0``. """ - self.inv_freq_short = 1.0 / ( - mx.array(short_factor, dtype=mx.float32) - * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) - ) - self.inv_freq_long = 1.0 / ( - scale - * mx.array(long_factor, dtype=mx.float32) - * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) - ) + super().__init__() + freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) + self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs self.original_max_position_embeddings = original_max_position_embeddings - self.scaling_factor = math.sqrt( + self.scale = math.sqrt( 1 + math.log(max_position_embeddings / original_max_position_embeddings) / math.log(original_max_position_embeddings) ) - def _get_cos_sin(self, offset, L): - position_ids = mx.arange(offset, offset + L, dtype=mx.float32) - inv_freq = ( - self.inv_freq_long - if (offset + L) > self.original_max_position_embeddings - else self.inv_freq_short - ) - freqs = position_ids[:, None] * inv_freq[None, :] - emb = mx.concatenate([freqs, freqs], axis=-1) - cos = mx.cos(emb) * self.scaling_factor - sin = mx.sin(emb) * self.scaling_factor - return cos, sin - def __call__(self, x, offset: int = 0): - def _rotate_half(_x): - midpoint = _x.shape[-1] // 2 - x1, x2 = _x[..., :midpoint], _x[..., midpoint:] - return mx.concatenate([-x2, x1], axis=-1) - - cos, sin = self._get_cos_sin(offset, x.shape[2]) - return (x * cos) + (_rotate_half(x) * sin) + return mx.fast.rope( + self.scale * x, + x.shape[-1], + traditional=False, + base=None, + scale=1.0, + offset=offset, + freqs=self._freqs, + ) diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 4875f931..814c03cc 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.14.1 +mlx>=0.17.0 numpy transformers[sentencepiece]>=4.39.3 protobuf diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index f73aaa0a..41237905 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.17.0" +__version__ = "0.17.1" diff --git a/stable_diffusion/stable_diffusion/unet.py b/stable_diffusion/stable_diffusion/unet.py index ec2915e5..cfad7fcc 100644 --- a/stable_diffusion/stable_diffusion/unet.py +++ b/stable_diffusion/stable_diffusion/unet.py @@ -110,7 +110,7 @@ class Transformer2D(nn.Module): # Perform the input norm and projection B, H, W, C = x.shape - x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C) + x = self.norm(x).reshape(B, -1, C) x = self.proj_in(x) # Apply the transformer @@ -156,12 +156,12 @@ class ResnetBlock2D(nn.Module): if temb is not None: temb = self.time_emb_proj(nn.silu(temb)) - y = self.norm1(x.astype(mx.float32)).astype(dtype) + y = self.norm1(x) y = nn.silu(y) y = self.conv1(y) if temb is not None: y = y + temb[:, None, None, :] - y = self.norm2(y.astype(mx.float32)).astype(dtype) + y = self.norm2(y) y = nn.silu(y) y = self.conv2(y) @@ -453,8 +453,7 @@ class UNetModel(nn.Module): ) # Postprocess the output - dtype = x.dtype - x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype) + x = self.conv_norm_out(x) x = nn.silu(x) x = self.conv_out(x)