From c0001a94f239c80103574f6726cdbd8e67764242 Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Thu, 14 Dec 2023 15:38:41 -0500 Subject: [PATCH] Load all encoder weights --- t5/convert.py | 5 ++++ t5/t5.py | 69 ++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 71 insertions(+), 3 deletions(-) diff --git a/t5/convert.py b/t5/convert.py index 0e6a51b4..0977c917 100644 --- a/t5/convert.py +++ b/t5/convert.py @@ -8,7 +8,12 @@ def replace_key(key: str) -> str: key = key.replace(".o.", ".out_proj.") key = key.replace(".q.", ".query_proj.") key = key.replace(".v.", ".value_proj.") + key = key.replace(".layer.0.layer_norm.", ".ln1.") + key = key.replace(".layer.1.layer_norm.", ".ln2.") key = key.replace(".layer.1.DenseReluDense.wi.", ".linear1.") + key = key.replace(".layer.1.DenseReluDense.wo.", ".linear2.") + key = key.replace(".final_layer_norm.", ".ln.") + key = key.replace("shared.", "wte.") return key def convert(): diff --git a/t5/t5.py b/t5/t5.py index ffdd0dc2..0774a1fa 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -1,10 +1,11 @@ import argparse +from typing import Optional from dataclasses import dataclass -from mlx.utils import tree_flatten, tree_unflatten -from transformers import AutoTokenizer import mlx.core as mx import mlx.nn as nn +from mlx.utils import tree_flatten, tree_unflatten +from transformers import AutoTokenizer @dataclass @@ -25,10 +26,72 @@ class ModelArgs: + +class LayerNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True): + super().__init__() + if affine: + self.weight = mx.ones((dims,)) + self.eps = eps + self.dims = dims + + def _extra_repr(self): + return f"{self.dims}, eps={self.eps}, affine={'weight' in self}" + + def __call__(self, x): + means = mx.mean(x, axis=-1, keepdims=True) + var = mx.var(x, axis=-1, keepdims=True) + x = (x - means) * mx.rsqrt(var + self.eps) + return (self.weight * x) if "weight" in self else x + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): + super().__init__() + mlp_dims = mlp_dims or dims * 4 + self.attention = nn.MultiHeadAttention(dims, num_heads) + self.ln1 = LayerNorm(dims) + self.ln2 = LayerNorm(dims) + self.linear1 = nn.Linear(dims, mlp_dims, bias=False) + self.linear2 = nn.Linear(mlp_dims, dims, bias=False) + + def __call__(self, x, mask): + y = self.ln1(x) + y = self.attention(y, y, y, mask) + x = x + y + + y = self.ln2(x) + y = self.linear1(y) + y = mx.maximum(y, 0) + y = self.linear2(y) + x = x + y + + return x + + +class TransformerEncoder(nn.Module): + def __init__( + self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None + ): + super().__init__() + self.layers = [ + TransformerEncoderLayer(dims, num_heads, mlp_dims) + for _ in range(num_layers) + ] + self.ln = LayerNorm(dims) + + def __call__(self, x, mask): + for layer in self.layers: + x = layer(x, mask) + x = self.ln(x) + + return x + + class T5(nn.Module): def __init__(self, config: ModelArgs): self.wte = nn.Embedding(config.vocab_size, config.d_model) - self.encoder = nn.TransformerEncoder( + self.encoder = TransformerEncoder( num_layers=config.num_layers, dims=config.d_model, num_heads=config.num_heads,