mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Enable unit testing in Circle and start some MLX LM tests (#545)
* add a few tests for mlx lm * add a few tests for mlx lm * add a few tests for mlx lm * more tests / cleanup
This commit is contained in:
@@ -18,7 +18,7 @@ class ModelArgs(BaseModelArgs):
|
||||
num_attention_heads: int
|
||||
rms_norm_eps: float
|
||||
vocab_size: int
|
||||
n_shared_head: int = (8,)
|
||||
n_shared_head: int = 8
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
|
||||
@@ -80,16 +80,11 @@ class Attention(nn.Module):
|
||||
bsz, q_len, self.v_num_heads, self.v_dim
|
||||
).transpose(0, 2, 1, 3)
|
||||
|
||||
def _expand_kv(a: mx.array) -> mx.array:
|
||||
a = mx.concatenate(
|
||||
[mx.expand_dims(a, 1)] * self.config.n_shared_head, axis=1
|
||||
)
|
||||
return a.reshape([bsz, self.q_num_heads, q_len, -1])
|
||||
|
||||
# expand shared kv
|
||||
assert self.k_num_heads == self.v_num_heads
|
||||
key_states = _expand_kv(key_states)
|
||||
value_states = _expand_kv(value_states)
|
||||
repeats = self.config.n_shared_head
|
||||
key_states = mx.repeat(key_states, repeats, axis=1)
|
||||
value_states = mx.repeat(value_states, repeats, axis=1)
|
||||
|
||||
kv_seq_len = 0
|
||||
if cache is not None:
|
||||
@@ -222,3 +217,7 @@ class Model(nn.Module):
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
out, cache = self.model(inputs, cache)
|
||||
return self.lm_head(out), cache
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers.layers
|
||||
|
Reference in New Issue
Block a user