diff --git a/t5/convert.py b/t5/convert.py index 1de73150..e10e4e45 100644 --- a/t5/convert.py +++ b/t5/convert.py @@ -9,6 +9,7 @@ SHARED_REPLACEMENT_PATTERNS = [ (".q.", ".query_proj."), (".v.", ".value_proj."), ("shared.", "wte."), + ("lm_head.", "lm_head.linear."), (".layer.0.layer_norm.", ".ln1."), (".layer.1.layer_norm.", ".ln2."), (".layer.2.layer_norm.", ".ln3."), diff --git a/t5/t5.py b/t5/t5.py index 6e3374a6..cbecbb62 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -1,5 +1,4 @@ import argparse -import math from dataclasses import dataclass import numpy as np @@ -259,32 +258,38 @@ class TransformerDecoder(nn.Module): return x +class OutputHead(nn.Module): + def __init__(self, config: ModelArgs) -> None: + self.linear = nn.Linear(config.d_model, config.vocab_size) + + def __call__(self, inputs): + return self.linear(inputs) + + class T5(nn.Module): def __init__(self, config: ModelArgs): self.wte = nn.Embedding(config.vocab_size, config.d_model) self.encoder = TransformerEncoder(config) self.decoder = TransformerDecoder(config) - # self.lm_head = OutputHead(config) + self.lm_head = OutputHead(config) def __call__( self, inputs: mx.array, + decoder_inputs: mx.array, mask: mx.array = None, cache: mx.array = None, ) -> tuple[mx.array, mx.array]: x = self.wte(inputs) - - y = self.encoder(x, mask=None) #, cache) - mask = None - if x.shape[1] > 1: + if x.shape[1] > 1 and mask is None: mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = mask.astype(x.dtype) - # y, cache = self.decoder(x, mask, cache) - # return self.lm_head(y), cache - return y #, cache + decoder_inputs = self.wte(decoder_inputs) + y, cache = self.decoder(x=decoder_inputs, x_mask=mask, memory=y) #, cache) + return self.lm_head(y), cache # def generate(prompt: mx.array, model: T5, temp: Optional[float] = 0.0):