mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 09:56:24 +08:00
formated
This commit is contained in:
parent
e4c56625f0
commit
a1ff1bf72a
@ -91,7 +91,7 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
return self.o_proj(output)
|
return self.o_proj(output)
|
||||||
|
|
||||||
|
|
||||||
class OlmoeSparseMoeBlock(nn.Module):
|
class OlmoeSparseMoeBlock(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
@ -99,13 +99,13 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|||||||
self.num_experts = args.num_experts
|
self.num_experts = args.num_experts
|
||||||
self.top_k = args.num_experts_per_tok
|
self.top_k = args.num_experts_per_tok
|
||||||
self.norm_topk_prob = args.norm_topk_prob
|
self.norm_topk_prob = args.norm_topk_prob
|
||||||
|
|
||||||
self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False)
|
self.gate = nn.Linear(args.hidden_size, self.num_experts, bias=False)
|
||||||
self.switch_mlp = SwitchGLU(
|
self.switch_mlp = SwitchGLU(
|
||||||
args.hidden_size,
|
args.hidden_size,
|
||||||
args.intermediate_size,
|
args.intermediate_size,
|
||||||
self.num_experts,
|
self.num_experts,
|
||||||
bias=args.mlp_bias
|
bias=args.mlp_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
@ -114,7 +114,9 @@ class OlmoeSparseMoeBlock(nn.Module):
|
|||||||
router_logits = self.gate(x_flat)
|
router_logits = self.gate(x_flat)
|
||||||
routing_weights = mx.softmax(router_logits, axis=1, precise=True)
|
routing_weights = mx.softmax(router_logits, axis=1, precise=True)
|
||||||
k = self.top_k
|
k = self.top_k
|
||||||
indices = mx.stop_gradient(mx.argpartition(-routing_weights, kth=k-1, axis=-1)[..., :k])
|
indices = mx.stop_gradient(
|
||||||
|
mx.argpartition(-routing_weights, kth=k - 1, axis=-1)[..., :k]
|
||||||
|
)
|
||||||
scores = mx.take_along_axis(routing_weights, indices, axis=-1)
|
scores = mx.take_along_axis(routing_weights, indices, axis=-1)
|
||||||
if self.norm_topk_prob:
|
if self.norm_topk_prob:
|
||||||
scores = scores / scores.sum(axis=-1, keepdims=True)
|
scores = scores / scores.sum(axis=-1, keepdims=True)
|
||||||
@ -129,7 +131,9 @@ class TransformerBlock(nn.Module):
|
|||||||
self.self_attn = Attention(args)
|
self.self_attn = Attention(args)
|
||||||
self.mlp = OlmoeSparseMoeBlock(args)
|
self.mlp = OlmoeSparseMoeBlock(args)
|
||||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||||
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
self.post_attention_layernorm = nn.RMSNorm(
|
||||||
|
args.hidden_size, eps=args.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -192,7 +196,7 @@ class Model(nn.Module):
|
|||||||
else:
|
else:
|
||||||
out = self.lm_head(out)
|
out = self.lm_head(out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def sanitize(self, weights):
|
def sanitize(self, weights):
|
||||||
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
|
||||||
return weights
|
return weights
|
||||||
|
Loading…
Reference in New Issue
Block a user