From f20e68fcc0eab129911828c00cbeb1c2a5246156 Mon Sep 17 00:00:00 2001 From: AlexandrosChrtn <56091961+AlexandrosChrtn@users.noreply.github.com> Date: Sun, 21 Apr 2024 19:04:44 +0300 Subject: [PATCH] Load fused model with transformers (#703) * save format for transformers compatibility * save format for transformers compatibility arg * hardcode mlx * hardcode mlx format --- lora/fuse.py | 2 +- lora/utils.py | 21 +++++++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/lora/fuse.py b/lora/fuse.py index 6244ecd1..6cae95b1 100644 --- a/lora/fuse.py +++ b/lora/fuse.py @@ -38,7 +38,7 @@ if __name__ == "__main__": ) parser.add_argument( "--upload-name", - help="The name of model to upload to Hugging Face MLX Community", + help="The name of model to upload to Hugging Face MLX Community.", type=str, default=None, ) diff --git a/lora/utils.py b/lora/utils.py index 5c791561..0e7c1fb9 100644 --- a/lora/utils.py +++ b/lora/utils.py @@ -93,15 +93,32 @@ def save_model(save_dir: str, weights, tokenizer, config): else "model.safetensors" ) + total_size = sum(v.nbytes for v in weights.values()) + index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} + for i, shard in enumerate(shards): shard_name = shard_file_format.format(i + 1, shards_count) - mx.save_safetensors(str(save_dir / shard_name), shard) + mx.save_safetensors( + str(save_dir / shard_name), shard, metadata={"format": "mlx"} + ) + for weight_name in shard.keys(): + index_data["weight_map"][weight_name] = shard_name + del shard tokenizer.save_pretrained(save_dir) - with open(save_dir / "config.json", "w") as fid: json.dump(config, fid, indent=4) + index_data["weight_map"] = { + k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) + } + with open(save_dir / "model.safetensors.index.json", "w") as f: + json.dump( + index_data, + f, + indent=4, + ) + def load(path_or_hf_repo: str): # If the path exists, it will try to load model form it