feat(lora): add de-quantized support for fuse.py (#351)

* feat(lora): add de-quantized support for fuse.py

* address comments
This commit is contained in:
Anchen
2024-01-22 17:32:24 -08:00
committed by GitHub
parent 30be4c4734
commit 8022083979
3 changed files with 48 additions and 6 deletions

View File

@@ -115,11 +115,21 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
def save_model(save_dir: str, weights, tokenizer, config):
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
shards = make_shards(weights, max_file_size_gibibyte=5)
shards_count = len(shards)
shard_file_format = (
"model-{:05d}-of-{:05d}.safetensors"
if shards_count > 1
else "model.safetensors"
)
for i, shard in enumerate(shards):
# TODO use HF file name scheme for simplicity
mx.save_safetensors(str(save_dir / f"weights.{i:02d}.safetensors"), shard)
shard_name = shard_file_format.format(i + 1, shards_count)
mx.save_safetensors(str(save_dir / shard_name), shard)
tokenizer.save_pretrained(save_dir)
with open(save_dir / "config.json", "w") as fid:
json.dump(config, fid, indent=4)