diff --git a/bert/model.py b/bert/model.py index 446919b1..d4dccfac 100644 --- a/bert/model.py +++ b/bert/model.py @@ -7,7 +7,6 @@ import mlx.core as mx import mlx.nn as nn import argparse import numpy -import math @dataclass @@ -34,74 +33,6 @@ model_configs = { } -class MultiHeadAttention(nn.Module): - """ - Minor update to the MultiHeadAttention module to ensure that the - projections use bias. - """ - - def __init__( - self, - dims: int, - num_heads: int, - query_input_dims: Optional[int] = None, - key_input_dims: Optional[int] = None, - value_input_dims: Optional[int] = None, - value_dims: Optional[int] = None, - value_output_dims: Optional[int] = None, - ): - super().__init__() - - if (dims % num_heads) != 0: - raise ValueError( - f"The input feature dimensions should be divisible by the number of heads ({dims} % {num_heads}) != 0" - ) - - query_input_dims = query_input_dims or dims - key_input_dims = key_input_dims or dims - value_input_dims = value_input_dims or key_input_dims - value_dims = value_dims or dims - value_output_dims = value_output_dims or dims - - self.num_heads = num_heads - self.query_proj = nn.Linear(query_input_dims, dims, True) - self.key_proj = nn.Linear(key_input_dims, dims, True) - self.value_proj = nn.Linear(value_input_dims, value_dims, True) - self.out_proj = nn.Linear(value_dims, value_output_dims, True) - - def __call__(self, queries, keys, values, mask=None): - queries = self.query_proj(queries) - keys = self.key_proj(keys) - values = self.value_proj(values) - - num_heads = self.num_heads - B, L, D = queries.shape - _, S, _ = keys.shape - queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) - values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) - - # Dimensions are [batch x num heads x sequence x hidden dim] - scale = math.sqrt(1 / queries.shape[-1]) - scores = (queries * scale) @ keys - if mask is not None: - mask = self.convert_mask_to_additive_causal_mask(mask) - mask = mx.expand_dims(mask, (1, 2)) - mask = mx.broadcast_to(mask, scores.shape) - scores = scores + mask.astype(scores.dtype) - scores = mx.softmax(scores, axis=-1) - values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.out_proj(values_hat) - - def convert_mask_to_additive_causal_mask( - self, mask: mx.array, dtype: mx.Dtype = mx.float32 - ) -> mx.array: - mask = mask == 0 - mask = mask.astype(dtype) * -1e9 - return mask - - class TransformerEncoderLayer(nn.Module): """ A transformer encoder layer with (the original BERT) post-normalization. @@ -116,7 +47,7 @@ class TransformerEncoderLayer(nn.Module): ): super().__init__() mlp_dims = mlp_dims or dims * 4 - self.attention = MultiHeadAttention(dims, num_heads) + self.attention = nn.MultiHeadAttention(dims, num_heads, bias=True) self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps) self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps) self.linear1 = nn.Linear(dims, mlp_dims) @@ -186,11 +117,26 @@ class Bert(nn.Module): self, input_ids: mx.array, token_type_ids: mx.array, - attention_mask: Optional[mx.array] = None, + attention_mask: mx.array = None, ) -> tuple[mx.array, mx.array]: x = self.embeddings(input_ids, token_type_ids) + + if attention_mask is not None: + # convert 0's to -infs, 1's to 0's, and make it broadcastable + attention_mask = self.convert_mask_to_additive_causal_mask(attention_mask) + attention_mask = mx.expand_dims(attention_mask, (1, 2)) + y = self.encoder(x, attention_mask) return y, mx.tanh(self.pooler(y[:, 0])) + + + def convert_mask_to_additive_causal_mask( + self, mask: mx.array, dtype: mx.Dtype = mx.float32 + ) -> mx.array: + mask = mask == 0 + mask = mask.astype(dtype) * -1e9 + return mask + def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]: @@ -214,7 +160,7 @@ def run(bert_model: str, mlx_model: str): "A second string", "This is another string.", ] - + tokens = tokenizer(batch, return_tensors="np", padding=True) tokens = {key: mx.array(v) for key, v in tokens.items()}