mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 06:28:07 +08:00
Enable unit testing in Circle and start some MLX LM tests (#545)
* add a few tests for mlx lm * add a few tests for mlx lm * add a few tests for mlx lm * more tests / cleanup
This commit is contained in:
@@ -18,7 +18,7 @@ class ModelArgs(BaseModelArgs):
|
||||
head_dim: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: int = None
|
||||
num_key_value_heads: int
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
|
||||
|
@@ -13,7 +13,6 @@ from .layers import RMSNorm
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int = 32000
|
||||
max_position_embeddings: int = 4096 * 32
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 14336
|
||||
num_hidden_layers: int = 32
|
||||
@@ -38,7 +37,6 @@ class MixtralAttention(nn.Module):
|
||||
self.num_heads = args.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = args.num_key_value_heads
|
||||
self.max_position_embeddings = args.max_position_embeddings
|
||||
self.rope_theta = args.rope_theta
|
||||
|
||||
self.repeats = self.num_heads // self.num_key_value_heads
|
||||
|
@@ -24,7 +24,6 @@ class ModelArgs(BaseModelArgs):
|
||||
n_heads: int
|
||||
vocab_size: int
|
||||
embedding_size: int
|
||||
model_type: str
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
mlp_ratio: int = 4
|
||||
|
@@ -11,7 +11,7 @@ from .layers import LayerNorm
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
model_type: str = "phi"
|
||||
max_position_embeddings: int = 2048
|
||||
vocab_size: int = 51200
|
||||
hidden_size: int = 2560
|
||||
|
@@ -18,7 +18,7 @@ class ModelArgs(BaseModelArgs):
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
n_shared_head: int = (8,)
|
||||
n_shared_head: int = 8
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
|
||||
@@ -80,16 +80,11 @@ class Attention(nn.Module):
|
||||
bsz, q_len, self.v_num_heads, self.v_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
def _expand_kv(a: mx.array) -> mx.array:
|
||||
a = mx.concatenate(
|
||||
[mx.expand_dims(a, 1)] * self.config.n_shared_head, axis=1
|
||||
)
|
||||
return a.reshape([bsz, self.q_num_heads, q_len, -1])
|
||||
|
||||
# expand shared kv
|
||||
assert self.k_num_heads == self.v_num_heads
|
||||
key_states = _expand_kv(key_states)
|
||||
value_states = _expand_kv(value_states)
|
||||
repeats = self.config.n_shared_head
|
||||
key_states = mx.repeat(key_states, repeats, axis=1)
|
||||
value_states = mx.repeat(value_states, repeats, axis=1)
|
||||
|
||||
kv_seq_len = 0
|
||||
if cache is not None:
|
||||
@@ -222,3 +217,7 @@ class Model(nn.Module):
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
out, cache = self.model(inputs, cache)
|
||||
return self.lm_head(out), cache
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers.layers
|
||||
|
@@ -141,8 +141,7 @@ class QwenModel(nn.Module):
|
||||
for e, layer in enumerate(self.h):
|
||||
x, cache[e] = layer(x, mask, cache[e])
|
||||
|
||||
x = self.ln_f(x[:, T - 1 : T, :])
|
||||
return x, cache
|
||||
return self.ln_f(x), cache
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
@@ -162,3 +161,7 @@ class Model(nn.Module):
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
y, cache = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y), cache
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
|
@@ -11,7 +11,6 @@ from .layers import LayerNorm
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
max_position_embeddings: int
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
|
@@ -15,10 +15,8 @@ class ModelArgs(BaseModelArgs):
|
||||
num_hidden_layers: int
|
||||
intermediate_size: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int = None
|
||||
max_position_embeddings: int = 16384
|
||||
num_key_value_heads: int
|
||||
norm_epsilon: float = 1e-5
|
||||
norm_type: str = "layer_norm"
|
||||
vocab_size: int = 49152
|
||||
rope_theta: float = 100000
|
||||
tie_word_embeddings: bool = True
|
||||
|
Reference in New Issue
Block a user