mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
revert layer_idx
This commit is contained in:
parent
ad93729dce
commit
9d7f38abb8
@ -42,13 +42,13 @@ def create_causal_mask(
|
|||||||
return mask * -1e9
|
return mask * -1e9
|
||||||
|
|
||||||
|
|
||||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None, layer_idx: int = 0):
|
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||||
T = h.shape[1]
|
T = h.shape[1]
|
||||||
if T > 1:
|
if T > 1:
|
||||||
window_size = None
|
window_size = None
|
||||||
offset = 0
|
offset = 0
|
||||||
if cache is not None and cache[layer_idx] is not None:
|
if cache is not None and cache[0] is not None:
|
||||||
c = cache[layer_idx]
|
c = cache[0]
|
||||||
if hasattr(c, "max_size"):
|
if hasattr(c, "max_size"):
|
||||||
offset = min(c.max_size, c.offset)
|
offset = min(c.max_size, c.offset)
|
||||||
window_size = c.max_size
|
window_size = c.max_size
|
||||||
|
Loading…
Reference in New Issue
Block a user