From 9f3531c8148ef9ca0ad29938d84ef3752de92bad Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 9 Dec 2024 07:44:55 -0800 Subject: [PATCH] clean up and fix rope --- llms/mlx_lm/models/exaone.py | 7 ++++++- llms/mlx_lm/models/llama.py | 7 ++++++- llms/mlx_lm/models/olmo2.py | 8 ++++++-- llms/mlx_lm/models/rope_utils.py | 13 +++++++++++-- llms/tests/test_models.py | 16 +++++++++++++--- 5 files changed, 42 insertions(+), 9 deletions(-) diff --git a/llms/mlx_lm/models/exaone.py b/llms/mlx_lm/models/exaone.py index 0df08118..eaed5dd8 100644 --- a/llms/mlx_lm/models/exaone.py +++ b/llms/mlx_lm/models/exaone.py @@ -45,7 +45,12 @@ class AttentionModule(nn.Module): self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) self.rope = initialize_rope( - self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings) + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) def __call__( self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index eaca5a9f..290cb83e 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -56,7 +56,12 @@ class Attention(nn.Module): self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) self.rope = initialize_rope( - self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings) + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) def __call__( self, diff --git a/llms/mlx_lm/models/olmo2.py b/llms/mlx_lm/models/olmo2.py index 6496295c..64d7e116 100644 --- a/llms/mlx_lm/models/olmo2.py +++ b/llms/mlx_lm/models/olmo2.py @@ -56,8 +56,12 @@ class Attention(nn.Module): self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) self.rope = initialize_rope( - self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings) - + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps) self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps) diff --git a/llms/mlx_lm/models/rope_utils.py b/llms/mlx_lm/models/rope_utils.py index 12da2002..d30b432d 100644 --- a/llms/mlx_lm/models/rope_utils.py +++ b/llms/mlx_lm/models/rope_utils.py @@ -1,6 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from typing import Optional + import mlx.core as mx import mlx.nn as nn @@ -60,9 +61,17 @@ class Llama3RoPE(nn.Module): ) -def initialize_rope(dims, base, traditional, scaling_config: Optional[dict] = None, max_position_embeddings: Optional[int] = None): +def initialize_rope( + dims, + base, + traditional, + scaling_config: Optional[dict] = None, + max_position_embeddings: Optional[int] = None, +): if scaling_config is not None: - rope_type = scaling_config.get("type") or scaling_config.get("rope_type", "default") + rope_type = scaling_config.get("type") or scaling_config.get( + "rope_type", "default" + ) else: rope_type = "default" diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 2322f37e..374a5113 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -4,8 +4,8 @@ import unittest import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_map -from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache from mlx_lm.models import rope_utils +from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache class TestModels(unittest.TestCase): @@ -132,10 +132,20 @@ class TestModels(unittest.TestCase): rope = rope_utils.initialize_rope(32, base=100, traditional=False) self.assertTrue(isinstance(rope, nn.RoPE)) - rope = rope_utils.initialize_rope(32, base=100, traditional=False, scaling_config={"rope_type": "linear", "factor": 10.0}) + rope = rope_utils.initialize_rope( + 32, + base=100, + traditional=False, + scaling_config={"rope_type": "linear", "factor": 10.0}, + ) self.assertTrue(isinstance(rope, nn.RoPE)) - rope = rope_utils.initialize_rope(32, base=100, traditional=False, scaling_config={"rope_type": "llama3", "factor": 2.0}) + rope = rope_utils.initialize_rope( + 32, + base=100, + traditional=False, + scaling_config={"rope_type": "llama3", "factor": 2.0}, + ) self.assertTrue(isinstance(rope, rope_utils.Llama3RoPE)) def model_test_runner(self, model, model_type, vocab_size, num_layers):