mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-11 06:04:36 +08:00
Load all encoder weights
This commit is contained in:
69
t5/t5.py
69
t5/t5.py
@@ -1,10 +1,11 @@
|
||||
import argparse
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -25,10 +26,72 @@ class ModelArgs:
|
||||
|
||||
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
|
||||
super().__init__()
|
||||
if affine:
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
self.dims = dims
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||
|
||||
def __call__(self, x):
|
||||
means = mx.mean(x, axis=-1, keepdims=True)
|
||||
var = mx.var(x, axis=-1, keepdims=True)
|
||||
x = (x - means) * mx.rsqrt(var + self.eps)
|
||||
return (self.weight * x) if "weight" in self else x
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None):
|
||||
super().__init__()
|
||||
mlp_dims = mlp_dims or dims * 4
|
||||
self.attention = nn.MultiHeadAttention(dims, num_heads)
|
||||
self.ln1 = LayerNorm(dims)
|
||||
self.ln2 = LayerNorm(dims)
|
||||
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
|
||||
self.linear2 = nn.Linear(mlp_dims, dims, bias=False)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
y = self.ln1(x)
|
||||
y = self.attention(y, y, y, mask)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y = self.linear1(y)
|
||||
y = mx.maximum(y, 0)
|
||||
y = self.linear2(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
TransformerEncoderLayer(dims, num_heads, mlp_dims)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
self.ln = LayerNorm(dims)
|
||||
|
||||
def __call__(self, x, mask):
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask)
|
||||
x = self.ln(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class T5(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
self.encoder = TransformerEncoder(
|
||||
num_layers=config.num_layers,
|
||||
dims=config.d_model,
|
||||
num_heads=config.num_heads,
|
||||
|
Reference in New Issue
Block a user