Enable unit testing in Circle and start some MLX LM tests (#545)

* add a few tests for mlx lm

* add a few tests for mlx lm

* add a few tests for mlx lm

* more tests / cleanup
This commit is contained in:
Awni Hannun
2024-03-07 09:31:57 -08:00
committed by GitHub
parent ef32379bc6
commit 7cdd1b69ac
12 changed files with 294 additions and 20 deletions

View File

@@ -141,8 +141,7 @@ class QwenModel(nn.Module):
for e, layer in enumerate(self.h):
x, cache[e] = layer(x, mask, cache[e])
x = self.ln_f(x[:, T - 1 : T, :])
return x, cache
return self.ln_f(x), cache
class Model(nn.Module):
@@ -162,3 +161,7 @@ class Model(nn.Module):
) -> Tuple[mx.array, mx.array]:
y, cache = self.transformer(x, mask, cache)
return self.lm_head(y), cache
@property
def layers(self):
return self.transformer.h