This commit is contained in:
Juarez Bochi 2023-12-16 14:44:15 -05:00
parent d12db65eeb
commit 31da1b0dab
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6
2 changed files with 15 additions and 9 deletions

View File

@ -9,6 +9,7 @@ SHARED_REPLACEMENT_PATTERNS = [
(".q.", ".query_proj."), (".q.", ".query_proj."),
(".v.", ".value_proj."), (".v.", ".value_proj."),
("shared.", "wte."), ("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."), (".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."), (".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."), (".layer.2.layer_norm.", ".ln3."),

View File

@ -1,5 +1,4 @@
import argparse import argparse
import math
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np import numpy as np
@ -259,32 +258,38 @@ class TransformerDecoder(nn.Module):
return x return x
class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None:
self.linear = nn.Linear(config.d_model, config.vocab_size)
def __call__(self, inputs):
return self.linear(inputs)
class T5(nn.Module): class T5(nn.Module):
def __init__(self, config: ModelArgs): def __init__(self, config: ModelArgs):
self.wte = nn.Embedding(config.vocab_size, config.d_model) self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = TransformerEncoder(config) self.encoder = TransformerEncoder(config)
self.decoder = TransformerDecoder(config) self.decoder = TransformerDecoder(config)
# self.lm_head = OutputHead(config) self.lm_head = OutputHead(config)
def __call__( def __call__(
self, self,
inputs: mx.array, inputs: mx.array,
decoder_inputs: mx.array,
mask: mx.array = None, mask: mx.array = None,
cache: mx.array = None, cache: mx.array = None,
) -> tuple[mx.array, mx.array]: ) -> tuple[mx.array, mx.array]:
x = self.wte(inputs) x = self.wte(inputs)
y = self.encoder(x, mask=None) #, cache) y = self.encoder(x, mask=None) #, cache)
mask = None if x.shape[1] > 1 and mask is None:
if x.shape[1] > 1:
mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1]) mask = MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(x.dtype) mask = mask.astype(x.dtype)
# y, cache = self.decoder(x, mask, cache) decoder_inputs = self.wte(decoder_inputs)
# return self.lm_head(y), cache y, cache = self.decoder(x=decoder_inputs, x_mask=mask, memory=y) #, cache)
return y #, cache return self.lm_head(y), cache
# def generate(prompt: mx.array, model: T5, temp: Optional[float] = 0.0): # def generate(prompt: mx.array, model: T5, temp: Optional[float] = 0.0):