mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
clean up and fix rope
This commit is contained in:
parent
d90c6af11a
commit
9f3531c814
@ -45,7 +45,12 @@ class AttentionModule(nn.Module):
|
|||||||
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
|
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
|
||||||
|
|
||||||
self.rope = initialize_rope(
|
self.rope = initialize_rope(
|
||||||
self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings)
|
self.head_dim,
|
||||||
|
args.rope_theta,
|
||||||
|
args.rope_traditional,
|
||||||
|
args.rope_scaling,
|
||||||
|
args.max_position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None
|
self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None
|
||||||
|
@ -56,7 +56,12 @@ class Attention(nn.Module):
|
|||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||||
|
|
||||||
self.rope = initialize_rope(
|
self.rope = initialize_rope(
|
||||||
self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings)
|
self.head_dim,
|
||||||
|
args.rope_theta,
|
||||||
|
args.rope_traditional,
|
||||||
|
args.rope_scaling,
|
||||||
|
args.max_position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -56,8 +56,12 @@ class Attention(nn.Module):
|
|||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||||
|
|
||||||
self.rope = initialize_rope(
|
self.rope = initialize_rope(
|
||||||
self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings)
|
self.head_dim,
|
||||||
|
args.rope_theta,
|
||||||
|
args.rope_traditional,
|
||||||
|
args.rope_scaling,
|
||||||
|
args.max_position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps)
|
self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps)
|
||||||
self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps)
|
self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
@ -60,9 +61,17 @@ class Llama3RoPE(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def initialize_rope(dims, base, traditional, scaling_config: Optional[dict] = None, max_position_embeddings: Optional[int] = None):
|
def initialize_rope(
|
||||||
|
dims,
|
||||||
|
base,
|
||||||
|
traditional,
|
||||||
|
scaling_config: Optional[dict] = None,
|
||||||
|
max_position_embeddings: Optional[int] = None,
|
||||||
|
):
|
||||||
if scaling_config is not None:
|
if scaling_config is not None:
|
||||||
rope_type = scaling_config.get("type") or scaling_config.get("rope_type", "default")
|
rope_type = scaling_config.get("type") or scaling_config.get(
|
||||||
|
"rope_type", "default"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
rope_type = "default"
|
rope_type = "default"
|
||||||
|
|
||||||
|
@ -4,8 +4,8 @@ import unittest
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map
|
||||||
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
|
||||||
from mlx_lm.models import rope_utils
|
from mlx_lm.models import rope_utils
|
||||||
|
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
||||||
|
|
||||||
|
|
||||||
class TestModels(unittest.TestCase):
|
class TestModels(unittest.TestCase):
|
||||||
@ -132,10 +132,20 @@ class TestModels(unittest.TestCase):
|
|||||||
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
|
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
|
||||||
self.assertTrue(isinstance(rope, nn.RoPE))
|
self.assertTrue(isinstance(rope, nn.RoPE))
|
||||||
|
|
||||||
rope = rope_utils.initialize_rope(32, base=100, traditional=False, scaling_config={"rope_type": "linear", "factor": 10.0})
|
rope = rope_utils.initialize_rope(
|
||||||
|
32,
|
||||||
|
base=100,
|
||||||
|
traditional=False,
|
||||||
|
scaling_config={"rope_type": "linear", "factor": 10.0},
|
||||||
|
)
|
||||||
self.assertTrue(isinstance(rope, nn.RoPE))
|
self.assertTrue(isinstance(rope, nn.RoPE))
|
||||||
|
|
||||||
rope = rope_utils.initialize_rope(32, base=100, traditional=False, scaling_config={"rope_type": "llama3", "factor": 2.0})
|
rope = rope_utils.initialize_rope(
|
||||||
|
32,
|
||||||
|
base=100,
|
||||||
|
traditional=False,
|
||||||
|
scaling_config={"rope_type": "llama3", "factor": 2.0},
|
||||||
|
)
|
||||||
self.assertTrue(isinstance(rope, rope_utils.Llama3RoPE))
|
self.assertTrue(isinstance(rope, rope_utils.Llama3RoPE))
|
||||||
|
|
||||||
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
||||||
|
Loading…
Reference in New Issue
Block a user