mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Made llama and mistral files mypy compatible (#1359)
* Made mypy compatible * reformatted * Added more fixes * Added fixes to speculative-decoding * Fixes * fix circle * revert some stuff --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -17,10 +17,10 @@ class ModelArgs:
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
num_key_value_heads: Optional[int] = None
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
model_type: str = None
|
||||
model_type: Optional[str] = None
|
||||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -146,7 +146,7 @@ class Attention(nn.Module):
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
|
||||
rope_scale = (
|
||||
1 / args.rope_scaling["factor"]
|
||||
1 / float(args.rope_scaling["factor"])
|
||||
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
|
||||
else 1
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user