mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
More cache improvements (#1015)
* fix rotating kv cache for chat use case * reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat * nit in chat * fix tests * fix tests * fix tests * docs * chat command * comments + docs * Define meta_state on all Cache implementations * fixes + trim_prompt_cache api * fix default model --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from sys import exit
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -13,7 +13,7 @@ try:
|
||||
import hf_olmo
|
||||
except ImportError:
|
||||
print("To run olmo install ai2-olmo: pip install ai2-olmo")
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -68,7 +68,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -98,7 +98,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attend(self.att_norm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -174,11 +174,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.transformer.blocks
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.d_model // self.args.n_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.n_heads
|
||||
|
Reference in New Issue
Block a user