reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat

This commit is contained in:
Awni Hannun
2024-10-05 14:49:39 -07:00
parent ed060a7c5c
commit 782f5a71b7
40 changed files with 824 additions and 691 deletions

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -60,7 +60,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -120,7 +120,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
residual = x
# NeoX runs attention and feedforward network in parallel.
@@ -214,11 +214,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.h
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads