Load fused model with transformers (#703)

* save format for transformers compatibility

* save format for transformers compatibility arg

* hardcode mlx

* hardcode mlx format
This commit is contained in:
AlexandrosChrtn 2024-04-21 19:04:44 +03:00 committed by GitHub
parent 749cabf299
commit f20e68fcc0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 3 deletions

View File

@ -38,7 +38,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--upload-name",
help="The name of model to upload to Hugging Face MLX Community",
help="The name of model to upload to Hugging Face MLX Community.",
type=str,
default=None,
)

View File

@ -93,15 +93,32 @@ def save_model(save_dir: str, weights, tokenizer, config):
else "model.safetensors"
)
total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
for i, shard in enumerate(shards):
shard_name = shard_file_format.format(i + 1, shards_count)
mx.save_safetensors(str(save_dir / shard_name), shard)
mx.save_safetensors(
str(save_dir / shard_name), shard, metadata={"format": "mlx"}
)
for weight_name in shard.keys():
index_data["weight_map"][weight_name] = shard_name
del shard
tokenizer.save_pretrained(save_dir)
with open(save_dir / "config.json", "w") as fid:
json.dump(config, fid, indent=4)
index_data["weight_map"] = {
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
}
with open(save_dir / "model.safetensors.index.json", "w") as f:
json.dump(
index_data,
f,
indent=4,
)
def load(path_or_hf_repo: str):
# If the path exists, it will try to load model form it