mlx-examples/llms/mlx_lm/models/base.py
Awni Hannun fca087be49
More cache improvements (#1015)
* fix rotating kv cache for chat use case

* reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat

* nit in chat

* fix tests

* fix tests

* fix tests

* docs

* chat command

* comments + docs

* Define meta_state on all Cache implementations

* fixes + trim_prompt_cache api

* fix default model

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
2024-10-07 20:45:51 -07:00

51 lines
1.3 KiB
Python

# Copyright © 2023-2024 Apple Inc.
import inspect
from dataclasses import dataclass
from typing import Any, Optional
import mlx.core as mx
@dataclass
class BaseModelArgs:
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
linds = linds[:, None]
rinds = rinds[None]
mask = linds < rinds
if window_size is not None:
mask = mask | (linds > rinds + window_size)
return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
T = h.shape[1]
if T > 1:
window_size = None
offset = 0
if cache is not None and cache[0] is not None:
c = cache[0]
if hasattr(c, "max_size"):
offset = min(c.max_size - 1, c.offset)
window_size = c.max_size
else:
offset = c.offset
mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else:
mask = None
return mask