mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Handle longer prompt/generation (#931)
* rebase * nits * nit * fix rotating cache with step prefill * update version
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -44,6 +46,100 @@ class KVCache:
|
||||
self.values[..., prev : self.offset, :] = values
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
|
||||
def state(self):
|
||||
return self.keys, self.values
|
||||
|
||||
|
||||
class RotatingKVCache:
|
||||
|
||||
def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
|
||||
self.n_kv_heads = n_kv_heads
|
||||
if isinstance(head_dim, int):
|
||||
self.k_head_dim = self.v_head_dim = head_dim
|
||||
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
|
||||
self.k_head_dim, self.v_head_dim = head_dim
|
||||
else:
|
||||
raise ValueError("head_dim must be an int or a tuple of two ints")
|
||||
self.keep = keep
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.max_size = max_size
|
||||
self.step = step
|
||||
self._idx = 0
|
||||
|
||||
def _trim(self, trim_size, v, append=None):
|
||||
to_cat = []
|
||||
if trim_size > 0:
|
||||
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
|
||||
else:
|
||||
to_cat = [v]
|
||||
if append is not None:
|
||||
to_cat.append(append)
|
||||
return mx.concatenate(to_cat, axis=2)
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
prev = self.offset
|
||||
B, _, S = keys.shape[:3]
|
||||
|
||||
# Prefill mode
|
||||
if S > 1:
|
||||
if self.keys is None:
|
||||
self.keys = keys
|
||||
self.values = values
|
||||
else:
|
||||
# The largest size is self.max_size + S - 1 to ensure
|
||||
# every token gets at least self.max_size context
|
||||
trim_size = self.keys.shape[2] - self.max_size + 1
|
||||
self.keys = self._trim(trim_size, self.keys, keys)
|
||||
self.values = self._trim(trim_size, self.values, values)
|
||||
self.offset += S
|
||||
self._idx = self.keys.shape[2]
|
||||
return self.keys, self.values
|
||||
|
||||
# Generation mode
|
||||
# May not have hit the max size yet, so potentially
|
||||
# keep growing the cache
|
||||
if self.keys is None or (
|
||||
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
|
||||
):
|
||||
new_size = min(self.step, self.max_size - prev)
|
||||
k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim)
|
||||
v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
self._idx = prev
|
||||
|
||||
# Trim if needed
|
||||
trim_size = self.keys.shape[2] - self.max_size
|
||||
if trim_size > 0:
|
||||
self.keys = self._trim(trim_size, self.keys)
|
||||
self.values = self._trim(trim_size, self.values)
|
||||
self._idx = self.max_size
|
||||
|
||||
# Rotate
|
||||
if self._idx == self.max_size:
|
||||
self._idx = self.keep
|
||||
|
||||
# Assign
|
||||
self.keys[..., self._idx : self._idx + 1, :] = keys
|
||||
self.values[..., self._idx : self._idx + 1, :] = values
|
||||
self.offset += 1
|
||||
self._idx += 1
|
||||
|
||||
# If the buffer is not full, slice off the end
|
||||
if self.offset < self.max_size:
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
return self.keys, self.values
|
||||
|
||||
def state(self):
|
||||
return self.keys, self.values
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelArgs:
|
||||
@@ -65,13 +161,17 @@ def create_additive_causal_mask(N: int, offset: int = 0):
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
def create_attention_mask(h: mx.array, cache: Optional[List[KVCache]] = None):
|
||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
# Input consists of multiple tokens, create a causal mask so that prior
|
||||
# tokens do not give attention to later tokens. If a cache is in place
|
||||
# (because e.g. prompt reuse), offset the mask accordingly.
|
||||
offset = cache[0].offset if cache is not None and cache[0] is not None else 0
|
||||
if cache is not None and cache[0] is not None:
|
||||
c = cache[0]
|
||||
if isinstance(c, RotatingKVCache):
|
||||
offset = min(c.max_size - 1, c.offset)
|
||||
else:
|
||||
offset = c.offset
|
||||
else:
|
||||
offset = 0
|
||||
mask = create_additive_causal_mask(T, offset)
|
||||
mask = mask.astype(h.dtype)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user