Fix Cohere2: mask shape error (long context) (#1202)

* fix mask shape error (long context)

* Update llms/mlx_lm/models/cohere2.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* revert layer_idx

* black formatting

* Update cohere2.py

* format

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Prince Canuma 2025-01-12 21:58:08 +01:00 committed by GitHub
parent 514502da22
commit bf2da36fc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -156,12 +156,13 @@ class CohereModel(nn.Module):
): ):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * len(self.layers)
if mask is None:
j = self.args.sliding_window_pattern
mask = create_attention_mask(h, cache[j - 1 : j])
for layer, c in zip(self.layers, cache): for layer, c in zip(self.layers, cache):
h = layer(h, mask, c) h = layer(h, mask, c)