mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
@@ -1,5 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
@@ -24,7 +23,11 @@ class ModelArgs:
|
||||
class LoRALinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_linear(linear: nn.Linear, rank: int = 8):
|
||||
# TODO remove when input_dims and output_dims are attributes
|
||||
# on linear and quantized linear
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
if isinstance(linear, nn.QuantizedLinear):
|
||||
input_dims *= 32 // linear.bits
|
||||
lora_lin = LoRALinear(input_dims, output_dims, rank)
|
||||
lora_lin.linear = linear
|
||||
return lora_lin
|
||||
@@ -47,7 +50,10 @@ class LoRALinear(nn.Module):
|
||||
self.lora_b = mx.zeros(shape=(lora_rank, output_dims))
|
||||
|
||||
def __call__(self, x):
|
||||
y = self.linear(x.astype(self.linear.weight.dtype))
|
||||
dtype = self.linear.weight.dtype
|
||||
if isinstance(self.linear, nn.QuantizedLinear):
|
||||
dtype = self.linear.scales.dtype
|
||||
y = self.linear(x.astype(dtype))
|
||||
z = (x @ self.lora_a) @ self.lora_b
|
||||
return y + 2.0 * z
|
||||
|
||||
|
Reference in New Issue
Block a user