Add support for Llama-3.1 (#907)

* add dynamicNTK scaling rope

* remove unused var

* fix rope base

* llama3.1 fixes

* TODO for rope eval

* vectorise llama3 base freq calculation

* removed the arbitrary 2.0 rope_scale default case

* fix slow llama3.1 generation by evaluating stateless part of DynamicNTKScalingRoPE in init

* nits + format

* use mx.pi

* fix tests and add test for 3.1

---------

Co-authored-by: Prince Canuma <prince.gdt@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Alex Cheema 2024-07-23 13:21:32 -07:00 committed by GitHub
parent 47060a8130
commit cd8efc7fbc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 148 additions and 16 deletions

View File

@ -17,6 +17,7 @@ class ModelArgs(BaseModelArgs):
rms_norm_eps: float rms_norm_eps: float
vocab_size: int vocab_size: int
head_dim: Optional[int] = None head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
num_key_value_heads: Optional[int] = None num_key_value_heads: Optional[int] = None
attention_bias: bool = False attention_bias: bool = False
mlp_bias: bool = False mlp_bias: bool = False
@ -30,12 +31,126 @@ class ModelArgs(BaseModelArgs):
self.num_key_value_heads = self.num_attention_heads self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling: if self.rope_scaling:
required_keys = {"factor", "type"} if not "factor" in self.rope_scaling:
if not all(key in self.rope_scaling for key in required_keys): raise ValueError(f"rope_scaling must contain 'factor'")
raise ValueError(f"rope_scaling must contain keys {required_keys}") rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
"rope_type"
)
if rope_type is None:
raise ValueError(
f"rope_scaling must contain either 'type' or 'rope_type'"
)
if rope_type not in ["linear", "dynamic", "llama3"]:
raise ValueError(
"rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'"
)
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'") class DynamicNTKScalingRoPE(nn.Module):
"""Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE."""
def __init__(
self,
dims: int,
max_position_embeddings: int = 2048,
traditional: bool = False,
base: float = 10000,
scale: float = 1.0,
rope_type: str = "default",
rope_scaling: dict = None,
):
super().__init__()
self.dims = dims
self.max_position_embeddings = max_position_embeddings
self.traditional = traditional
self.original_base = base
self.scale = scale
self.rope_type = rope_type
self.rope_scaling = rope_scaling
self.base = self.compute_base_freq()
def compute_base_freq(self):
if self.rope_type == "llama3":
return self.compute_llama3_base_freq()
return self.original_base
# source: https://github.com/huggingface/transformers/blob/d5a99dfcee6e94065cb7c83cc8ab6fc5daa0cc4e/src/transformers/modeling_rope_utils.py#L318
def compute_llama3_base_freq(self):
factor = self.rope_scaling["factor"]
low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
old_context_len = self.rope_scaling.get(
"original_max_position_embeddings",
8192,
)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
freqs = self.original_base ** (mx.arange(0, self.dims, 2) / self.dims)
wavelens = 2 * mx.pi * freqs
new_base_freqs = []
smooths = (wavelens - high_freq_wavelen) / (
low_freq_wavelen - high_freq_wavelen
)
new_base_freqs = freqs * (1 - smooths) * factor + smooths
new_base_freqs = mx.where(wavelens < high_freq_wavelen, freqs, new_base_freqs)
new_base_freqs = mx.where(
wavelens > low_freq_wavelen, freqs * factor, new_base_freqs
)
return new_base_freqs.mean().item()
def extra_repr(self):
return (
f"{self.dims}, traditional={self.traditional}, "
f"max_position_embeddings={self.max_position_embeddings}, "
f"scaling_factor={self.scale}, rope_type={self.rope_type}"
)
def __call__(self, x, offset: int = 0):
seq_len = x.shape[1] + offset
base = self.base
if self.max_position_embeddings and seq_len > self.max_position_embeddings:
base *= (
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
) ** (self.dims / (self.dims - 2))
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=base,
scale=self.scale,
offset=offset,
)
def initialize_rope(args: ModelArgs):
head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
rope_scaling = args.rope_scaling
rope_type = "default"
rope_scale = 1.0
if rope_scaling is not None:
rope_type = (
rope_scaling.get("type") or rope_scaling.get("rope_type") or "default"
)
if rope_type == "linear":
rope_scale = 1 / rope_scaling["factor"]
elif rope_type == "llama3":
rope_scale = 1.0 # The scaling is handled internally for llama3
return DynamicNTKScalingRoPE(
dims=head_dim,
max_position_embeddings=args.max_position_embeddings,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
rope_type=rope_type,
rope_scaling=rope_scaling,
)
class Attention(nn.Module): class Attention(nn.Module):
@ -59,17 +174,7 @@ class Attention(nn.Module):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
rope_scale = ( self.rope = initialize_rope(args)
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__( def __call__(
self, self,

View File

@ -449,6 +449,33 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers model, args.model_type, args.vocab_size, args.num_hidden_layers
) )
def test_llama3_1(self):
from mlx_lm.models import llama
args = llama.ModelArgs(
model_type="llama",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
max_position_embeddings=128,
mlp_bias=False,
num_key_value_heads=2,
rope_scaling={
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
)
model = llama.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()