mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
fix(mlx-lm): olmo 1b model (#417)
This commit is contained in:
parent
aa7447efa2
commit
a7d139f484
@ -24,6 +24,15 @@ class ModelArgs(BaseModelArgs):
|
||||
rope_theta: float = 10000
|
||||
rope_traditional: bool = False
|
||||
model_type: str = None
|
||||
mlp_ratio: int = 4
|
||||
weight_tying: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.mlp_hidden_size = (
|
||||
self.mlp_hidden_size
|
||||
if self.mlp_hidden_size is not None
|
||||
else self.mlp_ratio * self.d_model
|
||||
)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
@ -99,6 +108,7 @@ class TransformerBlock(nn.Module):
|
||||
h = x + r
|
||||
|
||||
x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1)
|
||||
|
||||
out = h + self.ff_out(nn.silu(x2) * x1)
|
||||
return out, cache
|
||||
|
||||
@ -107,9 +117,12 @@ class Transformer(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_layers = args.n_layers
|
||||
self.weight_tying = args.weight_tying
|
||||
|
||||
self.wte = nn.Embedding(args.embedding_size, args.d_model)
|
||||
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
|
||||
if not self.weight_tying:
|
||||
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
|
||||
self.norm = LayerNorm(args.d_model, affine=False)
|
||||
|
||||
def __call__(
|
||||
@ -130,7 +143,12 @@ class Transformer(nn.Module):
|
||||
for e, block in enumerate(self.blocks):
|
||||
h, cache[e] = block(h, mask, cache[e])
|
||||
|
||||
return self.ff_out(self.norm(h)), cache
|
||||
h = self.norm(h)
|
||||
|
||||
if self.weight_tying:
|
||||
return h @ self.wte.weight.T, cache
|
||||
|
||||
return self.ff_out(h), cache
|
||||
|
||||
|
||||
class OlmoModel(nn.Module):
|
||||
|
Loading…
Reference in New Issue
Block a user