mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
chore: remove default rope_scaling dict and use get to access type and factor to avoid key error
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
@@ -26,20 +26,19 @@ class ModelArgs:
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
model_type: str = None
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = field(default_factory=dict)
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
|
||||
if self.rope_scaling:
|
||||
required_keys = {"factor", "type"}
|
||||
if not all(key in self.rope_scaling for key in required_keys):
|
||||
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
||||
|
||||
|
||||
if self.rope_scaling["type"] != "linear":
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
@@ -83,9 +82,17 @@ class Attention(nn.Module):
|
||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
rope_scale = 1 / args.rope_scaling['factor'] if args.rope_scaling is not None and args.rope_scaling['type'] == 'linear' else 1
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling.get("factor", 1.0)
|
||||
if args.rope_scaling is not None
|
||||
and args.rope_scaling.get("type") == "linear"
|
||||
else 1
|
||||
)
|
||||
self.rope = nn.RoPE(
|
||||
head_dim, traditional=args.rope_traditional, base=args.rope_theta, scale= rope_scale
|
||||
head_dim,
|
||||
traditional=args.rope_traditional,
|
||||
base=args.rope_theta,
|
||||
scale=rope_scale,
|
||||
)
|
||||
|
||||
def __call__(
|
||||
|
Reference in New Issue
Block a user