Update cohere2.py

This commit is contained in:
Awni Hannun 2025-01-12 10:58:18 -08:00 committed by GitHub
parent d7638e029c
commit 0e908bddff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -156,13 +156,13 @@ class CohereModel(nn.Module):
):
h = self.embed_tokens(inputs)
if cache is None:
cache = [None] * len(self.layers)
if mask is None:
j = self.args.sliding_window_pattern
mask = create_attention_mask(h, cache[j - 1 : j])
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)