comment + fix

This commit is contained in:
Alex Barron 2024-12-18 18:39:14 -08:00
parent eb9452beb9
commit 5b414dddf2
34 changed files with 76 additions and 36 deletions

View File

@ -37,10 +37,8 @@ def create_causal_mask(
if window_size is not None:
mask = mask | (linds > rinds + window_size)
if lengths is not None:
mask = mx.repeat(mask[None], lengths.shape[0], axis=0)
lengths = lengths[:, None, None]
mask = mask | (rinds[None] >= lengths)
mask = mask[:, None]
lengths = lengths[:, None, None, None]
mask = mask | (rinds >= lengths)
return mask * -1e9

View File

@ -160,7 +160,8 @@ class CohereModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -156,7 +156,8 @@ class CohereModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -202,7 +202,8 @@ class DBRX(nn.Module):
):
h = self.wte(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)

View File

@ -214,7 +214,8 @@ class DeepseekModel(nn.Module):
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -373,7 +373,9 @@ class DeepseekV2Model(nn.Module):
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -127,7 +127,8 @@ class ExaoneModel(nn.Module):
cache=None,
):
h = self.wte(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.h)

View File

@ -144,7 +144,8 @@ class GemmaModel(nn.Module):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -166,7 +166,8 @@ class GemmaModel(nn.Module):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -139,7 +139,8 @@ class GPT2Model(nn.Module):
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = mask or create_attention_mask(hidden_states, cache)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)

View File

@ -150,7 +150,8 @@ class GPTBigCodeModel(nn.Module):
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = mask or create_attention_mask(hidden_states, cache)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)

View File

@ -153,7 +153,8 @@ class GPTNeoXModel(nn.Module):
hidden_states = self.embed_in(inputs)
mask = mask or create_attention_mask(hidden_states, cache)
if mask is None:
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)

View File

@ -244,7 +244,8 @@ class HunYuanModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -198,7 +198,8 @@ class InternLM2Model(nn.Module):
):
h = self.tok_embeddings(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -160,7 +160,8 @@ class LlamaModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -163,7 +163,8 @@ class MiniCPMModel(nn.Module):
):
h = self.embed_tokens(inputs) * self.args.scale_emb
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -167,7 +167,8 @@ class MixtralModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -181,7 +181,8 @@ class NemotronModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -129,7 +129,8 @@ class Transformer(nn.Module):
):
h = self.wte(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)

View File

@ -167,7 +167,8 @@ class LlamaModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -183,7 +183,8 @@ class OpenELMModel(nn.Module):
):
h = self.token_embeddings(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -146,7 +146,8 @@ class PhiModel(nn.Module):
def __call__(self, x, mask, cache):
x = self.embed_tokens(x)
mask = mask or create_attention_mask(x, cache)
if mask is None:
mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -173,7 +173,8 @@ class Phi3Model(nn.Module):
):
h = self.embed_tokens(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -265,7 +265,8 @@ class Phi3Model(nn.Module):
if self.mup_embedding_multiplier:
h = self.mup_embedding_multiplier * h
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -160,7 +160,8 @@ class PhiMoEModel(nn.Module):
) -> mx.array:
h = self.embed_tokens(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -175,7 +175,9 @@ class Model(nn.Module):
mask: mx.array = None,
cache=None,
) -> mx.array:
mask = mask or create_attention_mask(x, cache)
if mask is None:
mask = create_attention_mask(x, cache)
y = self.transformer(x, mask, cache)
return self.lm_head(y)

View File

@ -178,7 +178,8 @@ class PlamoModel(nn.Module):
) -> mx.array:
h = self.embed_tokens(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None for _ in range(len(self.layers.layers))]

View File

@ -123,7 +123,8 @@ class QwenModel(nn.Module):
def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs)
mask = mask or create_attention_mask(x, cache)
if mask is None:
mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.h)

View File

@ -154,7 +154,8 @@ class Qwen2Model(nn.Module):
):
h = self.embed_tokens(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -192,7 +192,8 @@ class Qwen2MoeModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -403,7 +403,8 @@ class Griffin(nn.Module):
if block.temporal_block_type != "recurrent":
mask_cache = [cache[i]]
mask = mask or create_attention_mask(x, mask_cache)
if mask is None:
mask = create_attention_mask(x, mask_cache)
for i, block in enumerate(self.layers):
x = block(x, mask=mask, cache=cache[i])

View File

@ -199,7 +199,10 @@ class Model(nn.Module):
mask: mx.array = None,
cache=None,
) -> mx.array:
mask = mask or create_attention_mask(x, cache)
if mask is None:
mask = create_attention_mask(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)

View File

@ -130,7 +130,8 @@ class Starcoder2Model(nn.Module):
):
h = self.embed_tokens(inputs)
mask = mask or create_attention_mask(h, cache)
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)

View File

@ -183,6 +183,12 @@ class TestModels(unittest.TestCase):
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
if model_type != "mamba":
mask = create_causal_mask(inputs.shape[1], 0).astype(t)
outputs = model(inputs, mask=mask)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
self.assertEqual(outputs.shape, (1, 1, vocab_size))
self.assertEqual(outputs.dtype, t)