diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index ae79fb65..188fd7b5 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -51,6 +51,9 @@ max_seq_length: 2048 # Use gradient checkpointing to reduce memory use. grad_checkpoint: false +# Use DoRA instead of LoRA. +use_dora: false + # LoRA parameters can only be specified in a config file lora_parameters: # The layer keys to apply LoRA to. diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py index b362e7b7..1c7250e7 100644 --- a/llms/mlx_lm/fuse.py +++ b/llms/mlx_lm/fuse.py @@ -6,6 +6,7 @@ from pathlib import Path from mlx.utils import tree_flatten, tree_unflatten from .gguf import convert_to_gguf +from .tuner.dora import DoRALinear from .tuner.lora import LoRALinear from .tuner.utils import apply_lora_layers, dequantize from .utils import ( @@ -18,7 +19,9 @@ from .utils import ( def parse_arguments() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") + parser = argparse.ArgumentParser( + description="Fuse fine-tuned adapters into the base model." + ) parser.add_argument( "--model", default="mlx_model", @@ -79,7 +82,7 @@ def main() -> None: fused_linears = [ (n, m.to_linear()) for n, m in model.named_modules() - if isinstance(m, LoRALinear) + if isinstance(m, (LoRALinear, DoRALinear)) ] model.update_modules(tree_unflatten(fused_linears)) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 477398b6..629bba16 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -123,7 +123,9 @@ def main(): tokenizer_config["eos_token"] = args.eos_token model, tokenizer = load( - args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config + args.model, + adapter_path=args.adapter_path, + tokenizer_config=tokenizer_config, ) if args.use_default_chat_template: diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index df382cfe..5bcfb829 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -54,6 +54,7 @@ CONFIG_DEFAULTS = { "max_seq_length": 2048, "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, + "use_dora": False, } @@ -69,6 +70,7 @@ def build_parser(): "--train", action="store_true", help="Do training", + default=None, ) parser.add_argument( "--data", @@ -117,6 +119,7 @@ def build_parser(): "--test", action="store_true", help="Evaluate on the test set after training", + default=None, ) parser.add_argument( "--test-batches", @@ -138,8 +141,12 @@ def build_parser(): "--grad-checkpoint", action="store_true", help="Use gradient checkpointing to reduce memory use.", + default=None, ) parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") + parser.add_argument( + "--use-dora", action="store_true", default=None, help="Use DoRA to finetune." + ) return parser @@ -175,16 +182,20 @@ def run(args, training_callback: TrainingCallback = None): adapter_file = adapter_path / "adapters.safetensors" if args.test and not args.train: - apply_lora_layers(model, adapter_path) - - else: + # Allow testing without LoRA layers by providing empty path + if args.adapter_path != "": + apply_lora_layers(model, adapter_path) + elif args.train: adapter_path.mkdir(parents=True, exist_ok=True) save_config(vars(args), adapter_path / "adapter_config.json") # Convert linear layers to lora layers and unfreeze in the process - linear_to_lora_layers(model, args.lora_layers, args.lora_parameters) - + linear_to_lora_layers( + model, args.lora_layers, args.lora_parameters, args.use_dora + ) print_trainable_parameters(model) + else: + raise ValueError("Must provide at least one of --train or --test") print("Loading datasets") train_set, valid_set, test_set = load_dataset(args, tokenizer) @@ -257,12 +268,12 @@ def main(): config = yaml.load(file, yaml_loader) # Prefer parameters from command-line arguments for k, v in config.items(): - if not args.get(k, None): + if args.get(k, None) is not None: args[k] = v # Update defaults for unspecified parameters for k, v in CONFIG_DEFAULTS.items(): - if not args.get(k, None): + if args.get(k, None) is None: args[k] = v run(types.SimpleNamespace(**args)) diff --git a/llms/mlx_lm/tuner/dora.py b/llms/mlx_lm/tuner/dora.py new file mode 100644 index 00000000..364f5b45 --- /dev/null +++ b/llms/mlx_lm/tuner/dora.py @@ -0,0 +1,98 @@ +# Copyright © 2024 Apple Inc. + +import math + +import mlx.core as mx +import mlx.nn as nn + + +class DoRALinear(nn.Module): + @staticmethod + def from_linear( + linear: nn.Linear, + r: int = 8, + alpha: float = 16, + dropout: float = 0.0, + scale: float = 10.0, + ): + # TODO support quantized weights in DoRALinear + output_dims, input_dims = linear.weight.shape + if isinstance(linear, nn.QuantizedLinear): + raise ValueError("DoRALinear does not yet support quantization.") + dora_lin = DoRALinear( + input_dims=input_dims, + output_dims=output_dims, + r=r, + alpha=alpha, + dropout=dropout, + scale=scale, + ) + dora_lin.linear = linear + return dora_lin + + def to_linear(self, de_quantize: bool = False): + linear = self.linear + bias = "bias" in linear + weight = linear.weight + + # Use the same type as the linear weight if not quantized + dtype = weight.dtype + + output_dims, input_dims = weight.shape + fused_linear = nn.Linear(input_dims, output_dims, bias=bias) + + lora_b = (self.scale * self.lora_b.T).astype(dtype) + lora_a = self.lora_a.T.astype(dtype) + weight = weight + lora_b @ lora_a + norm_scale = self.m / mx.linalg.norm(weight, axis=1) + fused_linear.weight = norm_scale[:, None] * weight + + if bias: + fused_linear.bias = linear.bias + return fused_linear + + def __init__( + self, + input_dims: int, + output_dims: int, + r: int = 8, + alpha: float = 16, + dropout: float = 0.0, + scale: float = 10.0, + bias: bool = False, + ): + super().__init__() + + # Regular linear layer weights + self.linear = nn.Linear(input_dims, output_dims, bias=bias) + self.dropout = nn.Dropout(p=dropout) + + # Scale for low-rank update + self.scale = scale * (alpha / r) + + # Low rank lora weights + scale = 1 / math.sqrt(input_dims) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(input_dims, r), + ) + self.lora_b = mx.zeros(shape=(r, output_dims)) + self.m = mx.linalg.norm(self.linear.weight, axis=1) + + def __call__(self, x): + # Regular LoRA (without a bias) + y = x @ self.linear.weight.T + z = (self.dropout(x) @ self.lora_a) @ self.lora_b + out = y + (self.scale * z).astype(x.dtype) + + # Compute the norm of the adapted weights + adapted = self.linear.weight + (self.scale * self.lora_b.T) @ self.lora_a.T + denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1)) + + # Remove the norm and scale by the learned magnitude + out = (self.m / denom) * out + + if "bias" in self.linear: + out = out + self.linear.bias + return out diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 0b529366..e976e4af 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -9,6 +9,7 @@ import mlx.nn as nn import mlx.optimizers as opt from mlx.utils import tree_unflatten +from .dora import DoRALinear from .lora import LoRALinear @@ -36,6 +37,7 @@ def linear_to_lora_layers( model: nn.Module, num_lora_layers: int, config: Dict, + use_dora: bool = False, ): """ Convert some of the models linear layers to lora layers. @@ -46,6 +48,8 @@ def linear_to_lora_layers( starting from the last layer. config (dict): More configuration parameters for LoRA, including the rank, alpha, scale, and optional layer keys. + use_dora (bool): If True, uses DoRA instead of LoRA. + Default: ``False`` """ num_layers = len(model.layers) @@ -54,14 +58,16 @@ def linear_to_lora_layers( f"Requested {num_lora_layers} LoRA layers " f"but the model only has {num_layers} layers." ) + cls = DoRALinear if use_dora else LoRALinear - to_lora = lambda lin: LoRALinear.from_linear( - lin, - r=config["rank"], - alpha=config["alpha"], - scale=config["scale"], - dropout=config["dropout"], - ) + def to_lora(lin): + return cls.from_linear( + lin, + r=config["rank"], + alpha=config["alpha"], + scale=config["scale"], + dropout=config["dropout"], + ) keys = config.get("keys", None) if keys is not None: @@ -119,7 +125,12 @@ def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module: raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}") with open(adapter_path / "adapter_config.json", "r") as fid: config = types.SimpleNamespace(**json.load(fid)) - linear_to_lora_layers(model, config.lora_layers, config.lora_parameters) + linear_to_lora_layers( + model, + config.lora_layers, + config.lora_parameters, + getattr(config, "use_dora", False), + ) model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False) return model diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index 25ba8398..70b7614f 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.13.0" +__version__ = "0.13.1"