diff --git a/lora/fuse.py b/lora/fuse.py index 2ea265fb..a957ff28 100644 --- a/lora/fuse.py +++ b/lora/fuse.py @@ -4,6 +4,7 @@ import argparse from pathlib import Path import mlx.core as mx +import mlx.nn as nn import utils from mlx.utils import tree_flatten, tree_unflatten from models.lora import LoRALinear @@ -41,6 +42,12 @@ if __name__ == "__main__": type=str, default=None, ) + parser.add_argument( + "-d", + "--de-quantize", + help="Generate a de-quantized model.", + action="store_true", + ) print("Loading pretrained model") args = parser.parse_args() @@ -53,7 +60,7 @@ if __name__ == "__main__": # Freeze all layers other than LORA linears model.freeze() - for l in model.model.layers[-lora_layers:]: + for l in model.model.layers[len(model.model.layers) - lora_layers :]: l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) if hasattr(l, "block_sparse_moe"): @@ -67,7 +74,32 @@ if __name__ == "__main__": ] model.update_modules(tree_unflatten(fused_linears)) + + if args.de_quantize: + de_quantize_layers = [] + for n, m in model.named_modules(): + if isinstance(m, nn.QuantizedLinear): + bias = "bias" in m + weight = m.weight + weight = mx.dequantize( + weight, + m.scales, + m.biases, + m.group_size, + m.bits, + ).astype(mx.float16) + output_dims, input_dims = weight.shape + linear = nn.Linear(input_dims, output_dims, bias=bias) + linear.weight = weight + if bias: + linear.bias = m.bias + de_quantize_layers.append((n, linear)) + + model.update_modules(tree_unflatten(de_quantize_layers)) + weights = dict(tree_flatten(model.parameters())) + if args.de_quantize: + config.pop("quantization", None) utils.save_model(args.save_path, weights, tokenizer, config) if args.upload_name is not None: diff --git a/lora/models/lora.py b/lora/models/lora.py index 3f584cfd..8f3c01eb 100644 --- a/lora/models/lora.py +++ b/lora/models/lora.py @@ -16,7 +16,7 @@ class LoRALinear(nn.Module): lora_lin.linear = linear return lora_lin - def to_linear(self): + def to_linear(self, de_quantize: bool = False): linear = self.linear bias = "bias" in linear weight = linear.weight @@ -43,7 +43,7 @@ class LoRALinear(nn.Module): if bias: fused_linear.bias = linear.bias - if is_quantized: + if is_quantized and not de_quantize: fused_linear = nn.QuantizedLinear.from_linear( fused_linear, linear.group_size, diff --git a/lora/utils.py b/lora/utils.py index 80d59399..dee35bc4 100644 --- a/lora/utils.py +++ b/lora/utils.py @@ -115,11 +115,21 @@ def make_shards(weights: dict, max_file_size_gibibyte: int = 15): def save_model(save_dir: str, weights, tokenizer, config): save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) - shards = make_shards(weights) + + shards = make_shards(weights, max_file_size_gibibyte=5) + shards_count = len(shards) + shard_file_format = ( + "model-{:05d}-of-{:05d}.safetensors" + if shards_count > 1 + else "model.safetensors" + ) + for i, shard in enumerate(shards): - # TODO use HF file name scheme for simplicity - mx.save_safetensors(str(save_dir / f"weights.{i:02d}.safetensors"), shard) + shard_name = shard_file_format.format(i + 1, shards_count) + mx.save_safetensors(str(save_dir / shard_name), shard) + tokenizer.save_pretrained(save_dir) + with open(save_dir / "config.json", "w") as fid: json.dump(config, fid, indent=4)