mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 10:56:38 +08:00
translate pytorch to mx
This commit is contained in:
parent
330f024d1c
commit
392b7a2f98
@ -13,6 +13,10 @@ SHARED_REPLACEMENT_PATTERNS = [
|
|||||||
(".layer.1.layer_norm.", ".ln2."),
|
(".layer.1.layer_norm.", ".ln2."),
|
||||||
(".layer.2.layer_norm.", ".ln3."),
|
(".layer.2.layer_norm.", ".ln3."),
|
||||||
(".final_layer_norm.", ".ln."),
|
(".final_layer_norm.", ".ln."),
|
||||||
|
(
|
||||||
|
".relative_attention_bias.",
|
||||||
|
".relative_attention_bias.embeddings."
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
ENCODER_REPLACEMENT_PATTERNS = [
|
ENCODER_REPLACEMENT_PATTERNS = [
|
||||||
|
41
t5/t5.py
41
t5/t5.py
@ -2,6 +2,7 @@ import argparse
|
|||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from mlx.utils import tree_flatten, tree_unflatten
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
@ -49,7 +50,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
|
|||||||
relative_buckets = 0
|
relative_buckets = 0
|
||||||
if bidirectional:
|
if bidirectional:
|
||||||
num_buckets //= 2
|
num_buckets //= 2
|
||||||
relative_buckets += (relative_position > 0).to(mx.long) * num_buckets
|
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
|
||||||
relative_position = mx.abs(relative_position)
|
relative_position = mx.abs(relative_position)
|
||||||
else:
|
else:
|
||||||
relative_position = -mx.min(relative_position, mx.zeros_like(relative_position))
|
relative_position = -mx.min(relative_position, mx.zeros_like(relative_position))
|
||||||
@ -60,13 +61,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
|
|||||||
is_small = relative_position < max_exact
|
is_small = relative_position < max_exact
|
||||||
|
|
||||||
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
|
||||||
|
print("relative_position", relative_position)
|
||||||
relative_position_if_large = max_exact + (
|
relative_position_if_large = max_exact + (
|
||||||
mx.log(relative_position.float() / max_exact)
|
mx.log(relative_position.astype(mx.float32) / max_exact)
|
||||||
/ mx.log(max_distance / max_exact)
|
/ np.log(max_distance / max_exact)
|
||||||
* (num_buckets - max_exact)
|
* (num_buckets - max_exact)
|
||||||
).to(mx.long)
|
).astype(mx.int16)
|
||||||
relative_position_if_large = mx.min(
|
relative_position_if_large = mx.minimum(
|
||||||
relative_position_if_large, mx.full_like(relative_position_if_large, num_buckets - 1)
|
relative_position_if_large, num_buckets - 1
|
||||||
)
|
)
|
||||||
relative_buckets += mx.where(is_small, relative_position, relative_position_if_large)
|
relative_buckets += mx.where(is_small, relative_position, relative_position_if_large)
|
||||||
return relative_buckets
|
return relative_buckets
|
||||||
@ -78,20 +80,23 @@ class RelativePositionBias(nn.Module):
|
|||||||
self.num_buckets = config.relative_attention_num_buckets
|
self.num_buckets = config.relative_attention_num_buckets
|
||||||
self.max_distance = config.n_positions
|
self.max_distance = config.n_positions
|
||||||
self.n_heads = config.num_heads
|
self.n_heads = config.num_heads
|
||||||
|
self.embeddings = nn.Embedding(
|
||||||
|
config.relative_attention_num_buckets,
|
||||||
|
config.num_heads)
|
||||||
|
|
||||||
def compute_bias(self, query_length, key_length):
|
def __call__(self, query_length, key_length):
|
||||||
"""Compute binned relative position bias"""
|
"""Compute binned relative position bias"""
|
||||||
context_position = mx.arange(query_length, dtype=mx.long)[:, None]
|
context_position = mx.arange(query_length, dtype=mx.int32)[:, None]
|
||||||
memory_position = mx.arange(key_length, dtype=mx.long)[None, :]
|
memory_position = mx.arange(key_length, dtype=mx.int32)[None, :]
|
||||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||||
relative_position_bucket = _relative_position_bucket(
|
relative_position_bucket = _relative_position_bucket(
|
||||||
relative_position, # shape (query_length, key_length)
|
relative_position, # shape (query_length, key_length)
|
||||||
bidirectional=self.bidirectional,
|
bidirectional=self.bidirectional,
|
||||||
num_buckets=self.relative_attention_num_buckets,
|
num_buckets=self.num_buckets,
|
||||||
max_distance=self.relative_attention_max_distance,
|
max_distance=self.max_distance,
|
||||||
)
|
)
|
||||||
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
values = self.embeddings(relative_position_bucket) # shape (query_length, key_length, num_heads)
|
||||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
values = mx.expand_dims(values.transpose(2, 0, 1), 0) # shape (1, num_heads, query_length, key_length)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
@ -103,10 +108,9 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||||
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||||
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
||||||
|
self.has_relative_attention_bias = has_relative_attention_bias
|
||||||
if has_relative_attention_bias:
|
if has_relative_attention_bias:
|
||||||
self.relative_attention_bias = nn.Embedding(
|
self.relative_attention_bias = RelativePositionBias(config)
|
||||||
config.relative_attention_num_buckets,
|
|
||||||
config.num_heads)
|
|
||||||
|
|
||||||
def __call__(self, queries, keys, values, mask=None):
|
def __call__(self, queries, keys, values, mask=None):
|
||||||
queries = self.query_proj(queries)
|
queries = self.query_proj(queries)
|
||||||
@ -125,6 +129,11 @@ class MultiHeadAttention(nn.Module):
|
|||||||
scores = (queries * scale) @ keys
|
scores = (queries * scale) @ keys
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores = scores + mask.astype(scores.dtype)
|
scores = scores + mask.astype(scores.dtype)
|
||||||
|
|
||||||
|
if self.has_relative_attention_bias:
|
||||||
|
position_bias = self.relative_attention_bias(L, S)
|
||||||
|
scores += position_bias
|
||||||
|
|
||||||
scores = mx.softmax(scores, axis=-1)
|
scores = mx.softmax(scores, axis=-1)
|
||||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user