2024-01-12 04:29:12 +08:00
|
|
|
import inspect
|
|
|
|
from dataclasses import dataclass
|
2024-07-26 07:45:22 +08:00
|
|
|
from typing import List, Optional
|
2024-01-12 04:29:12 +08:00
|
|
|
|
2024-05-08 23:18:13 +08:00
|
|
|
import mlx.core as mx
|
2024-07-26 07:45:22 +08:00
|
|
|
import mlx.nn as nn
|
2024-05-22 11:16:31 +08:00
|
|
|
|
|
|
|
|
2024-05-08 23:18:13 +08:00
|
|
|
class KVCache:
|
|
|
|
|
|
|
|
def __init__(self, head_dim, n_kv_heads):
|
|
|
|
self.n_kv_heads = n_kv_heads
|
2024-07-17 22:23:28 +08:00
|
|
|
if isinstance(head_dim, int):
|
|
|
|
self.k_head_dim = self.v_head_dim = head_dim
|
|
|
|
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
|
|
|
|
self.k_head_dim, self.v_head_dim = head_dim
|
|
|
|
else:
|
|
|
|
raise ValueError("head_dim must be an int or a tuple of two ints")
|
2024-05-08 23:18:13 +08:00
|
|
|
self.keys = None
|
|
|
|
self.values = None
|
|
|
|
self.offset = 0
|
|
|
|
self.step = 256
|
|
|
|
|
|
|
|
def update_and_fetch(self, keys, values):
|
|
|
|
prev = self.offset
|
2024-05-16 03:56:24 +08:00
|
|
|
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
2024-07-26 07:45:22 +08:00
|
|
|
B = keys.shape[0]
|
2024-05-08 23:18:13 +08:00
|
|
|
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
2024-07-26 07:45:22 +08:00
|
|
|
k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
|
|
|
|
v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
|
2024-07-17 22:23:28 +08:00
|
|
|
new_k = mx.zeros(k_shape, keys.dtype)
|
|
|
|
new_v = mx.zeros(v_shape, values.dtype)
|
2024-05-08 23:18:13 +08:00
|
|
|
if self.keys is not None:
|
2024-05-16 03:56:24 +08:00
|
|
|
if prev % self.step != 0:
|
|
|
|
self.keys = self.keys[..., :prev, :]
|
|
|
|
self.values = self.values[..., :prev, :]
|
2024-05-08 23:18:13 +08:00
|
|
|
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
|
|
|
self.values = mx.concatenate([self.values, new_v], axis=2)
|
|
|
|
else:
|
|
|
|
self.keys, self.values = new_k, new_v
|
|
|
|
|
|
|
|
self.offset += keys.shape[2]
|
|
|
|
self.keys[..., prev : self.offset, :] = keys
|
|
|
|
self.values[..., prev : self.offset, :] = values
|
|
|
|
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
|
|
|
|
2024-01-12 04:29:12 +08:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class BaseModelArgs:
|
|
|
|
@classmethod
|
|
|
|
def from_dict(cls, params):
|
|
|
|
return cls(
|
|
|
|
**{
|
|
|
|
k: v
|
|
|
|
for k, v in params.items()
|
|
|
|
if k in inspect.signature(cls).parameters
|
|
|
|
}
|
|
|
|
)
|
2024-07-26 07:45:22 +08:00
|
|
|
|
|
|
|
|
|
|
|
def create_additive_causal_mask(N: int, offset: int = 0):
|
|
|
|
rinds = mx.arange(offset + N)
|
|
|
|
linds = mx.arange(offset, offset + N) if offset else rinds
|
|
|
|
mask = linds[:, None] < rinds[None]
|
|
|
|
return mask * -1e9
|
|
|
|
|
|
|
|
|
|
|
|
def create_attention_mask(h: mx.array, cache: Optional[List[KVCache]] = None):
|
|
|
|
T = h.shape[1]
|
|
|
|
if T > 1:
|
|
|
|
# Input consists of multiple tokens, create a causal mask so that prior
|
|
|
|
# tokens do not give attention to later tokens. If a cache is in place
|
|
|
|
# (because e.g. prompt reuse), offset the mask accordingly.
|
|
|
|
offset = cache[0].offset if cache is not None and cache[0] is not None else 0
|
|
|
|
mask = create_additive_causal_mask(T, offset)
|
|
|
|
mask = mask.astype(h.dtype)
|
|
|
|
else:
|
|
|
|
mask = None
|
|
|
|
return mask
|