This commit is contained in:
Awni Hannun 2023-12-17 21:30:28 -08:00
parent c468edc4e3
commit 34843ddeb2
2 changed files with 25 additions and 32 deletions

View File

@ -14,8 +14,10 @@ SHARED_REPLACEMENT_PATTERNS = [
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
("layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]
ENCODER_REPLACEMENT_PATTERNS = [

View File

@ -58,7 +58,9 @@ def _relative_position_bucket(
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
relative_position = mx.abs(relative_position)
else:
relative_position = -mx.minimum(relative_position, mx.zeros_like(relative_position))
relative_position = -mx.minimum(
relative_position, mx.zeros_like(relative_position)
)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
@ -66,10 +68,9 @@ def _relative_position_bucket(
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
scale = np.log(max_distance / max_exact) * (num_buckets - max_exact)
relative_position_if_large = max_exact + (
mx.log(relative_position.astype(mx.float32) / max_exact)
/ np.log(max_distance / max_exact)
* (num_buckets - max_exact)
mx.log(relative_position.astype(mx.float32) / max_exact) * scale
).astype(mx.int16)
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
relative_buckets += mx.where(
@ -80,7 +81,7 @@ def _relative_position_bucket(
class RelativePositionBias(nn.Module):
def __init__(self, config: ModelArgs, bidirectional: bool):
self.bidirectional = False #bidirectional
self.bidirectional = False # bidirectional
self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.relative_attention_max_distance
self.n_heads = config.num_heads
@ -93,11 +94,8 @@ class RelativePositionBias(nn.Module):
context_position = mx.arange(offset, query_length)[:, None]
memory_position = mx.arange(key_length)[None, :]
# shape (query_length, key_length)
relative_position = (
memory_position - context_position
)
relative_position = memory_position - context_position
relative_position_bucket = _relative_position_bucket(
relative_position,
bidirectional=self.bidirectional,
@ -106,9 +104,7 @@ class RelativePositionBias(nn.Module):
)
# shape (query_length, key_length, num_heads)
values = self.embeddings(
relative_position_bucket
)
values = self.embeddings(relative_position_bucket)
# shape (num_heads, query_length, key_length)
return values.transpose(2, 0, 1)
@ -130,8 +126,7 @@ class MultiHeadAttention(nn.Module):
values: mx.array,
mask: mx.array,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
) -> [mx.array, Tuple[mx.array, mx.array]]:
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
@ -203,7 +198,7 @@ class TransformerEncoder(nn.Module):
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
def __call__(self, x):
def __call__(self, x: mx.array):
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
for layer in self.layers:
x = layer(x, mask=pos_bias)
@ -228,8 +223,8 @@ class TransformerDecoderLayer(nn.Module):
memory: mx.array,
mask: mx.array,
memory_mask: mx.array,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
):
y = self.ln1(x)
y, cache = self.self_attention(y, y, y, mask, cache)
x = x + y
@ -278,7 +273,7 @@ class TransformerDecoder(nn.Module):
class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None:
def __init__(self, config: ModelArgs):
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
def __call__(self, inputs):
@ -292,17 +287,14 @@ class T5(nn.Module):
self.decoder = TransformerDecoder(config)
self.lm_head = OutputHead(config)
def encode(
self,
inputs: mx.array
) -> mx.array:
def encode(self, inputs: mx.array):
return self.encoder(self.wte(inputs))
def decode(
self,
inputs: mx.array,
memory: mx.array,
cache = None,
cache=None,
):
inputs = self.wte(inputs)
T = inputs.shape[1]
@ -321,7 +313,7 @@ class T5(nn.Module):
self,
inputs: mx.array,
decoder_inputs: mx.array,
) -> mx.array:
):
return decode(decoder_inputs, encode(inputs))[0]
@ -375,7 +367,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--encode-only",
action='store_true',
action="store_true",
default=False,
help="Whether to decode or not",
)
@ -425,14 +417,13 @@ if __name__ == "__main__":
tokens = []
for token, n_tokens in zip(
generate(prompt, decoder_inputs, model, args.temp),
range(args.max_tokens)
generate(prompt, decoder_inputs, model, args.temp), range(args.max_tokens)
):
if token.item() == tokenizer.eos_token_id:
break
tokens.append(token.item())
# For some reason using the following line doesn't give spaces
# print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True)
# For some reason using the following line doesn't give spaces
# print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True)
print(tokenizer.decode(tokens), end="", flush=True)
end = perf_counter_ns()