From b8fee34f24c2fa4ba32ead16e379008b103d656f Mon Sep 17 00:00:00 2001 From: paramthakkar123 Date: Wed, 9 Apr 2025 23:13:43 +0530 Subject: [PATCH] Fixes --- llms/mixtral/mixtral.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/llms/mixtral/mixtral.py b/llms/mixtral/mixtral.py index 807d3b23..6ddc09b8 100644 --- a/llms/mixtral/mixtral.py +++ b/llms/mixtral/mixtral.py @@ -91,10 +91,8 @@ class FeedForward(nn.Module): class MOEFeedForward(nn.Module): def __init__(self, args: ModelArgs): super().__init__() - if args.moe is None: raise ValueError("args.moe must not be None for MOEFeedForward") - self.num_experts = args.moe["num_experts"] self.num_experts_per_tok = args.moe["num_experts_per_tok"] self.experts = [FeedForward(args) for _ in range(self.num_experts)] @@ -103,23 +101,27 @@ class MOEFeedForward(nn.Module): def __call__(self, x) -> mx.array: ne = self.num_experts_per_tok orig_shape = x.shape - x = x.reshape(-1, x.shape[-1]) + x_flat = x.reshape(-1, x.shape[-1]) + batch_size = x_flat.shape[0] - gates = self.gate(x) + gates = self.gate(x_flat) inds = mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne] scores = mx.softmax( mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), axis=-1, ).astype(gates.dtype) - y = [] - for xt, st, it in zip(x, scores, inds.tolist()): - yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1) - yt = (yt * st).sum(axis=-1) - y.append(yt[None, :]) - y = mx.concatenate(y) + final_output = mx.zeros((batch_size, x.shape[-1]), dtype=x.dtype) - return y.reshape(orig_shape) + for i in range(batch_size): + item_experts = inds[i].tolist() + item_scores = scores[i] + + for j, expert_idx in enumerate(item_experts): + expert_output = self.experts[expert_idx](x_flat[i]) + final_output = final_output.at[i].add(expert_output * item_scores[j]) + + return final_output.reshape(orig_shape) class MOETransformerBlock(nn.Module):