mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:17:07 +08:00
clean up and fix rope
This commit is contained in:
parent
b2b16500fb
commit
d90c6af11a
@ -7,6 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
from .rope_utils import initialize_rope
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -28,76 +29,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
attention_bias: bool = False
|
attention_bias: bool = False
|
||||||
mlp_bias: bool = False
|
mlp_bias: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.rope_scaling:
|
|
||||||
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
|
|
||||||
"rope_type"
|
|
||||||
)
|
|
||||||
if rope_type is None:
|
|
||||||
raise ValueError(
|
|
||||||
"rope_scaling must contain either 'type' or 'rope_type'"
|
|
||||||
)
|
|
||||||
if rope_type not in ["linear", "dynamic", "llama3", "default"]:
|
|
||||||
raise ValueError(
|
|
||||||
"rope_scaling 'type' currently only supports 'linear', 'dynamic', 'llama3', or 'default'"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExaoneRotaryEmbedding(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dims: int,
|
|
||||||
max_position_embeddings: int = 2048,
|
|
||||||
traditional: bool = False,
|
|
||||||
base: float = 10000,
|
|
||||||
scale: float = 1.0,
|
|
||||||
rope_type: str = "default",
|
|
||||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.dims = dims
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.traditional = traditional
|
|
||||||
self.scale = scale
|
|
||||||
self.rope_type = rope_type
|
|
||||||
self.rope_scaling = rope_scaling
|
|
||||||
self.base = base
|
|
||||||
|
|
||||||
def __call__(self, x, offset: int = 0):
|
|
||||||
return mx.fast.rope(
|
|
||||||
x,
|
|
||||||
self.dims,
|
|
||||||
traditional=self.traditional,
|
|
||||||
base=self.base,
|
|
||||||
scale=self.scale,
|
|
||||||
offset=offset,
|
|
||||||
freqs=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_rope(args: ModelArgs):
|
|
||||||
head_dim = args.head_dim or (args.hidden_size // args.num_attention_heads)
|
|
||||||
rope_scaling = args.rope_scaling
|
|
||||||
rope_type = "default"
|
|
||||||
rope_scale = 1.0
|
|
||||||
|
|
||||||
if rope_scaling is not None:
|
|
||||||
rope_type = rope_scaling.get("type") or rope_scaling.get("rope_type", "default")
|
|
||||||
if rope_type == "linear":
|
|
||||||
rope_scale = 1 / rope_scaling["factor"]
|
|
||||||
elif rope_type in ["llama3", "dynamic"]:
|
|
||||||
rope_scale = 1.0
|
|
||||||
|
|
||||||
return ExaoneRotaryEmbedding(
|
|
||||||
dims=head_dim,
|
|
||||||
max_position_embeddings=args.max_position_embeddings or 2048,
|
|
||||||
traditional=args.rope_traditional,
|
|
||||||
base=args.rope_theta,
|
|
||||||
scale=rope_scale,
|
|
||||||
rope_type=rope_type,
|
|
||||||
rope_scaling=rope_scaling,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionModule(nn.Module):
|
class AttentionModule(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
@ -113,7 +44,8 @@ class AttentionModule(nn.Module):
|
|||||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||||
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
|
self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
|
||||||
|
|
||||||
self.rope = initialize_rope(args)
|
self.rope = initialize_rope(
|
||||||
|
self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None
|
self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None
|
||||||
|
@ -7,6 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
from .rope_utils import initialize_rope
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -32,117 +33,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
if self.num_key_value_heads is None:
|
if self.num_key_value_heads is None:
|
||||||
self.num_key_value_heads = self.num_attention_heads
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
if self.rope_scaling:
|
|
||||||
if not "factor" in self.rope_scaling:
|
|
||||||
raise ValueError(f"rope_scaling must contain 'factor'")
|
|
||||||
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
|
|
||||||
"rope_type"
|
|
||||||
)
|
|
||||||
if rope_type is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"rope_scaling must contain either 'type' or 'rope_type'"
|
|
||||||
)
|
|
||||||
if rope_type not in ["linear", "dynamic", "llama3"]:
|
|
||||||
raise ValueError(
|
|
||||||
"rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicNTKScalingRoPE(nn.Module):
|
|
||||||
"""Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dims: int,
|
|
||||||
max_position_embeddings: int = 2048,
|
|
||||||
traditional: bool = False,
|
|
||||||
base: float = 10000,
|
|
||||||
scale: float = 1.0,
|
|
||||||
rope_type: str = "default",
|
|
||||||
rope_scaling: dict = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.dims = dims
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.traditional = traditional
|
|
||||||
self.scale = scale
|
|
||||||
self.rope_type = rope_type
|
|
||||||
self.rope_scaling = rope_scaling
|
|
||||||
self.base = base
|
|
||||||
self.compute_freqs()
|
|
||||||
|
|
||||||
def compute_freqs(self):
|
|
||||||
if self.rope_type != "llama3":
|
|
||||||
self._freqs = None
|
|
||||||
return
|
|
||||||
factor = self.rope_scaling["factor"]
|
|
||||||
low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
|
|
||||||
high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
|
|
||||||
old_context_len = self.rope_scaling.get(
|
|
||||||
"original_max_position_embeddings",
|
|
||||||
8192,
|
|
||||||
)
|
|
||||||
|
|
||||||
low_freq_wavelen = old_context_len / low_freq_factor
|
|
||||||
high_freq_wavelen = old_context_len / high_freq_factor
|
|
||||||
|
|
||||||
freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims)
|
|
||||||
wavelens = 2 * mx.pi * freqs
|
|
||||||
|
|
||||||
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
|
|
||||||
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
|
|
||||||
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
|
|
||||||
high_freq_factor - low_freq_factor
|
|
||||||
)
|
|
||||||
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
|
|
||||||
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
|
|
||||||
self.base = None
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return (
|
|
||||||
f"{self.dims}, traditional={self.traditional}, "
|
|
||||||
f"max_position_embeddings={self.max_position_embeddings}, "
|
|
||||||
f"scaling_factor={self.scale}, rope_type={self.rope_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, x, offset: int = 0):
|
|
||||||
return mx.fast.rope(
|
|
||||||
x,
|
|
||||||
self.dims,
|
|
||||||
traditional=self.traditional,
|
|
||||||
base=self.base,
|
|
||||||
scale=self.scale,
|
|
||||||
offset=offset,
|
|
||||||
freqs=self._freqs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_rope(args: ModelArgs):
|
|
||||||
head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
|
|
||||||
|
|
||||||
rope_scaling = args.rope_scaling
|
|
||||||
rope_type = "default"
|
|
||||||
rope_scale = 1.0
|
|
||||||
|
|
||||||
if rope_scaling is not None:
|
|
||||||
rope_type = (
|
|
||||||
rope_scaling.get("type") or rope_scaling.get("rope_type") or "default"
|
|
||||||
)
|
|
||||||
if rope_type == "linear":
|
|
||||||
rope_scale = 1 / rope_scaling["factor"]
|
|
||||||
elif rope_type == "llama3":
|
|
||||||
rope_scale = 1.0 # The scaling is handled internally for llama3
|
|
||||||
|
|
||||||
return DynamicNTKScalingRoPE(
|
|
||||||
dims=head_dim,
|
|
||||||
max_position_embeddings=args.max_position_embeddings,
|
|
||||||
traditional=args.rope_traditional,
|
|
||||||
base=args.rope_theta,
|
|
||||||
scale=rope_scale,
|
|
||||||
rope_type=rope_type,
|
|
||||||
rope_scaling=rope_scaling,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
@ -165,7 +55,8 @@ class Attention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||||
|
|
||||||
self.rope = initialize_rope(args)
|
self.rope = initialize_rope(
|
||||||
|
self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -7,6 +7,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
|
from .rope_utils import initialize_rope
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -32,117 +33,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
if self.num_key_value_heads is None:
|
if self.num_key_value_heads is None:
|
||||||
self.num_key_value_heads = self.num_attention_heads
|
self.num_key_value_heads = self.num_attention_heads
|
||||||
|
|
||||||
if self.rope_scaling:
|
|
||||||
if not "factor" in self.rope_scaling:
|
|
||||||
raise ValueError(f"rope_scaling must contain 'factor'")
|
|
||||||
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
|
|
||||||
"rope_type"
|
|
||||||
)
|
|
||||||
if rope_type is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"rope_scaling must contain either 'type' or 'rope_type'"
|
|
||||||
)
|
|
||||||
if rope_type not in ["linear", "dynamic", "llama3"]:
|
|
||||||
raise ValueError(
|
|
||||||
"rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicNTKScalingRoPE(nn.Module):
|
|
||||||
"""Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dims: int,
|
|
||||||
max_position_embeddings: int = 2048,
|
|
||||||
traditional: bool = False,
|
|
||||||
base: float = 10000,
|
|
||||||
scale: float = 1.0,
|
|
||||||
rope_type: str = "default",
|
|
||||||
rope_scaling: dict = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.dims = dims
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.traditional = traditional
|
|
||||||
self.scale = scale
|
|
||||||
self.rope_type = rope_type
|
|
||||||
self.rope_scaling = rope_scaling
|
|
||||||
self.base = base
|
|
||||||
self.compute_freqs()
|
|
||||||
|
|
||||||
def compute_freqs(self):
|
|
||||||
if self.rope_type != "llama3":
|
|
||||||
self._freqs = None
|
|
||||||
return
|
|
||||||
factor = self.rope_scaling["factor"]
|
|
||||||
low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
|
|
||||||
high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
|
|
||||||
old_context_len = self.rope_scaling.get(
|
|
||||||
"original_max_position_embeddings",
|
|
||||||
8192,
|
|
||||||
)
|
|
||||||
|
|
||||||
low_freq_wavelen = old_context_len / low_freq_factor
|
|
||||||
high_freq_wavelen = old_context_len / high_freq_factor
|
|
||||||
|
|
||||||
freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims)
|
|
||||||
wavelens = 2 * mx.pi * freqs
|
|
||||||
|
|
||||||
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
|
|
||||||
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
|
|
||||||
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
|
|
||||||
high_freq_factor - low_freq_factor
|
|
||||||
)
|
|
||||||
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
|
|
||||||
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
|
|
||||||
self.base = None
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return (
|
|
||||||
f"{self.dims}, traditional={self.traditional}, "
|
|
||||||
f"max_position_embeddings={self.max_position_embeddings}, "
|
|
||||||
f"scaling_factor={self.scale}, rope_type={self.rope_type}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, x, offset: int = 0):
|
|
||||||
return mx.fast.rope(
|
|
||||||
x,
|
|
||||||
self.dims,
|
|
||||||
traditional=self.traditional,
|
|
||||||
base=self.base,
|
|
||||||
scale=self.scale,
|
|
||||||
offset=offset,
|
|
||||||
freqs=self._freqs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def initialize_rope(args: ModelArgs):
|
|
||||||
head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
|
|
||||||
|
|
||||||
rope_scaling = args.rope_scaling
|
|
||||||
rope_type = "default"
|
|
||||||
rope_scale = 1.0
|
|
||||||
|
|
||||||
if rope_scaling is not None:
|
|
||||||
rope_type = (
|
|
||||||
rope_scaling.get("type") or rope_scaling.get("rope_type") or "default"
|
|
||||||
)
|
|
||||||
if rope_type == "linear":
|
|
||||||
rope_scale = 1 / rope_scaling["factor"]
|
|
||||||
elif rope_type == "llama3":
|
|
||||||
rope_scale = 1.0 # The scaling is handled internally for llama3
|
|
||||||
|
|
||||||
return DynamicNTKScalingRoPE(
|
|
||||||
dims=head_dim,
|
|
||||||
max_position_embeddings=args.max_position_embeddings,
|
|
||||||
traditional=args.rope_traditional,
|
|
||||||
base=args.rope_theta,
|
|
||||||
scale=rope_scale,
|
|
||||||
rope_type=rope_type,
|
|
||||||
rope_scaling=rope_scaling,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
@ -165,7 +55,10 @@ class Attention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
||||||
|
|
||||||
self.rope = initialize_rope(args)
|
self.rope = initialize_rope(
|
||||||
|
self.head_dim, args.rope_theta, args.rope_traditional, args.rope_scaling, args.max_position_embeddings)
|
||||||
|
|
||||||
|
|
||||||
self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps)
|
self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps)
|
||||||
self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps)
|
self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps)
|
||||||
|
|
||||||
|
82
llms/mlx_lm/models/rope_utils.py
Normal file
82
llms/mlx_lm/models/rope_utils.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class Llama3RoPE(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dims: int,
|
||||||
|
max_position_embeddings: int = 2048,
|
||||||
|
traditional: bool = False,
|
||||||
|
base: float = 10000,
|
||||||
|
scaling_config: dict = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dims = dims
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.traditional = traditional
|
||||||
|
|
||||||
|
factor = scaling_config["factor"]
|
||||||
|
low_freq_factor = scaling_config.get("low_freq_factor", 1.0)
|
||||||
|
high_freq_factor = scaling_config.get("high_freq_factor", 4.0)
|
||||||
|
old_context_len = scaling_config.get(
|
||||||
|
"original_max_position_embeddings",
|
||||||
|
8192,
|
||||||
|
)
|
||||||
|
|
||||||
|
low_freq_wavelen = old_context_len / low_freq_factor
|
||||||
|
high_freq_wavelen = old_context_len / high_freq_factor
|
||||||
|
|
||||||
|
freqs = base ** (mx.arange(0, dims, 2) / dims)
|
||||||
|
wavelens = 2 * mx.pi * freqs
|
||||||
|
|
||||||
|
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
|
||||||
|
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
|
||||||
|
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
|
||||||
|
high_freq_factor - low_freq_factor
|
||||||
|
)
|
||||||
|
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
|
||||||
|
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return (
|
||||||
|
f"{self.dims}, traditional={self.traditional}, "
|
||||||
|
f"max_position_embeddings={self.max_position_embeddings}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x, offset: int = 0):
|
||||||
|
return mx.fast.rope(
|
||||||
|
x,
|
||||||
|
self.dims,
|
||||||
|
traditional=self.traditional,
|
||||||
|
base=None,
|
||||||
|
scale=1.0,
|
||||||
|
offset=offset,
|
||||||
|
freqs=self._freqs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_rope(dims, base, traditional, scaling_config: Optional[dict] = None, max_position_embeddings: Optional[int] = None):
|
||||||
|
if scaling_config is not None:
|
||||||
|
rope_type = scaling_config.get("type") or scaling_config.get("rope_type", "default")
|
||||||
|
else:
|
||||||
|
rope_type = "default"
|
||||||
|
|
||||||
|
if rope_type in ["default", "linear"]:
|
||||||
|
scale = 1 / scaling_config["factor"] if rope_type == "linear" else 1.0
|
||||||
|
return nn.RoPE(dims, traditional=traditional, base=base, scale=scale)
|
||||||
|
|
||||||
|
elif rope_type == "llama3":
|
||||||
|
return Llama3RoPE(
|
||||||
|
dims=dims,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
traditional=traditional,
|
||||||
|
base=base,
|
||||||
|
scaling_config=scaling_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported RoPE type {rope_type}")
|
@ -2,8 +2,10 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map
|
||||||
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
||||||
|
from mlx_lm.models import rope_utils
|
||||||
|
|
||||||
|
|
||||||
class TestModels(unittest.TestCase):
|
class TestModels(unittest.TestCase):
|
||||||
@ -126,6 +128,16 @@ class TestModels(unittest.TestCase):
|
|||||||
self.assertEqual(cache.offset, 22)
|
self.assertEqual(cache.offset, 22)
|
||||||
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
||||||
|
|
||||||
|
def test_rope(self):
|
||||||
|
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
|
||||||
|
self.assertTrue(isinstance(rope, nn.RoPE))
|
||||||
|
|
||||||
|
rope = rope_utils.initialize_rope(32, base=100, traditional=False, scaling_config={"rope_type": "linear", "factor": 10.0})
|
||||||
|
self.assertTrue(isinstance(rope, nn.RoPE))
|
||||||
|
|
||||||
|
rope = rope_utils.initialize_rope(32, base=100, traditional=False, scaling_config={"rope_type": "llama3", "factor": 2.0})
|
||||||
|
self.assertTrue(isinstance(rope, rope_utils.Llama3RoPE))
|
||||||
|
|
||||||
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
||||||
|
|
||||||
self.assertEqual(len(model.layers), num_layers)
|
self.assertEqual(len(model.layers), num_layers)
|
||||||
|
Loading…
Reference in New Issue
Block a user