Enable unit testing in Circle and start some MLX LM tests (#545)

* add a few tests for mlx lm

* add a few tests for mlx lm

* add a few tests for mlx lm

* more tests / cleanup
This commit is contained in:
Awni Hannun
2024-03-07 09:31:57 -08:00
committed by GitHub
parent ef32379bc6
commit 7cdd1b69ac
12 changed files with 294 additions and 20 deletions

View File

@@ -18,7 +18,7 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
n_shared_head: int = (8,)
n_shared_head: int = 8
rope_theta: float = 10000
rope_traditional: bool = False
@@ -80,16 +80,11 @@ class Attention(nn.Module):
bsz, q_len, self.v_num_heads, self.v_dim
).transpose(0, 2, 1, 3)
def _expand_kv(a: mx.array) -> mx.array:
a = mx.concatenate(
[mx.expand_dims(a, 1)] * self.config.n_shared_head, axis=1
)
return a.reshape([bsz, self.q_num_heads, q_len, -1])
# expand shared kv
assert self.k_num_heads == self.v_num_heads
key_states = _expand_kv(key_states)
value_states = _expand_kv(value_states)
repeats = self.config.n_shared_head
key_states = mx.repeat(key_states, repeats, axis=1)
value_states = mx.repeat(value_states, repeats, axis=1)
kv_seq_len = 0
if cache is not None:
@@ -222,3 +217,7 @@ class Model(nn.Module):
) -> Tuple[mx.array, mx.array]:
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
@property
def layers(self):
return self.model.layers.layers