From 392b7a2f983b98168d5661e841968d9941db0d0f Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Fri, 15 Dec 2023 16:51:01 -0500 Subject: [PATCH] translate pytorch to mx --- t5/convert.py | 4 ++++ t5/t5.py | 41 +++++++++++++++++++++++++---------------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/t5/convert.py b/t5/convert.py index 7531ec4d..1de73150 100644 --- a/t5/convert.py +++ b/t5/convert.py @@ -13,6 +13,10 @@ SHARED_REPLACEMENT_PATTERNS = [ (".layer.1.layer_norm.", ".ln2."), (".layer.2.layer_norm.", ".ln3."), (".final_layer_norm.", ".ln."), + ( + ".relative_attention_bias.", + ".relative_attention_bias.embeddings." + ), ] ENCODER_REPLACEMENT_PATTERNS = [ diff --git a/t5/t5.py b/t5/t5.py index c4048c86..206ad972 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -2,6 +2,7 @@ import argparse import math from dataclasses import dataclass +import numpy as np import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_flatten, tree_unflatten @@ -49,7 +50,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets = 0 if bidirectional: num_buckets //= 2 - relative_buckets += (relative_position > 0).to(mx.long) * num_buckets + relative_buckets += (relative_position > 0).astype(mx.int16) * num_buckets relative_position = mx.abs(relative_position) else: relative_position = -mx.min(relative_position, mx.zeros_like(relative_position)) @@ -60,13 +61,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets is_small = relative_position < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + print("relative_position", relative_position) relative_position_if_large = max_exact + ( - mx.log(relative_position.float() / max_exact) - / mx.log(max_distance / max_exact) + mx.log(relative_position.astype(mx.float32) / max_exact) + / np.log(max_distance / max_exact) * (num_buckets - max_exact) - ).to(mx.long) - relative_position_if_large = mx.min( - relative_position_if_large, mx.full_like(relative_position_if_large, num_buckets - 1) + ).astype(mx.int16) + relative_position_if_large = mx.minimum( + relative_position_if_large, num_buckets - 1 ) relative_buckets += mx.where(is_small, relative_position, relative_position_if_large) return relative_buckets @@ -78,20 +80,23 @@ class RelativePositionBias(nn.Module): self.num_buckets = config.relative_attention_num_buckets self.max_distance = config.n_positions self.n_heads = config.num_heads + self.embeddings = nn.Embedding( + config.relative_attention_num_buckets, + config.num_heads) - def compute_bias(self, query_length, key_length): + def __call__(self, query_length, key_length): """Compute binned relative position bias""" - context_position = mx.arange(query_length, dtype=mx.long)[:, None] - memory_position = mx.arange(key_length, dtype=mx.long)[None, :] + context_position = mx.arange(query_length, dtype=mx.int32)[:, None] + memory_position = mx.arange(key_length, dtype=mx.int32)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = _relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=self.bidirectional, - num_buckets=self.relative_attention_num_buckets, - max_distance=self.relative_attention_max_distance, + num_buckets=self.num_buckets, + max_distance=self.max_distance, ) - values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) - values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + values = self.embeddings(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = mx.expand_dims(values.transpose(2, 0, 1), 0) # shape (1, num_heads, query_length, key_length) return values @@ -103,10 +108,9 @@ class MultiHeadAttention(nn.Module): self.key_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.value_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) + self.has_relative_attention_bias = has_relative_attention_bias if has_relative_attention_bias: - self.relative_attention_bias = nn.Embedding( - config.relative_attention_num_buckets, - config.num_heads) + self.relative_attention_bias = RelativePositionBias(config) def __call__(self, queries, keys, values, mask=None): queries = self.query_proj(queries) @@ -125,6 +129,11 @@ class MultiHeadAttention(nn.Module): scores = (queries * scale) @ keys if mask is not None: scores = scores + mask.astype(scores.dtype) + + if self.has_relative_attention_bias: + position_bias = self.relative_attention_bias(L, S) + scores += position_bias + scores = mx.softmax(scores, axis=-1) values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)