This commit is contained in:
Awni Hannun 2023-12-17 21:30:28 -08:00
parent c468edc4e3
commit 34843ddeb2
2 changed files with 25 additions and 32 deletions

View File

@ -14,8 +14,10 @@ SHARED_REPLACEMENT_PATTERNS = [
(".layer.1.layer_norm.", ".ln2."), (".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."), (".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."), (".final_layer_norm.", ".ln."),
("layers.0.layer.0.SelfAttention.relative_attention_bias.", (
"relative_attention_bias.embeddings."), "layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
] ]
ENCODER_REPLACEMENT_PATTERNS = [ ENCODER_REPLACEMENT_PATTERNS = [

View File

@ -58,7 +58,9 @@ def _relative_position_bucket(
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
relative_position = mx.abs(relative_position) relative_position = mx.abs(relative_position)
else: else:
relative_position = -mx.minimum(relative_position, mx.zeros_like(relative_position)) relative_position = -mx.minimum(
relative_position, mx.zeros_like(relative_position)
)
# now relative_position is in the range [0, inf) # now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions # half of the buckets are for exact increments in positions
@ -66,10 +68,9 @@ def _relative_position_bucket(
is_small = relative_position < max_exact is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
scale = np.log(max_distance / max_exact) * (num_buckets - max_exact)
relative_position_if_large = max_exact + ( relative_position_if_large = max_exact + (
mx.log(relative_position.astype(mx.float32) / max_exact) mx.log(relative_position.astype(mx.float32) / max_exact) * scale
/ np.log(max_distance / max_exact)
* (num_buckets - max_exact)
).astype(mx.int16) ).astype(mx.int16)
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1) relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
relative_buckets += mx.where( relative_buckets += mx.where(
@ -80,7 +81,7 @@ def _relative_position_bucket(
class RelativePositionBias(nn.Module): class RelativePositionBias(nn.Module):
def __init__(self, config: ModelArgs, bidirectional: bool): def __init__(self, config: ModelArgs, bidirectional: bool):
self.bidirectional = False #bidirectional self.bidirectional = False # bidirectional
self.num_buckets = config.relative_attention_num_buckets self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.relative_attention_max_distance self.max_distance = config.relative_attention_max_distance
self.n_heads = config.num_heads self.n_heads = config.num_heads
@ -93,11 +94,8 @@ class RelativePositionBias(nn.Module):
context_position = mx.arange(offset, query_length)[:, None] context_position = mx.arange(offset, query_length)[:, None]
memory_position = mx.arange(key_length)[None, :] memory_position = mx.arange(key_length)[None, :]
# shape (query_length, key_length) # shape (query_length, key_length)
relative_position = ( relative_position = memory_position - context_position
memory_position - context_position
)
relative_position_bucket = _relative_position_bucket( relative_position_bucket = _relative_position_bucket(
relative_position, relative_position,
bidirectional=self.bidirectional, bidirectional=self.bidirectional,
@ -106,9 +104,7 @@ class RelativePositionBias(nn.Module):
) )
# shape (query_length, key_length, num_heads) # shape (query_length, key_length, num_heads)
values = self.embeddings( values = self.embeddings(relative_position_bucket)
relative_position_bucket
)
# shape (num_heads, query_length, key_length) # shape (num_heads, query_length, key_length)
return values.transpose(2, 0, 1) return values.transpose(2, 0, 1)
@ -130,8 +126,7 @@ class MultiHeadAttention(nn.Module):
values: mx.array, values: mx.array,
mask: mx.array, mask: mx.array,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array: ) -> [mx.array, Tuple[mx.array, mx.array]]:
queries = self.query_proj(queries) queries = self.query_proj(queries)
keys = self.key_proj(keys) keys = self.key_proj(keys)
values = self.value_proj(values) values = self.value_proj(values)
@ -191,7 +186,7 @@ class TransformerEncoderLayer(nn.Module):
y = self.linear1(y) y = self.linear1(y)
y = mx.maximum(y, 0) y = mx.maximum(y, 0)
y = self.linear2(y) y = self.linear2(y)
return x + y return x + y
class TransformerEncoder(nn.Module): class TransformerEncoder(nn.Module):
@ -203,7 +198,7 @@ class TransformerEncoder(nn.Module):
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True) self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
def __call__(self, x): def __call__(self, x: mx.array):
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1]) pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
for layer in self.layers: for layer in self.layers:
x = layer(x, mask=pos_bias) x = layer(x, mask=pos_bias)
@ -228,8 +223,8 @@ class TransformerDecoderLayer(nn.Module):
memory: mx.array, memory: mx.array,
mask: mx.array, mask: mx.array,
memory_mask: mx.array, memory_mask: mx.array,
cache: Optional[Tuple[mx.array, mx.array]] = None, cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
) -> mx.array: ):
y = self.ln1(x) y = self.ln1(x)
y, cache = self.self_attention(y, y, y, mask, cache) y, cache = self.self_attention(y, y, y, mask, cache)
x = x + y x = x + y
@ -278,7 +273,7 @@ class TransformerDecoder(nn.Module):
class OutputHead(nn.Module): class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None: def __init__(self, config: ModelArgs):
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False) self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
def __call__(self, inputs): def __call__(self, inputs):
@ -292,17 +287,14 @@ class T5(nn.Module):
self.decoder = TransformerDecoder(config) self.decoder = TransformerDecoder(config)
self.lm_head = OutputHead(config) self.lm_head = OutputHead(config)
def encode( def encode(self, inputs: mx.array):
self,
inputs: mx.array
) -> mx.array:
return self.encoder(self.wte(inputs)) return self.encoder(self.wte(inputs))
def decode( def decode(
self, self,
inputs: mx.array, inputs: mx.array,
memory: mx.array, memory: mx.array,
cache = None, cache=None,
): ):
inputs = self.wte(inputs) inputs = self.wte(inputs)
T = inputs.shape[1] T = inputs.shape[1]
@ -321,7 +313,7 @@ class T5(nn.Module):
self, self,
inputs: mx.array, inputs: mx.array,
decoder_inputs: mx.array, decoder_inputs: mx.array,
) -> mx.array: ):
return decode(decoder_inputs, encode(inputs))[0] return decode(decoder_inputs, encode(inputs))[0]
@ -375,7 +367,7 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"--encode-only", "--encode-only",
action='store_true', action="store_true",
default=False, default=False,
help="Whether to decode or not", help="Whether to decode or not",
) )
@ -425,14 +417,13 @@ if __name__ == "__main__":
tokens = [] tokens = []
for token, n_tokens in zip( for token, n_tokens in zip(
generate(prompt, decoder_inputs, model, args.temp), generate(prompt, decoder_inputs, model, args.temp), range(args.max_tokens)
range(args.max_tokens)
): ):
if token.item() == tokenizer.eos_token_id: if token.item() == tokenizer.eos_token_id:
break break
tokens.append(token.item()) tokens.append(token.item())
# For some reason using the following line doesn't give spaces # For some reason using the following line doesn't give spaces
# print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True) # print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True)
print(tokenizer.decode(tokens), end="", flush=True) print(tokenizer.decode(tokens), end="", flush=True)
end = perf_counter_ns() end = perf_counter_ns()