clean up and fix rope

This commit is contained in:
Awni Hannun 2024-12-09 07:44:55 -08:00
parent d90c6af11a
commit 9f3531c814
5 changed files with 42 additions and 9 deletions

View File

@ -45,7 +45,12 @@ class AttentionModule(nn.Module):
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
self.rope = initialize_rope(
self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings)
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)
def __call__(
self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None

View File

@ -56,7 +56,12 @@ class Attention(nn.Module):
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
self.rope = initialize_rope(
self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings)
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)
def __call__(
self,

View File

@ -56,8 +56,12 @@ class Attention(nn.Module):
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
self.rope = initialize_rope(
self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings)
self.head_dim,
args.rope_theta,
args.rope_traditional,
args.rope_scaling,
args.max_position_embeddings,
)
self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps)
self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps)

View File

@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
@ -60,9 +61,17 @@ class Llama3RoPE(nn.Module):
)
def initialize_rope(dims, base, traditional, scaling_config: Optional[dict] = None, max_position_embeddings: Optional[int] = None):
def initialize_rope(
dims,
base,
traditional,
scaling_config: Optional[dict] = None,
max_position_embeddings: Optional[int] = None,
):
if scaling_config is not None:
rope_type = scaling_config.get("type") or scaling_config.get("rope_type", "default")
rope_type = scaling_config.get("type") or scaling_config.get(
"rope_type", "default"
)
else:
rope_type = "default"

View File

@ -4,8 +4,8 @@ import unittest
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_map
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
from mlx_lm.models import rope_utils
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
class TestModels(unittest.TestCase):
@ -132,10 +132,20 @@ class TestModels(unittest.TestCase):
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
self.assertTrue(isinstance(rope, nn.RoPE))
rope = rope_utils.initialize_rope(32, base=100, traditional=False, scaling_config={"rope_type": "linear", "factor": 10.0})
rope = rope_utils.initialize_rope(
32,
base=100,
traditional=False,
scaling_config={"rope_type": "linear", "factor": 10.0},
)
self.assertTrue(isinstance(rope, nn.RoPE))
rope = rope_utils.initialize_rope(32, base=100, traditional=False, scaling_config={"rope_type": "llama3", "factor": 2.0})
rope = rope_utils.initialize_rope(
32,
base=100,
traditional=False,
scaling_config={"rope_type": "llama3", "factor": 2.0},
)
self.assertTrue(isinstance(rope, rope_utils.Llama3RoPE))
def model_test_runner(self, model, model_type, vocab_size, num_layers):