From 1e7f4a59213dd47023736f535dd6c5e2cf009c13 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 18 Dec 2023 19:33:17 -0800 Subject: [PATCH] fix use for llama 2 from meta (#144) --- lora/convert.py | 27 ++++++++++++++++++--------- lora/lora.py | 2 +- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/lora/convert.py b/lora/convert.py index 16af7931..3fdb5d42 100644 --- a/lora/convert.py +++ b/lora/convert.py @@ -32,21 +32,30 @@ if __name__ == "__main__": os.makedirs(args.mlx_model) mlx_path = Path(args.mlx_model) + # Copy the tokenizer + tokenizer_path = torch_path / "tokenizer.model" + if not tokenizer_path.exists(): + print(f"Make sure there is a file tokenizer.model in {args.torch_model}") + exit(0) + shutil.copyfile( + str(tokenizer_path), + str(mlx_path / "tokenizer.model"), + ) + + # Copy the model weights state = torch.load(str(torch_path / "consolidated.00.pth")) np.savez( str(mlx_path / "weights.npz"), - **{k: v.to(torch.float16).numpy() for k, v in state.items()} - ) - - # Copy the tokenizer - shutil.copyfile( - str(torch_path / "tokenizer.model"), - str(mlx_path / "tokenizer.model"), + **{k: v.to(torch.float16).numpy() for k, v in state.items()}, ) # Copy the params with open(torch_path / "params.json", "r") as f: config = json.loads(f.read()) + unused = ["multiple_of"] + for k in unused: + if k in config: + config.pop(k) n_heads = config["n_heads"] if "sliding_window" in config: config.pop("sliding_window") @@ -55,6 +64,6 @@ if __name__ == "__main__": if "head_dim" not in config: config["head_dim"] = config["dim"] // n_heads if "hidden_dim" not in config: - config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape + config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape[0] with open(mlx_path / "params.json", "w") as outfile: - json.dump(config, outfile) + json.dump(config, outfile, indent=4) diff --git a/lora/lora.py b/lora/lora.py index 2e0fa0a1..e1412da3 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -332,9 +332,9 @@ def load_model(folder: str, dtype=mx.float16): tokenizer = Tokenizer(str(model_path / "tokenizer.model")) with open(model_path / "params.json", "r") as f: config = json.loads(f.read()) - model_args = ModelArgs(**config) if config.get("vocab_size", -1) < 0: config["vocab_size"] = tokenizer.vocab_size + model_args = ModelArgs(**config) weights = mx.load(str(model_path / "weights.npz")) weights = tree_unflatten(list(weights.items())) weights = tree_map(lambda p: p.astype(dtype), weights)