feat: move lora into mlx-lm (#337)

* feat: Add lora and qlora training to mlx-lm


---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Anchen
2024-01-23 08:44:37 -08:00
committed by GitHub
parent 85c1ff8fd6
commit 362e88a744
13 changed files with 987 additions and 111 deletions

View File

@@ -4,7 +4,7 @@ import json
import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -198,7 +198,7 @@ class Model(nn.Module):
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> tuple[mx.array, mx.array]:
) -> Tuple[mx.array, mx.array]:
mask = None
if x.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])