diff --git a/bert/convert.py b/bert/convert.py index 5a9298d6..82ee8b82 100644 --- a/bert/convert.py +++ b/bert/convert.py @@ -32,7 +32,12 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Convert BERT weights to MLX.") parser.add_argument( "--bert-model", - choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"], + choices=[ + "bert-base-uncased", + "bert-base-cased", + "bert-large-uncased", + "bert-large-cased", + ], default="bert-base-uncased", help="The huggingface name of the BERT model to save.", ) @@ -44,4 +49,4 @@ if __name__ == "__main__": ) args = parser.parse_args() - convert(args.bert_model, args.mlx_model) \ No newline at end of file + convert(args.bert_model, args.mlx_model) diff --git a/bert/hf_model.py b/bert/hf_model.py index 9f73028d..4f07df13 100644 --- a/bert/hf_model.py +++ b/bert/hf_model.py @@ -24,10 +24,17 @@ def run(bert_model: str): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run the BERT model using HuggingFace Transformers.") + parser = argparse.ArgumentParser( + description="Run the BERT model using HuggingFace Transformers." + ) parser.add_argument( "--bert-model", - choices=["bert-base-uncased", "bert-base-cased", "bert-large-uncased", "bert-large-cased"], + choices=[ + "bert-base-uncased", + "bert-base-cased", + "bert-large-uncased", + "bert-large-cased", + ], default="bert-base-uncased", help="The huggingface name of the BERT model to save.", ) diff --git a/bert/model.py b/bert/model.py index 446919b1..00344ab6 100644 --- a/bert/model.py +++ b/bert/model.py @@ -1,3 +1,4 @@ +import numpy as np from typing import Optional from dataclasses import dataclass from transformers import BertTokenizer @@ -214,19 +215,29 @@ def run(bert_model: str, mlx_model: str): "A second string", "This is another string.", ] - + tokens = tokenizer(batch, return_tensors="np", padding=True) tokens = {key: mx.array(v) for key, v in tokens.items()} - mlx_output, mlx_pooled = model(**tokens) - mlx_output = numpy.array(mlx_output) - mlx_pooled = numpy.array(mlx_pooled) + vs = model_configs[bert_model].vocab_size + ts = np.random.randint(0, vs, (8, 512)) + tokens["input_ids"] = mx.array(ts) + tokens["token_type_ids"] = mx.zeros((8, 512), mx.int32) + tokens.pop("attention_mask") - print("MLX BERT:") - print(mlx_output) + for _ in range(5): + out = model(**tokens) + mx.eval(out) - print("\n\nMLX Pooled:") - print(mlx_pooled[0, :20]) + import time + + tic = time.time() + for _ in range(10): + out = model(**tokens) + mx.eval(out) + toc = time.time() + tps = (8 * 5 * 10) / (toc - tic) + print(tps) if __name__ == "__main__": diff --git a/lora/convert.py b/lora/convert.py index 2903aae8..02dd06fb 100644 --- a/lora/convert.py +++ b/lora/convert.py @@ -30,7 +30,7 @@ if __name__ == "__main__": torch_path = Path(args.torch_model) if not os.path.exists(args.mlx_model): os.makedirs(args.mlx_model) - mlx_path = Path(args.mlx_model) + mlx_path = Path(args.mlx_model) state = torch.load(str(torch_path / "consolidated.00.pth")) np.savez( @@ -57,5 +57,3 @@ if __name__ == "__main__": config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape with open(mlx_path / "params.json", "w") as outfile: json.dump(config, outfile) - - diff --git a/lora/lora.py b/lora/lora.py index a7dcdb30..e35f0a86 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -20,9 +20,13 @@ import wikisql def build_parser(): - parser = argparse.ArgumentParser(description="LoRA finetuning with Llama or Mistral") + parser = argparse.ArgumentParser( + description="LoRA finetuning with Llama or Mistral" + ) parser.add_argument( - "--model", required=True, help="A path to the model files containing the tokenizer, weights, config." + "--model", + required=True, + help="A path to the model files containing the tokenizer, weights, config.", ) # Generation args parser.add_argument( @@ -227,6 +231,7 @@ def generate(model, prompt, tokenizer, args): def generate_step(): temp = args.temp + def sample(logits): if temp == 0: return mx.argmax(logits, axis=-1) diff --git a/transformer_lm/datasets.py b/transformer_lm/datasets.py index f0e8ff51..78db713e 100644 --- a/transformer_lm/datasets.py +++ b/transformer_lm/datasets.py @@ -42,7 +42,7 @@ def wikitext(dataset="2", save_dir="/tmp"): Load the WikiText-* language modeling dataset: https://paperswithcode.com/dataset/wikitext-2 https://paperswithcode.com/dataset/wikitext-103 - + """ if dataset not in ("2", "103"): raise ValueError(f'Dataset must be either "2" or "103", got {dataset}')