1 Commits

Author SHA1 Message Date
Alex Barron
64ceb62674 load q4_k_m inefficiently 2024-12-03 19:54:57 -08:00
137 changed files with 16416 additions and 6193 deletions

67
.circleci/config.yml Normal file
View File

@@ -0,0 +1,67 @@
version: 2.1
orbs:
apple: ml-explore/pr-approval@0.1.0
jobs:
linux_build_and_test:
docker:
- image: cimg/python:3.9
steps:
- checkout
- run:
name: Run style checks
command: |
pip install pre-commit
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
mlx_lm_build_and_test:
macos:
xcode: "15.2.0"
resource_class: macos.m1.large.gen1
steps:
- checkout
- run:
name: Install dependencies
command: |
brew install python@3.9
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install unittest-xml-reporting
cd llms/
pip install -e ".[testing]"
- run:
name: Run Python tests
command: |
source env/bin/activate
python -m xmlrunner discover -v llms/tests -o test-results/
- store_test_results:
path: test-results
workflows:
build_and_test:
when:
matches:
pattern: "^(?!pull/)[-\\w]+$"
value: << pipeline.git.branch >>
jobs:
- mlx_lm_build_and_test
- linux_build_and_test
prb:
when:
matches:
pattern: "^pull/\\d+(/head)?$"
value: << pipeline.git.branch >>
jobs:
- hold:
type: approval
- apple/authenticate:
context: pr-approval
- mlx_lm_build_and_test:
requires: [ hold ]
- linux_build_and_test:
requires: [ hold ]

View File

@@ -1,25 +0,0 @@
name: Test
on:
push:
branches: ["main"]
pull_request:
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/head/main' }}
jobs:
check_lint:
if: github.repository == 'ml-explore/mlx-examples'
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v5
- uses: actions/setup-python@v6
with:
python-version: "3.10"
- uses: pre-commit/action@v3.0.1

View File

@@ -1,10 +1,10 @@
repos:
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 25.1.0
rev: 24.8.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 6.0.0
rev: 5.13.2
hooks:
- id: isort
args:

View File

@@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1`, `OLMoE` archtectures and support for `full-fine-tuning`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Mamba` and support for `full-fine-tuning`.

View File

@@ -4,12 +4,12 @@ This repo contains a variety of standalone examples using the [MLX
framework](https://github.com/ml-explore/mlx).
The [MNIST](mnist) example is a good starting point to learn how to use MLX.
Some more useful examples are listed below. Check-out [MLX
LM](https://github.com/ml-explore/mlx-lm) for a more fully featured Python
package for LLMs with MLX.
Some more useful examples are listed below.
### Text Models
- [MLX LM](llms/README.md) a package for LLM text generation, fine-tuning, and more.
- [Transformer language model](transformer_lm) training.
- Minimal examples of large scale text generation with [LLaMA](llms/llama),
[Mistral](llms/mistral), and more in the [LLMs](llms) directory.
@@ -30,7 +30,6 @@ package for LLMs with MLX.
- Speech recognition with [OpenAI's Whisper](whisper).
- Audio compression and generation with [Meta's EnCodec](encodec).
- Music generation with [Meta's MusicGen](musicgen).
### Multimodal models
@@ -46,7 +45,7 @@ package for LLMs with MLX.
### Hugging Face
You can directly use or download converted checkpoints from the [MLX
Note: You can now directly download a few converted checkpoints from the [MLX
Community](https://huggingface.co/mlx-community) organization on Hugging Face.
We encourage you to join the community and [contribute new
models](https://github.com/ml-explore/mlx-examples/issues/155).

View File

@@ -48,17 +48,3 @@ Note this was run on an M1 Macbook Pro with 16GB RAM.
At the time of writing, `mlx` doesn't have built-in learning rate schedules.
We intend to update this example once these features are added.
## Distributed training
The example also supports distributed data parallel training. You can launch a
distributed training as follows:
```shell
$ cat >hostfile.json
[
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]},
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}
]
$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20
```

View File

@@ -1,4 +1,3 @@
import mlx.core as mx
import numpy as np
from mlx.data.datasets import load_cifar10
@@ -13,11 +12,8 @@ def get_cifar10(batch_size, root=None):
x = x.astype("float32") / 255.0
return (x - mean) / std
group = mx.distributed.init()
tr_iter = (
tr.shuffle()
.partition_if(group.size() > 1, group.size(), group.rank())
.to_stream()
.image_random_h_flip("image", prob=0.5)
.pad("image", 0, 4, 4, 0.0)
@@ -29,11 +25,6 @@ def get_cifar10(batch_size, root=None):
)
test = load_cifar10(root=root, train=False)
test_iter = (
test.to_stream()
.partition_if(group.size() > 1, group.size(), group.rank())
.key_transform("image", normalize)
.batch(batch_size)
)
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
return tr_iter, test_iter

View File

@@ -23,13 +23,6 @@ parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument("--cpu", action="store_true", help="use cpu only")
def print_zero(group, *args, **kwargs):
if group.rank() != 0:
return
flush = kwargs.pop("flush", True)
print(*args, **kwargs, flush=flush)
def eval_fn(model, inp, tgt):
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)
@@ -41,20 +34,9 @@ def train_epoch(model, train_iter, optimizer, epoch):
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
return loss, acc
world = mx.distributed.init()
losses = 0
accuracies = 0
samples_per_sec = 0
count = 0
def average_stats(stats, count):
if world.size() == 1:
return [s / count for s in stats]
with mx.stream(mx.cpu):
stats = mx.distributed.all_sum(mx.array(stats))
count = mx.distributed.all_sum(count)
return (stats / count).tolist()
losses = []
accs = []
samples_per_sec = []
state = [model.state, optimizer.state]
@@ -62,7 +44,6 @@ def train_epoch(model, train_iter, optimizer, epoch):
def step(inp, tgt):
train_step_fn = nn.value_and_grad(model, train_step)
(loss, acc), grads = train_step_fn(model, inp, tgt)
grads = nn.utils.average_gradients(grads)
optimizer.update(model, grads)
return loss, acc
@@ -71,79 +52,69 @@ def train_epoch(model, train_iter, optimizer, epoch):
y = mx.array(batch["label"])
tic = time.perf_counter()
loss, acc = step(x, y)
mx.eval(loss, acc, state)
mx.eval(state)
toc = time.perf_counter()
losses += loss.item()
accuracies += acc.item()
samples_per_sec += x.shape[0] / (toc - tic)
count += 1
loss = loss.item()
acc = acc.item()
losses.append(loss)
accs.append(acc)
throughput = x.shape[0] / (toc - tic)
samples_per_sec.append(throughput)
if batch_counter % 10 == 0:
l, a, s = average_stats(
[losses, accuracies, world.size() * samples_per_sec],
count,
)
print_zero(
world,
print(
" | ".join(
(
f"Epoch {epoch:02d} [{batch_counter:03d}]",
f"Train loss {l:.3f}",
f"Train acc {a:.3f}",
f"Throughput: {s:.2f} images/second",
f"Train loss {loss:.3f}",
f"Train acc {acc:.3f}",
f"Throughput: {throughput:.2f} images/second",
)
),
)
)
return average_stats([losses, accuracies, world.size() * samples_per_sec], count)
mean_tr_loss = mx.mean(mx.array(losses))
mean_tr_acc = mx.mean(mx.array(accs))
samples_per_sec = mx.mean(mx.array(samples_per_sec))
return mean_tr_loss, mean_tr_acc, samples_per_sec
def test_epoch(model, test_iter, epoch):
accuracies = 0
count = 0
accs = []
for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["image"])
y = mx.array(batch["label"])
acc = eval_fn(model, x, y)
accuracies += acc.item()
count += 1
with mx.stream(mx.cpu):
accuracies = mx.distributed.all_sum(accuracies)
count = mx.distributed.all_sum(count)
return (accuracies / count).item()
acc_value = acc.item()
accs.append(acc_value)
mean_acc = mx.mean(mx.array(accs))
return mean_acc
def main(args):
mx.random.seed(args.seed)
# Initialize the distributed group and report the nodes that showed up
world = mx.distributed.init()
if world.size() > 1:
print(f"Starting rank {world.rank()} of {world.size()}", flush=True)
model = getattr(resnet, args.arch)()
print_zero(world, f"Number of params: {model.num_params() / 1e6:0.04f} M")
print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
optimizer = optim.Adam(learning_rate=args.lr)
train_data, test_data = get_cifar10(args.batch_size)
for epoch in range(args.epochs):
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
print_zero(
world,
print(
" | ".join(
(
f"Epoch: {epoch}",
f"avg. Train loss {tr_loss:.3f}",
f"avg. Train acc {tr_acc:.3f}",
f"Throughput: {throughput:.2f} images/sec",
f"avg. Train loss {tr_loss.item():.3f}",
f"avg. Train acc {tr_acc.item():.3f}",
f"Throughput: {throughput.item():.2f} images/sec",
)
),
)
)
test_acc = test_epoch(model, test_data, epoch)
print_zero(world, f"Epoch: {epoch} | Test acc {test_acc:.3f}")
print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}")
train_data.reset()
test_data.reset()

View File

@@ -121,7 +121,7 @@ if __name__ == "__main__":
mlx_path.mkdir(parents=True, exist_ok=True)
print("[INFO] Loading")
torch_weights = torch.load(torch_path / "pytorch_model.bin", weights_only=True)
torch_weights = torch.load(torch_path / "pytorch_model.bin")
print("[INFO] Converting")
mlx_weights = {
k: torch_to_mx(v, dtype=args.dtype) for k, v in torch_weights.items()

View File

@@ -167,9 +167,8 @@ python dreambooth.py \
path/to/dreambooth/dataset/dog6
```
Or you can directly use the pre-processed Hugging Face dataset
[mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6)
for fine-tuning.
Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning.
```shell
python dreambooth.py \
@@ -211,71 +210,3 @@ speed during generation.
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
[^2]: The images are from unsplash by https://unsplash.com/@alvannee .
Distributed Computation
------------------------
The FLUX example supports distributed computation during both generation and
training. See the [distributed communication
documentation](https://ml-explore.github.io/mlx/build/html/usage/distributed.html)
for information on how to set-up MLX for distributed communication. The rest of
this section assumes you can launch distributed MLX programs using `mlx.launch
--hostfile hostfile.json`.
### Distributed Finetuning
Distributed finetuning scales very well with FLUX and all one has to do is
adjust the gradient accumulation and training iterations so that the batch
size remains the same. For instance, to replicate the following training
```shell
python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 8 \
mlx-community/dreambooth-dog6
```
On 4 machines we simply run
```shell
mlx.launch --verbose --hostfile hostfile.json -- python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 150 --iterations 300 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 2 \
mlx-community/dreambooth-dog6
```
Note the iterations that changed to 300 from 1200 and the gradient accumulations to 2 from 8.
### Distributed Inference
Distributed inference can be divided in two different approaches. The first
approach is the data-parallel approach, where each node generates its own
images and shares them at the end. The second approach is the model-parallel
approach where the model is shared across the nodes and they collaboratively
generate the images.
The `txt2image.py` script will attempt to choose the best approach depending on
how many images are being generated across the nodes. The model-parallel
approach can be forced by passing the argument `--force-shard`.
For better performance in the model-parallel approach we suggest that you use a
[thunderbolt
ring](https://ml-explore.github.io/mlx/build/html/usage/distributed.html#getting-started-with-ring).
All you have to do once again is use `mlx.launch` as follows
```shell
mlx.launch --verbose --hostfile hostfile.json -- \
python txt2image.py --model schnell \
--n-images 8 \
--image-size 512x512 \
--verbose \
'A photo of an astronaut riding a horse on Mars'
```
for model-parallel generation you may want to also pass `--env
MLX_METAL_FAST_SYNCH=1` to `mlx.launch` which is an experimental setting that
reduces the CPU/GPU synchronization overhead.

View File

@@ -289,4 +289,4 @@ if __name__ == "__main__":
tic = time.time()
save_adapters("final_adapters.safetensors", flux, args)
print("Training successful.")
print(f"Training successful. Saved final weights to {args.adapter_file}.")

View File

@@ -178,8 +178,6 @@ class DoubleStreamBlock(nn.Module):
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.sharding_group = None
def __call__(
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
) -> Tuple[mx.array, mx.array]:
@@ -218,35 +216,18 @@ class DoubleStreamBlock(nn.Module):
attn = _attention(q, k, v, pe)
txt_attn, img_attn = mx.split(attn, [S], axis=1)
# Project - cat - average - split
txt_attn = self.txt_attn.proj(txt_attn)
img_attn = self.img_attn.proj(img_attn)
if self.sharding_group is not None:
attn = mx.concatenate([txt_attn, img_attn], axis=1)
attn = mx.distributed.all_sum(attn, group=self.sharding_group)
txt_attn, img_attn = mx.split(attn, [S], axis=1)
# calculate the img bloks
img = img + img_mod1.gate * img_attn
img_mlp = self.img_mlp(
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp(
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
)
# calculate the txt bloks
txt = txt + txt_mod1.gate * txt_attn
txt_mlp = self.txt_mlp(
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp(
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
)
if self.sharding_group is not None:
txt_img = mx.concatenate([txt_mlp, img_mlp], axis=1)
txt_img = mx.distributed.all_sum(txt_img, group=self.sharding_group)
txt_mlp, img_mlp = mx.split(txt_img, [S], axis=1)
# finalize the img/txt blocks
img = img + img_mod2.gate * img_mlp
txt = txt + txt_mod2.gate * txt_mlp
return img, txt

View File

@@ -5,7 +5,6 @@ from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx.nn.layers.distributed import shard_inplace, shard_linear
from .layers import (
DoubleStreamBlock,
@@ -86,8 +85,6 @@ class Flux(nn.Module):
def sanitize(self, weights):
new_weights = {}
for k, w in weights.items():
if k.startswith("model.diffusion_model."):
k = k[22:]
if k.endswith(".scale"):
k = k[:-6] + ".weight"
for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
@@ -97,47 +94,6 @@ class Flux(nn.Module):
new_weights[k] = w
return new_weights
def shard(self, group: Optional[mx.distributed.Group] = None):
group = group or mx.distributed.init()
N = group.size()
if N == 1:
return
for block in self.double_blocks:
block.num_heads //= N
block.img_attn.num_heads //= N
block.txt_attn.num_heads //= N
block.sharding_group = group
block.img_attn.qkv = shard_linear(
block.img_attn.qkv, "all-to-sharded", segments=3, group=group
)
block.txt_attn.qkv = shard_linear(
block.txt_attn.qkv, "all-to-sharded", segments=3, group=group
)
shard_inplace(block.img_attn.proj, "sharded-to-all", group=group)
shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group)
block.img_mlp.layers[0] = shard_linear(
block.img_mlp.layers[0], "all-to-sharded", group=group
)
block.txt_mlp.layers[0] = shard_linear(
block.txt_mlp.layers[0], "all-to-sharded", group=group
)
shard_inplace(block.img_mlp.layers[2], "sharded-to-all", group=group)
shard_inplace(block.txt_mlp.layers[2], "sharded-to-all", group=group)
for block in self.single_blocks:
block.num_heads //= N
block.hidden_size //= N
block.linear1 = shard_linear(
block.linear1,
"all-to-sharded",
segments=[1 / 7, 2 / 7, 3 / 7],
group=group,
)
block.linear2 = shard_linear(
block.linear2, "sharded-to-all", segments=[1 / 5], group=group
)
def __call__(
self,
img: mx.array,

View File

@@ -7,7 +7,7 @@ import mlx.core as mx
class FluxSampler:
def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.15):
def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5):
self._base_shift = base_shift
self._max_shift = max_shift
self._schnell = "schnell" in name
@@ -25,7 +25,7 @@ class FluxSampler:
):
t = mx.linspace(start, stop, num_steps + 1)
if not self._schnell:
if self._schnell:
t = self._time_shift(image_sequence_length, t)
return t.tolist()
@@ -50,7 +50,6 @@ class FluxSampler:
if noise is not None
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
)
t = t.reshape([-1] + [1] * (x.ndim - 1))
return x * (1 - t) + t * noise
def step(self, pred, x_t, t, t_prev):

View File

@@ -1,109 +0,0 @@
import argparse
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline
def print_zero(group, *args, **kwargs):
if group.rank() == 0:
flush = kwargs.pop("flush", True)
print(*args, **kwargs, flush=flush)
def quantization_predicate(name, m):
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
def to_latent_size(image_size):
h, w = image_size
h = ((h + 15) // 16) * 16
w = ((w + 15) // 16) * 16
if (h, w) != image_size:
print(
"Warning: The image dimensions need to be divisible by 16px. "
f"Changing size to {h}x{w}."
)
return (h // 8, w // 8)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using FLUX"
)
parser.add_argument("--quantize", "-q", action="store_true")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
parser.add_argument("--output", default="out.png")
args = parser.parse_args()
flux = FluxPipeline("flux-" + args.model, t5_padding=True)
if args.quantize:
nn.quantize(flux.flow, class_predicate=quantization_predicate)
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
group = mx.distributed.init()
if group.size() > 1:
flux.flow.shard(group)
print_zero(group, "Loading models")
flux.ensure_models_are_loaded()
def print_help():
print_zero(group, "The command list:")
print_zero(group, "- 'q' to exit")
print_zero(group, "- 's HxW' to change the size of the image")
print_zero(group, "- 'n S' to change the number of steps")
print_zero(group, "- 'h' to print this help")
print_zero(group, "FLUX interactive session")
print_help()
seed = 0
size = (512, 512)
latent_size = to_latent_size(size)
steps = 50 if args.model == "dev" else 4
while True:
prompt = input(">> " if group.rank() == 0 else "")
if prompt == "q":
break
if prompt == "h":
print_help()
continue
if prompt.startswith("s "):
size = tuple([int(xi) for xi in prompt[2:].split("x")])
print_zero(group, "Setting the size to", size)
latent_size = to_latent_size(size)
continue
if prompt.startswith("n "):
steps = int(prompt[2:])
print_zero(group, "Setting the steps to", steps)
continue
seed += 1
latents = flux.generate_latents(
prompt,
n_images=1,
num_steps=steps,
latent_size=latent_size,
guidance=4.0,
seed=seed,
)
print_zero(group, "Processing prompt")
mx.eval(next(latents))
print_zero(group, "Generating latents")
for xt in tqdm(latents, total=steps, disable=group.rank() > 0):
mx.eval(xt)
print_zero(group, "Generating image")
xt = flux.decode(xt, latent_size)
xt = (xt * 255).astype(mx.uint8)
mx.eval(xt)
im = Image.fromarray(np.array(xt[0]))
im.save(args.output)
print_zero(group, "Saved at", args.output, end="\n\n")

View File

@@ -41,7 +41,7 @@ def load_adapter(flux, adapter_file, fuse=False):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using FLUX"
description="Generate images from a textual prompt using stable diffusion"
)
parser.add_argument("prompt")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
@@ -62,7 +62,6 @@ if __name__ == "__main__":
parser.add_argument("--adapter")
parser.add_argument("--fuse-adapter", action="store_true")
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
parser.add_argument("--force-shard", action="store_true")
args = parser.parse_args()
# Load the models
@@ -77,24 +76,6 @@ if __name__ == "__main__":
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
# Figure out what kind of distributed generation we should do
group = mx.distributed.init()
n_images = args.n_images
should_gather = False
if group.size() > 1:
if args.force_shard or n_images < group.size() or n_images % group.size() != 0:
flux.flow.shard(group)
else:
n_images //= group.size()
should_gather = True
# If we are sharding we should have the same seed and if we are doing
# data parallel generation we should have different seeds
if args.seed is None:
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
if should_gather:
args.seed = args.seed + group.rank()
if args.preload_models:
flux.ensure_models_are_loaded()
@@ -102,7 +83,7 @@ if __name__ == "__main__":
latent_size = to_latent_size(args.image_size)
latents = flux.generate_latents(
args.prompt,
n_images=n_images,
n_images=args.n_images,
num_steps=args.steps,
latent_size=latent_size,
guidance=args.guidance,
@@ -112,8 +93,8 @@ if __name__ == "__main__":
# First we get and eval the conditioning
conditioning = next(latents)
mx.eval(conditioning)
peak_mem_conditioning = mx.get_peak_memory() / 1024**3
mx.reset_peak_memory()
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the text encoders.
@@ -121,42 +102,36 @@ if __name__ == "__main__":
del flux.clip
# Actual denoising loop
for x_t in tqdm(latents, total=args.steps, disable=group.rank() > 0):
for x_t in tqdm(latents, total=args.steps):
mx.eval(x_t)
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the flow transformer.
del flux.flow
peak_mem_generation = mx.get_peak_memory() / 1024**3
mx.reset_peak_memory()
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
# Decode them into images
decoded = []
for i in tqdm(range(0, n_images, args.decoding_batch_size)):
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
mx.eval(decoded[-1])
peak_mem_decoding = mx.get_peak_memory() / 1024**3
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3
peak_mem_overall = max(
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
)
# Gather them if each node has different images
decoded = mx.concatenate(decoded, axis=0)
if should_gather:
decoded = mx.distributed.all_gather(decoded)
mx.eval(decoded)
if args.save_raw:
*name, suffix = args.output.split(".")
name = ".".join(name)
x = decoded
x = mx.concatenate(decoded, axis=0)
x = (x * 255).astype(mx.uint8)
for i in range(len(x)):
im = Image.fromarray(np.array(x[i]))
im.save(".".join([name, str(i), suffix]))
else:
# Arrange them on a grid
x = decoded
x = mx.concatenate(decoded, axis=0)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
@@ -168,7 +143,7 @@ if __name__ == "__main__":
im.save(args.output)
# Report the peak memory used during generation
if args.verbose and group.rank() == 0:
if args.verbose:
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")

View File

@@ -79,10 +79,10 @@ def load_image(image_source):
def prepare_inputs(processor, image, prompt):
if isinstance(image, str):
image = load_image(image)
inputs = processor(image, prompt, return_tensors="np")
inputs = processor(prompt, image, return_tensors="np")
pixel_values = mx.array(inputs["pixel_values"])
input_ids = mx.array(inputs["input_ids"])
return pixel_values, input_ids
return input_ids, pixel_values
def load_model(model_path, tokenizer_config={}):
@@ -126,7 +126,8 @@ def main():
processor, model = load_model(args.model, tokenizer_config)
prompt = codecs.decode(args.prompt, "unicode_escape")
pixel_values, input_ids = prepare_inputs(processor, args.image, prompt)
input_ids, pixel_values = prepare_inputs(processor, args.image, prompt)
print(prompt)
generated_text = generate_text(

View File

@@ -104,21 +104,31 @@ class LlavaModel(nn.Module):
self, image_features, inputs_embeds, input_ids
):
image_token_index = self.config.image_token_index
batch_size, num_image_patches, embed_dim = image_features.shape
num_images, num_image_patches, embed_dim = image_features.shape
# Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = mx.array(
np.where(input_ids[0] == image_token_index)[0], mx.uint32
)
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
if len(image_positions) != num_image_patches:
if len(image_positions) != num_images:
raise ValueError(
f"The number of image tokens ({len(image_positions)}) does not "
f" match the number of image patches ({num_image_patches})."
f" match the number of image inputs ({num_images})."
)
inputs_embeds[0, image_positions] = image_features
return inputs_embeds
text_segments = []
start_idx = 0
for position in image_positions:
text_segments.append(inputs_embeds[:, start_idx:position])
start_idx = position + 1
image_embeddings = mx.split(image_features, image_features.shape[0])
final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p]
final_embeddings += [inputs_embeds[:, start_idx:]]
# Create a final embedding of shape
# (1, num_image_patches*num_images + sequence_len, embed_dim)
return mx.concatenate(final_embeddings, axis=1)
def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
input_embddings = self.get_input_embeddings(input_ids, pixel_values)

47
llms/CONTRIBUTING.md Normal file
View File

@@ -0,0 +1,47 @@
# Contributing to MLX LM
Below are some tips to port LLMs available on Hugging Face to MLX.
Before starting checkout the [general contribution
guidelines](https://github.com/ml-explore/mlx-examples/blob/main/CONTRIBUTING.md).
Next, from this directory, do an editable install:
```shell
pip install -e .
```
Then check if the model has weights in the
[safetensors](https://huggingface.co/docs/safetensors/index) format. If not
[follow instructions](https://huggingface.co/spaces/safetensors/convert) to
convert it.
After that, add the model file to the
[`mlx_lm/models`](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/models)
directory. You can see other examples there. We recommend starting from a model
that is similar to the model you are porting.
Make sure the name of the new model file is the same as the `model_type` in the
`config.json`, for example
[starcoder2](https://huggingface.co/bigcode/starcoder2-7b/blob/main/config.json#L17).
To determine the model layer names, we suggest either:
- Refer to the Transformers implementation if you are familiar with the
codebase.
- Load the model weights and check the weight names which will tell you about
the model structure.
- Look at the names of the weights by inspecting `model.safetensors.index.json`
in the Hugging Face repo.
To add LoRA support edit
[`mlx_lm/tuner/utils.py`](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/tuner/utils.py#L27-L60)
Finally, add a test for the new modle type to the [model
tests](https://github.com/ml-explore/mlx-examples/blob/main/llms/tests/test_models.py).
From the `llms/` directory, you can run the tests with:
```shell
python -m unittest discover tests/
```

2
llms/MANIFEST.in Normal file
View File

@@ -0,0 +1,2 @@
include mlx_lm/requirements.txt
recursive-include mlx_lm/ *.py

View File

@@ -1,6 +1,277 @@
# MOVE NOTICE
## Generate Text with LLMs and MLX
The mlx-lm package has moved to a [new repo](https://github.com/ml-explore/mlx-lm).
The easiest way to get started is to install the `mlx-lm` package:
The package has been removed from the MLX Examples repo. Send new contributions
and issues to the MLX LM repo.
**With `pip`**:
```sh
pip install mlx-lm
```
**With `conda`**:
```sh
conda install -c conda-forge mlx-lm
```
The `mlx-lm` package also has:
- [LoRA, QLoRA, and full fine-tuning](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md)
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
### Quick Start
To generate text with an LLM use:
```bash
mlx_lm.generate --prompt "Hi!"
```
To chat with an LLM use:
```bash
mlx_lm.chat
```
This will give you a chat REPL that you can use to interact with the LLM. The
chat context is preserved during the lifetime of the REPL.
Commands in `mlx-lm` typically take command line options which let you specify
the model, sampling parameters, and more. Use `-h` to see a list of available
options for a command, e.g.:
```bash
mlx_lm.generate -h
```
### Python API
You can use `mlx-lm` as a module:
```python
from mlx_lm import load, generate
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
response = generate(model, tokenizer, prompt=prompt, verbose=True)
```
To see a description of all the arguments you can do:
```
>>> help(generate)
```
Check out the [generation
example](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm/examples/generate_response.py)
to see how to use the API in more detail.
The `mlx-lm` package also comes with functionality to quantize and optionally
upload models to the Hugging Face Hub.
You can convert models in the Python API with:
```python
from mlx_lm import convert
repo = "mistralai/Mistral-7B-Instruct-v0.3"
upload_repo = "mlx-community/My-Mistral-7B-Instruct-v0.3-4bit"
convert(repo, quantize=True, upload_repo=upload_repo)
```
This will generate a 4-bit quantized Mistral 7B and upload it to the repo
`mlx-community/My-Mistral-7B-Instruct-v0.3-4bit`. It will also save the
converted model in the path `mlx_model` by default.
To see a description of all the arguments you can do:
```
>>> help(convert)
```
#### Streaming
For streaming generation, use the `stream_generate` function. This returns a
generator object which streams the output text, token, and log probabilities.
For example,
```python
from mlx_lm import load, stream_generate
repo = "mlx-community/Mistral-7B-Instruct-v0.3-4bit"
model, tokenizer = load(repo)
prompt = "Write a story about Einstein"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True)
print()
```
### Command Line
You can also use `mlx-lm` from the command line with:
```
mlx_lm.generate --model mistralai/Mistral-7B-Instruct-v0.3 --prompt "hello"
```
This will download a Mistral 7B model from the Hugging Face Hub and generate
text using the given prompt.
For a full list of options run:
```
mlx_lm.generate --help
```
To quantize a model from the command line run:
```
mlx_lm.convert --hf-path mistralai/Mistral-7B-Instruct-v0.3 -q
```
For more options run:
```
mlx_lm.convert --help
```
You can upload new models to Hugging Face by specifying `--upload-repo` to
`convert`. For example, to upload a quantized Mistral-7B model to the
[MLX Hugging Face community](https://huggingface.co/mlx-community) you can do:
```
mlx_lm.convert \
--hf-path mistralai/Mistral-7B-Instruct-v0.3 \
-q \
--upload-repo mlx-community/my-4bit-mistral
```
### Long Prompts and Generations
`mlx-lm` has some tools to scale efficiently to long prompts and generations:
- A rotating fixed-size key-value cache.
- Prompt caching
To use the rotating key-value cache pass the argument `--max-kv-size n` where
`n` can be any integer. Smaller values like `512` will use very little RAM but
result in worse quality. Larger values like `4096` or higher will use more RAM
but have better quality.
Caching prompts can substantially speedup reusing the same long context with
different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
```bash
cat prompt.txt | mlx_lm.cache_prompt \
--model mistralai/Mistral-7B-Instruct-v0.3 \
--prompt - \
--prompt-cache-file mistral_prompt.safetensors
```
Then use the cached prompt with `mlx_lm.generate`:
```
mlx_lm.generate \
--prompt-cache-file mistral_prompt.safetensors \
--prompt "\nSummarize the above text."
```
The cached prompt is treated as a prefix to the supplied prompt. Also notice
when using a cached prompt, the model to use is read from the cache and need
not be supplied explicitly.
Prompt caching can also be used in the Python API in order to to avoid
recomputing the prompt. This is useful in multi-turn dialogues or across
requests that use the same context. See the
[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
for more usage details.
### Supported Models
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
run is not supported, file an
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
submit a pull request.
Here are a few examples of Hugging Face models that work with this example:
- [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
- [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
- [deepseek-ai/deepseek-coder-6.7b-instruct](https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct)
- [01-ai/Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)
- [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
- [mistralai/Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
- [Qwen/Qwen-7B](https://huggingface.co/Qwen/Qwen-7B)
- [pfnet/plamo-13b](https://huggingface.co/pfnet/plamo-13b)
- [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct)
- [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)
- [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b)
- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct)
Most
[Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending),
[Llama](https://huggingface.co/models?library=transformers,safetensors&other=llama&sort=trending),
[Phi-2](https://huggingface.co/models?library=transformers,safetensors&other=phi&sort=trending),
and
[Mixtral](https://huggingface.co/models?library=transformers,safetensors&other=mixtral&sort=trending)
style models should work out of the box.
For some models (such as `Qwen` and `plamo`) the tokenizer requires you to
enable the `trust_remote_code` option. You can do this by passing
`--trust-remote-code` in the command line. If you don't specify the flag
explicitly, you will be prompted to trust remote code in the terminal when
running the model.
For `Qwen` models you must also specify the `eos_token`. You can do this by
passing `--eos-token "<|endoftext|>"` in the command
line.
These options can also be set in the Python API. For example:
```python
model, tokenizer = load(
"qwen/Qwen-7B",
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
)
```
### Large Models
> [!NOTE]
This requires macOS 15.0 or higher to work.
Models which are large relative to the total RAM available on the machine can
be slow. `mlx-lm` will attempt to make them faster by wiring the memory
occupied by the model and cache. This requires macOS 15 or higher to
work.
If you see the following warning message:
> [WARNING] Generating with a model that requires ...
then the model will likely be slow on the given machine. If the model fits in
RAM then it can often be sped up by increasing the system wired memory limit.
To increase the limit, set the following `sysctl`:
```bash
sudo sysctl iogpu.wired_limit_mb=N
```
The value `N` should be larger than the size of the model in megabytes but
smaller than the memory size of the machine.

View File

@@ -40,7 +40,7 @@ def generate(
if len(tokens) == 0:
print("No tokens generated for this prompt")
return
prompt_tps = len(prompt) / prompt_time
prompt_tps = prompt.size / prompt_time
gen_tps = (len(tokens) - 1) / gen_time
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")

View File

@@ -19,10 +19,10 @@ class ModelArgs:
rms_norm_eps: float
vocab_size: int
context_length: int
num_key_value_heads: Optional[int] = None
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
model_type: Optional[str] = None
model_type: str = None
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
@@ -54,7 +54,7 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads or n_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
@@ -66,7 +66,7 @@ class Attention(nn.Module):
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / float(args.rope_scaling["factor"])
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
@@ -254,7 +254,7 @@ def translate_weight_names(name):
return name
def load(gguf_file: str, repo: Optional[str] = None):
def load(gguf_file: str, repo: str = None):
# If the gguf_file exists, try to load model from it.
# Otherwise try to download and cache from the HF repo
if not Path(gguf_file).exists():

View File

@@ -7,7 +7,6 @@ import glob
import json
import shutil
from pathlib import Path
from typing import Dict
import mlx.core as mx
import mlx.nn as nn
@@ -150,8 +149,7 @@ def quantize(weights, config, args):
def make_shards(weights: dict, max_file_size_gibibyte: int = 15):
max_file_size_bytes = max_file_size_gibibyte << 30
shards = []
shard: Dict[str, mx.array] = {}
shard_size = 0
shard, shard_size = {}, 0
for k, v in weights.items():
if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard)

View File

@@ -23,7 +23,7 @@ class ModelArgs:
n_kv_heads: int
norm_eps: float
vocab_size: int
moe: dict
moe: dict = None
class Attention(nn.Module):
@@ -91,6 +91,7 @@ class FeedForward(nn.Module):
class MOEFeedForward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_experts = args.moe["num_experts"]
self.num_experts_per_tok = args.moe["num_experts_per_tok"]
self.experts = [FeedForward(args) for _ in range(self.num_experts)]
@@ -114,6 +115,7 @@ class MOEFeedForward(nn.Module):
yt = (yt * st).sum(axis=-1)
y.append(yt[None, :])
y = mx.concatenate(y)
return y.reshape(orig_shape)

357
llms/mlx_lm/LORA.md Normal file
View File

@@ -0,0 +1,357 @@
# Fine-Tuning with LoRA or QLoRA
You can use use the `mlx-lm` package to fine-tune an LLM with low rank
adaptation (LoRA) for a target task.[^lora] The example also supports quantized
LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
- Mistral
- Llama
- Phi2
- Mixtral
- Qwen2
- Gemma
- OLMo
- MiniCPM
- InternLM2
## Contents
- [Run](#Run)
- [Fine-tune](#Fine-tune)
- [Evaluate](#Evaluate)
- [Generate](#Generate)
- [Fuse](#Fuse)
- [Data](#Data)
- [Memory Issues](#Memory-Issues)
## Run
The main command is `mlx_lm.lora`. To see a full list of command-line options run:
```shell
mlx_lm.lora --help
```
Note, in the following the `--model` argument can be any compatible Hugging
Face repo or a local path to a converted model.
You can also specify a YAML config with `-c`/`--config`. For more on the format see the
[example YAML](examples/lora_config.yaml). For example:
```shell
mlx_lm.lora --config /path/to/config.yaml
```
If command-line flags are also used, they will override the corresponding
values in the config.
### Fine-tune
To fine-tune a model use:
```shell
mlx_lm.lora \
--model <path_to_model> \
--train \
--data <path_to_data> \
--iters 600
```
To fine-tune the full model weights, add the `--fine-tune-type full` flag.
Currently supported fine-tuning types are `lora` (default), `dora`, and `full`.
The `--data` argument must specify a path to a `train.jsonl`, `valid.jsonl`
when using `--train` and a path to a `test.jsonl` when using `--test`. For more
details on the data format see the section on [Data](#Data).
For example, to fine-tune a Mistral 7B you can use `--model
mistralai/Mistral-7B-v0.1`.
If `--model` points to a quantized model, then the training will use QLoRA,
otherwise it will use regular LoRA.
By default, the adapter config and learned weights are saved in `adapters/`.
You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`.
### Evaluate
To compute test set perplexity use:
```shell
mlx_lm.lora \
--model <path_to_model> \
--adapter-path <path_to_adapters> \
--data <path_to_data> \
--test
```
### Generate
For generation use `mlx_lm.generate`:
```shell
mlx_lm.generate \
--model <path_to_model> \
--adapter-path <path_to_adapters> \
--prompt "<your_model_prompt>"
```
## Fuse
You can generate a model fused with the low-rank adapters using the
`mlx_lm.fuse` command. This command also allows you to optionally:
- Upload the fused model to the Hugging Face Hub.
- Export the fused model to GGUF. Note GGUF support is limited to Mistral,
Mixtral, and Llama style models in fp16 precision.
To see supported options run:
```shell
mlx_lm.fuse --help
```
To generate the fused model run:
```shell
mlx_lm.fuse --model <path_to_model>
```
This will by default load the adapters from `adapters/`, and save the fused
model in the path `fused_model/`. All of these are configurable.
To upload a fused model, supply the `--upload-repo` and `--hf-path` arguments
to `mlx_lm.fuse`. The latter is the repo name of the original model, which is
useful for the sake of attribution and model versioning.
For example, to fuse and upload a model derived from Mistral-7B-v0.1, run:
```shell
mlx_lm.fuse \
--model mistralai/Mistral-7B-v0.1 \
--upload-repo mlx-community/my-lora-mistral-7b \
--hf-path mistralai/Mistral-7B-v0.1
```
To export a fused model to GGUF, run:
```shell
mlx_lm.fuse \
--model mistralai/Mistral-7B-v0.1 \
--export-gguf
```
This will save the GGUF model in `fused_model/ggml-model-f16.gguf`. You
can specify the file name with `--gguf-path`.
## Data
The LoRA command expects you to provide a dataset with `--data`. The MLX
Examples GitHub repo has an [example of the WikiSQL
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
correct format.
Datasets can be specified in `*.jsonl` files locally or loaded from Hugging
Face.
### Local Datasets
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
loader expects a `test.jsonl` in the data directory.
Currently, `*.jsonl` files support `chat`, `tools`, `completions`, and `text`
data formats. Here are examples of these formats:
`chat`:
```jsonl
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "How can I assistant you today."}]}
```
`tools`:
```jsonl
{"messages":[{"role":"user","content":"What is the weather in San Francisco?"},{"role":"assistant","tool_calls":[{"id":"call_id","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"}}]}],"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and country, eg. San Francisco, USA"},"format":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location","format"]}}}]}
```
<details>
<summary>View the expanded single data tool format</summary>
```jsonl
{
"messages": [
{ "role": "user", "content": "What is the weather in San Francisco?" },
{
"role": "assistant",
"tool_calls": [
{
"id": "call_id",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}"
}
}
]
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and country, eg. San Francisco, USA"
},
"format": { "type": "string", "enum": ["celsius", "fahrenheit"] }
},
"required": ["location", "format"]
}
}
}
]
}
```
The format for the `arguments` field in a function varies for different models.
Common formats include JSON strings and dictionaries. The example provided
follows the format used by
[OpenAI](https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples)
and [Mistral
AI](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#instruct).
A dictionary format is used in Hugging Face's [chat
templates](https://huggingface.co/docs/transformers/main/en/chat_templating#a-complete-tool-use-example).
Refer to the documentation for the model you are fine-tuning for more details.
</details>
`completions`:
```jsonl
{"prompt": "What is the capital of France?", "completion": "Paris."}
```
`text`:
```jsonl
{"text": "This is an example for the model."}
```
Note, the format is automatically determined by the dataset. Note also, keys in
each line not expected by the loader will be ignored.
> [!NOTE]
> Each example in the datasets must be on a single line. Do not put more than
> one example per line and do not split an example across multiple lines.
### Hugging Face Datasets
To use Hugging Face datasets, first install the `datasets` package:
```
pip install datasets
```
If the Hugging Face dataset is already in a supported format, you can specify
it on the command line. For example, pass `--data mlx-community/wikisql` to
train on the pre-formatted WikiwSQL data.
Otherwise, provide a mapping of keys in the dataset to the features MLX LM
expects. Use a YAML config to specify the Hugging Face dataset arguments. For
example:
```
hf_dataset:
name: "billsum"
prompt_feature: "text"
completion_feature: "summary"
```
- Use `prompt_feature` and `completion_feature` to specify keys for a
`completions` dataset. Use `text_feature` to specify the key for a `text`
dataset.
- To specify the train, valid, or test splits, set the corresponding
`{train,valid,test}_split` argument.
- Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
In general, for the `chat`, `tools` and `completions` formats, Hugging Face
[chat
templates](https://huggingface.co/docs/transformers/main/en/chat_templating)
are used. This applies the model's chat template by default. If the model does
not have a chat template, then Hugging Face will use a default. For example,
the final text in the `chat` example above with Hugging Face's default template
becomes:
```text
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello.<|im_end|>
<|im_start|>assistant
How can I assistant you today.<|im_end|>
```
If you are unsure of the format to use, the `chat` or `completions` are good to
start with. For custom requirements on the format of the dataset, use the
`text` format to assemble the content yourself.
## Memory Issues
Fine-tuning a large model with LoRA requires a machine with a decent amount
of memory. Here are some tips to reduce memory use should you need to do so:
1. Try quantization (QLoRA). You can use QLoRA by generating a quantized model
with `convert.py` and the `-q` flag. See the [Setup](#setup) section for
more details.
2. Try using a smaller batch size with `--batch-size`. The default is `4` so
setting this to `2` or `1` will reduce memory consumption. This may slow
things down a little, but will also reduce the memory use.
3. Reduce the number of layers to fine-tune with `--num-layers`. The default
is `16`, so you can try `8` or `4`. This reduces the amount of memory
needed for back propagation. It may also reduce the quality of the
fine-tuned model if you are fine-tuning with a lot of data.
4. Longer examples require more memory. If it makes sense for your data, one thing
you can do is break your examples into smaller
sequences when making the `{train, valid, test}.jsonl` files.
5. Gradient checkpointing lets you trade-off memory use (less) for computation
(more) by recomputing instead of storing intermediate values needed by the
backward pass. You can use gradient checkpointing by passing the
`--grad-checkpoint` flag. Gradient checkpointing will be more helpful for
larger batch sizes or sequence lengths with smaller or quantized models.
For example, for a machine with 32 GB the following should run reasonably fast:
```
mlx_lm.lora \
--model mistralai/Mistral-7B-v0.1 \
--train \
--batch-size 1 \
--num-layers 4 \
--data wikisql
```
The above command on an M1 Max with 32 GB runs at about 250
tokens-per-second, using the MLX Example
[`wikisql`](https://github.com/ml-explore/mlx-examples/tree/main/lora/data)
data set.
[^lora]: Refer to the [arXiv paper](https://arxiv.org/abs/2106.09685) for more details on LoRA.
[^qlora]: Refer to the paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)

22
llms/mlx_lm/MANAGE.md Normal file
View File

@@ -0,0 +1,22 @@
# Managing Models
You can use `mlx-lm` to manage models downloaded locally in your machine. They
are stored in the Hugging Face cache.
Scan models:
```shell
mlx_lm.manage --scan
```
Specify a `--pattern` to get info on a single or specific set of models:
```shell
mlx_lm.manage --scan --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit
```
To delete a model (or multiple models):
```shell
mlx_lm.manage --delete --pattern mlx-community/Mistral-7B-Instruct-v0.2-4bit
```

50
llms/mlx_lm/MERGE.md Normal file
View File

@@ -0,0 +1,50 @@
# Model Merging
You can use `mlx-lm` to merge models and upload them to the Hugging
Face hub or save them locally for LoRA fine tuning.
The main command is `mlx_lm.merge`:
```shell
mlx_lm.merge --config config.yaml
```
The merged model will be saved by default in `mlx_merged_model`. To see a
full list of options run:
```shell
mlx_lm.merge --help
```
Here is an example `config.yaml`:
```yaml
models:
- OpenPipe/mistral-ft-optimized-1218
- mlabonne/NeuralHermes-2.5-Mistral-7B
method: slerp
parameters:
t:
- filter: self_attn
value: [0, 0.5, 0.3, 0.7, 1]
- filter: mlp
value: [1, 0.5, 0.7, 0.3, 0]
- value: 0.5
```
The `models` field is a list of Hugging Face repo ids. The first model in the
list is treated as the base model into which the remaining models are merged.
The `method` field is the merging method. Right now `slerp` is the only
supported method.
The `parameters` are the corresponding parameters for the given `method`.
Each parameter is a list with `filter` determining which layer the parameter
applies to and `value` determining the actual value used. The last item in
the list without a `filter` field is the default.
If `value` is a list, it specifies the start and end values for the
corresponding segment of blocks. In the example above, the models have 32
blocks. For blocks 1-8, the layers with `self_attn` in the name will use the
values `np.linspace(0, 0.5, 8)`, the same layers in the next 8 blocks (9-16)
will use `np.linspace(0.5, 0.3, 8)`, and so on.

10
llms/mlx_lm/README.md Normal file
View File

@@ -0,0 +1,10 @@
## Generate Text with MLX and :hugs: Hugging Face
This an example of large language model text generation that can pull models from
the Hugging Face Hub.
For more information on this example, see the [README](../README.md) in the
parent directory.
This package also supports fine tuning with LoRA or QLoRA. For more information
see the [LoRA documentation](LORA.md).

131
llms/mlx_lm/SERVER.md Normal file
View File

@@ -0,0 +1,131 @@
# HTTP Model Server
You use `mlx-lm` to make an HTTP API for generating text with any supported
model. The HTTP API is intended to be similar to the [OpenAI chat
API](https://platform.openai.com/docs/api-reference).
> [!NOTE]
> The MLX LM server is not recommended for production as it only implements
> basic security checks.
Start the server with:
```shell
mlx_lm.server --model <path_to_model_or_hf_repo>
```
For example:
```shell
mlx_lm.server --model mlx-community/Mistral-7B-Instruct-v0.3-4bit
```
This will start a text generation server on port `8080` of the `localhost`
using Mistral 7B instruct. The model will be downloaded from the provided
Hugging Face repo if it is not already in the local cache.
To see a full list of options run:
```shell
mlx_lm.server --help
```
You can make a request to the model by running:
```shell
curl localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Say this is a test!"}],
"temperature": 0.7
}'
```
### Request Fields
- `messages`: An array of message objects representing the conversation
history. Each message object should have a role (e.g. user, assistant) and
content (the message text).
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
the generated prompt. If not provided, the default mappings are used.
- `stop`: (Optional) An array of strings or a single string. These are
sequences of tokens on which the generation should stop.
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
to generate. Defaults to `100`.
- `stream`: (Optional) A boolean indicating if the response should be
streamed. If true, responses are sent as they are generated. Defaults to
false.
- `temperature`: (Optional) A float specifying the sampling temperature.
Defaults to `1.0`.
- `top_p`: (Optional) A float specifying the nucleus sampling parameter.
Defaults to `1.0`.
- `repetition_penalty`: (Optional) Applies a penalty to repeated tokens.
Defaults to `1.0`.
- `repetition_context_size`: (Optional) The size of the context window for
applying repetition penalty. Defaults to `20`.
- `logit_bias`: (Optional) A dictionary mapping token IDs to their bias
values. Defaults to `None`.
- `logprobs`: (Optional) An integer specifying the number of top tokens and
corresponding log probabilities to return for each output in the generated
sequence. If set, this can be any value between 1 and 10, inclusive.
- `model`: (Optional) A string path to a local model or Hugging Face repo id.
If the path is local is must be relative to the directory the server was
started in.
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
relative to the directory the server was started in.
### Response Fields
- `id`: A unique identifier for the chat.
- `system_fingerprint`: A unique identifier for the system.
- `object`: Any of "chat.completions", "chat.completions.chunk" (for
streaming), or "text.completion".
- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).
- `created`: A time-stamp for when the request was processed.
- `choices`: A list of outputs. Each output is a dictionary containing the fields:
- `index`: The index in the list.
- `logprobs`: A dictionary containing the fields:
- `token_logprobs`: A list of the log probabilities for the generated
tokens.
- `tokens`: A list of the generated token ids.
- `top_logprobs`: A list of lists. Each list contains the `logprobs`
top tokens (if requested) with their corresponding probabilities.
- `finish_reason`: The reason the completion ended. This can be either of
`"stop"` or `"length"`.
- `message`: The text response from the model.
- `usage`: A dictionary containing the fields:
- `prompt_tokens`: The number of prompt tokens processed.
- `completion_tokens`: The number of tokens generated.
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.
### List Models
Use the `v1/models` endpoint to list available models:
```shell
curl localhost:8080/v1/models -H "Content-Type: application/json"
```
This will return a list of locally available models where each model in the
list contains the following fields:
- `id`: The Hugging Face repo id.
- `created`: A time-stamp representing the model creation time.

37
llms/mlx_lm/UPLOAD.md Normal file
View File

@@ -0,0 +1,37 @@
### Packaging for PyPI
Install `build` and `twine`:
```
pip install --user --upgrade build
pip install --user --upgrade twine
```
Generate the source distribution and wheel:
```
python -m build
```
> [!warning]
> Use a test server first
#### Test Upload
Upload to test server:
```
python -m twine upload --repository testpypi dist/*
```
Install from test server and check that it works:
```
python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx-lm
```
#### Upload
```
python -m twine upload dist/*
```

9
llms/mlx_lm/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
# Copyright © 2023-2024 Apple Inc.
import os
from ._version import __version__
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
from .utils import convert, generate, load, stream_generate

3
llms/mlx_lm/_version.py Normal file
View File

@@ -0,0 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.19.3"

180
llms/mlx_lm/cache_prompt.py Normal file
View File

@@ -0,0 +1,180 @@
# Copyright © 2024 Apple Inc.
import argparse
import json
import sys
import time
import mlx.core as mx
from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import load, maybe_quantize_kv_cache
DEFAULT_QUANTIZED_KV_START = 5000
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="Cache the state of a prompt to be reused with mlx_lm.generate"
)
parser.add_argument(
"--model",
type=str,
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Enable trusting remote code for tokenizer",
)
parser.add_argument(
"--eos-token",
type=str,
default=None,
help="End of sequence token for tokenizer",
)
parser.add_argument(
"--ignore-chat-template",
action="store_true",
help="Use the raw prompt without the tokenizer's chat template.",
)
parser.add_argument(
"--use-default-chat-template",
action="store_true",
help="Use the default chat template",
)
parser.add_argument(
"--cache-limit-gb",
type=int,
default=None,
help="Set the MLX cache limit in GB",
)
parser.add_argument(
"--max-kv-size",
type=int,
default=None,
help="Set the maximum key-value cache size",
)
parser.add_argument(
"--prompt-cache-file",
help="The file to save the prompt cache in",
required=True,
)
parser.add_argument(
"--prompt",
required=True,
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--kv-bits",
type=int,
help="Number of bits for KV cache quantization. "
"Defaults to no quantization.",
default=None,
)
parser.add_argument(
"--kv-group-size",
type=int,
help="Group size for KV cache quantization.",
default=64,
)
parser.add_argument(
"--quantized-kv-start",
help="When --kv-bits is set, start quantizing the KV cache "
"from this step onwards.",
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
return parser
def main():
parser = setup_arg_parser()
args = parser.parse_args()
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model, tokenizer = load(
args.model,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
args.prompt = sys.stdin.read() if args.prompt == "-" else args.prompt
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Treat the prompt as a prefix assuming that the suffix will be
# provided at generation time.
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
n = len(test_prompt) - test_prompt.index("<query>") - len("<query>")
prompt = prompt[:-n]
else:
prompt = args.prompt
cache = make_prompt_cache(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt))
# Process the prompt
processed = 0
step_size = 512
start = time.time()
max_msg_len = 0
while y.size > 0:
model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache])
mx.metal.clear_cache()
processed += min(y.size, step_size)
y = y[step_size:]
current = time.time()
speed = processed / (current - start)
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)
maybe_quantize_kv_cache(
cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits
)
print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
print("Saving...")
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
save_prompt_cache(args.prompt_cache_file, cache, metadata)
if __name__ == "__main__":
main()

91
llms/mlx_lm/chat.py Normal file
View File

@@ -0,0 +1,91 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import mlx.core as mx
from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
from .utils import load, stream_generate
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MAX_TOKENS = 256
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="Chat with an LLM")
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
default=DEFAULT_MODEL,
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--max-kv-size",
type=int,
help="Set the maximum key-value cache size",
default=None,
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
return parser
def main():
parser = setup_arg_parser()
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load(
args.model,
adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True},
)
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
query = input(">> ")
if query == "q":
break
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for response, *_ in stream_generate(
model,
tokenizer,
prompt,
args.max_tokens,
temp=args.temp,
top_p=args.top_p,
prompt_cache=prompt_cache,
):
print(response, flush=True, end="")
print()
if __name__ == "__main__":
main()

62
llms/mlx_lm/convert.py Normal file
View File

@@ -0,0 +1,62 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
from .utils import convert
def configure_parser() -> argparse.ArgumentParser:
"""
Configures and returns the argument parser for the script.
Returns:
argparse.ArgumentParser: Configured argument parser.
"""
parser = argparse.ArgumentParser(
description="Convert Hugging Face model to MLX format"
)
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
parser.add_argument(
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
)
parser.add_argument(
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
)
parser.add_argument(
"--q-group-size", help="Group size for quantization.", type=int, default=64
)
parser.add_argument(
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
)
parser.add_argument(
"--dtype",
help="Type to save the non-quantized parameters.",
type=str,
choices=["float16", "bfloat16", "float32"],
default="float16",
)
parser.add_argument(
"--upload-repo",
help="The Hugging Face repo to upload the model to.",
type=str,
default=None,
)
parser.add_argument(
"-d",
"--dequantize",
help="Dequantize a quantized model.",
action="store_true",
default=False,
)
return parser
def main():
parser = configure_parser()
args = parser.parse_args()
convert(**vars(args))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,53 @@
# Copyright © 2024 Apple Inc.
"""
An example of a multi-turn chat with prompt caching.
"""
from mlx_lm import generate, load
from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
# Make the initial prompt cache for the model
prompt_cache = make_prompt_cache(model)
# User turn
prompt = "Hi my name is <Name>."
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response
response = generate(
model,
tokenizer,
prompt=prompt,
verbose=True,
temp=0.0,
prompt_cache=prompt_cache,
)
# User turn
prompt = "What's my name?"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response
response = generate(
model,
tokenizer,
prompt=prompt,
verbose=True,
temp=0.0,
prompt_cache=prompt_cache,
)
# Save the prompt cache to disk to reuse it at a later time
save_prompt_cache("mistral_prompt.safetensors", prompt_cache)
# Load the prompt cache from disk
prompt_cache = load_prompt_cache("mistral_prompt.safetensors")

View File

@@ -0,0 +1,42 @@
# Copyright © 2024 Apple Inc.
from mlx_lm import generate, load
# Specify the checkpoint
checkpoint = "mistralai/Mistral-7B-Instruct-v0.3"
# Load the corresponding model and tokenizer
model, tokenizer = load(path_or_hf_repo=checkpoint)
# Specify the prompt and conversation history
prompt = "Why is the sky blue?"
conversation = [{"role": "user", "content": prompt}]
# Transform the prompt into the chat template
prompt = tokenizer.apply_chat_template(
conversation=conversation, tokenize=False, add_generation_prompt=True
)
# Specify the maximum number of tokens
max_tokens = 1_000
# Specify if tokens and timing information will be printed
verbose = True
# Some optional arguments for causal language model generation
generation_args = {
"temp": 0.7,
"repetition_penalty": 1.2,
"repetition_context_size": 20,
"top_p": 0.95,
}
# Generate a response with the specified settings
response = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
verbose=verbose,
**generation_args,
)

View File

@@ -0,0 +1,80 @@
# The path to the local model directory or Hugging Face repo.
model: "mlx_model"
# Whether or not to train (boolean)
train: true
# The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora
# Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data"
# The PRNG seed
seed: 0
# Number of layers to fine-tune
num_layers: 16
# Minibatch size.
batch_size: 4
# Iterations to train for.
iters: 1000
# Number of validation batches, -1 uses the entire validation set.
val_batches: 25
# Adam learning rate.
learning_rate: 1e-5
# Number of training steps between loss reporting.
steps_per_report: 10
# Number of training steps between validations.
steps_per_eval: 200
# Load path to resume training with the given adapter weights.
resume_adapter_file: null
# Save/load path for the trained adapter weights.
adapter_path: "adapters"
# Save the model every N iterations.
save_every: 100
# Evaluate on the test set after training
test: false
# Number of test set batches, -1 uses the entire test set.
test_batches: 100
# Maximum sequence length.
max_seq_length: 2048
# Use gradient checkpointing to reduce memory use.
grad_checkpoint: false
# LoRA parameters can only be specified in a config file
lora_parameters:
# The layer keys to apply LoRA to.
# These will be applied for the last lora_layers
keys: ["self_attn.q_proj", "self_attn.v_proj"]
rank: 8
scale: 20.0
dropout: 0.0
# Schedule can only be specified in a config file, uncomment to use.
#lr_schedule:
# name: cosine_decay
# warmup: 100 # 0 for no warmup
# warmup_init: 1e-7 # 0 if not specified
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler
#hf_dataset:
# name: "billsum"
# train_split: "train[:1000]"
# valid_split: "train[-100:]"
# prompt_feature: "text"
# completion_feature: "summary"

View File

@@ -0,0 +1,11 @@
models:
- OpenPipe/mistral-ft-optimized-1218
- mlabonne/NeuralHermes-2.5-Mistral-7B
method: slerp
parameters:
t:
- filter: self_attn
value: [0, 0.5, 0.3, 0.7, 1]
- filter: mlp
value: [1, 0.5, 0.7, 0.3, 0]
- value: 0.5

130
llms/mlx_lm/fuse.py Normal file
View File

@@ -0,0 +1,130 @@
import argparse
import glob
import shutil
from pathlib import Path
from mlx.utils import tree_flatten, tree_unflatten
from .gguf import convert_to_gguf
from .tuner.dora import DoRAEmbedding, DoRALinear
from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
from .tuner.utils import dequantize, load_adapters
from .utils import (
fetch_from_hub,
get_model_path,
save_config,
save_weights,
upload_to_hub,
)
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Fuse fine-tuned adapters into the base model."
)
parser.add_argument(
"--model",
default="mlx_model",
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--save-path",
default="fused_model",
help="The path to save the fused model.",
)
parser.add_argument(
"--adapter-path",
type=str,
default="adapters",
help="Path to the trained adapter weights and config.",
)
parser.add_argument(
"--hf-path",
type=str,
default=None,
help="Path to the original Hugging Face model. Required for upload if --model is a local directory.",
)
parser.add_argument(
"--upload-repo",
help="The Hugging Face repo to upload the model to.",
type=str,
default=None,
)
parser.add_argument(
"--de-quantize",
help="Generate a de-quantized model.",
action="store_true",
)
parser.add_argument(
"--export-gguf",
help="Export model weights in GGUF format.",
action="store_true",
)
parser.add_argument(
"--gguf-path",
help="Path to save the exported GGUF format model weights. Default is ggml-model-f16.gguf.",
default="ggml-model-f16.gguf",
type=str,
)
return parser.parse_args()
def main() -> None:
print("Loading pretrained model")
args = parse_arguments()
model_path = get_model_path(args.model)
model, config, tokenizer = fetch_from_hub(model_path)
model.freeze()
model = load_adapters(model, args.adapter_path)
fused_linears = [
(n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse")
]
if fused_linears:
model.update_modules(tree_unflatten(fused_linears))
if args.de_quantize:
print("De-quantizing model")
model = dequantize(model)
weights = dict(tree_flatten(model.parameters()))
save_path = Path(args.save_path)
save_weights(save_path, weights)
py_files = glob.glob(str(model_path / "*.py"))
for file in py_files:
shutil.copy(file, save_path)
tokenizer.save_pretrained(save_path)
if args.de_quantize:
config.pop("quantization", None)
save_config(config, config_path=save_path / "config.json")
if args.export_gguf:
model_type = config["model_type"]
if model_type not in ["llama", "mixtral", "mistral"]:
raise ValueError(
f"Model type {model_type} not supported for GGUF conversion."
)
convert_to_gguf(model_path, weights, config, str(save_path / args.gguf_path))
if args.upload_repo is not None:
hf_path = args.hf_path or (
args.model if not Path(args.model).exists() else None
)
if hf_path is None:
raise ValueError(
"Must provide original Hugging Face repo to upload local model."
)
upload_to_hub(args.save_path, args.upload_repo, hf_path)
if __name__ == "__main__":
main()

275
llms/mlx_lm/generate.py Normal file
View File

@@ -0,0 +1,275 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import sys
import mlx.core as mx
from .models.cache import QuantizedKVCache, load_prompt_cache
from .utils import generate, load
DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_MIN_P = 0.0
DEFAULT_MIN_TOKENS_TO_KEEP = 1
DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_QUANTIZED_KV_START = 5000
def str2bool(string):
return string.lower() not in ["false", "f"]
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="LLM inference script")
parser.add_argument(
"--model",
type=str,
help=(
"The path to the local model directory or Hugging Face repo. "
f"If no model is specified, then {DEFAULT_MODEL} is used."
),
default=None,
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Enable trusting remote code for tokenizer",
)
parser.add_argument(
"--eos-token",
type=str,
default=None,
help="End of sequence token for tokenizer",
)
parser.add_argument(
"--prompt",
"-p",
default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument(
"--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p"
)
parser.add_argument(
"--min-tokens-to-keep",
type=float,
default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.",
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--ignore-chat-template",
action="store_true",
help="Use the raw prompt without the tokenizer's chat template.",
)
parser.add_argument(
"--use-default-chat-template",
action="store_true",
help="Use the default chat template",
)
parser.add_argument(
"--verbose",
type=str2bool,
default=True,
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
)
parser.add_argument(
"--colorize",
action="store_true",
help="Colorize output based on T[0] probability",
)
parser.add_argument(
"--max-kv-size",
type=int,
help="Set the maximum key-value cache size",
default=None,
)
parser.add_argument(
"--prompt-cache-file",
type=str,
default=None,
help="A file containing saved KV caches to avoid recomputing them",
)
parser.add_argument(
"--kv-bits",
type=int,
help="Number of bits for KV cache quantization. "
"Defaults to no quantization.",
default=None,
)
parser.add_argument(
"--kv-group-size",
type=int,
help="Group size for KV cache quantization.",
default=64,
)
parser.add_argument(
"--quantized-kv-start",
help="When --kv-bits is set, start quantizing the KV cache "
"from this step onwards.",
type=int,
default=DEFAULT_QUANTIZED_KV_START,
)
return parser
def colorprint(color, s):
color_codes = {
"black": 30,
"red": 31,
"green": 32,
"yellow": 33,
"blue": 34,
"magenta": 35,
"cyan": 36,
"white": 39,
}
ccode = color_codes.get(color, 30)
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)
def colorprint_by_t0(s, t0):
if t0 > 0.95:
color = "white"
elif t0 > 0.70:
color = "green"
elif t0 > 0.30:
color = "yellow"
else:
color = "red"
colorprint(color, s)
def main():
parser = setup_arg_parser()
args = parser.parse_args()
mx.random.seed(args.seed)
# Load the prompt cache and metadata if a cache file is provided
using_cache = args.prompt_cache_file is not None
if using_cache:
prompt_cache, metadata = load_prompt_cache(
args.prompt_cache_file,
return_metadata=True,
)
if isinstance(prompt_cache[0], QuantizedKVCache):
if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits:
raise ValueError(
"--kv-bits does not match the kv cache loaded from --prompt-cache-file."
)
if args.kv_group_size != prompt_cache[0].group_size:
raise ValueError(
"--kv-group-size does not match the kv cache loaded from --prompt-cache-file."
)
# Building tokenizer_config
tokenizer_config = (
{} if not using_cache else json.loads(metadata["tokenizer_config"])
)
if args.trust_remote_code:
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
model_path = args.model
if using_cache:
if model_path is None:
model_path = metadata["model"]
elif model_path != metadata["model"]:
raise ValueError(
f"Providing a different model ({model_path}) than that "
f"used to create the prompt cache ({metadata['model']}) "
"is an error."
)
model_path = model_path or DEFAULT_MODEL
model, tokenizer = load(
model_path,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
)
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
elif using_cache:
tokenizer.chat_template = metadata["chat_template"]
if not args.ignore_chat_template and (
hasattr(tokenizer, "apply_chat_template")
and tokenizer.chat_template is not None
):
messages = [
{
"role": "user",
"content": sys.stdin.read() if args.prompt == "-" else args.prompt,
}
]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache.
if using_cache:
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
add_generation_prompt=True,
)
prompt = prompt[test_prompt.index("<query>") :]
else:
prompt = args.prompt
if args.colorize and not args.verbose:
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None
response = generate(
model,
tokenizer,
prompt,
args.max_tokens,
verbose=args.verbose,
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
min_p=args.min_p,
min_tokens_to_keep=args.min_tokens_to_keep,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
)
if not args.verbose:
print(response)
if __name__ == "__main__":
main()

616
llms/mlx_lm/gguf.py Normal file
View File

@@ -0,0 +1,616 @@
import importlib
import re
import tempfile
from enum import IntEnum
from pathlib import Path
from typing import Iterable, Optional, Set, Tuple, Union
import gguf
import mlx.core as mx
import mlx.nn as nn
from gguf import GGMLQuantizationType
from gguf.gguf_reader import GGUFReader
from transformers import AutoTokenizer
from .tokenizer_utils import TokenizerWrapper
class TokenType(IntEnum):
NORMAL = 1
UNKNOWN = 2
CONTROL = 3
USER_DEFINED = 4
UNUSED = 5
BYTE = 6
class GGMLFileType(IntEnum):
GGML_TYPE_F16 = 1
# copied from https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L455
class HfVocab:
def __init__(
self,
fname_tokenizer: Path,
fname_added_tokens: Optional[Union[Path, None]] = None,
) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(
fname_tokenizer,
cache_dir=fname_tokenizer,
local_files_only=True,
)
self.added_tokens_list = []
self.added_tokens_dict = dict()
self.added_tokens_ids = set()
for tok, tokidx in sorted(
self.tokenizer.get_added_vocab().items(), key=lambda x: x[1]
):
if tokidx >= self.tokenizer.vocab_size:
self.added_tokens_list.append(tok)
self.added_tokens_dict[tok] = tokidx
self.added_tokens_ids.add(tokidx)
self.specials = {
tok: self.tokenizer.get_vocab()[tok]
for tok in self.tokenizer.all_special_tokens
}
self.special_ids = set(self.tokenizer.all_special_ids)
self.vocab_size_base = self.tokenizer.vocab_size
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
reverse_vocab = {
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()
}
for token_id in range(self.vocab_size_base):
if token_id in self.added_tokens_ids:
continue
token_text = reverse_vocab[token_id]
yield token_text, self.get_token_score(token_id), self.get_token_type(
token_id, token_text, self.special_ids
)
def get_token_type(
self, token_id: int, token_text: bytes, special_ids: Set[int]
) -> TokenType:
if re.fullmatch(r"<0x[0-9A-Fa-f]{2}>", token_text):
return TokenType.BYTE
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL
def get_token_score(self, token_id: int) -> float:
return -1000.0
def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
for text in self.added_tokens_list:
if text in self.specials:
toktype = self.get_token_type(self.specials[text], "", self.special_ids)
score = self.get_token_score(self.specials[text])
else:
toktype = TokenType.USER_DEFINED
score = -1000.0
yield text, score, toktype
def has_newline_token(self):
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
def all_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
yield from self.hf_tokens()
yield from self.added_tokens()
def __repr__(self) -> str:
return f"<HfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
@staticmethod
def load(path: Path) -> "HfVocab":
added_tokens_path = path.parent / "added_tokens.json"
return HfVocab(path, added_tokens_path if added_tokens_path.exists() else None)
def translate_weight_names(name):
name = name.replace("model.layers.", "blk.")
# for mixtral gate
name = name.replace("block_sparse_moe.gate", "ffn_gate_inp")
# for mixtral experts ffns
pattern = r"block_sparse_moe\.experts\.(\d+)\.w1\.weight"
replacement = r"ffn_gate.\1.weight"
name = re.sub(pattern, replacement, name)
pattern = r"block_sparse_moe\.experts\.(\d+)\.w2\.weight"
replacement = r"ffn_down.\1.weight"
name = re.sub(pattern, replacement, name)
pattern = r"block_sparse_moe\.experts\.(\d+)\.w3\.weight"
replacement = r"ffn_up.\1.weight"
name = re.sub(pattern, replacement, name)
name = name.replace("mlp.gate_proj", "ffn_gate")
name = name.replace("mlp.down_proj", "ffn_down")
name = name.replace("mlp.up_proj", "ffn_up")
name = name.replace("self_attn.q_proj", "attn_q")
name = name.replace("self_attn.k_proj", "attn_k")
name = name.replace("self_attn.v_proj", "attn_v")
name = name.replace("self_attn.o_proj", "attn_output")
name = name.replace("input_layernorm", "attn_norm")
name = name.replace("post_attention_layernorm", "ffn_norm")
name = name.replace("model.embed_tokens", "token_embd")
name = name.replace("model.norm", "output_norm")
name = name.replace("lm_head", "output")
return name
def permute_weights(weights, n_head, n_head_kv=None):
if n_head_kv is not None and n_head != n_head_kv:
n_head = n_head_kv
reshaped = weights.reshape(
n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]
)
swapped = reshaped.swapaxes(1, 2)
final_shape = weights.shape
return swapped.reshape(final_shape)
def prepare_metadata(config, vocab):
metadata = {
"general.name": "llama",
"llama.context_length": (
mx.array(config["max_position_embeddings"], dtype=mx.uint32)
if config.get("max_position_embeddings") is not None
else None
),
"llama.embedding_length": (
mx.array(config["hidden_size"], dtype=mx.uint32)
if config.get("hidden_size") is not None
else None
),
"llama.block_count": (
mx.array(config["num_hidden_layers"], dtype=mx.uint32)
if config.get("num_hidden_layers") is not None
else None
),
"llama.feed_forward_length": (
mx.array(config["intermediate_size"], dtype=mx.uint32)
if config.get("intermediate_size") is not None
else None
),
"llama.rope.dimension_count": (
mx.array(
config["hidden_size"] // config["num_attention_heads"], dtype=mx.uint32
)
if config.get("hidden_size") is not None
and config.get("num_attention_heads") is not None
else None
),
"llama.attention.head_count": (
mx.array(config["num_attention_heads"], dtype=mx.uint32)
if config.get("num_attention_heads") is not None
else None
),
"llama.attention.head_count_kv": (
mx.array(
config.get("num_key_value_heads", config["num_attention_heads"]),
dtype=mx.uint32,
)
if config.get("num_attention_heads") is not None
else None
),
"llama.expert_count": (
mx.array(config.get("num_local_experts", None), dtype=mx.uint32)
if config.get("num_local_experts") is not None
else None
),
"llama.expert_used_count": (
mx.array(config.get("num_experts_per_tok", None), dtype=mx.uint32)
if config.get("num_experts_per_tok") is not None
else None
),
"llama.attention.layer_norm_rms_epsilon": (
mx.array(config.get("rms_norm_eps", 1e-05))
if config.get("rms_norm_eps") is not None
else None
),
"llama.rope.freq_base": (
mx.array(config.get("rope_theta", 10000), dtype=mx.float32)
if config.get("rope_theta") is not None
else None
),
}
rope_scaling = config.get("rope_scaling")
if rope_scaling is not None and (typ := rope_scaling.get("type")):
rope_factor = rope_scaling.get("factor")
f_rope_scale = rope_factor
if typ == "linear":
rope_scaling_type = "linear"
metadata["llama.rope.scaling.type"] = rope_scaling_type
metadata["llama.rope.scaling.factor"] = mx.array(f_rope_scale)
metadata["general.file_type"] = mx.array(
GGMLFileType.GGML_TYPE_F16.value,
dtype=mx.uint32,
)
metadata["general.quantization_version"] = mx.array(
GGMLFileType.GGML_TYPE_F16.value,
dtype=mx.uint32,
)
metadata["general.name"] = config.get("_name_or_path", "llama").split("/")[-1]
metadata["general.architecture"] = "llama"
metadata["general.alignment"] = mx.array(32, dtype=mx.uint32)
# add metadata for vocab
metadata["tokenizer.ggml.model"] = "llama"
tokens = []
scores = []
toktypes = []
for text, score, toktype in vocab.all_tokens():
tokens.append(text)
scores.append(score)
toktypes.append(toktype.value)
assert len(tokens) == vocab.vocab_size
metadata["tokenizer.ggml.tokens"] = tokens
metadata["tokenizer.ggml.scores"] = mx.array(scores, dtype=mx.float32)
metadata["tokenizer.ggml.token_type"] = mx.array(toktypes, dtype=mx.uint32)
if vocab.tokenizer.bos_token_id is not None:
metadata["tokenizer.ggml.bos_token_id"] = mx.array(
vocab.tokenizer.bos_token_id, dtype=mx.uint32
)
if vocab.tokenizer.eos_token_id is not None:
metadata["tokenizer.ggml.eos_token_id"] = mx.array(
vocab.tokenizer.eos_token_id, dtype=mx.uint32
)
if vocab.tokenizer.unk_token_id is not None:
metadata["tokenizer.ggml.unknown_token_id"] = mx.array(
vocab.tokenizer.unk_token_id, dtype=mx.uint32
)
metadata = {k: v for k, v in metadata.items() if v is not None}
return metadata
def convert_to_gguf(
model_path: Union[str, Path],
weights: dict,
config: dict,
output_file_path: str,
):
if isinstance(model_path, str):
model_path = Path(model_path)
quantization = config.get("quantization", None)
if quantization:
raise NotImplementedError(
"Conversion of quantized models is not yet supported."
)
print("Converting to GGUF format")
# https://github.com/ggerganov/llama.cpp/blob/master/convert.py#L1182 seems relate to llama.cpp's multihead attention
weights = {
k: (
permute_weights(
v, config["num_attention_heads"], config["num_attention_heads"]
)
if "self_attn.q_proj.weight" in k
else (
permute_weights(
v, config["num_attention_heads"], config["num_key_value_heads"]
)
if "self_attn.k_proj.weight" in k
else v
)
)
for k, v in weights.items()
}
# rename weights for gguf format
weights = {translate_weight_names(k): v for k, v in weights.items()}
if not (model_path / "tokenizer.json").exists():
raise ValueError("Tokenizer json not found")
vocab = HfVocab.load(model_path)
metadata = prepare_metadata(config, vocab)
weights = {
k: (
v.astype(mx.float32).astype(mx.float16)
if v.dtype == mx.bfloat16
else v.astype(mx.float32) if "norm" in k else v
)
for k, v in weights.items()
}
output_file_path = output_file_path
mx.save_gguf(output_file_path, weights, metadata)
print(f"Converted GGUF model saved as: {output_file_path}")
# Adapted from https://github.com/antirez/gguf-tools/blob/4e6455ecaf92b1a59e6a3291646459af3154bef5/gguflib.c#L568
def parse_q4_k(tensor):
bits = 4
pack_factor = 32 // bits
group_size = 32
block_size = 144
data = mx.array(tensor.data)
shape = [int(d) for d in reversed(tensor.shape)]
wshape = (*shape[:-1], shape[-1] // pack_factor)
gshape = (*shape[:-1], shape[-1] // group_size)
num_blocks = data.size // block_size
kernel = mx.fast.metal_kernel(
name="parse_q4_k",
input_names=["data"],
output_names=["w", "scales", "biases"],
header="""
typedef struct {
float16_t d;
float16_t d_min;
uint8_t scales[12];
uint8_t qs[128];
} block_q4_K;
""",
source="""
uint elem = thread_position_in_grid.x;
const device block_q4_K* block = reinterpret_cast<const device block_q4_K*>(data);
block += elem;
w += elem * 32;
scales += elem * 8;
biases += elem * 8;
// First unpack the quantized scales/biases
for (int j = 0; j < 8; j++) {
uint8_t d, m;
if (j < 4) {
d = block->scales[j] & 63;
m = block->scales[j + 4] & 63;
} else {
d = (block->scales[j + 4] & 0xF) | ((block->scales[j - 4] >> 6) << 4);
m = (block->scales[j + 4] >> 4) | ((block->scales[j - 0] >> 6) << 4);
}
scales[j] = d * block->d;
biases[j] = -m * block->d_min;
}
uint32_t outputs[32] = {0};
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 32; j++) {
uint8_t val = block->qs[i * 32 + j] & 0xf;
int index = i * 8 + (j / 8);
outputs[index] += val << (4 * (j % 8));
}
for (int j = 0; j < 32; j++) {
uint8_t val = block->qs[i * 32 + j] >> 4;
int index = i * 8 + 4 + (j / 8);
outputs[index] += val << (4 * (j % 8));
}
}
for (int i = 0; i < 32; i++) {
w[i] = outputs[i];
}
""",
)
w, scales, biases = kernel(
inputs=[data],
grid=(num_blocks, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[wshape, gshape, gshape],
output_dtypes=[mx.uint32, mx.float16, mx.float16],
)
return w, scales, biases
# Adapted from https://github.com/antirez/gguf-tools/blob/4e6455ecaf92b1a59e6a3291646459af3154bef5/gguflib.c#L658
def parse_q6_k(tensor):
bits = 6
group_size = 16
block_size = 210
data = mx.array(tensor.data)
shape = [int(d) for d in reversed(tensor.shape)]
wshape = (*shape[:-1], shape[-1] * bits // 8)
gshape = (*shape[:-1], shape[-1] // group_size)
num_blocks = data.size // block_size
kernel = mx.fast.metal_kernel(
name="parse_q6_k",
input_names=["data"],
output_names=["w", "scales", "biases"],
header="""
typedef struct {
uint8_t ql[128]; // quants, lower 4 bits
uint8_t qh[64]; // quants, upper 2 bits
int8_t scales[16]; // scales, quantized with 8 bits
float16_t d; // super-block scale
} block_q6_K;
""",
source="""
uint elem = thread_position_in_grid.x;
const device block_q6_K* block = reinterpret_cast<const device block_q6_K*>(data);
block += elem;
w += elem * 192;
scales += elem * 16;
biases += elem * 16;
const device uint8_t* ql = &block->ql[0];
const device uint8_t* qh = &block->qh[0];
const device int8_t* bscales = &block->scales[0];
uint32_t output = 0;
for (int cluster = 0; cluster < 2; cluster++) {
for (uint64_t j = 0; j < 128; j++) {
uint8_t val = ((ql[j%64] >> (j/64*4)) & 0xF) | (((qh[j%32] >> (j/32*2)) & 3) << 4);
output += val << (6 * (j % 4));
// Every 4 values write out 3 bytes
if (j % 4 == 3) {
w[0] = output & 0xff;
w[1] = (output & 0xff00) >> 8;
w[2] = (output & 0xff0000) >> 16;
w += 3;
output = 0;
}
if (j % 16 == 0) {
scales[j/16] = block->d * bscales[j/16];
biases[j/16] = -32.0f * scales[j/16];
}
}
ql += 64;
qh += 32;
bscales += 8;
scales += 8;
biases += 8;
}
""",
)
w, scales, biases = kernel(
inputs=[data],
grid=(num_blocks, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[wshape, gshape, gshape],
output_dtypes=[mx.uint8, mx.float16, mx.float16],
)
w = mx.view(w, dtype=mx.uint32)
return w, scales, biases
def parse_gguf_tensor(tensor):
from gguf import GGMLQuantizationType
if tensor.tensor_type == GGMLQuantizationType.Q4_K:
return parse_q4_k(tensor)
elif tensor.tensor_type == GGMLQuantizationType.Q6_K:
return parse_q6_k(tensor)
elif tensor.tensor_type in [GGMLQuantizationType.F16, GGMLQuantizationType.F32]:
return mx.array(tensor.data)
else:
raise NotImplementedError(f"Type: {tensor.tensor_type} is not yet supported.")
def convert_name(name):
name = name.replace("blk", "model.layers")
name = name.replace("attn_norm", "input_layernorm")
name = name.replace("ffn_norm", "post_attention_layernorm")
name = name.replace("attn_q", "self_attn.q_proj")
name = name.replace("attn_k", "self_attn.k_proj")
name = name.replace("attn_v", "self_attn.v_proj")
name = name.replace("attn_output", "self_attn.o_proj")
name = name.replace("ffn_up", "mlp.up_proj")
name = name.replace("ffn_down", "mlp.down_proj")
name = name.replace("ffn_gate", "mlp.gate_proj")
if "output_norm" in name:
name = name.replace("output_norm", "model.norm")
else:
name = name.replace("output", "lm_head")
name = name.replace("token_embd", "model.embed_tokens")
return name
FIELD_MAPPING = {
"{model}.embedding_length": "hidden_size",
"{model}.feed_forward_length": "intermediate_size",
"{model}.attention.head_count": "num_attention_heads",
"{model}.attention.head_count_kv": "num_key_value_heads",
"{model}.block_count": "num_hidden_layers",
"{model}.attention.layer_norm_rms_epsilon": "rms_norm_eps",
"{model}.rope.freq_base": "rope_theta",
}
QUANT_MAPPING = {
GGMLQuantizationType.Q4_K: {
"bits": 4,
"group_size": 32,
},
GGMLQuantizationType.Q6_K: {
"bits": 6,
"group_size": 16,
},
GGMLQuantizationType.F16: None,
GGMLQuantizationType.F32: None,
}
# from https://github.com/ggerganov/llama.cpp/blob/40c6d79fb52f995f47507fedfeaae2ac05d9b35c/gguf-py/scripts/gguf_new_metadata.py#L46
def decode_field(field):
if field and field.types:
main_type = field.types[0]
if main_type == gguf.GGUFValueType.ARRAY:
sub_type = field.types[-1]
if sub_type == gguf.GGUFValueType.STRING:
return [
str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data
]
else:
return [pv for idx in field.data for pv in field.parts[idx].tolist()]
if main_type == gguf.GGUFValueType.STRING:
return str(bytes(field.parts[-1]), encoding="utf-8")
else:
return field.parts[-1][0]
return None
def load_gguf(model_path: str) -> tuple[nn.Module, TokenizerWrapper]:
with tempfile.TemporaryDirectory() as tmp_dir:
base_name = Path(model_path).name
(Path(tmp_dir) / base_name).symlink_to(model_path)
tokenizer = AutoTokenizer.from_pretrained(tmp_dir, gguf_file=base_name)
reader = GGUFReader(model_path)
model_type = "qwen2"
config = {
"model_type": model_type,
"vocab_size": tokenizer.vocab_size,
"tie_word_embeddings": False,
}
mapping = {k.format(model=model_type): v for k, v in FIELD_MAPPING.items()}
for field in reader.fields:
if field in mapping:
config[mapping[field]] = decode_field(reader.get_field(field))
config["quantization"] = {}
weights = {}
# Look for any extra gguf files
parts = Path(model_path).name.split("-")
parts[-3] = "*"
gguf_pattern = "-".join(parts)
for filename in Path(model_path).parent.glob(gguf_pattern):
reader = GGUFReader(str(filename))
for tensor in reader.tensors:
w = parse_gguf_tensor(tensor)
mx.eval(w)
name = convert_name(tensor.name)
base_name = ".".join(name.split(".")[:-1])
if quant := QUANT_MAPPING[tensor.tensor_type]:
config["quantization"][base_name] = quant
if len(w) == 3:
w, scales, biases = w
weights[name] = w
weights[base_name + ".scales"] = scales
weights[base_name + ".biases"] = biases
else:
weights[name] = w
arch = importlib.import_module(f"mlx_lm.models.{config['model_type']}")
model_class, model_args_class = arch.Model, arch.ModelArgs
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
quant_config = config["quantization"]
def pred(p, m):
return quant_config.get(p)
nn.quantize(model, class_predicate=pred)
model.load_weights(list(weights.items()))
model.eval()
return model, tokenizer

295
llms/mlx_lm/lora.py Normal file
View File

@@ -0,0 +1,295 @@
# Copyright © 2024 Apple Inc.
import argparse
import math
import re
import types
from pathlib import Path
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import yaml
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import (
build_schedule,
linear_to_lora_layers,
load_adapters,
print_trainable_parameters,
)
from .utils import load, save_config
yaml_loader = yaml.SafeLoader
yaml_loader.add_implicit_resolver(
"tag:yaml.org,2002:float",
re.compile(
"""^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$""",
re.X,
),
list("-+0123456789."),
)
CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"fine_tune_type": "lora",
"data": "data/",
"seed": 0,
"num_layers": 16,
"batch_size": 4,
"iters": 1000,
"val_batches": 25,
"learning_rate": 1e-5,
"steps_per_report": 10,
"steps_per_eval": 200,
"resume_adapter_file": None,
"adapter_path": "adapters",
"save_every": 100,
"test": False,
"test_batches": 500,
"max_seq_length": 2048,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
}
def build_parser():
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
parser.add_argument(
"--model",
help="The path to the local model directory or Hugging Face repo.",
)
# Training args
parser.add_argument(
"--train",
action="store_true",
help="Do training",
default=None,
)
parser.add_argument(
"--data",
type=str,
help=(
"Directory with {train, valid, test}.jsonl files or the name "
"of a Hugging Face dataset (e.g., 'mlx-community/wikisql')"
),
)
parser.add_argument(
"--fine-tune-type",
type=str,
choices=["lora", "dora", "full"],
default="lora",
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--num-layers",
type=int,
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
)
parser.add_argument("--batch-size", type=int, help="Minibatch size.")
parser.add_argument("--iters", type=int, help="Iterations to train for.")
parser.add_argument(
"--val-batches",
type=int,
help="Number of validation batches, -1 uses the entire validation set.",
)
parser.add_argument("--learning-rate", type=float, help="Adam learning rate.")
parser.add_argument(
"--steps-per-report",
type=int,
help="Number of training steps between loss reporting.",
)
parser.add_argument(
"--steps-per-eval",
type=int,
help="Number of training steps between validations.",
)
parser.add_argument(
"--resume-adapter-file",
type=str,
help="Load path to resume training from the given fine-tuned weights.",
)
parser.add_argument(
"--adapter-path",
type=str,
help="Save/load path for the fine-tuned weights.",
)
parser.add_argument(
"--save-every",
type=int,
help="Save the model every N iterations.",
)
parser.add_argument(
"--test",
action="store_true",
help="Evaluate on the test set after training",
default=None,
)
parser.add_argument(
"--test-batches",
type=int,
help="Number of test set batches, -1 uses the entire test set.",
)
parser.add_argument(
"--max-seq-length",
type=int,
help="Maximum sequence length.",
)
parser.add_argument(
"-c",
"--config",
default=None,
help="A YAML configuration file with the training options",
)
parser.add_argument(
"--grad-checkpoint",
action="store_true",
help="Use gradient checkpointing to reduce memory use.",
default=None,
)
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
return parser
def train_model(
args,
model: nn.Module,
tokenizer: TokenizerWrapper,
train_set,
valid_set,
training_callback: TrainingCallback = None,
):
model.freeze()
if args.fine_tune_type == "full":
for l in model.layers[-min(args.num_layers, 0) :]:
l.unfreeze()
elif args.fine_tune_type in ["lora", "dora"]:
# Convert linear layers to lora/dora layers and unfreeze in the process
linear_to_lora_layers(
model,
args.num_layers,
args.lora_parameters,
use_dora=(args.fine_tune_type == "dora"),
)
else:
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")
# Resume from weights if provided
if args.resume_adapter_file is not None:
print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)
print_trainable_parameters(model)
adapter_path = Path(args.adapter_path)
adapter_path.mkdir(parents=True, exist_ok=True)
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")
# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)
model.train()
opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)
# Train model
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)
def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
def run(args, training_callback: TrainingCallback = None):
np.random.seed(args.seed)
print("Loading pretrained model")
model, tokenizer = load(args.model)
print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args, tokenizer)
if args.test and not args.train:
# Allow testing without LoRA layers by providing empty path
if args.adapter_path != "":
load_adapters(model, args.adapter_path)
elif args.train:
print("Training")
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
else:
raise ValueError("Must provide at least one of --train or --test")
if args.test:
print("Testing")
evaluate_model(args, model, tokenizer, test_set)
def main():
parser = build_parser()
args = parser.parse_args()
config = args.config
args = vars(args)
if config:
print("Loading configuration file", config)
with open(config, "r") as file:
config = yaml.load(file, yaml_loader)
# Prefer parameters from command-line arguments
for k, v in config.items():
if args.get(k, None) is None:
args[k] = v
# Update defaults for unspecified parameters
for k, v in CONFIG_DEFAULTS.items():
if args.get(k, None) is None:
args[k] = v
run(types.SimpleNamespace(**args))
if __name__ == "__main__":
main()

121
llms/mlx_lm/manage.py Normal file
View File

@@ -0,0 +1,121 @@
import argparse
from typing import List, Union
from huggingface_hub import scan_cache_dir
from transformers.commands.user import tabulate
def ask_for_confirmation(message: str) -> bool:
y = ("y", "yes", "1")
n = ("n", "no", "0")
all_values = y + n + ("",)
full_message = f"{message} (Y/n) "
while True:
answer = input(full_message).lower()
if answer == "":
return False
if answer in y:
return True
if answer in n:
return False
print(f"Invalid input. Must be one of {all_values}")
def main():
parser = argparse.ArgumentParser(description="MLX Model Cache.")
parser.add_argument(
"--scan",
action="store_true",
help="Scan Hugging Face cache for mlx models.",
)
parser.add_argument(
"--delete",
action="store_true",
help="Delete models matching the given pattern.",
)
parser.add_argument(
"--pattern",
type=str,
help="Model repos contain the pattern.",
default="mlx",
)
args = parser.parse_args()
if args.scan:
print(
"Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".'
)
hf_cache_info = scan_cache_dir()
print(
tabulate(
rows=[
[
repo.repo_id,
repo.repo_type,
"{:>12}".format(repo.size_on_disk_str),
repo.nb_files,
repo.last_accessed_str,
repo.last_modified_str,
str(repo.repo_path),
]
for repo in sorted(
hf_cache_info.repos, key=lambda repo: repo.repo_path
)
if args.pattern in repo.repo_id
],
headers=[
"REPO ID",
"REPO TYPE",
"SIZE ON DISK",
"NB FILES",
"LAST_ACCESSED",
"LAST_MODIFIED",
"LOCAL PATH",
],
)
)
if args.delete:
print(f'Deleting models matching pattern "{args.pattern}"')
hf_cache_info = scan_cache_dir()
repos = [
repo
for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path)
if args.pattern in repo.repo_id
]
if repos:
print(
tabulate(
rows=[
[
repo.repo_id,
str(repo.repo_path),
]
for repo in repos
],
headers=[
"REPO ID",
"LOCAL PATH",
],
)
)
confirmed = ask_for_confirmation(f"Confirm deletion ?")
if confirmed:
for model_info in repos:
for revision in sorted(
model_info.revisions, key=lambda revision: revision.commit_hash
):
strategy = hf_cache_info.delete_revisions(revision.commit_hash)
strategy.execute()
print("Model(s) deleted.")
else:
print("Deletion is cancelled. Do nothing.")
else:
print(f"No models found.")
if __name__ == "__main__":
main()

172
llms/mlx_lm/merge.py Normal file
View File

@@ -0,0 +1,172 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import glob
import shutil
from pathlib import Path
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import yaml
from mlx.utils import tree_flatten, tree_map
from .utils import (
fetch_from_hub,
get_model_path,
save_config,
save_weights,
upload_to_hub,
)
def configure_parser() -> argparse.ArgumentParser:
"""
Configures and returns the argument parser for the script.
Returns:
argparse.ArgumentParser: Configured argument parser.
"""
parser = argparse.ArgumentParser(description="Merge multiple models.")
parser.add_argument("--config", type=str, help="Path to the YAML config.")
parser.add_argument(
"--mlx-path",
type=str,
default="mlx_merged_model",
help="Path to save the MLX model.",
)
parser.add_argument(
"--upload-repo",
help="The Hugging Face repo to upload the model to.",
type=str,
default=None,
)
return parser
def slerp(t, w1, w2, eps=1e-5):
"""
Spherical linear interpolation
Args:
t (float): Interpolation weight in [0.0, 1.0]
w1 (mx.array): First input
w2 (mx.array): Second input
eps (float): Constant for numerical stability
Returns:
mx.array: Interpolated result
"""
t = float(t)
if t == 0:
return w1
elif t == 1:
return w2
# Normalize
v1 = w1 / mx.linalg.norm(w1)
v2 = w2 / mx.linalg.norm(w2)
# Angle
dot = mx.clip((v1 * v2).sum(), 0.0, 1.0)
theta = mx.arccos(dot)
sin_theta = mx.sin(theta + eps)
s1 = mx.sin(theta * (1 - t)) / sin_theta
s2 = mx.sin(theta * t) / sin_theta
return s1 * w1 + s2 * w2
def merge_models(base_model: nn.Module, model: nn.Module, config: dict):
method = config.get("method", None)
if method != "slerp":
raise ValueError(f"Merge method {method} not supported")
num_layers = len(model.layers)
def unpack_values(vals):
if isinstance(vals, (int, float)):
return np.full(num_layers, vals)
bins = len(vals) - 1
sizes = [num_layers // bins] * bins
sizes[-1] = num_layers - sum(sizes[:-1])
return np.concatenate(
[np.linspace(v1, v2, s) for v1, v2, s in zip(vals[:-1], vals[1:], sizes)]
)
param_list = config["parameters"]["t"]
params = {}
filter_keys = set()
for pl in param_list[:-1]:
params[pl["filter"]] = unpack_values(pl["value"])
filter_keys.add(pl["filter"])
default = unpack_values(param_list[-1]["value"])
for e in range(num_layers):
bl = base_model.layers[e]
l = model.layers[e]
base_weights = bl.parameters()
weights = l.parameters()
for k, w1 in base_weights.items():
w2 = weights[k]
t = params.get(k, default)[e]
base_weights[k] = tree_map(lambda x, y: slerp(t, x, y), w1, w2)
base_model.update(base_weights)
def merge(
config: str,
mlx_path: str = "mlx_model",
upload_repo: Optional[str] = None,
):
with open(config, "r") as fid:
merge_conf = yaml.safe_load(fid)
print("[INFO] Loading")
model_paths = merge_conf.get("models", [])
if len(model_paths) < 2:
raise ValueError(f"Expected at least 2 models, got {len(model_paths)}.")
# Load all models
base_hf_path = model_paths[0]
base_path = get_model_path(base_hf_path)
base_model, base_config, tokenizer = fetch_from_hub(base_path, lazy=True)
models = []
for mp in model_paths[1:]:
model, model_config, _ = fetch_from_hub(get_model_path(mp), lazy=True)
base_type = base_config["model_type"]
model_type = model_config["model_type"]
if base_type != model_type:
raise ValueError(
f"Can only merge models of the same type,"
f" but got {base_type} and {model_type}."
)
models.append(model)
# Merge models into base model
for m in models:
merge_models(base_model, m, merge_conf)
# Save base model
mlx_path = Path(mlx_path)
weights = dict(tree_flatten(base_model.parameters()))
del models, base_model
save_weights(mlx_path, weights, donate_weights=True)
py_files = glob.glob(str(base_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)
tokenizer.save_pretrained(mlx_path)
save_config(config, config_path=mlx_path / "config.json")
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, base_hf_path)
def main():
parser = configure_parser()
args = parser.parse_args()
merge(**vars(args))
if __name__ == "__main__":
main()

View File

113
llms/mlx_lm/models/base.py Normal file
View File

@@ -0,0 +1,113 @@
# Copyright © 2023-2024 Apple Inc.
import inspect
from dataclasses import dataclass
from typing import Any, Optional
import mlx.core as mx
from mlx.utils import tree_map
from .cache import QuantizedKVCache
@dataclass
class BaseModelArgs:
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
linds = linds[:, None]
rinds = rinds[None]
mask = linds < rinds
if window_size is not None:
mask = mask | (linds > rinds + window_size)
return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
T = h.shape[1]
if T > 1:
window_size = None
offset = 0
if cache is not None and cache[0] is not None:
c = cache[0]
if hasattr(c, "max_size"):
offset = min(c.max_size, c.offset)
window_size = c.max_size
else:
offset = c.offset
mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else:
mask = None
return mask
def quantized_scaled_dot_product_attention(
queries: mx.array,
q_keys: tuple[mx.array, mx.array, mx.array],
q_values: tuple[mx.array, mx.array, mx.array],
scale: float,
mask: Optional[mx.array],
group_size: int = 64,
bits: int = 8,
) -> mx.array:
B, n_q_heads, L, D = queries.shape
n_kv_heads = q_keys[0].shape[-3]
n_repeats = n_q_heads // n_kv_heads
queries *= scale
if n_repeats > 1:
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
scores = mx.quantized_matmul(
queries, *q_keys, transpose=True, group_size=group_size, bits=bits
)
if mask is not None:
scores += mask
scores = mx.softmax(scores, axis=-1, precise=True)
out = mx.quantized_matmul(
scores, *q_values, transpose=False, group_size=group_size, bits=bits
)
if n_repeats > 1:
out = mx.reshape(out, (B, n_q_heads, L, D))
return out
def scaled_dot_product_attention(
queries,
keys,
values,
cache,
scale: float,
mask: Optional[mx.array],
) -> mx.array:
if isinstance(cache, QuantizedKVCache):
return quantized_scaled_dot_product_attention(
queries,
keys,
values,
scale=scale,
mask=mask,
group_size=cache.group_size,
bits=cache.bits,
)
else:
return mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=scale, mask=mask
)

438
llms/mlx_lm/models/cache.py Normal file
View File

@@ -0,0 +1,438 @@
# Copyright © 2023-2024 Apple Inc.
from typing import Any, Dict, List, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_map, tree_unflatten
def make_prompt_cache(
model: nn.Module,
max_kv_size: Optional[int] = None,
) -> List[Any]:
"""
Construct the model's cache for use when cgeneration.
This function will defer the cache construction to the model if it has a
``make_cache`` method, otherwise it will make a default KV cache.
Args:
model (nn.Module): The language model.
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
"""
if hasattr(model, "make_cache"):
return model.make_cache()
num_layers = len(model.layers)
if max_kv_size is not None:
return [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
]
else:
return [KVCache() for _ in range(num_layers)]
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
"""
Save a pre-computed prompt cache to a file.
Args:
file_name (str): The ``.safetensors`` file name.
cache (List[Any]): The model state.
metadata (Dict[str, str]): Optional metadata to save along with model
state.
"""
cache_data = [c.state for c in cache]
cache_info = [c.meta_state for c in cache]
cache_data = dict(tree_flatten(cache_data))
cache_classes = [type(c).__name__ for c in cache]
cache_metadata = [cache_info, metadata, cache_classes]
cache_metadata = dict(tree_flatten(cache_metadata))
mx.save_safetensors(file_name, cache_data, cache_metadata)
def load_prompt_cache(file_name, return_metadata=False):
"""
Load a prompt cache from a file.
Args:
file_name (str): The ``.safetensors`` file name.
return_metadata (bool): Whether or not to return metadata.
Default: ``False``.
Returns:
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
the metadata if requested.
"""
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
arrays = tree_unflatten(list(arrays.items()))
cache_metadata = tree_unflatten(list(cache_metadata.items()))
info, metadata, classes = cache_metadata
cache = [globals()[c]() for c in classes]
for c, state, meta_state in zip(cache, arrays, info):
c.state = state
c.meta_state = meta_state
if return_metadata:
return cache, metadata
return cache
def can_trim_prompt_cache(cache: List[Any]) -> bool:
"""
Check if model's cache can be trimmed.
"""
return all(c.is_trimmable() for c in cache)
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
"""
Trim the model's cache by the given number of tokens.
This function will trim the cache if possible (in-place) and return the
number of tokens that were trimmed.
Args:
cache (List[Any]): The model's cache.
num_tokens (int): The number of tokens to trim.
Returns:
(int): The number of tokens that were trimmed.
"""
if not can_trim_prompt_cache(cache) or len(cache) == 0:
return 0
return [c.trim(num_tokens) for c in cache][0]
class _BaseCache:
@property
def state(self):
return []
@state.setter
def state(self, v):
if v is not None and v:
raise ValueError("This cache has no state but a state was set.")
@property
def meta_state(self):
return ""
@meta_state.setter
def meta_state(self, v):
if v is not None and v:
raise ValueError("This cache has no meta_state but a meta_state was set.")
def is_trimmable(self):
return False
class QuantizedKVCache(_BaseCache):
def __init__(self, group_size: int = 64, bits: int = 8):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
self.group_size = group_size
self.bits = bits
def update_and_fetch(self, keys, values):
B, n_kv_heads, num_steps, k_head_dim = keys.shape
v_head_dim = values.shape[-1]
prev = self.offset
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
el_per_int = 8 * mx.uint32.size // self.bits
new_steps = (self.step + num_steps - 1) // self.step * self.step
shape = (B, n_kv_heads, new_steps)
def init_quant(dim):
return (
mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32),
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
)
def expand_quant(x):
new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype)
return mx.concatenate([x, new_x], axis=-2)
if self.keys is not None:
if prev % self.step != 0:
self.keys, self.values = tree_map(
lambda x: x[..., :prev, :], (self.keys, self.values)
)
self.keys, self.values = tree_map(
expand_quant, (self.keys, self.values)
)
else:
self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim)
self.offset += num_steps
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
for i in range(len(self.keys)):
self.keys[i][..., prev : self.offset, :] = keys[i]
self.values[i][..., prev : self.offset, :] = values[i]
return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values))
@property
def state(self):
if self.offset == self.keys[0].shape[2]:
return self.keys, self.values
else:
return tree_map(
lambda x: x[..., : self.offset, :], (self.keys, self.values)
)
@state.setter
def state(self, v):
self.keys, self.values = v
@property
def meta_state(self):
return tuple(map(str, (self.step, self.offset, self.group_size, self.bits)))
@meta_state.setter
def meta_state(self, v):
self.step, self.offset, self.group_size, self.bits = map(int, v)
def is_trimmable(self):
return True
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n
class KVCache(_BaseCache):
def __init__(self):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B, n_kv_heads, _, k_head_dim = keys.shape
v_head_dim = values.shape[3]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@property
def state(self):
if self.offset == self.keys.shape[2]:
return self.keys, self.values
else:
return (
self.keys[..., : self.offset, :],
self.values[..., : self.offset, :],
)
@state.setter
def state(self, v):
self.keys, self.values = v
self.offset = self.keys.shape[2]
def is_trimmable(self):
return True
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
quant_cache.offset = self.offset
if self.keys is not None:
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
quant_cache.values = mx.quantize(
self.values, group_size=group_size, bits=bits
)
return quant_cache
class RotatingKVCache(_BaseCache):
def __init__(self, max_size=None, keep=0, step=256):
self.keep = keep
self.keys = None
self.values = None
self.offset = 0
self.max_size = max_size
self.step = step
self._idx = 0
def _trim(self, trim_size, v, append=None):
to_cat = []
if trim_size > 0:
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
else:
to_cat = [v]
if append is not None:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def _temporal_order(self, v):
"""
Rearrange the cache into temporal order, slicing off the end if unused.
"""
if self._idx == v.shape[2]:
return v
elif self._idx < self.offset:
return mx.concatenate(
[
v[..., : self.keep, :],
v[..., self._idx :, :],
v[..., self.keep : self._idx, :],
],
axis=2,
)
else:
return v[..., : self._idx, :]
def _update_concat(self, keys, values):
if self.keys is None:
self.keys = keys
self.values = values
else:
# Put the keys/values in temporal order to
# preserve context
self.keys = self._temporal_order(self.keys)
self.values = self._temporal_order(self.values)
# The largest size is self.max_size + S to ensure
# every token gets at least self.max_size context
trim_size = self._idx - self.max_size
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += keys.shape[2]
self._idx = self.keys.shape[2]
return self.keys, self.values
def _update_in_place(self, keys, values):
# May not have hit the max size yet, so potentially
# keep growing the cache
B, n_kv_heads, S, k_head_dim = keys.shape
prev = self.offset
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
):
v_head_dim = values.shape[3]
new_size = min(self.step, self.max_size - prev)
k_shape = (B, n_kv_heads, new_size, k_head_dim)
v_shape = (B, n_kv_heads, new_size, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self._idx = prev
# Trim if needed
trim_size = self.keys.shape[2] - self.max_size
if trim_size > 0:
self.keys = self._trim(trim_size, self.keys)
self.values = self._trim(trim_size, self.values)
self._idx = self.max_size
# Rotate
if self._idx == self.max_size:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + S, :] = keys
self.values[..., self._idx : self._idx + S, :] = values
self.offset += S
self._idx += S
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
def update_and_fetch(self, keys, values):
if keys.shape[2] == 1:
return self._update_in_place(keys, values)
return self._update_concat(keys, values)
@property
def state(self):
if self.offset < self.keys.shape[2]:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
else:
return self.keys, self.values
@state.setter
def state(self, v):
self.keys, self.values = v
@property
def meta_state(self):
return tuple(
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
)
@meta_state.setter
def meta_state(self, v):
self.keep, self.max_size, self.step, self.offset, self._idx = map(
int,
v,
)
def is_trimmable(self):
return self.offset < self.max_size
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
self._idx -= n
return n
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
raise NotImplementedError("RotatingKVCache Quantization NYI")
class MambaCache(_BaseCache):
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
@state.setter
def state(self, v):
self.cache = v

View File

@@ -0,0 +1,192 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int = 8192
num_hidden_layers: int = 40
intermediate_size: int = 22528
num_attention_heads: int = 64
num_key_value_heads: int = 64
rope_theta: float = 8000000.0
vocab_size: int = 256000
layer_norm_eps: float = 1e-05
logit_scale: float = 0.0625
attention_bias: bool = False
layer_norm_bias: bool = False
use_qk_norm: bool = False
class LayerNorm2D(nn.Module):
def __init__(self, d1, d2, eps):
super().__init__()
self.weight = mx.zeros((d1, d2))
self.eps = eps
def __call__(self, x):
return self.weight * mx.fast.layer_norm(x, None, None, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // args.num_attention_heads
self.scale = head_dim**-0.5
attetion_bias = args.attention_bias
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
self.use_qk_norm = args.use_qk_norm
if self.use_qk_norm:
self.q_norm = LayerNorm2D(self.n_heads, head_dim, eps=args.layer_norm_eps)
self.k_norm = LayerNorm2D(
self.n_kv_heads, head_dim, eps=args.layer_norm_eps
)
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.n_heads, -1)
keys = keys.reshape(B, L, self.n_kv_heads, -1)
if self.use_qk_norm:
queries = self.q_norm(queries)
keys = self.k_norm(keys)
queries = queries.transpose(0, 2, 1, 3)
keys = keys.transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
def __call__(self, x):
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.n_heads = args.num_attention_heads
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.input_layernorm(x)
attn_h = self.self_attn(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x
class CohereModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = CohereModel(args)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out
@property
def layers(self):
return self.model.layers

251
llms/mlx_lm/models/dbrx.py Normal file
View File

@@ -0,0 +1,251 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
d_model: int
ffn_config: dict
attn_config: dict
n_layers: int
n_heads: int
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_heads = args.n_heads
self.d_model = args.d_model
self.head_dim = args.d_model // args.n_heads
self.num_key_value_heads = args.attn_config["kv_n_heads"]
self.clip_qkv = args.attn_config["clip_qkv"]
self.rope_theta = args.attn_config["rope_theta"]
self.scale = self.head_dim**-0.5
self.Wqkv = nn.Linear(
args.d_model,
(self.num_key_value_heads * 2 + self.num_heads) * self.head_dim,
bias=False,
)
self.out_proj = nn.Linear(args.d_model, args.d_model, bias=False)
self.rope = nn.RoPE(
self.head_dim,
traditional=False,
base=self.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
qkv = self.Wqkv(x)
qkv = mx.clip(qkv, a_min=-self.clip_qkv, a_max=self.clip_qkv)
splits = [self.d_model, self.d_model + self.head_dim * self.num_key_value_heads]
queries, keys, values = mx.split(qkv, splits, axis=-1)
B, L, D = x.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(output)
class NormAttnNorm(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.norm_1 = nn.LayerNorm(args.d_model, bias=False)
self.norm_2 = nn.LayerNorm(args.d_model, bias=False)
self.attn = Attention(args)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
x = h + x
return x, self.norm_2(x)
class MLP(nn.Module):
def __init__(self, d_model: int, ffn_dim: int):
super().__init__()
self.v1 = nn.Linear(d_model, ffn_dim, bias=False)
self.w1 = nn.Linear(d_model, ffn_dim, bias=False)
self.w2 = nn.Linear(ffn_dim, d_model, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
current_hidden_states = self.act_fn(self.w1(x)) * self.v1(x)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
class Router(nn.Module):
def __init__(self, d_model: int, num_experts: int):
super().__init__()
self.layer = nn.Linear(d_model, num_experts, bias=False)
def __call__(self, x: mx.array):
return self.layer(x)
class SparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.d_model = args.d_model
self.ffn_dim = args.ffn_config["ffn_hidden_size"]
self.num_experts = args.ffn_config["moe_num_experts"]
self.num_experts_per_tok = args.ffn_config["moe_top_k"]
self.router = Router(self.d_model, self.num_experts)
self.experts = [
MLP(self.d_model, self.ffn_dim) for _ in range(self.num_experts)
]
def __call__(self, x: mx.array) -> mx.array:
ne = self.num_experts_per_tok
orig_shape = x.shape
x = x.reshape(-1, x.shape[-1])
gates = self.router(x)
gates = mx.softmax(gates.astype(mx.float32), axis=-1)
inds = mx.stop_gradient(mx.argpartition(-gates, kth=ne - 1, axis=-1)[:, :ne])
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = scores / mx.linalg.norm(scores, ord=1, axis=-1, keepdims=True)
scores = scores.astype(x.dtype)
if self.training:
inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
continue
y[idx1, idx2] = expert(x[idx1])
y = (y * scores[:, :, None]).sum(axis=1)
else:
y = []
for xt, st, it in zip(x, scores, inds.tolist()):
yt = mx.stack([self.experts[e](xt) for e in it], axis=-1)
yt = (yt * st).sum(axis=-1)
y.append(yt)
y = mx.stack(y, axis=0)
return y.reshape(orig_shape)
class DecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ffn = SparseMoeBlock(args)
self.norm_attn_norm = NormAttnNorm(args)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r, h = self.norm_attn_norm(x, mask, cache)
out = self.ffn(h) + r
return out
class DBRX(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.vocab_size = args.vocab_size
self.wte = nn.Embedding(args.vocab_size, args.d_model)
self.blocks = [DecoderLayer(args=args) for _ in range(args.n_layers)]
self.norm_f = nn.LayerNorm(args.d_model, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.wte(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)
for layer, c in zip(self.blocks, cache):
h = layer(h, mask, c)
return self.norm_f(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.transformer = DBRX(args)
self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.transformer(inputs, cache)
return self.lm_head(out)
@property
def layers(self):
return self.transformer.blocks
def sanitize(self, weights):
# Split experts into sub matrices
num_experts = self.args.ffn_config["moe_num_experts"]
dim = self.args.ffn_config["ffn_hidden_size"]
pattern = "experts.mlp"
new_weights = {k: v for k, v in weights.items() if pattern not in k}
for k, v in weights.items():
if pattern in k:
experts = [
(k.replace(".mlp", f".{e}") + ".weight", sv)
for e, sv in enumerate(mx.split(v, num_experts, axis=0))
]
if k.endswith("w2"):
experts = [(s, sv.T) for s, sv in experts]
new_weights.update(experts)
return new_weights

View File

@@ -0,0 +1,258 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "deepseek"
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
moe_intermediate_size: int = 1407
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32
n_shared_experts: Optional[int] = None
n_routed_experts: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_layer_freq: int = 1
first_k_dense_replace: int = 0
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Optional[Dict] = None
attention_bias: bool = False
class DeepseekAttention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.scale = self.head_dim**-0.5
attention_bias = getattr(config, "attention_bias", False)
self.q_proj = nn.Linear(
self.hidden_size,
config.num_attention_heads * self.head_dim,
bias=attention_bias,
)
self.k_proj = nn.Linear(
self.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=attention_bias,
)
self.o_proj = nn.Linear(
self.hidden_size,
config.num_attention_heads * self.head_dim,
bias=attention_bias,
)
rope_scale = 1.0
if config.rope_scaling and config.rope_scaling["type"] == "linear":
assert isinstance(config.rope_scaling["factor"], float)
rope_scale = 1 / config.rope_scaling["factor"]
self.rope = nn.RoPE(
self.head_dim,
base=config.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(
0, 2, 1, 3
)
keys = keys.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class DeepseekMLP(nn.Module):
def __init__(
self,
config: ModelArgs,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
):
super().__init__()
self.config = config
self.hidden_size = hidden_size or config.hidden_size
self.intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = nn.silu
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.softmax(gates, axis=-1, precise=True)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(scores, inds, axis=-1)
return inds, scores
class DeepseekMoE(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekMLP(
config=config, intermediate_size=intermediate_size
)
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
return y
class DeepseekDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = DeepseekAttention(config)
self.mlp = (
DeepseekMoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekMLP(config)
)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class DeepseekModel(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
DeepseekDecoderLayer(config, idx) for idx in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for m in ["gate_proj", "down_proj", "up_proj"]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -0,0 +1,417 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "deepseek_v2"
vocab_size: int = 102400
hidden_size: int = 4096
intermediate_size: int = 11008
moe_intermediate_size: int = 1407
num_hidden_layers: int = 30
num_attention_heads: int = 32
num_key_value_heads: int = 32
n_shared_experts: Optional[int] = None
n_routed_experts: Optional[int] = None
routed_scaling_factor: float = 1.0
kv_lora_rank: int = 512
q_lora_rank: int = 1536
qk_rope_head_dim: int = 64
v_head_dim: int = 128
qk_nope_head_dim: int = 128
topk_method: str = "gready"
n_group: Optional[int] = None
topk_group: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_layer_freq: int = 1
first_k_dense_replace: int = 0
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Dict = None
attention_bias: bool = False
def yarn_find_correction_dim(
num_rotations, dim, base=10000, max_position_embeddings=2048
):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def yarn_find_correction_range(
low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
):
low = math.floor(
yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
)
high = math.ceil(
yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
)
return max(low, 0), min(high, dim - 1)
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
def yarn_linear_ramp_mask(min_val, max_val, dim):
if min_val == max_val:
max_val += 0.001 # Prevent singularity
linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val)
return mx.clip(linear_func, 0, 1)
class DeepseekV2YarnRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
super().__init__()
self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale(
scaling_factor, mscale_all_dim
)
freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim)
freq_inter = scaling_factor * base ** (
mx.arange(0, dim, 2, dtype=mx.float32) / dim
)
low, high = yarn_find_correction_range(
beta_fast,
beta_slow,
dim,
base,
original_max_position_embeddings,
)
freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2)
self._freqs = (freq_inter * freq_extra) / (
freq_inter * freq_mask + freq_extra * (1 - freq_mask)
)
def __call__(self, x, offset=0):
if self.mscale != 1.0:
x = self.mscale * x
return mx.fast.rope(
x,
x.shape[-1],
traditional=True,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
class DeepseekV2Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.kv_lora_rank = config.kv_lora_rank
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
self.scale = self.q_head_dim**-0.5
if self.q_lora_rank is None:
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.q_head_dim, bias=False
)
else:
self.q_a_proj = nn.Linear(
self.hidden_size, self.q_lora_rank, bias=config.attention_bias
)
self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank)
self.q_b_proj = nn.Linear(
self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
)
self.kv_a_proj_with_mqa = nn.Linear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear(
self.kv_lora_rank,
self.num_heads
* (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
bias=False,
)
self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=config.attention_bias,
)
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
rope_kwargs = {
key: self.config.rope_scaling[key]
for key in [
"original_max_position_embeddings",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
]
if key in self.config.rope_scaling
}
self.rope = DeepseekV2YarnRotaryEmbedding(
dim=self.qk_rope_head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
**rope_kwargs,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
if self.q_lora_rank is None:
q = self.q_proj(x)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))
q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(x)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
if cache is not None:
q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values
)
else:
q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class DeepseekV2MLP(nn.Module):
def __init__(
self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
self.intermediate_size = (
config.intermediate_size if intermediate_size is None else intermediate_size
)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def __call__(self, x):
down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class MoEGate(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.n_routed_experts = config.n_routed_experts
self.routed_scaling_factor = config.routed_scaling_factor
self.topk_method = config.topk_method
self.n_group = config.n_group
self.topk_group = config.topk_group
self.weight = mx.zeros((self.n_routed_experts, config.hidden_size))
def __call__(self, x):
gates = x @ self.weight.T
scores = mx.softmax(gates, axis=-1, precise=True)
if self.topk_method == "group_limited_greedy":
bsz, seq_len = x.shape[:2]
scores = scores.reshape(bsz, seq_len, self.n_group, -1)
group_scores = scores.max(axis=-1)
k = self.n_group - self.topk_group
group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k]
batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2))
seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2))
scores[batch_idx, seq_idx, group_idx] = 0.0
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
scores = scores * self.routed_scaling_factor
return inds, scores
class DeepseekV2MoE(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.num_experts_per_tok = config.num_experts_per_tok
self.switch_mlp = SwitchGLU(
config.hidden_size, config.moe_intermediate_size, config.n_routed_experts
)
self.gate = MoEGate(config)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP(
config=config, intermediate_size=intermediate_size
)
def __call__(self, x):
inds, scores = self.gate(x)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(x)
return y
class DeepseekV2DecoderLayer(nn.Module):
def __init__(self, config: ModelArgs, layer_idx: int):
super().__init__()
self.self_attn = DeepseekV2Attention(config)
self.mlp = (
DeepseekV2MoE(config)
if (
config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
)
else DeepseekV2MLP(config)
)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class DeepseekV2Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
DeepseekV2DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.args = config
self.model_type = config.model_type
self.model = DeepseekV2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{m}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}")
for e in range(self.args.n_routed_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers

175
llms/mlx_lm/models/gemma.py Normal file
View File

@@ -0,0 +1,175 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
head_dim: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int
rope_theta: float = 10000
rope_traditional: bool = False
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class GemmaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = GemmaModel(args)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out)
return out
@property
def layers(self):
return self.model.layers

View File

@@ -0,0 +1,200 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
head_dim: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int
rope_theta: float = 10000
rope_traditional: bool = False
attn_logit_softcapping: float = 50.0
final_logit_softcapping: float = 30.0
query_pre_attn_scalar: float = 144.0
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.repeats = n_heads // n_kv_heads
self.head_dim = head_dim = args.head_dim
self.scale = 1.0 / (args.query_pre_attn_scalar**0.5)
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.attn_logit_softcapping = args.attn_logit_softcapping
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries * self.scale
if self.repeats > 1:
queries = queries.reshape(
B, self.n_kv_heads, self.repeats, L, self.head_dim
)
keys = mx.expand_dims(keys, 2)
values = mx.expand_dims(values, 2)
scores = queries @ keys.swapaxes(-1, -2)
scores = mx.tanh(scores / self.attn_logit_softcapping)
scores *= self.attn_logit_softcapping
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, precise=True, axis=-1)
output = scores @ values
if self.repeats > 1:
output = output.reshape(B, self.n_heads, L, self.head_dim)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.pre_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_feedforward_layernorm = RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + self.post_attention_layernorm(r)
r = self.mlp(self.pre_feedforward_layernorm(h))
out = h + self.post_feedforward_layernorm(r)
return out
class GemmaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
h = h * (self.args.hidden_size**0.5)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.final_logit_softcapping = args.final_logit_softcapping
self.model = GemmaModel(args)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out)
out = mx.tanh(out / self.final_logit_softcapping)
out = out * self.final_logit_softcapping
return out
@property
def layers(self):
return self.model.layers

198
llms/mlx_lm/models/gpt2.py Normal file
View File

@@ -0,0 +1,198 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
n_ctx: int
n_embd: int
n_head: int
n_layer: int
n_positions: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.n_head
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.n_embd % args.n_head == 0, "n_embd must be divisible by n_head"
self.n_embd = args.n_embd
self.n_head = args.n_head
self.head_dim = self.n_embd // self.n_head
self.scale = self.head_dim**-0.5
self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.c_attn(x)
queries, keys, values = mx.split(qkv, 3, axis=-1)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_head, -1).transpose(0, 2, 1, 3)
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_embd = args.n_embd
self.c_fc = nn.Linear(self.n_embd, 4 * self.n_embd)
self.c_proj = nn.Linear(4 * self.n_embd, self.n_embd)
def __call__(self, x) -> mx.array:
return self.c_proj(nn.gelu_approx(self.c_fc(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_head = args.n_head
self.n_embd = args.n_embd
self.layer_norm_epsilon = args.layer_norm_epsilon
self.attn = Attention(args)
self.mlp = MLP(args)
self.ln_1 = nn.LayerNorm(
self.n_embd,
eps=self.layer_norm_epsilon,
)
self.ln_2 = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
r = self.mlp(self.ln_2(h))
out = h + r
return out
class GPT2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_embd = args.n_embd
self.n_positions = args.n_positions
self.vocab_size = args.vocab_size
self.n_layer = args.n_layer
self.layer_norm_epsilon = args.layer_norm_epsilon
assert self.vocab_size > 0
self.wte = nn.Embedding(self.vocab_size, self.n_embd)
self.wpe = nn.Embedding(self.n_positions, self.n_embd)
self.h = [TransformerBlock(args=args) for _ in range(self.n_layer)]
self.ln_f = nn.LayerNorm(self.n_embd, eps=self.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
cache=None,
):
_, L = inputs.shape
hidden_states = self.wte(inputs)
mask = None
if hidden_states.shape[1] > 1:
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)
return self.ln_f(hidden_states)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = GPT2Model(args)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
out = self.model.wte.as_linear(out)
return out
def sanitize(self, weights):
new_weights = {}
for i in range(self.args.n_layer):
if f"h.{i}.attn.bias" in weights:
del weights[f"h.{i}.attn.bias"]
if f"h.{i}.attn.c_attn.weight" in weights:
weights[f"h.{i}.attn.c_attn.weight"] = weights[
f"h.{i}.attn.c_attn.weight"
].transpose(1, 0)
if f"h.{i}.attn.c_proj.weight" in weights:
weights[f"h.{i}.attn.c_proj.weight"] = weights[
f"h.{i}.attn.c_proj.weight"
].transpose(1, 0)
if f"h.{i}.mlp.c_fc.weight" in weights:
weights[f"h.{i}.mlp.c_fc.weight"] = weights[
f"h.{i}.mlp.c_fc.weight"
].transpose(1, 0)
if f"h.{i}.mlp.c_proj.weight" in weights:
weights[f"h.{i}.mlp.c_proj.weight"] = weights[
f"h.{i}.mlp.c_proj.weight"
].transpose(1, 0)
for weight in weights:
if not weight.startswith("model."):
new_weights[f"model.{weight}"] = weights[weight]
else:
new_weights[weight] = weights[weight]
return new_weights
@property
def layers(self):
return self.model.h

View File

@@ -0,0 +1,186 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
n_embd: int
n_layer: int
n_inner: int
n_head: int
n_positions: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int = None
multi_query: bool = True
attention_bias: bool = True
mlp_bias: bool = True
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = 1 if self.multi_query else self.n_head
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = dim = args.n_embd
self.n_heads = n_heads = args.n_head
self.n_kv_heads = n_kv_heads = 1 if args.multi_query else args.n_head
self.head_dim = head_dim = dim // n_heads
self.kv_dim = n_kv_heads * head_dim
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False
self.c_attn = nn.Linear(dim, dim + 2 * self.kv_dim, bias=attention_bias)
self.c_proj = nn.Linear(dim, dim, bias=attention_bias)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.c_attn(x)
queries, keys, values = mx.split(
qkv, [self.dim, self.dim + self.kv_dim], axis=-1
)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.n_embd
hidden_dim = args.n_inner
if hasattr(args, "mlp_bias"):
mlp_bias = args.mlp_bias
else:
mlp_bias = False
self.c_fc = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.c_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.c_proj(nn.gelu(self.c_fc(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_head = args.n_head
self.n_embd = args.n_embd
self.attn = Attention(args)
self.mlp = MLP(args)
self.ln_1 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
self.ln_2 = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
r = self.mlp(self.ln_2(h))
out = h + r
return out
class GPTBigCodeModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
assert self.vocab_size > 0
self.wte = nn.Embedding(args.vocab_size, args.n_embd)
self.wpe = nn.Embedding(args.n_positions, args.n_embd)
self.h = [TransformerBlock(args=args) for _ in range(args.n_layer)]
self.ln_f = nn.LayerNorm(args.n_embd, eps=args.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
cache=None,
):
B, L = inputs.shape
hidden_states = self.wte(inputs)
mask = None
if hidden_states.shape[1] > 1:
position_ids = mx.array(np.arange(L))
hidden_states += self.wpe(position_ids)
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)
return self.ln_f(hidden_states)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.transformer = GPTBigCodeModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.transformer(inputs, cache)
if self.args.tie_word_embeddings:
out = self.transformer.wte.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.transformer.h

View File

@@ -0,0 +1,216 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
# Based on the transformers implementation at:
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
max_position_embeddings: int
hidden_size: int
num_attention_heads: int
num_hidden_layers: int
layer_norm_eps: float
vocab_size: int
rotary_emb_base: int
rotary_pct: float
num_key_value_heads: int = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert (
args.hidden_size % args.num_attention_heads == 0
), "hidden_size must be divisible by num_attention_heads"
self.hidden_size = args.hidden_size
self.num_attention_heads = args.num_attention_heads
self.head_dim = self.hidden_size // self.num_attention_heads
self.rope = nn.RoPE(
dims=int(self.head_dim * args.rotary_pct),
traditional=False,
base=args.rotary_emb_base,
)
self.scale = self.head_dim**-0.5
self.query_key_value = nn.Linear(
self.hidden_size, 3 * self.hidden_size, bias=True
)
self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.query_key_value(x)
new_qkv_shape = qkv.shape[:-1] + (self.num_attention_heads, 3 * self.head_dim)
qkv = qkv.reshape(*new_qkv_shape)
queries, keys, values = [x.transpose(0, 2, 1, 3) for x in qkv.split(3, -1)]
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.dense(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
def __call__(self, x) -> mx.array:
# gelu_approx corresponds to FastGELUActivation in transformers.
return self.dense_4h_to_h(nn.gelu_approx(self.dense_h_to_4h(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.layer_norm_eps = args.layer_norm_eps
self.attention = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = nn.LayerNorm(
self.hidden_size,
eps=self.layer_norm_eps,
)
self.post_attention_layernorm = nn.LayerNorm(
self.hidden_size, eps=self.layer_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
residual = x
# NeoX runs attention and feedforward network in parallel.
attn = self.attention(self.input_layernorm(x), mask, cache)
ffn = self.mlp(self.post_attention_layernorm(x))
out = attn + ffn + residual
return out
class GPTNeoXModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
self.layer_norm_eps = args.layer_norm_eps
assert self.vocab_size > 0
self.embed_in = nn.Embedding(self.vocab_size, self.hidden_size)
self.embed_out = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.h = [TransformerBlock(args=args) for _ in range(self.num_hidden_layers)]
self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
_, L = inputs.shape
hidden_states = self.embed_in(inputs)
mask = create_attention_mask(hidden_states, cache)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
hidden_states = layer(hidden_states, mask, cache=c)
out = self.final_layer_norm(hidden_states)
out = self.embed_out(out)
return out
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = GPTNeoXModel(args)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
return out
def sanitize(self, weights):
new_weights = {}
for w_key, w_value in weights.items():
# Created through register_buffer in Pytorch, not needed here.
ignore_suffixes = [
".attention.bias",
".attention.masked_bias",
".attention.rotary_emb.inv_freq",
]
skip_weight = False
for ignored_suffix in ignore_suffixes:
if w_key.endswith(ignored_suffix):
skip_weight = True
break
if skip_weight:
continue
if not w_key.startswith("model."):
w_key = f"model.{w_key}"
w_key = w_key.replace(".gpt_neox.layers.", ".h.")
w_key = w_key.replace(".gpt_neox.", ".")
new_weights[w_key] = w_value
return new_weights
@property
def layers(self):
return self.model.h

View File

@@ -0,0 +1,238 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
bias: bool = True
max_position_embeddings: int = 32768
num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] not in ["linear", "dynamic"]:
raise ValueError(
"rope_scaling 'type' currently only supports 'linear' or 'dynamic"
)
class DynamicNTKScalingRoPE(nn.Module):
"""Implements the rotary positional encoding with Dynamic NTK scaling."""
def __init__(
self,
dims: int,
max_position_embeddings: int = 2048,
traditional: bool = False,
base: float = 10000,
scale: float = 1.0,
):
super().__init__()
self.max_position_embeddings = max_position_embeddings
self.original_base = base
self.dims = dims
self.traditional = traditional
self.scale = scale
def extra_repr(self):
return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}"
def __call__(self, x, offset: int = 0):
seq_len = x.shape[1] + offset
if seq_len > self.max_position_embeddings:
base = self.original_base * (
(self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1)
) ** (self.dims / (self.dims - 2))
else:
base = self.original_base
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=base,
scale=self.scale,
offset=offset,
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_kv_groups = n_heads // args.num_key_value_heads
self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.wqkv = nn.Linear(
dim, (n_heads + 2 * n_kv_heads) * head_dim, bias=args.bias
)
self.wo = nn.Linear(n_heads * head_dim, dim, bias=args.bias)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 2.0
)
self.rope = DynamicNTKScalingRoPE(
head_dim,
max_position_embeddings=args.max_position_embeddings,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv_states = self.wqkv(x)
qkv_states = qkv_states.reshape(B, L, -1, 2 + self.n_kv_groups, self.head_dim)
queries = qkv_states[..., : self.n_kv_groups, :]
queries = queries.reshape(B, L, -1, self.head_dim)
keys = qkv_states[..., -2, :]
values = qkv_states[..., -1, :]
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.wo(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.w2(nn.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention = Attention(args)
self.feed_forward = MLP(args.hidden_size, args.intermediate_size)
self.attention_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.ffn_norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attention(self.attention_norm(x), mask, cache)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out
class InternLM2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.vocab_size > 0
self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.tok_embeddings(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = InternLM2Model(args)
if not args.tie_word_embeddings:
self.output = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.tok_embeddings.as_linear(out)
else:
out = self.output(out)
return out
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k}
@property
def layers(self):
return self.model.layers

306
llms/mlx_lm/models/llama.py Normal file
View File

@@ -0,0 +1,306 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
num_key_value_heads: Optional[int] = None
attention_bias: bool = False
mlp_bias: bool = False
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
if not "factor" in self.rope_scaling:
raise ValueError(f"rope_scaling must contain 'factor'")
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
"rope_type"
)
if rope_type is None:
raise ValueError(
f"rope_scaling must contain either 'type' or 'rope_type'"
)
if rope_type not in ["linear", "dynamic", "llama3"]:
raise ValueError(
"rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'"
)
class DynamicNTKScalingRoPE(nn.Module):
"""Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE."""
def __init__(
self,
dims: int,
max_position_embeddings: int = 2048,
traditional: bool = False,
base: float = 10000,
scale: float = 1.0,
rope_type: str = "default",
rope_scaling: dict = None,
):
super().__init__()
self.dims = dims
self.max_position_embeddings = max_position_embeddings
self.traditional = traditional
self.scale = scale
self.rope_type = rope_type
self.rope_scaling = rope_scaling
self.base = base
self.compute_freqs()
def compute_freqs(self):
if self.rope_type != "llama3":
self._freqs = None
return
factor = self.rope_scaling["factor"]
low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0)
old_context_len = self.rope_scaling.get(
"original_max_position_embeddings",
8192,
)
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims)
wavelens = 2 * mx.pi * freqs
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
self.base = None
def extra_repr(self):
return (
f"{self.dims}, traditional={self.traditional}, "
f"max_position_embeddings={self.max_position_embeddings}, "
f"scaling_factor={self.scale}, rope_type={self.rope_type}"
)
def __call__(self, x, offset: int = 0):
return mx.fast.rope(
x,
self.dims,
traditional=self.traditional,
base=self.base,
scale=self.scale,
offset=offset,
freqs=self._freqs,
)
def initialize_rope(args: ModelArgs):
head_dim = args.head_dim or args.hidden_size // args.num_attention_heads
rope_scaling = args.rope_scaling
rope_type = "default"
rope_scale = 1.0
if rope_scaling is not None:
rope_type = (
rope_scaling.get("type") or rope_scaling.get("rope_type") or "default"
)
if rope_type == "linear":
rope_scale = 1 / rope_scaling["factor"]
elif rope_type == "llama3":
rope_scale = 1.0 # The scaling is handled internally for llama3
return DynamicNTKScalingRoPE(
dims=head_dim,
max_position_embeddings=args.max_position_embeddings,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
rope_type=rope_type,
rope_scaling=rope_scaling,
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
self.rope = initialize_rope(args)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
if hasattr(args, "mlp_bias"):
mlp_bias = args.mlp_bias
else:
mlp_bias = False
self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class LlamaModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = LlamaModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def layers(self):
return self.model.layers

228
llms/mlx_lm/models/mamba.py Normal file
View File

@@ -0,0 +1,228 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
intermediate_size: int
state_size: int
num_hidden_layers: int
conv_kernel: int
use_bias: bool
use_conv_bias: bool
time_step_rank: int
tie_word_embeddings: bool = True
use_bcdt_rms: bool = False
mixer_rms_eps: float = 1e-6
def __post_init__(self):
if not hasattr(self, "hidden_size") and hasattr(self, "d_model"):
self.hidden_size = self.d_model
if not hasattr(self, "intermediate_size") and hasattr(self, "d_inner"):
self.intermediate_size = self.d_inner
if not hasattr(self, "state_size") and hasattr(self, "d_state"):
self.state_size = self.d_state
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layer"):
self.num_hidden_layers = self.n_layer
if not hasattr(self, "num_hidden_layers") and hasattr(self, "n_layers"):
self.num_hidden_layers = self.n_layers
if not hasattr(self, "conv_kernel") and hasattr(self, "d_conv"):
self.conv_kernel = self.d_conv
if not hasattr(self, "use_bias") and hasattr(self, "bias"):
self.use_bias = self.bias
if not hasattr(self, "use_conv_bias") and hasattr(self, "conv_bias"):
self.use_conv_bias = self.conv_bias
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
if self.model_type == "falcon_mamba":
self.use_bcdt_rms = True
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.padding = padding
self.weight = mx.random.normal((self.channels, kernel_size, 1))
self.bias = mx.zeros((channels,)) if bias else None
def __call__(self, x, cache=None):
B, L, C = x.shape
groups, K, _ = self.weight.shape
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=groups)
if self.bias is not None:
y = y + self.bias
return y, x[:, -K + 1 :, :]
class MambaBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.hidden_size = args.hidden_size
self.ssm_state_size = args.state_size
self.conv_kernel_size = args.conv_kernel
self.intermediate_size = args.intermediate_size
self.time_step_rank = int(args.time_step_rank)
self.use_conv_bias = args.use_conv_bias
self.use_bcdt_rms = args.use_bcdt_rms
if self.use_bcdt_rms:
self.mixer_norm = lambda x: mx.fast.rms_norm(
x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps
)
self.in_proj = nn.Linear(
self.hidden_size, self.intermediate_size * 2, bias=args.use_bias
)
self.conv1d = DepthWiseConv1d(
channels=self.intermediate_size,
kernel_size=self.conv_kernel_size,
bias=self.use_conv_bias,
padding=self.conv_kernel_size - 1,
)
self.x_proj = nn.Linear(
self.intermediate_size,
self.time_step_rank + 2 * self.ssm_state_size,
bias=False,
)
self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
A = mx.repeat(
mx.arange(1.0, self.ssm_state_size + 1.0).reshape([1, self.ssm_state_size]),
repeats=self.intermediate_size,
axis=0,
)
self.A_log = mx.log(A)
self.D = mx.ones([self.intermediate_size])
self.out_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=args.use_bias
)
def ssm_step(self, x, state=None):
A = -mx.exp(self.A_log)
D = self.D
deltaBC = self.x_proj(x)
delta, B, C = mx.split(
deltaBC,
indices_or_sections=[
self.time_step_rank,
self.time_step_rank + self.ssm_state_size,
],
axis=-1,
)
if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None:
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
y = y + D * x
return y, new_state
def __call__(self, x, cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
outputs = []
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1])
z_t = nn.silu(z_t)
output_t = y_t * z_t
output_t = self.out_proj(output_t)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
return output
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = MambaBlock(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
class Mamba(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers

View File

@@ -0,0 +1,207 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
dim_model_base: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: int
scale_depth: float
scale_emb: float
rope_theta: float = 1000000.0
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[str, float]]] = None
tie_word_embeddings: bool = False
class MLP(nn.Module):
def __init__(self, args):
super().__init__()
self.gate_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
self.up_proj = nn.Linear(args.hidden_size, args.intermediate_size, bias=False)
self.down_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
def __call__(self, x):
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.hidden_size = args.hidden_size
self.num_heads = n_heads = args.num_attention_heads
self.rope_theta = args.rope_theta
self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.num_key_value_heads = args.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
dims=self.head_dim,
traditional=args.rope_traditional,
base=self.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
):
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
attn_output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(attn_output)
class DecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.hidden_size = args.hidden_size
self.num_hidden_layers = args.num_hidden_layers
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.scale_depth = args.scale_depth
self.num_hidden_layers = args.num_hidden_layers
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
r = self.mlp(self.post_attention_layernorm(h))
out = h + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
return out
class MiniCPMModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [DecoderLayer(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs) * self.args.scale_emb
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = MiniCPMModel(args)
if not self.args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if not self.args.tie_word_embeddings:
out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base))
else:
out = out @ self.model.embed_tokens.weight.T
return out
def sanitize(self, weights):
if "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -0,0 +1,217 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int = 32000
hidden_size: int = 4096
intermediate_size: int = 14336
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_experts_per_tok: int = 2
num_key_value_heads: int = 8
num_local_experts: int = 8
rms_norm_eps: float = 1e-5
rope_theta: float = 1e6
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class MixtralAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.num_heads = args.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = args.num_key_value_heads
self.rope_theta = args.rope_theta
self.scale = self.head_dim**-0.5
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
self.rope = nn.RoPE(
self.head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MixtralSparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_dim = args.hidden_size
self.ffn_dim = args.intermediate_size
self.num_experts = args.num_local_experts
self.num_experts_per_tok = args.num_experts_per_tok
# gating
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
def __call__(self, x: mx.array) -> mx.array:
gates = self.gate(x)
k = self.num_experts_per_tok
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = mx.softmax(scores, axis=-1, precise=True)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
return y
class MixtralDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = MixtralAttention(args)
self.block_sparse_moe = MixtralSparseMoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.block_sparse_moe(self.post_attention_layernorm(h))
out = h + r
return out
class MixtralModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = MixtralModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(
f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
)
for e in range(self.args.num_local_experts)
]
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
mx.stack(to_join)
)
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -0,0 +1,217 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
hidden_act: str
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
norm_eps: float
vocab_size: int
num_key_value_heads: int
head_dim: Optional[int] = None
max_position_embeddings: Optional[int] = None
attention_bias: bool = False
mlp_bias: bool = False
partial_rotary_factor: float = 0.5
rope_theta: float = 10000.0
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self):
if self.rope_scaling:
if not "factor" in self.rope_scaling:
raise ValueError(f"rope_scaling must contain 'factor'")
rope_type = self.rope_scaling.get("type") or self.rope_scaling.get(
"rope_type"
)
if rope_type is None:
raise ValueError(
f"rope_scaling must contain either 'type' or 'rope_type'"
)
if rope_type not in ["linear"]:
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
@partial(mx.compile, shapeless=True)
def relu_squared(x):
return nn.relu(x).square()
class NemotronLayerNorm1P(nn.LayerNorm):
def __call__(self, x):
weight = self.weight + 1 if "weight" in self else None
bias = self.bias if "bias" in self else None
return mx.fast.layer_norm(x, weight, bias, self.eps)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
self.partial_rotary_factor = args.partial_rotary_factor
self.scale = head_dim**-0.5
if hasattr(args, "attention_bias"):
attention_bias = args.attention_bias
else:
attention_bias = False
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
int(self.partial_rotary_factor * self.head_dim),
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
hidden_dim = args.intermediate_size
mlp_bias = args.mlp_bias
self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias)
self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias)
def __call__(self, x) -> mx.array:
return self.down_proj(relu_squared(self.up_proj(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args)
self.input_layernorm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
self.post_attention_layernorm = NemotronLayerNorm1P(
args.hidden_size, eps=args.norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class NemotronModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = NemotronLayerNorm1P(args.hidden_size, eps=args.norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = NemotronModel(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.model.layers

176
llms/mlx_lm/models/olmo.py Normal file
View File

@@ -0,0 +1,176 @@
# Copyright © 2023-2024 Apple Inc.
import sys
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask
try:
import hf_olmo
except ImportError:
print("To run olmo install ai2-olmo: pip install ai2-olmo")
sys.exit(1)
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
d_model: int
n_layers: int
mlp_hidden_size: int
n_heads: int
vocab_size: int
embedding_size: int
rope_theta: float = 10000
rope_traditional: bool = False
mlp_ratio: int = 4
weight_tying: bool = False
def __post_init__(self):
self.mlp_hidden_size = (
self.mlp_hidden_size
if self.mlp_hidden_size is not None
else self.mlp_ratio * self.d_model
)
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
dim = args.d_model
self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False)
self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False)
self.att_norm = nn.LayerNorm(dim, affine=False)
self.ff_norm = nn.LayerNorm(dim, affine=False)
head_dim = dim // self.n_heads
self.scale = head_dim**-0.5
self.att_proj = nn.Linear(dim, 3 * dim, bias=False)
self.attn_out = nn.Linear(dim, dim, bias=False)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
self.args = args
def attend(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = mx.split(self.att_proj(x), 3, axis=-1)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores += mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.attn_out(output)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attend(self.att_norm(x), mask, cache)
h = x + r
x1, x2 = mx.split(self.ff_proj(self.ff_norm(h)), 2, axis=-1)
out = h + self.ff_out(nn.silu(x2) * x1)
return out
class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_layers = args.n_layers
self.weight_tying = args.weight_tying
self.wte = nn.Embedding(args.embedding_size, args.d_model)
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
if not self.weight_tying:
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
self.norm = nn.LayerNorm(args.d_model, affine=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.wte(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.blocks)
for block, c in zip(self.blocks, cache):
h = block(h, mask, c)
h = self.norm(h)
if self.weight_tying:
return self.wte.as_linear(h), cache
return self.ff_out(h)
class OlmoModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.transformer = Transformer(args)
def __call__(
self,
inputs: mx.array,
cache=None,
):
return self.transformer(inputs, cache)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = OlmoModel(args)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
return self.model(inputs, cache)
@property
def layers(self):
return self.model.transformer.blocks

View File

@@ -0,0 +1,220 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
head_dim: int
num_transformer_layers: int
model_dim: int
vocab_size: int
ffn_dim_divisor: int
num_query_heads: List
num_kv_heads: List
ffn_multipliers: List
ffn_with_glu: bool = True
normalize_qk_projections: bool = True
share_input_output_layers: bool = True
rms_norm_eps: float = 1e-6
rope_freq_constant: float = 10000
def make_divisible(
v: Union[float, int],
divisor: Optional[int] = 8,
min_value: Optional[Union[float, int]] = None,
) -> Union[float, int]:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by the divisor
It can be seen at:
https://github.com/tensorflow/models/blob/2cfc99eff5e5eb729c6793d2f3d03aa1c9be2b15/research/slim/nets/mobilenet/mobilenet.py#L62
Args:
v: input value
divisor: default to 8
min_value: minimum divisor value
Returns:
new_v: new divisible value
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_id: int):
super().__init__()
self.head_dim = head_dim = args.head_dim
self.layer_id = layer_id
self.model_dim = model_dim = args.model_dim
self.n_heads = n_heads = args.num_query_heads[layer_id]
self.n_kv_heads = n_kv_heads = args.num_kv_heads[layer_id]
self.scale = head_dim**-0.5
op_size = (n_heads + (n_kv_heads * 2)) * head_dim
self.qkv_proj = nn.Linear(model_dim, op_size, bias=False)
self.out_proj = nn.Linear(n_heads * head_dim, model_dim, bias=False)
self.normalize_qk_projections = args.normalize_qk_projections
if self.normalize_qk_projections:
self.q_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.k_norm = nn.RMSNorm(head_dim, eps=args.rms_norm_eps)
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_freq_constant)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.qkv_proj(x)
qkv = qkv.reshape(
B, L, self.n_heads + (self.n_kv_heads * 2), self.head_dim
).transpose(0, 2, 1, 3)
queries, keys, values = mx.split(
qkv, [self.n_heads, self.n_heads + self.n_kv_heads], axis=1
)
# Prepare the queries, keys and values for the attention computation
if self.normalize_qk_projections:
queries = self.q_norm(queries)
keys = self.k_norm(keys)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs, layer_id: int):
super().__init__()
self.args = args
dim = args.model_dim
ffn_multiplier = args.ffn_multipliers[layer_id]
intermediate_dim = int(
make_divisible(
ffn_multiplier * args.model_dim,
divisor=args.ffn_dim_divisor,
)
)
self.proj_1 = nn.Linear(dim, 2 * intermediate_dim, bias=False)
self.proj_2 = nn.Linear(intermediate_dim, dim, bias=False)
def __call__(self, x) -> mx.array:
x = self.proj_1(x)
gate, x = mx.split(x, 2, axis=-1)
return self.proj_2(nn.silu(gate) * x)
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_id: int):
super().__init__()
dim = args.model_dim
self.attn = Attention(args, layer_id=layer_id)
self.ffn = MLP(args, layer_id=layer_id)
self.ffn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
self.attn_norm = nn.RMSNorm(dim, eps=args.rms_norm_eps)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.attn_norm(x), mask, cache)
h = x + r
r = self.ffn(self.ffn_norm(h))
out = h + r
return out
class OpenELMModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_transformer_layers = args.num_transformer_layers
assert self.vocab_size > 0
self.token_embeddings = nn.Embedding(args.vocab_size, args.model_dim)
self.layers = [
TransformerBlock(args, layer_id=layer_id)
for layer_id in range(self.num_transformer_layers)
]
self.norm = nn.RMSNorm(args.model_dim, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.token_embeddings(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.transformer = OpenELMModel(args)
if not args.share_input_output_layers:
self.lm_head = nn.Linear(args.model_dim, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.transformer(inputs, cache)
if self.args.share_input_output_layers:
out = self.transformer.token_embeddings.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.transformer.layers

177
llms/mlx_lm/models/phi.py Normal file
View File

@@ -0,0 +1,177 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "phi"
max_position_embeddings: int = 2048
vocab_size: int = 51200
hidden_size: int = 2560
num_attention_heads: int = 32
num_hidden_layers: int = 32
num_key_value_heads: int = 32
partial_rotary_factor: float = 0.4
intermediate_size: int = 10240
layer_norm_eps: float = 1e-5
rope_theta: float = 10000.0
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class PhiAttention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.repeats = self.num_heads // self.num_key_value_heads
self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=True
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.dense = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=True
)
self.rope = nn.RoPE(
int(self.partial_rotary_factor * self.head_dim),
traditional=False,
base=self.rope_theta,
)
def __call__(self, x, mask=None, cache=None):
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Extract some shapes
B, L, D = queries.shape
n_heads, n_kv_heads = self.num_heads, self.num_key_value_heads
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(
B,
L,
n_heads,
-1,
).moveaxis(1, 2)
keys = keys.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
values = values.reshape(B, L, n_kv_heads, -1).moveaxis(1, 2)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
scale = math.sqrt(1 / queries.shape[-1])
output = scaled_dot_product_attention(
queries.astype(mx.float32),
keys,
values,
cache=cache,
scale=scale,
mask=mask,
).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1)
return self.dense(output)
class PhiMLP(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
self.act = nn.GELU(approx="precise")
def __call__(self, x) -> mx.array:
return self.fc2(self.act(self.fc1(x)))
class PhiDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.self_attn = PhiAttention(config=config)
self.input_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.mlp = PhiMLP(config)
def __call__(self, x, mask, cache):
h = self.input_layernorm(x)
attn_h = self.self_attn(h, mask, cache)
ff_h = self.mlp(h)
return attn_h + ff_h + x
class PhiModel(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)]
self.final_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
def __call__(self, x, cache):
x = self.embed_tokens(x)
mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, mask, c)
return self.final_layernorm(x)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.model = PhiModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
self.args = config
def __call__(
self,
x: mx.array,
cache=None,
) -> mx.array:
y = self.model(x, cache)
return self.lm_head(y)
@property
def layers(self):
return self.model.layers

204
llms/mlx_lm/models/phi3.py Normal file
View File

@@ -0,0 +1,204 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .su_rope import SuScaledRotaryEmbedding
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: Optional[int] = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, List[float]]]] = None
max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"long_factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] not in ["longrope", "su", "linear"]:
print(
"[WARNING] rope_scaling 'type' currently only supports 'linear', 'su', and 'longrope'; setting rope scaling to false."
)
self.rope_scaling = None
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.num_hidden_layers = args.num_hidden_layers
self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
if args.rope_scaling and args.rope_scaling["type"] in ["longrope", "su"]:
self.rope = SuScaledRotaryEmbedding(
head_dim,
base=args.rope_theta,
max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings,
short_factor=args.rope_scaling["short_factor"],
long_factor=args.rope_scaling["long_factor"],
)
else:
rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "linear":
assert isinstance(args.rope_scaling["factor"], float)
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.qkv_proj(x)
query_pos = self.n_heads * self.head_dim
queries, keys, values = mx.split(
qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1
)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
def __call__(self, x) -> mx.array:
x = self.gate_up_proj(x)
gate, x = mx.split(x, 2, axis=-1)
return self.down_proj(nn.silu(gate) * x)
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class Phi3Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = Phi3Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.args = args
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
@property
def layers(self):
return self.model.layers

View File

@@ -0,0 +1,310 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from functools import partial
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
dense_attention_every_n_layers: int
ff_intermediate_size: int
gegelu_limit: float
num_hidden_layers: int
num_attention_heads: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: int
mup_attn_multiplier: float = 1.0
mup_use_scaling: bool = True
mup_embedding_multiplier: float = 10.0
mup_width_multiplier: float = 8.0
rope_embedding_base: float = 1000000
rope_position_scale: float = 1.0
blocksparse_block_size: int = 64
blocksparse_num_local_blocks: int = 16
blocksparse_vert_stride: int = 8
@partial(mx.compile, shapeless=True)
def gegelu_impl(a_gelu, a_linear, limit):
a_gelu = mx.where(
mx.isinf(a_gelu),
a_gelu,
mx.clip(a_gelu, a_min=None, a_max=limit),
)
a_linear = mx.where(
mx.isinf(a_linear),
a_linear,
mx.clip(a_linear, a_min=-limit, a_max=limit),
)
out_gelu = a_gelu * mx.sigmoid(1.702 * a_gelu)
return out_gelu * (a_linear + 1.0)
def gegelu(x, limit):
a_gelu, a_linear = x[..., ::2], x[..., 1::2]
return gegelu_impl(a_gelu, a_linear, limit)
class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_idx):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_q_per_kv = n_heads // n_kv_heads
self.head_dim = head_dim = args.hidden_size // n_heads
self.query_key_value = nn.Linear(
dim, (self.n_heads + 2 * self.n_kv_heads) * head_dim
)
self.dense = nn.Linear(dim, dim)
if args.mup_use_scaling:
norm_factor = head_dim / args.mup_attn_multiplier
else:
norm_factor = math.sqrt(head_dim)
self.scale = 1.0 / norm_factor
self.rope = nn.RoPE(
head_dim,
traditional=False,
base=args.rope_embedding_base,
scale=args.rope_position_scale,
)
if layer_idx % args.dense_attention_every_n_layers == 0:
self.block_sparse = True
self.blocksparse_block_size = args.blocksparse_block_size
if self.blocksparse_block_size not in (32, 64):
raise ValueError(
f"Unsupported block size {self.blocksparse_block_size}"
)
self.blocksparse_num_local_blocks = args.blocksparse_num_local_blocks
self.blocksparse_vert_stride = args.blocksparse_vert_stride
else:
self.block_sparse = False
def _block_sparse_mask(self, q_len, kv_len):
vert_stride = self.blocksparse_vert_stride
local_blocks = self.blocksparse_num_local_blocks
block_size = self.blocksparse_block_size
n_heads = self.n_heads
kv_blocks = (kv_len + block_size - 1) // block_size
q_blocks = (q_len + block_size - 1) // block_size
q_pos = mx.arange(kv_blocks - q_blocks, kv_blocks)[None, :, None]
k_pos = mx.arange(kv_blocks)[None, None]
mask_vert_strided = (
mx.arange(kv_blocks)[None, :] + mx.arange(1, n_heads + 1)[:, None]
) % vert_stride
mask_vert_strided = (mask_vert_strided == 0)[:, None, :]
block_mask = (q_pos >= k_pos) & (
(q_pos - k_pos < local_blocks) | mask_vert_strided
)
block_mask = block_mask.reshape(
self.n_kv_heads, self.n_q_per_kv, *block_mask.shape[-2:]
)
dense_mask = mx.repeat(
mx.repeat(block_mask, block_size, axis=-1), block_size, axis=-2
)
return block_mask, dense_mask[..., -q_len:, :kv_len]
def _block_sparse_attention(self, queries, keys, values, scale, mask):
queries = scale * queries
B = queries.shape[0]
L = queries.shape[2]
queries = mx.reshape(queries, (B, self.n_kv_heads, self.n_q_per_kv, L, -1))
keys = mx.expand_dims(keys, 2)
values = mx.expand_dims(values, 2)
# TODO get rid of dense mask if we have a fill value
block_mask, dense_mask = self._block_sparse_mask(L, keys.shape[-2])
scores = queries @ mx.swapaxes(keys, -1, -2)
# TODO, uncomment when faster
# scores = mx.block_masked_mm(
# queries,
# mx.swapaxes(keys, -1, -2),
# mask_out=block_mask,
# block_size=self.blocksparse_block_size,
# )
if mask is not None:
scores = scores + mask
scores = scores + mx.where(
dense_mask, mx.array(0, scores.dtype), mx.array(-float("inf"), scores.dtype)
)
scores = mx.softmax(scores, axis=-1, precise=True)
output = scores @ values
# TODO, uncomment when faster
# output = mx.block_masked_mm(
# scores, values, mask_lhs=block_mask, block_size=self.blocksparse_block_size
# )
return mx.reshape(output, (B, self.n_heads, L, -1))
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
qkv = self.query_key_value(x)
qkv = qkv.reshape(B, L, -1, self.n_q_per_kv + 2, self.head_dim)
queries = qkv[..., :-2, :].flatten(-3, -2)
keys = qkv[..., -2, :]
values = qkv[..., -1, :]
# Prepare the queries, keys and values for the attention computation
queries = queries.transpose(0, 2, 1, 3)
keys = keys.transpose(0, 2, 1, 3)
values = values.transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
if self.block_sparse:
output = self._block_sparse_attention(
queries, keys, values, scale=self.scale, mask=mask
)
else:
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.dense(output)
class MLP(nn.Module):
def __init__(self, args):
super().__init__()
dim = args.hidden_size
hidden_dim = args.ff_intermediate_size
self.gegelu_limit = args.gegelu_limit
self.up_proj = nn.Linear(dim, 2 * hidden_dim)
self.down_proj = nn.Linear(hidden_dim, dim)
def __call__(self, x) -> mx.array:
x = self.up_proj(x)
return self.down_proj(gegelu(x, self.gegelu_limit))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_idx):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args, layer_idx)
self.mlp = MLP(args)
self.input_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_epsilon
)
self.post_attention_layernorm = nn.LayerNorm(
args.hidden_size,
eps=args.layer_norm_epsilon,
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class Phi3Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.mup_embedding_multiplier = args.mup_embedding_multiplier
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args, layer_idx=l)
for l in range(args.num_hidden_layers)
]
self.final_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_epsilon
)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
if self.mup_embedding_multiplier:
h = self.mup_embedding_multiplier * h
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.final_layernorm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.model = Phi3Model(args)
self.args = args
self.mup_width_multiplier = args.mup_width_multiplier
self._dummy_tokenizer_ids = mx.array(
[100256, 100258, 100259, 100260, 100264, 100265]
+ list(range(100267, 100352))
)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
out = self.model.embed_tokens.as_linear(out)
if self.mup_width_multiplier:
out = out / self.mup_width_multiplier
out[self._dummy_tokenizer_ids] = -float("inf")
return out
@property
def layers(self):
return self.model.layers
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}

View File

@@ -0,0 +1,211 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .su_rope import SuScaledRotaryEmbedding
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "phimoe"
vocab_size: int = 32064
hidden_size: int = 4096
intermediate_size: int = 6400
num_hidden_layers: int = 32
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096
rms_norm_eps: float = 1e-6
rope_scaling: Dict[str, Union[float, List[float]]] = None
num_local_experts: int = 16
num_experts_per_tok: int = 2
rope_theta: float = 10000.0
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
self.rope = SuScaledRotaryEmbedding(
head_dim,
base=args.rope_theta,
max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings,
short_factor=args.rope_scaling["short_factor"],
long_factor=args.rope_scaling["long_factor"],
short_mscale=args.rope_scaling["short_mscale"],
long_mscale=args.rope_scaling["long_mscale"],
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache=None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class PhiMoESparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_dim = args.hidden_size
self.ffn_dim = args.intermediate_size
self.num_experts = args.num_local_experts
self.top_k = args.num_experts_per_tok
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.switch_mlp = SwitchGLU(self.hidden_dim, self.ffn_dim, self.num_experts)
def __call__(self, x: mx.array) -> mx.array:
gates = self.gate(x)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = mx.softmax(scores, axis=-1, precise=True)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
return y
class PhiMoEDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.block_sparse_moe = PhiMoESparseMoeBlock(args)
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache=None,
) -> mx.array:
residual = x
hidden_states = self.input_layernorm(x)
hidden_states = self.self_attn(hidden_states, mask=mask, cache=cache)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class PhiMoEModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [PhiMoEDecoderLayer(args) for _ in range(args.num_hidden_layers)]
self.norm = nn.LayerNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.args = args
self.model = PhiMoEModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
if "model.layers.0.block_sparse_moe.experts.0.w1.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.block_sparse_moe.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(
f"{prefix}.block_sparse_moe.experts.{e}.{n}.{k}"
)
for e in range(self.args.num_local_experts)
]
weights[f"{prefix}.block_sparse_moe.switch_mlp.{m}.{k}"] = (
mx.stack(to_join)
)
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -0,0 +1,200 @@
# Copyright © 2023-2024 Apple Inc.
import inspect
import math
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchMLP
@dataclass
class ModelArgs:
model_type: str
num_vocab: int = 51200
model_dim: int = 2560
num_heads: int = 32
num_layers: int = 32
rotary_dim: int = 32
num_experts_per_tok: int = 2
num_local_experts: int = 4
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
class RoPEAttention(nn.Module):
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(rotary_dim, traditional=False)
self.Wqkv = nn.Linear(dims, 3 * dims)
self.out_proj = nn.Linear(dims, dims)
def __call__(self, x, mask=None, cache=None):
qkv = self.Wqkv(x)
queries, keys, values = mx.split(qkv, 3, axis=-1)
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries.astype(mx.float32)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
output = scaled_dot_product_attention(
queries.astype(mx.float32),
keys,
values,
cache=cache,
scale=scale,
mask=mask,
).astype(values.dtype)
output = output.moveaxis(2, 1).reshape(B, L, -1)
return self.out_proj(output)
class MOE(nn.Module):
def __init__(self, args: ModelArgs, dim: int, hidden_dim: int):
super().__init__()
self.dim = dim
self.hidden_dim = hidden_dim
self.num_experts = args.num_local_experts
self.num_experts_per_tok = args.num_experts_per_tok
self.switch_mlp = SwitchMLP(
self.dim, self.hidden_dim, self.num_experts, bias=True
)
self.gate = nn.Linear(args.model_dim, self.num_experts, bias=False)
def __call__(self, x: mx.array) -> mx.array:
gates = self.gate(x)
k = self.num_experts_per_tok
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1))[..., :k]
scores = mx.take_along_axis(gates, inds, axis=-1)
scores = mx.softmax(scores, axis=-1, precise=True)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
return y
class ParallelBlock(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
dims = config.model_dim
mlp_dims = dims * 4
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
self.ln = nn.LayerNorm(dims)
self.moe = MOE(config, dims, mlp_dims)
def __call__(self, x, mask, cache):
h = self.ln(x)
attn_h = self.mixer(h, mask, cache)
ff_h = self.moe(h)
return attn_h + ff_h + x
class TransformerDecoder(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.embd = Embd(config)
self.h = [ParallelBlock(config) for i in range(config.num_layers)]
def __call__(self, x, mask, cache):
x = self.embd(x)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
x = layer(x, mask, c)
return x
class Embd(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.wte = nn.Embedding(config.num_vocab, config.model_dim)
def __call__(self, x):
return self.wte(x)
class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.ln = nn.LayerNorm(config.model_dim)
self.linear = nn.Linear(config.model_dim, config.num_vocab)
def __call__(self, inputs):
return self.linear(self.ln(inputs))
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.transformer = TransformerDecoder(config)
self.lm_head = OutputHead(config)
self.args = config
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
y = self.transformer(x, mask, cache)
return self.lm_head(y)
def sanitize(self, weights):
if "transformer.h.0.moe.mlp.0.fc1.weight" not in weights:
return weights
for l in range(self.args.num_layers):
prefix = f"transformer.h.{l}"
for n in ["fc1", "fc2"]:
for k in ["weight", "scales", "biases", "bias"]:
if f"{prefix}.moe.mlp.0.{n}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.moe.mlp.{e}.{n}.{k}")
for e in range(self.args.num_local_experts)
]
weights[f"{prefix}.moe.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.transformer.h

211
llms/mlx_lm/models/plamo.py Normal file
View File

@@ -0,0 +1,211 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
n_shared_head: int = 8
rope_theta: float = 10000
rope_traditional: bool = False
class Attention(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
head_dim = self.hidden_size // config.num_attention_heads
self.q_num_heads = config.num_attention_heads
self.qk_dim = self.v_dim = head_dim
self.k_num_heads = self.v_num_heads = int(
np.ceil(self.q_num_heads / config.n_shared_head)
)
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(
self.hidden_size, self.q_num_heads * self.qk_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size, self.k_num_heads * self.qk_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size, self.v_num_heads * self.v_dim, bias=False
)
self.o_proj = nn.Linear(
self.q_num_heads * self.v_dim, self.hidden_size, bias=False
)
self.rotary_emb = nn.RoPE(
head_dim,
traditional=config.rope_traditional,
base=config.rope_theta,
scale=1.0,
)
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
bsz, q_len, _ = hidden_states.shape
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(bsz, q_len, self.q_num_heads, self.qk_dim).transpose(
0, 2, 1, 3
)
keys = keys.reshape(bsz, q_len, self.k_num_heads, self.qk_dim).transpose(
0, 2, 1, 3
)
values = values.reshape(bsz, q_len, self.v_num_heads, self.v_dim).transpose(
0, 2, 1, 3
)
if cache is not None:
queries = self.rotary_emb(queries, offset=cache.offset)
keys = self.rotary_emb(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys)
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
output = scaled_dot_product_attention(
queries,
keys,
values,
cache=cache,
scale=self.scale,
mask=attention_mask,
)
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def __call__(self, x: mx.array) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) # type: ignore
class PlamoDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.self_attn = Attention(config)
self.mlp = MLP(config)
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
):
# from LlamaDecoder
residual = hidden_states
hidden_states = self.norm(hidden_states)
# Self Attention
hidden_states_sa = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
cache=cache,
)
# Fully Connected
hidden_states_mlp = self.mlp(hidden_states)
hidden_states = residual + hidden_states_sa + hidden_states_mlp
return hidden_states
class PlamoDecoder(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.layers = [
PlamoDecoderLayer(config) for _ in range(config.num_hidden_layers)
]
class PlamoModel(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = PlamoDecoder(config) # type: ignore
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None for _ in range(len(self.layers.layers))]
for layer, c in zip(self.layers.layers, cache):
h = layer(h, mask, cache=c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs) -> None:
super().__init__()
self.model_type = args.model_type
self.model = PlamoModel(args)
self.lm_head: nn.Module = nn.Linear(
args.hidden_size, args.vocab_size, bias=False
)
self.args = args
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array:
out = self.model(inputs, cache)
return self.lm_head(out)
@property
def layers(self):
return self.model.layers.layers

158
llms/mlx_lm/models/qwen.py Normal file
View File

@@ -0,0 +1,158 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int = 2048
num_attention_heads: int = 16
num_hidden_layers: int = 24
kv_channels: int = 128
max_position_embeddings: int = 8192
layer_norm_epsilon: float = 1e-6
intermediate_size: int = 11008
no_bias: bool = True
vocab_size: int = 151936
num_key_value_heads = None
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
hidden_size = args.hidden_size
self.num_attention_heads = args.num_attention_heads
hidden_size_per_attention_head = hidden_size // self.num_attention_heads
self.rotary_emb = nn.RoPE(hidden_size_per_attention_head, traditional=False)
proj_size = args.kv_channels * self.num_attention_heads
self.c_attn = nn.Linear(hidden_size, proj_size * 3, bias=True)
self.c_proj = nn.Linear(hidden_size, proj_size, bias=not args.no_bias)
self.scale = hidden_size_per_attention_head**-0.5
def __call__(self, x, mask=None, cache=None):
qkv = self.c_attn(x)
q, k, v = mx.split(qkv, 3, axis=-1)
B, L, _ = q.shape
queries = q.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
keys = k.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
values = v.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rotary_emb(queries, offset=cache.offset)
keys = self.rotary_emb(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.c_proj(output)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.w1 = nn.Linear(
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
)
self.w2 = nn.Linear(
args.hidden_size, args.intermediate_size // 2, bias=not args.no_bias
)
self.c_proj = nn.Linear(
args.intermediate_size // 2, args.hidden_size, bias=not args.no_bias
)
def __call__(self, x):
a1 = self.w1(x)
a2 = self.w2(x)
return self.c_proj(a1 * nn.silu(a2))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.attn = Attention(args)
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.mlp = MLP(args)
def __call__(self, x, mask=None, cache=None):
residual = x
x = self.ln_1(x)
x = self.attn(x, mask=mask, cache=cache)
residual = x + residual
x = self.ln_2(residual)
x = self.mlp(x)
x = x + residual
return x
class QwenModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs)
mask = create_attention_mask(x, cache)
if cache is None:
cache = [None] * len(self.h)
for layer, c in zip(self.h, cache):
x = layer(x, mask, c)
return self.ln_f(x)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.transformer = QwenModel(config)
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=not config.no_bias
)
self.args = config
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
y = self.transformer(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.transformer.h

198
llms/mlx_lm/models/qwen2.py Normal file
View File

@@ -0,0 +1,198 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: Optional[int] = None
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = True
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
rope_scale = (
1 / args.rope_scaling["factor"]
if args.rope_scaling is not None and args.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class Qwen2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Qwen2Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
def sanitize(self, weights):
if self.args.tie_word_embeddings:
weights.pop("lm_head.weight", None)
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def layers(self):
return self.model.layers

View File

@@ -0,0 +1,238 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_experts_per_tok: int
num_experts: int
moe_intermediate_size: int
shared_expert_intermediate_size: int
rms_norm_eps: float
vocab_size: int
num_key_value_heads: Optional[int] = None
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
tie_word_embeddings: bool = False
def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")
if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class Qwen2MoeSparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
dim = args.hidden_size
intermediate_size = args.moe_intermediate_size
shared_expert_intermediate_size = args.shared_expert_intermediate_size
self.num_experts = num_experts = args.num_experts
self.top_k = args.num_experts_per_tok
self.gate = nn.Linear(dim, num_experts, bias=False)
self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
self.shared_expert = MLP(dim, shared_expert_intermediate_size)
self.shared_expert_gate = nn.Linear(dim, 1, bias=False)
def __call__(
self,
x: mx.array,
):
gates = self.gate(x)
gates = mx.softmax(gates, axis=-1, precise=True)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k])
scores = mx.take_along_axis(gates, inds, axis=-1)
y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
shared_expert_output = self.shared_expert(x)
shared_expert_output = (
mx.sigmoid(self.shared_expert_gate(x)) * shared_expert_output
)
return y + shared_expert_output
class Qwen2MoeDecoderLayer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = Qwen2MoeSparseMoeBlock(args)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class Qwen2MoeModel(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
Qwen2MoeDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Qwen2MoeModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
def sanitize(self, weights):
if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights:
return weights
for l in range(self.args.num_hidden_layers):
prefix = f"model.layers.{l}"
for n in ["up_proj", "down_proj", "gate_proj"]:
for k in ["weight", "scales", "biases"]:
if f"{prefix}.mlp.experts.0.{n}.{k}" in weights:
to_join = [
weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}")
for e in range(self.args.num_experts)
]
weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join)
return weights
@property
def layers(self):
return self.model.layers

View File

@@ -0,0 +1,456 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import List, Literal, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import MambaCache, RotatingKVCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
attention_bias: bool
conv1d_width: int
hidden_size: int
intermediate_size: int
logits_soft_cap: float
num_attention_heads: int
num_hidden_layers: int
num_key_value_heads: int
rms_norm_eps: float
rope_theta: float
attention_window_size: int
vocab_size: int
embeddings_scale_by_sqrt_dim: bool = True
block_types: Optional[List[str]] = None
_block_types: Optional[List[str]] = None
def __post_init__(self):
# For some reason these have different names in 2B and 9B
if self.block_types is None:
self.block_types = self._block_types
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
def rnn_scan(x, a, h0):
assert x.ndim == 3
assert a.shape == x.shape[-a.ndim :]
assert a.dtype == x.dtype
if x.shape[1] == 1:
# Using scan in sampling mode.
if h0 is None:
return x, x[:, 0]
else:
y = a * h0[:, None] + x
return y, y[:, -1]
else:
# Using scan in linear mode.
if h0 is not None:
h_t = h0
else:
B, _, D = x.shape
h_t = mx.zeros((B, D), dtype=x.dtype)
y = mx.zeros_like(x)
for t in range(x.shape[1]):
h_t = a[:, t] * h_t + x[:, t]
y[:, t] = h_t
return y, h_t
class Conv1d(nn.Module):
def __init__(
self,
channels: int,
kernel_size: int,
):
super().__init__()
self.weight = mx.zeros((channels, kernel_size, 1))
self.bias = mx.zeros((channels,))
def __call__(self, x, cache=None):
B, L, C = x.shape
groups, K, _ = self.weight.shape
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=groups)
y = y + self.bias
return y, x[:, -K + 1 :, :]
class RGLRU(nn.Module):
"""A Real-Gated Linear Recurrent Unit (RG-LRU) layer."""
def __init__(
self,
width: int,
num_heads: int,
):
super().__init__()
self.width = width
self.num_heads = num_heads
self.head_dim = self.width // self.num_heads
self.recurrent_param = mx.zeros((self.width,))
self.input_gate_weight = mx.zeros(
(self.num_heads, self.head_dim, self.head_dim),
)
self.input_gate_bias = mx.zeros((self.num_heads, self.head_dim))
self.recurrent_gate_weight = mx.zeros(
(self.num_heads, self.head_dim, self.head_dim),
)
self.recurrent_gate_bias = mx.zeros((self.num_heads, self.head_dim))
def __call__(
self,
x: mx.array,
cache=None,
):
B, L, _ = x.shape
def apply_block_linear(h, w, b):
h = h.reshape((B, L, self.num_heads, self.head_dim))
h = (h.swapaxes(1, 2) @ w).swapaxes(1, 2) + b
return mx.sigmoid(h.flatten(2, 3))
# Gates for x and a.
gate_x = apply_block_linear(x, self.input_gate_weight, self.input_gate_bias)
gate_a = apply_block_linear(
x, self.recurrent_gate_weight, self.recurrent_gate_bias
)
# Compute the parameter `A` of the recurrence.
log_a = -8.0 * gate_a * nn.softplus(self.recurrent_param)
a = mx.exp(log_a)
a_square = mx.exp(2 * log_a)
# Gate the input.
gated_x = x * gate_x
# Apply gamma normalization to the input.
multiplier = mx.sqrt(1 - a_square)
if cache is None:
multiplier[:, 0, :] = 1.0
normalized_x = gated_x * multiplier.astype(x.dtype)
y, last_h = rnn_scan(
x=normalized_x,
a=a,
h0=cache,
)
return y, last_h
class RecurrentBlock(nn.Module):
def __init__(
self,
width: int,
num_heads: int,
lru_width: int = None,
conv1d_temporal_width: int = 4,
):
super().__init__()
self.width = width
self.num_heads = num_heads
self.lru_width = lru_width or width
self.conv1d_temporal_width = conv1d_temporal_width
self.linear_y = nn.Linear(width, self.lru_width)
self.linear_x = nn.Linear(width, self.lru_width)
self.linear_out = nn.Linear(self.lru_width, width)
self.conv_1d = Conv1d(
channels=self.lru_width,
kernel_size=self.conv1d_temporal_width,
)
self.rg_lru = RGLRU(
width=self.lru_width,
num_heads=self.num_heads,
)
def __call__(
self,
x: mx.array,
cache=None,
mask=None,
):
# y branch.
y = self.linear_y(x)
y = nn.gelu_approx(y)
# x branch.
x = self.linear_x(x)
if cache is None:
cache = [None, None]
x, cache[0] = self.conv_1d(x=x, cache=cache[0])
x, cache[1] = self.rg_lru(x=x, cache=cache[1])
x = x * y
x = self.linear_out(x)
return x
class LocalAttentionBlock(nn.Module):
def __init__(
self,
width: int,
num_heads: int,
window_size: int,
):
super().__init__()
self.width = width
self.num_heads = num_heads
self.window_size = window_size
self.scale = (width // num_heads) ** (-0.5)
self.head_dim = self.width // self.num_heads
self.q_proj = nn.Linear(self.width, self.width, bias=False)
self.k_proj = nn.Linear(self.width, self.head_dim, bias=False)
self.v_proj = nn.Linear(self.width, self.head_dim, bias=False)
self.o_proj = nn.Linear(self.width, self.width, bias=True)
self.rope = nn.RoPE(
self.head_dim // 2,
traditional=False,
)
def __call__(
self,
x: mx.array,
cache=None,
mask=None,
):
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, 1, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, 1, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLPBlock(nn.Module):
def __init__(self, width: int, expanded_width: int):
super().__init__()
self.up_proj = nn.Linear(width, expanded_width // 2)
self.gate_proj = nn.Linear(width, expanded_width // 2)
self.down_proj = nn.Linear(expanded_width // 2, width)
def __call__(self, x: mx.array):
gate = self.gate_proj(x)
x = self.up_proj(x)
return self.down_proj(nn.gelu_approx(gate) * x)
class ResidualBlock(nn.Module):
def __init__(
self,
width: int,
mlp_expanded_width: int,
num_heads: int,
attention_window_size: int,
temporal_block_type: str,
lru_width: Optional[int] = None,
conv1d_temporal_width: int = 4,
):
"""Initializes the residual block.
Args:
width: The width of the block.
mlp_expanded_width: The width of the expansion inside the MLP block.
num_heads: The number of heads for the Attention or the RG-LRU.
attention_window_size: The window size for the local attention block.
temporal_block_type: Either "recurrent" or "attention", specifying the
type of recurrent block to use.
lru_width: The width of the RG-LRU if different from `width`.
conv1d_temporal_width: The width of the temporal convolution.
"""
super().__init__()
self.width = width
self.mlp_expanded_width = mlp_expanded_width
self.num_heads = num_heads
self.attention_window_size = attention_window_size
self.temporal_block_type = temporal_block_type
self.lru_width = lru_width
self.conv1d_temporal_width = conv1d_temporal_width
self.temporal_pre_norm = RMSNorm(width)
if self.temporal_block_type == "recurrent":
self.temporal_block = RecurrentBlock(
width=self.width,
num_heads=self.num_heads,
lru_width=self.lru_width,
conv1d_temporal_width=self.conv1d_temporal_width,
)
else:
self.temporal_block = LocalAttentionBlock(
width=self.width,
num_heads=self.num_heads,
window_size=self.attention_window_size,
)
self.channel_pre_norm = RMSNorm(width)
self.mlp_block = MLPBlock(
width=self.width,
expanded_width=self.mlp_expanded_width,
)
def __call__(
self,
x: mx.array,
cache=None,
mask=None,
):
raw_x = x
inputs_normalized = self.temporal_pre_norm(raw_x)
x = self.temporal_block(inputs_normalized, cache=cache, mask=mask)
residual = x + raw_x
x = self.channel_pre_norm(residual)
x = self.mlp_block(x)
x = x + residual
return x
class Griffin(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(
config.vocab_size,
config.hidden_size,
)
self.scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim
block_types = config.block_types
self.layers = [
ResidualBlock(
width=config.hidden_size,
mlp_expanded_width=config.intermediate_size,
num_heads=config.num_attention_heads,
attention_window_size=config.attention_window_size,
temporal_block_type=block_types[i % len(block_types)],
lru_width=None,
)
for i in range(config.num_hidden_layers)
]
self.final_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
tokens,
cache=None,
):
x = self.embed_tokens(tokens)
if self.scale_by_sqrt_dim:
x = x * math.sqrt(x.shape[-1])
if cache is None:
cache = [None] * len(self.layers)
for i, block in enumerate(self.layers):
if block.temporal_block_type != "recurrent":
mask_cache = [cache[i]]
mask = create_attention_mask(x, mask_cache)
for i, block in enumerate(self.layers):
x = block(x, mask=mask, cache=cache[i])
return self.final_norm(x)
class Model(nn.Module):
def __init__(self, config):
self.args = config
self.model = Griffin(config)
self.model_type = config.model_type
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
"""
Args:
tokens: Sequence of input tokens.
"""
logits = self.model(tokens, cache=cache)
if "lm_head" in self:
logits = self.lm_head(logits)
else:
logits = self.model.embed_tokens.as_linear(logits)
c = self.args.logits_soft_cap
if c:
logits = mx.tanh(logits / c) * c
return logits
@property
def layers(self):
return self.model.layers
def sanitize(self, weights):
for k, v in weights.items():
if "conv_1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
if "lm_head.weight" not in weights:
self.pop("lm_head")
return weights
def make_cache(self):
cache = []
for layer in self.layers:
if layer.temporal_block_type == "recurrent":
cache.append(MambaCache())
else:
cache.append(RotatingKVCache(max_size=self.args.attention_window_size))
return cache

View File

@@ -0,0 +1,208 @@
# Copyright © 2023-2024 Apple Inc.
import math
from dataclasses import dataclass
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
num_attention_heads: int
num_hidden_layers: int
num_key_value_heads: int
intermediate_size: int
rope_theta: float
use_qkv_bias: bool
partial_rotary_factor: float
layer_norm_eps: float
use_parallel_residual: bool = False
qk_layernorm: bool = False
class LayerNormPerHead(nn.Module):
def __init__(self, head_dim, num_heads, eps):
super().__init__()
self.norms = [
nn.LayerNorm(head_dim, eps=eps, bias=False) for _ in range(num_heads)
]
self.eps = eps
def __call__(self, x):
w = mx.stack([n.weight for n in self.norms])
return w * mx.fast.layer_norm(x, None, None, self.eps)
class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.use_qkv_bias
)
self.k_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.use_qkv_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.use_qkv_bias,
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
self.rope = nn.RoPE(
int(self.partial_rotary_factor * self.head_dim),
traditional=False,
base=self.rope_theta,
)
self.qk_layernorm = config.qk_layernorm
if self.qk_layernorm:
self.q_layernorm = LayerNormPerHead(
self.head_dim, self.num_heads, eps=config.layer_norm_eps
)
self.k_layernorm = LayerNormPerHead(
self.head_dim, self.num_key_value_heads, eps=config.layer_norm_eps
)
def __call__(self, x, mask=None, cache=None):
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Extract some shapes
B, L, D = queries.shape
queries = queries.reshape(B, L, self.num_heads, -1)
keys = keys.reshape(B, L, self.num_key_value_heads, -1)
if self.qk_layernorm:
queries = self.q_layernorm(queries)
keys = self.k_layernorm(keys)
queries = queries.transpose(0, 2, 1, 3)
keys = keys.transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(
0, 2, 1, 3
)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
queries = queries.astype(mx.float32)
keys = keys.astype(mx.float32)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=scale, mask=mask
).astype(values.dtype)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class DecoderLayer(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.self_attn = Attention(config=config)
self.mlp = MLP(config.hidden_size, config.intermediate_size)
self.input_layernorm = nn.LayerNorm(
config.hidden_size,
eps=config.layer_norm_eps,
)
self.use_parallel_residual = config.use_parallel_residual
if not self.use_parallel_residual:
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size,
eps=config.layer_norm_eps,
)
def __call__(self, x, mask, cache):
h = self.input_layernorm(x)
r = self.self_attn(h, mask, cache)
if self.use_parallel_residual:
out = x + r + self.mlp(h)
else:
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class StableLM(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)]
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def __call__(self, x, mask, cache):
x = self.embed_tokens(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, mask, cache=c)
return self.norm(x)
class Model(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.model_type = config.model_type
self.model = StableLM(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.args = config
def __call__(
self,
x: mx.array,
mask: mx.array = None,
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.model.layers

View File

@@ -0,0 +1,166 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
num_key_value_heads: int
norm_epsilon: float = 1e-5
vocab_size: int = 49152
rope_theta: float = 100000
tie_word_embeddings: bool = True
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // args.num_attention_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=True)
self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_theta)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.c_fc = nn.Linear(dim, hidden_dim, bias=True)
self.c_proj = nn.Linear(hidden_dim, dim, bias=True)
def __call__(self, x):
return self.c_proj(nn.gelu(self.c_fc(x)))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.hidden_size = args.hidden_size
self.n_heads = args.num_attention_heads
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.norm_epsilon
)
self.args = args
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out
class Starcoder2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
def __call__(
self,
inputs: mx.array,
cache=None,
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
return self.norm(h)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.model = Starcoder2Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(
self,
inputs: mx.array,
cache=None,
):
out = self.model(inputs, cache)
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out
@property
def layers(self):
return self.model.layers

View File

@@ -0,0 +1,64 @@
# Copyright © 2023-2024 Apple Inc.
import math
from typing import List, Union
import mlx.core as mx
import mlx.nn as nn
class SuScaledRotaryEmbedding(nn.Module):
def __init__(
self,
dims: int,
base: float = 10000.0,
max_position_embeddings: int = 131072,
original_max_position_embeddings: int = 4096,
short_factor: Union[List[float], float] = 1.0,
long_factor: Union[List[float], float] = 1.0,
short_mscale: float = None,
long_mscale: float = None,
):
"""
Phi3Su Scaled Rotary Embedding layer for Phi-3 models.
Args:
dims (int): The feature dimensions to be rotated.
base (int, optional): Base for the exponential scaling.
max_position_embeddings (int, optional): The maximum sequence
length that this model was trained with. This is used to determine
the size of the original RoPE embeddings when using long scaling.
Default: ``131072``.
original_max_position_embeddings (int, optional): The maximum
sequence length that this model was trained with. This is used to
determine the size of the original RoPE embeddings when using long
scaling. Default: ``4096``.
short_factor (float or list[float], optional): List of scaling
factors for sequences of length lesser than
``original_max_position_embeddings``. Default: ``1.0``.
long_factor (float or list[float], optional): List of scaling
factors for sequences of length greater than
``original_max_position_embeddings``. Default: ``1.0``.
short_mscale (float, optional): Scale the input prior to embedding.
long_mscale (float, optional): Scale the input prior to embedding.
"""
super().__init__()
freqs = base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
self._freqs = mx.array(long_factor, dtype=mx.float32) * freqs
self.original_max_position_embeddings = original_max_position_embeddings
self.scale = long_mscale or math.sqrt(
1
+ math.log(max_position_embeddings / original_max_position_embeddings)
/ math.log(original_max_position_embeddings)
)
def __call__(self, x, offset: int = 0):
return mx.fast.rope(
self.scale * x,
x.shape[-1],
traditional=False,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)

View File

@@ -0,0 +1,167 @@
# Copyright © 2023-2024 Apple Inc.
import math
import mlx.core as mx
import mlx.nn as nn
class QuantizedSwitchLinear(nn.Module):
def __init__(
self,
input_dims: int,
output_dims: int,
num_experts: int,
bias: bool = True,
group_size: int = 64,
bits: int = 4,
):
super().__init__()
scale = math.sqrt(1 / input_dims)
self.weight, self.scales, self.biases = mx.quantize(
mx.random.uniform(
low=-scale,
high=scale,
shape=(num_experts, output_dims, input_dims),
),
group_size=group_size,
bits=bits,
)
if bias:
self.bias = mx.zeros((num_experts, output_dims))
self.group_size = group_size
self.bits = bits
# Freeze this model's parameters
self.freeze()
def unfreeze(self, *args, **kwargs):
"""Wrap unfreeze so that we unfreeze any layers we might contain but
our parameters will remain frozen."""
super().unfreeze(*args, **kwargs)
self.freeze(recurse=False)
@property
def input_dims(self):
return self.scales.shape[2] * self.group_size
@property
def output_dims(self):
return self.weight.shape[1]
@property
def num_experts(self):
return self.weight.shape[0]
def __call__(self, x, indices):
x = mx.gather_qmm(
x,
self["weight"],
self["scales"],
self["biases"],
rhs_indices=indices,
transpose=True,
group_size=self.group_size,
bits=self.bits,
)
if "bias" in self:
x = x + mx.expand_dims(self["bias"][indices], -2)
return x
class SwitchLinear(nn.Module):
def __init__(
self, input_dims: int, output_dims: int, num_experts: int, bias: bool = True
):
super().__init__()
scale = math.sqrt(1 / input_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(num_experts, output_dims, input_dims),
)
if bias:
self.bias = mx.zeros((num_experts, output_dims))
@property
def input_dims(self):
return self.weight.shape[2]
@property
def output_dims(self):
return self.weight.shape[1]
@property
def num_experts(self):
return self.weight.shape[0]
def __call__(self, x, indices):
x = mx.gather_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices)
if "bias" in self:
x = x + mx.expand_dims(self["bias"][indices], -2)
return x
def to_quantized(self, group_size: int = 64, bits: int = 4):
num_experts, output_dims, input_dims = self.weight.shape
ql = QuantizedSwitchLinear(
input_dims, output_dims, num_experts, False, group_size, bits
)
ql.weight, ql.scales, ql.biases = mx.quantize(self.weight, group_size, bits)
if "bias" in self:
ql.bias = self.bias
return ql
class SwitchGLU(nn.Module):
def __init__(
self,
input_dims: int,
hidden_dims: int,
num_experts: int,
activation=nn.silu,
bias: bool = False,
):
super().__init__()
self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
self.activation = activation
def __call__(self, x, indices) -> mx.array:
x = mx.expand_dims(x, (-2, -3))
x_up = self.up_proj(x, indices)
x_gate = self.gate_proj(x, indices)
x = self.down_proj(self.activation(x_gate) * x_up, indices)
return x.squeeze(-2)
class SwitchMLP(nn.Module):
def __init__(
self,
input_dims: int,
hidden_dims: int,
num_experts: int,
activation=nn.gelu_approx,
bias: bool = False,
):
super().__init__()
self.fc1 = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias)
self.fc2 = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias)
self.activation = activation
def __call__(self, x, indices) -> mx.array:
x = mx.expand_dims(x, (-2, -3))
x = self.fc1(x, indices)
x = self.activation(x)
x = self.fc2(x, indices)
return x.squeeze(-2)

1
llms/mlx_lm/py.typed Normal file
View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,6 @@
mlx>=0.19.2
numpy
transformers[sentencepiece]>=4.39.3
protobuf
pyyaml
jinja2

208
llms/mlx_lm/sample_utils.py Normal file
View File

@@ -0,0 +1,208 @@
# Copyright © 2023-2024 Apple Inc.
from functools import partial
from typing import Callable, Dict, Optional
import mlx.core as mx
def make_sampler(
temp: float = 0.0,
top_p: float = 0.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
) -> Callable[mx.array, mx.array]:
"""
Make a sampler function for use with ``generate_step``.
Args:
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
Returns:
Callable[mx.array, mx.array]:
A sampler which takes log-probabilities and returns tokens.
"""
if temp == 0:
return lambda x: mx.argmax(x, axis=-1)
elif top_p > 0 and top_p < 1.0:
return lambda x: top_p_sampling(x, top_p, temp)
elif min_p != 0.0:
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
else:
return lambda x: categorical_sampling(x, temp)
def make_logits_processors(
logit_bias: Optional[Dict[int, float]] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
):
"""
Make logits processors for use with ``generate_step``.
Args:
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
logit_bias (dictionary, optional): Additive logit bias.
Returns:
List[Callable[[mx.array, mx.array], mx.array]]:
A list of logits processors. Each processor in the list is a
callable which takes an array of tokens and an array of logits
and returns the updated logits.
"""
logits_processors = []
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
def logit_bias_processor(_, logits):
logits[:, indices] += values
return logits
logits_processors.append(logit_bias_processor)
if repetition_penalty and repetition_penalty != 0.0:
logits_processors.append(
make_repetition_penalty(repetition_penalty, repetition_context_size)
)
return logits_processors
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling(
logits: mx.array,
min_p: float,
min_tokens_to_keep: int = 1,
temperature=1.0,
) -> mx.array:
"""
Apply min-p sampling to the logits.
Min-p keeps all tokens that are above a minimum probability, scaled by the
probability of the most likely token. As a result, the filter is more
aggressive given a very high-probability token.
Args:
logits: The logits from the model's output.
min_p (float): Minimum token probability. Typical values are in the
0.01-0.2 range, comparably selective as setting `top_p` in the
0.99-0.8 range.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered. Default: ``1``.
"""
if not (0 <= min_p <= 1.0):
raise ValueError(
f"`min_p` has to be a float in the [0, 1] interval, but is {min_p}"
)
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
raise ValueError(
f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}"
)
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605
# Softmax probabilities
probs = mx.softmax(logits * (1 / temperature), axis=-1)
# Indices sorted in decreasing order
sorted_indices = mx.argsort(-logits).squeeze(0)
sorted_probs = probs[..., sorted_indices]
# Top probability
top_probs = probs[..., sorted_indices[0]]
# Calculate the min_p threshold
scaled_min_p = min_p * top_probs
# Mask tokens that have a probability less than the scaled min_p
tokens_to_remove = sorted_probs < scaled_min_p
tokens_to_remove[..., :min_tokens_to_keep] = False
# Create pool of tokens with probability less than scaled min_p
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs)
# Return sampled token
sorted_token = mx.random.categorical(mx.log(selected_probs))
return sorted_indices[sorted_token]
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
"""
Apply top-p (nucleus) sampling to logits.
Args:
logits: The logits from the model's output.
top_p: The cumulative probability threshold for top-p filtering.
temperature: Temperature parameter for softmax distribution reshaping.
Returns:
token selected based on the top-p criterion.
"""
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
probs = mx.softmax(logits * (1 / temperature), axis=-1)
# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=-1)
sorted_probs = probs[..., sorted_indices.squeeze(0)]
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
# select tokens with cumulative probs below threshold
top_probs = mx.where(
cumulative_probs > 1 - top_p,
sorted_probs,
0,
)
sorted_token = mx.random.categorical(mx.log(top_probs))
token = sorted_indices.squeeze(0)[sorted_token]
return token
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def categorical_sampling(logits, temp):
return mx.random.categorical(logits * (1 / temp))
def make_repetition_penalty(penalty: float, context_size: int = 20):
"""
Make repetition penalty processor.
Paper: https://arxiv.org/abs/1909.05858
Args:
penalty (float): The repetition penalty factor to be applied.
context_size (int): The number of previous tokens to use.
Default: ``20``.
Returns:
Callable[[mx.array, List[int]], mx.array]:
The repetition penalty processor.
"""
if penalty < 0 or not isinstance(penalty, float):
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
def repetition_penalty_processor(tokens, logits):
if len(tokens) > 0:
tokens = tokens[-context_size:]
selected_logits = logits[:, tokens]
selected_logits = mx.where(
selected_logits < 0,
selected_logits * penalty,
selected_logits / penalty,
)
logits[:, tokens] = selected_logits
return logits
return repetition_penalty_processor

766
llms/mlx_lm/server.py Normal file
View File

@@ -0,0 +1,766 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import logging
import platform
import time
import uuid
import warnings
from dataclasses import dataclass, field
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import (
Any,
Dict,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
import mlx.core as mx
from huggingface_hub import scan_cache_dir
from ._version import __version__
from .models.cache import make_prompt_cache
from .utils import load, stream_generate
def get_system_fingerprint():
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"
class StopCondition(NamedTuple):
stop_met: bool
trim_length: int
def stopping_criteria(
tokens: List[int],
stop_id_sequences: List[List[int]],
eos_token_id: Union[int, None],
) -> StopCondition:
"""
Determines whether the token generation should stop based on predefined
conditions.
Args:
tokens (List[int]): The current sequence of generated tokens.
stop_id_sequences (List[List[[int]]): A list of integer lists, each
representing a sequence of token IDs. If the end of the `tokens`
list matches any of these sequences, the generation should stop.
eos_token_id (Union[int, None]): The token ID that represents the
end-of-sequence. If the last token in `tokens` matches this, the
generation should stop.
Returns:
StopCondition: A named tuple indicating whether the stop condition has
been met (`stop_met`) and how many tokens should be trimmed from the
end if it has (`trim_length`).
"""
if tokens and tokens[-1] == eos_token_id:
return StopCondition(stop_met=True, trim_length=0)
for stop_ids in stop_id_sequences:
if len(tokens) >= len(stop_ids):
if tokens[-len(stop_ids) :] == stop_ids:
return StopCondition(stop_met=True, trim_length=len(stop_ids))
return StopCondition(stop_met=False, trim_length=0)
def sequence_overlap(s1: Sequence, s2: Sequence) -> bool:
"""
Checks if a suffix of s1 has overlap with a prefix of s2
Args:
s1 (Sequence): The first sequence
s2 (Sequence): The second sequence
Returns:
bool: If the two sequences have overlap
"""
max_overlap = min(len(s1), len(s2))
return any(s1[-i:] == s2[:i] for i in range(1, max_overlap + 1))
def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
default_role_mapping = {
"system_prompt": (
"A chat between a curious user and an artificial intelligence "
"assistant. The assistant follows the given rules no matter what."
),
"system": "ASSISTANT's RULE: ",
"user": "USER: ",
"assistant": "ASSISTANT: ",
"stop": "\n",
}
role_mapping = role_mapping if role_mapping is not None else default_role_mapping
prompt = ""
for line in messages:
role_prefix = role_mapping.get(line["role"], "")
stop = role_mapping.get("stop", "")
content = line.get("content", "")
prompt += f"{role_prefix}{content}{stop}"
prompt += role_mapping.get("assistant", "")
return prompt.rstrip()
@dataclass
class PromptCache:
cache: List[Any] = field(default_factory=list)
model_key: Tuple[str, Optional[str]] = ("", None)
tokens: List[int] = field(default_factory=list)
class ModelProvider:
def __init__(self, cli_args: argparse.Namespace):
"""Load models on demand and persist them across the whole process."""
self.cli_args = cli_args
self.model_key = None
self.model = None
self.tokenizer = None
# Preload the default model if it is provided
if self.cli_args.model is not None:
self.load("default_model")
def _validate_model_path(self, model_path: str):
model_path = Path(model_path)
if model_path.exists() and not model_path.is_relative_to(Path.cwd()):
raise RuntimeError(
"Local models must be relative to the current working dir."
)
# Added in adapter_path to load dynamically
def load(self, model_path, adapter_path=None):
if self.model_key == (model_path, adapter_path):
return self.model, self.tokenizer
# Remove the old model if it exists.
self.model = None
self.tokenizer = None
self.model_key = None
# Building tokenizer_config
tokenizer_config = {
"trust_remote_code": True if self.cli_args.trust_remote_code else None
}
if self.cli_args.chat_template:
tokenizer_config["chat_template"] = self.cli_args.chat_template
if model_path == "default_model" and self.cli_args.model is not None:
model, tokenizer = load(
self.cli_args.model,
adapter_path=(
adapter_path if adapter_path else self.cli_args.adapter_path
), # if the user doesn't change the model but adds an adapter path
tokenizer_config=tokenizer_config,
)
else:
self._validate_model_path(model_path)
model, tokenizer = load(
model_path, adapter_path=adapter_path, tokenizer_config=tokenizer_config
)
if self.cli_args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
self.model_key = (model_path, adapter_path)
self.model = model
self.tokenizer = tokenizer
return self.model, self.tokenizer
class APIHandler(BaseHTTPRequestHandler):
def __init__(
self,
model_provider: ModelProvider,
*args,
prompt_cache: Optional[PromptCache] = None,
system_fingerprint: Optional[str] = None,
**kwargs,
):
"""
Create static request specific metadata
"""
self.created = int(time.time())
self.model_provider = model_provider
self.prompt_cache = prompt_cache or PromptCache()
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
super().__init__(*args, **kwargs)
def _set_cors_headers(self):
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Methods", "*")
self.send_header("Access-Control-Allow-Headers", "*")
def _set_completion_headers(self, status_code: int = 200):
self.send_response(status_code)
self.send_header("Content-type", "application/json")
self._set_cors_headers()
def _set_stream_headers(self, status_code: int = 200):
self.send_response(status_code)
self.send_header("Content-type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self._set_cors_headers()
def do_OPTIONS(self):
self._set_completion_headers(204)
self.end_headers()
def do_POST(self):
"""
Respond to a POST request from a client.
"""
endpoints = {
"/v1/completions": self.handle_text_completions,
"/v1/chat/completions": self.handle_chat_completions,
"/chat/completions": self.handle_chat_completions,
}
if self.path not in endpoints:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
return
# Fetch and parse request body
content_length = int(self.headers["Content-Length"])
raw_body = self.rfile.read(content_length)
self.body = json.loads(raw_body.decode())
indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Incoming Request Body: {json.dumps(self.body, indent=indent)}")
assert isinstance(
self.body, dict
), f"Request should be dict, but got {type(self.body)}"
# Extract request parameters from the body
self.stream = self.body.get("stream", False)
self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model")
self.adapter = self.body.get("adapters", None)
self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 0.0)
self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
self.repetition_context_size = self.body.get("repetition_context_size", 20)
self.logit_bias = self.body.get("logit_bias", None)
self.logprobs = self.body.get("logprobs", -1)
self.validate_model_parameters()
# Load the model if needed
try:
self.model, self.tokenizer = self.model_provider.load(
self.requested_model, self.adapter
)
except:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
return
# Get stop id sequences, if provided
stop_words = self.body.get("stop")
stop_words = stop_words or []
stop_words = [stop_words] if isinstance(stop_words, str) else stop_words
stop_id_sequences = [
self.tokenizer.encode(stop_word, add_special_tokens=False)
for stop_word in stop_words
]
# Send header type
(
self._set_stream_headers(200)
if self.stream
else self._set_completion_headers(200)
)
# Call endpoint specific method
prompt = endpoints[self.path]()
self.handle_completion(prompt, stop_id_sequences)
def validate_model_parameters(self):
"""
Validate the model parameters passed in the request for the correct types and values.
"""
if not isinstance(self.stream, bool):
raise ValueError("stream must be a boolean")
if not isinstance(self.max_tokens, int) or self.max_tokens < 0:
raise ValueError("max_tokens must be a non-negative integer")
if not isinstance(self.temperature, (float, int)) or self.temperature < 0:
raise ValueError("temperature must be a non-negative float")
if not isinstance(self.top_p, (float, int)) or self.top_p < 0 or self.top_p > 1:
raise ValueError("top_p must be a float between 0 and 1")
if (
not isinstance(self.repetition_penalty, (float, int))
or self.repetition_penalty < 0
):
raise ValueError("repetition_penalty must be a non-negative float")
if self.logprobs != -1 and not (0 < self.logprobs <= 10):
raise ValueError(
f"logprobs must be between 1 and 10 but got {self.logprobs:,}"
)
if (
not isinstance(self.repetition_context_size, int)
or self.repetition_context_size < 0
):
raise ValueError("repetition_context_size must be a non-negative integer")
if self.logit_bias is not None:
if not isinstance(self.logit_bias, dict):
raise ValueError("logit_bias must be a dict of int to float")
try:
self.logit_bias = {int(k): v for k, v in self.logit_bias.items()}
except ValueError:
raise ValueError("logit_bias must be a dict of int to float")
if not isinstance(self.requested_model, str):
raise ValueError("model must be a string")
if self.adapter is not None and not isinstance(self.adapter, str):
raise ValueError("adapter must be a string")
def generate_response(
self,
text: str,
finish_reason: Union[Literal["length", "stop"], None],
prompt_token_count: Optional[int] = None,
completion_token_count: Optional[int] = None,
token_logprobs: Optional[List[float]] = None,
top_tokens: Optional[List[Dict[int, float]]] = None,
tokens: Optional[List[int]] = None,
) -> dict:
"""
Generate a single response packet based on response type (stream or
not), completion type and parameters.
Args:
text (str): Text generated by model
finish_reason (Union[Literal["length", "stop"], None]): The reason the
response is being sent: "length", "stop" or `None`.
prompt_token_count (Optional[int]): The number of tokens in the prompt,
used to populate the "usage" field (not used when stream).
completion_token_count (Optional[int]): The number of tokens in the
response, used to populate the "usage" field (not used when stream).
token_logprobs (Optional[List[float]]): The log probabilities per token,
in token order.
top_tokens (Optional[List[Dict[int, float]]]): List of dictionaries mapping
tokens to logprobs for the top N tokens at each token position.
tokens (Optional[List[int]]): List of tokens to return with logprobs structure
Returns:
dict: A dictionary containing the response, in the same format as
OpenAI's API.
"""
token_logprobs = token_logprobs if token_logprobs else []
top_logprobs = top_tokens if top_tokens else []
# Static response
response = {
"id": self.request_id,
"system_fingerprint": self.system_fingerprint,
"object": self.object_type,
"model": self.requested_model,
"created": self.created,
"choices": [
{
"index": 0,
"logprobs": {
"token_logprobs": token_logprobs,
"top_logprobs": top_logprobs,
"tokens": tokens,
},
"finish_reason": finish_reason,
}
],
}
if not self.stream:
if not (
isinstance(prompt_token_count, int)
and isinstance(completion_token_count, int)
):
raise ValueError(
"Response type is complete, but token counts not provided"
)
response["usage"] = {
"prompt_tokens": prompt_token_count,
"completion_tokens": completion_token_count,
"total_tokens": prompt_token_count + completion_token_count,
}
choice = response["choices"][0]
# Add dynamic response
if self.object_type.startswith("chat.completion"):
key_name = "delta" if self.stream else "message"
choice[key_name] = {"role": "assistant", "content": text}
elif self.object_type == "text_completion":
choice.update(text=text)
else:
ValueError(f"Unsupported response type: {self.object_type}")
return response
def get_prompt_cache(self, prompt):
cache_len = len(self.prompt_cache.tokens)
if (
self.prompt_cache.model_key != self.model_provider.model_key
or cache_len >= len(prompt)
or self.prompt_cache.tokens != prompt[:cache_len]
):
self.prompt_cache.model_key = self.model_provider.model_key
self.prompt_cache.cache = make_prompt_cache(self.model_provider.model)
else:
prompt = prompt[cache_len:]
self.prompt_cache.tokens.extend(prompt)
return prompt
def handle_completion(
self,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
Generate a response to a prompt and send it to the client in a single batch.
Args:
prompt (List[int]): The tokenized prompt.
stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function
"""
tokens = []
finish_reason = "length"
stop_sequence_suffix = None
if self.stream:
self.end_headers()
logging.debug(f"Starting stream:")
else:
logging.debug(f"Starting completion:")
token_logprobs = []
top_tokens = []
prompt = self.get_prompt_cache(prompt)
text = ""
tic = time.perf_counter()
for n, (segment, token, logprobs) in enumerate(
stream_generate(
model=self.model,
tokenizer=self.tokenizer,
prompt=prompt,
max_tokens=self.max_tokens,
temp=self.temperature,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache,
),
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
text += segment
logging.debug(text)
tokens.append(token)
if self.logprobs > 0:
sorted_indices = mx.argpartition(-logprobs, kth=self.logprobs - 1)
top_indices = sorted_indices[: self.logprobs]
top_logprobs = logprobs[top_indices]
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
top_tokens.append(tuple(top_token_info))
token_logprobs.append(logprobs[token].item())
stop_condition = stopping_criteria(
tokens, stop_id_sequences, self.tokenizer.eos_token_id
)
if stop_condition.stop_met:
finish_reason = "stop"
if stop_condition.trim_length:
stop_sequence_suffix = self.tokenizer.decode(
tokens[-stop_condition.trim_length :]
)
text = text[: -len(stop_sequence_suffix)]
break
if self.stream:
# If the end of tokens overlaps with a stop sequence, generate new
# tokens until we know if the stop sequence is hit or not
if any(
(
sequence_overlap(tokens, sequence)
for sequence in stop_id_sequences
)
):
continue
elif segment:
response = self.generate_response(segment, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.prompt_cache.tokens.extend(tokens)
gen_time = time.perf_counter() - tic
prompt_tps = len(prompt) / prompt_time
gen_tps = len(tokens) / gen_time
peak_mem = mx.metal.get_peak_memory() / 1e9
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
if self.stream:
response = self.generate_response(segment, finish_reason)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
if self.stream_options is not None and self.stream_options["include_usage"]:
response = self.completion_usage_response(len(prompt), len(tokens))
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.wfile.write("data: [DONE]\n\n".encode())
self.wfile.flush()
else:
response = self.generate_response(
text,
finish_reason,
len(prompt),
len(tokens),
token_logprobs=token_logprobs,
top_tokens=top_tokens,
tokens=tokens,
)
response_json = json.dumps(response).encode()
indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
# Send an additional Content-Length header when it is known
self.send_header("Content-Length", str(len(response_json)))
self.end_headers()
self.wfile.write(response_json)
self.wfile.flush()
def completion_usage_response(
self,
prompt_token_count: Optional[int] = None,
completion_token_count: Optional[int] = None,
):
response = {
"id": self.request_id,
"system_fingerprint": self.system_fingerprint,
"object": "chat.completion",
"model": self.requested_model,
"created": self.created,
"choices": [],
"usage": {
"prompt_tokens": prompt_token_count,
"completion_tokens": completion_token_count,
"total_tokens": prompt_token_count + completion_token_count,
},
}
return response
def handle_chat_completions(self) -> List[int]:
"""
Handle a chat completion request.
Returns:
mx.array: A mx.array of the tokenized prompt from the request body
"""
body = self.body
assert "messages" in body, "Request did not contain messages"
# Determine response type
self.request_id = f"chatcmpl-{uuid.uuid4()}"
self.object_type = (
"chat.completions.chunk" if self.stream else "chat.completions"
)
if (
hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template
):
prompt = self.tokenizer.apply_chat_template(
body["messages"],
body.get("tools", None),
tokenize=True,
add_generation_prompt=True,
)
else:
prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = self.tokenizer.encode(prompt)
return prompt
def handle_text_completions(self) -> List[int]:
"""
Handle a text completion request.
Returns:
mx.array: A mx.array of the tokenized prompt from the request body
"""
# Determine response type
self.request_id = f"cmpl-{uuid.uuid4()}"
self.object_type = "text_completion"
assert "prompt" in self.body, "Request did not contain a prompt"
return self.tokenizer.encode(self.body["prompt"])
def do_GET(self):
"""
Respond to a GET request from a client.
"""
if self.path == "/v1/models":
self.handle_models_request()
else:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(b"Not Found")
def handle_models_request(self):
"""
Handle a GET request for the /v1/models endpoint.
"""
self._set_completion_headers(200)
self.end_headers()
# Scan the cache directory for downloaded mlx models
hf_cache_info = scan_cache_dir()
downloaded_models = [
repo for repo in hf_cache_info.repos if "mlx" in repo.repo_id
]
# Create a list of available models
models = [
{
"id": repo.repo_id,
"object": "model",
"created": self.created,
}
for repo in downloaded_models
]
response = {"object": "list", "data": models}
response_json = json.dumps(response).encode()
self.wfile.write(response_json)
self.wfile.flush()
def run(
host: str,
port: int,
model_provider: ModelProvider,
server_class=HTTPServer,
handler_class=APIHandler,
):
server_address = (host, port)
prompt_cache = PromptCache()
httpd = server_class(
server_address,
lambda *args, **kwargs: handler_class(
model_provider,
prompt_cache=prompt_cache,
system_fingerprint=get_system_fingerprint(),
*args,
**kwargs,
),
)
warnings.warn(
"mlx_lm.server is not recommended for production as "
"it only implements basic security checks."
)
logging.info(f"Starting httpd at {host} on port {port}...")
httpd.serve_forever()
def main():
parser = argparse.ArgumentParser(description="MLX Http Server.")
parser.add_argument(
"--model",
type=str,
help="The path to the MLX model weights, tokenizer, and config",
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="Host for the HTTP server (default: 127.0.0.1)",
)
parser.add_argument(
"--port",
type=int,
default=8080,
help="Port for the HTTP server (default: 8080)",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Enable trusting remote code for tokenizer",
)
parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the logging level (default: INFO)",
)
parser.add_argument(
"--cache-limit-gb",
type=int,
default=None,
help="Set the MLX cache limit in GB",
required=False,
)
parser.add_argument(
"--chat-template",
type=str,
default="",
help="Specify a chat template for the tokenizer",
required=False,
)
parser.add_argument(
"--use-default-chat-template",
action="store_true",
help="Use the default chat template",
)
args = parser.parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level.upper(), None),
format="%(asctime)s - %(levelname)s - %(message)s",
)
if args.cache_limit_gb is not None:
logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB")
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
run(args.host, args.port, ModelProvider(args))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,345 @@
import json
from functools import partial
from transformers import AutoTokenizer
REPLACEMENT_CHAR = "\ufffd"
class StreamingDetokenizer:
"""The streaming detokenizer interface so that we can detokenize one token at a time.
Example usage is as follows:
detokenizer = ...
# Reset the tokenizer state
detokenizer.reset()
for token in generate(...):
detokenizer.add_token(token.item())
# Contains the whole text so far. Some tokens may not be included
# since it contains whole words usually.
detokenizer.text
# Contains the printable segment (usually a word) since the last
# time it was accessed
detokenizer.last_segment
# Contains all the tokens added so far
detokenizer.tokens
# Make sure that we detokenize any remaining tokens
detokenizer.finalize()
# Now detokenizer.text should match tokenizer.decode(detokenizer.tokens)
"""
__slots__ = ("text", "tokens", "offset")
def reset(self):
raise NotImplementedError()
def add_token(self, token):
raise NotImplementedError()
def finalize(self):
raise NotImplementedError()
@property
def last_segment(self):
"""Return the last segment of readable text since last time this property was accessed."""
text = self.text
if text and text[-1] != REPLACEMENT_CHAR:
segment = text[self.offset :]
self.offset = len(text)
return segment
return ""
class NaiveStreamingDetokenizer(StreamingDetokenizer):
"""NaiveStreamingDetokenizer relies on the underlying tokenizer
implementation and should work with every tokenizer.
Its complexity is O(T^2) where T is the longest line since it will
repeatedly detokenize the same tokens until a new line is generated.
"""
def __init__(self, tokenizer):
self._tokenizer = tokenizer
self._tokenizer.decode([0])
self.reset()
def reset(self):
self.offset = 0
self._tokens = []
self._text = ""
self._current_tokens = []
self._current_text = ""
def add_token(self, token):
self._current_tokens.append(token)
def finalize(self):
self._tokens.extend(self._current_tokens)
self._text += self._tokenizer.decode(self._current_tokens)
self._current_tokens = []
self._current_text = ""
@property
def text(self):
if self._current_tokens:
self._current_text = self._tokenizer.decode(self._current_tokens)
if (
self._tokenizer.clean_up_tokenization_spaces
and self._current_text[-1] == " "
):
self._current_text = self._current_text[:-1]
if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens)
self._text += self._current_text
self._current_tokens.clear()
self._current_text = ""
return self._text + self._current_text
@property
def tokens(self):
return self._tokens
class SPMStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for SPM models.
It adds tokens to the text if the next token starts with the special SPM
underscore which results in linear complexity.
"""
def __init__(self, tokenizer, trim_space=True):
self.trim_space = trim_space
self._sep = "\u2581".encode()
# Extract the tokens in a list from id to text
self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1)
for value, tokenid in tokenizer.vocab.items():
if value.startswith("<0x"):
# Replace bytes with their value
self.tokenmap[tokenid] = bytes([int(value[3:5], 16)])
else:
self.tokenmap[tokenid] = value.encode()
self.reset()
def reset(self):
self.offset = 0
self._unflushed = b""
self.text = ""
self.tokens = []
def _flush(self):
text = self._unflushed.replace(self._sep, b" ").decode("utf-8")
if not self.text and self.trim_space and text and text[0] == " ":
text = text[1:]
self.text += text
def add_token(self, token):
v = self.tokenmap[token]
if v.startswith(self._sep):
self._flush()
self._unflushed = v
else:
self._unflushed += v
def finalize(self):
self._flush()
self._unflushed = b""
class BPEStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for OpenAI style BPE models.
It adds tokens to the text if the next token starts with a space similar to
the SPM detokenizer.
"""
_byte_decoder = None
_space_matches = (".", "?", "!", ",", "n't", "'m", "'s", "'ve", "'re")
def __init__(self, tokenizer):
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
# Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab)
for value, tokenid in tokenizer.vocab.items():
self.tokenmap[tokenid] = value
self.reset()
# Make the BPE byte decoder from
# https://github.com/openai/gpt-2/blob/master/src/encoder.py
self.make_byte_decoder()
self._added_ids = set(tokenizer.added_tokens_decoder.keys())
def reset(self):
self.offset = 0
self._unflushed = ""
self.text = ""
self.tokens = []
def _maybe_trim_space(self, current_text):
if len(current_text) == 0:
return current_text
elif current_text[0] != " ":
return current_text
elif not self.text:
return current_text[1:]
elif self.clean_spaces and current_text[1:].startswith(self._space_matches):
return current_text[1:]
return current_text
def add_token(self, token):
v = self.tokenmap[token]
is_added = token in self._added_ids
if is_added or self._byte_decoder[v[0]] == 32:
current_text = bytearray(
self._byte_decoder[c] for c in self._unflushed
).decode("utf-8")
self.text += self._maybe_trim_space(current_text)
if is_added:
self.text += v
self._unflushed = ""
else:
self._unflushed = v
else:
self._unflushed += v
def finalize(self):
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
"utf-8"
)
self.text += self._maybe_trim_space(current_text)
self._unflushed = ""
@classmethod
def make_byte_decoder(cls):
"""See https://github.com/openai/gpt-2/blob/master/src/encoder.py for the rationale."""
if cls._byte_decoder is not None:
return
char_to_bytes = {}
limits = [
0,
ord("!"),
ord("~") + 1,
ord("¡"),
ord("¬") + 1,
ord("®"),
ord("ÿ") + 1,
]
n = 0
for i, (start, stop) in enumerate(zip(limits, limits[1:])):
if i % 2 == 0:
for b in range(start, stop):
char_to_bytes[chr(2**8 + n)] = b
n += 1
else:
for b in range(start, stop):
char_to_bytes[chr(b)] = b
cls._byte_decoder = char_to_bytes
class TokenizerWrapper:
"""A wrapper that combines an HF tokenizer and a detokenizer.
Accessing any attribute other than the ``detokenizer`` is forwarded to the
huggingface tokenizer.
"""
def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer):
self._tokenizer = tokenizer
self._detokenizer = detokenizer_class(tokenizer)
def __getattr__(self, attr):
if attr == "detokenizer":
return self._detokenizer
elif attr.startswith("_"):
return self.__getattribute__(attr)
else:
return getattr(self._tokenizer, attr)
def __setattr__(self, attr, value):
if attr == "detokenizer":
raise AttributeError("Cannot set the detokenizer.")
elif attr.startswith("_"):
super().__setattr__(attr, value)
else:
setattr(self._tokenizer, attr, value)
def _match(a, b):
if type(a) != type(b):
return False
if isinstance(a, dict):
return len(a) == len(b) and all(k in b and _match(a[k], b[k]) for k in a)
if isinstance(a, list):
return len(a) == len(b) and all(_match(ai, bi) for ai, bi in zip(a, b))
return a == b
def _is_spm_decoder(decoder):
_target_description = {
"type": "Sequence",
"decoders": [
{"type": "Replace", "pattern": {"String": ""}, "content": " "},
{"type": "ByteFallback"},
{"type": "Fuse"},
{"type": "Strip", "content": " ", "start": 1, "stop": 0},
],
}
return _match(_target_description, decoder)
def _is_spm_decoder_no_space(decoder):
_target_description = {
"type": "Sequence",
"decoders": [
{"type": "Replace", "pattern": {"String": ""}, "content": " "},
{"type": "ByteFallback"},
{"type": "Fuse"},
],
}
return _match(_target_description, decoder)
def _is_bpe_decoder(decoder):
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
def load_tokenizer(model_path, tokenizer_config_extra={}):
"""Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use.
Note, to use a fast streaming tokenizer, pass a local file path rather than
a Hugging Face repo ID.
"""
detokenizer_class = NaiveStreamingDetokenizer
tokenizer_file = model_path / "tokenizer.json"
if tokenizer_file.exists():
with open(tokenizer_file, "r") as fid:
tokenizer_content = json.load(fid)
if "decoder" in tokenizer_content:
if _is_spm_decoder(tokenizer_content["decoder"]):
detokenizer_class = SPMStreamingDetokenizer
elif _is_spm_decoder_no_space(tokenizer_content["decoder"]):
detokenizer_class = partial(SPMStreamingDetokenizer, trim_space=False)
elif _is_bpe_decoder(tokenizer_content["decoder"]):
detokenizer_class = BPEStreamingDetokenizer
return TokenizerWrapper(
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
detokenizer_class,
)

View File

@@ -0,0 +1,2 @@
from .trainer import TrainingArgs, evaluate, train
from .utils import linear_to_lora_layers

View File

@@ -0,0 +1,191 @@
import json
from pathlib import Path
from typing import Dict, List
from transformers import PreTrainedTokenizer
class Dataset:
"""
Light-weight wrapper to hold a dataset.
"""
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
self._text_key = text_key
self._data = data
def __getitem__(self, idx: int):
return self._data[idx][self._text_key]
def __len__(self):
if self._data is None:
return 0
return len(self._data)
class ChatDataset(Dataset):
"""
A dataset for chat data in the format of {"messages": [...]}
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
super().__init__(data)
self._tokenizer = tokenizer
def __getitem__(self, idx: int):
messages = self._data[idx]["messages"]
text = self._tokenizer.apply_chat_template(
messages,
tools=self._data[idx].get("tools", None),
tokenize=False,
add_generation_prompt=True,
)
return text
class CompletionsDataset(Dataset):
"""
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
or using user-provided keys for prompt and completion values
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
completion_key: str = "completion",
):
super().__init__(data)
self._tokenizer = tokenizer
self._prompt_key = prompt_key
self._completion_key = completion_key
def __getitem__(self, idx: int):
data = self._data[idx]
text = self._tokenizer.apply_chat_template(
[
{"role": "user", "content": data[self._prompt_key]},
{"role": "assistant", "content": data[self._completion_key]},
],
tokenize=False,
add_generation_prompt=True,
)
return text
def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
sample = data[0]
if "messages" in sample:
return ChatDataset(data, tokenizer)
elif "prompt" in sample and "completion" in sample:
return CompletionsDataset(data, tokenizer)
elif "text" in sample:
return Dataset(data)
else:
raise ValueError(
"Unsupported data format, check the supported formats here:\n"
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
)
def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer):
def load_subset(path):
if not path.exists():
return []
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer)
names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]
return train, valid, test
def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
from datasets import exceptions, load_dataset
try:
dataset = load_dataset(data_id)
names = ("train", "valid", "test")
train, valid, test = [
create_dataset(dataset[n], tokenizer) if n in dataset.keys() else []
for n in names
]
except exceptions.DatasetNotFoundError:
raise ValueError(f"Not found Hugging Face dataset: {data_id} .")
return train, valid, test
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets
hf_args = args.hf_dataset
dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature = hf_args.get("text_feature")
prompt_feature = hf_args.get("prompt_feature")
completion_feature = hf_args.get("completion_feature")
def create_hf_dataset(split: str = None):
ds = datasets.load_dataset(
dataset_name,
split=split,
**hf_args.get("config", {}),
)
if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif text_feature:
return Dataset(train_ds, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
"feature for the Hugging Face dataset."
)
if args.train:
train_split = hf_args.get("train_split", "train[:80%]")
valid_split = hf_args.get("valid_split", "train[-10%:]")
train = create_hf_dataset(split=train_split)
valid = create_hf_dataset(split=valid_split)
else:
train, valid = [], []
if args.test:
test = create_hf_dataset(split=hf_args.get("test_split"))
else:
test = []
return train, valid, test
def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", None) is not None:
train, valid, test = load_custom_hf_dataset(args, tokenizer)
else:
data_path = Path(args.data)
if data_path.exists():
train, valid, test = load_local_dataset(data_path, tokenizer)
else:
print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(args.data, tokenizer)
if args.train and len(train) == 0:
raise ValueError(
"Training set not found or empty. Must provide training set for fine-tuning."
)
if args.train and len(valid) == 0:
raise ValueError(
"Validation set not found or empty. Must provide validation set for fine-tuning."
)
if args.test and len(test) == 0:
raise ValueError(
"Test set not found or empty. Must provide test set for evaluation."
)
return train, valid, test

228
llms/mlx_lm/tuner/dora.py Normal file
View File

@@ -0,0 +1,228 @@
# Copyright © 2024 Apple Inc.
import math
import mlx.core as mx
import mlx.nn as nn
class DoRALinear(nn.Module):
@staticmethod
def from_base(
linear: nn.Linear,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
# TODO remove when input_dims and output_dims are attributes
# on linear and quantized linear
output_dims, input_dims = linear.weight.shape
if isinstance(linear, nn.QuantizedLinear):
input_dims *= 32 // linear.bits
dora_lin = DoRALinear(
input_dims=input_dims,
output_dims=output_dims,
r=r,
dropout=dropout,
scale=scale,
)
dora_lin.set_linear(linear)
return dora_lin
def fuse(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = self._dequantized_weight()
# Use the same type as the linear weight
dtype = weight.dtype
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=False)
lora_b = (self.scale * self.lora_b.T).astype(dtype)
lora_a = self.lora_a.T.astype(dtype)
weight = weight + lora_b @ lora_a
norm_scale = self.m / mx.linalg.norm(weight, axis=1)
fused_linear.weight = norm_scale[:, None] * weight
if bias:
fused_linear.bias = linear.bias
if self._is_quantized() and not de_quantize:
fused_linear = nn.QuantizedLinear.from_linear(
fused_linear,
linear.group_size,
linear.bits,
)
return fused_linear
def __init__(
self,
input_dims: int,
output_dims: int,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
bias: bool = False,
):
super().__init__()
# Regular linear layer weights
self.set_linear(nn.Linear(input_dims, output_dims, bias=bias))
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, r),
)
self.lora_b = mx.zeros(shape=(r, output_dims))
def set_linear(self, linear):
"""
Set the self.linear layer and recompute self.m.
"""
self.linear = linear
self.m = mx.linalg.norm(self._dequantized_weight().astype(mx.float32), axis=1)
def _dequantized_weight(self):
"""
Return the weight of linear layer and dequantize it if is quantized
"""
weight = self.linear.weight
if self._is_quantized():
weight = mx.dequantize(
weight,
self.linear.scales,
self.linear.biases,
self.linear.group_size,
self.linear.bits,
)
return weight
def _is_quantized(self):
return isinstance(self.linear, nn.QuantizedLinear)
def __call__(self, x):
# Regular LoRA (without a bias)
w = self._dequantized_weight()
y = x @ w.T
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
out = y + (self.scale * z).astype(x.dtype)
# Compute the norm of the adapted weights
adapted = w + (self.scale * self.lora_b.T) @ self.lora_a.T
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
# Remove the norm and scale by the learned magnitude
out = (self.m / denom).astype(x.dtype) * out
if "bias" in self.linear:
out = out + self.linear.bias
return out
class DoRAEmbedding(nn.Module):
def from_base(
embedding: nn.Embedding,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
num_embeddings, dims = embedding.weight.shape
# TODO support quantized weights in DoRALinear
if isinstance(embedding, nn.QuantizedLinear):
raise ValueError("DoRAEmbedding does not yet support quantization.")
dora_embedding = DoRAEmbedding(
num_embeddings=num_embeddings,
dims=dims,
r=r,
dropout=dropout,
scale=scale,
)
dora_embedding.set_embedding(embedding)
return dora_embedding
def fuse(self, de_quantize: bool = False):
embedding = self.embedding
weight = embedding.weight
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
num_embeddings, dims = weight.shape
fused_embedding = nn.Embedding(num_embeddings, dims)
lora_a = (self.scale * self.lora_a).astype(dtype)
lora_b = self.lora_b.astype(dtype)
weight = weight + lora_a @ lora_b
norm_scale = self.m / mx.linalg.norm(weight, axis=1)
fused_embedding.weight = norm_scale[:, None] * weight
return fused_embedding
def __init__(
self,
num_embeddings: int,
dims: int,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
super().__init__()
# Regular embedding layer weights
self.set_embedding(nn.Embedding(num_embeddings, dims))
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(num_embeddings)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(num_embeddings, r),
)
self.lora_b = mx.zeros(shape=(r, dims))
def set_embedding(self, embedding: nn.Module):
self.embedding = embedding
self.m = mx.linalg.norm(embedding.weight, axis=1)
def __call__(self, x):
y = self.embedding(x)
z = self.scale * self.lora_a[x] @ self.lora_b
out = y + self.dropout(z).astype(y.dtype)
# Compute the norm of the adapted weights for the individual embeddings
adapted = y + z
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=-1))
# Remove the norm and scale by the learned magnitude
out = (self.m[x] / denom)[..., None] * out
return out
def as_linear(self, x):
y = self.embedding.as_linear(x)
z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T
out = y + (self.scale * z).astype(x.dtype)
# Compute the norm of the adapted weights
adapted = self.embedding.weight + (self.scale * self.lora_a) @ self.lora_b
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
# Remove the norm and scale by the learned magnitude
out = (self.m / denom) * out
return out

285
llms/mlx_lm/tuner/lora.py Normal file
View File

@@ -0,0 +1,285 @@
# Copyright © 2024 Apple Inc.
import math
import mlx.core as mx
import mlx.nn as nn
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
class LoRALinear(nn.Module):
@staticmethod
def from_base(
linear: nn.Linear,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
# TODO remove when input_dims and output_dims are attributes
# on linear and quantized linear
output_dims, input_dims = linear.weight.shape
if isinstance(linear, nn.QuantizedLinear):
input_dims *= 32 // linear.bits
lora_lin = LoRALinear(
input_dims=input_dims,
output_dims=output_dims,
r=r,
dropout=dropout,
scale=scale,
)
lora_lin.linear = linear
return lora_lin
def fuse(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
is_quantized = isinstance(linear, nn.QuantizedLinear)
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
if is_quantized:
dtype = linear.scales.dtype
weight = mx.dequantize(
weight,
linear.scales,
linear.biases,
linear.group_size,
linear.bits,
)
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
lora_b = (self.scale * self.lora_b.T).astype(dtype)
lora_a = self.lora_a.T.astype(dtype)
fused_linear.weight = weight + lora_b @ lora_a
if bias:
fused_linear.bias = linear.bias
if is_quantized and not de_quantize:
fused_linear = nn.QuantizedLinear.from_linear(
fused_linear,
linear.group_size,
linear.bits,
)
return fused_linear
def __init__(
self,
input_dims: int,
output_dims: int,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
bias: bool = False,
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, r),
)
self.lora_b = mx.zeros(shape=(r, output_dims))
def __call__(self, x):
y = self.linear(x)
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
return y + (self.scale * z).astype(x.dtype)
class LoRASwitchLinear(nn.Module):
@staticmethod
def from_base(
linear: nn.Module,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
lora_lin = LoRASwitchLinear(
input_dims=linear.input_dims,
output_dims=linear.output_dims,
num_experts=linear.num_experts,
r=r,
dropout=dropout,
scale=scale,
)
lora_lin.linear = linear
return lora_lin
def fuse(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
is_quantized = isinstance(linear, QuantizedSwitchLinear)
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
if is_quantized:
dtype = mx.float16
weight = mx.dequantize(
weight,
linear.scales,
linear.biases,
linear.group_size,
linear.bits,
)
num_experts, output_dims, input_dims = weight.shape
fused_linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias)
lora_b = (self.scale * self.lora_b).astype(dtype)
lora_a = self.lora_a.reshape(num_experts, -1, input_dims).astype(dtype)
fused_linear.weight = weight + lora_b @ lora_a
if bias:
fused_linear.bias = linear.bias
if is_quantized and not de_quantize:
fused_linear = fused_linear.to_quantized(linear.group_size, linear.bits)
return fused_linear
def __init__(
self,
input_dims: int,
output_dims: int,
num_experts: int,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
bias: bool = False,
):
super().__init__()
# Regular linear layer weights
self.linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias)
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(r * num_experts, input_dims),
)
self.lora_b = mx.zeros(shape=(num_experts, output_dims, r))
self.num_experts = num_experts
def __call__(self, x, indices):
shape = x.shape[:-3] + (self.num_experts, -1)
y = self.linear(x, indices)
z = (self.dropout(x) @ self.lora_a.T).reshape(shape)
z = mx.take_along_axis(z, indices[..., None], axis=-2)
z = z[..., None, :] @ self.lora_b[indices].swapaxes(-2, -1)
return y + (self.scale * z).astype(x.dtype)
class LoRAEmbedding(nn.Module):
@staticmethod
def from_base(
embedding: nn.Embedding,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
num_embeddings, dims = embedding.weight.shape
if isinstance(embedding, nn.QuantizedEmbedding):
dims *= 32 // embedding.bits
lora_embedding = LoRAEmbedding(
num_embeddings=num_embeddings,
dims=dims,
r=r,
dropout=dropout,
scale=scale,
)
lora_embedding.embedding = embedding
return lora_embedding
def fuse(self, de_quantize: bool = False):
embedding = self.embedding
weight = embedding.weight
is_quantized = isinstance(embedding, nn.QuantizedEmbedding)
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
if is_quantized:
dtype = embedding.scales.dtype
weight = mx.dequantize(
weight,
embedding.scales,
embedding.biases,
embedding.group_size,
embedding.bits,
)
num_embeddings, dims = weight.shape
fused_embedding = nn.Embedding(num_embeddings, dims)
lora_a = (self.scale * self.lora_a).astype(dtype)
lora_b = self.lora_b.astype(dtype)
fused_embedding.weight = weight + lora_a @ lora_b
if is_quantized and not de_quantize:
fused_embedding = nn.QuantizedEmbedding.from_embedding(
fused_embedding,
embedding.group_size,
embedding.bits,
)
return fused_embedding
def __init__(
self,
num_embeddings: int,
dims: int,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
super().__init__()
# Regular embedding layer
self.embedding = nn.Embedding(num_embeddings, dims)
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(num_embeddings)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(num_embeddings, r),
)
self.lora_b = mx.zeros(shape=(r, dims))
def __call__(self, x):
y = self.embedding(x)
z = self.dropout(self.lora_a[x] @ self.lora_b)
out = y + (self.scale * z).astype(y.dtype)
return out
def as_linear(self, x):
y = self.embedding.as_linear(x)
z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T
return y + (self.scale * z).astype(x.dtype)

View File

@@ -0,0 +1,333 @@
# Copyright © 2024 Apple Inc.
import glob
import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
def grad_checkpoint(layer):
"""
Update all instances of type(layer) to use gradient checkpointing.
"""
fn = type(layer).__call__
def checkpointed_fn(model, *args, **kwargs):
def inner_fn(params, *args, **kwargs):
model.update(params)
return fn(model, *args, **kwargs)
return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs)
type(layer).__call__ = checkpointed_fn
@dataclass
class TrainingArgs:
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
val_batches: int = field(
default=25,
metadata={
"help": "Number of validation batches, -1 uses the entire validation set."
},
)
steps_per_report: int = field(
default=10,
metadata={"help": "Number of training steps between loss reporting."},
)
steps_per_eval: int = field(
default=200, metadata={"help": "Number of training steps between validations."}
)
steps_per_save: int = field(
default=100, metadata={"help": "Save the model every number steps"}
)
max_seq_length: int = field(
default=2048, metadata={"help": "Maximum sequence length."}
)
adapter_file: str = field(
default="adapters.safetensors",
metadata={"help": "Save/load path for the trained adapter weights."},
)
grad_checkpoint: bool = field(
default=False,
metadata={"help": "Use gradient checkpointing to reduce memory use."},
)
def default_loss(model, inputs, targets, lengths):
logits = model(inputs)
logits = logits.astype(mx.float32)
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]
ce = nn.losses.cross_entropy(logits, targets) * length_mask
ntoks = length_mask.sum()
ce = ce.sum() / ntoks
return ce, ntoks
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
)
# If running in distributed mode (N machines) then each one should skip N-1
# samples
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
# Make the batches:
batch_idx = [
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True:
indices = np.random.permutation(len(batch_idx))
for i in indices:
# Encode batch
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
for b in batch:
if b[-1] != tokenizer.eos_token_id:
b.append(tokenizer.eos_token_id)
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
"Consider pre-splitting your data to save memory."
)
# Pad to the nearest multiple of 8 or the maximum length
pad_to = 8
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
for j in range(batch_size // step):
truncated_length = min(lengths[j], max_seq_length)
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
lengths[j] = (
truncated_length # Update lengths to match truncated lengths
)
batch = mx.array(batch_arr)
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
if not train:
break
def evaluate(
model,
dataset,
tokenizer,
batch_size,
num_batches,
max_seq_length=2048,
loss: callable = default_loss,
iterate_batches: callable = iterate_batches,
):
all_losses = 0
ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
losses, toks = loss(model, *batch)
all_losses += losses * toks
ntokens += toks
mx.eval(all_losses, ntokens)
all_losses = mx.distributed.all_sum(all_losses)
ntokens = mx.distributed.all_sum(ntokens)
return (all_losses / ntokens).item()
class TrainingCallback:
def on_train_loss_report(self, train_info: dict):
"""Called to report training loss at specified intervals."""
pass
def on_val_loss_report(self, val_info: dict):
"""Called to report validation loss at specified intervals or the beginning."""
pass
def train(
model,
tokenizer,
optimizer,
train_dataset,
val_dataset,
args: TrainingArgs = TrainingArgs(),
loss: callable = default_loss,
iterate_batches: callable = iterate_batches,
training_callback: TrainingCallback = None,
):
print(f"Starting training..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
if world_size > 1:
print(f"Node {rank} of {world_size}")
if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
state = [model.state, optimizer.state]
def step(batch):
# Forward and backward pass
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
# All reduce the gradients if running in distributed mode
grad = average_gradients(grad)
# Model update
optimizer.update(model, grad)
return lvalue, toks
loss_value_and_grad = nn.value_and_grad(model, loss)
losses = 0
n_tokens = 0
steps = 0
trained_tokens = 0
# Main training loop
start = time.perf_counter()
for it, batch in zip(
range(1, args.iters + 1),
iterate_batches(
dataset=train_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
),
):
# Report validation loss if needed, the first validation loss
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss = evaluate(
model=model,
dataset=val_dataset,
loss=loss,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
if rank == 0:
print(
f"Iter {it}: "
f"Val loss {val_loss:.3f}, "
f"Val took {val_time:.3f}s",
flush=True,
)
if training_callback is not None:
val_info = {
"iteration": it,
"val_loss": val_loss,
"val_time": val_time,
}
training_callback.on_val_loss_report(val_info)
start = time.perf_counter()
lvalue, toks = step(batch)
losses += lvalue
n_tokens += toks
steps += 1
mx.eval(state, losses, n_tokens)
# Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses).item()
train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
print(
f"Iter {it}: Train loss {train_loss:.3f}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)
if training_callback is not None:
train_info = {
"iteration": it,
"train_loss": train_loss,
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
"tokens_per_second": tokens_sec,
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
}
training_callback.on_train_loss_report(train_info)
losses = 0
n_tokens = 0
steps = 0
start = time.perf_counter()
# Save adapter weights
if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
mx.save_safetensors(str(checkpoint), adapter_weights)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
)
# Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")

268
llms/mlx_lm/tuner/utils.py Normal file
View File

@@ -0,0 +1,268 @@
# Copyright © 2024 Apple Inc.
import json
import types
from pathlib import Path
from typing import Dict
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_flatten, tree_unflatten
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
from .dora import DoRAEmbedding, DoRALinear
from .lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
def build_schedule(schedule_config: Dict):
"""
Build a learning rate schedule from the given config.
"""
schedule_fn = getattr(opt.schedulers, schedule_config["name"])
arguments = schedule_config["arguments"]
initial_lr = arguments[0]
bound_schedule_fn = schedule_fn(*arguments)
if warmup_steps := schedule_config.get("warmup", 0):
warmup_init = schedule_config.get("warmup_init", 0.0)
warmup_fn = opt.schedulers.linear_schedule(
warmup_init, initial_lr, warmup_steps
)
return opt.schedulers.join_schedules(
[warmup_fn, bound_schedule_fn], [warmup_steps + 1]
)
else:
return bound_schedule_fn
def linear_to_lora_layers(
model: nn.Module,
num_layers: int,
config: Dict,
use_dora: bool = False,
):
"""
Convert some of the models linear layers to lora layers.
Args:
model (nn.Module): The neural network model.
num_layers (int): The number of blocks to convert to lora layers
starting from the last layer.
config (dict): More configuration parameters for LoRA, including the
rank, scale, and optional layer keys.
use_dora (bool): If True, uses DoRA instead of LoRA.
Default: ``False``
"""
if num_layers > len(model.layers):
raise ValueError(
f"Requested {num_layers} LoRA layers "
f"but the model only has {len(model.layers)} layers."
)
def to_lora(layer):
if isinstance(layer, (nn.Linear, nn.QuantizedLinear)):
LoRALayer = DoRALinear if use_dora else LoRALinear
elif isinstance(layer, (SwitchLinear, QuantizedSwitchLinear)):
if use_dora:
raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.")
LoRALayer = LoRASwitchLinear
elif isinstance(layer, (nn.Embedding, nn.QuantizedEmbedding)):
LoRALayer = DoRAEmbedding if use_dora else LoRAEmbedding
else:
raise ValueError(
f"Can't convert layer of type {type(layer).__name__} to LoRA"
)
return LoRALayer.from_base(
layer,
r=config["rank"],
scale=config["scale"],
dropout=config["dropout"],
)
keys = config.get("keys", None)
if keys is not None:
keys = set(keys)
elif model.model_type in [
"mistral",
"llama",
"phi",
"mixtral",
"nemotron",
"stablelm",
"qwen2",
"qwen2_moe",
"phimoe",
"gemma",
"gemma2",
"starcoder2",
"cohere",
"minicpm",
"deepseek",
]:
keys = set(["self_attn.q_proj", "self_attn.v_proj"])
if model.model_type in ["mixtral", "phimoe"]:
keys.add("block_sparse_moe.gate")
if model.model_type == "qwen2_moe":
keys.add("mlp.gate")
keys.add("mlp.shared_expert_gate")
elif model.model_type == "gpt_bigcode":
keys = set(["attn.c_attn"])
elif model.model_type == "gpt2":
keys = set(["attn.c_attn"])
elif model.model_type == "gpt_neox":
keys = set(["attention.query_key_value"])
elif model.model_type == "olmo":
keys = set(["att_proj"])
elif model.model_type == "openelm":
keys = set(["attn.qkv_proj"])
elif model.model_type == "phi3":
keys = set(["self_attn.qkv_proj"])
elif model.model_type == "phi-msft":
keys = set(["mixer.Wqkv", "moe.gate"])
elif model.model_type == "dbrx":
keys = set(["norm_attn_norm.attn.Wqkv", "ffn.router.layer"])
elif model.model_type == "internlm2":
keys = set(["attention.wqkv", "attention.wo"])
elif model.model_type == "deepseek_v2":
keys = set(
[
"self_attn.q_proj",
"self_attn.q_a_proj",
"self_attn.q_b_proj",
"self_attn.kv_a_proj_with_mqa",
"self_attn.kv_b_proj",
]
)
elif model.model_type == "mamba":
keys = set(
[
"mixer.in_proj",
"mixer.x_proj",
"mixer.dt_proj",
"mixer.out_proj",
]
)
else:
raise ValueError(f"Lora does not support {model.model_type}")
for l in model.layers[-min(num_layers, 0) :]:
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
if lora_layers:
l.update_modules(tree_unflatten(lora_layers))
lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys]
if lora_modules:
model.update_modules(tree_unflatten(lora_modules))
def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
"""
Load any fine-tuned adapters / layers.
Args:
model (nn.Module): The neural network model.
adapter_path (str): Path to the adapter configuration file.
Returns:
nn.Module: The updated model with LoRA layers applied.
"""
adapter_path = Path(adapter_path)
if not adapter_path.exists():
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
with open(adapter_path / "adapter_config.json", "r") as fid:
config = types.SimpleNamespace(**json.load(fid))
fine_tune_type = getattr(config, "fine_tune_type", "lora")
if fine_tune_type != "full":
linear_to_lora_layers(
model,
config.num_layers,
config.lora_parameters,
use_dora=(fine_tune_type == "dora"),
)
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
return model
def dequantize(model: nn.Module) -> nn.Module:
"""
Dequantize the quantized linear layers in the model.
Args:
model (nn.Module): The model with quantized linear layers.
Returns:
nn.Module: The model with dequantized layers.
"""
de_quantize_layers = []
for name, module in model.named_modules():
if isinstance(module, nn.QuantizedLinear):
bias = "bias" in module
weight = module.weight
weight = mx.dequantize(
weight,
module.scales,
module.biases,
module.group_size,
module.bits,
).astype(mx.float16)
output_dims, input_dims = weight.shape
linear = nn.Linear(input_dims, output_dims, bias=bias)
linear.weight = weight
if bias:
linear.bias = module.bias
de_quantize_layers.append((name, linear))
if isinstance(module, nn.QuantizedEmbedding):
weight = mx.dequantize(
module.weight,
module.scales,
module.biases,
module.group_size,
module.bits,
).astype(mx.float16)
num_embeddings, dims = weight.shape
emb = nn.Embedding(num_embeddings, dims)
emb.weight = weight
de_quantize_layers.append((name, emb))
if len(de_quantize_layers) > 0:
model.update_modules(tree_unflatten(de_quantize_layers))
return model
def remove_lora_layers(model: nn.Module) -> nn.Module:
"""
Remove the LoRA layers from the model.
Args:
model (nn.Module): The model with LoRA layers.
Returns:
nn.Module: The model without LoRA layers.
"""
reset_layers = []
for name, module in model.named_modules():
if isinstance(module, LoRALinear):
reset_layers.append((name, module.linear))
if len(reset_layers) > 0:
model.update_modules(tree_unflatten(reset_layers))
return model
def print_trainable_parameters(model):
def nparams(m):
if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)):
return m.weight.size * (32 // m.bits)
return sum(v.size for _, v in tree_flatten(m.parameters()))
leaf_modules = tree_flatten(
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
)
total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6
trainable_p = (
sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
)
print(
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
f"({trainable_p:.3f}M/{total_p:.3f}M)"
)

812
llms/mlx_lm/utils.py Normal file
View File

@@ -0,0 +1,812 @@
# Copyright © 2023-2024 Apple Inc.
import contextlib
import copy
import glob
import importlib
import json
import logging
import shutil
import time
from pathlib import Path
from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer
# Local imports
from .gguf import load_gguf
from .models import cache
from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model
from .tuner.utils import load_adapters
# Constants
MODEL_REMAPPING = {
"mistral": "llama", # mistral is compatible with llama
"phi-msft": "phixtral",
"falcon_mamba": "mamba",
}
MAX_FILE_SIZE_GB = 5
# A stream on the default device just for generation
generation_stream = mx.new_stream(mx.default_device())
class ModelNotFoundError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
@contextlib.contextmanager
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
"""
A context manager to temporarily change the wired limit.
Note, the wired limit should not be changed during an async eval. If an
async eval could be running pass in the streams to synchronize with prior
to exiting the context manager.
"""
model_bytes = tree_reduce(
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
)
max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"]
if model_bytes > 0.9 * max_rec_size:
model_mb = model_bytes // 2**20
max_rec_mb = max_rec_size // 2**20
print(
f"[WARNING] Generating with a model that requires {model_mb} MB "
f"which is close to the maximum recommended size of {max_rec_mb} "
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models"
)
old_limit = mx.metal.set_wired_limit(max_rec_size)
try:
yield None
finally:
if streams is not None:
for s in streams:
mx.synchronize(s)
else:
mx.synchronize()
mx.metal.set_wired_limit(old_limit)
def _get_classes(config: dict):
"""
Retrieve the model and model args classes based on the configuration.
Args:
config (dict): The model configuration.
Returns:
A tuple containing the Model class and the ModelArgs class.
"""
model_type = config["model_type"]
model_type = MODEL_REMAPPING.get(model_type, model_type)
try:
arch = importlib.import_module(f"mlx_lm.models.{model_type}")
except ImportError:
msg = f"Model type {model_type} not supported."
logging.error(msg)
raise ValueError(msg)
return arch.Model, arch.ModelArgs
def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path:
"""
Ensures the model is available locally. If the path does not exist locally,
it is downloaded from the Hugging Face Hub.
Args:
path_or_hf_repo (str): The local path or Hugging Face repository ID of the model.
revision (str, optional): A revision id which can be a branch name, a tag, or a commit hash.
Returns:
Path: The path to the model.
"""
model_path = Path(path_or_hf_repo)
if not model_path.exists():
try:
model_path = Path(
snapshot_download(
repo_id=path_or_hf_repo,
revision=revision,
allow_patterns=[
"*.json",
"*.safetensors",
"*.py",
"tokenizer.model",
"*.tiktoken",
"*.txt",
],
)
)
except:
raise ModelNotFoundError(
f"Model not found for path or HF repo: {path_or_hf_repo}.\n"
"Please make sure you specified the local path or Hugging Face"
" repo id correctly.\nIf you are trying to access a private or"
" gated Hugging Face repo, make sure you are authenticated:\n"
"https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login"
) from None
return model_path
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
if (
kv_bits is not None
and not isinstance(prompt_cache[0], cache.QuantizedKVCache)
and prompt_cache[0].offset > quantized_kv_start
):
for i in range(len(prompt_cache)):
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)
def generate_step(
prompt: mx.array,
model: nn.Module,
temp: float = 0.0,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
top_p: float = 1.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
y = prompt
tokens = None
# Create the KV cache for generation
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=max_kv_size,
)
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
logits_processors = logits_processors or []
logits_processors.extend(
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
)
def _step(y):
with mx.stream(generation_stream):
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
if logits_processors:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
for processor in logits_processors:
logits = processor(tokens, logits)
maybe_quantize_kv_cache(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs)
return y, logprobs.squeeze(0)
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
y = y[prefill_step_size:]
mx.metal.clear_cache()
y, logprobs = _step(y)
mx.async_eval(y, logprobs)
n = 0
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
yield y.item(), logprobs
if n % 256 == 0:
mx.metal.clear_cache()
n += 1
y, logprobs = next_y, next_logprobs
def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, List[int]],
max_tokens: int = 100,
**kwargs,
) -> Generator[Tuple[str, int, mx.array], None, None]:
"""
A generator producing text based on the given prompt from the model.
Args:
model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
Yields:
Tuple[str, int, mx.array]:
The next text segment, token, and vector of log probabilities.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
prompt_tokens = mx.array(
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
)
detokenizer = tokenizer.detokenizer
with wired_limit(model, [generation_stream]):
detokenizer.reset()
for n, (token, logits) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
if n == (max_tokens - 1):
break
yield detokenizer.last_segment, token, logits
detokenizer.finalize()
yield detokenizer.last_segment, token, logits
def generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str,
max_tokens: int = 100,
verbose: bool = False,
formatter: Optional[Callable] = None,
**kwargs,
) -> str:
"""
Generate a complete response from the model.
Args:
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
formatter (Optional[Callable]): A function which takes a token and a
probability and displays it.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
"""
if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer)
if verbose:
print("=" * 10)
print("Prompt:", prompt)
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
with wired_limit(model, [generation_stream]):
tic = time.perf_counter()
detokenizer.reset()
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
if verbose:
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)
token_count = n + 1
detokenizer.finalize()
if verbose:
gen_time = time.perf_counter() - tic
print(detokenizer.last_segment, flush=True)
print("=" * 10)
if token_count == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print(
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
)
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 1e9
print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text
def load_config(model_path: Path) -> dict:
try:
with open(model_path / "config.json", "r") as f:
config = json.load(f)
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
return config
def load_model(
model_path: Path,
lazy: bool = False,
model_config: dict = {},
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module:
"""
Load and initialize the model from a given path.
Args:
model_path (Path): The path to load the model from.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
model_config (dict, optional): Configuration parameters for the model.
Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
A function that returns the model class and model args class given a config.
Defaults to the _get_classes function.
Returns:
nn.Module: The loaded and initialized model.
Raises:
FileNotFoundError: If the weight files (.safetensors) are not found.
ValueError: If the model class or args class are not found or cannot be instantiated.
"""
config = load_config(model_path)
config.update(model_config)
weight_files = glob.glob(str(model_path / "model*.safetensors"))
if not weight_files:
# Try weight for back-compat
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
if not weight_files:
logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}")
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
model_class, model_args_class = get_model_classes(config=config)
model_args = model_args_class.from_dict(config)
model = model_class(model_args)
if hasattr(model, "sanitize"):
weights = model.sanitize(weights)
if (quantization := config.get("quantization", None)) is not None:
def class_predicate(p, m):
# Handle custom per layer quantizations
if p in config["quantization"]:
return config["quantization"][p]
if not hasattr(m, "to_quantized"):
return False
# Handle legacy models which may not have everything quantized
return f"{p}.scales" in weights
nn.quantize(
model,
group_size=quantization["group_size"],
bits=quantization["bits"],
class_predicate=class_predicate,
)
model.load_weights(list(weights.items()))
if not lazy:
mx.eval(model.parameters())
model.eval()
return model
def load(
path_or_hf_repo: str,
tokenizer_config={},
model_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]:
"""
Load the model and tokenizer from a given path or a huggingface repository.
Args:
path_or_hf_repo (Path): The path or the huggingface repository to load the model from.
tokenizer_config (dict, optional): Configuration parameters specifically for the tokenizer.
Defaults to an empty dictionary.
model_config(dict, optional): Configuration parameters specifically for the model.
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
Returns:
Tuple[nn.Module, TokenizerWrapper]: A tuple containing the loaded model and tokenizer.
Raises:
FileNotFoundError: If config file or safetensors are not found.
ValueError: If model class or args class are not found.
"""
if path_or_hf_repo.endswith(".gguf"):
model, tokenizer = load_gguf(path_or_hf_repo)
return model, tokenizer
model_path = get_model_path(path_or_hf_repo)
model = load_model(model_path, lazy, model_config)
if adapter_path is not None:
model = load_adapters(model, adapter_path)
model.eval()
tokenizer = load_tokenizer(model_path, tokenizer_config)
return model, tokenizer
def fetch_from_hub(
model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model = load_model(model_path, lazy)
config = load_config(model_path)
tokenizer = load_tokenizer(model_path)
return model, config, tokenizer
def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
"""
Splits the weights into smaller shards.
Args:
weights (dict): Model weights.
max_file_size_gb (int): Maximum size of each shard in gigabytes.
Returns:
list: List of weight shards.
"""
max_file_size_bytes = max_file_size_gb << 30
shards = []
shard, shard_size = {}, 0
for k, v in weights.items():
if shard_size + v.nbytes > max_file_size_bytes:
shards.append(shard)
shard, shard_size = {}, 0
shard[k] = v
shard_size += v.nbytes
shards.append(shard)
return shards
def upload_to_hub(path: str, upload_repo: str, hf_path: str):
"""
Uploads the model to Hugging Face hub.
Args:
path (str): Local path to the model.
upload_repo (str): Name of the HF repo to upload to.
hf_path (str): Path to the original Hugging Face model.
"""
import os
from huggingface_hub import HfApi, ModelCard, logging
from . import __version__
card = ModelCard.load(hf_path)
card.data.tags = ["mlx"] if card.data.tags is None else card.data.tags + ["mlx"]
card.data.base_model = hf_path
card.text = dedent(
f"""
# {upload_repo}
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path})
using mlx-lm version **{__version__}**.
## Use with mlx
```bash
pip install mlx-lm
```
```python
from mlx_lm import load, generate
model, tokenizer = load("{upload_repo}")
prompt="hello"
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None:
messages = [{{"role": "user", "content": prompt}}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
response = generate(model, tokenizer, prompt=prompt, verbose=True)
```
"""
)
card.save(os.path.join(path, "README.md"))
logging.set_verbosity_info()
api = HfApi()
api.create_repo(repo_id=upload_repo, exist_ok=True)
api.upload_folder(
folder_path=path,
repo_id=upload_repo,
repo_type="model",
multi_commits=True,
multi_commits_verbose=True,
)
print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.")
def save_weights(
save_path: Union[str, Path],
weights: Dict[str, Any],
*,
donate_weights: bool = False,
) -> None:
"""Save model weights into specified directory."""
if isinstance(save_path, str):
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
shards = make_shards(weights)
shards_count = len(shards)
shard_file_format = (
"model-{:05d}-of-{:05d}.safetensors"
if shards_count > 1
else "model.safetensors"
)
total_size = sum(v.nbytes for v in weights.values())
index_data = {"metadata": {"total_size": total_size}, "weight_map": {}}
# Write the weights and make sure no references are kept other than the
# necessary ones
if donate_weights:
weights.clear()
del weights
for i in range(len(shards)):
shard = shards[i]
shards[i] = None
shard_name = shard_file_format.format(i + 1, shards_count)
shard_path = save_path / shard_name
mx.save_safetensors(str(shard_path), shard, metadata={"format": "mlx"})
for weight_name in shard.keys():
index_data["weight_map"][weight_name] = shard_name
del shard
index_data["weight_map"] = {
k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"])
}
with open(save_path / "model.safetensors.index.json", "w") as f:
json.dump(
index_data,
f,
indent=4,
)
def quantize_model(
model: nn.Module,
config: dict,
q_group_size: int,
q_bits: int,
quant_predicate: Optional[
Callable[[str, nn.Module, dict], Union[bool, dict]]
] = None,
) -> Tuple:
"""
Applies quantization to the model weights.
Args:
model (nn.Module): The model to be quantized.
config (dict): Model configuration.
q_group_size (int): Group size for quantization.
q_bits (int): Bits per weight for quantization.
quant_predicate (Callable): A callable that decides how
to quantize each layer based on the path.
Accepts the layer `path`, the `module` and the model `config`.
Returns either a bool to signify quantize/no quantize or
a dict of quantization parameters to pass to `to_quantized`.
Returns:
Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
# Add any custom quantization parameters to the config as we go
def _class_predicate(p, m):
bool_or_params = quant_predicate(p, m, config)
if isinstance(bool_or_params, dict):
quantized_config["quantization"][p] = bool_or_params
return bool_or_params
nn.quantize(
model,
q_group_size,
q_bits,
class_predicate=_class_predicate if quant_predicate else None,
)
# support hf model tree #957
quantized_config["quantization_config"] = quantized_config["quantization"]
quantized_weights = dict(tree_flatten(model.parameters()))
return quantized_weights, quantized_config
def save_config(
config: dict,
config_path: Union[str, Path],
) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
# Clean unused keys
config.pop("_name_or_path", None)
# sort the config for better readability
config = dict(sorted(config.items()))
# write the updated config to the config_path (if provided)
with open(config_path, "w") as fid:
json.dump(config, fid, indent=4)
def convert(
hf_path: str,
mlx_path: str = "mlx_model",
quantize: bool = False,
q_group_size: int = 64,
q_bits: int = 4,
dtype: str = "float16",
upload_repo: str = None,
revision: Optional[str] = None,
dequantize: bool = False,
quant_predicate: Optional[
Callable[[str, nn.Module, dict], Union[bool, dict]]
] = None,
):
# Check the save path is empty
if isinstance(mlx_path, str):
mlx_path = Path(mlx_path)
if mlx_path.exists():
raise ValueError(
f"Cannot save to the path {mlx_path} as it already exists."
" Please delete the file/directory or specify a new path to save to."
)
print("[INFO] Loading")
model_path = get_model_path(hf_path, revision=revision)
model, config, tokenizer = fetch_from_hub(model_path, lazy=True)
weights = dict(tree_flatten(model.parameters()))
dtype = getattr(mx, dtype)
weights = {k: v.astype(dtype) for k, v in weights.items()}
if quantize and dequantize:
raise ValueError("Choose either quantize or dequantize, not both.")
if quantize:
print("[INFO] Quantizing")
model.load_weights(list(weights.items()))
weights, config = quantize_model(
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
)
if dequantize:
print("[INFO] Dequantizing")
model = dequantize_model(model)
weights = dict(tree_flatten(model.parameters()))
del model
save_weights(mlx_path, weights, donate_weights=True)
py_files = glob.glob(str(model_path / "*.py"))
for file in py_files:
shutil.copy(file, mlx_path)
tokenizer.save_pretrained(mlx_path)
save_config(config, config_path=mlx_path / "config.json")
if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, hf_path)

45
llms/setup.py Normal file
View File

@@ -0,0 +1,45 @@
# Copyright © 2024 Apple Inc.
import sys
from pathlib import Path
from setuptools import setup
package_dir = Path(__file__).parent / "mlx_lm"
with open(package_dir / "requirements.txt") as fid:
requirements = [l.strip() for l in fid.readlines()]
sys.path.append(str(package_dir))
from _version import __version__
setup(
name="mlx-lm",
version=__version__,
description="LLMs on Apple silicon with MLX and the Hugging Face Hub",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
readme="README.md",
author_email="mlx@group.apple.com",
author="MLX Contributors",
url="https://github.com/ml-explore/mlx-examples",
license="MIT",
install_requires=requirements,
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
python_requires=">=3.8",
extras_require={
"testing": ["datasets"],
},
entry_points={
"console_scripts": [
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
"mlx_lm.chat = mlx_lm.chat:main",
"mlx_lm.convert = mlx_lm.convert:main",
"mlx_lm.fuse = mlx_lm.fuse:main",
"mlx_lm.generate = mlx_lm.generate:main",
"mlx_lm.lora = mlx_lm.lora:main",
"mlx_lm.merge = mlx_lm.merge:main",
"mlx_lm.server = mlx_lm.server:main",
"mlx_lm.manage = mlx_lm.manage:main",
]
},
)

View File

@@ -160,12 +160,12 @@ class SpeculativeDecoder:
)
n_accepted += num_to_accept
n_draft += len(draft_tokens)
n_draft += draft_tokens.size
# Rewind the cache for unaccepted tokens:
if (n := len(draft_tokens)) > num_to_accept:
self.draft_model.truncate_cache(n - len(new_tokens))
self.model.truncate_cache(n - len(new_tokens) + 1)
if (n := draft_tokens.size) > num_to_accept:
self.draft_model.truncate_cache(n - new_tokens.size)
self.model.truncate_cache(n - new_tokens.size + 1)
n_steps += 1
@@ -181,7 +181,7 @@ class SpeculativeDecoder:
if ntoks >= max_tokens or new_tokens[-1] == self.tokenizer.eos_id:
break
draft_inputs = new_tokens[max(len(new_tokens) - 2, 0) :]
draft_inputs = new_tokens[max(new_tokens.size - 2, 0) :]
inputs = draft_inputs[-1:]
print(self.tokenizer.decode(outputs)[skip:], end="", flush=True)

View File

@@ -213,10 +213,10 @@ class TransformerDecoderLayer(nn.Module):
memory: mx.array,
mask: mx.array,
memory_mask: mx.array,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
):
y = self.ln1(x)
y, new_cache = self.self_attention(y, y, y, mask, cache)
y, cache = self.self_attention(y, y, y, mask, cache)
x = x + y
y = self.ln2(x)
@@ -227,7 +227,7 @@ class TransformerDecoderLayer(nn.Module):
y = self.dense(y)
x = x + y
return x, new_cache
return x, cache
def create_additive_causal_mask(N: int, offset: int = 0):

View File

@@ -0,0 +1,99 @@
# Copyright © 2024 Apple Inc.
import json
import os
import tempfile
import types
import unittest
from mlx_lm.tuner import datasets
from transformers import AutoTokenizer
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
class TestDatasets(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
if not os.path.isdir(cls.test_dir):
os.mkdir(cls.test_dir_fid.name)
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
def save_data(self, data):
for ds in ["train", "valid"]:
with open(os.path.join(self.test_dir, f"{ds}.jsonl"), "w") as fid:
for l in data:
json.dump(l, fid)
fid.write("\n")
def test_text(self):
data = {"text": "This is an example for the model."}
self.save_data(4 * [data])
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
train, valid, test = datasets.load_dataset(args, None)
self.assertEqual(len(train), 4)
self.assertEqual(len(valid), 4)
self.assertEqual(len(test), 0)
self.assertTrue(len(train[0]) > 0)
self.assertTrue(len(valid[0]) > 0)
self.assertTrue(isinstance(train, datasets.Dataset))
def test_completions(self):
data = {"prompt": "What is the capital of France?", "completion": "Paris."}
self.save_data(4 * [data])
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
train, valid, test = datasets.load_dataset(args, tokenizer)
self.assertEqual(len(train), 4)
self.assertEqual(len(valid), 4)
self.assertEqual(len(test), 0)
self.assertTrue(len(train[0]) > 0)
self.assertTrue(len(valid[0]) > 0)
self.assertTrue(isinstance(train, datasets.CompletionsDataset))
def test_chat(self):
data = {
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello."},
{"role": "assistant", "content": "How can I assistant you today."},
]
}
self.save_data(4 * [data])
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
train, valid, test = datasets.load_dataset(args, tokenizer)
self.assertEqual(len(train), 4)
self.assertEqual(len(valid), 4)
self.assertEqual(len(test), 0)
self.assertTrue(len(train[0]) > 0)
self.assertTrue(len(valid[0]) > 0)
self.assertTrue(isinstance(train, datasets.ChatDataset))
def test_hf(self):
args = types.SimpleNamespace(
hf_dataset={
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
},
test=False,
train=True,
)
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
train, valid, test = datasets.load_dataset(args, tokenizer)
self.assertTrue(len(train) > 0)
self.assertTrue(len(train[0]) > 0)
self.assertTrue(len(valid) > 0)
self.assertTrue(len(valid[0]) > 0)
self.assertEqual(len(test), 0)
if __name__ == "__main__":
unittest.main()

447
llms/tests/test_finetune.py Normal file
View File

@@ -0,0 +1,447 @@
# Copyright © 2024 Apple Inc.
import math
import sys
import unittest
from contextlib import contextmanager
from io import StringIO
from unittest.mock import MagicMock
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_flatten
from mlx_lm import lora, tuner
from mlx_lm.tuner.dora import DoRAEmbedding, DoRALinear
from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear
from mlx_lm.tuner.trainer import evaluate
from mlx_lm.tuner.utils import build_schedule
@contextmanager
def swapped_with_identity(obj, func):
old_func = getattr(obj, func)
setattr(obj, func, lambda x: x)
yield
setattr(obj, func, old_func)
class TestLora(unittest.TestCase):
def setUp(self):
self.capturedOutput = StringIO()
sys.stdout = self.capturedOutput
def tearDown(self):
sys.stdout = sys.__stdout__
def test_llama(self):
from mlx_lm.models import llama
args = llama.ModelArgs(
model_type="llama",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
tie_word_embeddings=False,
)
lora_layers = 4
def check_config(params, expected_trainable_parameters=None):
n_keys = 2
if "keys" in params:
n_keys = len(params["keys"])
model = llama.Model(args)
model.freeze()
tuner.utils.linear_to_lora_layers(model, lora_layers, params)
trainable_params = sum(
v.size for _, v in tree_flatten(model.trainable_parameters())
)
expected_trainable_parameters = expected_trainable_parameters or (
lora_layers * params["rank"] * args.hidden_size * 2 * n_keys
)
self.assertEqual(trainable_params, expected_trainable_parameters)
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
check_config(params)
params["rank"] = 1
check_config(params)
params["keys"] = ["self_attn.k_proj"]
check_config(params)
params["keys"] = ["lm_head"]
check_config(
params,
expected_trainable_parameters=(
params["rank"] * (args.hidden_size + args.vocab_size)
),
)
params["keys"] = ["model.embed_tokens"]
check_config(
params,
expected_trainable_parameters=(
params["rank"] * (args.hidden_size + args.vocab_size)
),
)
def test_gpt_neox(self):
from mlx_lm.models import gpt_neox
args = gpt_neox.ModelArgs(
model_type="gpt_neox",
max_position_embeddings=2048,
hidden_size=6144,
num_attention_heads=64,
num_hidden_layers=44,
layer_norm_eps=1e-5,
vocab_size=50432,
rotary_emb_base=10_000,
rotary_pct=0.25,
)
num_lora_layers = 4
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
model = gpt_neox.Model(args)
model.freeze()
tuner.utils.linear_to_lora_layers(model, num_lora_layers, params)
def test_lora_embedding(self):
num_embeddings = 256
dims = 512
tokens = mx.array([1, 2, 3])
embedding = nn.QuantizedEmbedding(num_embeddings, dims)
dequantized_weight = mx.dequantize(
embedding.weight,
embedding.scales,
embedding.biases,
embedding.group_size,
embedding.bits,
)
lora_emb = LoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10)
new_embedding = lora_emb.fuse(de_quantize=True)
self.assertTrue(mx.array_equal(dequantized_weight, new_embedding.weight))
self.assertTrue(mx.array_equal(embedding(tokens), lora_emb(tokens)))
# as_linear
attn_output = mx.random.uniform(shape=(dims,))
embedding_lin_out = lora_emb.as_linear(attn_output)
self.assertEqual(embedding_lin_out.shape, (num_embeddings,))
self.assertTrue(
mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output))
)
# change the value of lora_b and the embeddings will no longer be equal
lora_emb.lora_b = mx.random.uniform(shape=lora_emb.lora_b.shape)
new_embedding = lora_emb.fuse(de_quantize=True)
self.assertFalse(mx.array_equal(dequantized_weight, new_embedding.weight))
self.assertFalse(mx.array_equal(embedding(tokens), lora_emb(tokens)))
class TestDora(unittest.TestCase):
def test_dora_embedding(self):
num_embeddings = 256
dims = 512
tokens = mx.array([1, 2, 3])
embedding = nn.Embedding(num_embeddings, dims)
dora_emb = DoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10)
new_embedding = dora_emb.fuse()
self.assertTrue(mx.array_equal(embedding.weight, new_embedding.weight))
self.assertTrue(mx.array_equal(embedding(tokens), dora_emb(tokens)))
# as_linear
attn_output = mx.random.uniform(shape=(dims,))
embedding_lin_out = dora_emb.as_linear(attn_output)
self.assertEqual(embedding_lin_out.shape, (num_embeddings,))
self.assertTrue(
mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output))
)
# change the value of lora_b and the embeddings will no longer be equal
dora_emb.lora_b = mx.random.uniform(shape=dora_emb.lora_b.shape)
new_embedding = dora_emb.fuse()
self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight))
self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens)))
def test_llama(self):
from mlx_lm.models import llama
hidden_size = 1024
intermediate_size = 2048
args = llama.ModelArgs(
model_type="llama",
hidden_size=hidden_size,
num_hidden_layers=4,
intermediate_size=intermediate_size,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
dora_layers = 4
def check_config(params):
n_keys = 2
if "keys" in params:
n_keys = len(params["keys"])
model = llama.Model(args)
model.freeze()
tuner.utils.linear_to_lora_layers(model, dora_layers, params, use_dora=True)
trainable_params = sum(
v.size for _, v in tree_flatten(model.trainable_parameters())
)
self.assertEqual(
trainable_params,
dora_layers
* (params["rank"] * hidden_size * 2 * n_keys + n_keys * hidden_size),
)
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
check_config(params)
params["rank"] = 1
check_config(params)
params["keys"] = ["self_attn.k_proj"]
check_config(params)
def test_dora_m_parameter(self):
dora_lin = DoRALinear(input_dims=100, output_dims=100)
self.assertTrue(
mx.allclose(dora_lin.m, mx.linalg.norm(dora_lin.linear.weight, axis=1))
)
# Recomputes m when changing Linear
inital_m = dora_lin.m
lin = nn.Linear(10, 10)
dora_lin.set_linear(lin)
self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(lin.weight, axis=1)))
# Works with quantized weights
quantized_linear = nn.QuantizedLinear(512, 512)
dora_lin.set_linear(quantized_linear)
dequantized_weight = mx.dequantize(
quantized_linear.weight,
quantized_linear.scales,
quantized_linear.biases,
quantized_linear.group_size,
quantized_linear.bits,
)
self.assertTrue(
mx.allclose(dora_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
)
def test_dora_from_linear(self):
in_dims = 256
out_dims = 256
r = 4
linear = nn.Linear(in_dims, out_dims)
dora_lin = DoRALinear.from_base(linear, r)
self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(linear.weight, axis=1)))
self.assertEqual(dora_lin.lora_a.shape, (in_dims, r))
self.assertEqual(dora_lin.lora_b.shape, (r, out_dims))
self.assertEqual(dora_lin.m.shape, (out_dims,))
quantized_linear = nn.QuantizedLinear(in_dims, out_dims)
dequantized_weight = mx.dequantize(
quantized_linear.weight,
quantized_linear.scales,
quantized_linear.biases,
quantized_linear.group_size,
quantized_linear.bits,
)
dora_quant_lin = DoRALinear.from_base(quantized_linear, r)
self.assertTrue(
mx.allclose(dora_quant_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
)
self.assertEqual(dora_quant_lin.lora_a.shape, (in_dims, r))
self.assertEqual(dora_quant_lin.lora_b.shape, (r, out_dims))
self.assertEqual(dora_quant_lin.m.shape, (out_dims,))
def test_dora_to_linear(self):
in_dims = 256
out_dims = 256
r = 4
linear = nn.Linear(in_dims, out_dims, bias=True)
dora_lin = DoRALinear.from_base(linear, r)
to_linear = dora_lin.fuse()
self.assertTrue(mx.allclose(linear.weight, to_linear.weight))
self.assertTrue(mx.allclose(linear.bias, to_linear.bias))
def dequantize_weight(quantized_linear):
return mx.dequantize(
quantized_linear.weight,
quantized_linear.scales,
quantized_linear.biases,
quantized_linear.group_size,
quantized_linear.bits,
)
quantized_linear = nn.QuantizedLinear(in_dims, out_dims, bias=True)
dora_quantized_linear = DoRALinear.from_base(quantized_linear, r)
# Dequantize
to_linear_from_quantized = dora_quantized_linear.fuse(de_quantize=True)
self.assertTrue(
mx.allclose(quantized_linear.bias, to_linear_from_quantized.bias)
)
self.assertTrue(
mx.allclose(
dequantize_weight(quantized_linear), to_linear_from_quantized.weight
)
)
def test_dora_dtype(self):
in_dims = 256
out_dims = 256
r = 4
linear = nn.Linear(in_dims, out_dims, bias=True)
linear.set_dtype(mx.float16)
dora_lin = DoRALinear.from_base(linear, r)
x = mx.random.uniform(shape=(2, 256)).astype(mx.float16)
self.assertEqual(dora_lin(x).dtype, mx.float16)
class TestScheduleConfig(unittest.TestCase):
def test_join(self):
config = {"name": "cosine_decay", "warmup": 100, "arguments": [1e-5, 100]}
cos_with_warmup = build_schedule(config)
self.assertIsNotNone(cos_with_warmup)
self.assertEqual(cos_with_warmup(0), 0.0)
self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
optimizer = opt.Adam(learning_rate=cos_with_warmup)
for _ in range(100):
optimizer.update({}, {})
self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
for _ in range(100):
optimizer.update({}, {})
expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
def test_single_schedule(self):
config = {
"name": "cosine_decay",
"arguments": [0.1, 10],
}
lr_schedule = build_schedule(config)
lr = lr_schedule(4)
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_non_zero_warmup(self):
config = {
"name": "cosine_decay",
"warmup": 10,
"warmup_init": 1e-6,
"arguments": [1e-5, 20],
}
lr_schedule = build_schedule(config)
lr = lr_schedule(0)
self.assertAlmostEqual(lr, 1e-6, delta=1e-7)
def test_malformed_config(self):
config = {"warmup": 100}
self.assertRaises(KeyError, build_schedule, config)
config = {"cosine_decay": None}
self.assertRaises(KeyError, build_schedule, config)
def test_evaluate_calls(self):
mock_model = MagicMock()
mock_dataset = MagicMock()
mock_tokenizer = MagicMock()
mock_default_loss = MagicMock()
mock_iterate_batches = MagicMock()
mock_iterate_batches.return_value = [
(MagicMock(), MagicMock()),
(MagicMock(), MagicMock()),
(MagicMock(), MagicMock()),
(MagicMock(), MagicMock()),
(MagicMock(), MagicMock()),
]
mock_default_loss.side_effect = [
(MagicMock(return_value=0.5), MagicMock(return_value=100)),
(MagicMock(return_value=0.3), MagicMock(return_value=200)),
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
]
with swapped_with_identity(mx.distributed, "all_sum"):
evaluate(
model=mock_model,
dataset=mock_dataset,
tokenizer=mock_tokenizer,
batch_size=2,
num_batches=2,
max_seq_length=2048,
loss=mock_default_loss,
iterate_batches=mock_iterate_batches,
)
mock_iterate_batches.assert_called_once_with(
dataset=mock_dataset,
tokenizer=mock_tokenizer,
batch_size=2,
max_seq_length=2048,
)
self.assertEqual(mock_default_loss.call_count, 2)
def test_evaluate_infinite_batches(self):
mock_model = MagicMock()
mock_dataset = MagicMock()
mock_tokenizer = MagicMock()
mock_default_loss = MagicMock()
mock_iterate_batches = MagicMock()
mock_iterate_batches.return_value = [
(MagicMock(), MagicMock()),
(MagicMock(), MagicMock()),
(MagicMock(), MagicMock()),
]
mock_default_loss.side_effect = [
(MagicMock(return_value=0.5), MagicMock(return_value=100)),
(MagicMock(return_value=0.3), MagicMock(return_value=200)),
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
]
with swapped_with_identity(mx.distributed, "all_sum"):
evaluate(
model=mock_model,
dataset=mock_dataset,
tokenizer=mock_tokenizer,
batch_size=2,
num_batches=-1,
max_seq_length=2048,
loss=mock_default_loss,
iterate_batches=mock_iterate_batches,
)
mock_iterate_batches.assert_called_once_with(
dataset=mock_dataset,
tokenizer=mock_tokenizer,
batch_size=2,
max_seq_length=2048,
)
self.assertEqual(mock_default_loss.call_count, 3)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,55 @@
# Copyright © 2024 Apple Inc.
import unittest
from mlx_lm.utils import generate, load
class TestGenerate(unittest.TestCase):
@classmethod
def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
def test_generate(self):
# Simple test that generation runs
text = generate(
self.model, self.tokenizer, "hello", max_tokens=5, verbose=False
)
def test_generate_with_logit_bias(self):
logit_bias = {0: 2000.0, 1: -20.0}
text = generate(
self.model,
self.tokenizer,
"hello",
max_tokens=5,
verbose=False,
logit_bias=logit_bias,
)
self.assertEqual(text, "!!!!!")
def test_generate_with_processor(self):
init_toks = self.tokenizer.encode("hello")
all_toks = None
def logits_processor(toks, logits):
nonlocal all_toks
all_toks = toks
return logits
generate(
self.model,
self.tokenizer,
"hello",
max_tokens=5,
verbose=False,
logits_processors=[logits_processor],
)
self.assertEqual(len(all_toks), len(init_toks) + 5)
if __name__ == "__main__":
unittest.main()

58
llms/tests/test_gguf.py Normal file
View File

@@ -0,0 +1,58 @@
import os
import tempfile
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch
import mlx.core as mx
from mlx_lm.gguf import convert_to_gguf
class TestConvertToGGUFWithoutMocks(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
cls.tokenizer_file_path = os.path.join(cls.test_dir, "tokenizer.json")
with open(cls.tokenizer_file_path, "w") as f:
f.write("{}")
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
@patch("transformers.AutoTokenizer.from_pretrained")
@patch("mlx.core.save_gguf")
def test_convert_to_gguf(
self,
mock_save_gguf,
mock_from_pretrained,
):
mock_tokenizer = MagicMock()
mock_tokenizer.vocab_size = 3
mock_tokenizer.get_added_vocab.return_value = {}
mock_tokenizer.get_vocab.return_value = {"<pad>": 0, "hello": 1, "world": 2}
mock_tokenizer.all_special_tokens = ["<pad>"]
mock_tokenizer.all_special_ids = [0]
mock_from_pretrained.return_value = mock_tokenizer
model_path = Path(self.test_dir)
weights = {
"self_attn.q_proj.weight": mx.random.uniform(shape=[768, 768]),
}
config = {
"num_attention_heads": 1,
"num_hidden_layers": 1,
"hidden_size": 768,
"intermediate_size": 3072,
"_name_or_path": "test-llama",
}
output_file_path = "/fake/output/path/gguf_model.gguf"
convert_to_gguf(model_path, weights, config, output_file_path)
called_args, _ = mock_save_gguf.call_args
self.assertEqual(called_args[0], output_file_path)
if __name__ == "__main__":
unittest.main()

765
llms/tests/test_models.py Normal file
View File

@@ -0,0 +1,765 @@
# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
from mlx.utils import tree_map
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
class TestModels(unittest.TestCase):
def test_kv_cache(self):
cache = KVCache()
k = mx.ones((1, 4, 1, 32), mx.float16)
v = mx.ones((1, 4, 1, 32), mx.float16)
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up, k))
self.assertTrue(mx.array_equal(v_up, v))
self.assertEqual(cache.offset, 1)
k = mx.ones((1, 4, cache.step, 32), mx.float16)
v = mx.ones((1, 4, cache.step, 32), mx.float16)
k_up, v_up = cache.update_and_fetch(k, v)
expected = mx.ones((1, 4, cache.step + 1, 32), mx.float16)
self.assertTrue(mx.array_equal(k_up, expected))
self.assertTrue(mx.array_equal(v_up, expected))
self.assertEqual(cache.offset, cache.step + 1)
def test_rotating_kv_cache(self):
b, h, d = 1, 2, 32
cache = RotatingKVCache(max_size=8, step=4)
k = mx.random.uniform(shape=(b, h, 2, d))
v = mx.random.uniform(shape=(b, h, 2, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up, k))
self.assertTrue(mx.array_equal(v_up, v))
self.assertEqual(cache.offset, 2)
k = mx.random.uniform(shape=(b, h, 5, d))
v = mx.random.uniform(shape=(b, h, 5, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., 2:, :], k))
self.assertTrue(mx.array_equal(v_up[..., 2:, :], v))
k = mx.random.uniform(shape=(b, h, 4, d))
v = mx.random.uniform(shape=(b, h, 4, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., -4:, :], k))
self.assertTrue(mx.array_equal(v_up[..., -4:, :], v))
idx = 0
for _ in range(10):
k = mx.random.uniform(shape=(b, h, 1, d))
v = mx.random.uniform(shape=(b, h, 1, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
idx += 1
idx %= 8
# Try with nonzero keep
cache = RotatingKVCache(max_size=8, step=4, keep=2)
# Check a large update
k = mx.random.uniform(shape=(b, h, 20, d))
v = mx.random.uniform(shape=(b, h, 20, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up, k))
self.assertTrue(mx.array_equal(v_up, v))
# A bunch of small updates
self.assertEqual(cache.offset, 20)
idx = 2
for i in range(10):
k = mx.random.uniform(shape=(b, h, 1, d))
v = mx.random.uniform(shape=(b, h, 1, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
self.assertEqual(cache.offset, 21 + i)
idx += 1
if idx >= 8:
idx = 2
def test_rotating_kv_cache_chat_mode(self):
# Test that the rotating kv cache can handle
# alternating prompt/prefill with generation
d = 4
h = 2
cache = RotatingKVCache(max_size=18, step=4)
x = mx.random.uniform(shape=(1, h, 8, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 8)
self.assertEqual(cache.offset, 8)
x = mx.random.uniform(shape=(1, h, 1, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 9)
self.assertEqual(cache.offset, 9)
self.assertTrue(mx.allclose(x, k[..., 8:9, :]))
x = mx.random.uniform(shape=(1, h, 2, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 11)
self.assertEqual(cache.offset, 11)
self.assertTrue(mx.allclose(x, k[..., 9:11, :]))
x = mx.random.uniform(shape=(1, h, 3, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 14)
self.assertEqual(cache.offset, 14)
self.assertTrue(mx.allclose(x, k[..., 11:14, :]))
x = mx.random.uniform(shape=(1, h, 6, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(cache.offset, 20)
self.assertTrue(mx.allclose(x, k[..., -6:, :]))
x = mx.random.uniform(shape=(1, h, 2, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(cache.offset, 22)
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
def model_test_runner(self, model, model_type, vocab_size, num_layers):
self.assertEqual(len(model.layers), num_layers)
self.assertEqual(model.model_type, model_type)
for t in [mx.float32, mx.float16]:
model.update(tree_map(lambda p: p.astype(t), model.parameters()))
inputs = mx.array([[0, 1]])
outputs = model(inputs)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
cache = make_prompt_cache(model)
outputs = model(inputs, cache)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
self.assertEqual(outputs.shape, (1, 1, vocab_size))
self.assertEqual(outputs.dtype, t)
def test_llama(self):
from mlx_lm.models import llama
args = llama.ModelArgs(
model_type="llama",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
model = llama.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phi2(self):
from mlx_lm.models import phi
args = phi.ModelArgs()
model = phi.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phixtral(self):
from mlx_lm.models import phixtral
args = phixtral.ModelArgs(
"phixtral", num_vocab=1000, num_layers=4, model_dim=1024
)
model = phixtral.Model(args)
self.model_test_runner(model, args.model_type, args.num_vocab, args.num_layers)
def test_phi3(self):
from mlx_lm.models import phi3
args = phi3.ModelArgs(
model_type="phi3",
hidden_size=3072,
num_hidden_layers=32,
intermediate_size=8192,
num_attention_heads=32,
rms_norm_eps=1e-5,
vocab_size=32064,
)
model = phi3.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gemma(self):
from mlx_lm.models import gemma
args = gemma.ModelArgs(
model_type="gemma",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
head_dim=128,
rms_norm_eps=1e-5,
vocab_size=10_000,
num_key_value_heads=4,
)
model = gemma.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_mixtral(self):
from mlx_lm.models import mixtral
# Make a baby mixtral, because it will actually do the
# eval
args = mixtral.ModelArgs(
model_type="mixtral",
vocab_size=100,
hidden_size=32,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=4,
num_experts_per_tok=2,
num_key_value_heads=2,
num_local_experts=4,
)
model = mixtral.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
@unittest.skip("requires ai2-olmo")
def test_olmo(self):
from mlx_lm.models import olmo
args = olmo.ModelArgs(
model_type="olmo",
d_model=1024,
n_layers=4,
mlp_hidden_size=2048,
n_heads=2,
vocab_size=10_000,
embedding_size=10_000,
)
model = olmo.Model(args)
self.model_test_runner(
model,
args.model_type,
args.vocab_size,
args.n_layers,
)
def test_qwen2_moe(self):
from mlx_lm.models import qwen2_moe
args = qwen2_moe.ModelArgs(
model_type="qwen2_moe",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
num_experts_per_tok=4,
num_experts=16,
moe_intermediate_size=1024,
shared_expert_intermediate_size=2048,
)
model = qwen2_moe.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_qwen2(self):
from mlx_lm.models import qwen2
args = qwen2.ModelArgs(
model_type="qwen2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
model = qwen2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_qwen(self):
from mlx_lm.models import qwen
args = qwen.ModelArgs(
model_type="qwen",
)
model = qwen.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_plamo(self):
from mlx_lm.models import plamo
args = plamo.ModelArgs(
model_type="plamo",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=8,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
model = plamo.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_stablelm(self):
from mlx_lm.models import stablelm
args = stablelm.ModelArgs(
model_type="stablelm",
vocab_size=10_000,
hidden_size=1024,
num_attention_heads=4,
num_hidden_layers=4,
num_key_value_heads=2,
partial_rotary_factor=1.0,
intermediate_size=2048,
layer_norm_eps=1e-2,
rope_theta=10_000,
use_qkv_bias=False,
)
model = stablelm.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
# StableLM 2
args = stablelm.ModelArgs(
model_type="stablelm",
vocab_size=10000,
hidden_size=512,
num_attention_heads=8,
num_hidden_layers=4,
num_key_value_heads=2,
partial_rotary_factor=0.25,
intermediate_size=1024,
layer_norm_eps=1e-5,
rope_theta=10000,
use_qkv_bias=True,
)
model = stablelm.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_starcoder2(self):
from mlx_lm.models import starcoder2
args = starcoder2.ModelArgs(
model_type="starcoder2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
num_key_value_heads=4,
)
model = starcoder2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_cohere(self):
from mlx_lm.models import cohere
args = cohere.ModelArgs(
model_type="cohere",
)
model = cohere.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_dbrx(self):
from mlx_lm.models import dbrx
args = dbrx.ModelArgs(
model_type="dbrx",
d_model=1024,
ffn_config={"ffn_hidden_size": 2048, "moe_num_experts": 4, "moe_top_k": 2},
attn_config={"kv_n_heads": 2, "clip_qkv": True, "rope_theta": 10000},
n_layers=4,
n_heads=4,
vocab_size=10_000,
)
model = dbrx.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layers)
def test_minicpm(self):
from mlx_lm.models import minicpm
args = minicpm.ModelArgs(
model_type="minicpm",
hidden_size=1024,
dim_model_base=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-4,
vocab_size=10000,
num_key_value_heads=2,
scale_depth=1.0,
scale_emb=1.0,
)
model = minicpm.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_mamba(self):
from mlx_lm.models import mamba
args = mamba.ModelArgs(
model_type="mamba",
vocab_size=10000,
use_bias=False,
use_conv_bias=True,
conv_kernel=4,
hidden_size=768,
num_hidden_layers=24,
state_size=16,
intermediate_size=1536,
time_step_rank=48,
)
model = mamba.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gpt2(self):
from mlx_lm.models import gpt2
args = gpt2.ModelArgs(
model_type="gpt2",
n_ctx=1024,
n_embd=768,
n_head=12,
n_layer=12,
n_positions=1024,
layer_norm_epsilon=1e-5,
vocab_size=50256,
)
model = gpt2.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
def test_gpt_neox(self):
from mlx_lm.models import gpt_neox
args = gpt_neox.ModelArgs(
model_type="gpt_neox",
max_position_embeddings=2048,
hidden_size=6144,
num_attention_heads=64,
num_hidden_layers=44,
layer_norm_eps=1e-5,
vocab_size=50432,
rotary_emb_base=10_000,
rotary_pct=0.25,
)
model = gpt_neox.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_openelm(self):
from mlx_lm.models import openelm
args = openelm.ModelArgs(
model_type="openelm",
ffn_dim_divisor=256,
ffn_multipliers=[
0.5,
0.73,
0.97,
1.2,
1.43,
1.67,
1.9,
2.13,
2.37,
2.6,
2.83,
3.07,
3.3,
3.53,
3.77,
4.0,
],
head_dim=64,
model_dim=1280,
normalize_qk_projections=True,
num_kv_heads=[3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5],
num_query_heads=[
12,
12,
12,
12,
12,
16,
16,
16,
16,
16,
16,
16,
20,
20,
20,
20,
],
num_transformer_layers=16,
vocab_size=32000,
)
model = openelm.Model(args)
self.model_test_runner(
model,
args.model_type,
args.vocab_size,
len(args.ffn_multipliers),
)
def test_internlm2(self):
from mlx_lm.models import internlm2
args = internlm2.ModelArgs(
model_type="internlm2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10000,
)
model = internlm2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_llama3_1(self):
from mlx_lm.models import llama
args = llama.ModelArgs(
model_type="llama",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
max_position_embeddings=128,
mlp_bias=False,
num_key_value_heads=2,
rope_scaling={
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
)
model = llama.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_deepseek(self):
from mlx_lm.models import deepseek
args = deepseek.ModelArgs(
model_type="deepseek",
vocab_size=1024,
hidden_size=128,
intermediate_size=256,
moe_intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=8,
num_key_value_heads=4,
)
model = deepseek.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_deepseek_v2(self):
from mlx_lm.models import deepseek_v2
args = deepseek_v2.ModelArgs(
model_type="deepseek_v2",
vocab_size=1024,
hidden_size=128,
intermediate_size=256,
moe_intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
kv_lora_rank=4,
q_lora_rank=4,
qk_rope_head_dim=32,
v_head_dim=16,
qk_nope_head_dim=32,
rope_scaling={
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn",
},
)
model = deepseek_v2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gemma2(self):
from mlx_lm.models import gemma2
args = gemma2.ModelArgs(
model_type="gemma2",
hidden_size=128,
num_hidden_layers=4,
intermediate_size=256,
num_attention_heads=2,
head_dim=32,
rms_norm_eps=1e-4,
vocab_size=1024,
num_key_value_heads=2,
)
model = gemma2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gpt_bigcode(self):
from mlx_lm.models import gpt_bigcode
args = gpt_bigcode.ModelArgs(
model_type="gpt_bigcode",
n_embd=128,
n_layer=128,
n_inner=256,
n_head=4,
n_positions=1000,
layer_norm_epsilon=1e-5,
vocab_size=1024,
)
model = gpt_bigcode.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
def test_nemotron(self):
from mlx_lm.models import nemotron
args = nemotron.ModelArgs(
model_type="nemotron",
hidden_size=128,
hidden_act="gelu",
num_hidden_layers=4,
intermediate_size=256,
num_attention_heads=4,
norm_eps=1e-5,
vocab_size=1024,
num_key_value_heads=2,
)
model = nemotron.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phi3small(self):
from mlx_lm.models import phi3small
args = phi3small.ModelArgs(
model_type="phi3small",
hidden_size=128,
dense_attention_every_n_layers=2,
ff_intermediate_size=256,
gegelu_limit=1.0,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
layer_norm_epsilon=1e-4,
vocab_size=1000,
)
model = phi3small.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phimoe(self):
from mlx_lm.models import phimoe
args = phimoe.ModelArgs(
model_type="phimoe",
vocab_size=320,
hidden_size=128,
intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=4,
rope_scaling={
"long_factor": [1.0] * 16,
"long_mscale": 1.243163121016122,
"original_max_position_embeddings": 4096,
"short_factor": [1.0] * 16,
"short_mscale": 1.243163121016122,
"type": "longrope",
},
)
model = phimoe.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_recurrent_gemma(self):
from mlx_lm.models import recurrent_gemma
args = recurrent_gemma.ModelArgs(
model_type="recurrent_gemma",
hidden_size=128,
attention_bias=False,
conv1d_width=3,
intermediate_size=256,
logits_soft_cap=1.0,
num_attention_heads=4,
num_hidden_layers=4,
num_key_value_heads=2,
rms_norm_eps=1e-4,
rope_theta=1000,
attention_window_size=1024,
vocab_size=1000,
block_types=["recurrent", "recurrent", "attention"],
)
model = recurrent_gemma.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
if __name__ == "__main__":
unittest.main()

Some files were not shown because too many files have changed in this diff Show More