mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Load fused model with transformers (#703)
* save format for transformers compatibility * save format for transformers compatibility arg * hardcode mlx * hardcode mlx format
This commit is contained in:
parent
749cabf299
commit
f20e68fcc0
@ -38,7 +38,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upload-name",
|
||||
help="The name of model to upload to Hugging Face MLX Community",
|
||||
help="The name of model to upload to Hugging Face MLX Community.",
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
|
@ -93,15 +93,32 @@ def save_model(save_dir: str, weights, tokenizer, config):
|
||||
else "model.safetensors"
|
||||
)
|
||||
|
||||
total_size = sum(v.nbytes for v in weights.values())
|
||||
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
||||
|
||||
for i, shard in enumerate(shards):
|
||||
shard_name = shard_file_format.format(i + 1, shards_count)
|
||||
mx.save_safetensors(str(save_dir / shard_name), shard)
|
||||
mx.save_safetensors(
|
||||
str(save_dir / shard_name), shard, metadata={"format": "mlx"}
|
||||
)
|
||||
for weight_name in shard.keys():
|
||||
index_data["weight_map"][weight_name] = shard_name
|
||||
del shard
|
||||
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
||||
with open(save_dir / "config.json", "w") as fid:
|
||||
json.dump(config, fid, indent=4)
|
||||
|
||||
index_data["weight_map"] = {
|
||||
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
||||
}
|
||||
with open(save_dir / "model.safetensors.index.json", "w") as f:
|
||||
json.dump(
|
||||
index_data,
|
||||
f,
|
||||
indent=4,
|
||||
)
|
||||
|
||||
|
||||
def load(path_or_hf_repo: str):
|
||||
# If the path exists, it will try to load model form it
|
||||
|
Loading…
Reference in New Issue
Block a user