mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
@@ -159,7 +159,7 @@ class LlamaModel(nn.Module):
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = create_additive_causal_mask(
|
||||
h.shape[1], cache[0][0].shape[2] if cache is not None else 0
|
||||
h.shape[1], cache[0].offset if cache is not None else 0
|
||||
)
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
|
Reference in New Issue
Block a user