mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
LM head
This commit is contained in:
parent
d12db65eeb
commit
31da1b0dab
@ -9,6 +9,7 @@ SHARED_REPLACEMENT_PATTERNS = [
|
||||
(".q.", ".query_proj."),
|
||||
(".v.", ".value_proj."),
|
||||
("shared.", "wte."),
|
||||
("lm_head.", "lm_head.linear."),
|
||||
(".layer.0.layer_norm.", ".ln1."),
|
||||
(".layer.1.layer_norm.", ".ln2."),
|
||||
(".layer.2.layer_norm.", ".ln3."),
|
||||
|
23
t5/t5.py
23
t5/t5.py
@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
@ -259,32 +258,38 @@ class TransformerDecoder(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
self.linear = nn.Linear(config.d_model, config.vocab_size)
|
||||
|
||||
def __call__(self, inputs):
|
||||
return self.linear(inputs)
|
||||
|
||||
|
||||
class T5(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||
self.encoder = TransformerEncoder(config)
|
||||
self.decoder = TransformerDecoder(config)
|
||||
# self.lm_head = OutputHead(config)
|
||||
self.lm_head = OutputHead(config)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
decoder_inputs: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
x = self.wte(inputs)
|
||||
|
||||
|
||||
y = self.encoder(x, mask=None) #, cache)
|
||||
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
if x.shape[1] > 1 and mask is None:
|
||||
mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||
mask = mask.astype(x.dtype)
|
||||
|
||||
# y, cache = self.decoder(x, mask, cache)
|
||||
# return self.lm_head(y), cache
|
||||
return y #, cache
|
||||
decoder_inputs = self.wte(decoder_inputs)
|
||||
y, cache = self.decoder(x=decoder_inputs, x_mask=mask, memory=y) #, cache)
|
||||
return self.lm_head(y), cache
|
||||
|
||||
|
||||
# def generate(prompt: mx.array, model: T5, temp: Optional[float] = 0.0):
|
||||
|
Loading…
Reference in New Issue
Block a user