Fix llama cache check (#763)

* fix llama cache check

* add test
This commit is contained in:
Awni Hannun
2024-05-08 08:35:54 -07:00
committed by GitHub
parent ee60e2a9d5
commit fad9598372
2 changed files with 5 additions and 1 deletions

View File

@@ -159,7 +159,7 @@ class LlamaModel(nn.Module):
mask = None
if h.shape[1] > 1:
mask = create_additive_causal_mask(
h.shape[1], cache[0][0].shape[2] if cache is not None else 0
h.shape[1], cache[0].offset if cache is not None else 0
)
mask = mask.astype(h.dtype)