mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Load all encoder weights
This commit is contained in:
parent
bca5ca4f98
commit
c0001a94f2
@ -8,7 +8,12 @@ def replace_key(key: str) -> str:
|
|||||||
key = key.replace(".o.", ".out_proj.")
|
key = key.replace(".o.", ".out_proj.")
|
||||||
key = key.replace(".q.", ".query_proj.")
|
key = key.replace(".q.", ".query_proj.")
|
||||||
key = key.replace(".v.", ".value_proj.")
|
key = key.replace(".v.", ".value_proj.")
|
||||||
|
key = key.replace(".layer.0.layer_norm.", ".ln1.")
|
||||||
|
key = key.replace(".layer.1.layer_norm.", ".ln2.")
|
||||||
key = key.replace(".layer.1.DenseReluDense.wi.", ".linear1.")
|
key = key.replace(".layer.1.DenseReluDense.wi.", ".linear1.")
|
||||||
|
key = key.replace(".layer.1.DenseReluDense.wo.", ".linear2.")
|
||||||
|
key = key.replace(".final_layer_norm.", ".ln.")
|
||||||
|
key = key.replace("shared.", "wte.")
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def convert():
|
def convert():
|
||||||
|
69
t5/t5.py
69
t5/t5.py
@ -1,10 +1,11 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
from typing import Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from mlx.utils import tree_flatten, tree_unflatten
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -25,10 +26,72 @@ class ModelArgs:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
if affine:
|
||||||
|
self.weight = mx.ones((dims,))
|
||||||
|
self.eps = eps
|
||||||
|
self.dims = dims
|
||||||
|
|
||||||
|
def _extra_repr(self):
|
||||||
|
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
means = mx.mean(x, axis=-1, keepdims=True)
|
||||||
|
var = mx.var(x, axis=-1, keepdims=True)
|
||||||
|
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||||
|
return (self.weight * x) if "weight" in self else x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderLayer(nn.Module):
|
||||||
|
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
||||||
|
super().__init__()
|
||||||
|
mlp_dims = mlp_dims or dims * 4
|
||||||
|
self.attention = nn.MultiHeadAttention(dims, num_heads)
|
||||||
|
self.ln1 = LayerNorm(dims)
|
||||||
|
self.ln2 = LayerNorm(dims)
|
||||||
|
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
||||||
|
self.linear2 = nn.Linear(mlp_dims, dims, bias=False)
|
||||||
|
|
||||||
|
def __call__(self, x, mask):
|
||||||
|
y = self.ln1(x)
|
||||||
|
y = self.attention(y, y, y, mask)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
y = self.ln2(x)
|
||||||
|
y = self.linear1(y)
|
||||||
|
y = mx.maximum(y, 0)
|
||||||
|
y = self.linear2(y)
|
||||||
|
x = x + y
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = [
|
||||||
|
TransformerEncoderLayer(dims, num_heads, mlp_dims)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
self.ln = LayerNorm(dims)
|
||||||
|
|
||||||
|
def __call__(self, x, mask):
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, mask)
|
||||||
|
x = self.ln(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class T5(nn.Module):
|
class T5(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||||
self.encoder = nn.TransformerEncoder(
|
self.encoder = TransformerEncoder(
|
||||||
num_layers=config.num_layers,
|
num_layers=config.num_layers,
|
||||||
dims=config.d_model,
|
dims=config.d_model,
|
||||||
num_heads=config.num_heads,
|
num_heads=config.num_heads,
|
||||||
|
Loading…
Reference in New Issue
Block a user