* in place kv_cache

* fix

* fix kv cache size

* partially fix kv cache dtype

* step kv cache

* multiple of step size

* more teests + kv cache

* more kv cache

* udpate all models to use kv cache
This commit is contained in:
Awni Hannun 2024-05-08 08:18:13 -07:00 committed by GitHub
parent bfbc0e434a
commit ee60e2a9d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 534 additions and 298 deletions

View File

@ -1,6 +1,37 @@
import inspect
from dataclasses import dataclass
import mlx.core as mx
class KVCache:
def __init__(self, head_dim, n_kv_heads):
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.keys = None
self.values = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys, values):
prev = self.offset
if prev % self.step == 0:
n_steps = (self.step + keys.shape[2] - 1) // self.step
shape = (1, self.n_kv_heads, n_steps * self.step, self.head_dim)
new_k = mx.zeros(shape, keys.dtype)
new_v = mx.zeros(shape, values.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@dataclass
class BaseModelArgs:

View File

@ -84,11 +84,9 @@ class Attention(nn.Module):
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
@ -98,7 +96,7 @@ class Attention(nn.Module):
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
return self.o_proj(output)
class MLP(nn.Module):
@ -132,9 +130,9 @@ class TransformerBlock(nn.Module):
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
h = self.input_layernorm(x)
attn_h, cache = self.self_attn(h, mask, cache)
attn_h = self.self_attn(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x, cache
return attn_h + ff_h + x
class CohereModel(nn.Module):
@ -167,10 +165,10 @@ class CohereModel(nn.Module):
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h), cache
return self.norm(h)
class Model(nn.Module):
@ -178,17 +176,26 @@ class Model(nn.Module):
super().__init__()
self.model_type = args.model_type
self.model = CohereModel(args)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out, cache
return out
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -65,11 +65,9 @@ class Attention(nn.Module):
)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
@ -78,7 +76,7 @@ class Attention(nn.Module):
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(output), (keys, values)
return self.out_proj(output)
class NormAttnNorm(nn.Module):
@ -94,9 +92,9 @@ class NormAttnNorm(nn.Module):
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
h, cache = self.attn(self.norm_1(x), mask=mask, cache=cache)
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
x = h + x
return x, self.norm_2(x), cache
return x, self.norm_2(x)
class MLP(nn.Module):
@ -181,9 +179,9 @@ class DecoderLayer(nn.Module):
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, h, cache = self.norm_attn_norm(x, mask, cache)
r, h = self.norm_attn_norm(x, mask, cache)
out = self.ffn(h) + r
return out, cache
return out
class DBRX(nn.Module):
@ -210,10 +208,10 @@ class DBRX(nn.Module):
if cache is None:
cache = [None] * len(self.blocks)
for e, layer in enumerate(self.blocks):
h, cache[e] = layer(h, mask, cache[e])
for layer, c in zip(self.blocks, cache):
h = layer(h, mask, c)
return self.norm_f(h), cache
return self.norm_f(h)
class Model(nn.Module):
@ -229,8 +227,8 @@ class Model(nn.Module):
inputs: mx.array,
cache=None,
):
out, cache = self.transformer(inputs, cache)
return self.lm_head(out), cache
out = self.transformer(inputs, cache)
return self.lm_head(out)
@property
def layers(self):
@ -253,3 +251,11 @@ class Model(nn.Module):
experts = [(s, sv.T) for s, sv in experts]
new_weights.update(experts)
return new_weights
@property
def head_dim(self):
return self.args.d_model // self.args.n_heads
@property
def n_kv_heads(self):
return self.args.attn_config["kv_n_heads"]

View File

@ -70,11 +70,9 @@ class Attention(nn.Module):
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
@ -84,7 +82,7 @@ class Attention(nn.Module):
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
return self.o_proj(output)
class MLP(nn.Module):
@ -115,11 +113,11 @@ class TransformerBlock(nn.Module):
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
return out
class GemmaModel(nn.Module):
@ -151,10 +149,10 @@ class GemmaModel(nn.Module):
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h), cache
return self.norm(h)
class Model(nn.Module):
@ -162,16 +160,25 @@ class Model(nn.Module):
super().__init__()
self.model_type = args.model_type
self.model = GemmaModel(args)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out)
return out, cache
return out
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -78,11 +78,9 @@ class Attention(nn.Module):
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
@ -91,7 +89,7 @@ class Attention(nn.Module):
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
return self.o_proj(output)
class MLP(nn.Module):
@ -124,11 +122,11 @@ class TransformerBlock(nn.Module):
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
return out
def create_additive_causal_mask(N: int, offset: int = 0):
@ -168,10 +166,10 @@ class LlamaModel(nn.Module):
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h), cache
return self.norm(h)
class Model(nn.Module):
@ -180,14 +178,15 @@ class Model(nn.Module):
self.model_type = args.model_type
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
@ -198,3 +197,11 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -19,7 +19,6 @@ class ModelArgs(BaseModelArgs):
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int
max_position_embeddings: int
scale_depth: float
scale_emb: float
rope_theta: float = 1000000.0
@ -47,7 +46,6 @@ class Attention(nn.Module):
self.hidden_size = args.hidden_size
self.num_heads = n_heads = args.num_attention_heads
self.rope_theta = args.rope_theta
self.max_position_embeddings = args.max_position_embeddings
self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
@ -98,11 +96,9 @@ class Attention(nn.Module):
)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
@ -113,7 +109,7 @@ class Attention(nn.Module):
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(attn_output), (keys, values)
return self.o_proj(attn_output)
class DecoderLayer(nn.Module):
@ -139,11 +135,11 @@ class DecoderLayer(nn.Module):
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
r = self.mlp(self.post_attention_layernorm(h))
out = h + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
return out, cache
return out
class MiniCPMModel(nn.Module):
@ -172,10 +168,10 @@ class MiniCPMModel(nn.Module):
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h), cache
return self.norm(h)
class Model(nn.Module):
@ -193,14 +189,14 @@ class Model(nn.Module):
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
out = self.model(inputs, cache)
if not self.args.tie_word_embeddings:
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))
else:
out = out @ self.model.embed_tokens.weight.T
return out, cache
return out
def sanitize(self, weights):
if "lm_head.weight" not in weights:
@ -210,3 +206,11 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -77,11 +77,9 @@ class MixtralAttention(nn.Module):
)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
@ -90,7 +88,7 @@ class MixtralAttention(nn.Module):
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
return self.o_proj(output)
class MixtralBLockSparseTop2MLP(nn.Module):
@ -180,11 +178,11 @@ class MixtralDecoderLayer(nn.Module):
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.block_sparse_moe(self.post_attention_layernorm(h))
out = h + r
return out, cache
return out
class MixtralModel(nn.Module):
@ -215,10 +213,10 @@ class MixtralModel(nn.Module):
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h), cache
return self.norm(h)
class Model(nn.Module):
@ -227,15 +225,24 @@ class Model(nn.Module):
self.model_type = args.model_type
self.model = MixtralModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
out = self.model(inputs, cache)
return self.lm_head(out)
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -78,11 +78,9 @@ class TransformerBlock(nn.Module):
values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
@ -92,7 +90,7 @@ class TransformerBlock(nn.Module):
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.attn_out(output), (keys, values)
return self.attn_out(output)
def __call__(
self,
@ -100,13 +98,13 @@ class TransformerBlock(nn.Module):
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.attend(self.att_norm(x), mask, cache)
r = self.attend(self.att_norm(x), mask, cache)
h = x + r
x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1)
out = h + self.ff_out(nn.silu(x2) * x1)
return out, cache
return out
class Transformer(nn.Module):
@ -136,15 +134,15 @@ class Transformer(nn.Module):
if cache is None:
cache = [None] * len(self.blocks)
for e, block in enumerate(self.blocks):
h, cache[e] = block(h, mask, cache[e])
for block, c in zip(self.blocks, cache):
h = block(h, mask, c)
h = self.norm(h)
if self.weight_tying:
return self.wte.as_linear(h), cache
return self.ff_out(h), cache
return self.ff_out(h)
class OlmoModel(nn.Module):
@ -165,6 +163,7 @@ class Model(nn.Module):
super().__init__()
self.model_type = args.model_type
self.model = OlmoModel(args)
self.args = args
def __call__(
self,
@ -176,3 +175,11 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.transformer.blocks
@property
def head_dim(self):
return self.args.d_model // self.args.n_heads
@property
def n_kv_heads(self):
return self.args.n_heads

View File

@ -22,8 +22,7 @@ class ModelArgs(BaseModelArgs):
normalize_qk_projections: bool = True
share_input_output_layers: bool = True
rms_norm_eps: float = 1e-6
rope_theta: float = 10000
rope_traditional: bool = False
rope_freq_constant: float = 10000
def make_divisible(
@ -73,9 +72,7 @@ class Attention(nn.Module):
self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.rope = nn.RoPE(
head_dim, traditional=args.rope_traditional, base=args.rope_theta
)
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_freq_constant)
def __call__(
self,
@ -87,12 +84,10 @@ class Attention(nn.Module):
qkv = self.qkv_proj(x)
# [B, S, (q_h + k_h + v_h) * h] --> [B, S, (q_h + k_h + v_h), h] -> [B, (q_h + k_h + v_h), S, h]
qkv = qkv.reshape(
B, L, self.n_heads + (self.n_kv_heads * 2), self.head_dim
).transpose(0, 2, 1, 3)
# [B, (q_h + k_h + v_h), S, h] --> [B, q_h, S h], [B, k_h, S, h], [B, v_h, S, h]
queries, keys, values = mx.split(
qkv, [self.n_heads, self.n_heads + self.n_kv_heads], axis=1
)
@ -103,11 +98,9 @@ class Attention(nn.Module):
keys = self.k_norm(keys)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
@ -118,7 +111,7 @@ class Attention(nn.Module):
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(output), (keys, values)
return self.out_proj(output)
class MLP(nn.Module):
@ -159,11 +152,11 @@ class TransformerBlock(nn.Module):
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.attn(self.attn_norm(x), mask, cache)
r = self.attn(self.attn_norm(x), mask, cache)
h = x + r
r = self.ffn(self.ffn_norm(h))
out = h + r
return out, cache
return out
class OpenELMModel(nn.Module):
@ -195,10 +188,10 @@ class OpenELMModel(nn.Module):
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h), cache
return self.norm(h)
class Model(nn.Module):
@ -215,14 +208,22 @@ class Model(nn.Module):
inputs: mx.array,
cache=None,
):
out, cache = self.transformer(inputs, cache)
out = self.transformer(inputs, cache)
if self.args.share_input_output_layers:
out = self.transformer.token_embeddings.as_linear(out)
else:
out = self.lm_head(out)
return out, cache
return out
@property
def layers(self):
return self.transformer.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_kv_heads

View File

@ -75,35 +75,29 @@ class PhiAttention(nn.Module):
queries = queries.reshape(
B,
L,
n_kv_heads,
n_heads // n_kv_heads,
n_heads,
-1,
).moveaxis(1, 3)
keys = keys.reshape(B, L, n_kv_heads, 1, -1).moveaxis(1, 3)
values = values.reshape(B, L, n_kv_heads, 1, -1).moveaxis(1, 3)
).moveaxis(1, 2)
keys = keys.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
values = values.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[-2])
keys = self.rope(keys, offset=key_cache.shape[-2])
keys = mx.concatenate([key_cache, keys], axis=-2)
values = mx.concatenate([value_cache, values], axis=-2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries.astype(mx.float32)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.swapaxes(-1, -2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
output = (scores @ values).moveaxis(3, 1).reshape(B, L, -1)
output = mx.fast.scaled_dot_product_attention(
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
).astype(values.dtype)
return self.dense(output), (keys, values)
output = output.moveaxis(2, 1).reshape(B, L, -1)
return self.dense(output)
class PhiMLP(nn.Module):
@ -128,9 +122,9 @@ class PhiDecoderLayer(nn.Module):
def __call__(self, x, mask, cache):
h = self.input_layernorm(x)
attn_h, cache = self.self_attn(h, mask, cache)
attn_h = self.self_attn(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x, cache
return attn_h + ff_h + x
class PhiModel(nn.Module):
@ -152,9 +146,9 @@ class PhiModel(nn.Module):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, mask, cache[e])
return self.final_layernorm(x), cache
for layer, c in zip(self.layers, cache):
x = layer(x, mask, c)
return self.final_layernorm(x)
class Model(nn.Module):
@ -163,15 +157,24 @@ class Model(nn.Module):
self.model_type = config.model_type
self.model = PhiModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
self.args = config
def __call__(
self,
x: mx.array,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
y, cache = self.model(x, cache)
return self.lm_head(y), cache
y = self.model(x, cache)
return self.lm_head(y)
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -81,11 +81,9 @@ class Attention(nn.Module):
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
@ -94,7 +92,7 @@ class Attention(nn.Module):
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
return self.o_proj(output)
class MLP(nn.Module):
@ -128,11 +126,11 @@ class TransformerBlock(nn.Module):
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
return out
class Phi3Model(nn.Module):
@ -163,10 +161,10 @@ class Phi3Model(nn.Module):
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h), cache
return self.norm(h)
class Model(nn.Module):
@ -175,15 +173,24 @@ class Model(nn.Module):
self.model_type = args.model_type
self.model = Phi3Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
out = self.model(inputs, cache)
return self.lm_head(out)
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -11,7 +11,6 @@ import numpy as np
@dataclass
class ModelArgs:
model_type: str
max_sequence_length: int = 2048
num_vocab: int = 51200
model_dim: int = 2560
num_heads: int = 32
@ -56,11 +55,9 @@ class RoPEAttention(nn.Module):
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
@ -69,14 +66,13 @@ class RoPEAttention(nn.Module):
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
output = mx.fast.scaled_dot_product_attention(
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
return self.out_proj(output)
class MLP(nn.Module):
@ -144,9 +140,9 @@ class ParallelBlock(nn.Module):
def __call__(self, x, mask, cache):
h = self.ln(x)
attn_h, cache = self.mixer(h, mask, cache)
attn_h = self.mixer(h, mask, cache)
ff_h = self.moe(h)
return attn_h + ff_h + x, cache
return attn_h + ff_h + x
class TransformerDecoder(nn.Module):
@ -160,9 +156,9 @@ class TransformerDecoder(nn.Module):
if cache is None:
cache = [None] * len(self.h)
for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e])
return x, cache
for layer, c in zip(self.h, cache):
x = layer(x, mask, c)
return x
class Embd(nn.Module):
@ -190,6 +186,7 @@ class Model(nn.Module):
self.model_type = config.model_type
self.transformer = TransformerDecoder(config)
self.lm_head = OutputHead(config)
self.args = config
def __call__(
self,
@ -202,9 +199,17 @@ class Model(nn.Module):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
y, cache = self.transformer(x, mask, cache)
return self.lm_head(y), cache
y = self.transformer(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.model_dim // self.args.num_heads
@property
def n_kv_heads(self):
return self.args.num_heads

View File

@ -64,44 +64,38 @@ class Attention(nn.Module):
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
bsz, q_len, _ = hidden_states.shape
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
# Prepare the queries, keys and values for the attention computation
query_states = query_states.reshape(
bsz, q_len, self.q_num_heads, self.qk_dim
).transpose(0, 2, 1, 3)
key_states = key_states.reshape(
bsz, q_len, self.k_num_heads, self.qk_dim
).transpose(0, 2, 1, 3)
value_states = value_states.reshape(
bsz, q_len, self.v_num_heads, self.v_dim
).transpose(0, 2, 1, 3)
# expand shared kv
assert self.k_num_heads == self.v_num_heads
kv_seq_len = 0
if cache is not None:
kv_seq_len += cache[0].shape[-2]
query_states = self.rotary_emb(query_states, offset=kv_seq_len)
key_states = self.rotary_emb(key_states, offset=kv_seq_len)
queries = queries.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(
0, 2, 1, 3
)
keys = keys.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(
0, 2, 1, 3
)
values = values.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(
0, 2, 1, 3
)
if cache is not None:
# reuse k, v, self_attention
key_states = mx.concatenate([cache[0], key_states], axis=2)
value_states = mx.concatenate([cache[1], value_states], axis=2)
queries = self.rotary_emb(queries, offset=cache.offset)
keys = self.rotary_emb(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys)
output = mx.fast.scaled_dot_product_attention(
query_states,
key_states,
value_states,
queries,
keys,
values,
scale=self.scale,
mask=attention_mask,
)
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
return self.o_proj(output), (key_states, value_states)
return self.o_proj(output)
class MLP(nn.Module):
@ -139,7 +133,7 @@ class PlamoDecoderLayer(nn.Module):
hidden_states = self.norm(hidden_states)
# Self Attention
hidden_states_sa, cache = self.self_attn(
hidden_states_sa = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
cache=cache,
@ -149,7 +143,7 @@ class PlamoDecoderLayer(nn.Module):
hidden_states_mlp = self.mlp(hidden_states)
hidden_states = residual + hidden_states_sa + hidden_states_mlp
return hidden_states, cache
return hidden_states
class PlamoDecoder(nn.Module):
@ -185,14 +179,10 @@ class PlamoModel(nn.Module):
if cache is None:
cache = [None for _ in range(len(self.layers.layers))]
for e, layer in enumerate(self.layers.layers):
h, c = layer(h, mask, cache[e])
if cache is not None:
cache[e] = c
else:
cache.append(c)
for layer, c in zip(self.layers.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h), cache
return self.norm(h)
class Model(nn.Module):
@ -203,15 +193,24 @@ class Model(nn.Module):
self.lm_head: nn.Module = nn.Linear(
args.hidden_size, args.vocab_size, bias=False
)
self.args = args
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
) -> Tuple[mx.array, mx.array]:
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
out = self.model(inputs, cache)
return self.lm_head(out)
@property
def layers(self):
return self.model.layers.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_attention_heads // self.args.n_shared_head

View File

@ -51,30 +51,24 @@ class Attention(nn.Module):
B, L, _ = q.shape
q = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
queries = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
keys = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
values = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
k_cache, v_cache = cache
q = self.rotary_emb(q, offset=k_cache.shape[2])
k = self.rotary_emb(k, offset=k_cache.shape[2])
k = mx.concatenate([k_cache, k], axis=2)
v = mx.concatenate([v_cache, v], axis=2)
queries = self.rotary_emb(queries, offset=cache.offset)
keys = self.rotary_emb(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
q = self.rotary_emb(q)
k = self.rotary_emb(k)
queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys)
scores = (q * self.scale) @ k.transpose(0, 1, 3, 2)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(v_hat), (k, v)
return self.c_proj(output)
class MLP(nn.Module):
@ -109,13 +103,13 @@ class TransformerBlock(nn.Module):
def __call__(self, x, mask=None, cache=None):
residual = x
x = self.ln_1(x)
x, cache = self.attn(x, mask=mask, cache=cache)
x = self.attn(x, mask=mask, cache=cache)
residual = x + residual
x = self.ln_2(residual)
x = self.mlp(x)
x = x + residual
return x, cache
return x
class QwenModel(nn.Module):
@ -137,10 +131,10 @@ class QwenModel(nn.Module):
if cache is None:
cache = [None] * len(self.h)
for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e])
for layer, c in zip(self.h, cache):
x = layer(x, mask, c)
return self.ln_f(x), cache
return self.ln_f(x)
class Model(nn.Module):
@ -151,6 +145,7 @@ class Model(nn.Module):
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=not config.no_bias
)
self.args = config
def __call__(
self,
@ -158,9 +153,17 @@ class Model(nn.Module):
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
y, cache = self.transformer(x, mask, cache)
return self.lm_head(y), cache
y = self.transformer(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_attention_heads

View File

@ -79,11 +79,9 @@ class Attention(nn.Module):
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
@ -92,7 +90,7 @@ class Attention(nn.Module):
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
return self.o_proj(output)
class MLP(nn.Module):
@ -125,11 +123,11 @@ class TransformerBlock(nn.Module):
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
return out
class Qwen2Model(nn.Module):
@ -160,10 +158,10 @@ class Qwen2Model(nn.Module):
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h), cache
return self.norm(h)
class Model(nn.Module):
@ -180,12 +178,12 @@ class Model(nn.Module):
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out, cache
return out
def sanitize(self, weights):
if self.args.tie_word_embeddings:
@ -198,3 +196,11 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -78,11 +78,9 @@ class Attention(nn.Module):
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
@ -91,7 +89,7 @@ class Attention(nn.Module):
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
return self.o_proj(output)
class Qwen2MoeMLP(nn.Module):
@ -187,11 +185,11 @@ class Qwen2MoeDecoderLayer(nn.Module):
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
return out
class Qwen2MoeModel(nn.Module):
@ -222,10 +220,10 @@ class Qwen2MoeModel(nn.Module):
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h), cache
return self.norm(h)
class Model(nn.Module):
@ -241,8 +239,8 @@ class Model(nn.Module):
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
@ -255,3 +253,11 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -107,11 +107,9 @@ class Attention(nn.Module):
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
@ -125,7 +123,7 @@ class Attention(nn.Module):
queries, keys, values, scale=scale, mask=mask
).astype(values.dtype)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
return self.o_proj(output)
class MLP(nn.Module):
@ -157,7 +155,7 @@ class DecoderLayer(nn.Module):
def __call__(self, x, mask, cache):
h = self.input_layernorm(x)
r, cache = self.self_attn(h, mask, cache)
r = self.self_attn(h, mask, cache)
if self.use_parallel_residual:
out = x + r + self.mlp(h)
@ -165,7 +163,7 @@ class DecoderLayer(nn.Module):
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
return out
class StableLM(nn.Module):
@ -180,9 +178,10 @@ class StableLM(nn.Module):
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, mask, cache[e])
return self.norm(x), cache
for layer, c in zip(self.layers, cache):
x = layer(x, mask, cache=c)
return self.norm(x)
class Model(nn.Module):
@ -191,6 +190,7 @@ class Model(nn.Module):
self.model_type = config.model_type
self.model = StableLM(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.args = config
def __call__(
self,
@ -203,9 +203,17 @@ class Model(nn.Module):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype)
y, cache = self.model(x, mask, cache)
return self.lm_head(y), cache
y = self.model(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -55,11 +55,9 @@ class Attention(nn.Module):
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
@ -69,7 +67,7 @@ class Attention(nn.Module):
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
return self.o_proj(output)
class MLP(nn.Module):
@ -102,11 +100,11 @@ class TransformerBlock(nn.Module):
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
return out
class Starcoder2Model(nn.Module):
@ -137,10 +135,10 @@ class Starcoder2Model(nn.Module):
if cache is None:
cache = [None] * len(self.layers)
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h), cache
return self.norm(h)
class Model(nn.Module):
@ -157,13 +155,21 @@ class Model(nn.Module):
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out, cache
return out
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -314,7 +314,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}):
tokenizer_file = model_path / "tokenizer.json"
if tokenizer_file.exists():
tokenizer_content = json.load(tokenizer_file.open())
with open(tokenizer_file, "r") as fid:
tokenizer_content = json.load(fid)
if "decoder" in tokenizer_content:
if _is_spm_decoder(tokenizer_content["decoder"]):
detokenizer_class = SPMStreamingDetokenizer

View File

@ -18,6 +18,7 @@ from mlx.utils import tree_flatten
from transformers import AutoTokenizer, PreTrainedTokenizer
# Local imports
from .models.base import KVCache
from .sample_utils import top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import apply_lora_layers
@ -160,7 +161,12 @@ def generate_step(
)
y = prompt
cache = None
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
cache = [KVCache(model.head_dim, n) for n in kv_heads]
repetition_context = prompt.tolist()
@ -168,8 +174,8 @@ def generate_step(
repetition_context = repetition_context[-repetition_context_size:]
def _step(y):
nonlocal cache, repetition_context
logits, cache = model(y[None], cache=cache)
nonlocal repetition_context
logits = model(y[None], cache=cache)
logits = logits[:, -1, :]
if repetition_penalty:

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.12.0"
__version__ = "0.13.0"

View File

@ -4,6 +4,7 @@ import unittest
import mlx.core as mx
from mlx.utils import tree_map
from mlx_lm.models.base import KVCache
class TestModels(unittest.TestCase):
@ -17,13 +18,18 @@ class TestModels(unittest.TestCase):
model.update(tree_map(lambda p: p.astype(t), model.parameters()))
inputs = mx.array([[0, 1]])
outputs, cache = model(inputs)
outputs = model(inputs)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
outputs, cache = model(
mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
cache = [KVCache(model.head_dim, n) for n in kv_heads]
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
self.assertEqual(outputs.shape, (1, 1, vocab_size))
self.assertEqual(outputs.dtype, t)
@ -53,6 +59,15 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phixtral(self):
from mlx_lm.models import phixtral
args = phixtral.ModelArgs(
"phixtral", num_vocab=1000, num_layers=4, model_dim=1024
)
model = phixtral.Model(args)
self.model_test_runner(model, args.model_type, args.num_vocab, args.num_layers)
def test_phi3(self):
from mlx_lm.models import phi3
@ -264,6 +279,100 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_dbrx(self):
from mlx_lm.models import dbrx
args = dbrx.ModelArgs(
model_type="dbrx",
d_model=1024,
ffn_config={"ffn_hidden_size": 2048, "moe_num_experts": 4, "moe_top_k": 2},
attn_config={"kv_n_heads": 2, "clip_qkv": True, "rope_theta": 10000},
n_layers=4,
n_heads=4,
vocab_size=10_000,
)
model = dbrx.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layers)
def test_minicpm(self):
from mlx_lm.models import minicpm
args = minicpm.ModelArgs(
model_type="minicpm",
hidden_size=1024,
dim_model_base=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-4,
vocab_size=10000,
num_key_value_heads=2,
scale_depth=1.0,
scale_emb=1.0,
)
model = minicpm.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_openelm(self):
from mlx_lm.models import openelm
args = openelm.ModelArgs(
model_type="openelm",
ffn_dim_divisor=256,
ffn_multipliers=[
0.5,
0.73,
0.97,
1.2,
1.43,
1.67,
1.9,
2.13,
2.37,
2.6,
2.83,
3.07,
3.3,
3.53,
3.77,
4.0,
],
head_dim=64,
model_dim=1280,
normalize_qk_projections=True,
num_kv_heads=[3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5],
num_query_heads=[
12,
12,
12,
12,
12,
16,
16,
16,
16,
16,
16,
16,
20,
20,
20,
20,
],
num_transformer_layers=16,
vocab_size=32000,
)
model = openelm.Model(args)
self.model_test_runner(
model,
args.model_type,
args.vocab_size,
len(args.ffn_multipliers),
)
if __name__ == "__main__":
unittest.main()