mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Fix Cohere2: mask shape error (long context) (#1202)
* fix mask shape error (long context) * Update llms/mlx_lm/models/cohere2.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * revert layer_idx * black formatting * Update cohere2.py * format --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
514502da22
commit
bf2da36fc6
@ -156,12 +156,13 @@ class CohereModel(nn.Module):
|
|||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
if mask is None:
|
|
||||||
mask = create_attention_mask(h, cache)
|
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
|
if mask is None:
|
||||||
|
j = self.args.sliding_window_pattern
|
||||||
|
mask = create_attention_mask(h, cache[j - 1 : j])
|
||||||
|
|
||||||
for layer, c in zip(self.layers, cache):
|
for layer, c in zip(self.layers, cache):
|
||||||
h = layer(h, mask, c)
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user