revert layer_idx

This commit is contained in:
Prince Canuma 2025-01-12 15:25:02 +01:00
parent ad93729dce
commit 9d7f38abb8

View File

@ -42,13 +42,13 @@ def create_causal_mask(
return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[Any] = None, layer_idx: int = 0):
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
T = h.shape[1]
if T > 1:
window_size = None
offset = 0
if cache is not None and cache[layer_idx] is not None:
c = cache[layer_idx]
if cache is not None and cache[0] is not None:
c = cache[0]
if hasattr(c, "max_size"):
offset = min(c.max_size, c.offset)
window_size = c.max_size