mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 03:01:34 +08:00
comment + fix
This commit is contained in:
parent
eb9452beb9
commit
5b414dddf2
@ -37,10 +37,8 @@ def create_causal_mask(
|
||||
if window_size is not None:
|
||||
mask = mask | (linds > rinds + window_size)
|
||||
if lengths is not None:
|
||||
mask = mx.repeat(mask[None], lengths.shape[0], axis=0)
|
||||
lengths = lengths[:, None, None]
|
||||
mask = mask | (rinds[None] >= lengths)
|
||||
mask = mask[:, None]
|
||||
lengths = lengths[:, None, None, None]
|
||||
mask = mask | (rinds >= lengths)
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
|
@ -160,7 +160,8 @@ class CohereModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -156,7 +156,8 @@ class CohereModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -202,7 +202,8 @@ class DBRX(nn.Module):
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.blocks)
|
||||
|
@ -214,7 +214,8 @@ class DeepseekModel(nn.Module):
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -373,7 +373,9 @@ class DeepseekV2Model(nn.Module):
|
||||
mask: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -127,7 +127,8 @@ class ExaoneModel(nn.Module):
|
||||
cache=None,
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
@ -144,7 +144,8 @@ class GemmaModel(nn.Module):
|
||||
h = self.embed_tokens(inputs)
|
||||
h = h * (self.args.hidden_size**0.5)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -166,7 +166,8 @@ class GemmaModel(nn.Module):
|
||||
h = self.embed_tokens(inputs)
|
||||
h = h * (self.args.hidden_size**0.5)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -139,7 +139,8 @@ class GPT2Model(nn.Module):
|
||||
position_ids = mx.array(np.arange(L))
|
||||
hidden_states += self.wpe(position_ids)
|
||||
|
||||
mask = mask or create_attention_mask(hidden_states, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
@ -150,7 +150,8 @@ class GPTBigCodeModel(nn.Module):
|
||||
position_ids = mx.array(np.arange(L))
|
||||
hidden_states += self.wpe(position_ids)
|
||||
|
||||
mask = mask or create_attention_mask(hidden_states, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
@ -153,7 +153,8 @@ class GPTNeoXModel(nn.Module):
|
||||
|
||||
hidden_states = self.embed_in(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(hidden_states, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(hidden_states, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
@ -244,7 +244,8 @@ class HunYuanModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -198,7 +198,8 @@ class InternLM2Model(nn.Module):
|
||||
):
|
||||
h = self.tok_embeddings(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -160,7 +160,8 @@ class LlamaModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -163,7 +163,8 @@ class MiniCPMModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs) * self.args.scale_emb
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -167,7 +167,8 @@ class MixtralModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -181,7 +181,8 @@ class NemotronModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -129,7 +129,8 @@ class Transformer(nn.Module):
|
||||
):
|
||||
h = self.wte(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.blocks)
|
||||
|
@ -167,7 +167,8 @@ class LlamaModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -183,7 +183,8 @@ class OpenELMModel(nn.Module):
|
||||
):
|
||||
h = self.token_embeddings(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -146,7 +146,8 @@ class PhiModel(nn.Module):
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embed_tokens(x)
|
||||
|
||||
mask = mask or create_attention_mask(x, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -173,7 +173,8 @@ class Phi3Model(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -265,7 +265,8 @@ class Phi3Model(nn.Module):
|
||||
if self.mup_embedding_multiplier:
|
||||
h = self.mup_embedding_multiplier * h
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -160,7 +160,8 @@ class PhiMoEModel(nn.Module):
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -175,7 +175,9 @@ class Model(nn.Module):
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
mask = mask or create_attention_mask(x, cache)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
y = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
@ -178,7 +178,8 @@ class PlamoModel(nn.Module):
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None for _ in range(len(self.layers.layers))]
|
||||
|
@ -123,7 +123,8 @@ class QwenModel(nn.Module):
|
||||
def __call__(self, inputs, mask=None, cache=None):
|
||||
x = self.wte(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(x, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.h)
|
||||
|
@ -154,7 +154,8 @@ class Qwen2Model(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -192,7 +192,8 @@ class Qwen2MoeModel(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -403,7 +403,8 @@ class Griffin(nn.Module):
|
||||
if block.temporal_block_type != "recurrent":
|
||||
mask_cache = [cache[i]]
|
||||
|
||||
mask = mask or create_attention_mask(x, mask_cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, mask_cache)
|
||||
|
||||
for i, block in enumerate(self.layers):
|
||||
x = block(x, mask=mask, cache=cache[i])
|
||||
|
@ -199,7 +199,10 @@ class Model(nn.Module):
|
||||
mask: mx.array = None,
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
mask = mask or create_attention_mask(x, cache)
|
||||
|
||||
if mask is None:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
y = self.model(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
|
@ -130,7 +130,8 @@ class Starcoder2Model(nn.Module):
|
||||
):
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = mask or create_attention_mask(h, cache)
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
@ -183,6 +183,12 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
if model_type != "mamba":
|
||||
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
|
||||
outputs = model(inputs, mask=mask)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
|
||||
self.assertEqual(outputs.shape, (1, 1, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
Loading…
Reference in New Issue
Block a user