Use fast rope (#945)

* use fast rope

* fix llama

* use fast rope for llama3.1

* requires unreleased mlx

* fix su

* fix deepseek v2

* only one of base or freqs

* nit

* fix

* hard code freqs
This commit is contained in:
Awni Hannun 2024-08-23 13:18:51 -07:00 committed by GitHub
parent 58591a1b41
commit 6731254e76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 65 additions and 137 deletions

View File

@ -68,13 +68,12 @@ def yarn_get_mscale(scale=1, mscale=1):
return 0.1 * mscale * math.log(scale) + 1.0 return 0.1 * mscale * math.log(scale) + 1.0
def yarn_linear_ramp_mask(min, max, dim): def yarn_linear_ramp_mask(min_val, max_val, dim):
if min == max: if min_val == max_val:
max += 0.001 # Prevent singularity max_val += 0.001 # Prevent singularity
linear_func = (mx.arange(dim, dtype=mx.float32) - min) / (max - min) linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
ramp_func = mx.clip(linear_func, 0, 1) return mx.clip(linear_func, 0, 1)
return ramp_func
class DeepseekV2YarnRotaryEmbedding(nn.Module): class DeepseekV2YarnRotaryEmbedding(nn.Module):
@ -91,72 +90,36 @@ class DeepseekV2YarnRotaryEmbedding(nn.Module):
mscale_all_dim=0, mscale_all_dim=0,
): ):
super().__init__() super().__init__()
self.dim = dim self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
self.max_position_embeddings = max_position_embeddings scaling_factor, mscale_all_dim
self.base = base )
self.scaling_factor = scaling_factor freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
self.original_max_position_embeddings = original_max_position_embeddings freq_inter = scaling_factor * base ** (
self.beta_fast = beta_fast mx.arange(0, dim, 2, dtype=mx.float32) / dim
self.beta_slow = beta_slow
self.mscale = mscale
self.mscale_all_dim = mscale_all_dim
self.max_seq_len_cached = None
self._cos_cached = None
self._sin_cached = None
self._inv_freq = None
self.set_cos_sin_cache(max_position_embeddings)
def set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
dim = self.dim
freq_extra = 1.0 / (self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim))
freq_inter = 1.0 / (
self.scaling_factor
* self.base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
) )
low, high = yarn_find_correction_range( low, high = yarn_find_correction_range(
self.beta_fast, beta_fast,
self.beta_slow, beta_slow,
dim, dim,
self.base, base,
self.original_max_position_embeddings, original_max_position_embeddings,
) )
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask self._freqs = (freq_inter * freq_extra) / (
self._inv_freq = inv_freq freq_inter * freq_mask + freq_extra * (1 - freq_mask)
t = mx.arange(seq_len, dtype=mx.float32)
freqs = mx.outer(t, inv_freq)
mscale = yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale(
self.scaling_factor, self.mscale_all_dim
) )
self._cos_cached = mx.cos(freqs) * mscale
self._sin_cached = mx.sin(freqs) * mscale
def apply_rotary_pos_emb(self, x, cos, sin):
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = x1 * cos - x2 * sin
rx2 = x1 * sin + x2 * cos
return mx.concatenate([rx1, rx2], axis=-1)
def __call__(self, x, offset=0): def __call__(self, x, offset=0):
seq_len = offset + x.shape[2] if self.mscale != 1.0:
if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: x = self.mscale * x
self.set_cos_sin_cache(seq_len=seq_len) return mx.fast.rope(
if self._cos_cached.dtype != x.dtype:
self._cos_cached = self._cos_cached.astype(x.dtype)
self._sin_cached = self._sin_cached.astype(x.dtype)
return self.apply_rotary_pos_emb(
x, x,
self._cos_cached[offset:seq_len], x.shape[-1],
self._sin_cached[offset:seq_len], traditional=True,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
) )

View File

@ -65,19 +65,16 @@ class DynamicNTKScalingRoPE(nn.Module):
self.dims = dims self.dims = dims
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.traditional = traditional self.traditional = traditional
self.original_base = base
self.scale = scale self.scale = scale
self.rope_type = rope_type self.rope_type = rope_type
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self.base = self.compute_base_freq() self.base = base
self.compute_freqs()
def compute_base_freq(self): def compute_freqs(self):
if self.rope_type == "llama3": if self.rope_type != "llama3":
return self.compute_llama3_base_freq() self._freqs = None
return self.original_base return
# source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318
def compute_llama3_base_freq(self):
factor = self.rope_scaling["factor"] factor = self.rope_scaling["factor"]
low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0) low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0) high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
@ -89,19 +86,17 @@ class DynamicNTKScalingRoPE(nn.Module):
low_freq_wavelen = old_context_len / low_freq_factor low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor high_freq_wavelen = old_context_len / high_freq_factor
freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims) freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims)
wavelens = 2 * mx.pi * freqs wavelens = 2 * mx.pi * freqs
new_base_freqs = []
smooths = (wavelens - high_freq_wavelen) / ( freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
low_freq_wavelen - high_freq_wavelen is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
high_freq_factor - low_freq_factor
) )
new_base_freqs = freqs * (1 - smooths) * factor + smooths smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs) self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
new_base_freqs = mx.where( self.base = None
wavelens > low_freq_wavelen, freqs * factor, new_base_freqs
)
return new_base_freqs.mean().item()
def extra_repr(self): def extra_repr(self):
return ( return (
@ -111,20 +106,14 @@ class DynamicNTKScalingRoPE(nn.Module):
) )
def __call__(self, x, offset: int = 0): def __call__(self, x, offset: int = 0):
seq_len = x.shape[1] + offset
base = self.base
if self.max_position_embeddings and seq_len > self.max_position_embeddings:
base *= (
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
) ** (self.dims / (self.dims - 2))
return mx.fast.rope( return mx.fast.rope(
x, x,
self.dims, self.dims,
traditional=self.traditional, traditional=self.traditional,
base=base, base=self.base,
scale=self.scale, scale=self.scale,
offset=offset, offset=offset,
freqs=self._freqs,
) )

View File

@ -59,19 +59,17 @@ class Attention(nn.Module):
self.qkv_proj = nn.Linear(dim, op_size, bias=False) self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]: if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
self.rope = SuScaledRotaryEmbedding( self.rope = SuScaledRotaryEmbedding(
head_dim, head_dim,
traditional=False,
base=args.rope_theta, base=args.rope_theta,
scale=rope_scale,
max_position_embeddings=args.max_position_embeddings, max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings, original_max_position_embeddings=args.original_max_position_embeddings,
short_factor=args.rope_scaling["short_factor"], short_factor=args.rope_scaling["short_factor"],
long_factor=args.rope_scaling["long_factor"], long_factor=args.rope_scaling["long_factor"],
) )
else: else:
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "linear": if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float) assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"] rope_scale = 1 / args.rope_scaling["factor"]

View File

@ -4,15 +4,14 @@ import math
from typing import List, Union from typing import List, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn
class SuScaledRotaryEmbedding: class SuScaledRotaryEmbedding(nn.Module):
def __init__( def __init__(
self, self,
dims: int, dims: int,
traditional: bool = False,
base: float = 10000.0, base: float = 10000.0,
scale: float = 1.0,
max_position_embeddings: int = 131072, max_position_embeddings: int = 131072,
original_max_position_embeddings: int = 4096, original_max_position_embeddings: int = 4096,
short_factor: Union[List[float], float] = 1.0, short_factor: Union[List[float], float] = 1.0,
@ -23,10 +22,7 @@ class SuScaledRotaryEmbedding:
Args: Args:
dims (int): The feature dimensions to be rotated. dims (int): The feature dimensions to be rotated.
traditional (bool, optional): Unused. Default: ``False``.
base (int, optional): Base for the exponential scaling. base (int, optional): Base for the exponential scaling.
scale (float, optional): The scale used to scale the positions.
Default: ``1.0``.
max_position_embeddings (int, optional): The maximum sequence max_position_embeddings (int, optional): The maximum sequence
length that this model was trained with. This is used to determine length that this model was trained with. This is used to determine
the size of the original RoPE embeddings when using long scaling. the size of the original RoPE embeddings when using long scaling.
@ -42,40 +38,23 @@ class SuScaledRotaryEmbedding:
factors for sequences of length greater than factors for sequences of length greater than
``original_max_position_embeddings``. Default: ``1.0``. ``original_max_position_embeddings``. Default: ``1.0``.
""" """
self.inv_freq_short = 1.0 / ( super().__init__()
mx.array(short_factor, dtype=mx.float32) freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs
)
self.inv_freq_long = 1.0 / (
scale
* mx.array(long_factor, dtype=mx.float32)
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
)
self.original_max_position_embeddings = original_max_position_embeddings self.original_max_position_embeddings = original_max_position_embeddings
self.scaling_factor = math.sqrt( self.scale = math.sqrt(
1 1
+ math.log(max_position_embeddings / original_max_position_embeddings) + math.log(max_position_embeddings / original_max_position_embeddings)
/ math.log(original_max_position_embeddings) / math.log(original_max_position_embeddings)
) )
def _get_cos_sin(self, offset, L):
position_ids = mx.arange(offset, offset + L, dtype=mx.float32)
inv_freq = (
self.inv_freq_long
if (offset + L) > self.original_max_position_embeddings
else self.inv_freq_short
)
freqs = position_ids[:, None] * inv_freq[None, :]
emb = mx.concatenate([freqs, freqs], axis=-1)
cos = mx.cos(emb) * self.scaling_factor
sin = mx.sin(emb) * self.scaling_factor
return cos, sin
def __call__(self, x, offset: int = 0): def __call__(self, x, offset: int = 0):
def _rotate_half(_x): return mx.fast.rope(
midpoint = _x.shape[-1] // 2 self.scale * x,
x1, x2 = _x[..., :midpoint], _x[..., midpoint:] x.shape[-1],
return mx.concatenate([-x2, x1], axis=-1) traditional=False,
base=None,
cos, sin = self._get_cos_sin(offset, x.shape[2]) scale=1.0,
return (x * cos) + (_rotate_half(x) * sin) offset=offset,
freqs=self._freqs,
)

View File

@ -1,4 +1,4 @@
mlx>=0.14.1 mlx>=0.17.0
numpy numpy
transformers[sentencepiece]>=4.39.3 transformers[sentencepiece]>=4.39.3
protobuf protobuf

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.17.0" __version__ = "0.17.1"

View File

@ -110,7 +110,7 @@ class Transformer2D(nn.Module):
# Perform the input norm and projection # Perform the input norm and projection
B, H, W, C = x.shape B, H, W, C = x.shape
x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C) x = self.norm(x).reshape(B, -1, C)
x = self.proj_in(x) x = self.proj_in(x)
# Apply the transformer # Apply the transformer
@ -156,12 +156,12 @@ class ResnetBlock2D(nn.Module):
if temb is not None: if temb is not None:
temb = self.time_emb_proj(nn.silu(temb)) temb = self.time_emb_proj(nn.silu(temb))
y = self.norm1(x.astype(mx.float32)).astype(dtype) y = self.norm1(x)
y = nn.silu(y) y = nn.silu(y)
y = self.conv1(y) y = self.conv1(y)
if temb is not None: if temb is not None:
y = y + temb[:, None, None, :] y = y + temb[:, None, None, :]
y = self.norm2(y.astype(mx.float32)).astype(dtype) y = self.norm2(y)
y = nn.silu(y) y = nn.silu(y)
y = self.conv2(y) y = self.conv2(y)
@ -453,8 +453,7 @@ class UNetModel(nn.Module):
) )
# Postprocess the output # Postprocess the output
dtype = x.dtype x = self.conv_norm_out(x)
x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
x = nn.silu(x) x = nn.silu(x)
x = self.conv_out(x) x = self.conv_out(x)