feat: move lora into mlx-lm (#337)

* feat: Add lora and qlora training to mlx-lm


---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Anchen
2024-01-23 08:44:37 -08:00
committed by GitHub
parent 85c1ff8fd6
commit 362e88a744
13 changed files with 987 additions and 111 deletions

View File

@@ -3,6 +3,7 @@ from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
@@ -158,18 +159,34 @@ class MixtralSparseMoeBlock(nn.Module):
x = x.reshape(-1, x.shape[-1])
gates = self.gate(x)
inds = mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
inds = mx.stop_gradient(
mx.argpartition(-gates, kth=ne, axis=-1)[:, :ne]
) # TODO remove it once we figure out how to fine tune TopK in MOE
scores = mx.softmax(
mx.take_along_axis(gates, inds, axis=-1).astype(mx.float32),
axis=-1,
).astype(gates.dtype)
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
if self.training:
mx.eval(inds)
inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.concatenate([self.experts[e](xt)[:, None] for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
return y.reshape(orig_shape)
@@ -229,7 +246,7 @@ class MixtralModel(nn.Module):
for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
return self.norm(h[:, T - 1 : T, :]), cache
return self.norm(h), cache
class Model(nn.Module):