This commit is contained in:
paramthakkar123 2025-04-09 23:13:43 +05:30
parent 8f8f9b6991
commit b8fee34f24

View File

@ -91,10 +91,8 @@ class FeedForward(nn.Module):
class MOEFeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
if args.moe is None:
raise ValueError("args.moe must not be None for MOEFeedForward")
self.num_experts = args.moe["num_experts"]
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
@ -103,23 +101,27 @@ class MOEFeedForward(nn.Module):
def __call__(self, x) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
x_flat = x.reshape(-1, x.shape[-1])
batch_size = x_flat.shape[0]
gates = self.gate(x)
gates = self.gate(x_flat)
inds = mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne]
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
final_output = mx.zeros((batch_size, x.shape[-1]), dtype=x.dtype)
return y.reshape(orig_shape)
for i in range(batch_size):
item_experts = inds[i].tolist()
item_scores = scores[i]
for j, expert_idx in enumerate(item_experts):
expert_output = self.experts[expert_idx](x_flat[i])
final_output = final_output.at[i].add(expert_output * item_scores[j])
return final_output.reshape(orig_shape)
class MOETransformerBlock(nn.Module):