mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
fix(mlx-lm): olmo 1b model (#417)
This commit is contained in:
parent
aa7447efa2
commit
a7d139f484
@ -24,6 +24,15 @@ class ModelArgs(BaseModelArgs):
|
|||||||
rope_theta: float = 10000
|
rope_theta: float = 10000
|
||||||
rope_traditional: bool = False
|
rope_traditional: bool = False
|
||||||
model_type: str = None
|
model_type: str = None
|
||||||
|
mlp_ratio: int = 4
|
||||||
|
weight_tying: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.mlp_hidden_size = (
|
||||||
|
self.mlp_hidden_size
|
||||||
|
if self.mlp_hidden_size is not None
|
||||||
|
else self.mlp_ratio * self.d_model
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.LayerNorm):
|
class LayerNorm(nn.LayerNorm):
|
||||||
@ -99,6 +108,7 @@ class TransformerBlock(nn.Module):
|
|||||||
h = x + r
|
h = x + r
|
||||||
|
|
||||||
x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1)
|
x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1)
|
||||||
|
|
||||||
out = h + self.ff_out(nn.silu(x2) * x1)
|
out = h + self.ff_out(nn.silu(x2) * x1)
|
||||||
return out, cache
|
return out, cache
|
||||||
|
|
||||||
@ -107,9 +117,12 @@ class Transformer(nn.Module):
|
|||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_layers = args.n_layers
|
self.n_layers = args.n_layers
|
||||||
|
self.weight_tying = args.weight_tying
|
||||||
|
|
||||||
self.wte = nn.Embedding(args.embedding_size, args.d_model)
|
self.wte = nn.Embedding(args.embedding_size, args.d_model)
|
||||||
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||||
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
|
if not self.weight_tying:
|
||||||
|
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
|
||||||
self.norm = LayerNorm(args.d_model, affine=False)
|
self.norm = LayerNorm(args.d_model, affine=False)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@ -130,7 +143,12 @@ class Transformer(nn.Module):
|
|||||||
for e, block in enumerate(self.blocks):
|
for e, block in enumerate(self.blocks):
|
||||||
h, cache[e] = block(h, mask, cache[e])
|
h, cache[e] = block(h, mask, cache[e])
|
||||||
|
|
||||||
return self.ff_out(self.norm(h)), cache
|
h = self.norm(h)
|
||||||
|
|
||||||
|
if self.weight_tying:
|
||||||
|
return h @ self.wte.weight.T, cache
|
||||||
|
|
||||||
|
return self.ff_out(h), cache
|
||||||
|
|
||||||
|
|
||||||
class OlmoModel(nn.Module):
|
class OlmoModel(nn.Module):
|
||||||
|
Loading…
Reference in New Issue
Block a user