mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 17:37:56 +08:00
length masking
This commit is contained in:
parent
845efddc8c
commit
c5ce9a31f2
@ -23,7 +23,12 @@ class BaseModelArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
|
def create_causal_mask(
|
||||||
|
N: int,
|
||||||
|
offset: int = 0,
|
||||||
|
window_size: Optional[int] = None,
|
||||||
|
lengths: Optional[mx.array] = None,
|
||||||
|
):
|
||||||
rinds = mx.arange(offset + N)
|
rinds = mx.arange(offset + N)
|
||||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||||
linds = linds[:, None]
|
linds = linds[:, None]
|
||||||
@ -31,12 +36,18 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
|
|||||||
mask = linds < rinds
|
mask = linds < rinds
|
||||||
if window_size is not None:
|
if window_size is not None:
|
||||||
mask = mask | (linds > rinds + window_size)
|
mask = mask | (linds > rinds + window_size)
|
||||||
|
if lengths is not None:
|
||||||
|
mask = mx.repeat(mask[None], lengths.shape[0], axis=0)
|
||||||
|
lengths = lengths[:, None, None]
|
||||||
|
mask = mask | (rinds[None] >= lengths)
|
||||||
|
mask = mask[:, None]
|
||||||
return mask * -1e9
|
return mask * -1e9
|
||||||
|
|
||||||
|
|
||||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||||
T = h.shape[1]
|
T = h.shape[1]
|
||||||
if T > 1:
|
if T > 1:
|
||||||
|
lengths = None
|
||||||
window_size = None
|
window_size = None
|
||||||
offset = 0
|
offset = 0
|
||||||
if cache is not None and cache[0] is not None:
|
if cache is not None and cache[0] is not None:
|
||||||
@ -46,7 +57,8 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
|||||||
window_size = c.max_size
|
window_size = c.max_size
|
||||||
else:
|
else:
|
||||||
offset = c.offset
|
offset = c.offset
|
||||||
mask = create_causal_mask(T, offset, window_size=window_size)
|
lengths = getattr(c, "lengths", None)
|
||||||
|
mask = create_causal_mask(T, offset, window_size=window_size, lengths=lengths)
|
||||||
mask = mask.astype(h.dtype)
|
mask = mask.astype(h.dtype)
|
||||||
else:
|
else:
|
||||||
mask = None
|
mask = None
|
||||||
|
@ -10,6 +10,7 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
|||||||
def make_prompt_cache(
|
def make_prompt_cache(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
max_kv_size: Optional[int] = None,
|
max_kv_size: Optional[int] = None,
|
||||||
|
lengths: Optional[mx.array] = None,
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
Construct the model's cache for use when cgeneration.
|
Construct the model's cache for use when cgeneration.
|
||||||
@ -22,17 +23,22 @@ def make_prompt_cache(
|
|||||||
max_kv_size (Optional[int]): If provided and the model does not have a
|
max_kv_size (Optional[int]): If provided and the model does not have a
|
||||||
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
||||||
size of ``max_kv_size``
|
size of ``max_kv_size``
|
||||||
|
lengths (Optional[array]): If provided these sequence lengths will be
|
||||||
|
used mask the KV cache. Useful for batch inputs.
|
||||||
"""
|
"""
|
||||||
if hasattr(model, "make_cache"):
|
if hasattr(model, "make_cache"):
|
||||||
return model.make_cache()
|
return model.make_cache()
|
||||||
|
|
||||||
num_layers = len(model.layers)
|
num_layers = len(model.layers)
|
||||||
if max_kv_size is not None:
|
if max_kv_size is not None:
|
||||||
return [
|
cache = [
|
||||||
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
return [KVCache() for _ in range(num_layers)]
|
cache = [KVCache() for _ in range(num_layers)]
|
||||||
|
|
||||||
|
cache[0].lengths = lengths
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
|
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
|
||||||
@ -135,6 +141,7 @@ class QuantizedKVCache(_BaseCache):
|
|||||||
self.values = None
|
self.values = None
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
self.step = 256
|
self.step = 256
|
||||||
|
self.lengths = None
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
|
|
||||||
@ -217,6 +224,7 @@ class KVCache(_BaseCache):
|
|||||||
self.values = None
|
self.values = None
|
||||||
self.offset = 0
|
self.offset = 0
|
||||||
self.step = 256
|
self.step = 256
|
||||||
|
self.lengths = None
|
||||||
|
|
||||||
def update_and_fetch(self, keys, values):
|
def update_and_fetch(self, keys, values):
|
||||||
prev = self.offset
|
prev = self.offset
|
||||||
@ -285,6 +293,7 @@ class RotatingKVCache(_BaseCache):
|
|||||||
self.offset = 0
|
self.offset = 0
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self.step = step
|
self.step = step
|
||||||
|
self.lengths = None
|
||||||
self._idx = 0
|
self._idx = 0
|
||||||
|
|
||||||
def _trim(self, trim_size, v, append=None):
|
def _trim(self, trim_size, v, append=None):
|
||||||
|
@ -5,6 +5,7 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_map
|
from mlx.utils import tree_map
|
||||||
from mlx_lm.models import rope_utils
|
from mlx_lm.models import rope_utils
|
||||||
|
from mlx_lm.models.base import create_attention_mask
|
||||||
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
||||||
|
|
||||||
|
|
||||||
@ -128,6 +129,25 @@ class TestModels(unittest.TestCase):
|
|||||||
self.assertEqual(cache.offset, 22)
|
self.assertEqual(cache.offset, 22)
|
||||||
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
||||||
|
|
||||||
|
def test_cache_lengths(self):
|
||||||
|
mx.random.seed(8)
|
||||||
|
B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2)
|
||||||
|
lengths = mx.array([1, 2, 3, 1])
|
||||||
|
h = mx.random.uniform(shape=(B, T_q, D))
|
||||||
|
q = mx.random.uniform(shape=(B, N_q, T_q, D))
|
||||||
|
k = mx.random.uniform(shape=(B, N_kv, T_kv, D))
|
||||||
|
v = k
|
||||||
|
cache = [KVCache()]
|
||||||
|
cache[0].lengths = lengths
|
||||||
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
|
out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||||
|
q[1, :, 2:] = mx.ones_like(q[1, :, 2:])
|
||||||
|
k[1, :, 2:] = mx.ones_like(k[1, :, 2:])
|
||||||
|
v[1, :, 2:] = mx.ones_like(v[1, :, 2:])
|
||||||
|
out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask)
|
||||||
|
self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2]))
|
||||||
|
|
||||||
def test_rope(self):
|
def test_rope(self):
|
||||||
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
|
rope = rope_utils.initialize_rope(32, base=100, traditional=False)
|
||||||
self.assertTrue(isinstance(rope, nn.RoPE))
|
self.assertTrue(isinstance(rope, nn.RoPE))
|
||||||
|
Loading…
Reference in New Issue
Block a user