diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index b4068679..6b46ddc6 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -48,8 +48,12 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - head_dim = args.hidden_size // args.num_attention_heads + self.head_dim = head_dim = args.head_dim + if (head_dim * n_heads) != dim: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}" + f" and `num_heads`: {n_heads})." + ) self.scale = head_dim**-0.5 attetion_bias = args.attention_bias @@ -77,11 +81,8 @@ class Attention(nn.Module): queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - queries = queries.reshape(B, L, self.n_heads, -1) - keys = keys.reshape(B, L, self.n_kv_heads, -1) - - queries = queries.transpose(0, 2, 1, 3) - keys = keys.transpose(0, 2, 1, 3) + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) if cache is not None: @@ -94,10 +95,10 @@ class Attention(nn.Module): # sliding window attention if self.sliding_window is not None: - keys = keys[:, : -self.sliding_window :, :] - values = values[:, : -self.sliding_window :, :] + keys = keys[:, :, -self.sliding_window :, :] + values = values[:, :, -self.sliding_window :, :] if mask is not None: - mask = mask[:, : -self.sliding_window, :] + mask = mask[:, -self.sliding_window :] output = mx.fast.scaled_dot_product_attention( queries, keys, values, scale=self.scale, mask=mask @@ -200,7 +201,7 @@ class Model(nn.Module): @property def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads + return self.args.head_dim @property def n_kv_heads(self):