qlora
This commit is contained in:
Awni Hannun
2024-01-04 21:05:59 -08:00
committed by GitHub
parent 4fa659acbd
commit 37b41cec60
8 changed files with 137 additions and 51 deletions

View File

@@ -1,5 +1,4 @@
# Copyright © 2023 Apple Inc.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
@@ -24,7 +23,11 @@ class ModelArgs:
class LoRALinear(nn.Module):
@staticmethod
def from_linear(linear: nn.Linear, rank: int = 8):
# TODO remove when input_dims and output_dims are attributes
# on linear and quantized linear
output_dims, input_dims = linear.weight.shape
if isinstance(linear, nn.QuantizedLinear):
input_dims *= 32 // linear.bits
lora_lin = LoRALinear(input_dims, output_dims, rank)
lora_lin.linear = linear
return lora_lin
@@ -47,7 +50,10 @@ class LoRALinear(nn.Module):
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
def __call__(self, x):
y = self.linear(x.astype(self.linear.weight.dtype))
dtype = self.linear.weight.dtype
if isinstance(self.linear, nn.QuantizedLinear):
dtype = self.linear.scales.dtype
y = self.linear(x.astype(dtype))
z = (x @ self.lora_a) @ self.lora_b
return y + 2.0 * z