mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
translate pytorch to mx
This commit is contained in:
parent
330f024d1c
commit
392b7a2f98
@ -13,6 +13,10 @@ SHARED_REPLACEMENT_PATTERNS = [
|
||||
(".layer.1.layer_norm.", ".ln2."),
|
||||
(".layer.2.layer_norm.", ".ln3."),
|
||||
(".final_layer_norm.", ".ln."),
|
||||
(
|
||||
".relative_attention_bias.",
|
||||
".relative_attention_bias.embeddings."
|
||||
),
|
||||
]
|
||||
|
||||
ENCODER_REPLACEMENT_PATTERNS = [
|
||||
|
41
t5/t5.py
41
t5/t5.py
@ -2,6 +2,7 @@ import argparse
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
@ -49,7 +50,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
|
||||
relative_buckets = 0
|
||||
if bidirectional:
|
||||
num_buckets //= 2
|
||||
relative_buckets += (relative_position > 0).to(mx.long) * num_buckets
|
||||
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
|
||||
relative_position = mx.abs(relative_position)
|
||||
else:
|
||||
relative_position = -mx.min(relative_position, mx.zeros_like(relative_position))
|
||||
@ -60,13 +61,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
|
||||
is_small = relative_position < max_exact
|
||||
|
||||
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||
print("relative_position", relative_position)
|
||||
relative_position_if_large = max_exact + (
|
||||
mx.log(relative_position.float() / max_exact)
|
||||
/ mx.log(max_distance / max_exact)
|
||||
mx.log(relative_position.astype(mx.float32) / max_exact)
|
||||
/ np.log(max_distance / max_exact)
|
||||
* (num_buckets - max_exact)
|
||||
).to(mx.long)
|
||||
relative_position_if_large = mx.min(
|
||||
relative_position_if_large, mx.full_like(relative_position_if_large, num_buckets - 1)
|
||||
).astype(mx.int16)
|
||||
relative_position_if_large = mx.minimum(
|
||||
relative_position_if_large, num_buckets - 1
|
||||
)
|
||||
relative_buckets += mx.where(is_small, relative_position, relative_position_if_large)
|
||||
return relative_buckets
|
||||
@ -78,20 +80,23 @@ class RelativePositionBias(nn.Module):
|
||||
self.num_buckets = config.relative_attention_num_buckets
|
||||
self.max_distance = config.n_positions
|
||||
self.n_heads = config.num_heads
|
||||
self.embeddings = nn.Embedding(
|
||||
config.relative_attention_num_buckets,
|
||||
config.num_heads)
|
||||
|
||||
def compute_bias(self, query_length, key_length):
|
||||
def __call__(self, query_length, key_length):
|
||||
"""Compute binned relative position bias"""
|
||||
context_position = mx.arange(query_length, dtype=mx.long)[:, None]
|
||||
memory_position = mx.arange(key_length, dtype=mx.long)[None, :]
|
||||
context_position = mx.arange(query_length, dtype=mx.int32)[:, None]
|
||||
memory_position = mx.arange(key_length, dtype=mx.int32)[None, :]
|
||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||
relative_position_bucket = _relative_position_bucket(
|
||||
relative_position, # shape (query_length, key_length)
|
||||
bidirectional=self.bidirectional,
|
||||
num_buckets=self.relative_attention_num_buckets,
|
||||
max_distance=self.relative_attention_max_distance,
|
||||
num_buckets=self.num_buckets,
|
||||
max_distance=self.max_distance,
|
||||
)
|
||||
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||
values = self.embeddings(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||
values = mx.expand_dims(values.transpose(2, 0, 1), 0) # shape (1, num_heads, query_length, key_length)
|
||||
return values
|
||||
|
||||
|
||||
@ -103,10 +108,9 @@ class MultiHeadAttention(nn.Module):
|
||||
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||
self.has_relative_attention_bias = has_relative_attention_bias
|
||||
if has_relative_attention_bias:
|
||||
self.relative_attention_bias = nn.Embedding(
|
||||
config.relative_attention_num_buckets,
|
||||
config.num_heads)
|
||||
self.relative_attention_bias = RelativePositionBias(config)
|
||||
|
||||
def __call__(self, queries, keys, values, mask=None):
|
||||
queries = self.query_proj(queries)
|
||||
@ -125,6 +129,11 @@ class MultiHeadAttention(nn.Module):
|
||||
scores = (queries * scale) @ keys
|
||||
if mask is not None:
|
||||
scores = scores + mask.astype(scores.dtype)
|
||||
|
||||
if self.has_relative_attention_bias:
|
||||
position_bias = self.relative_attention_bias(L, S)
|
||||
scores += position_bias
|
||||
|
||||
scores = mx.softmax(scores, axis=-1)
|
||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user