revert some stuff

This commit is contained in:
Awni Hannun 2025-04-23 14:18:43 -07:00
parent 3992ca5554
commit d5443aa13a
3 changed files with 11 additions and 23 deletions

View File

@ -23,7 +23,7 @@ class ModelArgs:
n_kv_heads: int n_kv_heads: int
norm_eps: float norm_eps: float
vocab_size: int vocab_size: int
moe: Optional[dict] = None moe: dict
class Attention(nn.Module): class Attention(nn.Module):
@ -91,8 +91,6 @@ class FeedForward(nn.Module):
class MOEFeedForward(nn.Module): class MOEFeedForward(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
if args.moe is None:
raise ValueError("args.moe must not be None for MOEFeedForward")
self.num_experts = args.moe["num_experts"] self.num_experts = args.moe["num_experts"]
self.num_experts_per_tok = args.moe["num_experts_per_tok"] self.num_experts_per_tok = args.moe["num_experts_per_tok"]
self.experts = [FeedForward(args) for _ in range(self.num_experts)] self.experts = [FeedForward(args) for _ in range(self.num_experts)]
@ -101,27 +99,22 @@ class MOEFeedForward(nn.Module):
def __call__(self, x) -> mx.array: def __call__(self, x) -> mx.array:
ne = self.num_experts_per_tok ne = self.num_experts_per_tok
orig_shape = x.shape orig_shape = x.shape
x_flat = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
batch_size = x_flat.shape[0]
gates = self.gate(x_flat) gates = self.gate(x)
inds = mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne] inds = mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne]
scores = mx.softmax( scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32), mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1, axis=-1,
).astype(gates.dtype) ).astype(gates.dtype)
final_output = mx.zeros((batch_size, x.shape[-1]), dtype=x.dtype) y = []
for xt, st, it in zip(x, scores, inds.tolist()):
for i in range(batch_size): yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
item_experts = inds[i].tolist() yt = (yt * st).sum(axis=-1)
item_scores = scores[i] y.append(yt[None, :])
y = mx.concatenate(y)
for j, expert_idx in enumerate(item_experts): return y.reshape(orig_shape)
expert_output = self.experts[expert_idx](x_flat[i])
final_output = final_output.at[i].add(expert_output * item_scores[j])
return final_output.reshape(orig_shape)
class MOETransformerBlock(nn.Module): class MOETransformerBlock(nn.Module):

View File

@ -169,7 +169,7 @@ class SpeculativeDecoder:
n_steps += 1 n_steps += 1
for t in list(new_tokens): for t in new_tokens.tolist():
if t == self.tokenizer.eos_id or ntoks >= max_tokens: if t == self.tokenizer.eos_id or ntoks >= max_tokens:
break break
outputs.append(t) outputs.append(t)

View File

@ -136,11 +136,6 @@ class Attention(nn.Module):
self.n_heads = n_heads = args.num_attention_heads self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_kv_heads = n_kv_heads = args.num_key_value_heads
if n_heads is None or n_kv_heads is None:
raise ValueError(
"num_attention_heads and num_key_value_heads must not be None"
)
self.repeats = n_heads // n_kv_heads self.repeats = n_heads // n_kv_heads
head_dim = args.hidden_size // n_heads head_dim = args.hidden_size // n_heads