mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-18 10:26:58 +08:00
Merge branch 'ml-explore:main' into adding-support-for-mamba2
This commit is contained in:
@@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False):
|
||||
return cache
|
||||
|
||||
|
||||
def can_trim_prompt_cache(cache: List[Any]) -> bool:
|
||||
"""
|
||||
Check if model's cache can be trimmed.
|
||||
"""
|
||||
return all(c.is_trimmable() for c in cache)
|
||||
|
||||
|
||||
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||
"""
|
||||
Trim the model's cache by the given number of tokens.
|
||||
@@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||
Returns:
|
||||
(int): The number of tokens that were trimmed.
|
||||
"""
|
||||
if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
|
||||
if not can_trim_prompt_cache(cache) or len(cache) == 0:
|
||||
return 0
|
||||
return [c.trim(num_tokens) for c in cache][0]
|
||||
|
||||
|
@@ -220,17 +220,17 @@ class DeepseekV2Attention(nn.Module):
|
||||
|
||||
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
|
||||
|
||||
k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
q_pe = self.rope(q_pe, cache.offset)
|
||||
k_pe = self.rope(k_pe, cache.offset)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys, values = cache.update_and_fetch(
|
||||
mx.concatenate([k_nope, k_pe], axis=-1), values
|
||||
)
|
||||
else:
|
||||
q_pe = self.rope(q_pe)
|
||||
k_pe = self.rope(k_pe)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys = mx.concatenate([k_nope, k_pe], axis=-1)
|
||||
|
||||
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
||||
@@ -291,7 +291,7 @@ class MoEGate(nn.Module):
|
||||
scores = scores.reshape(bsz, seq_len, -1)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
|
||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
scores = scores * self.routed_scaling_factor
|
||||
|
||||
|
Reference in New Issue
Block a user