Enable unit testing in Circle and start some MLX LM tests (#545)

* add a few tests for mlx lm

* add a few tests for mlx lm

* add a few tests for mlx lm

* more tests / cleanup
This commit is contained in:
Awni Hannun
2024-03-07 09:31:57 -08:00
committed by GitHub
parent ef32379bc6
commit 7cdd1b69ac
12 changed files with 294 additions and 20 deletions

View File

@@ -18,7 +18,7 @@ class ModelArgs(BaseModelArgs):
head_dim: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int = None
num_key_value_heads: int
rope_theta: float = 10000
rope_traditional: bool = False

View File

@@ -13,7 +13,6 @@ from .layers import RMSNorm
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int = 32000
max_position_embeddings: int = 4096 * 32
hidden_size: int = 4096
intermediate_size: int = 14336
num_hidden_layers: int = 32
@@ -38,7 +37,6 @@ class MixtralAttention(nn.Module):
self.num_heads = args.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = args.num_key_value_heads
self.max_position_embeddings = args.max_position_embeddings
self.rope_theta = args.rope_theta
self.repeats = self.num_heads // self.num_key_value_heads

View File

@@ -24,7 +24,6 @@ class ModelArgs(BaseModelArgs):
n_heads: int
vocab_size: int
embedding_size: int
model_type: str
rope_theta: float = 10000
rope_traditional: bool = False
mlp_ratio: int = 4

View File

@@ -11,7 +11,7 @@ from .layers import LayerNorm
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
model_type: str = "phi"
max_position_embeddings: int = 2048
vocab_size: int = 51200
hidden_size: int = 2560

View File

@@ -18,7 +18,7 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
n_shared_head: int = (8,)
n_shared_head: int = 8
rope_theta: float = 10000
rope_traditional: bool = False
@@ -80,16 +80,11 @@ class Attention(nn.Module):
bsz, q_len, self.v_num_heads, self.v_dim
).transpose(0, 2, 1, 3)
def _expand_kv(a: mx.array) -> mx.array:
a = mx.concatenate(
[mx.expand_dims(a, 1)] * self.config.n_shared_head, axis=1
)
return a.reshape([bsz, self.q_num_heads, q_len, -1])
# expand shared kv
assert self.k_num_heads == self.v_num_heads
key_states = _expand_kv(key_states)
value_states = _expand_kv(value_states)
repeats = self.config.n_shared_head
key_states = mx.repeat(key_states, repeats, axis=1)
value_states = mx.repeat(value_states, repeats, axis=1)
kv_seq_len = 0
if cache is not None:
@@ -222,3 +217,7 @@ class Model(nn.Module):
) -> Tuple[mx.array, mx.array]:
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
@property
def layers(self):
return self.model.layers.layers

View File

@@ -141,8 +141,7 @@ class QwenModel(nn.Module):
for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e])
x = self.ln_f(x[:, T - 1 : T, :])
return x, cache
return self.ln_f(x), cache
class Model(nn.Module):
@@ -162,3 +161,7 @@ class Model(nn.Module):
) -> Tuple[mx.array, mx.array]:
y, cache = self.transformer(x, mask, cache)
return self.lm_head(y), cache
@property
def layers(self):
return self.transformer.h

View File

@@ -11,7 +11,6 @@ from .layers import LayerNorm
@dataclass
class ModelArgs(BaseModelArgs):
max_position_embeddings: int
model_type: str
vocab_size: int
hidden_size: int

View File

@@ -15,10 +15,8 @@ class ModelArgs(BaseModelArgs):
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_key_value_heads: int = None
max_position_embeddings: int = 16384
num_key_value_heads: int
norm_epsilon: float = 1e-5
norm_type: str = "layer_norm"
vocab_size: int = 49152
rope_theta: float = 100000
tie_word_embeddings: bool = True