mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
format
This commit is contained in:
parent
c468edc4e3
commit
34843ddeb2
@ -14,8 +14,10 @@ SHARED_REPLACEMENT_PATTERNS = [
|
||||
(".layer.1.layer_norm.", ".ln2."),
|
||||
(".layer.2.layer_norm.", ".ln3."),
|
||||
(".final_layer_norm.", ".ln."),
|
||||
("layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||
"relative_attention_bias.embeddings."),
|
||||
(
|
||||
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
|
||||
"relative_attention_bias.embeddings.",
|
||||
),
|
||||
]
|
||||
|
||||
ENCODER_REPLACEMENT_PATTERNS = [
|
||||
|
41
t5/t5.py
41
t5/t5.py
@ -58,7 +58,9 @@ def _relative_position_bucket(
|
||||
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
|
||||
relative_position = mx.abs(relative_position)
|
||||
else:
|
||||
relative_position = -mx.minimum(relative_position, mx.zeros_like(relative_position))
|
||||
relative_position = -mx.minimum(
|
||||
relative_position, mx.zeros_like(relative_position)
|
||||
)
|
||||
# now relative_position is in the range [0, inf)
|
||||
|
||||
# half of the buckets are for exact increments in positions
|
||||
@ -66,10 +68,9 @@ def _relative_position_bucket(
|
||||
is_small = relative_position < max_exact
|
||||
|
||||
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||
scale = np.log(max_distance / max_exact) * (num_buckets - max_exact)
|
||||
relative_position_if_large = max_exact + (
|
||||
mx.log(relative_position.astype(mx.float32) / max_exact)
|
||||
/ np.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
mx.log(relative_position.astype(mx.float32) / max_exact) * scale
|
||||
).astype(mx.int16)
|
||||
relative_position_if_large = mx.minimum(relative_position_if_large, num_buckets - 1)
|
||||
relative_buckets += mx.where(
|
||||
@ -93,11 +94,8 @@ class RelativePositionBias(nn.Module):
|
||||
context_position = mx.arange(offset, query_length)[:, None]
|
||||
memory_position = mx.arange(key_length)[None, :]
|
||||
|
||||
|
||||
# shape (query_length, key_length)
|
||||
relative_position = (
|
||||
memory_position - context_position
|
||||
)
|
||||
relative_position = memory_position - context_position
|
||||
relative_position_bucket = _relative_position_bucket(
|
||||
relative_position,
|
||||
bidirectional=self.bidirectional,
|
||||
@ -106,9 +104,7 @@ class RelativePositionBias(nn.Module):
|
||||
)
|
||||
|
||||
# shape (query_length, key_length, num_heads)
|
||||
values = self.embeddings(
|
||||
relative_position_bucket
|
||||
)
|
||||
values = self.embeddings(relative_position_bucket)
|
||||
|
||||
# shape (num_heads, query_length, key_length)
|
||||
return values.transpose(2, 0, 1)
|
||||
@ -130,8 +126,7 @@ class MultiHeadAttention(nn.Module):
|
||||
values: mx.array,
|
||||
mask: mx.array,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
|
||||
) -> [mx.array, Tuple[mx.array, mx.array]]:
|
||||
queries = self.query_proj(queries)
|
||||
keys = self.key_proj(keys)
|
||||
values = self.value_proj(values)
|
||||
@ -203,7 +198,7 @@ class TransformerEncoder(nn.Module):
|
||||
self.ln = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
|
||||
|
||||
def __call__(self, x):
|
||||
def __call__(self, x: mx.array):
|
||||
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask=pos_bias)
|
||||
@ -228,8 +223,8 @@ class TransformerDecoderLayer(nn.Module):
|
||||
memory: mx.array,
|
||||
mask: mx.array,
|
||||
memory_mask: mx.array,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> mx.array:
|
||||
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
):
|
||||
y = self.ln1(x)
|
||||
y, cache = self.self_attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
@ -278,7 +273,7 @@ class TransformerDecoder(nn.Module):
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
def __init__(self, config: ModelArgs):
|
||||
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, inputs):
|
||||
@ -292,10 +287,7 @@ class T5(nn.Module):
|
||||
self.decoder = TransformerDecoder(config)
|
||||
self.lm_head = OutputHead(config)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
inputs: mx.array
|
||||
) -> mx.array:
|
||||
def encode(self, inputs: mx.array):
|
||||
return self.encoder(self.wte(inputs))
|
||||
|
||||
def decode(
|
||||
@ -321,7 +313,7 @@ class T5(nn.Module):
|
||||
self,
|
||||
inputs: mx.array,
|
||||
decoder_inputs: mx.array,
|
||||
) -> mx.array:
|
||||
):
|
||||
return decode(decoder_inputs, encode(inputs))[0]
|
||||
|
||||
|
||||
@ -375,7 +367,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encode-only",
|
||||
action='store_true',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to decode or not",
|
||||
)
|
||||
@ -425,8 +417,7 @@ if __name__ == "__main__":
|
||||
|
||||
tokens = []
|
||||
for token, n_tokens in zip(
|
||||
generate(prompt, decoder_inputs, model, args.temp),
|
||||
range(args.max_tokens)
|
||||
generate(prompt, decoder_inputs, model, args.temp), range(args.max_tokens)
|
||||
):
|
||||
if token.item() == tokenizer.eos_token_id:
|
||||
break
|
||||
|
Loading…
Reference in New Issue
Block a user