2024-08-17 06:28:39 +08:00
|
|
|
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
import inspect
|
|
|
|
from dataclasses import dataclass
|
2024-10-08 11:45:51 +08:00
|
|
|
from typing import Any, Optional
|
2024-01-12 04:29:12 +08:00
|
|
|
|
2024-05-08 23:18:13 +08:00
|
|
|
import mlx.core as mx
|
2024-08-17 06:28:39 +08:00
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class BaseModelArgs:
|
|
|
|
@classmethod
|
|
|
|
def from_dict(cls, params):
|
|
|
|
return cls(
|
|
|
|
**{
|
|
|
|
k: v
|
|
|
|
for k, v in params.items()
|
|
|
|
if k in inspect.signature(cls).parameters
|
|
|
|
}
|
|
|
|
)
|
2024-07-26 07:45:22 +08:00
|
|
|
|
|
|
|
|
2024-10-08 11:45:51 +08:00
|
|
|
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
|
2024-07-26 07:45:22 +08:00
|
|
|
rinds = mx.arange(offset + N)
|
|
|
|
linds = mx.arange(offset, offset + N) if offset else rinds
|
2024-10-08 11:45:51 +08:00
|
|
|
linds = linds[:, None]
|
|
|
|
rinds = rinds[None]
|
|
|
|
mask = linds < rinds
|
|
|
|
if window_size is not None:
|
|
|
|
mask = mask | (linds > rinds + window_size)
|
2024-07-26 07:45:22 +08:00
|
|
|
return mask * -1e9
|
|
|
|
|
|
|
|
|
2024-08-17 06:28:39 +08:00
|
|
|
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
2024-07-26 07:45:22 +08:00
|
|
|
T = h.shape[1]
|
|
|
|
if T > 1:
|
2024-10-08 11:45:51 +08:00
|
|
|
window_size = None
|
|
|
|
offset = 0
|
2024-08-17 06:28:39 +08:00
|
|
|
if cache is not None and cache[0] is not None:
|
|
|
|
c = cache[0]
|
2024-10-08 11:45:51 +08:00
|
|
|
if hasattr(c, "max_size"):
|
2024-08-17 06:28:39 +08:00
|
|
|
offset = min(c.max_size - 1, c.offset)
|
2024-10-08 11:45:51 +08:00
|
|
|
window_size = c.max_size
|
2024-08-17 06:28:39 +08:00
|
|
|
else:
|
|
|
|
offset = c.offset
|
2024-10-08 11:45:51 +08:00
|
|
|
mask = create_causal_mask(T, offset, window_size=window_size)
|
2024-07-26 07:45:22 +08:00
|
|
|
mask = mask.astype(h.dtype)
|
|
|
|
else:
|
|
|
|
mask = None
|
|
|
|
return mask
|