mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Load fused model with transformers (#703)
* save format for transformers compatibility * save format for transformers compatibility arg * hardcode mlx * hardcode mlx format
This commit is contained in:
parent
749cabf299
commit
f20e68fcc0
@ -38,7 +38,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--upload-name",
|
"--upload-name",
|
||||||
help="The name of model to upload to Hugging Face MLX Community",
|
help="The name of model to upload to Hugging Face MLX Community.",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
@ -93,15 +93,32 @@ def save_model(save_dir: str, weights, tokenizer, config):
|
|||||||
else "model.safetensors"
|
else "model.safetensors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
total_size = sum(v.nbytes for v in weights.values())
|
||||||
|
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
|
||||||
|
|
||||||
for i, shard in enumerate(shards):
|
for i, shard in enumerate(shards):
|
||||||
shard_name = shard_file_format.format(i + 1, shards_count)
|
shard_name = shard_file_format.format(i + 1, shards_count)
|
||||||
mx.save_safetensors(str(save_dir / shard_name), shard)
|
mx.save_safetensors(
|
||||||
|
str(save_dir / shard_name), shard, metadata={"format": "mlx"}
|
||||||
|
)
|
||||||
|
for weight_name in shard.keys():
|
||||||
|
index_data["weight_map"][weight_name] = shard_name
|
||||||
|
del shard
|
||||||
|
|
||||||
tokenizer.save_pretrained(save_dir)
|
tokenizer.save_pretrained(save_dir)
|
||||||
|
|
||||||
with open(save_dir / "config.json", "w") as fid:
|
with open(save_dir / "config.json", "w") as fid:
|
||||||
json.dump(config, fid, indent=4)
|
json.dump(config, fid, indent=4)
|
||||||
|
|
||||||
|
index_data["weight_map"] = {
|
||||||
|
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
|
||||||
|
}
|
||||||
|
with open(save_dir / "model.safetensors.index.json", "w") as f:
|
||||||
|
json.dump(
|
||||||
|
index_data,
|
||||||
|
f,
|
||||||
|
indent=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load(path_or_hf_repo: str):
|
def load(path_or_hf_repo: str):
|
||||||
# If the path exists, it will try to load model form it
|
# If the path exists, it will try to load model form it
|
||||||
|
Loading…
Reference in New Issue
Block a user