mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Updating BERT model to take advantage of bias param in MultiHeadAttention
This commit is contained in:
parent
46c6bbe0a1
commit
5d4838b02e
@ -7,7 +7,6 @@ import mlx.core as mx
|
|||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import argparse
|
import argparse
|
||||||
import numpy
|
import numpy
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -34,74 +33,6 @@ model_configs = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
|
||||||
"""
|
|
||||||
Minor update to the MultiHeadAttention module to ensure that the
|
|
||||||
projections use bias.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dims: int,
|
|
||||||
num_heads: int,
|
|
||||||
query_input_dims: Optional[int] = None,
|
|
||||||
key_input_dims: Optional[int] = None,
|
|
||||||
value_input_dims: Optional[int] = None,
|
|
||||||
value_dims: Optional[int] = None,
|
|
||||||
value_output_dims: Optional[int] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if (dims % num_heads) != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0"
|
|
||||||
)
|
|
||||||
|
|
||||||
query_input_dims = query_input_dims or dims
|
|
||||||
key_input_dims = key_input_dims or dims
|
|
||||||
value_input_dims = value_input_dims or key_input_dims
|
|
||||||
value_dims = value_dims or dims
|
|
||||||
value_output_dims = value_output_dims or dims
|
|
||||||
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.query_proj = nn.Linear(query_input_dims, dims, True)
|
|
||||||
self.key_proj = nn.Linear(key_input_dims, dims, True)
|
|
||||||
self.value_proj = nn.Linear(value_input_dims, value_dims, True)
|
|
||||||
self.out_proj = nn.Linear(value_dims, value_output_dims, True)
|
|
||||||
|
|
||||||
def __call__(self, queries, keys, values, mask=None):
|
|
||||||
queries = self.query_proj(queries)
|
|
||||||
keys = self.key_proj(keys)
|
|
||||||
values = self.value_proj(values)
|
|
||||||
|
|
||||||
num_heads = self.num_heads
|
|
||||||
B, L, D = queries.shape
|
|
||||||
_, S, _ = keys.shape
|
|
||||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
|
||||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
||||||
|
|
||||||
# Dimensions are [batch x num heads x sequence x hidden dim]
|
|
||||||
scale = math.sqrt(1 / queries.shape[-1])
|
|
||||||
scores = (queries * scale) @ keys
|
|
||||||
if mask is not None:
|
|
||||||
mask = self.convert_mask_to_additive_causal_mask(mask)
|
|
||||||
mask = mx.expand_dims(mask, (1, 2))
|
|
||||||
mask = mx.broadcast_to(mask, scores.shape)
|
|
||||||
scores = scores + mask.astype(scores.dtype)
|
|
||||||
scores = mx.softmax(scores, axis=-1)
|
|
||||||
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
||||||
|
|
||||||
return self.out_proj(values_hat)
|
|
||||||
|
|
||||||
def convert_mask_to_additive_causal_mask(
|
|
||||||
self, mask: mx.array, dtype: mx.Dtype = mx.float32
|
|
||||||
) -> mx.array:
|
|
||||||
mask = mask == 0
|
|
||||||
mask = mask.astype(dtype) * -1e9
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
A transformer encoder layer with (the original BERT) post-normalization.
|
A transformer encoder layer with (the original BERT) post-normalization.
|
||||||
@ -116,7 +47,7 @@ class TransformerEncoderLayer(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = mlp_dims or dims * 4
|
mlp_dims = mlp_dims or dims * 4
|
||||||
self.attention = MultiHeadAttention(dims, num_heads)
|
self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True)
|
||||||
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
||||||
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
||||||
self.linear1 = nn.Linear(dims, mlp_dims)
|
self.linear1 = nn.Linear(dims, mlp_dims)
|
||||||
@ -186,11 +117,26 @@ class Bert(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: mx.array,
|
input_ids: mx.array,
|
||||||
token_type_ids: mx.array,
|
token_type_ids: mx.array,
|
||||||
attention_mask: Optional[mx.array] = None,
|
attention_mask: mx.array = None,
|
||||||
) -> tuple[mx.array, mx.array]:
|
) -> tuple[mx.array, mx.array]:
|
||||||
x = self.embeddings(input_ids, token_type_ids)
|
x = self.embeddings(input_ids, token_type_ids)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# convert 0's to -infs, 1's to 0's, and make it broadcastable
|
||||||
|
attention_mask = self.convert_mask_to_additive_causal_mask(attention_mask)
|
||||||
|
attention_mask = mx.expand_dims(attention_mask, (1, 2))
|
||||||
|
|
||||||
y = self.encoder(x, attention_mask)
|
y = self.encoder(x, attention_mask)
|
||||||
return y, mx.tanh(self.pooler(y[:, 0]))
|
return y, mx.tanh(self.pooler(y[:, 0]))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_mask_to_additive_causal_mask(
|
||||||
|
self, mask: mx.array, dtype: mx.Dtype = mx.float32
|
||||||
|
) -> mx.array:
|
||||||
|
mask = mask == 0
|
||||||
|
mask = mask.astype(dtype) * -1e9
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
|
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
|
||||||
@ -214,7 +160,7 @@ def run(bert_model: str, mlx_model: str):
|
|||||||
"A second string",
|
"A second string",
|
||||||
"This is another string.",
|
"This is another string.",
|
||||||
]
|
]
|
||||||
|
|
||||||
tokens = tokenizer(batch, return_tensors="np", padding=True)
|
tokens = tokenizer(batch, return_tensors="np", padding=True)
|
||||||
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user