mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
revert some stuff
This commit is contained in:
parent
3992ca5554
commit
d5443aa13a
@ -23,7 +23,7 @@ class ModelArgs:
|
|||||||
n_kv_heads: int
|
n_kv_heads: int
|
||||||
norm_eps: float
|
norm_eps: float
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
moe: Optional[dict] = None
|
moe: dict
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@ -91,8 +91,6 @@ class FeedForward(nn.Module):
|
|||||||
class MOEFeedForward(nn.Module):
|
class MOEFeedForward(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if args.moe is None:
|
|
||||||
raise ValueError("args.moe must not be None for MOEFeedForward")
|
|
||||||
self.num_experts = args.moe["num_experts"]
|
self.num_experts = args.moe["num_experts"]
|
||||||
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
|
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
|
||||||
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
|
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
|
||||||
@ -101,27 +99,22 @@ class MOEFeedForward(nn.Module):
|
|||||||
def __call__(self, x) -> mx.array:
|
def __call__(self, x) -> mx.array:
|
||||||
ne = self.num_experts_per_tok
|
ne = self.num_experts_per_tok
|
||||||
orig_shape = x.shape
|
orig_shape = x.shape
|
||||||
x_flat = x.reshape(-1, x.shape[-1])
|
x = x.reshape(-1, x.shape[-1])
|
||||||
batch_size = x_flat.shape[0]
|
|
||||||
|
|
||||||
gates = self.gate(x_flat)
|
gates = self.gate(x)
|
||||||
inds = mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne]
|
inds = mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne]
|
||||||
scores = mx.softmax(
|
scores = mx.softmax(
|
||||||
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
|
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
|
||||||
axis=-1,
|
axis=-1,
|
||||||
).astype(gates.dtype)
|
).astype(gates.dtype)
|
||||||
|
|
||||||
final_output = mx.zeros((batch_size, x.shape[-1]), dtype=x.dtype)
|
y = []
|
||||||
|
for xt, st, it in zip(x, scores, inds.tolist()):
|
||||||
for i in range(batch_size):
|
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
|
||||||
item_experts = inds[i].tolist()
|
yt = (yt * st).sum(axis=-1)
|
||||||
item_scores = scores[i]
|
y.append(yt[None, :])
|
||||||
|
y = mx.concatenate(y)
|
||||||
for j, expert_idx in enumerate(item_experts):
|
return y.reshape(orig_shape)
|
||||||
expert_output = self.experts[expert_idx](x_flat[i])
|
|
||||||
final_output = final_output.at[i].add(expert_output * item_scores[j])
|
|
||||||
|
|
||||||
return final_output.reshape(orig_shape)
|
|
||||||
|
|
||||||
|
|
||||||
class MOETransformerBlock(nn.Module):
|
class MOETransformerBlock(nn.Module):
|
||||||
|
@ -169,7 +169,7 @@ class SpeculativeDecoder:
|
|||||||
|
|
||||||
n_steps += 1
|
n_steps += 1
|
||||||
|
|
||||||
for t in list(new_tokens):
|
for t in new_tokens.tolist():
|
||||||
if t == self.tokenizer.eos_id or ntoks >= max_tokens:
|
if t == self.tokenizer.eos_id or ntoks >= max_tokens:
|
||||||
break
|
break
|
||||||
outputs.append(t)
|
outputs.append(t)
|
||||||
|
@ -136,11 +136,6 @@ class Attention(nn.Module):
|
|||||||
self.n_heads = n_heads = args.num_attention_heads
|
self.n_heads = n_heads = args.num_attention_heads
|
||||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||||
|
|
||||||
if n_heads is None or n_kv_heads is None:
|
|
||||||
raise ValueError(
|
|
||||||
"num_attention_heads and num_key_value_heads must not be None"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.repeats = n_heads // n_kv_heads
|
self.repeats = n_heads // n_kv_heads
|
||||||
|
|
||||||
head_dim = args.hidden_size // n_heads
|
head_dim = args.hidden_size // n_heads
|
||||||
|
Loading…
Reference in New Issue
Block a user