From 8022083979972efb92c2064fccc4ed1834ed2c96 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Mon, 22 Jan 2024 17:32:24 -0800
Subject: [PATCH] feat(lora): add de-quantized support for fuse.py (#351)
* feat(lora): add de-quantized support for fuse.py
* address comments
---
lora/fuse.py | 34 +++++++++++++++++++++++++++++++++-
lora/models/lora.py | 4 ++--
lora/utils.py | 16 +++++++++++++---
3 files changed, 48 insertions(+), 6 deletions(-)
diff --git a/lora/fuse.py b/lora/fuse.py
index 2ea265fb..a957ff28 100644
--- a/lora/fuse.py
+++ b/lora/fuse.py
@@ -4,6 +4,7 @@ import argparse
from pathlib import Path
import mlx.core as mx
+import mlx.nn as nn
import utils
from mlx.utils import tree_flatten, tree_unflatten
from models.lora import LoRALinear
@@ -41,6 +42,12 @@ if __name__ == "__main__":
type=str,
default=None,
)
+ parser.add_argument(
+ "-d",
+ "--de-quantize",
+ help="Generate a de-quantized model.",
+ action="store_true",
+ )
print("Loading pretrained model")
args = parser.parse_args()
@@ -53,7 +60,7 @@ if __name__ == "__main__":
# Freeze all layers other than LORA linears
model.freeze()
- for l in model.model.layers[-lora_layers:]:
+ for l in model.model.layers[len(model.model.layers) - lora_layers :]:
l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj)
l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj)
if hasattr(l, "block_sparse_moe"):
@@ -67,7 +74,32 @@ if __name__ == "__main__":
]
model.update_modules(tree_unflatten(fused_linears))
+
+ if args.de_quantize:
+ de_quantize_layers = []
+ for n, m in model.named_modules():
+ if isinstance(m, nn.QuantizedLinear):
+ bias = "bias" in m
+ weight = m.weight
+ weight = mx.dequantize(
+ weight,
+ m.scales,
+ m.biases,
+ m.group_size,
+ m.bits,
+ ).astype(mx.float16)
+ output_dims, input_dims = weight.shape
+ linear = nn.Linear(input_dims, output_dims, bias=bias)
+ linear.weight = weight
+ if bias:
+ linear.bias = m.bias
+ de_quantize_layers.append((n, linear))
+
+ model.update_modules(tree_unflatten(de_quantize_layers))
+
weights = dict(tree_flatten(model.parameters()))
+ if args.de_quantize:
+ config.pop("quantization", None)
utils.save_model(args.save_path, weights, tokenizer, config)
if args.upload_name is not None:
diff --git a/lora/models/lora.py b/lora/models/lora.py
index 3f584cfd..8f3c01eb 100644
--- a/lora/models/lora.py
+++ b/lora/models/lora.py
@@ -16,7 +16,7 @@ class LoRALinear(nn.Module):
lora_lin.linear = linear
return lora_lin
- def to_linear(self):
+ def to_linear(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
@@ -43,7 +43,7 @@ class LoRALinear(nn.Module):
if bias:
fused_linear.bias = linear.bias
- if is_quantized:
+ if is_quantized and not de_quantize:
fused_linear = nn.QuantizedLinear.from_linear(
fused_linear,
linear.group_size,
diff --git a/lora/utils.py b/lora/utils.py
index 80d59399..dee35bc4 100644
--- a/lora/utils.py
+++ b/lora/utils.py
@@ -115,11 +115,21 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
def save_model(save_dir: str, weights, tokenizer, config):
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
- shards = make_shards(weights)
+
+ shards = make_shards(weights, max_file_size_gibibyte=5)
+ shards_count = len(shards)
+ shard_file_format = (
+ "model-{:05d}-of-{:05d}.safetensors"
+ if shards_count > 1
+ else "model.safetensors"
+ )
+
for i, shard in enumerate(shards):
- # TODO use HF file name scheme for simplicity
- mx.save_safetensors(str(save_dir / f"weights.{i:02d}.safetensors"), shard)
+ shard_name = shard_file_format.format(i + 1, shards_count)
+ mx.save_safetensors(str(save_dir / shard_name), shard)
+
tokenizer.save_pretrained(save_dir)
+
with open(save_dir / "config.json", "w") as fid:
json.dump(config, fid, indent=4)