translate pytorch to mx

This commit is contained in:
Juarez Bochi 2023-12-15 16:51:01 -05:00
parent 330f024d1c
commit 392b7a2f98
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6
2 changed files with 29 additions and 16 deletions

View File

@ -13,6 +13,10 @@ SHARED_REPLACEMENT_PATTERNS = [
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
".relative_attention_bias.",
".relative_attention_bias.embeddings."
),
]
ENCODER_REPLACEMENT_PATTERNS = [

View File

@ -2,6 +2,7 @@ import argparse
import math
from dataclasses import dataclass
import numpy as np
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten
@ -49,7 +50,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(mx.long) * num_buckets
relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
relative_position = mx.abs(relative_position)
else:
relative_position = -mx.min(relative_position, mx.zeros_like(relative_position))
@ -60,13 +61,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
print("relative_position", relative_position)
relative_position_if_large = max_exact + (
mx.log(relative_position.float() / max_exact)
/ mx.log(max_distance / max_exact)
mx.log(relative_position.astype(mx.float32) / max_exact)
/ np.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(mx.long)
relative_position_if_large = mx.min(
relative_position_if_large, mx.full_like(relative_position_if_large, num_buckets - 1)
).astype(mx.int16)
relative_position_if_large = mx.minimum(
relative_position_if_large, num_buckets - 1
)
relative_buckets += mx.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
@ -78,20 +80,23 @@ class RelativePositionBias(nn.Module):
self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.n_positions
self.n_heads = config.num_heads
self.embeddings = nn.Embedding(
config.relative_attention_num_buckets,
config.num_heads)
def compute_bias(self, query_length, key_length):
def __call__(self, query_length, key_length):
"""Compute binned relative position bias"""
context_position = mx.arange(query_length, dtype=mx.long)[:, None]
memory_position = mx.arange(key_length, dtype=mx.long)[None, :]
context_position = mx.arange(query_length, dtype=mx.int32)[:, None]
memory_position = mx.arange(key_length, dtype=mx.int32)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = _relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=self.bidirectional,
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
values = self.embeddings(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = mx.expand_dims(values.transpose(2, 0, 1), 0) # shape (1, num_heads, query_length, key_length)
return values
@ -103,10 +108,9 @@ class MultiHeadAttention(nn.Module):
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.has_relative_attention_bias = has_relative_attention_bias
if has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(
config.relative_attention_num_buckets,
config.num_heads)
self.relative_attention_bias = RelativePositionBias(config)
def __call__(self, queries, keys, values, mask=None):
queries = self.query_proj(queries)
@ -125,6 +129,11 @@ class MultiHeadAttention(nn.Module):
scores = (queries * scale) @ keys
if mask is not None:
scores = scores + mask.astype(scores.dtype)
if self.has_relative_attention_bias:
position_bias = self.relative_attention_bias(L, S)
scores += position_bias
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)