fix use for llama 2 from meta (#144)

This commit is contained in:
Awni Hannun 2023-12-18 19:33:17 -08:00 committed by GitHub
parent 1d62b3ecc1
commit 1e7f4a5921
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 10 deletions

View File

@ -32,21 +32,30 @@ if __name__ == "__main__":
os.makedirs(args.mlx_model) os.makedirs(args.mlx_model)
mlx_path = Path(args.mlx_model) mlx_path = Path(args.mlx_model)
# Copy the tokenizer
tokenizer_path = torch_path / "tokenizer.model"
if not tokenizer_path.exists():
print(f"Make sure there is a file tokenizer.model in {args.torch_model}")
exit(0)
shutil.copyfile(
str(tokenizer_path),
str(mlx_path / "tokenizer.model"),
)
# Copy the model weights
state = torch.load(str(torch_path / "consolidated.00.pth")) state = torch.load(str(torch_path / "consolidated.00.pth"))
np.savez( np.savez(
str(mlx_path / "weights.npz"), str(mlx_path / "weights.npz"),
**{k: v.to(torch.float16).numpy() for k, v in state.items()} **{k: v.to(torch.float16).numpy() for k, v in state.items()},
)
# Copy the tokenizer
shutil.copyfile(
str(torch_path / "tokenizer.model"),
str(mlx_path / "tokenizer.model"),
) )
# Copy the params # Copy the params
with open(torch_path / "params.json", "r") as f: with open(torch_path / "params.json", "r") as f:
config = json.loads(f.read()) config = json.loads(f.read())
unused = ["multiple_of"]
for k in unused:
if k in config:
config.pop(k)
n_heads = config["n_heads"] n_heads = config["n_heads"]
if "sliding_window" in config: if "sliding_window" in config:
config.pop("sliding_window") config.pop("sliding_window")
@ -55,6 +64,6 @@ if __name__ == "__main__":
if "head_dim" not in config: if "head_dim" not in config:
config["head_dim"] = config["dim"] // n_heads config["head_dim"] = config["dim"] // n_heads
if "hidden_dim" not in config: if "hidden_dim" not in config:
config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape config["hidden_dim"] = state["layers.0.feed_forward.w1.weight"].shape[0]
with open(mlx_path / "params.json", "w") as outfile: with open(mlx_path / "params.json", "w") as outfile:
json.dump(config, outfile) json.dump(config, outfile, indent=4)

View File

@ -332,9 +332,9 @@ def load_model(folder: str, dtype=mx.float16):
tokenizer = Tokenizer(str(model_path / "tokenizer.model")) tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
with open(model_path / "params.json", "r") as f: with open(model_path / "params.json", "r") as f:
config = json.loads(f.read()) config = json.loads(f.read())
model_args = ModelArgs(**config)
if config.get("vocab_size", -1) < 0: if config.get("vocab_size", -1) < 0:
config["vocab_size"] = tokenizer.vocab_size config["vocab_size"] = tokenizer.vocab_size
model_args = ModelArgs(**config)
weights = mx.load(str(model_path / "weights.npz")) weights = mx.load(str(model_path / "weights.npz"))
weights = tree_unflatten(list(weights.items())) weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights) weights = tree_map(lambda p: p.astype(dtype), weights)