fix mask shape error (long context)

This commit is contained in:
Prince Canuma 2025-01-11 21:56:00 +01:00
parent 514502da22
commit 1107364c3a
2 changed files with 4 additions and 4 deletions

View File

@ -42,13 +42,13 @@ def create_causal_mask(
return mask * -1e9 return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[Any] = None): def create_attention_mask(h: mx.array, cache: Optional[Any] = None, layer_idx: int = 0):
T = h.shape[1] T = h.shape[1]
if T > 1: if T > 1:
window_size = None window_size = None
offset = 0 offset = 0
if cache is not None and cache[0] is not None: if cache is not None and cache[layer_idx] is not None:
c = cache[0] c = cache[layer_idx]
if hasattr(c, "max_size"): if hasattr(c, "max_size"):
offset = min(c.max_size, c.offset) offset = min(c.max_size, c.offset)
window_size = c.max_size window_size = c.max_size

View File

@ -157,7 +157,7 @@ class CohereModel(nn.Module):
h = self.embed_tokens(inputs) h = self.embed_tokens(inputs)
if mask is None: if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache, layer_idx=self.args.sliding_window_pattern - 1)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * len(self.layers)