fix(mlx-lm): olmo 1b model (#417)

This commit is contained in:
Anchen 2024-02-07 00:27:05 +11:00 committed by GitHub
parent aa7447efa2
commit a7d139f484
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -24,6 +24,15 @@ class ModelArgs(BaseModelArgs):
rope_theta: float = 10000
rope_traditional: bool = False
model_type: str = None
mlp_ratio: int = 4
weight_tying: bool = False
def __post_init__(self):
self.mlp_hidden_size = (
self.mlp_hidden_size
if self.mlp_hidden_size is not None
else self.mlp_ratio * self.d_model
)
class LayerNorm(nn.LayerNorm):
@ -99,6 +108,7 @@ class TransformerBlock(nn.Module):
h = x + r
x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1)
out = h + self.ff_out(nn.silu(x2) * x1)
return out, cache
@ -107,8 +117,11 @@ class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_layers = args.n_layers
self.weight_tying = args.weight_tying
self.wte = nn.Embedding(args.embedding_size, args.d_model)
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
if not self.weight_tying:
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
self.norm = LayerNorm(args.d_model, affine=False)
@ -130,7 +143,12 @@ class Transformer(nn.Module):
for e, block in enumerate(self.blocks):
h, cache[e] = block(h, mask, cache[e])
return self.ff_out(self.norm(h)), cache
h = self.norm(h)
if self.weight_tying:
return h @ self.wte.weight.T, cache
return self.ff_out(h), cache
class OlmoModel(nn.Module):