mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Length masking for batch inputs (#1173)
* length masking * add mask to mlx_lm model interface * remove lengths * fix test: * comment + fix
This commit is contained in:
@@ -5,6 +5,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_map
|
||||
from mlx_lm.models import rope_utils
|
||||
from mlx_lm.models.base import create_causal_mask
|
||||
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
||||
|
||||
|
||||
@@ -128,6 +129,22 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(cache.offset, 22)
|
||||
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
||||
|
||||
def test_causal_mask_lengths(self):
|
||||
mx.random.seed(8)
|
||||
B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2)
|
||||
lengths = mx.array([1, 2, 3, 1])
|
||||
q = mx.random.uniform(shape=(B, N_q, T_q, D))
|
||||
k = mx.random.uniform(shape=(B, N_kv, T_kv, D))
|
||||
v = k
|
||||
mask = create_causal_mask(T_q, 0, lengths=lengths)
|
||||
|
||||
out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
q[1, :, 2:] = mx.ones_like(q[1, :, 2:])
|
||||
k[1, :, 2:] = mx.ones_like(k[1, :, 2:])
|
||||
v[1, :, 2:] = mx.ones_like(v[1, :, 2:])
|
||||
out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||
self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2]))
|
||||
|
||||
def test_rope(self):
|
||||
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
|
||||
self.assertTrue(isinstance(rope, nn.RoPE))
|
||||
@@ -162,10 +179,16 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
cache = make_prompt_cache(model)
|
||||
outputs = model(inputs, cache)
|
||||
outputs = model(inputs, cache=cache)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
if model_type != "mamba":
|
||||
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
|
||||
outputs = model(inputs, mask=mask)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
|
||||
self.assertEqual(outputs.shape, (1, 1, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
Reference in New Issue
Block a user