mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
LM head
This commit is contained in:
parent
d12db65eeb
commit
31da1b0dab
@ -9,6 +9,7 @@ SHARED_REPLACEMENT_PATTERNS = [
|
|||||||
(".q.", ".query_proj."),
|
(".q.", ".query_proj."),
|
||||||
(".v.", ".value_proj."),
|
(".v.", ".value_proj."),
|
||||||
("shared.", "wte."),
|
("shared.", "wte."),
|
||||||
|
("lm_head.", "lm_head.linear."),
|
||||||
(".layer.0.layer_norm.", ".ln1."),
|
(".layer.0.layer_norm.", ".ln1."),
|
||||||
(".layer.1.layer_norm.", ".ln2."),
|
(".layer.1.layer_norm.", ".ln2."),
|
||||||
(".layer.2.layer_norm.", ".ln3."),
|
(".layer.2.layer_norm.", ".ln3."),
|
||||||
|
23
t5/t5.py
23
t5/t5.py
@ -1,5 +1,4 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import math
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -259,32 +258,38 @@ class TransformerDecoder(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class OutputHead(nn.Module):
|
||||||
|
def __init__(self, config: ModelArgs) -> None:
|
||||||
|
self.linear = nn.Linear(config.d_model, config.vocab_size)
|
||||||
|
|
||||||
|
def __call__(self, inputs):
|
||||||
|
return self.linear(inputs)
|
||||||
|
|
||||||
|
|
||||||
class T5(nn.Module):
|
class T5(nn.Module):
|
||||||
def __init__(self, config: ModelArgs):
|
def __init__(self, config: ModelArgs):
|
||||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||||
self.encoder = TransformerEncoder(config)
|
self.encoder = TransformerEncoder(config)
|
||||||
self.decoder = TransformerDecoder(config)
|
self.decoder = TransformerDecoder(config)
|
||||||
# self.lm_head = OutputHead(config)
|
self.lm_head = OutputHead(config)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
|
decoder_inputs: mx.array,
|
||||||
mask: mx.array = None,
|
mask: mx.array = None,
|
||||||
cache: mx.array = None,
|
cache: mx.array = None,
|
||||||
) -> tuple[mx.array, mx.array]:
|
) -> tuple[mx.array, mx.array]:
|
||||||
x = self.wte(inputs)
|
x = self.wte(inputs)
|
||||||
|
|
||||||
|
|
||||||
y = self.encoder(x, mask=None) #, cache)
|
y = self.encoder(x, mask=None) #, cache)
|
||||||
|
|
||||||
mask = None
|
if x.shape[1] > 1 and mask is None:
|
||||||
if x.shape[1] > 1:
|
|
||||||
mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1])
|
||||||
mask = mask.astype(x.dtype)
|
mask = mask.astype(x.dtype)
|
||||||
|
|
||||||
# y, cache = self.decoder(x, mask, cache)
|
decoder_inputs = self.wte(decoder_inputs)
|
||||||
# return self.lm_head(y), cache
|
y, cache = self.decoder(x=decoder_inputs, x_mask=mask, memory=y) #, cache)
|
||||||
return y #, cache
|
return self.lm_head(y), cache
|
||||||
|
|
||||||
|
|
||||||
# def generate(prompt: mx.array, model: T5, temp: Optional[float] = 0.0):
|
# def generate(prompt: mx.array, model: T5, temp: Optional[float] = 0.0):
|
||||||
|
Loading…
Reference in New Issue
Block a user