Add Mistral NeMo (fix) (#895)

* fix head_dim

* Update llms/mlx_lm/models/llama.py

* fix kv error

* formatting

* Delete test.py

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Prince Canuma 2024-07-22 15:09:24 +02:00 committed by GitHub
parent 3d365b612a
commit 3f337e0f0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,6 +16,7 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
head_dim: Optional[int] = None
num_key_value_heads: Optional[int] = None
attention_bias: bool = False
mlp_bias: bool = False
@ -45,7 +46,8 @@ class Attention(nn.Module):
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
@ -213,7 +215,9 @@ class Model(nn.Module):
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
return (
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
)
@property
def n_kv_heads(self):