diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 2a49ee37..b7be6a17 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -16,6 +16,7 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int rms_norm_eps: float vocab_size: int + head_dim: Optional[int] = None num_key_value_heads: Optional[int] = None attention_bias: bool = False mlp_bias: bool = False @@ -45,7 +46,8 @@ class Attention(nn.Module): self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads - head_dim = args.hidden_size // n_heads + self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads + self.scale = head_dim**-0.5 if hasattr(args, "attention_bias"): attention_bias = args.attention_bias @@ -213,7 +215,9 @@ class Model(nn.Module): @property def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads + return ( + self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads + ) @property def n_kv_heads(self):