mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Kv cache (#643)
* in place kv_cache * fix * fix kv cache size * partially fix kv cache dtype * step kv cache * multiple of step size * more teests + kv cache * more kv cache * udpate all models to use kv cache
This commit is contained in:
parent
bfbc0e434a
commit
ee60e2a9d5
@ -1,6 +1,37 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
class KVCache:
|
||||||
|
|
||||||
|
def __init__(self, head_dim, n_kv_heads):
|
||||||
|
self.n_kv_heads = n_kv_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.keys = None
|
||||||
|
self.values = None
|
||||||
|
self.offset = 0
|
||||||
|
self.step = 256
|
||||||
|
|
||||||
|
def update_and_fetch(self, keys, values):
|
||||||
|
prev = self.offset
|
||||||
|
if prev % self.step == 0:
|
||||||
|
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
||||||
|
shape = (1, self.n_kv_heads, n_steps * self.step, self.head_dim)
|
||||||
|
new_k = mx.zeros(shape, keys.dtype)
|
||||||
|
new_v = mx.zeros(shape, values.dtype)
|
||||||
|
if self.keys is not None:
|
||||||
|
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||||
|
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||||
|
else:
|
||||||
|
self.keys, self.values = new_k, new_v
|
||||||
|
|
||||||
|
self.offset += keys.shape[2]
|
||||||
|
self.keys[..., prev : self.offset, :] = keys
|
||||||
|
self.values[..., prev : self.offset, :] = values
|
||||||
|
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseModelArgs:
|
class BaseModelArgs:
|
||||||
|
@ -84,11 +84,9 @@ class Attention(nn.Module):
|
|||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
@ -98,7 +96,7 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output), (keys, values)
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
@ -132,9 +130,9 @@ class TransformerBlock(nn.Module):
|
|||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.input_layernorm(x)
|
h = self.input_layernorm(x)
|
||||||
attn_h, cache = self.self_attn(h, mask, cache)
|
attn_h = self.self_attn(h, mask, cache)
|
||||||
ff_h = self.mlp(h)
|
ff_h = self.mlp(h)
|
||||||
return attn_h + ff_h + x, cache
|
return attn_h + ff_h + x
|
||||||
|
|
||||||
|
|
||||||
class CohereModel(nn.Module):
|
class CohereModel(nn.Module):
|
||||||
@ -167,10 +165,10 @@ class CohereModel(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
for e, layer in enumerate(self.layers):
|
for layer, c in zip(self.layers, cache):
|
||||||
h, cache[e] = layer(h, mask, cache[e])
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
return self.norm(h), cache
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -178,17 +176,26 @@ class Model(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_type = args.model_type
|
self.model_type = args.model_type
|
||||||
self.model = CohereModel(args)
|
self.model = CohereModel(args)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out = self.model(inputs, cache)
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
out = out * self.model.args.logit_scale
|
out = out * self.model.args.logit_scale
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
|
@ -65,11 +65,9 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
@ -78,7 +76,7 @@ class Attention(nn.Module):
|
|||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.out_proj(output), (keys, values)
|
return self.out_proj(output)
|
||||||
|
|
||||||
|
|
||||||
class NormAttnNorm(nn.Module):
|
class NormAttnNorm(nn.Module):
|
||||||
@ -94,9 +92,9 @@ class NormAttnNorm(nn.Module):
|
|||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h, cache = self.attn(self.norm_1(x), mask=mask, cache=cache)
|
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
|
||||||
x = h + x
|
x = h + x
|
||||||
return x, self.norm_2(x), cache
|
return x, self.norm_2(x)
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
@ -181,9 +179,9 @@ class DecoderLayer(nn.Module):
|
|||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r, h, cache = self.norm_attn_norm(x, mask, cache)
|
r, h = self.norm_attn_norm(x, mask, cache)
|
||||||
out = self.ffn(h) + r
|
out = self.ffn(h) + r
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
|
|
||||||
class DBRX(nn.Module):
|
class DBRX(nn.Module):
|
||||||
@ -210,10 +208,10 @@ class DBRX(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.blocks)
|
cache = [None] * len(self.blocks)
|
||||||
|
|
||||||
for e, layer in enumerate(self.blocks):
|
for layer, c in zip(self.blocks, cache):
|
||||||
h, cache[e] = layer(h, mask, cache[e])
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
return self.norm_f(h), cache
|
return self.norm_f(h)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -229,8 +227,8 @@ class Model(nn.Module):
|
|||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.transformer(inputs, cache)
|
out = self.transformer(inputs, cache)
|
||||||
return self.lm_head(out), cache
|
return self.lm_head(out)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
@ -253,3 +251,11 @@ class Model(nn.Module):
|
|||||||
experts = [(s, sv.T) for s, sv in experts]
|
experts = [(s, sv.T) for s, sv in experts]
|
||||||
new_weights.update(experts)
|
new_weights.update(experts)
|
||||||
return new_weights
|
return new_weights
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.d_model // self.args.n_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.attn_config["kv_n_heads"]
|
||||||
|
@ -70,11 +70,9 @@ class Attention(nn.Module):
|
|||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
@ -84,7 +82,7 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output), (keys, values)
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
@ -115,11 +113,11 @@ class TransformerBlock(nn.Module):
|
|||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
out = h + r
|
out = h + r
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
|
|
||||||
class GemmaModel(nn.Module):
|
class GemmaModel(nn.Module):
|
||||||
@ -151,10 +149,10 @@ class GemmaModel(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
for e, layer in enumerate(self.layers):
|
for layer, c in zip(self.layers, cache):
|
||||||
h, cache[e] = layer(h, mask, cache[e])
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
return self.norm(h), cache
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -162,16 +160,25 @@ class Model(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_type = args.model_type
|
self.model_type = args.model_type
|
||||||
self.model = GemmaModel(args)
|
self.model = GemmaModel(args)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out = self.model(inputs, cache)
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.head_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
|
@ -78,11 +78,9 @@ class Attention(nn.Module):
|
|||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
@ -91,7 +89,7 @@ class Attention(nn.Module):
|
|||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output), (keys, values)
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
@ -124,11 +122,11 @@ class TransformerBlock(nn.Module):
|
|||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
out = h + r
|
out = h + r
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
|
|
||||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||||
@ -168,10 +166,10 @@ class LlamaModel(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
for e, layer in enumerate(self.layers):
|
for layer, c in zip(self.layers, cache):
|
||||||
h, cache[e] = layer(h, mask, cache[e])
|
h = layer(h, mask, cache=c)
|
||||||
|
|
||||||
return self.norm(h), cache
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -180,14 +178,15 @@ class Model(nn.Module):
|
|||||||
self.model_type = args.model_type
|
self.model_type = args.model_type
|
||||||
self.model = LlamaModel(args)
|
self.model = LlamaModel(args)
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out = self.model(inputs, cache)
|
||||||
return self.lm_head(out), cache
|
return self.lm_head(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
# Remove unused precomputed rotary freqs
|
# Remove unused precomputed rotary freqs
|
||||||
@ -198,3 +197,11 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
|
@ -19,7 +19,6 @@ class ModelArgs(BaseModelArgs):
|
|||||||
rms_norm_eps: float
|
rms_norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
num_key_value_heads: int
|
num_key_value_heads: int
|
||||||
max_position_embeddings: int
|
|
||||||
scale_depth: float
|
scale_depth: float
|
||||||
scale_emb: float
|
scale_emb: float
|
||||||
rope_theta: float = 1000000.0
|
rope_theta: float = 1000000.0
|
||||||
@ -47,7 +46,6 @@ class Attention(nn.Module):
|
|||||||
self.hidden_size = args.hidden_size
|
self.hidden_size = args.hidden_size
|
||||||
self.num_heads = n_heads = args.num_attention_heads
|
self.num_heads = n_heads = args.num_attention_heads
|
||||||
self.rope_theta = args.rope_theta
|
self.rope_theta = args.rope_theta
|
||||||
self.max_position_embeddings = args.max_position_embeddings
|
|
||||||
|
|
||||||
self.head_dim = head_dim = args.hidden_size // n_heads
|
self.head_dim = head_dim = args.hidden_size // n_heads
|
||||||
self.scale = head_dim**-0.5
|
self.scale = head_dim**-0.5
|
||||||
@ -98,11 +96,9 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
@ -113,7 +109,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
return self.o_proj(attn_output), (keys, values)
|
return self.o_proj(attn_output)
|
||||||
|
|
||||||
|
|
||||||
class DecoderLayer(nn.Module):
|
class DecoderLayer(nn.Module):
|
||||||
@ -139,11 +135,11 @@ class DecoderLayer(nn.Module):
|
|||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
out = h + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
out = h + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
|
|
||||||
class MiniCPMModel(nn.Module):
|
class MiniCPMModel(nn.Module):
|
||||||
@ -172,10 +168,10 @@ class MiniCPMModel(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
for e, layer in enumerate(self.layers):
|
for layer, c in zip(self.layers, cache):
|
||||||
h, cache[e] = layer(h, mask, cache[e])
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
return self.norm(h), cache
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -193,14 +189,14 @@ class Model(nn.Module):
|
|||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out = self.model(inputs, cache)
|
||||||
|
|
||||||
if not self.args.tie_word_embeddings:
|
if not self.args.tie_word_embeddings:
|
||||||
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))
|
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))
|
||||||
else:
|
else:
|
||||||
out = out @ self.model.embed_tokens.weight.T
|
out = out @ self.model.embed_tokens.weight.T
|
||||||
|
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
if "lm_head.weight" not in weights:
|
if "lm_head.weight" not in weights:
|
||||||
@ -210,3 +206,11 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
|
@ -77,11 +77,9 @@ class MixtralAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
@ -90,7 +88,7 @@ class MixtralAttention(nn.Module):
|
|||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output), (keys, values)
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
class MixtralBLockSparseTop2MLP(nn.Module):
|
class MixtralBLockSparseTop2MLP(nn.Module):
|
||||||
@ -180,11 +178,11 @@ class MixtralDecoderLayer(nn.Module):
|
|||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
r = self.block_sparse_moe(self.post_attention_layernorm(h))
|
r = self.block_sparse_moe(self.post_attention_layernorm(h))
|
||||||
out = h + r
|
out = h + r
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
|
|
||||||
class MixtralModel(nn.Module):
|
class MixtralModel(nn.Module):
|
||||||
@ -215,10 +213,10 @@ class MixtralModel(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
for e, layer in enumerate(self.layers):
|
for layer, c in zip(self.layers, cache):
|
||||||
h, cache[e] = layer(h, mask, cache[e])
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
return self.norm(h), cache
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -227,15 +225,24 @@ class Model(nn.Module):
|
|||||||
self.model_type = args.model_type
|
self.model_type = args.model_type
|
||||||
self.model = MixtralModel(args)
|
self.model = MixtralModel(args)
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out = self.model(inputs, cache)
|
||||||
return self.lm_head(out), cache
|
return self.lm_head(out)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
|
@ -78,11 +78,9 @@ class TransformerBlock(nn.Module):
|
|||||||
values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
@ -92,7 +90,7 @@ class TransformerBlock(nn.Module):
|
|||||||
scores += mask
|
scores += mask
|
||||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.attn_out(output), (keys, values)
|
return self.attn_out(output)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -100,13 +98,13 @@ class TransformerBlock(nn.Module):
|
|||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r, cache = self.attend(self.att_norm(x), mask, cache)
|
r = self.attend(self.att_norm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
|
|
||||||
x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1)
|
x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1)
|
||||||
|
|
||||||
out = h + self.ff_out(nn.silu(x2) * x1)
|
out = h + self.ff_out(nn.silu(x2) * x1)
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Transformer(nn.Module):
|
class Transformer(nn.Module):
|
||||||
@ -136,15 +134,15 @@ class Transformer(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.blocks)
|
cache = [None] * len(self.blocks)
|
||||||
|
|
||||||
for e, block in enumerate(self.blocks):
|
for block, c in zip(self.blocks, cache):
|
||||||
h, cache[e] = block(h, mask, cache[e])
|
h = block(h, mask, c)
|
||||||
|
|
||||||
h = self.norm(h)
|
h = self.norm(h)
|
||||||
|
|
||||||
if self.weight_tying:
|
if self.weight_tying:
|
||||||
return self.wte.as_linear(h), cache
|
return self.wte.as_linear(h), cache
|
||||||
|
|
||||||
return self.ff_out(h), cache
|
return self.ff_out(h)
|
||||||
|
|
||||||
|
|
||||||
class OlmoModel(nn.Module):
|
class OlmoModel(nn.Module):
|
||||||
@ -165,6 +163,7 @@ class Model(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_type = args.model_type
|
self.model_type = args.model_type
|
||||||
self.model = OlmoModel(args)
|
self.model = OlmoModel(args)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -176,3 +175,11 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.transformer.blocks
|
return self.model.transformer.blocks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.d_model // self.args.n_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.n_heads
|
||||||
|
@ -22,8 +22,7 @@ class ModelArgs(BaseModelArgs):
|
|||||||
normalize_qk_projections: bool = True
|
normalize_qk_projections: bool = True
|
||||||
share_input_output_layers: bool = True
|
share_input_output_layers: bool = True
|
||||||
rms_norm_eps: float = 1e-6
|
rms_norm_eps: float = 1e-6
|
||||||
rope_theta: float = 10000
|
rope_freq_constant: float = 10000
|
||||||
rope_traditional: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
def make_divisible(
|
def make_divisible(
|
||||||
@ -73,9 +72,7 @@ class Attention(nn.Module):
|
|||||||
self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
|
self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
|
||||||
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
|
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
|
||||||
|
|
||||||
self.rope = nn.RoPE(
|
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_freq_constant)
|
||||||
head_dim, traditional=args.rope_traditional, base=args.rope_theta
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -87,12 +84,10 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
qkv = self.qkv_proj(x)
|
qkv = self.qkv_proj(x)
|
||||||
|
|
||||||
# [B, S, (q_h + k_h + v_h) * h] --> [B, S, (q_h + k_h + v_h), h] -> [B, (q_h + k_h + v_h), S, h]
|
|
||||||
qkv = qkv.reshape(
|
qkv = qkv.reshape(
|
||||||
B, L, self.n_heads + (self.n_kv_heads * 2), self.head_dim
|
B, L, self.n_heads + (self.n_kv_heads * 2), self.head_dim
|
||||||
).transpose(0, 2, 1, 3)
|
).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
# [B, (q_h + k_h + v_h), S, h] --> [B, q_h, S h], [B, k_h, S, h], [B, v_h, S, h]
|
|
||||||
queries, keys, values = mx.split(
|
queries, keys, values = mx.split(
|
||||||
qkv, [self.n_heads, self.n_heads + self.n_kv_heads], axis=1
|
qkv, [self.n_heads, self.n_heads + self.n_kv_heads], axis=1
|
||||||
)
|
)
|
||||||
@ -103,11 +98,9 @@ class Attention(nn.Module):
|
|||||||
keys = self.k_norm(keys)
|
keys = self.k_norm(keys)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
@ -118,7 +111,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
return self.out_proj(output), (keys, values)
|
return self.out_proj(output)
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
@ -159,11 +152,11 @@ class TransformerBlock(nn.Module):
|
|||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r, cache = self.attn(self.attn_norm(x), mask, cache)
|
r = self.attn(self.attn_norm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
r = self.ffn(self.ffn_norm(h))
|
r = self.ffn(self.ffn_norm(h))
|
||||||
out = h + r
|
out = h + r
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
|
|
||||||
class OpenELMModel(nn.Module):
|
class OpenELMModel(nn.Module):
|
||||||
@ -195,10 +188,10 @@ class OpenELMModel(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
for e, layer in enumerate(self.layers):
|
for layer, c in zip(self.layers, cache):
|
||||||
h, cache[e] = layer(h, mask, cache[e])
|
h = layer(h, mask, cache=c)
|
||||||
|
|
||||||
return self.norm(h), cache
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -215,14 +208,22 @@ class Model(nn.Module):
|
|||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.transformer(inputs, cache)
|
out = self.transformer(inputs, cache)
|
||||||
if self.args.share_input_output_layers:
|
if self.args.share_input_output_layers:
|
||||||
out = self.transformer.token_embeddings.as_linear(out)
|
out = self.transformer.token_embeddings.as_linear(out)
|
||||||
else:
|
else:
|
||||||
out = self.lm_head(out)
|
out = self.lm_head(out)
|
||||||
|
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.transformer.layers
|
return self.transformer.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.head_dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_kv_heads
|
||||||
|
@ -75,35 +75,29 @@ class PhiAttention(nn.Module):
|
|||||||
queries = queries.reshape(
|
queries = queries.reshape(
|
||||||
B,
|
B,
|
||||||
L,
|
L,
|
||||||
n_kv_heads,
|
n_heads,
|
||||||
n_heads // n_kv_heads,
|
|
||||||
-1,
|
-1,
|
||||||
).moveaxis(1, 3)
|
).moveaxis(1, 2)
|
||||||
keys = keys.reshape(B, L, n_kv_heads, 1, -1).moveaxis(1, 3)
|
keys = keys.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
|
||||||
values = values.reshape(B, L, n_kv_heads, 1, -1).moveaxis(1, 3)
|
values = values.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
|
||||||
|
|
||||||
# Add RoPE to the queries and keys and combine them with the cache
|
# Add RoPE to the queries and keys and combine them with the cache
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
queries = self.rope(queries, offset=key_cache.shape[-2])
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
keys = self.rope(keys, offset=key_cache.shape[-2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = mx.concatenate([key_cache, keys], axis=-2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=-2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
|
|
||||||
queries = queries.astype(mx.float32)
|
|
||||||
|
|
||||||
# Finally perform the attention computation
|
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
scores = (queries * scale) @ keys.swapaxes(-1, -2)
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
if mask is not None:
|
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
||||||
scores = scores + mask
|
).astype(values.dtype)
|
||||||
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
|
|
||||||
output = (scores @ values).moveaxis(3, 1).reshape(B, L, -1)
|
|
||||||
|
|
||||||
return self.dense(output), (keys, values)
|
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||||
|
|
||||||
|
return self.dense(output)
|
||||||
|
|
||||||
|
|
||||||
class PhiMLP(nn.Module):
|
class PhiMLP(nn.Module):
|
||||||
@ -128,9 +122,9 @@ class PhiDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
def __call__(self, x, mask, cache):
|
def __call__(self, x, mask, cache):
|
||||||
h = self.input_layernorm(x)
|
h = self.input_layernorm(x)
|
||||||
attn_h, cache = self.self_attn(h, mask, cache)
|
attn_h = self.self_attn(h, mask, cache)
|
||||||
ff_h = self.mlp(h)
|
ff_h = self.mlp(h)
|
||||||
return attn_h + ff_h + x, cache
|
return attn_h + ff_h + x
|
||||||
|
|
||||||
|
|
||||||
class PhiModel(nn.Module):
|
class PhiModel(nn.Module):
|
||||||
@ -152,9 +146,9 @@ class PhiModel(nn.Module):
|
|||||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||||
mask = mask.astype(x.dtype)
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
for e, layer in enumerate(self.layers):
|
for layer, c in zip(self.layers, cache):
|
||||||
x, cache[e] = layer(x, mask, cache[e])
|
x = layer(x, mask, c)
|
||||||
return self.final_layernorm(x), cache
|
return self.final_layernorm(x)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -163,15 +157,24 @@ class Model(nn.Module):
|
|||||||
self.model_type = config.model_type
|
self.model_type = config.model_type
|
||||||
self.model = PhiModel(config)
|
self.model = PhiModel(config)
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||||
|
self.args = config
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
cache: mx.array = None,
|
cache: mx.array = None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
y, cache = self.model(x, cache)
|
y = self.model(x, cache)
|
||||||
return self.lm_head(y), cache
|
return self.lm_head(y)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
|
@ -81,11 +81,9 @@ class Attention(nn.Module):
|
|||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
@ -94,7 +92,7 @@ class Attention(nn.Module):
|
|||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output), (keys, values)
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
@ -128,11 +126,11 @@ class TransformerBlock(nn.Module):
|
|||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
out = h + r
|
out = h + r
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Phi3Model(nn.Module):
|
class Phi3Model(nn.Module):
|
||||||
@ -163,10 +161,10 @@ class Phi3Model(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
for e, layer in enumerate(self.layers):
|
for layer, c in zip(self.layers, cache):
|
||||||
h, cache[e] = layer(h, mask, cache[e])
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
return self.norm(h), cache
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -175,15 +173,24 @@ class Model(nn.Module):
|
|||||||
self.model_type = args.model_type
|
self.model_type = args.model_type
|
||||||
self.model = Phi3Model(args)
|
self.model = Phi3Model(args)
|
||||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out = self.model(inputs, cache)
|
||||||
return self.lm_head(out), cache
|
return self.lm_head(out)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
|
@ -11,7 +11,6 @@ import numpy as np
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs:
|
class ModelArgs:
|
||||||
model_type: str
|
model_type: str
|
||||||
max_sequence_length: int = 2048
|
|
||||||
num_vocab: int = 51200
|
num_vocab: int = 51200
|
||||||
model_dim: int = 2560
|
model_dim: int = 2560
|
||||||
num_heads: int = 32
|
num_heads: int = 32
|
||||||
@ -56,11 +55,9 @@ class RoPEAttention(nn.Module):
|
|||||||
|
|
||||||
# Add RoPE to the queries and keys and combine them with the cache
|
# Add RoPE to the queries and keys and combine them with the cache
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
@ -69,14 +66,13 @@ class RoPEAttention(nn.Module):
|
|||||||
|
|
||||||
# Finally perform the attention computation
|
# Finally perform the attention computation
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
scale = math.sqrt(1 / queries.shape[-1])
|
||||||
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
|
|
||||||
if mask is not None:
|
|
||||||
scores = scores + mask
|
|
||||||
|
|
||||||
scores = mx.softmax(scores, axis=-1).astype(values.dtype)
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
queries.astype(mx.float32), keys, values, scale=scale, mask=mask
|
||||||
|
).astype(values.dtype)
|
||||||
|
output = output.moveaxis(2, 1).reshape(B, L, -1)
|
||||||
|
|
||||||
return self.out_proj(values_hat), (keys, values)
|
return self.out_proj(output)
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
@ -144,9 +140,9 @@ class ParallelBlock(nn.Module):
|
|||||||
|
|
||||||
def __call__(self, x, mask, cache):
|
def __call__(self, x, mask, cache):
|
||||||
h = self.ln(x)
|
h = self.ln(x)
|
||||||
attn_h, cache = self.mixer(h, mask, cache)
|
attn_h = self.mixer(h, mask, cache)
|
||||||
ff_h = self.moe(h)
|
ff_h = self.moe(h)
|
||||||
return attn_h + ff_h + x, cache
|
return attn_h + ff_h + x
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoder(nn.Module):
|
class TransformerDecoder(nn.Module):
|
||||||
@ -160,9 +156,9 @@ class TransformerDecoder(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.h)
|
cache = [None] * len(self.h)
|
||||||
|
|
||||||
for e, layer in enumerate(self.h):
|
for layer, c in zip(self.h, cache):
|
||||||
x, cache[e] = layer(x, mask, cache[e])
|
x = layer(x, mask, c)
|
||||||
return x, cache
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Embd(nn.Module):
|
class Embd(nn.Module):
|
||||||
@ -190,6 +186,7 @@ class Model(nn.Module):
|
|||||||
self.model_type = config.model_type
|
self.model_type = config.model_type
|
||||||
self.transformer = TransformerDecoder(config)
|
self.transformer = TransformerDecoder(config)
|
||||||
self.lm_head = OutputHead(config)
|
self.lm_head = OutputHead(config)
|
||||||
|
self.args = config
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -202,9 +199,17 @@ class Model(nn.Module):
|
|||||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||||
mask = mask.astype(x.dtype)
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
y, cache = self.transformer(x, mask, cache)
|
y = self.transformer(x, mask, cache)
|
||||||
return self.lm_head(y), cache
|
return self.lm_head(y)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.transformer.h
|
return self.transformer.h
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.model_dim // self.args.num_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_heads
|
||||||
|
@ -64,44 +64,38 @@ class Attention(nn.Module):
|
|||||||
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||||
bsz, q_len, _ = hidden_states.shape
|
bsz, q_len, _ = hidden_states.shape
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
queries = self.q_proj(hidden_states)
|
||||||
key_states = self.k_proj(hidden_states)
|
keys = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
values = self.v_proj(hidden_states)
|
||||||
|
|
||||||
# Prepare the queries, keys and values for the attention computation
|
# Prepare the queries, keys and values for the attention computation
|
||||||
query_states = query_states.reshape(
|
queries = queries.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(
|
||||||
bsz, q_len, self.q_num_heads, self.qk_dim
|
0, 2, 1, 3
|
||||||
).transpose(0, 2, 1, 3)
|
)
|
||||||
key_states = key_states.reshape(
|
keys = keys.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(
|
||||||
bsz, q_len, self.k_num_heads, self.qk_dim
|
0, 2, 1, 3
|
||||||
).transpose(0, 2, 1, 3)
|
)
|
||||||
value_states = value_states.reshape(
|
values = values.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(
|
||||||
bsz, q_len, self.v_num_heads, self.v_dim
|
0, 2, 1, 3
|
||||||
).transpose(0, 2, 1, 3)
|
)
|
||||||
|
|
||||||
# expand shared kv
|
|
||||||
assert self.k_num_heads == self.v_num_heads
|
|
||||||
|
|
||||||
kv_seq_len = 0
|
|
||||||
if cache is not None:
|
|
||||||
kv_seq_len += cache[0].shape[-2]
|
|
||||||
query_states = self.rotary_emb(query_states, offset=kv_seq_len)
|
|
||||||
key_states = self.rotary_emb(key_states, offset=kv_seq_len)
|
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
# reuse k, v, self_attention
|
queries = self.rotary_emb(queries, offset=cache.offset)
|
||||||
key_states = mx.concatenate([cache[0], key_states], axis=2)
|
keys = self.rotary_emb(keys, offset=cache.offset)
|
||||||
value_states = mx.concatenate([cache[1], value_states], axis=2)
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
|
else:
|
||||||
|
queries = self.rotary_emb(queries)
|
||||||
|
keys = self.rotary_emb(keys)
|
||||||
|
|
||||||
output = mx.fast.scaled_dot_product_attention(
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
query_states,
|
queries,
|
||||||
key_states,
|
keys,
|
||||||
value_states,
|
values,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
mask=attention_mask,
|
mask=attention_mask,
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
|
||||||
return self.o_proj(output), (key_states, value_states)
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
@ -139,7 +133,7 @@ class PlamoDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states_sa, cache = self.self_attn(
|
hidden_states_sa = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
cache=cache,
|
cache=cache,
|
||||||
@ -149,7 +143,7 @@ class PlamoDecoderLayer(nn.Module):
|
|||||||
hidden_states_mlp = self.mlp(hidden_states)
|
hidden_states_mlp = self.mlp(hidden_states)
|
||||||
|
|
||||||
hidden_states = residual + hidden_states_sa + hidden_states_mlp
|
hidden_states = residual + hidden_states_sa + hidden_states_mlp
|
||||||
return hidden_states, cache
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class PlamoDecoder(nn.Module):
|
class PlamoDecoder(nn.Module):
|
||||||
@ -185,14 +179,10 @@ class PlamoModel(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None for _ in range(len(self.layers.layers))]
|
cache = [None for _ in range(len(self.layers.layers))]
|
||||||
|
|
||||||
for e, layer in enumerate(self.layers.layers):
|
for layer, c in zip(self.layers.layers, cache):
|
||||||
h, c = layer(h, mask, cache[e])
|
h = layer(h, mask, cache=c)
|
||||||
if cache is not None:
|
|
||||||
cache[e] = c
|
|
||||||
else:
|
|
||||||
cache.append(c)
|
|
||||||
|
|
||||||
return self.norm(h), cache
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -203,15 +193,24 @@ class Model(nn.Module):
|
|||||||
self.lm_head: nn.Module = nn.Linear(
|
self.lm_head: nn.Module = nn.Linear(
|
||||||
args.hidden_size, args.vocab_size, bias=False
|
args.hidden_size, args.vocab_size, bias=False
|
||||||
)
|
)
|
||||||
|
self.args = args
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
out, cache = self.model(inputs, cache)
|
out = self.model(inputs, cache)
|
||||||
return self.lm_head(out), cache
|
return self.lm_head(out)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers.layers
|
return self.model.layers.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_attention_heads // self.args.n_shared_head
|
||||||
|
@ -51,30 +51,24 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
B, L, _ = q.shape
|
B, L, _ = q.shape
|
||||||
|
|
||||||
q = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
queries = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||||
k = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
keys = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||||
v = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
values = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
k_cache, v_cache = cache
|
queries = self.rotary_emb(queries, offset=cache.offset)
|
||||||
q = self.rotary_emb(q, offset=k_cache.shape[2])
|
keys = self.rotary_emb(keys, offset=cache.offset)
|
||||||
k = self.rotary_emb(k, offset=k_cache.shape[2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
k = mx.concatenate([k_cache, k], axis=2)
|
|
||||||
v = mx.concatenate([v_cache, v], axis=2)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
q = self.rotary_emb(q)
|
queries = self.rotary_emb(queries)
|
||||||
k = self.rotary_emb(k)
|
keys = self.rotary_emb(keys)
|
||||||
|
|
||||||
scores = (q * self.scale) @ k.transpose(0, 1, 3, 2)
|
output = mx.fast.scaled_dot_product_attention(
|
||||||
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
|
)
|
||||||
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
if mask is not None:
|
return self.c_proj(output)
|
||||||
scores = scores + mask
|
|
||||||
|
|
||||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
|
||||||
v_hat = (scores @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
||||||
|
|
||||||
return self.c_proj(v_hat), (k, v)
|
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
@ -109,13 +103,13 @@ class TransformerBlock(nn.Module):
|
|||||||
def __call__(self, x, mask=None, cache=None):
|
def __call__(self, x, mask=None, cache=None):
|
||||||
residual = x
|
residual = x
|
||||||
x = self.ln_1(x)
|
x = self.ln_1(x)
|
||||||
x, cache = self.attn(x, mask=mask, cache=cache)
|
x = self.attn(x, mask=mask, cache=cache)
|
||||||
residual = x + residual
|
residual = x + residual
|
||||||
x = self.ln_2(residual)
|
x = self.ln_2(residual)
|
||||||
x = self.mlp(x)
|
x = self.mlp(x)
|
||||||
x = x + residual
|
x = x + residual
|
||||||
|
|
||||||
return x, cache
|
return x
|
||||||
|
|
||||||
|
|
||||||
class QwenModel(nn.Module):
|
class QwenModel(nn.Module):
|
||||||
@ -137,10 +131,10 @@ class QwenModel(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.h)
|
cache = [None] * len(self.h)
|
||||||
|
|
||||||
for e, layer in enumerate(self.h):
|
for layer, c in zip(self.h, cache):
|
||||||
x, cache[e] = layer(x, mask, cache[e])
|
x = layer(x, mask, c)
|
||||||
|
|
||||||
return self.ln_f(x), cache
|
return self.ln_f(x)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -151,6 +145,7 @@ class Model(nn.Module):
|
|||||||
self.lm_head = nn.Linear(
|
self.lm_head = nn.Linear(
|
||||||
config.hidden_size, config.vocab_size, bias=not config.no_bias
|
config.hidden_size, config.vocab_size, bias=not config.no_bias
|
||||||
)
|
)
|
||||||
|
self.args = config
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -158,9 +153,17 @@ class Model(nn.Module):
|
|||||||
mask: mx.array = None,
|
mask: mx.array = None,
|
||||||
cache: mx.array = None,
|
cache: mx.array = None,
|
||||||
) -> Tuple[mx.array, mx.array]:
|
) -> Tuple[mx.array, mx.array]:
|
||||||
y, cache = self.transformer(x, mask, cache)
|
y = self.transformer(x, mask, cache)
|
||||||
return self.lm_head(y), cache
|
return self.lm_head(y)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.transformer.h
|
return self.transformer.h
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_attention_heads
|
||||||
|
@ -79,11 +79,9 @@ class Attention(nn.Module):
|
|||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
@ -92,7 +90,7 @@ class Attention(nn.Module):
|
|||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output), (keys, values)
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
@ -125,11 +123,11 @@ class TransformerBlock(nn.Module):
|
|||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
out = h + r
|
out = h + r
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Qwen2Model(nn.Module):
|
class Qwen2Model(nn.Module):
|
||||||
@ -160,10 +158,10 @@ class Qwen2Model(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
for e, layer in enumerate(self.layers):
|
for layer, c in zip(self.layers, cache):
|
||||||
h, cache[e] = layer(h, mask, cache[e])
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
return self.norm(h), cache
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -180,12 +178,12 @@ class Model(nn.Module):
|
|||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out = self.model(inputs, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
else:
|
else:
|
||||||
out = self.lm_head(out)
|
out = self.lm_head(out)
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
@ -198,3 +196,11 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
|
@ -78,11 +78,9 @@ class Attention(nn.Module):
|
|||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
@ -91,7 +89,7 @@ class Attention(nn.Module):
|
|||||||
queries, keys, values, scale=self.scale, mask=mask
|
queries, keys, values, scale=self.scale, mask=mask
|
||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output), (keys, values)
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeMLP(nn.Module):
|
class Qwen2MoeMLP(nn.Module):
|
||||||
@ -187,11 +185,11 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
out = h + r
|
out = h + r
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Qwen2MoeModel(nn.Module):
|
class Qwen2MoeModel(nn.Module):
|
||||||
@ -222,10 +220,10 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
for e, layer in enumerate(self.layers):
|
for layer, c in zip(self.layers, cache):
|
||||||
h, cache[e] = layer(h, mask, cache[e])
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
return self.norm(h), cache
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -241,8 +239,8 @@ class Model(nn.Module):
|
|||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out = self.model(inputs, cache)
|
||||||
return self.lm_head(out), cache
|
return self.lm_head(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
|
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
|
||||||
@ -255,3 +253,11 @@ class Model(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
|
@ -107,11 +107,9 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
# Add RoPE to the queries and keys and combine them with the cache
|
# Add RoPE to the queries and keys and combine them with the cache
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
@ -125,7 +123,7 @@ class Attention(nn.Module):
|
|||||||
queries, keys, values, scale=scale, mask=mask
|
queries, keys, values, scale=scale, mask=mask
|
||||||
).astype(values.dtype)
|
).astype(values.dtype)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output), (keys, values)
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
@ -157,7 +155,7 @@ class DecoderLayer(nn.Module):
|
|||||||
|
|
||||||
def __call__(self, x, mask, cache):
|
def __call__(self, x, mask, cache):
|
||||||
h = self.input_layernorm(x)
|
h = self.input_layernorm(x)
|
||||||
r, cache = self.self_attn(h, mask, cache)
|
r = self.self_attn(h, mask, cache)
|
||||||
|
|
||||||
if self.use_parallel_residual:
|
if self.use_parallel_residual:
|
||||||
out = x + r + self.mlp(h)
|
out = x + r + self.mlp(h)
|
||||||
@ -165,7 +163,7 @@ class DecoderLayer(nn.Module):
|
|||||||
h = x + r
|
h = x + r
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
out = h + r
|
out = h + r
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
|
|
||||||
class StableLM(nn.Module):
|
class StableLM(nn.Module):
|
||||||
@ -180,9 +178,10 @@ class StableLM(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
for e, layer in enumerate(self.layers):
|
for layer, c in zip(self.layers, cache):
|
||||||
x, cache[e] = layer(x, mask, cache[e])
|
x = layer(x, mask, cache=c)
|
||||||
return self.norm(x), cache
|
|
||||||
|
return self.norm(x)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -191,6 +190,7 @@ class Model(nn.Module):
|
|||||||
self.model_type = config.model_type
|
self.model_type = config.model_type
|
||||||
self.model = StableLM(config)
|
self.model = StableLM(config)
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
self.args = config
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -203,9 +203,17 @@ class Model(nn.Module):
|
|||||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||||
mask = mask.astype(x.dtype)
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
y, cache = self.model(x, mask, cache)
|
y = self.model(x, mask, cache)
|
||||||
return self.lm_head(y), cache
|
return self.lm_head(y)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
|
@ -55,11 +55,9 @@ class Attention(nn.Module):
|
|||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
key_cache, value_cache = cache
|
queries = self.rope(queries, offset=cache.offset)
|
||||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
keys = self.rope(keys, offset=cache.offset)
|
||||||
keys = self.rope(keys, offset=key_cache.shape[2])
|
keys, values = cache.update_and_fetch(keys, values)
|
||||||
keys = mx.concatenate([key_cache, keys], axis=2)
|
|
||||||
values = mx.concatenate([value_cache, values], axis=2)
|
|
||||||
else:
|
else:
|
||||||
queries = self.rope(queries)
|
queries = self.rope(queries)
|
||||||
keys = self.rope(keys)
|
keys = self.rope(keys)
|
||||||
@ -69,7 +67,7 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output), (keys, values)
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
@ -102,11 +100,11 @@ class TransformerBlock(nn.Module):
|
|||||||
mask: Optional[mx.array] = None,
|
mask: Optional[mx.array] = None,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
|
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||||
h = x + r
|
h = x + r
|
||||||
r = self.mlp(self.post_attention_layernorm(h))
|
r = self.mlp(self.post_attention_layernorm(h))
|
||||||
out = h + r
|
out = h + r
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Starcoder2Model(nn.Module):
|
class Starcoder2Model(nn.Module):
|
||||||
@ -137,10 +135,10 @@ class Starcoder2Model(nn.Module):
|
|||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
|
|
||||||
for e, layer in enumerate(self.layers):
|
for layer, c in zip(self.layers, cache):
|
||||||
h, cache[e] = layer(h, mask, cache[e])
|
h = layer(h, mask, c)
|
||||||
|
|
||||||
return self.norm(h), cache
|
return self.norm(h)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -157,13 +155,21 @@ class Model(nn.Module):
|
|||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out, cache = self.model(inputs, cache)
|
out = self.model(inputs, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
else:
|
else:
|
||||||
out = self.lm_head(out)
|
out = self.lm_head(out)
|
||||||
return out, cache
|
return out
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.args.hidden_size // self.args.num_attention_heads
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_kv_heads(self):
|
||||||
|
return self.args.num_key_value_heads
|
||||||
|
@ -314,7 +314,8 @@ def load_tokenizer(model_path, tokenizer_config_extra={}):
|
|||||||
|
|
||||||
tokenizer_file = model_path / "tokenizer.json"
|
tokenizer_file = model_path / "tokenizer.json"
|
||||||
if tokenizer_file.exists():
|
if tokenizer_file.exists():
|
||||||
tokenizer_content = json.load(tokenizer_file.open())
|
with open(tokenizer_file, "r") as fid:
|
||||||
|
tokenizer_content = json.load(fid)
|
||||||
if "decoder" in tokenizer_content:
|
if "decoder" in tokenizer_content:
|
||||||
if _is_spm_decoder(tokenizer_content["decoder"]):
|
if _is_spm_decoder(tokenizer_content["decoder"]):
|
||||||
detokenizer_class = SPMStreamingDetokenizer
|
detokenizer_class = SPMStreamingDetokenizer
|
||||||
|
@ -18,6 +18,7 @@ from mlx.utils import tree_flatten
|
|||||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
|
from .models.base import KVCache
|
||||||
from .sample_utils import top_p_sampling
|
from .sample_utils import top_p_sampling
|
||||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
from .tuner.utils import apply_lora_layers
|
from .tuner.utils import apply_lora_layers
|
||||||
@ -160,7 +161,12 @@ def generate_step(
|
|||||||
)
|
)
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
cache = None
|
kv_heads = (
|
||||||
|
[model.n_kv_heads] * len(model.layers)
|
||||||
|
if isinstance(model.n_kv_heads, int)
|
||||||
|
else model.n_kv_heads
|
||||||
|
)
|
||||||
|
cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
||||||
|
|
||||||
repetition_context = prompt.tolist()
|
repetition_context = prompt.tolist()
|
||||||
|
|
||||||
@ -168,8 +174,8 @@ def generate_step(
|
|||||||
repetition_context = repetition_context[-repetition_context_size:]
|
repetition_context = repetition_context[-repetition_context_size:]
|
||||||
|
|
||||||
def _step(y):
|
def _step(y):
|
||||||
nonlocal cache, repetition_context
|
nonlocal repetition_context
|
||||||
logits, cache = model(y[None], cache=cache)
|
logits = model(y[None], cache=cache)
|
||||||
logits = logits[:, -1, :]
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
if repetition_penalty:
|
if repetition_penalty:
|
||||||
@ -445,9 +451,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
|
|||||||
card.text = dedent(
|
card.text = dedent(
|
||||||
f"""
|
f"""
|
||||||
# {upload_repo}
|
# {upload_repo}
|
||||||
|
|
||||||
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**.
|
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**.
|
||||||
|
|
||||||
## Use with mlx
|
## Use with mlx
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.12.0"
|
__version__ = "0.13.0"
|
||||||
|
@ -4,6 +4,7 @@ import unittest
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map
|
||||||
|
from mlx_lm.models.base import KVCache
|
||||||
|
|
||||||
|
|
||||||
class TestModels(unittest.TestCase):
|
class TestModels(unittest.TestCase):
|
||||||
@ -17,13 +18,18 @@ class TestModels(unittest.TestCase):
|
|||||||
model.update(tree_map(lambda p: p.astype(t), model.parameters()))
|
model.update(tree_map(lambda p: p.astype(t), model.parameters()))
|
||||||
|
|
||||||
inputs = mx.array([[0, 1]])
|
inputs = mx.array([[0, 1]])
|
||||||
outputs, cache = model(inputs)
|
outputs = model(inputs)
|
||||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||||
self.assertEqual(outputs.dtype, t)
|
self.assertEqual(outputs.dtype, t)
|
||||||
|
|
||||||
outputs, cache = model(
|
kv_heads = (
|
||||||
mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache
|
[model.n_kv_heads] * len(model.layers)
|
||||||
|
if isinstance(model.n_kv_heads, int)
|
||||||
|
else model.n_kv_heads
|
||||||
)
|
)
|
||||||
|
cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
||||||
|
|
||||||
|
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
|
||||||
self.assertEqual(outputs.shape, (1, 1, vocab_size))
|
self.assertEqual(outputs.shape, (1, 1, vocab_size))
|
||||||
self.assertEqual(outputs.dtype, t)
|
self.assertEqual(outputs.dtype, t)
|
||||||
|
|
||||||
@ -53,6 +59,15 @@ class TestModels(unittest.TestCase):
|
|||||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_phixtral(self):
|
||||||
|
from mlx_lm.models import phixtral
|
||||||
|
|
||||||
|
args = phixtral.ModelArgs(
|
||||||
|
"phixtral", num_vocab=1000, num_layers=4, model_dim=1024
|
||||||
|
)
|
||||||
|
model = phixtral.Model(args)
|
||||||
|
self.model_test_runner(model, args.model_type, args.num_vocab, args.num_layers)
|
||||||
|
|
||||||
def test_phi3(self):
|
def test_phi3(self):
|
||||||
from mlx_lm.models import phi3
|
from mlx_lm.models import phi3
|
||||||
|
|
||||||
@ -264,6 +279,100 @@ class TestModels(unittest.TestCase):
|
|||||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_dbrx(self):
|
||||||
|
from mlx_lm.models import dbrx
|
||||||
|
|
||||||
|
args = dbrx.ModelArgs(
|
||||||
|
model_type="dbrx",
|
||||||
|
d_model=1024,
|
||||||
|
ffn_config={"ffn_hidden_size": 2048, "moe_num_experts": 4, "moe_top_k": 2},
|
||||||
|
attn_config={"kv_n_heads": 2, "clip_qkv": True, "rope_theta": 10000},
|
||||||
|
n_layers=4,
|
||||||
|
n_heads=4,
|
||||||
|
vocab_size=10_000,
|
||||||
|
)
|
||||||
|
model = dbrx.Model(args)
|
||||||
|
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layers)
|
||||||
|
|
||||||
|
def test_minicpm(self):
|
||||||
|
from mlx_lm.models import minicpm
|
||||||
|
|
||||||
|
args = minicpm.ModelArgs(
|
||||||
|
model_type="minicpm",
|
||||||
|
hidden_size=1024,
|
||||||
|
dim_model_base=1024,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=2048,
|
||||||
|
num_attention_heads=4,
|
||||||
|
rms_norm_eps=1e-4,
|
||||||
|
vocab_size=10000,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
scale_depth=1.0,
|
||||||
|
scale_emb=1.0,
|
||||||
|
)
|
||||||
|
model = minicpm.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_openelm(self):
|
||||||
|
from mlx_lm.models import openelm
|
||||||
|
|
||||||
|
args = openelm.ModelArgs(
|
||||||
|
model_type="openelm",
|
||||||
|
ffn_dim_divisor=256,
|
||||||
|
ffn_multipliers=[
|
||||||
|
0.5,
|
||||||
|
0.73,
|
||||||
|
0.97,
|
||||||
|
1.2,
|
||||||
|
1.43,
|
||||||
|
1.67,
|
||||||
|
1.9,
|
||||||
|
2.13,
|
||||||
|
2.37,
|
||||||
|
2.6,
|
||||||
|
2.83,
|
||||||
|
3.07,
|
||||||
|
3.3,
|
||||||
|
3.53,
|
||||||
|
3.77,
|
||||||
|
4.0,
|
||||||
|
],
|
||||||
|
head_dim=64,
|
||||||
|
model_dim=1280,
|
||||||
|
normalize_qk_projections=True,
|
||||||
|
num_kv_heads=[3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5],
|
||||||
|
num_query_heads=[
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
12,
|
||||||
|
16,
|
||||||
|
16,
|
||||||
|
16,
|
||||||
|
16,
|
||||||
|
16,
|
||||||
|
16,
|
||||||
|
16,
|
||||||
|
20,
|
||||||
|
20,
|
||||||
|
20,
|
||||||
|
20,
|
||||||
|
],
|
||||||
|
num_transformer_layers=16,
|
||||||
|
vocab_size=32000,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = openelm.Model(args)
|
||||||
|
self.model_test_runner(
|
||||||
|
model,
|
||||||
|
args.model_type,
|
||||||
|
args.vocab_size,
|
||||||
|
len(args.ffn_multipliers),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user