diff --git a/lora/fuse.py b/lora/fuse.py index 6244ecd1..6cae95b1 100644 --- a/lora/fuse.py +++ b/lora/fuse.py @@ -38,7 +38,7 @@ if __name__ == "__main__": ) parser.add_argument( "--upload-name", - help="The name of model to upload to Hugging Face MLX Community", + help="The name of model to upload to Hugging Face MLX Community.", type=str, default=None, ) diff --git a/lora/utils.py b/lora/utils.py index 5c791561..0e7c1fb9 100644 --- a/lora/utils.py +++ b/lora/utils.py @@ -93,15 +93,32 @@ def save_model(save_dir: str, weights, tokenizer, config): else "model.safetensors" ) + total_size = sum(v.nbytes for v in weights.values()) + index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} + for i, shard in enumerate(shards): shard_name = shard_file_format.format(i + 1, shards_count) - mx.save_safetensors(str(save_dir / shard_name), shard) + mx.save_safetensors( + str(save_dir / shard_name), shard, metadata={"format": "mlx"} + ) + for weight_name in shard.keys(): + index_data["weight_map"][weight_name] = shard_name + del shard tokenizer.save_pretrained(save_dir) - with open(save_dir / "config.json", "w") as fid: json.dump(config, fid, indent=4) + index_data["weight_map"] = { + k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) + } + with open(save_dir / "model.safetensors.index.json", "w") as f: + json.dump( + index_data, + f, + indent=4, + ) + def load(path_or_hf_repo: str): # If the path exists, it will try to load model form it