diff --git a/llms/mixtral/mixtral.py b/llms/mixtral/mixtral.py index 6ddc09b8..653dad57 100644 --- a/llms/mixtral/mixtral.py +++ b/llms/mixtral/mixtral.py @@ -23,7 +23,7 @@ class ModelArgs: n_kv_heads: int norm_eps: float vocab_size: int - moe: Optional[dict] = None + moe: dict class Attention(nn.Module): @@ -91,8 +91,6 @@ class FeedForward(nn.Module): class MOEFeedForward(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - if args.moe is None: - raise ValueError("args.moe must not be None for MOEFeedForward") self.num_experts = args.moe["num_experts"] self.num_experts_per_tok = args.moe["num_experts_per_tok"] self.experts = [FeedForward(args) for _ in range(self.num_experts)] @@ -101,27 +99,22 @@ class MOEFeedForward(nn.Module): def __call__(self, x) -> mx.array: ne = self.num_experts_per_tok orig_shape = x.shape - x_flat = x.reshape(-1, x.shape[-1]) - batch_size = x_flat.shape[0] + x = x.reshape(-1, x.shape[-1]) - gates = self.gate(x_flat) + gates = self.gate(x) inds = mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne] scores = mx.softmax( mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), axis=-1, ).astype(gates.dtype) - final_output = mx.zeros((batch_size, x.shape[-1]), dtype=x.dtype) - - for i in range(batch_size): - item_experts = inds[i].tolist() - item_scores = scores[i] - - for j, expert_idx in enumerate(item_experts): - expert_output = self.experts[expert_idx](x_flat[i]) - final_output = final_output.at[i].add(expert_output * item_scores[j]) - - return final_output.reshape(orig_shape) + y = [] + for xt, st, it in zip(x, scores, inds.tolist()): + yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1) + yt = (yt * st).sum(axis=-1) + y.append(yt[None, :]) + y = mx.concatenate(y) + return y.reshape(orig_shape) class MOETransformerBlock(nn.Module): diff --git a/llms/speculative_decoding/decoder.py b/llms/speculative_decoding/decoder.py index d2b97716..39cf5b92 100644 --- a/llms/speculative_decoding/decoder.py +++ b/llms/speculative_decoding/decoder.py @@ -169,7 +169,7 @@ class SpeculativeDecoder: n_steps += 1 - for t in list(new_tokens): + for t in new_tokens.tolist(): if t == self.tokenizer.eos_id or ntoks >= max_tokens: break outputs.append(t) diff --git a/lora/models.py b/lora/models.py index ddec473c..acafbc61 100644 --- a/lora/models.py +++ b/lora/models.py @@ -136,11 +136,6 @@ class Attention(nn.Module): self.n_heads = n_heads = args.num_attention_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads - if n_heads is None or n_kv_heads is None: - raise ValueError( - "num_attention_heads and num_key_value_heads must not be None" - ) - self.repeats = n_heads // n_kv_heads head_dim = args.hidden_size // n_heads