import math import mlx.core as mx import mlx.nn as nn class LoRALinear(nn.Module): @staticmethod def from_linear( linear: nn.Linear, r: int = 8, lora_alpha: float = 16, lora_dropout: float = 0.05, scale: float = 10.0, ): # TODO remove when input_dims and output_dims are attributes # on linear and quantized linear output_dims, input_dims = linear.weight.shape if isinstance(linear, nn.QuantizedLinear): input_dims *= 32 // linear.bits lora_lin = LoRALinear( input_dims=input_dims, output_dims=output_dims, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, scale=scale, ) lora_lin.linear = linear return lora_lin def to_linear(self): linear = self.linear bias = "bias" in linear weight = linear.weight is_quantized = isinstance(linear, nn.QuantizedLinear) # Use the same type as the linear weight if not quantized dtype = weight.dtype if is_quantized: dtype = mx.float16 weight = mx.dequantize( weight, linear.scales, linear.biases, linear.group_size, linear.bits, ) output_dims, input_dims = weight.shape fused_linear = nn.Linear(input_dims, output_dims, bias=bias) lora_b = (self.scale * self.lora_b.T).astype(dtype) lora_a = self.lora_a.T.astype(dtype) fused_linear.weight = weight + lora_b @ lora_a if bias: fused_linear.bias = linear.bias if is_quantized: fused_linear = nn.QuantizedLinear.from_linear( fused_linear, linear.group_size, linear.bits, ) return fused_linear def __init__( self, input_dims: int, output_dims: int, r: int = 8, lora_alpha: float = 16, lora_dropout: float = 0.0, scale: float = 10.0, bias: bool = False, ): super().__init__() # Regular linear layer weights self.linear = nn.Linear(input_dims, output_dims, bias=bias) self.lora_dropout = nn.Dropout(p=lora_dropout) # Scale for low-rank update self.scale = scale * (lora_alpha / r) # Low rank lora weights scale = 1 / math.sqrt(input_dims) self.lora_a = mx.random.uniform( low=-scale, high=scale, shape=(input_dims, r), ) self.lora_b = mx.zeros(shape=(r, output_dims)) def __call__(self, x): dtype = self.linear.weight.dtype if isinstance(self.linear, nn.QuantizedLinear): dtype = self.linear.scales.dtype y = self.linear(x.astype(dtype)) z = (self.lora_dropout(x) @ self.lora_a) @ self.lora_b return y + self.scale * z