mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
format
This commit is contained in:
parent
c468edc4e3
commit
34843ddeb2
@ -14,8 +14,10 @@ SHARED_REPLACEMENT_PATTERNS = [
|
|||||||
(".layer.1.layer_norm.", ".ln2."),
|
(".layer.1.layer_norm.", ".ln2."),
|
||||||
(".layer.2.layer_norm.", ".ln3."),
|
(".layer.2.layer_norm.", ".ln3."),
|
||||||
(".final_layer_norm.", ".ln."),
|
(".final_layer_norm.", ".ln."),
|
||||||
("layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
(
|
||||||
"relative_attention_bias.embeddings."),
|
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||||
|
"relative_attention_bias.embeddings.",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
ENCODER_REPLACEMENT_PATTERNS = [
|
ENCODER_REPLACEMENT_PATTERNS = [
|
||||||
|
51
t5/t5.py
51
t5/t5.py
@ -58,7 +58,9 @@ def _relative_position_bucket(
|
|||||||
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
|
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
|
||||||
relative_position = mx.abs(relative_position)
|
relative_position = mx.abs(relative_position)
|
||||||
else:
|
else:
|
||||||
relative_position = -mx.minimum(relative_position, mx.zeros_like(relative_position))
|
relative_position = -mx.minimum(
|
||||||
|
relative_position, mx.zeros_like(relative_position)
|
||||||
|
)
|
||||||
# now relative_position is in the range [0, inf)
|
# now relative_position is in the range [0, inf)
|
||||||
|
|
||||||
# half of the buckets are for exact increments in positions
|
# half of the buckets are for exact increments in positions
|
||||||
@ -66,10 +68,9 @@ def _relative_position_bucket(
|
|||||||
is_small = relative_position < max_exact
|
is_small = relative_position < max_exact
|
||||||
|
|
||||||
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||||
|
scale = np.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||||
relative_position_if_large = max_exact + (
|
relative_position_if_large = max_exact + (
|
||||||
mx.log(relative_position.astype(mx.float32) / max_exact)
|
mx.log(relative_position.astype(mx.float32) / max_exact) * scale
|
||||||
/ np.log(max_distance / max_exact)
|
|
||||||
* (num_buckets - max_exact)
|
|
||||||
).astype(mx.int16)
|
).astype(mx.int16)
|
||||||
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
|
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
|
||||||
relative_buckets += mx.where(
|
relative_buckets += mx.where(
|
||||||
@ -80,7 +81,7 @@ def _relative_position_bucket(
|
|||||||
|
|
||||||
class RelativePositionBias(nn.Module):
|
class RelativePositionBias(nn.Module):
|
||||||
def __init__(self, config: ModelArgs, bidirectional: bool):
|
def __init__(self, config: ModelArgs, bidirectional: bool):
|
||||||
self.bidirectional = False #bidirectional
|
self.bidirectional = False # bidirectional
|
||||||
self.num_buckets = config.relative_attention_num_buckets
|
self.num_buckets = config.relative_attention_num_buckets
|
||||||
self.max_distance = config.relative_attention_max_distance
|
self.max_distance = config.relative_attention_max_distance
|
||||||
self.n_heads = config.num_heads
|
self.n_heads = config.num_heads
|
||||||
@ -93,11 +94,8 @@ class RelativePositionBias(nn.Module):
|
|||||||
context_position = mx.arange(offset, query_length)[:, None]
|
context_position = mx.arange(offset, query_length)[:, None]
|
||||||
memory_position = mx.arange(key_length)[None, :]
|
memory_position = mx.arange(key_length)[None, :]
|
||||||
|
|
||||||
|
|
||||||
# shape (query_length, key_length)
|
# shape (query_length, key_length)
|
||||||
relative_position = (
|
relative_position = memory_position - context_position
|
||||||
memory_position - context_position
|
|
||||||
)
|
|
||||||
relative_position_bucket = _relative_position_bucket(
|
relative_position_bucket = _relative_position_bucket(
|
||||||
relative_position,
|
relative_position,
|
||||||
bidirectional=self.bidirectional,
|
bidirectional=self.bidirectional,
|
||||||
@ -106,9 +104,7 @@ class RelativePositionBias(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# shape (query_length, key_length, num_heads)
|
# shape (query_length, key_length, num_heads)
|
||||||
values = self.embeddings(
|
values = self.embeddings(relative_position_bucket)
|
||||||
relative_position_bucket
|
|
||||||
)
|
|
||||||
|
|
||||||
# shape (num_heads, query_length, key_length)
|
# shape (num_heads, query_length, key_length)
|
||||||
return values.transpose(2, 0, 1)
|
return values.transpose(2, 0, 1)
|
||||||
@ -130,8 +126,7 @@ class MultiHeadAttention(nn.Module):
|
|||||||
values: mx.array,
|
values: mx.array,
|
||||||
mask: mx.array,
|
mask: mx.array,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||||
) -> mx.array:
|
) -> [mx.array, Tuple[mx.array, mx.array]]:
|
||||||
|
|
||||||
queries = self.query_proj(queries)
|
queries = self.query_proj(queries)
|
||||||
keys = self.key_proj(keys)
|
keys = self.key_proj(keys)
|
||||||
values = self.value_proj(values)
|
values = self.value_proj(values)
|
||||||
@ -191,7 +186,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
y = self.linear1(y)
|
y = self.linear1(y)
|
||||||
y = mx.maximum(y, 0)
|
y = mx.maximum(y, 0)
|
||||||
y = self.linear2(y)
|
y = self.linear2(y)
|
||||||
return x + y
|
return x + y
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(nn.Module):
|
class TransformerEncoder(nn.Module):
|
||||||
@ -203,7 +198,7 @@ class TransformerEncoder(nn.Module):
|
|||||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x: mx.array):
|
||||||
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x, mask=pos_bias)
|
x = layer(x, mask=pos_bias)
|
||||||
@ -228,8 +223,8 @@ class TransformerDecoderLayer(nn.Module):
|
|||||||
memory: mx.array,
|
memory: mx.array,
|
||||||
mask: mx.array,
|
mask: mx.array,
|
||||||
memory_mask: mx.array,
|
memory_mask: mx.array,
|
||||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||||
) -> mx.array:
|
):
|
||||||
y = self.ln1(x)
|
y = self.ln1(x)
|
||||||
y, cache = self.self_attention(y, y, y, mask, cache)
|
y, cache = self.self_attention(y, y, y, mask, cache)
|
||||||
x = x + y
|
x = x + y
|
||||||
@ -278,7 +273,7 @@ class TransformerDecoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class OutputHead(nn.Module):
|
class OutputHead(nn.Module):
|
||||||
def __init__(self, config: ModelArgs) -> None:
|
def __init__(self, config: ModelArgs):
|
||||||
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||||
|
|
||||||
def __call__(self, inputs):
|
def __call__(self, inputs):
|
||||||
@ -292,17 +287,14 @@ class T5(nn.Module):
|
|||||||
self.decoder = TransformerDecoder(config)
|
self.decoder = TransformerDecoder(config)
|
||||||
self.lm_head = OutputHead(config)
|
self.lm_head = OutputHead(config)
|
||||||
|
|
||||||
def encode(
|
def encode(self, inputs: mx.array):
|
||||||
self,
|
|
||||||
inputs: mx.array
|
|
||||||
) -> mx.array:
|
|
||||||
return self.encoder(self.wte(inputs))
|
return self.encoder(self.wte(inputs))
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
memory: mx.array,
|
memory: mx.array,
|
||||||
cache = None,
|
cache=None,
|
||||||
):
|
):
|
||||||
inputs = self.wte(inputs)
|
inputs = self.wte(inputs)
|
||||||
T = inputs.shape[1]
|
T = inputs.shape[1]
|
||||||
@ -321,7 +313,7 @@ class T5(nn.Module):
|
|||||||
self,
|
self,
|
||||||
inputs: mx.array,
|
inputs: mx.array,
|
||||||
decoder_inputs: mx.array,
|
decoder_inputs: mx.array,
|
||||||
) -> mx.array:
|
):
|
||||||
return decode(decoder_inputs, encode(inputs))[0]
|
return decode(decoder_inputs, encode(inputs))[0]
|
||||||
|
|
||||||
|
|
||||||
@ -375,7 +367,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encode-only",
|
"--encode-only",
|
||||||
action='store_true',
|
action="store_true",
|
||||||
default=False,
|
default=False,
|
||||||
help="Whether to decode or not",
|
help="Whether to decode or not",
|
||||||
)
|
)
|
||||||
@ -425,14 +417,13 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
tokens = []
|
tokens = []
|
||||||
for token, n_tokens in zip(
|
for token, n_tokens in zip(
|
||||||
generate(prompt, decoder_inputs, model, args.temp),
|
generate(prompt, decoder_inputs, model, args.temp), range(args.max_tokens)
|
||||||
range(args.max_tokens)
|
|
||||||
):
|
):
|
||||||
if token.item() == tokenizer.eos_token_id:
|
if token.item() == tokenizer.eos_token_id:
|
||||||
break
|
break
|
||||||
tokens.append(token.item())
|
tokens.append(token.item())
|
||||||
# For some reason using the following line doesn't give spaces
|
# For some reason using the following line doesn't give spaces
|
||||||
# print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True)
|
# print(tokenizer.decode(token.item(), clean_up_tokenization_spaces=False), end="", flush=True)
|
||||||
print(tokenizer.decode(tokens), end="", flush=True)
|
print(tokenizer.decode(tokens), end="", flush=True)
|
||||||
|
|
||||||
end = perf_counter_ns()
|
end = perf_counter_ns()
|
||||||
|
Loading…
Reference in New Issue
Block a user