translate pytorch to mx

This commit is contained in:
Juarez Bochi 2023-12-15 16:51:01 -05:00
parent 330f024d1c
commit 392b7a2f98
No known key found for this signature in database
GPG Key ID: 34CCBB77DC8BEBB6
2 changed files with 29 additions and 16 deletions

View File

@ -13,6 +13,10 @@ SHARED_REPLACEMENT_PATTERNS = [
(".layer.1.layer_norm.", ".ln2."), (".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."), (".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."), (".final_layer_norm.", ".ln."),
(
".relative_attention_bias.",
".relative_attention_bias.embeddings."
),
] ]
ENCODER_REPLACEMENT_PATTERNS = [ ENCODER_REPLACEMENT_PATTERNS = [

View File

@ -2,6 +2,7 @@ import argparse
import math import math
from dataclasses import dataclass from dataclasses import dataclass
import numpy as np
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_unflatten
@ -49,7 +50,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
relative_buckets = 0 relative_buckets = 0
if bidirectional: if bidirectional:
num_buckets //= 2 num_buckets //= 2
relative_buckets += (relative_position > 0).to(mx.long) * num_buckets relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets
relative_position = mx.abs(relative_position) relative_position = mx.abs(relative_position)
else: else:
relative_position = -mx.min(relative_position, mx.zeros_like(relative_position)) relative_position = -mx.min(relative_position, mx.zeros_like(relative_position))
@ -60,13 +61,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
is_small = relative_position < max_exact is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
print("relative_position", relative_position)
relative_position_if_large = max_exact + ( relative_position_if_large = max_exact + (
mx.log(relative_position.float() / max_exact) mx.log(relative_position.astype(mx.float32) / max_exact)
/ mx.log(max_distance / max_exact) / np.log(max_distance / max_exact)
* (num_buckets - max_exact) * (num_buckets - max_exact)
).to(mx.long) ).astype(mx.int16)
relative_position_if_large = mx.min( relative_position_if_large = mx.minimum(
relative_position_if_large, mx.full_like(relative_position_if_large, num_buckets - 1) relative_position_if_large, num_buckets - 1
) )
relative_buckets += mx.where(is_small, relative_position, relative_position_if_large) relative_buckets += mx.where(is_small, relative_position, relative_position_if_large)
return relative_buckets return relative_buckets
@ -78,20 +80,23 @@ class RelativePositionBias(nn.Module):
self.num_buckets = config.relative_attention_num_buckets self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.n_positions self.max_distance = config.n_positions
self.n_heads = config.num_heads self.n_heads = config.num_heads
self.embeddings = nn.Embedding(
config.relative_attention_num_buckets,
config.num_heads)
def compute_bias(self, query_length, key_length): def __call__(self, query_length, key_length):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
context_position = mx.arange(query_length, dtype=mx.long)[:, None] context_position = mx.arange(query_length, dtype=mx.int32)[:, None]
memory_position = mx.arange(key_length, dtype=mx.long)[None, :] memory_position = mx.arange(key_length, dtype=mx.int32)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length) relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = _relative_position_bucket( relative_position_bucket = _relative_position_bucket(
relative_position, # shape (query_length, key_length) relative_position, # shape (query_length, key_length)
bidirectional=self.bidirectional, bidirectional=self.bidirectional,
num_buckets=self.relative_attention_num_buckets, num_buckets=self.num_buckets,
max_distance=self.relative_attention_max_distance, max_distance=self.max_distance,
) )
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) values = self.embeddings(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) values = mx.expand_dims(values.transpose(2, 0, 1), 0) # shape (1, num_heads, query_length, key_length)
return values return values
@ -103,10 +108,9 @@ class MultiHeadAttention(nn.Module):
self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
self.has_relative_attention_bias = has_relative_attention_bias
if has_relative_attention_bias: if has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding( self.relative_attention_bias = RelativePositionBias(config)
config.relative_attention_num_buckets,
config.num_heads)
def __call__(self, queries, keys, values, mask=None): def __call__(self, queries, keys, values, mask=None):
queries = self.query_proj(queries) queries = self.query_proj(queries)
@ -125,6 +129,11 @@ class MultiHeadAttention(nn.Module):
scores = (queries * scale) @ keys scores = (queries * scale) @ keys
if mask is not None: if mask is not None:
scores = scores + mask.astype(scores.dtype) scores = scores + mask.astype(scores.dtype)
if self.has_relative_attention_bias:
position_bias = self.relative_attention_bias(L, S)
scores += position_bias
scores = mx.softmax(scores, axis=-1) scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)