mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
add mask to mlx_lm model interface
This commit is contained in:
parent
c5ce9a31f2
commit
cd9dcf0383
@ -10,7 +10,6 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
|||||||
def make_prompt_cache(
|
def make_prompt_cache(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
max_kv_size: Optional[int] = None,
|
max_kv_size: Optional[int] = None,
|
||||||
lengths: Optional[mx.array] = None,
|
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
Construct the model's cache for use when cgeneration.
|
Construct the model's cache for use when cgeneration.
|
||||||
@ -23,22 +22,17 @@ def make_prompt_cache(
|
|||||||
max_kv_size (Optional[int]): If provided and the model does not have a
|
max_kv_size (Optional[int]): If provided and the model does not have a
|
||||||
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
||||||
size of ``max_kv_size``
|
size of ``max_kv_size``
|
||||||
lengths (Optional[array]): If provided these sequence lengths will be
|
|
||||||
used mask the KV cache. Useful for batch inputs.
|
|
||||||
"""
|
"""
|
||||||
if hasattr(model, "make_cache"):
|
if hasattr(model, "make_cache"):
|
||||||
return model.make_cache()
|
return model.make_cache()
|
||||||
|
|
||||||
num_layers = len(model.layers)
|
num_layers = len(model.layers)
|
||||||
if max_kv_size is not None:
|
if max_kv_size is not None:
|
||||||
cache = [
|
return [
|
||||||
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
cache = [KVCache() for _ in range(num_layers)]
|
return [KVCache() for _ in range(num_layers)]
|
||||||
|
|
||||||
cache[0].lengths = lengths
|
|
||||||
return cache
|
|
||||||
|
|
||||||
|
|
||||||
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
|
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
|
||||||
|
@ -155,11 +155,12 @@ class CohereModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -180,9 +181,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
out = out * self.model.args.logit_scale
|
out = out * self.model.args.logit_scale
|
||||||
return out
|
return out
|
||||||
|
@ -6,7 +6,7 @@ from typing import Optional, Tuple
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention
|
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
|
||||||
from .cache import KVCache, RotatingKVCache
|
from .cache import KVCache, RotatingKVCache
|
||||||
|
|
||||||
|
|
||||||
@ -151,16 +151,12 @@ class CohereModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
T = h.shape[1]
|
mask = mask or create_attention_mask(h, cache)
|
||||||
if T > 1:
|
|
||||||
offset = cache[0].offset if cache else 0
|
|
||||||
mask = create_causal_mask(T, offset).astype(h.dtype)
|
|
||||||
else:
|
|
||||||
mask = None
|
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -181,9 +177,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
out = out * self.model.args.logit_scale
|
out = out * self.model.args.logit_scale
|
||||||
return out
|
return out
|
||||||
|
@ -197,11 +197,12 @@ class DBRX(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.wte(inputs)
|
h = self.wte(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.blocks)
|
cache = [None] * len(self.blocks)
|
||||||
@ -223,9 +224,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.transformer(inputs, cache)
|
out = self.transformer(inputs, mask, cache)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -211,9 +211,10 @@ class DeepseekModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
cache: Optional[Any] = None,
|
cache: Optional[Any] = None,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(x)
|
h = self.embed_tokens(x)
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -236,8 +237,9 @@ class Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache: Optional[Any] = None,
|
cache: Optional[Any] = None,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, cache, mask)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
@ -370,9 +370,10 @@ class DeepseekV2Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
cache: Optional[Any] = None,
|
cache: Optional[Any] = None,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(x)
|
h = self.embed_tokens(x)
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -395,8 +396,9 @@ class Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache: Optional[Any] = None,
|
cache: Optional[Any] = None,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, cache, mask)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
@ -123,10 +123,11 @@ class ExaoneModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.wte(inputs)
|
h = self.wte(inputs)
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.h)
|
cache = [None] * len(self.h)
|
||||||
@ -149,9 +150,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.transformer(inputs, cache)
|
out = self.transformer(inputs, mask, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.transformer.wte.as_linear(out)
|
out = self.transformer.wte.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -138,12 +138,13 @@ class GemmaModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
h = h * (self.args.hidden_size**0.5)
|
h = h * (self.args.hidden_size**0.5)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -164,9 +165,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -160,12 +160,13 @@ class GemmaModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
h = h * (self.args.hidden_size**0.5)
|
h = h * (self.args.hidden_size**0.5)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -187,9 +188,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
out = mx.tanh(out / self.final_logit_softcapping)
|
out = mx.tanh(out / self.final_logit_softcapping)
|
||||||
out = out * self.final_logit_softcapping
|
out = out * self.final_logit_softcapping
|
||||||
|
@ -126,6 +126,7 @@ class GPT2Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
_, L = inputs.shape
|
_, L = inputs.shape
|
||||||
@ -138,7 +139,7 @@ class GPT2Model(nn.Module):
|
|||||||
position_ids = mx.array(np.arange(L))
|
position_ids = mx.array(np.arange(L))
|
||||||
hidden_states += self.wpe(position_ids)
|
hidden_states += self.wpe(position_ids)
|
||||||
|
|
||||||
mask = create_attention_mask(hidden_states, cache)
|
mask = mask or create_attention_mask(hidden_states, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.h)
|
cache = [None] * len(self.h)
|
||||||
@ -159,9 +160,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
out = self.model.wte.as_linear(out)
|
out = self.model.wte.as_linear(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -137,6 +137,7 @@ class GPTBigCodeModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
B, L = inputs.shape
|
B, L = inputs.shape
|
||||||
@ -149,7 +150,7 @@ class GPTBigCodeModel(nn.Module):
|
|||||||
position_ids = mx.array(np.arange(L))
|
position_ids = mx.array(np.arange(L))
|
||||||
hidden_states += self.wpe(position_ids)
|
hidden_states += self.wpe(position_ids)
|
||||||
|
|
||||||
mask = create_attention_mask(hidden_states, cache)
|
mask = mask or create_attention_mask(hidden_states, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.h)
|
cache = [None] * len(self.h)
|
||||||
@ -172,9 +173,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.transformer(inputs, cache)
|
out = self.transformer(inputs, mask, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.transformer.wte.as_linear(out)
|
out = self.transformer.wte.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -146,13 +146,14 @@ class GPTNeoXModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
_, L = inputs.shape
|
_, L = inputs.shape
|
||||||
|
|
||||||
hidden_states = self.embed_in(inputs)
|
hidden_states = self.embed_in(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(hidden_states, cache)
|
mask = mask or create_attention_mask(hidden_states, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.h)
|
cache = [None] * len(self.h)
|
||||||
@ -176,9 +177,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
@ -239,11 +239,12 @@ class HunYuanModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -266,9 +267,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
return self.model.embed_tokens.as_linear(out)
|
return self.model.embed_tokens.as_linear(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
@ -193,11 +193,12 @@ class InternLM2Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.tok_embeddings(inputs)
|
h = self.tok_embeddings(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -220,9 +221,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.model.tok_embeddings.as_linear(out)
|
out = self.model.tok_embeddings.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -155,11 +155,12 @@ class LlamaModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -182,9 +183,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -158,11 +158,12 @@ class MiniCPMModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs) * self.args.scale_emb
|
h = self.embed_tokens(inputs) * self.args.scale_emb
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -186,9 +187,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
|
|
||||||
if not self.args.tie_word_embeddings:
|
if not self.args.tie_word_embeddings:
|
||||||
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))
|
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))
|
||||||
|
@ -162,11 +162,12 @@ class MixtralModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -188,9 +189,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
@ -176,11 +176,12 @@ class NemotronModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -203,9 +204,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -124,11 +124,12 @@ class Transformer(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.wte(inputs)
|
h = self.wte(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.blocks)
|
cache = [None] * len(self.blocks)
|
||||||
@ -152,9 +153,10 @@ class OlmoModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
return self.transformer(inputs, cache)
|
return self.transformer(inputs, mask, cache)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
@ -167,9 +169,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
return self.model(inputs, cache)
|
return self.model(inputs, mask, cache)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
|
@ -163,10 +163,11 @@ class LlamaModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
|
mask=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -190,8 +191,9 @@ class Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache=None,
|
cache=None,
|
||||||
|
mask=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, cache, mask)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -178,11 +178,12 @@ class OpenELMModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.token_embeddings(inputs)
|
h = self.token_embeddings(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -205,9 +206,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.transformer(inputs, cache)
|
out = self.transformer(inputs, mask, cache)
|
||||||
if self.args.share_input_output_layers:
|
if self.args.share_input_output_layers:
|
||||||
out = self.transformer.token_embeddings.as_linear(out)
|
out = self.transformer.token_embeddings.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -143,10 +143,10 @@ class PhiModel(nn.Module):
|
|||||||
config.hidden_size, eps=config.layer_norm_eps
|
config.hidden_size, eps=config.layer_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x, cache):
|
def __call__(self, x, mask, cache):
|
||||||
x = self.embed_tokens(x)
|
x = self.embed_tokens(x)
|
||||||
|
|
||||||
mask = create_attention_mask(x, cache)
|
mask = mask or create_attention_mask(x, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -167,9 +167,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
y = self.model(x, cache)
|
y = self.model(x, mask, cache)
|
||||||
return self.lm_head(y)
|
return self.lm_head(y)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -168,11 +168,12 @@ class Phi3Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -194,9 +195,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -258,13 +258,14 @@ class Phi3Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
if self.mup_embedding_multiplier:
|
if self.mup_embedding_multiplier:
|
||||||
h = self.mup_embedding_multiplier * h
|
h = self.mup_embedding_multiplier * h
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -290,9 +291,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
if self.mup_width_multiplier:
|
if self.mup_width_multiplier:
|
||||||
out = out / self.mup_width_multiplier
|
out = out / self.mup_width_multiplier
|
||||||
|
@ -155,11 +155,12 @@ class PhiMoEModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -181,9 +182,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
@ -175,7 +175,7 @@ class Model(nn.Module):
|
|||||||
mask: mx.array = None,
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
mask = create_attention_mask(x, cache)
|
mask = mask or create_attention_mask(x, cache)
|
||||||
|
|
||||||
y = self.transformer(x, mask, cache)
|
y = self.transformer(x, mask, cache)
|
||||||
return self.lm_head(y)
|
return self.lm_head(y)
|
||||||
|
@ -174,10 +174,11 @@ class PlamoModel(nn.Module):
|
|||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache: Optional[Any] = None,
|
cache: Optional[Any] = None,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None for _ in range(len(self.layers.layers))]
|
cache = [None for _ in range(len(self.layers.layers))]
|
||||||
@ -202,8 +203,9 @@ class Model(nn.Module):
|
|||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
cache: Optional[Any] = None,
|
cache: Optional[Any] = None,
|
||||||
|
mask: Optional[mx.array] = None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, cache, mask)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -123,7 +123,7 @@ class QwenModel(nn.Module):
|
|||||||
def __call__(self, inputs, mask=None, cache=None):
|
def __call__(self, inputs, mask=None, cache=None):
|
||||||
x = self.wte(inputs)
|
x = self.wte(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(x, cache)
|
mask = mask or create_attention_mask(x, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.h)
|
cache = [None] * len(self.h)
|
||||||
|
@ -149,11 +149,12 @@ class Qwen2Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -176,9 +177,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -187,11 +187,12 @@ class Qwen2MoeModel(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -213,9 +214,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
return self.lm_head(out)
|
return self.lm_head(out)
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
|
@ -389,6 +389,7 @@ class Griffin(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
tokens,
|
tokens,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
x = self.embed_tokens(tokens)
|
x = self.embed_tokens(tokens)
|
||||||
@ -402,7 +403,7 @@ class Griffin(nn.Module):
|
|||||||
if block.temporal_block_type != "recurrent":
|
if block.temporal_block_type != "recurrent":
|
||||||
mask_cache = [cache[i]]
|
mask_cache = [cache[i]]
|
||||||
|
|
||||||
mask = create_attention_mask(x, mask_cache)
|
mask = mask or create_attention_mask(x, mask_cache)
|
||||||
|
|
||||||
for i, block in enumerate(self.layers):
|
for i, block in enumerate(self.layers):
|
||||||
x = block(x, mask=mask, cache=cache[i])
|
x = block(x, mask=mask, cache=cache[i])
|
||||||
@ -418,12 +419,12 @@ class Model(nn.Module):
|
|||||||
self.model_type = config.model_type
|
self.model_type = config.model_type
|
||||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
|
def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
tokens: Sequence of input tokens.
|
tokens: Sequence of input tokens.
|
||||||
"""
|
"""
|
||||||
logits = self.model(tokens, cache=cache)
|
logits = self.model(tokens, mask=mask, cache=cache)
|
||||||
if "lm_head" in self:
|
if "lm_head" in self:
|
||||||
logits = self.lm_head(logits)
|
logits = self.lm_head(logits)
|
||||||
else:
|
else:
|
||||||
|
@ -199,7 +199,7 @@ class Model(nn.Module):
|
|||||||
mask: mx.array = None,
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
mask = create_attention_mask(x, cache)
|
mask = mask or create_attention_mask(x, cache)
|
||||||
y = self.model(x, mask, cache)
|
y = self.model(x, mask, cache)
|
||||||
return self.lm_head(y)
|
return self.lm_head(y)
|
||||||
|
|
||||||
|
@ -125,11 +125,12 @@ class Starcoder2Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
h = self.embed_tokens(inputs)
|
h = self.embed_tokens(inputs)
|
||||||
|
|
||||||
mask = create_attention_mask(h, cache)
|
mask = mask or create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * len(self.layers)
|
||||||
@ -152,9 +153,10 @@ class Model(nn.Module):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
mask: mx.array = None,
|
||||||
cache=None,
|
cache=None,
|
||||||
):
|
):
|
||||||
out = self.model(inputs, cache)
|
out = self.model(inputs, mask, cache)
|
||||||
if self.args.tie_word_embeddings:
|
if self.args.tie_word_embeddings:
|
||||||
out = self.model.embed_tokens.as_linear(out)
|
out = self.model.embed_tokens.as_linear(out)
|
||||||
else:
|
else:
|
||||||
|
@ -182,7 +182,7 @@ class TestModels(unittest.TestCase):
|
|||||||
self.assertEqual(outputs.dtype, t)
|
self.assertEqual(outputs.dtype, t)
|
||||||
|
|
||||||
cache = make_prompt_cache(model)
|
cache = make_prompt_cache(model)
|
||||||
outputs = model(inputs, cache)
|
outputs = model(inputs, cache=cache)
|
||||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||||
self.assertEqual(outputs.dtype, t)
|
self.assertEqual(outputs.dtype, t)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user