Load all encoder weights

This commit is contained in:
Juarez Bochi 2023-12-14 15:38:41 -05:00
parent bca5ca4f98
commit c0001a94f2
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6
2 changed files with 71 additions and 3 deletions

View File

@ -8,7 +8,12 @@ def replace_key(key: str) -> str:
key = key.replace(".o.", ".out_proj.") key = key.replace(".o.", ".out_proj.")
key = key.replace(".q.", ".query_proj.") key = key.replace(".q.", ".query_proj.")
key = key.replace(".v.", ".value_proj.") key = key.replace(".v.", ".value_proj.")
key = key.replace(".layer.0.layer_norm.", ".ln1.")
key = key.replace(".layer.1.layer_norm.", ".ln2.")
key = key.replace(".layer.1.DenseReluDense.wi.", ".linear1.") key = key.replace(".layer.1.DenseReluDense.wi.", ".linear1.")
key = key.replace(".layer.1.DenseReluDense.wo.", ".linear2.")
key = key.replace(".final_layer_norm.", ".ln.")
key = key.replace("shared.", "wte.")
return key return key
def convert(): def convert():

View File

@ -1,10 +1,11 @@
import argparse import argparse
from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass
from mlx.utils import tree_flatten, tree_unflatten
from transformers import AutoTokenizer
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten
from transformers import AutoTokenizer
@dataclass @dataclass
@ -25,10 +26,72 @@ class ModelArgs:
class LayerNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
super().__init__()
if affine:
self.weight = mx.ones((dims,))
self.eps = eps
self.dims = dims
def _extra_repr(self):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x):
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
x = (x - means) * mx.rsqrt(var + self.eps)
return (self.weight * x) if "weight" in self else x
class TransformerEncoderLayer(nn.Module):
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
super().__init__()
mlp_dims = mlp_dims or dims * 4
self.attention = nn.MultiHeadAttention(dims, num_heads)
self.ln1 = LayerNorm(dims)
self.ln2 = LayerNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(mlp_dims, dims, bias=False)
def __call__(self, x, mask):
y = self.ln1(x)
y = self.attention(y, y, y, mask)
x = x + y
y = self.ln2(x)
y = self.linear1(y)
y = mx.maximum(y, 0)
y = self.linear2(y)
x = x + y
return x
class TransformerEncoder(nn.Module):
def __init__(
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
):
super().__init__()
self.layers = [
TransformerEncoderLayer(dims, num_heads, mlp_dims)
for _ in range(num_layers)
]
self.ln = LayerNorm(dims)
def __call__(self, x, mask):
for layer in self.layers:
x = layer(x, mask)
x = self.ln(x)
return x
class T5(nn.Module): class T5(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
self.wte = nn.Embedding(config.vocab_size, config.d_model) self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = nn.TransformerEncoder( self.encoder = TransformerEncoder(
num_layers=config.num_layers, num_layers=config.num_layers,
dims=config.d_model, dims=config.d_model,
num_heads=config.num_heads, num_heads=config.num_heads,