mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-10 03:06:43 +08:00
Cleaning implementation for merge
This commit is contained in:
parent
e05ee57bab
commit
7320456226
@ -1,6 +1,6 @@
|
|||||||
# mlxbert
|
# BERT
|
||||||
|
|
||||||
An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within mlx.
|
An implementation of BERT [(Devlin, et al., 2019)](https://aclanthology.org/N19-1423/) within MLX.
|
||||||
|
|
||||||
## Downloading and Converting Weights
|
## Downloading and Converting Weights
|
||||||
|
|
||||||
|
@ -26,14 +26,13 @@ def convert(bert_model: str, mlx_model: str) -> None:
|
|||||||
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
|
replace_key(key): tensor.numpy() for key, tensor in model.state_dict().items()
|
||||||
}
|
}
|
||||||
numpy.savez(mlx_model, **tensors)
|
numpy.savez(mlx_model, **tensors)
|
||||||
# save the tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
|
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bert-model",
|
"--bert-model",
|
||||||
type=str,
|
choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"],
|
||||||
default="bert-base-uncased",
|
default="bert-base-uncased",
|
||||||
help="The huggingface name of the BERT model to save.",
|
help="The huggingface name of the BERT model to save.",
|
||||||
)
|
)
|
||||||
|
@ -24,10 +24,10 @@ def run(bert_model: str):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser(description="Run the BERT model using HuggingFace Transformers.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bert-model",
|
"--bert-model",
|
||||||
type=str,
|
choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"],
|
||||||
default="bert-base-uncased",
|
default="bert-base-uncased",
|
||||||
help="The huggingface name of the BERT model to save.",
|
help="The huggingface name of the BERT model to save.",
|
||||||
)
|
)
|
||||||
|
@ -1,10 +1,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from mlx.utils import tree_unflatten, tree_map
|
from transformers import BertTokenizer
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.utils import tree_unflatten
|
||||||
from mlx.nn.layers.linear import Linear
|
|
||||||
from mlx.nn.layers.normalization import LayerNorm
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -37,7 +34,7 @@ model_configs = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
"""
|
"""
|
||||||
Minor update to the MultiHeadAttention module to ensure that the
|
Minor update to the MultiHeadAttention module to ensure that the
|
||||||
projections use bias.
|
projections use bias.
|
||||||
@ -67,10 +64,10 @@ class MultiHeadAttention(Module):
|
|||||||
value_output_dims = value_output_dims or dims
|
value_output_dims = value_output_dims or dims
|
||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.query_proj = Linear(query_input_dims, dims, True)
|
self.query_proj = nn.Linear(query_input_dims, dims, True)
|
||||||
self.key_proj = Linear(key_input_dims, dims, True)
|
self.key_proj = nn.Linear(key_input_dims, dims, True)
|
||||||
self.value_proj = Linear(value_input_dims, value_dims, True)
|
self.value_proj = nn.Linear(value_input_dims, value_dims, True)
|
||||||
self.out_proj = Linear(value_dims, value_output_dims, True)
|
self.out_proj = nn.Linear(value_dims, value_output_dims, True)
|
||||||
|
|
||||||
def __call__(self, queries, keys, values, mask=None):
|
def __call__(self, queries, keys, values, mask=None):
|
||||||
queries = self.query_proj(queries)
|
queries = self.query_proj(queries)
|
||||||
@ -105,7 +102,7 @@ class MultiHeadAttention(Module):
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(Module):
|
class TransformerEncoderLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
A transformer encoder layer with (the original BERT) post-normalization.
|
A transformer encoder layer with (the original BERT) post-normalization.
|
||||||
"""
|
"""
|
||||||
@ -120,10 +117,10 @@ class TransformerEncoderLayer(Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
mlp_dims = mlp_dims or dims * 4
|
mlp_dims = mlp_dims or dims * 4
|
||||||
self.attention = MultiHeadAttention(dims, num_heads)
|
self.attention = MultiHeadAttention(dims, num_heads)
|
||||||
self.ln1 = LayerNorm(dims, eps=layer_norm_eps)
|
self.ln1 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
||||||
self.ln2 = LayerNorm(dims, eps=layer_norm_eps)
|
self.ln2 = nn.LayerNorm(dims, eps=layer_norm_eps)
|
||||||
self.linear1 = Linear(dims, mlp_dims)
|
self.linear1 = nn.Linear(dims, mlp_dims)
|
||||||
self.linear2 = Linear(mlp_dims, dims)
|
self.linear2 = nn.Linear(mlp_dims, dims)
|
||||||
self.gelu = nn.GELU()
|
self.gelu = nn.GELU()
|
||||||
|
|
||||||
def __call__(self, x, mask):
|
def __call__(self, x, mask):
|
||||||
@ -138,7 +135,7 @@ class TransformerEncoderLayer(Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(Module):
|
class TransformerEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None
|
||||||
):
|
):
|
||||||
@ -149,8 +146,8 @@ class TransformerEncoder(Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def __call__(self, x, mask):
|
def __call__(self, x, mask):
|
||||||
for l in self.layers:
|
for layer in self.layers:
|
||||||
x = l(x, mask)
|
x = layer(x, mask)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -196,23 +193,28 @@ class Bert(nn.Module):
|
|||||||
return y, mx.tanh(self.pooler(y[:, 0]))
|
return y, mx.tanh(self.pooler(y[:, 0]))
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(bert_model: str, weights_path: str) -> tuple[Bert, BertTokenizer]:
|
||||||
|
# load the weights npz
|
||||||
|
weights = mx.load(weights_path)
|
||||||
|
weights = tree_unflatten(list(weights.items()))
|
||||||
|
# create and update the model
|
||||||
|
model = Bert(model_configs[bert_model])
|
||||||
|
model.update(weights)
|
||||||
|
|
||||||
|
tokenizer = BertTokenizer.from_pretrained(bert_model)
|
||||||
|
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def run(bert_model: str, mlx_model: str):
|
def run(bert_model: str, mlx_model: str):
|
||||||
|
model, tokenizer = load_model(bert_model, mlx_model)
|
||||||
|
|
||||||
batch = [
|
batch = [
|
||||||
"This is an example of BERT working on MLX.",
|
"This is an example of BERT working on MLX.",
|
||||||
"A second string",
|
"A second string",
|
||||||
"This is another string.",
|
"This is another string.",
|
||||||
]
|
]
|
||||||
|
|
||||||
model = Bert(model_configs[bert_model])
|
|
||||||
|
|
||||||
weights = mx.load(mlx_model)
|
|
||||||
weights = tree_unflatten(list(weights.items()))
|
|
||||||
weights = tree_map(lambda p: mx.array(p), weights)
|
|
||||||
|
|
||||||
model.update(weights)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(bert_model)
|
|
||||||
|
|
||||||
tokens = tokenizer(batch, return_tensors="np", padding=True)
|
tokens = tokenizer(batch, return_tensors="np", padding=True)
|
||||||
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
tokens = {key: mx.array(v) for key, v in tokens.items()}
|
||||||
|
|
||||||
@ -228,7 +230,7 @@ def run(bert_model: str, mlx_model: str):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.")
|
parser = argparse.ArgumentParser(description="Run the BERT model using MLX.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bert-model",
|
"--bert-model",
|
||||||
type=str,
|
type=str,
|
||||||
@ -239,7 +241,7 @@ if __name__ == "__main__":
|
|||||||
"--mlx-model",
|
"--mlx-model",
|
||||||
type=str,
|
type=str,
|
||||||
default="weights/bert-base-uncased.npz",
|
default="weights/bert-base-uncased.npz",
|
||||||
help="The output path for the MLX BERT weights.",
|
help="The path of the stored MLX BERT weights (npz file).",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user