mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add support for Llama-3.1 (#907)
* add dynamicNTK scaling rope * remove unused var * fix rope base * llama3.1 fixes * TODO for rope eval * vectorise llama3 base freq calculation * removed the arbitrary 2.0 rope_scale default case * fix slow llama3.1 generation by evaluating stateless part of DynamicNTKScalingRoPE in init * nits + format * use mx.pi * fix tests and add test for 3.1 --------- Co-authored-by: Prince Canuma <prince.gdt@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
47060a8130
commit
cd8efc7fbc
@ -17,6 +17,7 @@ class ModelArgs(BaseModelArgs):
|
|||||||
rms_norm_eps: float
|
rms_norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
head_dim: Optional[int] = None
|
head_dim: Optional[int] = None
|
||||||
|
max_position_embeddings: Optional[int] = None
|
||||||
num_key_value_heads: Optional[int] = None
|
num_key_value_heads: Optional[int] = None
|
||||||
attention_bias: bool = False
|
attention_bias: bool = False
|
||||||
mlp_bias: bool = False
|
mlp_bias: bool = False
|
||||||
@ -30,12 +31,126 @@ class ModelArgs(BaseModelArgs):
|
|||||||
self.num_key_value_heads = self.num_attention_heads
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
if self.rope_scaling:
|
if self.rope_scaling:
|
||||||
required_keys = {"factor", "type"}
|
if not "factor" in self.rope_scaling:
|
||||||
if not all(key in self.rope_scaling for key in required_keys):
|
raise ValueError(f"rope_scaling must contain 'factor'")
|
||||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
|
||||||
|
"rope_type"
|
||||||
|
)
|
||||||
|
if rope_type is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"rope_scaling must contain either 'type' or 'rope_type'"
|
||||||
|
)
|
||||||
|
if rope_type not in ["linear", "dynamic", "llama3"]:
|
||||||
|
raise ValueError(
|
||||||
|
"rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'"
|
||||||
|
)
|
||||||
|
|
||||||
if self.rope_scaling["type"] != "linear":
|
|
||||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
class DynamicNTKScalingRoPE(nn.Module):
|
||||||
|
"""Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
max_position_embeddings: int = 2048,
|
||||||
|
traditional: bool = False,
|
||||||
|
base: float = 10000,
|
||||||
|
scale: float = 1.0,
|
||||||
|
rope_type: str = "default",
|
||||||
|
rope_scaling: dict = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dims = dims
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.traditional = traditional
|
||||||
|
self.original_base = base
|
||||||
|
self.scale = scale
|
||||||
|
self.rope_type = rope_type
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.base = self.compute_base_freq()
|
||||||
|
|
||||||
|
def compute_base_freq(self):
|
||||||
|
if self.rope_type == "llama3":
|
||||||
|
return self.compute_llama3_base_freq()
|
||||||
|
return self.original_base
|
||||||
|
|
||||||
|
# source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318
|
||||||
|
def compute_llama3_base_freq(self):
|
||||||
|
factor = self.rope_scaling["factor"]
|
||||||
|
low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
|
||||||
|
high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
|
||||||
|
old_context_len = self.rope_scaling.get(
|
||||||
|
"original_max_position_embeddings",
|
||||||
|
8192,
|
||||||
|
)
|
||||||
|
|
||||||
|
low_freq_wavelen = old_context_len / low_freq_factor
|
||||||
|
high_freq_wavelen = old_context_len / high_freq_factor
|
||||||
|
|
||||||
|
freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims)
|
||||||
|
wavelens = 2 * mx.pi * freqs
|
||||||
|
new_base_freqs = []
|
||||||
|
|
||||||
|
smooths = (wavelens - high_freq_wavelen) / (
|
||||||
|
low_freq_wavelen - high_freq_wavelen
|
||||||
|
)
|
||||||
|
new_base_freqs = freqs * (1 - smooths) * factor + smooths
|
||||||
|
new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs)
|
||||||
|
new_base_freqs = mx.where(
|
||||||
|
wavelens > low_freq_wavelen, freqs * factor, new_base_freqs
|
||||||
|
)
|
||||||
|
return new_base_freqs.mean().item()
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return (
|
||||||
|
f"{self.dims}, traditional={self.traditional}, "
|
||||||
|
f"max_position_embeddings={self.max_position_embeddings}, "
|
||||||
|
f"scaling_factor={self.scale}, rope_type={self.rope_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x, offset: int = 0):
|
||||||
|
seq_len = x.shape[1] + offset
|
||||||
|
base = self.base
|
||||||
|
if self.max_position_embeddings and seq_len > self.max_position_embeddings:
|
||||||
|
base *= (
|
||||||
|
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
|
||||||
|
) ** (self.dims / (self.dims - 2))
|
||||||
|
|
||||||
|
return mx.fast.rope(
|
||||||
|
x,
|
||||||
|
self.dims,
|
||||||
|
traditional=self.traditional,
|
||||||
|
base=base,
|
||||||
|
scale=self.scale,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_rope(args: ModelArgs):
|
||||||
|
head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
|
||||||
|
|
||||||
|
rope_scaling = args.rope_scaling
|
||||||
|
rope_type = "default"
|
||||||
|
rope_scale = 1.0
|
||||||
|
|
||||||
|
if rope_scaling is not None:
|
||||||
|
rope_type = (
|
||||||
|
rope_scaling.get("type") or rope_scaling.get("rope_type") or "default"
|
||||||
|
)
|
||||||
|
if rope_type == "linear":
|
||||||
|
rope_scale = 1 / rope_scaling["factor"]
|
||||||
|
elif rope_type == "llama3":
|
||||||
|
rope_scale = 1.0 # The scaling is handled internally for llama3
|
||||||
|
|
||||||
|
return DynamicNTKScalingRoPE(
|
||||||
|
dims=head_dim,
|
||||||
|
max_position_embeddings=args.max_position_embeddings,
|
||||||
|
traditional=args.rope_traditional,
|
||||||
|
base=args.rope_theta,
|
||||||
|
scale=rope_scale,
|
||||||
|
rope_type=rope_type,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@ -59,17 +174,7 @@ class Attention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||||
|
|
||||||
rope_scale = (
|
self.rope = initialize_rope(args)
|
||||||
1 / args.rope_scaling["factor"]
|
|
||||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
|
||||||
else 1
|
|
||||||
)
|
|
||||||
self.rope = nn.RoPE(
|
|
||||||
head_dim,
|
|
||||||
traditional=args.rope_traditional,
|
|
||||||
base=args.rope_theta,
|
|
||||||
scale=rope_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -449,6 +449,33 @@ class TestModels(unittest.TestCase):
|
|||||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_llama3_1(self):
|
||||||
|
from mlx_lm.models import llama
|
||||||
|
|
||||||
|
args = llama.ModelArgs(
|
||||||
|
model_type="llama",
|
||||||
|
hidden_size=1024,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=2048,
|
||||||
|
num_attention_heads=4,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
vocab_size=10_000,
|
||||||
|
max_position_embeddings=128,
|
||||||
|
mlp_bias=False,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
rope_scaling={
|
||||||
|
"factor": 8.0,
|
||||||
|
"low_freq_factor": 1.0,
|
||||||
|
"high_freq_factor": 4.0,
|
||||||
|
"original_max_position_embeddings": 8192,
|
||||||
|
"rope_type": "llama3",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
model = llama.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user