mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-08 18:06:37 +08:00
parent
ee60e2a9d5
commit
fad9598372
@ -159,7 +159,7 @@ class LlamaModel(nn.Module):
|
||||
mask = None
|
||||
if h.shape[1] > 1:
|
||||
mask = create_additive_causal_mask(
|
||||
h.shape[1], cache[0][0].shape[2] if cache is not None else 0
|
||||
h.shape[1], cache[0].offset if cache is not None else 0
|
||||
)
|
||||
mask = mask.astype(h.dtype)
|
||||
|
||||
|
@ -29,6 +29,10 @@ class TestModels(unittest.TestCase):
|
||||
)
|
||||
cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
||||
|
||||
outputs = model(inputs, cache)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
|
||||
self.assertEqual(outputs.shape, (1, 1, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
Loading…
Reference in New Issue
Block a user