Merge branch 'adding-support-for-mamba2' of https://github.com/Goekdeniz-Guelmez/mlx-examples into adding-support-for-mamba2

This commit is contained in:
Goekdeniz-Guelmez 2024-10-16 21:09:42 +02:00
commit 181d6abedc
90 changed files with 4951 additions and 998 deletions

View File

@ -26,8 +26,8 @@ jobs:
- run:
name: Install dependencies
command: |
brew install python@3.8
python3.8 -m venv env
brew install python@3.9
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install unittest-xml-reporting

3
.gitignore vendored
View File

@ -6,6 +6,9 @@ __pycache__/
# C extensions
*.so
# Vim
*.swp
# Distribution / packaging
.Python
build/

View File

@ -20,8 +20,10 @@ Some more useful examples are listed below.
### Image Models
- Generating images
- [FLUX](flux)
- [Stable Diffusion or SDXL](stable_diffusion)
- Image classification using [ResNets on CIFAR-10](cifar).
- Generating images with [Stable Diffusion or SDXL](stable_diffusion).
- Convolutional variational autoencoder [(CVAE) on MNIST](cvae).
### Audio Models

View File

@ -33,13 +33,14 @@ An example using the model:
```python
import mlx.core as mx
from utils import load, load_audio, save_audio
from encodec import EncodecModel
from utils import load_audio, save_audio
# Load the 48 KHz model and preprocessor.
model, processor = load("mlx-community/encodec-48khz-float32")
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
# Load an audio file
audio = load_audio("path/to/aduio", model.sampling_rate, model.channels)
audio = load_audio("path/to/audio", model.sampling_rate, model.channels)
# Preprocess the audio (this can also be a list of arrays for batched
# processing).

View File

@ -3,9 +3,10 @@
import time
import mlx.core as mx
from utils import load
model, processor = load("mlx-community/encodec-48khz-float32")
from encodec import EncodecModel
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
audio = mx.random.uniform(shape=(288000, 2))
feats, mask = processor(audio)

View File

@ -10,7 +10,6 @@ from typing import Any, Dict, Union
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
import encodec

View File

@ -1,7 +1,10 @@
# Copyright © 2024 Apple Inc.
import functools
import json
import math
from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
from typing import List, Optional, Tuple, Union
import mlx.core as mx
@ -669,3 +672,70 @@ class EncodecModel(nn.Module):
if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]:
audio_values = audio_values[:, : padding_mask.shape[1]]
return audio_values
@classmethod
def from_pretrained(cls, path_or_repo: str):
from huggingface_hub import snapshot_download
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
model = EncodecModel(config)
model.load_weights(str(path / "model.safetensors"))
processor = functools.partial(
preprocess_audio,
sampling_rate=config.sampling_rate,
chunk_length=model.chunk_length,
chunk_stride=model.chunk_stride,
)
mx.eval(model)
return model, processor
def preprocess_audio(
raw_audio: Union[mx.array, List[mx.array]],
sampling_rate: int = 24000,
chunk_length: Optional[int] = None,
chunk_stride: Optional[int] = None,
):
r"""
Prepare inputs for the EnCodec model.
Args:
raw_audio (mx.array or List[mx.array]): The sequence or batch of
sequences to be processed.
sampling_rate (int): The sampling rate at which the audio waveform
should be digitalized.
chunk_length (int, optional): The model's chunk length.
chunk_stride (int, optional): The model's chunk stride.
"""
if not isinstance(raw_audio, list):
raw_audio = [raw_audio]
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
max_length = max(array.shape[0] for array in raw_audio)
if chunk_length is not None:
max_length += chunk_length - (max_length % chunk_stride)
inputs = []
masks = []
for x in raw_audio:
length = x.shape[0]
mask = mx.ones((length,), dtype=mx.bool_)
difference = max_length - length
if difference > 0:
mask = mx.pad(mask, (0, difference))
x = mx.pad(x, ((0, difference), (0, 0)))
inputs.append(x)
masks.append(mask)
return mx.stack(inputs), mx.stack(masks)

View File

@ -1,10 +1,12 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
from utils import load, load_audio, save_audio
from utils import load_audio, save_audio
from encodec import EncodecModel
# Load the 48 KHz model and preprocessor.
model, processor = load("mlx-community/encodec-48khz-float32")
model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
# Load an audio file
audio = load_audio("/path/to/audio", model.sampling_rate, model.channels)

View File

@ -3,9 +3,10 @@
import mlx.core as mx
import numpy as np
import torch
from datasets import Audio, load_dataset
from transformers import AutoProcessor, EncodecModel
from utils import load, load_audio, preprocess_audio
from transformers import AutoProcessor
from transformers import EncodecModel as PTEncodecModel
from encodec import EncodecModel, preprocess_audio
def compare_processors():
@ -30,8 +31,8 @@ def compare_processors():
def compare_models():
pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz")
mx_model, _ = load("mlx-community/encodec-48khz-float32")
pt_model = PTEncodecModel.from_pretrained("facebook/encodec_48khz")
mx_model, _ = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32")
np.random.seed(0)
audio_length = 190560

View File

@ -1,16 +1,7 @@
# Copyright © 2024 Apple Inc.
import functools
import json
from pathlib import Path
from types import SimpleNamespace
from typing import List, Optional, Union
import mlx.core as mx
import numpy as np
from huggingface_hub import snapshot_download
import encodec
def save_audio(file: str, audio: mx.array, sampling_rate: int):
@ -59,71 +50,3 @@ def load_audio(file: str, sampling_rate: int, channels: int):
out = mx.array(np.frombuffer(out, np.int16))
return out.reshape(-1, channels).astype(mx.float32) / 32767.0
def preprocess_audio(
raw_audio: Union[mx.array, List[mx.array]],
sampling_rate: int = 24000,
chunk_length: Optional[int] = None,
chunk_stride: Optional[int] = None,
):
r"""
Prepare inputs for the EnCodec model.
Args:
raw_audio (mx.array or List[mx.array]): The sequence or batch of
sequences to be processed.
sampling_rate (int): The sampling rate at which the audio waveform
should be digitalized.
chunk_length (int, optional): The model's chunk length.
chunk_stride (int, optional): The model's chunk stride.
"""
if not isinstance(raw_audio, list):
raw_audio = [raw_audio]
raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio]
max_length = max(array.shape[0] for array in raw_audio)
if chunk_length is not None:
max_length += chunk_length - (max_length % chunk_stride)
inputs = []
masks = []
for x in raw_audio:
length = x.shape[0]
mask = mx.ones((length,), dtype=mx.bool_)
difference = max_length - length
if difference > 0:
mask = mx.pad(mask, (0, difference))
x = mx.pad(x, ((0, difference), (0, 0)))
inputs.append(x)
masks.append(mask)
return mx.stack(inputs), mx.stack(masks)
def load(path_or_repo):
"""
Load the model and audo preprocessor.
"""
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
model = encodec.EncodecModel(config)
model.load_weights(str(path / "model.safetensors"))
processor = functools.partial(
preprocess_audio,
sampling_rate=config.sampling_rate,
chunk_length=model.chunk_length,
chunk_stride=model.chunk_stride,
)
mx.eval(model)
return model, processor

212
flux/README.md Normal file
View File

@ -0,0 +1,212 @@
FLUX
====
FLUX implementation in MLX. The implementation is ported directly from
[https://github.com/black-forest-labs/flux](https://github.com/black-forest-labs/flux)
and the model weights are downloaded directly from the Hugging Face Hub.
The goal of this example is to be clean, educational and to allow for
experimentation with finetuning FLUX models as well as adding extra
functionality such as in-/outpainting, guidance with custom losses etc.
![MLX image](static/generated-mlx.png)
*Image generated using FLUX-dev in MLX and the prompt 'An image in the style of
tron emanating futuristic technology with the word "MLX" in the center with
capital red letters.'*
Installation
------------
The dependencies are minimal, namely:
- `huggingface-hub` to download the checkpoints.
- `regex` for the tokenization
- `tqdm`, `PIL`, and `numpy` for the scripts
- `sentencepiece` for the T5 tokenizer
- `datasets` for using an HF dataset directly
You can install all of the above with the `requirements.txt` as follows:
pip install -r requirements.txt
Usage
---------
You can use the following command to generate an image, using `--output` to specify the storage location of the image, defaulting to `out.png`.
```shell
python txt2image.py --model schnell \
--n-images 1 \
--image-size 256x512 \
--verbose \
'A photo of an astronaut riding a horse on Mars.'
```
For more parameters, please use the `--help` command to view.
```shell
python txt2image.py --help
```
Inference
---------
Inference in this example is similar to the stable diffusion example. The
classes to get you started are `FluxPipeline` from the `flux` module.
```python
import mlx.core as mx
from flux import FluxPipeline
# This will download all the weights from HF hub
flux = FluxPipeline("flux-schnell")
# Make a generator that returns the latent variables from the reverse diffusion
# process
latent_generator = flux.generate_latents(
"A photo of an astronaut riding a horse on Mars",
num_steps=4,
latent_size=(32, 64), # 256x512 image
)
# The first return value of the generator contains the conditioning and the
# random noise at the beginning of the diffusion process.
conditioning = next(latent_generator)
(
x_T, # The initial noise
x_positions, # The integer positions used for image positional encoding
t5_conditioning, # The T5 features from the text prompt
t5_positions, # Integer positions for text (normally all 0s)
clip_conditioning, # The clip text features from the text prompt
) = conditioning
# Returning the conditioning as the first output from the generator allows us
# to unload T5 and clip before running the diffusion transformer.
mx.eval(conditioning)
# Evaluate each diffusion step
for x_t in latent_generator:
mx.eval(x_t)
# Note that we need to pass the latent size because it is collapsed and
# patchified in x_t and we need to unwrap it.
img = flux.decode(x_t, latent_size=(32, 64))
```
The above are essentially the implementation of the `txt2image.py` script
except for some additional logic to quantize and/or load trained adapters. One
can use the script as follows:
```shell
python txt2image.py \
--n-images 4 \
--n-rows 2 \
--image-size 256x512 \
'A photo of an astronaut riding a horse on Mars.'
```
### Experimental Options
FLUX pads the prompt to a specific size of 512 tokens for the dev model and
256 for the schnell model. Not applying padding results in faster generation
but it is not clear how it may affect the generated images. To enable that
option in this example pass `--no-t5-padding` to the `txt2image.py` script or
instantiate the pipeline with `FluxPipeline("flux-schnell", t5_padding=False)`.
Finetuning
----------
The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
but ymmv) on a provided image dataset. The dataset folder must have an
`train.jsonl` file with the following format:
```jsonl
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
...
```
The training script by default trains for 600 iterations with a batch size of
1, gradient accumulation of 4 and LoRA rank of 8. Run `python dreambooth.py
--help` for the list of hyperparameters you can tune.
> [!Note]
> FLUX finetuning requires approximately 50GB of RAM. QLoRA is coming soon and
> should reduce this number significantly.
### Training Example
This is a step-by-step finetuning example. We will be using the data from
[https://github.com/google/dreambooth](https://github.com/google/dreambooth).
In particular, we will use `dog6` which is a popular example for showcasing
dreambooth [^1].
The training images are the following 5 images [^2]:
![dog6](static/dog6.png)
We start by making the following `train.jsonl` file and placing it in the same
folder as the images.
```jsonl
{"image": "00.jpg", "prompt": "A photo of sks dog"}
{"image": "01.jpg", "prompt": "A photo of sks dog"}
{"image": "02.jpg", "prompt": "A photo of sks dog"}
{"image": "03.jpg", "prompt": "A photo of sks dog"}
{"image": "04.jpg", "prompt": "A photo of sks dog"}
```
Subsequently we finetune FLUX using the following command:
```shell
python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 8 \
path/to/dreambooth/dataset/dog6
```
Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning.
```shell
python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 8 \
mlx-community/dreambooth-dog6
```
The training requires approximately 50GB of RAM and on an M2 Ultra it takes a
bit more than 1 hour.
### Using the Adapter
The adapters are saved in `mlx_output` and can be used directly by the
`txt2image.py` script. For instance,
```shell
python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \
--adapter mlx_output/0001200_adapters.safetensors \
--fuse-adapter \
--no-t5-padding \
'A photo of an sks dog lying on the sand at a beach in Greece'
```
generates an image that looks like the following,
![dog image](static/dog-r4-g8-1200.png)
and of course we can pass `--image-size 512x1024` to get larger images with
different aspect ratios,
![wide dog image](static/dog-r4-g8-1200-512x1024.png)
The arguments that are relevant to the adapters are of course `--adapter` and
`--fuse-adapter`. The first defines the path to an adapter to apply to the
model and the second fuses the adapter back into the model to get a bit more
speed during generation.
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
[^2]: The images are from unsplash by https://unsplash.com/@alvannee .

285
flux/dreambooth.py Normal file
View File

@ -0,0 +1,285 @@
# Copyright © 2024 Apple Inc.
import argparse
import time
from functools import partial
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image
from flux import FluxPipeline, Trainer, load_dataset
def generate_progress_images(iteration, flux, args):
"""Generate images to monitor the progress of the finetuning."""
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_progress.png"
print(f"Generating {str(out_file)}", flush=True)
# Generate some images and arrange them in a grid
n_rows = 2
n_images = 4
x = flux.generate_images(
args.progress_prompt,
n_images,
args.progress_steps,
)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(n_rows * H, B // n_rows * W, C)
x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
x = (x * 255).astype(mx.uint8)
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(out_file)
def save_adapters(iteration, flux, args):
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_adapters.safetensors"
print(f"Saving {str(out_file)}")
mx.save_safetensors(
str(out_file),
dict(tree_flatten(flux.flow.trainable_parameters())),
metadata={
"lora_rank": str(args.lora_rank),
"lora_blocks": str(args.lora_blocks),
},
)
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="Finetune Flux to generate images with a specific subject"
)
parser.add_argument(
"--model",
default="dev",
choices=[
"dev",
"schnell",
],
help="Which flux model to train",
)
parser.add_argument(
"--guidance", type=float, default=4.0, help="The guidance factor to use."
)
parser.add_argument(
"--iterations",
type=int,
default=600,
help="How many iterations to train for",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="The batch size to use when training the stable diffusion model",
)
parser.add_argument(
"--resolution",
type=lambda x: tuple(map(int, x.split("x"))),
default=(512, 512),
help="The resolution of the training images",
)
parser.add_argument(
"--num-augmentations",
type=int,
default=5,
help="Augment the images by random cropping and panning",
)
parser.add_argument(
"--progress-prompt",
required=True,
help="Use this prompt when generating images for evaluation",
)
parser.add_argument(
"--progress-steps",
type=int,
default=50,
help="Use this many steps when generating images for evaluation",
)
parser.add_argument(
"--progress-every",
type=int,
default=50,
help="Generate images every PROGRESS_EVERY steps",
)
parser.add_argument(
"--checkpoint-every",
type=int,
default=50,
help="Save the model every CHECKPOINT_EVERY steps",
)
parser.add_argument(
"--lora-blocks",
type=int,
default=-1,
help="Train the last LORA_BLOCKS transformer blocks",
)
parser.add_argument(
"--lora-rank", type=int, default=8, help="LoRA rank for finetuning"
)
parser.add_argument(
"--warmup-steps", type=int, default=100, help="Learning rate warmup"
)
parser.add_argument(
"--learning-rate", type=float, default="1e-4", help="Learning rate for training"
)
parser.add_argument(
"--grad-accumulate",
type=int,
default=4,
help="Accumulate gradients for that many iterations before applying them",
)
parser.add_argument(
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
)
parser.add_argument("dataset")
return parser
if __name__ == "__main__":
parser = setup_arg_parser()
args = parser.parse_args()
# Load the model and set it up for LoRA training. We use the same random
# state when creating the LoRA layers so all workers will have the same
# initial weights.
mx.random.seed(0x0F0F0F0F)
flux = FluxPipeline("flux-" + args.model)
flux.flow.freeze()
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
# Reset the seed to a different seed per worker if we are in distributed
# mode so that each worker is working on different data, diffusion step and
# random noise.
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
# Report how many parameters we are training
trainable_params = tree_reduce(
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
)
print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True)
# Set up the optimizer and training steps. The steps are a bit verbose to
# support gradient accumulation together with compilation.
warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps)
cosine = optim.cosine_decay(
args.learning_rate, args.iterations // args.grad_accumulate
)
lr_schedule = optim.join_schedules([warmup, cosine], [args.warmup_steps])
optimizer = optim.Adam(learning_rate=lr_schedule)
state = [flux.flow.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)
def single_step(x, t5_feat, clip_feat, guidance):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
grads = average_gradients(grads)
optimizer.update(flux.flow, grads)
return loss
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
return nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
return loss, grads
@partial(mx.compile, inputs=state, outputs=state)
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
grads = tree_map(
lambda a, b: (a + b) / args.grad_accumulate,
prev_grads,
grads,
)
grads = average_gradients(grads)
optimizer.update(flux.flow, grads)
return loss
# We simply route to the appropriate step based on whether we have
# gradients from a previous step and whether we should be performing an
# update or simply computing and accumulating gradients in this step.
def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
if prev_grads is None:
if perform_step:
return single_step(x, t5_feat, clip_feat, guidance), None
else:
return compute_loss_and_grads(x, t5_feat, clip_feat, guidance)
else:
if perform_step:
return (
grad_accumulate_and_step(
x, t5_feat, clip_feat, guidance, prev_grads
),
None,
)
else:
return compute_loss_and_accumulate_grads(
x, t5_feat, clip_feat, guidance, prev_grads
)
dataset = load_dataset(args.dataset)
trainer = Trainer(flux, dataset, args)
trainer.encode_dataset()
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
# An initial generation to compare
generate_progress_images(0, flux, args)
grads = None
losses = []
tic = time.time()
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
mx.eval(loss, grads, state)
losses.append(loss.item())
if (i + 1) % 10 == 0:
toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024**3
print(
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
f"It/s: {10 / (toc - tic):.3f} "
f"Peak mem: {peak_mem:.3f} GB",
flush=True,
)
if (i + 1) % args.progress_every == 0:
generate_progress_images(i + 1, flux, args)
if (i + 1) % args.checkpoint_every == 0:
save_adapters(i + 1, flux, args)
if (i + 1) % 10 == 0:
losses = []
tic = time.time()

15
flux/flux/__init__.py Normal file
View File

@ -0,0 +1,15 @@
# Copyright © 2024 Apple Inc.
from .datasets import Dataset, load_dataset
from .flux import FluxPipeline
from .lora import LoRALinear
from .sampler import FluxSampler
from .trainer import Trainer
from .utils import (
load_ae,
load_clip,
load_clip_tokenizer,
load_flow_model,
load_t5,
load_t5_tokenizer,
)

357
flux/flux/autoencoder.py Normal file
View File

@ -0,0 +1,357 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from typing import List
import mlx.core as mx
import mlx.nn as nn
from mlx.nn.layers.upsample import upsample_nearest
@dataclass
class AutoEncoderParams:
resolution: int
in_channels: int
ch: int
out_ch: int
ch_mult: List[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(
num_groups=32,
dims=in_channels,
eps=1e-6,
affine=True,
pytorch_compatible=True,
)
self.q = nn.Linear(in_channels, in_channels)
self.k = nn.Linear(in_channels, in_channels)
self.v = nn.Linear(in_channels, in_channels)
self.proj_out = nn.Linear(in_channels, in_channels)
def __call__(self, x: mx.array) -> mx.array:
B, H, W, C = x.shape
y = x.reshape(B, 1, -1, C)
y = self.norm(y)
q = self.q(y)
k = self.k(y)
v = self.v(y)
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5))
y = self.proj_out(y)
return x + y.reshape(B, H, W, C)
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(
num_groups=32,
dims=in_channels,
eps=1e-6,
affine=True,
pytorch_compatible=True,
)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = nn.GroupNorm(
num_groups=32,
dims=out_channels,
eps=1e-6,
affine=True,
pytorch_compatible=True,
)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Linear(in_channels, out_channels)
def __call__(self, x):
h = x
h = self.norm1(h)
h = nn.silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = nn.silu(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def __call__(self, x: mx.array):
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def __call__(self, x: mx.array):
x = upsample_nearest(x, (2, 2))
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = []
block_in = self.ch
for i_level in range(self.num_resolutions):
block = []
attn = [] # TODO: Remove the attn, nobody appends anything to it
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = {}
down["block"] = block
down["attn"] = attn
if i_level != self.num_resolutions - 1:
down["downsample"] = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = {}
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid["attn_1"] = AttnBlock(block_in)
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
)
self.conv_out = nn.Conv2d(
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
)
def __call__(self, x: mx.array):
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level]["block"][i_block](hs[-1])
# TODO: Remove the attn
if len(self.down[i_level]["attn"]) > 0:
h = self.down[i_level]["attn"][i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level]["downsample"](hs[-1]))
# middle
h = hs[-1]
h = self.mid["block_1"](h)
h = self.mid["attn_1"](h)
h = self.mid["block_2"](h)
# end
h = self.norm_out(h)
h = nn.silu(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
self.mid = {}
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid["attn_1"] = AttnBlock(block_in)
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = []
for i_level in reversed(range(self.num_resolutions)):
block = []
attn = [] # TODO: Remove the attn, nobody appends anything to it
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = {}
up["block"] = block
up["attn"] = attn
if i_level != 0:
up["upsample"] = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def __call__(self, z: mx.array):
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid["block_1"](h)
h = self.mid["attn_1"](h)
h = self.mid["block_2"](h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level]["block"][i_block](h)
# TODO: Remove the attn
if len(self.up[i_level]["attn"]) > 0:
h = self.up[i_level]["attn"][i_block](h)
if i_level != 0:
h = self.up[i_level]["upsample"](h)
# end
h = self.norm_out(h)
h = nn.silu(h)
h = self.conv_out(h)
return h
class DiagonalGaussian(nn.Module):
def __call__(self, z: mx.array):
mean, logvar = mx.split(z, 2, axis=-1)
if self.training:
std = mx.exp(0.5 * logvar)
eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
return mean + std * eps
else:
return mean
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.reg = DiagonalGaussian()
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def sanitize(self, weights):
new_weights = {}
for k, w in weights.items():
if w.ndim == 4:
w = w.transpose(0, 2, 3, 1)
w = w.reshape(-1).reshape(w.shape)
if w.shape[1:3] == (1, 1):
w = w.squeeze((1, 2))
new_weights[k] = w
return new_weights
def encode(self, x: mx.array):
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z
def decode(self, z: mx.array):
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
def __call__(self, x: mx.array):
return self.decode(self.encode(x))

154
flux/flux/clip.py Normal file
View File

@ -0,0 +1,154 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
@dataclass
class CLIPTextModelConfig:
num_layers: int = 23
model_dims: int = 1024
num_heads: int = 16
max_length: int = 77
vocab_size: int = 49408
hidden_act: str = "quick_gelu"
@classmethod
def from_dict(cls, config):
return cls(
num_layers=config["num_hidden_layers"],
model_dims=config["hidden_size"],
num_heads=config["num_attention_heads"],
max_length=config["max_position_embeddings"],
vocab_size=config["vocab_size"],
hidden_act=config["hidden_act"],
)
@dataclass
class CLIPOutput:
# The last_hidden_state indexed at the EOS token and possibly projected if
# the model has a projection layer
pooled_output: Optional[mx.array] = None
# The full sequence output of the transformer after the final layernorm
last_hidden_state: Optional[mx.array] = None
# A list of hidden states corresponding to the outputs of the transformer layers
hidden_states: Optional[List[mx.array]] = None
class CLIPEncoderLayer(nn.Module):
"""The transformer encoder layer from CLIP."""
def __init__(self, model_dims: int, num_heads: int, activation: str):
super().__init__()
self.layer_norm1 = nn.LayerNorm(model_dims)
self.layer_norm2 = nn.LayerNorm(model_dims)
self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
self.linear2 = nn.Linear(4 * model_dims, model_dims)
self.act = _ACTIVATIONS[activation]
def __call__(self, x, attn_mask=None):
y = self.layer_norm1(x)
y = self.attention(y, y, y, attn_mask)
x = y + x
y = self.layer_norm2(x)
y = self.linear1(y)
y = self.act(y)
y = self.linear2(y)
x = y + x
return x
class CLIPTextModel(nn.Module):
"""Implements the text encoder transformer from CLIP."""
def __init__(self, config: CLIPTextModelConfig):
super().__init__()
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
self.layers = [
CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
for i in range(config.num_layers)
]
self.final_layer_norm = nn.LayerNorm(config.model_dims)
def _get_mask(self, N, dtype):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
return mask
def sanitize(self, weights):
new_weights = {}
for key, w in weights.items():
# Remove prefixes
if key.startswith("text_model."):
key = key[11:]
if key.startswith("embeddings."):
key = key[11:]
if key.startswith("encoder."):
key = key[8:]
# Map attention layers
if "self_attn." in key:
key = key.replace("self_attn.", "attention.")
if "q_proj." in key:
key = key.replace("q_proj.", "query_proj.")
if "k_proj." in key:
key = key.replace("k_proj.", "key_proj.")
if "v_proj." in key:
key = key.replace("v_proj.", "value_proj.")
# Map ffn layers
if "mlp.fc1" in key:
key = key.replace("mlp.fc1", "linear1")
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
new_weights[key] = w
return new_weights
def __call__(self, x):
# Extract some shapes
B, N = x.shape
eos_tokens = x.argmax(-1)
# Compute the embeddings
x = self.token_embedding(x)
x = x + self.position_embedding.weight[:N]
# Compute the features from the transformer
mask = self._get_mask(N, x.dtype)
hidden_states = []
for l in self.layers:
x = l(x, mask)
hidden_states.append(x)
# Apply the final layernorm and return
x = self.final_layer_norm(x)
last_hidden_state = x
# Select the EOS token
pooled_output = x[mx.arange(len(x)), eos_tokens]
return CLIPOutput(
pooled_output=pooled_output,
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
)

75
flux/flux/datasets.py Normal file
View File

@ -0,0 +1,75 @@
import json
from pathlib import Path
from PIL import Image
class Dataset:
def __getitem__(self, index: int):
raise NotImplementedError()
def __len__(self):
raise NotImplementedError()
class LocalDataset(Dataset):
prompt_key = "prompt"
def __init__(self, dataset: str, data_file):
self.dataset_base = Path(dataset)
with open(data_file, "r") as fid:
self._data = [json.loads(l) for l in fid]
def __len__(self):
return len(self._data)
def __getitem__(self, index: int):
item = self._data[index]
image = Image.open(self.dataset_base / item["image"])
return image, item[self.prompt_key]
class LegacyDataset(LocalDataset):
prompt_key = "text"
def __init__(self, dataset: str):
self.dataset_base = Path(dataset)
with open(self.dataset_base / "index.json") as f:
self._data = json.load(f)["data"]
class HuggingFaceDataset(Dataset):
def __init__(self, dataset: str):
from datasets import load_dataset as hf_load_dataset
self._df = hf_load_dataset(dataset)["train"]
def __len__(self):
return len(self._df)
def __getitem__(self, index: int):
item = self._df[index]
return item["image"], item["prompt"]
def load_dataset(dataset: str):
dataset_base = Path(dataset)
data_file = dataset_base / "train.jsonl"
legacy_file = dataset_base / "index.json"
if data_file.exists():
print(f"Load the local dataset {data_file} .", flush=True)
dataset = LocalDataset(dataset, data_file)
elif legacy_file.exists():
print(f"Load the local dataset {legacy_file} .")
print()
print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
print(" See the README for details.")
print(flush=True)
dataset = LegacyDataset(dataset)
else:
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
dataset = HuggingFaceDataset(dataset)
return dataset

246
flux/flux/flux.py Normal file
View File

@ -0,0 +1,246 @@
# Copyright © 2024 Apple Inc.
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from tqdm import tqdm
from .lora import LoRALinear
from .sampler import FluxSampler
from .utils import (
load_ae,
load_clip,
load_clip_tokenizer,
load_flow_model,
load_t5,
load_t5_tokenizer,
)
class FluxPipeline:
def __init__(self, name: str, t5_padding: bool = True):
self.dtype = mx.bfloat16
self.name = name
self.t5_padding = t5_padding
self.ae = load_ae(name)
self.flow = load_flow_model(name)
self.clip = load_clip(name)
self.clip_tokenizer = load_clip_tokenizer(name)
self.t5 = load_t5(name)
self.t5_tokenizer = load_t5_tokenizer(name)
self.sampler = FluxSampler(name)
def ensure_models_are_loaded(self):
mx.eval(
self.ae.parameters(),
self.flow.parameters(),
self.clip.parameters(),
self.t5.parameters(),
)
def reload_text_encoders(self):
self.t5 = load_t5(self.name)
self.clip = load_clip(self.name)
def tokenize(self, text):
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
clip_tokens = self.clip_tokenizer.encode(text)
return t5_tokens, clip_tokens
def _prepare_latent_images(self, x):
b, h, w, c = x.shape
# Pack the latent image to 2x2 patches
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
# Create positions ids used to positionally encode each patch. Due to
# the way RoPE works, this results in an interesting positional
# encoding where parts of the feature are holding different positional
# information. Namely, the first part holds information independent of
# the spatial position (hence 0s), the 2nd part holds vertical spatial
# information and the last one horizontal.
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
x_ids = mx.stack([i, j, k], axis=-1)
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
return x, x_ids
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
# Prepare the text features
txt = self.t5(t5_tokens)
if len(txt) == 1 and n_images > 1:
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
# Prepare the clip text features
vec = self.clip(clip_tokens).pooled_output
if len(vec) == 1 and n_images > 1:
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
return txt, txt_ids, vec
def _denoising_loop(
self,
x_t,
x_ids,
txt,
txt_ids,
vec,
num_steps: int = 35,
guidance: float = 4.0,
start: float = 1,
stop: float = 0,
):
B = len(x_t)
def scalar(x):
return mx.full((B,), x, dtype=self.dtype)
guidance = scalar(guidance)
timesteps = self.sampler.timesteps(
num_steps,
x_t.shape[1],
start=start,
stop=stop,
)
for i in range(num_steps):
t = timesteps[i]
t_prev = timesteps[i + 1]
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=scalar(t),
guidance=guidance,
)
x_t = self.sampler.step(pred, x_t, t, t_prev)
yield x_t
def generate_latents(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
):
# Set the PRNG state
if seed is not None:
mx.random.seed(seed)
# Create the latent variables
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
x_T, x_ids = self._prepare_latent_images(x_T)
# Get the conditioning
t5_tokens, clip_tokens = self.tokenize(text)
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
# Yield the conditioning for controlled evaluation by the caller
yield (x_T, x_ids, txt, txt_ids, vec)
# Yield the latent sequences from the denoising loop
yield from self._denoising_loop(
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
)
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
h, w = latent_size
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
x = self.ae.decode(x)
return mx.clip(x + 1, 0, 2) * 0.5
def generate_images(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
reload_text_encoders: bool = True,
progress: bool = True,
):
latents = self.generate_latents(
text, n_images, num_steps, guidance, latent_size, seed
)
mx.eval(next(latents))
if reload_text_encoders:
self.reload_text_encoders()
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
mx.eval(x_t)
images = []
for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"):
images.append(self.decode(x_t[i : i + 1]))
mx.eval(images[-1])
images = mx.concatenate(images, axis=0)
mx.eval(images)
return images
def training_loss(
self,
x_0: mx.array,
t5_features: mx.array,
clip_features: mx.array,
guidance: mx.array,
):
# Get the text conditioning
txt = t5_features
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
vec = clip_features
# Prepare the latent input
x_0, x_ids = self._prepare_latent_images(x_0)
# Forward process
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
x_t = self.sampler.add_noise(x_0, t, noise=eps)
x_t = mx.stop_gradient(x_t)
# Do the denoising
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t,
guidance=guidance,
)
return (pred + x_0 - eps).square().mean()
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
"""Swap the linear layers in the transformer blocks with LoRA layers."""
all_blocks = self.flow.double_blocks + self.flow.single_blocks
all_blocks.reverse()
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
for i, block in zip(range(num_blocks), all_blocks):
loras = []
for name, module in block.named_modules():
if isinstance(module, nn.Linear):
loras.append((name, LoRALinear.from_base(module, r=rank)))
block.update_modules(tree_unflatten(loras))
def fuse_lora_layers(self):
fused_layers = []
for name, module in self.flow.named_modules():
if isinstance(module, LoRALinear):
fused_layers.append((name, module.fuse()))
self.flow.update_modules(tree_unflatten(fused_layers))

302
flux/flux/layers.py Normal file
View File

@ -0,0 +1,302 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
def _rope(pos: mx.array, dim: int, theta: float):
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
omega = 1.0 / (theta**scale)
x = pos[..., None] * omega
cosx = mx.cos(x)
sinx = mx.sin(x)
pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1)
pe = pe.reshape(*pe.shape[:-1], 2, 2)
return pe
@partial(mx.compile, shapeless=True)
def _ab_plus_cd(a, b, c, d):
return a * b + c * d
def _apply_rope(x, pe):
s = x.shape
x = x.reshape(*s[:-1], -1, 1, 2)
x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1])
return x.reshape(s)
def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
B, H, L, D = q.shape
q = _apply_rope(q, pe)
k = _apply_rope(k, pe)
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5))
return x.transpose(0, 2, 1, 3).reshape(B, L, -1)
def timestep_embedding(
t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0
):
half = dim // 2
freqs = mx.arange(0, half, dtype=mx.float32) / half
freqs = freqs * (-math.log(max_period))
freqs = mx.exp(freqs)
x = (time_factor * t)[:, None] * freqs[None]
x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1)
return x.astype(t.dtype)
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def __call__(self, ids: mx.array):
n_axes = ids.shape[-1]
pe = mx.concatenate(
[_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
axis=-3,
)
return pe[:, None]
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x: mx.array) -> mx.array:
return self.out_layer(nn.silu(self.in_layer(x)))
class QKNorm(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = nn.RMSNorm(dim)
self.key_norm = nn.RMSNorm(dim)
def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]:
return self.query_norm(q), self.key_norm(k)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
H = self.num_heads
B, L, _ = x.shape
qkv = self.qkv(x)
q, k, v = mx.split(qkv, 3, axis=-1)
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
q, k = self.norm(q, k)
x = _attention(q, k, v, pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: mx.array
scale: mx.array
gate: mx.array
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
x = self.lin(nn.silu(x))
xs = mx.split(x[:, None, :], self.multiplier, axis=-1)
mod1 = ModulationOut(*xs[:3])
mod2 = ModulationOut(*xs[3:]) if self.is_double else None
return mod1, mod2
class DoubleStreamBlock(nn.Module):
def __init__(
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.img_attn = SelfAttention(
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
)
self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approx="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.txt_attn = SelfAttention(
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
)
self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approx="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def __call__(
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
) -> Tuple[mx.array, mx.array]:
B, L, _ = img.shape
_, S, _ = txt.shape
H = self.num_heads
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1)
img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_q, img_k = self.img_attn.norm(img_q, img_k)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1)
txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
# run actual attention
q = mx.concatenate([txt_q, img_q], axis=2)
k = mx.concatenate([txt_k, img_k], axis=2)
v = mx.concatenate([txt_v, img_v], axis=2)
attn = _attention(q, k, v, pe)
txt_attn, img_attn = mx.split(attn, [S], axis=1)
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp(
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp(
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
)
return img, txt
class SingleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: Optional[float] = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approx="tanh")
self.modulation = Modulation(hidden_size, double=False)
def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
B, L, _ = x.shape
H = self.num_heads
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
q, k, v, mlp = mx.split(
self.linear1(x_mod),
[self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size],
axis=-1,
)
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
q, k = self.norm(q, k)
# compute attention
y = _attention(q, k, v, pe)
# compute activation in mlp stream, cat again and run second linear layer
y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
return x + mod.gate * y
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.linear = nn.Linear(
hidden_size, patch_size * patch_size * out_channels, bias=True
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def __call__(self, x: mx.array, vec: mx.array):
shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

76
flux/flux/lora.py Normal file
View File

@ -0,0 +1,76 @@
# Copyright © 2024 Apple Inc.
import math
import mlx.core as mx
import mlx.nn as nn
class LoRALinear(nn.Module):
@staticmethod
def from_base(
linear: nn.Linear,
r: int = 8,
dropout: float = 0.0,
scale: float = 1.0,
):
output_dims, input_dims = linear.weight.shape
lora_lin = LoRALinear(
input_dims=input_dims,
output_dims=output_dims,
r=r,
dropout=dropout,
scale=scale,
)
lora_lin.linear = linear
return lora_lin
def fuse(self):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
dtype = weight.dtype
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
lora_b = self.scale * self.lora_b.T
lora_a = self.lora_a.T
fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype)
if bias:
fused_linear.bias = linear.bias
return fused_linear
def __init__(
self,
input_dims: int,
output_dims: int,
r: int = 8,
dropout: float = 0.0,
scale: float = 1.0,
bias: bool = False,
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, r),
)
self.lora_b = mx.zeros(shape=(r, output_dims))
def __call__(self, x):
y = self.linear(x)
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
return y + (self.scale * z).astype(x.dtype)

134
flux/flux/model.py Normal file
View File

@ -0,0 +1,134 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from .layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
@dataclass
class FluxParams:
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(nn.Module):
def __init__(self, params: FluxParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
)
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
if params.guidance_embed
else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = [
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
self.single_blocks = [
SingleStreamBlock(
self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
)
for _ in range(params.depth_single_blocks)
]
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def sanitize(self, weights):
new_weights = {}
for k, w in weights.items():
if k.endswith(".scale"):
k = k[:-6] + ".weight"
for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
if f".{seq}." in k:
k = k.replace(f".{seq}.", f".{seq}.layers.")
break
new_weights[k] = w
return new_weights
def __call__(
self,
img: mx.array,
img_ids: mx.array,
txt: mx.array,
txt_ids: mx.array,
timesteps: mx.array,
y: mx.array,
guidance: Optional[mx.array] = None,
) -> mx.array:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = mx.concatenate([txt_ids, img_ids], axis=1)
pe = self.pe_embedder(ids).astype(img.dtype)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = mx.concatenate([txt, img], axis=1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec)
return img

56
flux/flux/sampler.py Normal file
View File

@ -0,0 +1,56 @@
# Copyright © 2024 Apple Inc.
import math
from functools import lru_cache
import mlx.core as mx
class FluxSampler:
def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5):
self._base_shift = base_shift
self._max_shift = max_shift
self._schnell = "schnell" in name
def _time_shift(self, x, t):
x1, x2 = 256, 4096
t1, t2 = self._base_shift, self._max_shift
exp_mu = math.exp((x - x1) * (t2 - t1) / (x2 - x1) + t1)
t = exp_mu / (exp_mu + (1 / t - 1))
return t
@lru_cache
def timesteps(
self, num_steps, image_sequence_length, start: float = 1, stop: float = 0
):
t = mx.linspace(start, stop, num_steps + 1)
if self._schnell:
t = self._time_shift(image_sequence_length, t)
return t.tolist()
def random_timesteps(self, B, L, dtype=mx.float32, key=None):
if self._schnell:
# TODO: Should we upweigh 1 and 0.75?
t = mx.random.randint(1, 5, shape=(B,), key=key)
t = t.astype(dtype) / 4
else:
t = mx.random.uniform(shape=(B,), dtype=dtype, key=key)
t = self._time_shift(L, t)
return t
def sample_prior(self, shape, dtype=mx.float32, key=None):
return mx.random.normal(shape, dtype=dtype, key=key)
def add_noise(self, x, t, noise=None, key=None):
noise = (
noise
if noise is not None
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
)
return x * (1 - t) + t * noise
def step(self, pred, x_t, t, t_prev):
return x_t + (t_prev - t) * pred

244
flux/flux/t5.py Normal file
View File

@ -0,0 +1,244 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
_SHARED_REPLACEMENT_PATTERNS = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]
_ENCODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.", ".dense."),
]
@dataclass
class T5Config:
vocab_size: int
num_layers: int
num_heads: int
relative_attention_num_buckets: int
d_kv: int
d_model: int
feed_forward_proj: str
tie_word_embeddings: bool
d_ff: Optional[int] = None
num_decoder_layers: Optional[int] = None
relative_attention_max_distance: int = 128
layer_norm_epsilon: float = 1e-6
@classmethod
def from_dict(cls, config):
return cls(
vocab_size=config["vocab_size"],
num_layers=config["num_layers"],
num_heads=config["num_heads"],
relative_attention_num_buckets=config["relative_attention_num_buckets"],
d_kv=config["d_kv"],
d_model=config["d_model"],
feed_forward_proj=config["feed_forward_proj"],
tie_word_embeddings=config["tie_word_embeddings"],
d_ff=config.get("d_ff", 4 * config["d_model"]),
num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]),
relative_attention_max_distance=config.get(
"relative_attention_max_distance", 128
),
layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6),
)
class RelativePositionBias(nn.Module):
def __init__(self, config: T5Config, bidirectional: bool):
self.bidirectional = bidirectional
self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.relative_attention_max_distance
self.n_heads = config.num_heads
self.embeddings = nn.Embedding(self.num_buckets, self.n_heads)
@staticmethod
def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance):
num_buckets = num_buckets // 2 if bidirectional else num_buckets
max_exact = num_buckets // 2
abspos = rpos.abs()
is_small = abspos < max_exact
scale = (num_buckets - max_exact) / math.log(max_distance / max_exact)
buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16)
buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1)
buckets = mx.where(is_small, abspos, buckets_large)
if bidirectional:
buckets = buckets + (rpos > 0) * num_buckets
else:
buckets = buckets * (rpos < 0)
return buckets
def __call__(self, query_length: int, key_length: int, offset: int = 0):
"""Compute binned relative position bias"""
context_position = mx.arange(offset, query_length)[:, None]
memory_position = mx.arange(key_length)[None, :]
# shape (query_length, key_length)
relative_position = memory_position - context_position
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=self.bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
# shape (query_length, key_length, num_heads)
values = self.embeddings(relative_position_bucket)
# shape (num_heads, query_length, key_length)
return values.transpose(2, 0, 1)
class MultiHeadAttention(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
inner_dim = config.d_kv * config.num_heads
self.num_heads = config.num_heads
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
def __call__(
self,
queries: mx.array,
keys: mx.array,
values: mx.array,
mask: Optional[mx.array],
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> [mx.array, Tuple[mx.array, mx.array]]:
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, _ = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
keys = mx.concatenate([key_cache, keys], axis=3)
values = mx.concatenate([value_cache, values], axis=2)
values_hat = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype)
)
values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class DenseActivation(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.gated = config.feed_forward_proj.startswith("gated")
if self.gated:
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
else:
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
activation = config.feed_forward_proj.removeprefix("gated-")
if activation == "relu":
self.act = nn.relu
elif activation == "gelu":
self.act = nn.gelu
elif activation == "silu":
self.act = nn.silu
else:
raise ValueError(f"Unknown activation: {activation}")
def __call__(self, x):
if self.gated:
hidden_act = self.act(self.wi_0(x))
hidden_linear = self.wi_1(x)
x = hidden_act * hidden_linear
else:
x = self.act(self.wi(x))
return self.wo(x)
class TransformerEncoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.attention = MultiHeadAttention(config)
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config)
def __call__(self, x, mask):
y = self.ln1(x)
y, _ = self.attention(y, y, y, mask=mask)
x = x + y
y = self.ln2(x)
y = self.dense(y)
return x + y
class TransformerEncoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.layers = [
TransformerEncoderLayer(config) for i in range(config.num_layers)
]
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
def __call__(self, x: mx.array):
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
pos_bias = pos_bias.astype(x.dtype)
for layer in self.layers:
x = layer(x, mask=pos_bias)
return self.ln(x)
class T5Encoder(nn.Module):
def __init__(self, config: T5Config):
self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = TransformerEncoder(config)
def sanitize(self, weights):
new_weights = {}
for k, w in weights.items():
for old, new in _SHARED_REPLACEMENT_PATTERNS:
k = k.replace(old, new)
if k.startswith("encoder."):
for old, new in _ENCODER_REPLACEMENT_PATTERNS:
k = k.replace(old, new)
new_weights[k] = w
return new_weights
def __call__(self, inputs: mx.array):
return self.encoder(self.wte(inputs))

185
flux/flux/tokenizers.py Normal file
View File

@ -0,0 +1,185 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import regex
from sentencepiece import SentencePieceProcessor
class CLIPTokenizer:
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
def __init__(self, bpe_ranks, vocab, max_length=77):
self.max_length = max_length
self.bpe_ranks = bpe_ranks
self.vocab = vocab
self.pat = regex.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
regex.IGNORECASE,
)
self._cache = {self.bos: self.bos, self.eos: self.eos}
@property
def bos(self):
return "<|startoftext|>"
@property
def bos_token(self):
return self.vocab[self.bos]
@property
def eos(self):
return "<|endoftext|>"
@property
def eos_token(self):
return self.vocab[self.eos]
def bpe(self, text):
if text in self._cache:
return self._cache[text]
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
unique_bigrams = set(zip(unigrams, unigrams[1:]))
if not unique_bigrams:
return unigrams
# In every iteration try to merge the two most likely bigrams. If none
# was merged we are done.
#
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
while unique_bigrams:
bigram = min(
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
)
if bigram not in self.bpe_ranks:
break
new_unigrams = []
skip = False
for a, b in zip(unigrams, unigrams[1:]):
if skip:
skip = False
continue
if (a, b) == bigram:
new_unigrams.append(a + b)
skip = True
else:
new_unigrams.append(a)
if not skip:
new_unigrams.append(b)
unigrams = new_unigrams
unique_bigrams = set(zip(unigrams, unigrams[1:]))
self._cache[text] = unigrams
return unigrams
def tokenize(self, text, prepend_bos=True, append_eos=True):
if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
# Lower case cleanup and split according to self.pat. Hugging Face does
# a much more thorough job here but this should suffice for 95% of
# cases.
clean_text = regex.sub(r"\s+", " ", text.lower())
tokens = regex.findall(self.pat, clean_text)
# Split the tokens according to the byte-pair merge file
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
# Map to token ids and return
tokens = [self.vocab[t] for t in bpe_tokens]
if prepend_bos:
tokens = [self.bos_token] + tokens
if append_eos:
tokens.append(self.eos_token)
if len(tokens) > self.max_length:
tokens = tokens[: self.max_length]
if append_eos:
tokens[-1] = self.eos_token
return tokens
def encode(self, text):
if not isinstance(text, list):
return self.encode([text])
tokens = self.tokenize(text)
length = max(len(t) for t in tokens)
for t in tokens:
t.extend([self.eos_token] * (length - len(t)))
return mx.array(tokens)
class T5Tokenizer:
def __init__(self, model_file, max_length=512):
self._tokenizer = SentencePieceProcessor(model_file)
self.max_length = max_length
@property
def pad(self):
try:
return self._tokenizer.id_to_piece(self.pad_token)
except IndexError:
return None
@property
def pad_token(self):
return self._tokenizer.pad_id()
@property
def bos(self):
try:
return self._tokenizer.id_to_piece(self.bos_token)
except IndexError:
return None
@property
def bos_token(self):
return self._tokenizer.bos_id()
@property
def eos(self):
try:
return self._tokenizer.id_to_piece(self.eos_token)
except IndexError:
return None
@property
def eos_token(self):
return self._tokenizer.eos_id()
def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True):
if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text]
tokens = self._tokenizer.encode(text)
if prepend_bos and self.bos_token >= 0:
tokens = [self.bos_token] + tokens
if append_eos and self.eos_token >= 0:
tokens.append(self.eos_token)
if pad and len(tokens) < self.max_length and self.pad_token >= 0:
tokens += [self.pad_token] * (self.max_length - len(tokens))
return tokens
def encode(self, text, pad=True):
if not isinstance(text, list):
return self.encode([text], pad=pad)
pad_token = self.pad_token if self.pad_token >= 0 else 0
tokens = self.tokenize(text, pad=pad)
length = max(len(t) for t in tokens)
for t in tokens:
t.extend([pad_token] * (length - len(t)))
return mx.array(tokens)

98
flux/flux/trainer.py Normal file
View File

@ -0,0 +1,98 @@
import mlx.core as mx
import numpy as np
from PIL import Image, ImageFile
from tqdm import tqdm
from .datasets import Dataset
from .flux import FluxPipeline
class Trainer:
def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
self.flux = flux
self.dataset = dataset
self.args = args
self.latents = []
self.t5_features = []
self.clip_features = []
def _random_crop_resize(self, img):
resolution = self.args.resolution
width, height = img.size
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * b) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * c,
pan[1] * d,
crop_size[0] + pan[0] * c,
crop_size[1] + pan[1] * d,
)
)
# Fit the largest rectangle with the ratio of resolution in the image
# rectangle.
width, height = crop_size
ratio = resolution[0] / resolution[1]
r1 = (height * ratio, height)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)
# Finally resize the image to resolution
img = img.resize(resolution, Image.LANCZOS)
return mx.array(np.array(img))
def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
for i in range(num_augmentations):
img = self._random_crop_resize(input_img)
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)
def _encode_prompt(self, prompt):
t5_tok, clip_tok = self.flux.tokenize([prompt])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)
def encode_dataset(self):
"""Encode the images & prompt in the latent space to prepare for training."""
self.flux.ae.eval()
for image, prompt in tqdm(self.dataset, desc="encode dataset"):
self._encode_image(image, self.args.num_augmentations)
self._encode_prompt(prompt)
def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True:
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i : i + batch_size]
c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]

209
flux/flux/utils.py Normal file
View File

@ -0,0 +1,209 @@
# Copyright © 2024 Apple Inc.
import json
import os
from dataclasses import dataclass
from typing import Optional
import mlx.core as mx
from huggingface_hub import hf_hub_download
from .autoencoder import AutoEncoder, AutoEncoderParams
from .clip import CLIPTextModel, CLIPTextModelConfig
from .model import Flux, FluxParams
from .t5 import T5Config, T5Encoder
from .tokenizers import CLIPTokenizer, T5Tokenizer
@dataclass
class ModelSpec:
params: FluxParams
ae_params: AutoEncoderParams
ckpt_path: Optional[str]
ae_path: Optional[str]
repo_id: Optional[str]
repo_flow: Optional[str]
repo_ae: Optional[str]
configs = {
"flux-dev": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev",
repo_flow="flux1-dev.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_DEV"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-schnell": ModelSpec(
repo_id="black-forest-labs/FLUX.1-schnell",
repo_flow="flux1-schnell.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_SCHNELL"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
}
def load_flow_model(name: str, hf_download: bool = True):
# Get the safetensors file to load
ckpt_path = configs[name].ckpt_path
# Download if needed
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_flow is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
# Make the model
model = Flux(configs[name].params)
# Load the checkpoint if needed
if ckpt_path is not None:
weights = mx.load(ckpt_path)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()))
return model
def load_ae(name: str, hf_download: bool = True):
# Get the safetensors file to load
ckpt_path = configs[name].ae_path
# Download if needed
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_ae is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
# Make the autoencoder
ae = AutoEncoder(configs[name].ae_params)
# Load the checkpoint if needed
if ckpt_path is not None:
weights = mx.load(ckpt_path)
weights = ae.sanitize(weights)
ae.load_weights(list(weights.items()))
return ae
def load_clip(name: str):
# Load the config
config_path = hf_hub_download(configs[name].repo_id, "text_encoder/config.json")
with open(config_path) as f:
config = CLIPTextModelConfig.from_dict(json.load(f))
# Make the clip text encoder
clip = CLIPTextModel(config)
# Load the weights
ckpt_path = hf_hub_download(configs[name].repo_id, "text_encoder/model.safetensors")
weights = mx.load(ckpt_path)
weights = clip.sanitize(weights)
clip.load_weights(list(weights.items()))
return clip
def load_t5(name: str):
# Load the config
config_path = hf_hub_download(configs[name].repo_id, "text_encoder_2/config.json")
with open(config_path) as f:
config = T5Config.from_dict(json.load(f))
# Make the T5 model
t5 = T5Encoder(config)
# Load the weights
model_index = hf_hub_download(
configs[name].repo_id, "text_encoder_2/model.safetensors.index.json"
)
weight_files = set()
with open(model_index) as f:
for _, w in json.load(f)["weight_map"].items():
weight_files.add(w)
weights = {}
for w in weight_files:
w = f"text_encoder_2/{w}"
w = hf_hub_download(configs[name].repo_id, w)
weights.update(mx.load(w))
weights = t5.sanitize(weights)
t5.load_weights(list(weights.items()))
return t5
def load_clip_tokenizer(name: str):
vocab_file = hf_hub_download(configs[name].repo_id, "tokenizer/vocab.json")
with open(vocab_file, encoding="utf-8") as f:
vocab = json.load(f)
merges_file = hf_hub_download(configs[name].repo_id, "tokenizer/merges.txt")
with open(merges_file, encoding="utf-8") as f:
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(m.split()) for m in bpe_merges]
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
return CLIPTokenizer(bpe_ranks, vocab, max_length=77)
def load_t5_tokenizer(name: str, pad: bool = True):
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)

7
flux/requirements.txt Normal file
View File

@ -0,0 +1,7 @@
mlx>=0.18.1
huggingface-hub
regex
numpy
tqdm
Pillow
sentencepiece

Binary file not shown.

After

Width:  |  Height:  |  Size: 754 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 423 KiB

BIN
flux/static/dog6.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 434 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 153 KiB

150
flux/txt2image.py Normal file
View File

@ -0,0 +1,150 @@
# Copyright © 2024 Apple Inc.
import argparse
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline
def to_latent_size(image_size):
h, w = image_size
h = ((h + 15) // 16) * 16
w = ((w + 15) // 16) * 16
if (h, w) != image_size:
print(
"Warning: The image dimensions need to be divisible by 16px. "
f"Changing size to {h}x{w}."
)
return (h // 8, w // 8)
def quantization_predicate(name, m):
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
def load_adapter(flux, adapter_file, fuse=False):
weights, lora_config = mx.load(adapter_file, return_metadata=True)
rank = int(lora_config["lora_rank"])
num_blocks = int(lora_config["lora_blocks"])
flux.linear_to_lora_layers(rank, num_blocks)
flux.flow.load_weights(list(weights.items()), strict=False)
if fuse:
flux.fuse_lora_layers()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"
)
parser.add_argument("prompt")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
parser.add_argument("--n-images", type=int, default=4)
parser.add_argument(
"--image-size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512)
)
parser.add_argument("--steps", type=int)
parser.add_argument("--guidance", type=float, default=4.0)
parser.add_argument("--n-rows", type=int, default=1)
parser.add_argument("--decoding-batch-size", type=int, default=1)
parser.add_argument("--quantize", "-q", action="store_true")
parser.add_argument("--preload-models", action="store_true")
parser.add_argument("--output", default="out.png")
parser.add_argument("--save-raw", action="store_true")
parser.add_argument("--seed", type=int)
parser.add_argument("--verbose", "-v", action="store_true")
parser.add_argument("--adapter")
parser.add_argument("--fuse-adapter", action="store_true")
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
args = parser.parse_args()
# Load the models
flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding)
args.steps = args.steps or (50 if args.model == "dev" else 2)
if args.adapter:
load_adapter(flux, args.adapter, fuse=args.fuse_adapter)
if args.quantize:
nn.quantize(flux.flow, class_predicate=quantization_predicate)
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
if args.preload_models:
flux.ensure_models_are_loaded()
# Make the generator
latent_size = to_latent_size(args.image_size)
latents = flux.generate_latents(
args.prompt,
n_images=args.n_images,
num_steps=args.steps,
latent_size=latent_size,
guidance=args.guidance,
seed=args.seed,
)
# First we get and eval the conditioning
conditioning = next(latents)
mx.eval(conditioning)
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the text encoders.
del flux.t5
del flux.clip
# Actual denoising loop
for x_t in tqdm(latents, total=args.steps):
mx.eval(x_t)
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the flow transformer.
del flux.flow
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
mx.eval(decoded[-1])
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3
peak_mem_overall = max(
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
)
if args.save_raw:
*name, suffix = args.output.split(".")
name = ".".join(name)
x = mx.concatenate(decoded, axis=0)
x = (x * 255).astype(mx.uint8)
for i in range(len(x)):
im = Image.fromarray(np.array(x[i]))
im.save(".".join([name, str(i), suffix]))
else:
# Arrange them on a grid
x = mx.concatenate(decoded, axis=0)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
x = (x * 255).astype(mx.uint8)
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(args.output)
# Report the peak memory used during generation
if args.verbose:
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")

View File

@ -20,6 +20,31 @@ The `mlx-lm` package also has:
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
### Quick Start
To generate text with an LLM use:
```bash
mlx_lm.generate --prompt "Hi!"
```
To chat with an LLM use:
```bash
mlx_lm.chat
```
This will give you a chat REPL that you can use to interact with the LLM. The
chat context is preserved during the lifetime of the REPL.
Commands in `mlx-lm` typically take command line options which let you specify
the model, sampling parameters, and more. Use `-h` to see a list of available
options for a command, e.g.:
```bash
mlx_lm.generate -h
```
### Python API
You can use `mlx-lm` as a module:
@ -138,7 +163,7 @@ mlx_lm.convert \
### Long Prompts and Generations
MLX LM has some tools to scale efficiently to long prompts and generations:
`mlx-lm` has some tools to scale efficiently to long prompts and generations:
- A rotating fixed-size key-value cache.
- Prompt caching
@ -155,14 +180,14 @@ different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
cat prompt.txt | mlx_lm.cache_prompt \
--model mistralai/Mistral-7B-Instruct-v0.3 \
--prompt - \
--kv-cache-file mistral_prompt.safetensors
--prompt-cache-file mistral_prompt.safetensors
```
Then use the cached prompt with `mlx_lm.generate`:
```
mlx_lm.generate \
--kv-cache-file mistral_prompt.safetensors \
--prompt-cache-file mistral_prompt.safetensors \
--prompt "\nSummarize the above text."
```
@ -170,9 +195,15 @@ The cached prompt is treated as a prefix to the supplied prompt. Also notice
when using a cached prompt, the model to use is read from the cache and need
not be supplied explicitly.
Prompt caching can also be used in the Python API in order to to avoid
recomputing the prompt. This is useful in multi-turn dialogues or across
requests that use the same context. See the
[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
for more usage details.
### Supported Models
MLX LM supports thousands of Hugging Face format LLMs. If the model you want to
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
run is not supported, file an
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
submit a pull request.

View File

@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
the generated prompt. If not provided, the default mappings are used.
- `stop`: (Optional) An array of strings or a single string. Thesse are
- `stop`: (Optional) An array of strings or a single string. These are
sequences of tokens on which the generation should stop.
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
@ -84,7 +84,37 @@ curl localhost:8080/v1/chat/completions \
started in.
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
rlative to the directory the server was started in.
relative to the directory the server was started in.
### Response Fields
- `id`: A unique identifier for the chat.
- `system_fingerprint`: A unique identifier for the system.
- `object`: Any of "chat.completions", "chat.completions.chunk" (for
streaming), or "text.completion".
- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).
- `created`: A time-stamp for when the request was processed.
- `choices`: A list of outputs. Each output is a dictionary containing the fields:
- `index`: The index in the list.
- `logprobs`: A dictionary containing the fields:
- `token_logprobs`: A list of the log probabilities for the generated
tokens.
- `tokens`: A list of the generated token ids.
- `top_logprobs`: A list of lists. Each list contains the `logprobs`
top tokens (if requested) with their corresponding probabilities.
- `finish_reason`: The reason the completion ended. This can be either of
`"stop"` or `"length"`.
- `message`: The text response from the model.
- `usage`: A dictionary containing the fields:
- `prompt_tokens`: The number of prompt tokens processed.
- `completion_tokens`: The number of tokens generated.
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.
### List Models
@ -97,5 +127,5 @@ curl localhost:8080/v1/models -H "Content-Type: application/json"
This will return a list of locally available models where each model in the
list contains the following fields:
- `"id"`: The Hugging Face repo id.
- `"created"`: A timestamp representing the model creation time.
- `id`: The Hugging Face repo id.
- `created`: A time-stamp representing the model creation time.

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.18.2"
__version__ = "0.19.1"

View File

@ -7,13 +7,14 @@ import time
import mlx.core as mx
from .utils import load, make_kv_caches
from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import load
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="Cache the KV cache of a prompt to be reused with mlx_lm.generate"
description="Cache the state of a prompt to be reused with mlx_lm.generate"
)
parser.add_argument(
"--model",
@ -60,7 +61,9 @@ def setup_arg_parser():
help="Set the maximum key-value cache size",
)
parser.add_argument(
"--kv-cache-file", help="The file to save the KV caches in", required=True
"--prompt-cache-file",
help="The file to save the prompt cache in",
required=True,
)
parser.add_argument(
"--prompt",
@ -115,7 +118,7 @@ def main():
else:
prompt = args.prompt
cache = make_kv_caches(model, args.max_kv_size)
cache = make_prompt_cache(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt))
# Process the prompt
@ -137,16 +140,12 @@ def main():
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
print("Saving...")
cache_dict = {}
for i, c in enumerate(cache):
cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :]
cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :]
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
metadata["max_kv_size"] = str(args.max_kv_size)
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
save_prompt_cache(args.prompt_cache_file, cache, metadata)
if __name__ == "__main__":

82
llms/mlx_lm/chat.py Normal file
View File

@ -0,0 +1,82 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import mlx.core as mx
from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
from .utils import load, stream_generate
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="Chat with an LLM")
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
default=DEFAULT_MODEL,
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--max-kv-size",
type=int,
help="Set the maximum key-value cache size",
default=None,
)
return parser
def main():
parser = setup_arg_parser()
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load(
args.model,
adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True},
)
print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
query = input(">> ")
if query == "q":
break
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for response in stream_generate(
model,
tokenizer,
prompt,
temp=args.temp,
top_p=args.top_p,
prompt_cache=prompt_cache,
):
print(response, flush=True, end="")
print()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,53 @@
# Copyright © 2024 Apple Inc.
"""
An example of a multi-turn chat with prompt caching.
"""
from mlx_lm import generate, load
from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
# Make the initial prompt cache for the model
prompt_cache = make_prompt_cache(model)
# User turn
prompt = "Hi my name is <Name>."
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response
response = generate(
model,
tokenizer,
prompt=prompt,
verbose=True,
temp=0.0,
prompt_cache=prompt_cache,
)
# User turn
prompt = "What's my name?"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response
response = generate(
model,
tokenizer,
prompt=prompt,
verbose=True,
temp=0.0,
prompt_cache=prompt_cache,
)
# Save the prompt cache to disk to reuse it at a later time
save_prompt_cache("mistral_prompt.safetensors", prompt_cache)
# Load the prompt cache from disk
prompt_cache = load_prompt_cache("mistral_prompt.safetensors")

View File

@ -1,3 +1,5 @@
# Copyright © 2024 Apple Inc.
from mlx_lm import generate, load
# Specify the checkpoint

View File

@ -6,13 +6,15 @@ import sys
import mlx.core as mx
from .models.cache import load_prompt_cache
from .utils import generate, load
DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.6
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def str2bool(string):
@ -25,7 +27,11 @@ def setup_arg_parser():
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
help=(
"The path to the local model directory or Hugging Face repo. "
f"If no model is specified, then {DEFAULT_MODEL} is used."
),
default=None,
)
parser.add_argument(
"--adapter-path",
@ -96,7 +102,7 @@ def setup_arg_parser():
default=None,
)
parser.add_argument(
"--kv-cache-file",
"--prompt-cache-file",
type=str,
default=None,
help="A file containing saved KV caches to avoid recomputing them",
@ -131,24 +137,6 @@ def colorprint_by_t0(s, t0):
colorprint(color, s)
def load_kv_cache_from_file(kv_cache_file):
if kv_cache_file is None:
return None, None
kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True)
cache_per_layer = {}
for k, x in kv_cache.items():
layer, kv_type = k.split("_")
if layer not in cache_per_layer:
cache_per_layer[layer] = {}
cache_per_layer[layer][kv_type] = x
cache_history = [None] * len(cache_per_layer)
for layer, c in cache_per_layer.items():
cache_history[int(layer)] = (c["keys"], c["values"])
return cache_history, metadata
def main():
parser = setup_arg_parser()
args = parser.parse_args()
@ -158,22 +146,33 @@ def main():
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Load the kv cache and metadata if a kv cache file is provided
cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file)
# Load the prompt cache and metadata if a cache file is provided
using_cache = args.prompt_cache_file is not None
if using_cache:
prompt_cache, metadata = load_prompt_cache(
args.prompt_cache_file, return_metadata=True
)
# Building tokenizer_config
tokenizer_config = (
{} if cache_history is None else json.loads(metadata["tokenizer_config"])
{} if not using_cache else json.loads(metadata["tokenizer_config"])
)
if args.trust_remote_code:
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
# If no model path is provided then use the one in the kv cache history
model_path = args.model
if cache_history is not None and model_path is None:
if using_cache:
if model_path is None:
model_path = metadata["model"]
elif model_path != metadata["model"]:
raise ValueError(
f"Providing a different model ({model_path}) than that "
f"used to create the prompt cache ({metadata['model']}) "
"is an error."
)
model_path = model_path or DEFAULT_MODEL
model, tokenizer = load(
model_path,
@ -184,7 +183,7 @@ def main():
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
elif cache_history is not None:
elif using_cache:
tokenizer.chat_template = metadata["chat_template"]
if not args.ignore_chat_template and (
@ -203,7 +202,7 @@ def main():
# Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache.
if cache_history is not None:
if using_cache:
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
@ -217,12 +216,6 @@ def main():
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None
# Determine the max kv size from the kv cache or passed arguments
max_kv_size = args.max_kv_size
if cache_history is not None:
max_kv_size = metadata["max_kv_size"]
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
response = generate(
model,
tokenizer,
@ -232,8 +225,8 @@ def main():
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
max_kv_size=max_kv_size,
cache_history=cache_history,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
)
if not args.verbose:
print(response)

View File

@ -2,145 +2,9 @@
import inspect
from dataclasses import dataclass
from typing import Any, List, Optional
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
class KVCache:
def __init__(self, head_dim, n_kv_heads):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keys = None
self.values = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B = keys.shape[0]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@property
def state(self):
return self.keys, self.values
class RotatingKVCache:
def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keep = keep
self.keys = None
self.values = None
self.offset = 0
self.max_size = max_size
self.step = step
self._idx = 0
def _trim(self, trim_size, v, append=None):
to_cat = []
if trim_size > 0:
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
else:
to_cat = [v]
if append is not None:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def update_and_fetch(self, keys, values):
prev = self.offset
B, _, S = keys.shape[:3]
# Prefill mode
if S > 1:
if self.keys is None:
self.keys = keys
self.values = values
else:
# The largest size is self.max_size + S - 1 to ensure
# every token gets at least self.max_size context
trim_size = self.keys.shape[2] - self.max_size + 1
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += S
self._idx = self.keys.shape[2]
return self.keys, self.values
# Generation mode
# May not have hit the max size yet, so potentially
# keep growing the cache
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
):
new_size = min(self.step, self.max_size - prev)
k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim)
v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self._idx = prev
# Trim if needed
trim_size = self.keys.shape[2] - self.max_size
if trim_size > 0:
self.keys = self._trim(trim_size, self.keys)
self.values = self._trim(trim_size, self.values)
self._idx = self.max_size
# Rotate
if self._idx == self.max_size:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + 1, :] = keys
self.values[..., self._idx : self._idx + 1, :] = values
self.offset += 1
self._idx += 1
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
@property
def state(self):
return self.keys, self.values
@dataclass
@ -156,25 +20,30 @@ class BaseModelArgs:
)
def create_additive_causal_mask(N: int, offset: int = 0):
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
mask = linds[:, None] < rinds[None]
linds = linds[:, None]
rinds = rinds[None]
mask = linds < rinds
if window_size is not None:
mask = mask | (linds > rinds + window_size)
return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
T = h.shape[1]
if T > 1:
window_size = None
offset = 0
if cache is not None and cache[0] is not None:
c = cache[0]
if isinstance(c, RotatingKVCache):
if hasattr(c, "max_size"):
offset = min(c.max_size - 1, c.offset)
window_size = c.max_size
else:
offset = c.offset
else:
offset = 0
mask = create_additive_causal_mask(T, offset)
mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else:
mask = None

340
llms/mlx_lm/models/cache.py Normal file
View File

@ -0,0 +1,340 @@
# Copyright © 2023-2024 Apple Inc.
from typing import Any, Dict, List, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
"""
Construct the model's cache for use when cgeneration.
This function will defer the cache construction to the model if it has a
``make_cache`` method, otherwise it will make a default KV cache.
Args:
model (nn.Module): The language model.
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
"""
if hasattr(model, "make_cache"):
return model.make_cache()
num_layers = len(model.layers)
if max_kv_size is not None:
return [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
]
else:
return [KVCache() for _ in range(num_layers)]
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
"""
Save a pre-computed prompt cache to a file.
Args:
file_name (str): The ``.safetensors`` file name.
cache (List[Any]): The model state.
metadata (Dict[str, str]): Optional metadata to save along with model
state.
"""
cache_data = [c.state for c in cache]
cache_info = [c.meta_state for c in cache]
cache_data = dict(tree_flatten(cache_data))
cache_classes = [type(c).__name__ for c in cache]
cache_metadata = [cache_info, metadata, cache_classes]
cache_metadata = dict(tree_flatten(cache_metadata))
mx.save_safetensors(file_name, cache_data, cache_metadata)
def load_prompt_cache(file_name, return_metadata=False):
"""
Load a prompt cache from a file.
Args:
file_name (str): The ``.safetensors`` file name.
return_metadata (bool): Whether or not to return metadata.
Default: ``False``.
Returns:
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
the metadata if requested.
"""
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
arrays = tree_unflatten(list(arrays.items()))
cache_metadata = tree_unflatten(list(cache_metadata.items()))
info, metadata, classes = cache_metadata
cache = [globals()[c]() for c in classes]
for c, state, meta_state in zip(cache, arrays, info):
c.state = state
c.meta_state = meta_state
if return_metadata:
return cache, metadata
return cache
def can_trim_prompt_cache(cache: List[Any]) -> bool:
"""
Check if model's cache can be trimmed.
"""
return all(c.is_trimmable() for c in cache)
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
"""
Trim the model's cache by the given number of tokens.
This function will trim the cache if possible (in-place) and return the
number of tokens that were trimmed.
Args:
cache (List[Any]): The model's cache.
num_tokens (int): The number of tokens to trim.
Returns:
(int): The number of tokens that were trimmed.
"""
if not can_trim_prompt_cache(cache) or len(cache) == 0:
return 0
return [c.trim(num_tokens) for c in cache][0]
class _BaseCache:
@property
def state(self):
return []
@state.setter
def state(self, v):
if v is not None and v:
raise ValueError("This cache has no state but a state was set.")
@property
def meta_state(self):
return ""
@meta_state.setter
def meta_state(self, v):
if v is not None and v:
raise ValueError("This cache has no meta_state but a meta_state was set.")
def is_trimmable(self):
return False
class KVCache(_BaseCache):
def __init__(self):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B, n_kv_heads, _, k_head_dim = keys.shape
v_head_dim = values.shape[3]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@property
def state(self):
if self.offset == self.keys.shape[2]:
return self.keys, self.values
else:
return (
self.keys[..., : self.offset, :],
self.values[..., : self.offset, :],
)
@state.setter
def state(self, v):
self.keys, self.values = v
self.offset = self.keys.shape[2]
def is_trimmable(self):
return True
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n
class RotatingKVCache(_BaseCache):
def __init__(self, max_size=None, keep=0, step=256):
self.keep = keep
self.keys = None
self.values = None
self.offset = 0
self.max_size = max_size
self.step = step
self._idx = 0
def _trim(self, trim_size, v, append=None):
to_cat = []
if trim_size > 0:
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
else:
to_cat = [v]
if append is not None:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def _temporal_order(self, v):
"""
Rearrange the cache into temporal order, slicing off the end if unused.
"""
if self._idx == v.shape[2]:
return v
elif self._idx < self.offset:
return mx.concatenate(
[
v[..., : self.keep, :],
v[..., self._idx :, :],
v[..., self.keep : self._idx, :],
],
axis=2,
)
else:
return v[..., : self._idx, :]
def _update_concat(self, keys, values):
if self.keys is None:
self.keys = keys
self.values = values
else:
# Put the keys/values in temporal order to
# preserve context
self.keys = self._temporal_order(self.keys)
self.values = self._temporal_order(self.values)
# The largest size is self.max_size + S - 1 to ensure
# every token gets at least self.max_size context
trim_size = self._idx - self.max_size + 1
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += keys.shape[2]
self._idx = self.keys.shape[2]
return self.keys, self.values
def _update_in_place(self, keys, values):
# May not have hit the max size yet, so potentially
# keep growing the cache
B, n_kv_heads, S, k_head_dim = keys.shape
prev = self.offset
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
):
v_head_dim = values.shape[3]
new_size = min(self.step, self.max_size - prev)
k_shape = (B, n_kv_heads, new_size, k_head_dim)
v_shape = (B, n_kv_heads, new_size, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self._idx = prev
# Trim if needed
trim_size = self.keys.shape[2] - self.max_size
if trim_size > 0:
self.keys = self._trim(trim_size, self.keys)
self.values = self._trim(trim_size, self.values)
self._idx = self.max_size
# Rotate
if self._idx == self.max_size:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + S, :] = keys
self.values[..., self._idx : self._idx + S, :] = values
self.offset += S
self._idx += S
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
def update_and_fetch(self, keys, values):
if keys.shape[2] == 1:
return self._update_in_place(keys, values)
return self._update_concat(keys, values)
@property
def state(self):
if self.offset < self.keys.shape[2]:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
else:
return self.keys, self.values
@state.setter
def state(self, v):
self.keys, self.values = v
@property
def meta_state(self):
return tuple(
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
)
@meta_state.setter
def meta_state(self, v):
self.keep, self.max_size, self.step, self.offset, self._idx = map(
int,
v,
)
def is_trimmable(self):
return self.offset < self.max_size
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
self._idx -= n
return n
class MambaCache(_BaseCache):
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
@state.setter
def state(self, v):
self.cache = v

View File

@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@ -69,7 +69,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -129,7 +129,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.input_layernorm(x)
attn_h = self.self_attn(h, mask, cache)
@ -190,11 +190,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@ -49,7 +49,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
qkv = self.Wqkv(x)
@ -92,7 +92,7 @@ class NormAttnNorm(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
x = h + x
@ -179,7 +179,7 @@ class DecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r, h = self.norm_attn_norm(x, mask, cache)
out = self.ffn(h) + r
@ -249,11 +249,3 @@ class Model(nn.Module):
experts = [(s, sv.T) for s, sv in experts]
new_weights.update(experts)
return new_weights
@property
def head_dim(self):
return self.args.d_model // self.args.n_heads
@property
def n_kv_heads(self):
return self.args.attn_config["kv_n_heads"]

View File

@ -1,10 +1,10 @@
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@ -77,7 +77,7 @@ class DeepseekAttention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
@ -108,8 +108,8 @@ class DeepseekMLP(nn.Module):
def __init__(
self,
config: ModelArgs,
hidden_size: int | None = None,
intermediate_size: int | None = None,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
):
super().__init__()
self.config = config
@ -188,7 +188,7 @@ class DeepseekDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@ -210,7 +210,7 @@ class DeepseekModel(nn.Module):
def __call__(
self,
x: mx.array,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
@ -235,7 +235,7 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
@ -256,11 +256,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -2,12 +2,12 @@
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@ -38,7 +38,7 @@ class ModelArgs(BaseModelArgs):
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Optional[Dict] = None
rope_scaling: Dict = None
attention_bias: bool = False
@ -172,7 +172,6 @@ class DeepseekV2Attention(nn.Module):
bias=config.attention_bias,
)
if self.config.rope_scaling is not None:
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
@ -202,7 +201,7 @@ class DeepseekV2Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -221,17 +220,17 @@ class DeepseekV2Attention(nn.Module):
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1)
if cache is not None:
q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values
)
else:
q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
@ -292,7 +291,7 @@ class MoEGate(nn.Module):
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
scores = scores * self.routed_scaling_factor
@ -347,7 +346,7 @@ class DeepseekV2DecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@ -370,7 +369,7 @@ class DeepseekV2Model(nn.Module):
def __call__(
self,
x: mx.array,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
@ -395,7 +394,7 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
@ -416,14 +415,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return (
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
self.args.v_head_dim,
)
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@ -60,7 +60,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -113,7 +113,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@ -173,11 +173,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@ -64,7 +64,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
@ -135,13 +135,11 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x.astype(mx.float32)), mask, cache)
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + self.post_attention_layernorm(r)
r = self.mlp(self.pre_feedforward_layernorm(h).astype(mx.float16)).astype(
mx.float32
)
r = self.mlp(self.pre_feedforward_layernorm(h))
out = h + self.post_feedforward_layernorm(r)
return out
@ -200,11 +198,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@ -46,7 +46,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -100,7 +100,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
@ -196,11 +196,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.h
@property
def head_dim(self):
return self.args.n_embd // self.args.n_head
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@ -57,7 +57,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -114,7 +114,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
@ -184,11 +184,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.n_embd // self.args.n_head
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@ -60,7 +60,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -120,7 +120,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
residual = x
# NeoX runs attention and feedforward network in parallel.
@ -214,11 +214,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.h
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@ -116,7 +116,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -171,7 +171,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attention(self.attention_norm(x), mask, cache)
h = x + r
@ -236,11 +236,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -171,7 +171,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -233,7 +233,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@ -303,13 +303,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return (
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
)
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -7,6 +7,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
@ -45,21 +46,6 @@ class ModelArgs(BaseModelArgs):
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaCache:
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
@ -223,7 +209,7 @@ class Model(nn.Module):
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self, batch_size: int = 1):
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property

View File

@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@ -85,7 +85,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
):
B, L, _ = x.shape
@ -135,7 +135,7 @@ class DecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
@ -205,11 +205,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -2,7 +2,7 @@
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@ -66,7 +66,7 @@ class MixtralAttention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -138,7 +138,7 @@ class MixtralDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@ -215,11 +215,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -2,12 +2,12 @@
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -94,7 +94,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
@ -151,7 +151,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@ -215,13 +215,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return (
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
)
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -1,8 +1,8 @@
# Copyright © 2023-2024 Apple Inc.
import sys
from dataclasses import dataclass
from sys import exit
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@ -13,7 +13,7 @@ try:
import hf_olmo
except ImportError:
print("To run olmo install ai2-olmo: pip install ai2-olmo")
exit(1)
sys.exit(1)
@dataclass
@ -68,7 +68,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -98,7 +98,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attend(self.att_norm(x), mask, cache)
h = x + r
@ -174,11 +174,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.transformer.blocks
@property
def head_dim(self):
return self.args.d_model // self.args.n_heads
@property
def n_kv_heads(self):
return self.args.n_heads

View File

@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@ -80,7 +80,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -152,7 +152,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.attn_norm(x), mask, cache)
h = x + r
@ -218,11 +218,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.transformer.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_kv_heads

View File

@ -162,19 +162,11 @@ class Model(nn.Module):
def __call__(
self,
x: mx.array,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
cache=None,
) -> mx.array:
y = self.model(x, cache)
return self.lm_head(y)
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
from .su_rope import SuScaledRotaryEmbedding
@ -84,7 +84,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -143,7 +143,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@ -202,11 +202,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -3,12 +3,12 @@
import math
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Tuple, Union
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -22,14 +22,14 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: Optional[int] = None
num_key_value_heads: int
mup_attn_multiplier: float = 1.0
mup_use_scaling: bool = True
mup_embedding_multiplier: float = 10.0
mup_width_multiplier: float = 8.0
rope_embedding_base: float = 1000000
rope_position_scale: float = 1.0
blocksparse_block_size: Tuple[int] = (64,)
blocksparse_block_size: int = 64
blocksparse_num_local_blocks: int = 16
blocksparse_vert_stride: int = 8
@ -61,7 +61,6 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_q_per_kv = n_heads // n_kv_heads
@ -161,7 +160,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -230,7 +229,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@ -304,16 +303,8 @@ class Model(nn.Module):
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -173,6 +173,7 @@ class PhiMoEModel(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.args = args
self.model = PhiMoEModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True)
@ -208,11 +209,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -168,8 +168,8 @@ class Model(nn.Module):
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
y = self.transformer(x, mask, cache)
@ -193,11 +193,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.model_dim // self.args.num_heads
@property
def n_kv_heads(self):
return self.args.num_heads

View File

@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
@ -62,8 +62,8 @@ class Attention(nn.Module):
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
cache: Optional[Any] = None,
) -> mx.array:
bsz, q_len, _ = hidden_states.shape
queries = self.q_proj(hidden_states)
@ -89,6 +89,9 @@ class Attention(nn.Module):
queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys)
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
output = mx.fast.scaled_dot_product_attention(
queries,
keys,
@ -127,8 +130,8 @@ class PlamoDecoderLayer(nn.Module):
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[Any, ...]:
cache: Optional[Any] = None,
):
# from LlamaDecoder
residual = hidden_states
@ -169,8 +172,8 @@ class PlamoModel(nn.Module):
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None,
) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]:
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
@ -197,19 +200,11 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
) -> Tuple[mx.array, mx.array]:
cache: Optional[Any] = None,
) -> mx.array:
out = self.model(inputs, cache)
return self.lm_head(out)
@property
def layers(self):
return self.model.layers.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_attention_heads // self.args.n_shared_head

View File

@ -1,7 +1,6 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
@ -149,19 +148,11 @@ class Model(nn.Module):
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
cache=None,
) -> mx.array:
y = self.transformer(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_attention_heads

View File

@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -70,7 +70,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -124,7 +124,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@ -196,11 +196,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -2,12 +2,12 @@
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@ -70,7 +70,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -162,7 +162,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@ -236,11 +236,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -7,13 +7,13 @@ from typing import List, Literal, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
from .cache import MambaCache, RotatingKVCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
attention_bias: bool
conv1d_width: int
hidden_size: int
@ -36,59 +36,6 @@ class ModelArgs(BaseModelArgs):
self.block_types = self._block_types
def create_window_causal_mask(N: int, window_size: int):
inds = mx.arange(N)
linds = inds[:, None]
rinds = inds[None]
mask = (linds < rinds) | (linds > rinds + window_size)
return mask * -1e9
class RecurrentCache:
def __init__(self):
self._cache = (None, None)
def __getitem__(self, idx):
return self._cache[idx]
def update(self, conv_state, recurrent_state):
self._cache = (conv_state, recurrent_state)
def state(self):
return self._cache
class WindowKVCache:
def __init__(self, window_size):
self.keys = None
self.values = None
self.offset = 0
self.window_size = window_size
def update_and_fetch(self, keys, values):
# TODO consider using rotating buffer here
# especially for very long generations
def _update(x, v):
t = x.shape[2] - self.window_size
if t > 0:
x = x[..., t:, :]
return mx.concatenate([x, v], axis=2)
self.offset += keys.shape[2]
if self.keys is None:
self.keys = keys
self.values = values
else:
self.keys = _update(self.keys, keys)
self.values = _update(self.values, values)
return self.keys, self.values
def state(self):
return self.keys, self.values
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
@ -136,31 +83,22 @@ class Conv1d(nn.Module):
kernel_size: int,
):
super().__init__()
self.weight = mx.zeros((kernel_size, channels))
self.weight = mx.zeros((channels, kernel_size, 1))
self.bias = mx.zeros((channels,))
def __call__(self, x, cache=None):
w = self.weight.T[..., None]
kw, groups = self.weight.shape
if cache is not None:
l = []
# Pad the cache if needed
if cache.shape[1] < kw - 1:
l.append(
mx.zeros(
(x.shape[0], kw - 1 - cache.shape[1], groups), dtype=x.dtype
)
)
l.extend([cache, x])
x = mx.concatenate(l, axis=1)
y = (x * w.swapaxes(0, 2)).sum(axis=1, keepdims=True)
else:
y = mx.conv_general(x, w, padding=([kw - 1], [0]), groups=groups)
B, L, C = x.shape
groups, K, _ = self.weight.shape
# The cache is always kw - 1
cache = x[:, max(x.shape[1] - kw + 1, 0) :, :]
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=groups)
y = y + self.bias
return y, cache
return y, x[:, -K + 1 :, :]
class RGLRU(nn.Module):
@ -269,19 +207,9 @@ class RecurrentBlock(nn.Module):
# x branch.
x = self.linear_x(x)
if cache is None:
conv_state, recurrent_state = (None, None)
else:
conv_state, recurrent_state = cache[0], cache[1]
x, conv_state = self.conv_1d(
x=x,
cache=conv_state,
)
x, recurrent_state = self.rg_lru(
x=x,
cache=recurrent_state,
)
if cache is not None:
cache.update(conv_state, recurrent_state)
cache = [None, None]
x, cache[0] = self.conv_1d(x=x, cache=cache[0])
x, cache[1] = self.rg_lru(x=x, cache=cache[1])
x = x * y
x = self.linear_out(x)
@ -467,12 +395,14 @@ class Griffin(nn.Module):
if self.scale_by_sqrt_dim:
x = x * math.sqrt(x.shape[-1])
mask = None
if x.shape[1] > 1:
mask = create_window_causal_mask(
x.shape[1], self.config.attention_window_size
)
mask = mask.astype(x.dtype)
if cache is None:
cache = [None] * len(self.layers)
for i, block in enumerate(self.layers):
if block.temporal_block_type != "recurrent":
mask_cache = [cache[i]]
mask = create_attention_mask(x, mask_cache)
for i, block in enumerate(self.layers):
x = block(x, mask=mask, cache=cache[i])
@ -485,6 +415,7 @@ class Model(nn.Module):
def __init__(self, config):
self.args = config
self.model = Griffin(config)
self.model_type = config.model_type
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
@ -508,10 +439,9 @@ class Model(nn.Module):
return self.model.layers
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
for k, v in weights.items():
if "conv_1d.weight" in k and v.ndim == 3:
weights[k] = v.squeeze(1).T
weights[k] = v.moveaxis(2, 1)
if "lm_head.weight" not in weights:
self.pop("lm_head")
return weights
@ -520,7 +450,7 @@ class Model(nn.Module):
cache = []
for layer in self.layers:
if layer.temporal_block_type == "recurrent":
cache.append(RecurrentCache())
cache.append(MambaCache())
else:
cache.append(WindowKVCache(self.args.attention_window_size))
cache.append(RotatingKVCache(max_size=self.args.attention_window_size))
return cache

View File

@ -2,7 +2,6 @@
import math
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
@ -198,8 +197,8 @@ class Model(nn.Module):
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)
@ -207,11 +206,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@ -45,7 +45,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@ -100,7 +100,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@ -164,11 +164,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@ -3,19 +3,38 @@
import argparse
import json
import logging
import platform
import time
import uuid
import warnings
from dataclasses import dataclass, field
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
from typing import (
Any,
Dict,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
import mlx.core as mx
from huggingface_hub import scan_cache_dir
from ._version import __version__
from .models.cache import make_prompt_cache
from .utils import generate_step, load
def get_system_fingerprint():
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"
class StopCondition(NamedTuple):
stop_met: bool
trim_length: int
@ -94,6 +113,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip()
@dataclass
class PromptCache:
cache: List[Any] = field(default_factory=list)
model_key: Tuple[str, Optional[str]] = ("", None)
tokens: List[int] = field(default_factory=list)
class ModelProvider:
def __init__(self, cli_args: argparse.Namespace):
"""Load models on demand and persist them across the whole process."""
@ -156,12 +182,21 @@ class ModelProvider:
class APIHandler(BaseHTTPRequestHandler):
def __init__(self, model_provider: ModelProvider, *args, **kwargs):
def __init__(
self,
model_provider: ModelProvider,
*args,
prompt_cache: Optional[PromptCache] = None,
system_fingerprint: Optional[str] = None,
**kwargs,
):
"""
Create static request specific metadata
"""
self.created = int(time.time())
self.model_provider = model_provider
self.prompt_cache = prompt_cache or PromptCache()
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
super().__init__(*args, **kwargs)
def _set_cors_headers(self):
@ -215,7 +250,9 @@ class APIHandler(BaseHTTPRequestHandler):
self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model")
self.adapter = self.body.get("adapters", None)
self.max_tokens = self.body.get("max_tokens", 100)
self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 1.0)
self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
@ -343,7 +380,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Static response
response = {
"id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"system_fingerprint": self.system_fingerprint,
"object": self.object_type,
"model": self.requested_model,
"created": self.created,
@ -388,16 +425,30 @@ class APIHandler(BaseHTTPRequestHandler):
return response
def get_prompt_cache(self, prompt):
cache_len = len(self.prompt_cache.tokens)
if (
self.prompt_cache.model_key != self.model_provider.model_key
or cache_len >= len(prompt)
or self.prompt_cache.tokens != prompt[:cache_len]
):
self.prompt_cache.model_key = self.model_provider.model_key
self.prompt_cache.cache = make_prompt_cache(self.model_provider.model)
else:
prompt = prompt[cache_len:]
self.prompt_cache.tokens.extend(prompt)
return prompt
def handle_completion(
self,
prompt: mx.array,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
Generate a response to a prompt and send it to the client in a single batch.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
prompt (List[int]): The tokenized prompt.
stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function
"""
@ -409,17 +460,21 @@ class APIHandler(BaseHTTPRequestHandler):
logging.debug(f"Starting completion:")
token_logprobs = []
top_tokens = []
for (token, logprobs), _ in zip(
prompt = self.get_prompt_cache(prompt)
for _, (token, logprobs) in zip(
range(self.max_tokens),
generate_step(
prompt=prompt,
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache,
),
range(self.max_tokens),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
@ -430,7 +485,7 @@ class APIHandler(BaseHTTPRequestHandler):
top_indices = sorted_indices[: self.logprobs]
top_logprobs = logprobs[top_indices]
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
top_tokens.append(dict(top_token_info))
top_tokens.append(tuple(top_token_info))
token_logprobs.append(logprobs[token].item())
@ -445,6 +500,7 @@ class APIHandler(BaseHTTPRequestHandler):
)
break
self.prompt_cache.tokens.extend(tokens)
detokenizer.finalize()
text = (
detokenizer.text
@ -474,7 +530,7 @@ class APIHandler(BaseHTTPRequestHandler):
def handle_stream(
self,
prompt: mx.array,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
@ -482,7 +538,7 @@ class APIHandler(BaseHTTPRequestHandler):
Sent Events (SSE) stream.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
prompt (mx.array): The tokenized prompt
stop_id_sequences (List[List[int]]): A list of stop words passed to
the stopping_criteria function
"""
@ -496,16 +552,19 @@ class APIHandler(BaseHTTPRequestHandler):
stop_sequence_suffix = None
logging.debug(f"Starting stream:")
for (token, _), _ in zip(
prompt = self.get_prompt_cache(prompt)
for _, (token, _) in zip(
range(self.max_tokens),
generate_step(
prompt=prompt,
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
prompt_cache=self.prompt_cache.cache,
),
range(self.max_tokens),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
@ -531,10 +590,13 @@ class APIHandler(BaseHTTPRequestHandler):
continue
new_text = detokenizer.last_segment
if new_text:
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.prompt_cache.tokens.extend(tokens)
# check is there any remaining text to send
detokenizer.finalize()
last_segment = detokenizer.last_segment
@ -559,7 +621,7 @@ class APIHandler(BaseHTTPRequestHandler):
):
response = {
"id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"system_fingerprint": self.system_fingerprint,
"object": "chat.completion",
"model": self.requested_model,
"created": self.created,
@ -572,7 +634,7 @@ class APIHandler(BaseHTTPRequestHandler):
}
return response
def handle_chat_completions(self) -> mx.array:
def handle_chat_completions(self) -> List[int]:
"""
Handle a chat completion request.
@ -587,7 +649,6 @@ class APIHandler(BaseHTTPRequestHandler):
self.object_type = (
"chat.completions.chunk" if self.stream else "chat.completions"
)
if (
hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template
@ -602,9 +663,9 @@ class APIHandler(BaseHTTPRequestHandler):
prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = self.tokenizer.encode(prompt)
return mx.array(prompt)
return prompt
def handle_text_completions(self) -> mx.array:
def handle_text_completions(self) -> List[int]:
"""
Handle a text completion request.
@ -614,11 +675,8 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type
self.request_id = f"cmpl-{uuid.uuid4()}"
self.object_type = "text_completion"
assert "prompt" in self.body, "Request did not contain a prompt"
prompt_text = self.body["prompt"]
prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt)
return self.tokenizer.encode(self.body["prompt"])
def do_GET(self):
"""
@ -669,9 +727,16 @@ def run(
handler_class=APIHandler,
):
server_address = (host, port)
prompt_cache = PromptCache()
httpd = server_class(
server_address,
lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs),
lambda *args, **kwargs: handler_class(
model_provider,
prompt_cache=prompt_cache,
system_fingerprint=get_system_fingerprint(),
*args,
**kwargs,
),
)
warnings.warn(
"mlx_lm.server is not recommended for production as "

View File

@ -97,6 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
def text(self):
if self._current_tokens:
self._current_text = self._tokenizer.decode(self._current_tokens)
if (
self._tokenizer.clean_up_tokenization_spaces
and self._current_text[-1] == " "
):
self._current_text = self._current_text[:-1]
if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens)
self._text += self._current_text
@ -164,9 +169,11 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
"""
_byte_decoder = None
_space_matches = (".", "?", "!", ",", "'", "n't", "'m", "'s", "'ve", "'re")
def __init__(self, tokenizer, trim_space=False):
self.trim_space = trim_space
def __init__(self, tokenizer):
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
# Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab)
@ -185,17 +192,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
self.text = ""
self.tokens = []
def _maybe_trim_space(self, current_text):
if current_text[0] != " ":
return current_text
elif not self.text:
return current_text[1:]
elif self.clean_spaces and current_text[1:].startswith(self._space_matches):
return current_text[1:]
return current_text
def add_token(self, token):
v = self.tokenmap[token]
# if the token starts with space
if self._byte_decoder[v[0]] == 32:
current_text = bytearray(
self._byte_decoder[c] for c in self._unflushed
).decode("utf-8")
if self.text or not self.trim_space:
self.text += current_text
else:
self.text += _remove_space(current_text)
self.text += self._maybe_trim_space(current_text)
self._unflushed = v
else:
self._unflushed += v
@ -204,10 +216,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
"utf-8"
)
if self.text or not self.trim_space:
self.text += current_text
else:
self.text += _remove_space(current_text)
self.text += self._maybe_trim_space(current_text)
self._unflushed = ""
@classmethod
@ -303,14 +312,7 @@ def _is_spm_decoder_no_space(decoder):
def _is_bpe_decoder(decoder):
_target_description = {
"type": "ByteLevel",
"add_prefix_space": False,
"trim_offsets": False,
"use_regex": False,
}
return _match(_target_description, decoder)
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
def load_tokenizer(model_path, tokenizer_config_extra={}):

View File

@ -18,7 +18,7 @@ from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
# Local imports
from .models.base import KVCache, RotatingKVCache
from .models import base, cache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model
@ -124,26 +124,6 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
return logits
def make_kv_caches(
model: nn.Module, max_kv_size: Optional[int] = None
) -> List[Union[KVCache, RotatingKVCache]]:
if hasattr(model, "make_cache"):
return model.make_cache()
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
if max_kv_size is not None:
return [
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
for n in kv_heads
]
else:
return [KVCache(model.head_dim, n) for n in kv_heads]
def generate_step(
prompt: mx.array,
model: nn.Module,
@ -155,7 +135,7 @@ def generate_step(
min_tokens_to_keep: int = 1,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
@ -180,6 +160,8 @@ def generate_step(
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias.
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
@ -237,20 +219,13 @@ def generate_step(
tokens = None
# Create the KV cache for generation
cache = make_kv_caches(model, max_kv_size)
if cache_history is not None:
if len(cache_history) != len(cache):
raise ValueError("Wrong number of layers in the cache history")
# Set the history in the cache objects and evaluate them to prepare for
# generation.
for c, h in zip(cache, cache_history):
c.update_and_fetch(h[0], h[1])
mx.eval([c.state for c in cache])
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
def _step(y):
logits = model(y[None], cache=cache)
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
if logits_processor:
@ -264,16 +239,17 @@ def generate_step(
return y, logprobs.squeeze(0)
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
mx.eval([c.state for c in cache])
model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
y = y[prefill_step_size:]
mx.metal.clear_cache()
y, logprobs = _step(y)
mx.async_eval(y)
mx.async_eval(y, logprobs)
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y)
mx.async_eval(next_y, next_logprobs)
yield y.item(), logprobs
y, logprobs = next_y, next_logprobs
@ -305,9 +281,9 @@ def stream_generate(
detokenizer = tokenizer.detokenizer
detokenizer.reset()
for (token, _), n in zip(
generate_step(prompt_tokens, model, **kwargs),
for n, (token, _) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if token == tokenizer.eos_token_id:
break
@ -357,9 +333,9 @@ def generate(
tic = time.perf_counter()
detokenizer.reset()
for (token, logprobs), n in zip(
generate_step(prompt_tokens, model, **kwargs),
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic
@ -372,7 +348,9 @@ def generate(
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
with mx.stream(mx.cpu):
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)

View File

@ -32,6 +32,7 @@ setup(
entry_points={
"console_scripts": [
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
"mlx_lm.chat = mlx_lm.chat:main",
"mlx_lm.convert = mlx_lm.convert:main",
"mlx_lm.fuse = mlx_lm.fuse:main",
"mlx_lm.generate = mlx_lm.generate:main",

View File

@ -1,17 +1,15 @@
# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
from mlx.utils import tree_map
from mlx_lm.models.base import KVCache, RotatingKVCache
from mlx_lm.utils import make_kv_caches
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
class TestModels(unittest.TestCase):
def test_kv_cache(self):
cache = KVCache(32, 4)
cache = KVCache()
k = mx.ones((1, 4, 1, 32), mx.float16)
v = mx.ones((1, 4, 1, 32), mx.float16)
@ -32,7 +30,7 @@ class TestModels(unittest.TestCase):
def test_rotating_kv_cache(self):
b, h, d = 1, 2, 32
cache = RotatingKVCache(d, h, max_size=8, step=4)
cache = RotatingKVCache(max_size=8, step=4)
k = mx.random.uniform(shape=(b, h, 2, d))
v = mx.random.uniform(shape=(b, h, 2, d))
@ -65,7 +63,7 @@ class TestModels(unittest.TestCase):
idx %= 8
# Try with nonzero keep
cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2)
cache = RotatingKVCache(max_size=8, step=4, keep=2)
# Check a large update
k = mx.random.uniform(shape=(b, h, 20, d))
@ -88,6 +86,46 @@ class TestModels(unittest.TestCase):
if idx >= 8:
idx = 2
def test_rotating_kv_cache_chat_mode(self):
# Test that the rotating kv cache can handle
# alternating prompt/prefill with generation
d = 4
h = 2
cache = RotatingKVCache(max_size=18, step=4)
x = mx.random.uniform(shape=(1, h, 8, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 8)
self.assertEqual(cache.offset, 8)
x = mx.random.uniform(shape=(1, h, 1, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 9)
self.assertEqual(cache.offset, 9)
self.assertTrue(mx.allclose(x, k[..., 8:9, :]))
x = mx.random.uniform(shape=(1, h, 2, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 11)
self.assertEqual(cache.offset, 11)
self.assertTrue(mx.allclose(x, k[..., 9:11, :]))
x = mx.random.uniform(shape=(1, h, 3, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 14)
self.assertEqual(cache.offset, 14)
self.assertTrue(mx.allclose(x, k[..., 11:14, :]))
x = mx.random.uniform(shape=(1, h, 6, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(cache.offset, 20)
self.assertTrue(mx.allclose(x, k[..., -6:, :]))
x = mx.random.uniform(shape=(1, h, 2, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(cache.offset, 22)
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
def model_test_runner(self, model, model_type, vocab_size, num_layers):
self.assertEqual(len(model.layers), num_layers)
@ -101,7 +139,7 @@ class TestModels(unittest.TestCase):
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
cache = make_kv_caches(model)
cache = make_prompt_cache(model)
outputs = model(inputs, cache)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
@ -549,6 +587,179 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_deepseek(self):
from mlx_lm.models import deepseek
args = deepseek.ModelArgs(
model_type="deepseek",
vocab_size=1024,
hidden_size=128,
intermediate_size=256,
moe_intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=8,
num_key_value_heads=4,
)
model = deepseek.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_deepseek_v2(self):
from mlx_lm.models import deepseek_v2
args = deepseek_v2.ModelArgs(
model_type="deepseek_v2",
vocab_size=1024,
hidden_size=128,
intermediate_size=256,
moe_intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
kv_lora_rank=4,
q_lora_rank=4,
qk_rope_head_dim=32,
v_head_dim=16,
qk_nope_head_dim=32,
rope_scaling={
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn",
},
)
model = deepseek_v2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gemma2(self):
from mlx_lm.models import gemma2
args = gemma2.ModelArgs(
model_type="gemma2",
hidden_size=128,
num_hidden_layers=4,
intermediate_size=256,
num_attention_heads=2,
head_dim=32,
rms_norm_eps=1e-4,
vocab_size=1024,
num_key_value_heads=2,
)
model = gemma2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gpt_bigcode(self):
from mlx_lm.models import gpt_bigcode
args = gpt_bigcode.ModelArgs(
model_type="gpt_bigcode",
n_embd=128,
n_layer=128,
n_inner=256,
n_head=4,
n_positions=1000,
layer_norm_epsilon=1e-5,
vocab_size=1024,
)
model = gpt_bigcode.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
def test_nemotron(self):
from mlx_lm.models import nemotron
args = nemotron.ModelArgs(
model_type="nemotron",
hidden_size=128,
hidden_act="gelu",
num_hidden_layers=4,
intermediate_size=256,
num_attention_heads=4,
norm_eps=1e-5,
vocab_size=1024,
num_key_value_heads=2,
)
model = nemotron.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phi3small(self):
from mlx_lm.models import phi3small
args = phi3small.ModelArgs(
model_type="phi3small",
hidden_size=128,
dense_attention_every_n_layers=2,
ff_intermediate_size=256,
gegelu_limit=1.0,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
layer_norm_epsilon=1e-4,
vocab_size=1000,
)
model = phi3small.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phimoe(self):
from mlx_lm.models import phimoe
args = phimoe.ModelArgs(
model_type="phimoe",
vocab_size=320,
hidden_size=128,
intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=4,
rope_scaling={
"long_factor": [1.0] * 16,
"long_mscale": 1.243163121016122,
"original_max_position_embeddings": 4096,
"short_factor": [1.0] * 16,
"short_mscale": 1.243163121016122,
"type": "longrope",
},
)
model = phimoe.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_recurrent_gemma(self):
from mlx_lm.models import recurrent_gemma
args = recurrent_gemma.ModelArgs(
model_type="recurrent_gemma",
hidden_size=128,
attention_bias=False,
conv1d_width=3,
intermediate_size=256,
logits_soft_cap=1.0,
num_attention_heads=4,
num_hidden_layers=4,
num_key_value_heads=2,
rms_norm_eps=1e-4,
rope_theta=1000,
attention_window_size=1024,
vocab_size=1000,
block_types=["recurrent", "recurrent", "attention"],
)
model = recurrent_gemma.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,243 @@
# Copyright © 2024 Apple Inc.
import copy
import os
import tempfile
import unittest
import mlx.core as mx
from mlx_lm.models.cache import (
KVCache,
MambaCache,
RotatingKVCache,
load_prompt_cache,
make_prompt_cache,
save_prompt_cache,
trim_prompt_cache,
)
from mlx_lm.utils import generate_step, load
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
class TestPromptCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
def test_save_load(self):
cache = [KVCache() for _ in range(4)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 4))
c.update_and_fetch(x, x)
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
save_prompt_cache(cache_file, cache)
loaded_cache = load_prompt_cache(cache_file)
self.assertTrue(len(cache), len(loaded_cache))
for c, lc in zip(cache, loaded_cache):
self.assertEqual(c.offset, lc.offset)
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
# Test with metadata
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
metadata = {"a": "b", "c": "d"}
save_prompt_cache(cache_file, cache, metadata)
_, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True)
self.assertEqual(metadata, loaded_metadata)
def test_save_load_rotating_cache(self):
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
# Test with rotating cache
cache = [RotatingKVCache(max_size=8, keep=2) for _ in range(4)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 4))
c.update_and_fetch(x, x)
save_prompt_cache(cache_file, cache)
loaded_cache = load_prompt_cache(cache_file)
self.assertTrue(len(cache), len(loaded_cache))
for c, lc in zip(cache, loaded_cache):
self.assertEqual(c.offset, lc.offset)
self.assertEqual(c.keep, lc.keep)
self.assertEqual(c.max_size, lc.max_size)
self.assertEqual(c.step, lc.step)
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
# Do a couple single token updates to get a rotation
for _ in range(2):
for c in cache:
x = mx.random.uniform(shape=(1, 8, 1, 4))
c.update_and_fetch(x, x)
save_prompt_cache(cache_file, cache)
loaded_cache = load_prompt_cache(cache_file)
for c, lc in zip(cache, loaded_cache):
x = mx.random.uniform(shape=(1, 8, 1, 4))
k, v = c.update_and_fetch(x, x)
lk, lv = lc.update_and_fetch(x, x)
self.assertEqual(c.offset, lc.offset)
self.assertTrue(mx.array_equal(k, lk))
self.assertTrue(mx.array_equal(v, lv))
def test_save_load_mixed_cache(self):
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
cache = [MambaCache(), KVCache(), RotatingKVCache(8), MambaCache()]
for c in cache:
if isinstance(c, MambaCache):
c[0] = mx.random.uniform(shape=(4, 4, 4))
c[1] = mx.random.uniform(shape=(4, 4, 4))
else:
x = mx.random.uniform(shape=(4, 4, 7, 4))
y = mx.random.uniform(shape=(4, 4, 7, 4))
c.update_and_fetch(x, y)
save_prompt_cache(cache_file, cache)
loaded_cache = load_prompt_cache(cache_file)
for c, lc in zip(cache, loaded_cache):
if isinstance(c, MambaCache):
self.assertTrue(mx.array_equal(c[0], lc[0]))
self.assertTrue(mx.array_equal(c[1], lc[1]))
else:
x = mx.random.uniform(shape=(4, 4, 1, 4))
y = mx.random.uniform(shape=(4, 4, 1, 4))
k, v = c.update_and_fetch(x, y)
lk, lv = lc.update_and_fetch(x, y)
self.assertEqual(c.offset, lc.offset)
self.assertTrue(mx.array_equal(k, lk))
self.assertTrue(mx.array_equal(v, lv))
def test_cache_with_generate(self):
model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
results = zip(range(4), generate_step(prompt, model))
toks, all_logits = zip(*(r[1] for r in results))
prompt_cache = make_prompt_cache(model)
i = 0
for _, (tok, logits) in zip(
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
):
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i]))
i += 1
for _, (tok, logits) in zip(
range(1),
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
):
i += 1
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i]))
def test_trim_cache(self):
cache = [KVCache() for _ in range(2)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 4))
c.update_and_fetch(x, x)
# Trim
num_trimmed = trim_prompt_cache(cache, 7)
self.assertEqual(num_trimmed, 7)
# Trim more tokens than remain
num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 3)
# Can't trim mamba cache
cache = [MambaCache() for _ in range(2)]
for c in cache:
c.state = mx.zeros((5, 5))
num_trimmed = trim_prompt_cache(cache, 7)
self.assertEqual(num_trimmed, 0)
# All cache's have to be trimmable
cache = [MambaCache(), KVCache()]
cache[0].state = mx.zeros((5, 5))
x = mx.random.uniform(shape=(1, 8, 10, 4))
cache[1].update_and_fetch(x, x)
num_trimmed = trim_prompt_cache(cache, 1)
self.assertEqual(num_trimmed, 0)
cache = [RotatingKVCache(max_size=6) for _ in range(2)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 5, 4))
c.update_and_fetch(x, x)
num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 4)
# Can't trim fixed-size KV cache after processing
# more than max_kv_size tokens
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 4))
c.update_and_fetch(x, x)
num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 0)
def test_trim_cache_with_generate(self):
model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
prompt_cache = make_prompt_cache(model)
# Generate one token so we process the full prompt
last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache))
last_tok = mx.array([last_tok])
# Generate two more tokens
results = zip(
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
)
toks, all_logits = zip(*(r[1] for r in results))
# To get back to the cache just after processing the prompt,
# trim by 3 tokens
trim_prompt_cache(prompt_cache, 3)
# Generate the same thing again
results = zip(
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
)
second_toks, second_all_logits = zip(*(r[1] for r in results))
self.assertEqual(toks, second_toks)
self.assertTrue(
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
)
def test_cache_copying(self):
cache = [KVCache()]
x = mx.random.uniform(shape=(1, 8, 10, 4))
cache[0].update_and_fetch(x, x)
y = mx.random.uniform(shape=(1, 8, 1, 4))
cache[0].update_and_fetch(y, y)
old_cache = copy.deepcopy(cache)
trim_prompt_cache(cache, 1)
self.assertTrue(old_cache[0].offset, 11)
self.assertTrue(cache[0].offset, 10)
z = mx.random.uniform(shape=(1, 8, 1, 4))
cache[0].update_and_fetch(z, z)
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
if __name__ == "__main__":
unittest.main()

View File

@ -14,6 +14,7 @@ class DummyModelProvider:
def __init__(self):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH)
self.model_key = (HF_MODEL_PATH, None)
def load(self, model, adapter=None):
assert model in ["default_model", "chat_model"]

View File

@ -0,0 +1,76 @@
# Copyright © 2024 Apple Inc.
import unittest
from pathlib import Path
from huggingface_hub import snapshot_download
from mlx_lm.tokenizer_utils import (
BPEStreamingDetokenizer,
NaiveStreamingDetokenizer,
SPMStreamingDetokenizer,
load_tokenizer,
)
class TestTokenizers(unittest.TestCase):
def download_tokenizer(self, repo):
path = Path(
snapshot_download(
repo_id=repo,
allow_patterns=[
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"tokenizer.model",
],
)
)
return load_tokenizer(path)
def check_tokenizer(self, tokenizer):
def check(tokens):
expected_text = tokenizer.decode(tokens)
detokenizer = tokenizer.detokenizer
detokenizer.reset()
text = ""
for t in tokens:
detokenizer.add_token(t)
seg = detokenizer.last_segment
text += seg
detokenizer.finalize()
text += detokenizer.last_segment
self.assertEqual(text, expected_text)
tokens = tokenizer.encode("a ,b")
check(tokens)
tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}')
check(tokens)
tokens = tokenizer.encode("3 3")
check(tokens)
def test_tokenizers(self):
tokenizer_repos = [
("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer),
("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer),
("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer),
("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer),
("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer),
]
for tokenizer_repo, expected_detokenizer in tokenizer_repos:
with self.subTest(tokenizer=tokenizer_repo):
tokenizer = self.download_tokenizer(tokenizer_repo)
tokenizer.decode([0, 1, 2])
self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer))
self.check_tokenizer(tokenizer)
# Try one with a naive detokenizer
tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit")
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
self.check_tokenizer(tokenizer)
if __name__ == "__main__":
unittest.main()

30
musicgen/README.md Normal file
View File

@ -0,0 +1,30 @@
# MusicGen
An example of Meta's MusicGen model in MLX.[^1] MusicGen is used to generate
music from text descriptions.
### Setup
Install the requirements:
```
pip install -r requirements.txt
```
### Example
An example using the model:
```python
from musicgen import MusicGen
from utils import save_audio
model = MusicGen.from_pretrained("facebook/musicgen-medium")
audio = model.generate("happy rock")
save_audio("out.wav", audio, model.sampling_rate)
```
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2306.05284) and
[code](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md) for more details.

View File

@ -0,0 +1,28 @@
# Copyright © 2024 Apple Inc.
import sys
import time
from pathlib import Path
import mlx.core as mx
cur_path = Path(__file__).parents[1].resolve()
sys.path.append(str(cur_path))
from musicgen import MusicGen
text = "folk ballad"
model = MusicGen.from_pretrained("facebook/musicgen-medium")
max_steps = 100
audio = model.generate(text, max_steps=10)
mx.eval(audio)
tic = time.time()
audio = model.generate(text, max_steps=max_steps)
mx.eval(audio)
toc = time.time()
ms = 1000 * (toc - tic) / max_steps
print(f"Time (ms) per step: {ms:.3f}")

View File

@ -0,0 +1,31 @@
# Copyright © 2024 Apple Inc.
import time
import torch
from transformers import AutoProcessor, MusicgenForConditionalGeneration
model_name = "facebook/musicgen-medium"
processor = AutoProcessor.from_pretrained(model_name)
model = MusicgenForConditionalGeneration.from_pretrained(model_name).to("mps")
inputs = processor(
text=["folk ballad"],
padding=True,
return_tensors="pt",
)
inputs["input_ids"] = inputs["input_ids"].to("mps")
inputs["attention_mask"] = inputs["attention_mask"].to("mps")
# warmup
audio_values = model.generate(**inputs, max_new_tokens=10)
torch.mps.synchronize()
max_steps = 100
tic = time.time()
audio_values = model.generate(**inputs, max_new_tokens=max_steps)
torch.mps.synchronize()
toc = time.time()
ms = 1000 * (toc - tic) / max_steps
print(f"Time (ms) per step: {ms:.3f}")

1
musicgen/encodec.py Symbolic link
View File

@ -0,0 +1 @@
../encodec/encodec.py

23
musicgen/generate.py Normal file
View File

@ -0,0 +1,23 @@
# Copyright © 2024 Apple Inc.
import argparse
from utils import save_audio
from musicgen import MusicGen
def main(text: str, output_path: str, model_name: str, max_steps: int):
model = MusicGen.from_pretrained(model_name)
audio = model.generate(text, max_steps=max_steps)
save_audio(output_path, audio, model.sampling_rate)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=False, default="facebook/musicgen-medium")
parser.add_argument("--text", required=False, default="happy rock")
parser.add_argument("--output-path", required=False, default="0.wav")
parser.add_argument("--max-steps", required=False, default=500, type=int)
args = parser.parse_args()
main(args.text, args.output_path, args.model, args.max_steps)

358
musicgen/musicgen.py Normal file
View File

@ -0,0 +1,358 @@
# Copyright © 2024 Apple Inc.
import json
from functools import partial
from pathlib import Path
from types import SimpleNamespace
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from tqdm import tqdm
from encodec import EncodecModel
from t5 import T5
class TextConditioner(nn.Module):
def __init__(self, t5_name, input_dim, output_dim):
super().__init__()
self._t5, self.tokenizer = T5.from_pretrained(t5_name)
self.output_proj = nn.Linear(input_dim, output_dim)
def __call__(self, text):
x = self.tokenizer.encode(text)
x = self._t5.encode(x)
return self.output_proj(x)
class KVCache:
def __init__(self, head_dim, n_kv_heads):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keys = None
self.values = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B = keys.shape[0]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@property
def state(self):
return self.keys, self.values
class MultiHeadAttention(nn.Module):
def __init__(self, dim, n_heads):
super().__init__()
self.n_heads = n_heads
head_dim = dim // n_heads
self.scale = head_dim**-0.5
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
self.out_proj = nn.Linear(dim, dim, bias=False)
def __call__(
self,
queries: mx.array,
keys: mx.array,
values: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
B, L_q, D = queries.shape
L_k = keys.shape[1]
queries, keys, values = (
self.q_proj(queries),
self.k_proj(keys),
self.v_proj(values),
)
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L_q, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L_k, self.n_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L_k, self.n_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L_q, -1)
return self.out_proj(output)
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.num_attention_heads = config.decoder.num_attention_heads
self.hidden_size = config.decoder.hidden_size
self.self_attn = MultiHeadAttention(self.hidden_size, self.num_attention_heads)
self.cross_attn = MultiHeadAttention(self.hidden_size, self.num_attention_heads)
self.linear1 = nn.Linear(self.hidden_size, config.decoder.ffn_dim, bias=False)
self.linear2 = nn.Linear(config.decoder.ffn_dim, self.hidden_size, bias=False)
self.norm1 = nn.LayerNorm(self.hidden_size, eps=1e-5)
self.norm_cross = nn.LayerNorm(self.hidden_size, eps=1e-5)
self.norm2 = nn.LayerNorm(self.hidden_size, eps=1e-5)
def __call__(
self,
x: mx.array,
conditioning: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
) -> mx.array:
xn = self.norm1(x)
x += self.self_attn(xn, xn, xn, mask, cache)
xn = self.norm_cross(x)
x += self.cross_attn(xn, conditioning, conditioning, mask)
xn = self.norm2(x)
x += self.linear2(nn.gelu(self.linear1(xn)))
return x
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def top_k_sampling(
logits: mx.array, top_k: float, temperature: float, axis: int = -1
) -> mx.array:
"""
Apply top-k sampling to logits.
Args:
logits: The logits from the model's output.
top_k: Sample from the top k logits.
temperature: Temperature parameter for softmax distribution reshaping.
axis: Axis along which to sample.
Returns:
token selected based on the top-k criterion.
"""
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
probs = mx.softmax(logits * (1 / temperature), axis=axis)
# sort probs in ascending order
sorted_indices = mx.argsort(probs, axis=axis)
sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis)
prob_threshold = mx.take(sorted_probs, mx.array(-top_k), axis=axis)
# select the top K tokens in probability
top_probs = mx.where(
sorted_probs > prob_threshold,
sorted_probs,
0,
)
sorted_token = mx.random.categorical(mx.log(top_probs), axis=axis)
token = mx.take_along_axis(
sorted_indices, mx.expand_dims(sorted_token, axis), axis=axis
)
return token
def create_sin_embedding(positions: mx.array, dim: int, max_period: float = 10000):
assert dim % 2 == 0
half_dim = dim // 2
adim = mx.arange(half_dim).reshape(1, 1, -1)
phase = positions / (max_period ** (adim / (half_dim - 1)))
return mx.concatenate([mx.cos(phase), mx.sin(phase)], axis=-1)
class MusicGen(nn.Module):
def __init__(self, config):
self.num_codebooks = config.decoder.num_codebooks
self.codebook_size = config.audio_encoder.codebook_size
self.bos_token_id = config.decoder.bos_token_id
self.hidden_size = config.decoder.hidden_size
self.num_attention_heads = config.decoder.num_attention_heads
self.sampling_rate = config.audio_encoder.sampling_rate
self.text_conditioner = TextConditioner(
config.text_encoder._name_or_path,
config.text_encoder.d_model,
self.hidden_size,
)
self.emb = [
nn.Embedding(self.codebook_size + 1, self.hidden_size)
for _ in range(self.num_codebooks)
]
self.layers = [
TransformerBlock(config) for _ in range(config.decoder.num_hidden_layers)
]
self.out_norm = nn.LayerNorm(self.hidden_size, eps=1e-5)
self.linears = [
nn.Linear(self.hidden_size, self.codebook_size, bias=False)
for _ in range(self.num_codebooks)
]
encodec_name = config.audio_encoder._name_or_path.split("/")[-1]
encodec_name = encodec_name.replace("_", "-")
self._audio_decoder, _ = EncodecModel.from_pretrained(
f"mlx-community/{encodec_name}-float32"
)
def __call__(
self,
audio_tokens: mx.array,
conditioning: mx.array,
cache: list[KVCache] = None,
):
if cache is None:
cache = [None] * len(self.layers)
x = sum([self.emb[k](audio_tokens[..., k]) for k in range(self.num_codebooks)])
offset = cache[0].offset if cache[0] is not None else 0
pos_emb = create_sin_embedding(offset, self.hidden_size)
x += pos_emb.astype(x.dtype)
for layer, c in zip(self.layers, cache):
x = layer(x, conditioning, cache=c)
x = self.out_norm(x)
x = mx.stack([self.linears[k](x) for k in range(self.num_codebooks)], axis=-1)
return x
def generate(
self,
text: str,
max_steps: int = 200,
top_k: int = 250,
temp: float = 1.0,
guidance_coef: float = 3.0,
) -> mx.array:
"""
Generates a waveform conditioned on `text`.
Args:
text (str): The text to condition generation on.
max_steps (int): Max steps to generate.
top_k (int): Top k used in sampling.
temp (float): Sampling softmax temperature.
guidance_coef (float): Classifier free guidance coefficent.
Used to combine conditional and unconditional logits.
Returns:
An mx.array of audio samples of shape ``(num_samples,)``.
"""
# Assuming no audio prompt we start with all bos token for the codebooks
audio_shape = (1, max_steps + 1, self.num_codebooks)
audio_seq = mx.full(audio_shape, self.bos_token_id)
text_tokens = self.text_conditioner(text)
# Compute conditional and unconditional logits in one batch
text_tokens = mx.concatenate([text_tokens, mx.zeros_like(text_tokens)], axis=0)
head_dim = self.hidden_size // self.num_attention_heads
cache = [
KVCache(head_dim, self.num_attention_heads) for _ in range(len(self.layers))
]
for offset in tqdm(range(max_steps)):
audio_input = mx.tile(audio_seq[:, offset : offset + 1], [2, 1, 1])
audio_logits = self(audio_input, text_tokens, cache)
cond_logits, uncond_logits = audio_logits[:1], audio_logits[1:2]
audio_logits = uncond_logits + (cond_logits - uncond_logits) * guidance_coef
audio_tokens = top_k_sampling(audio_logits, top_k, temp, axis=-2)
# "delay" pattern
audio_tokens[..., offset + 1 :] = self.bos_token_id
audio_tokens[..., : -max_steps + offset] = self.bos_token_id
audio_seq[:, offset + 1 : offset + 2] = audio_tokens
mx.eval(audio_seq)
# Undo delay
for i in range(self.num_codebooks):
audio_seq[:, : -self.num_codebooks, i] = audio_seq[
:, i : -self.num_codebooks + i, i
]
audio_seq = audio_seq[:, 1 : -self.num_codebooks + 1]
audio_seq = mx.swapaxes(audio_seq, -1, -2)[:, mx.newaxis]
audio = self._audio_decoder.decode(audio_seq, audio_scales=[None])
return audio[0]
@classmethod
def sanitize(cls, weights):
out_weights = {}
for k, arr in weights.items():
if k.startswith("transformer."):
k = k[len("transformer.") :]
if "cross_attention" in k:
k = k.replace("cross_attention", "cross_attn")
if "condition_provider" in k:
k = k.replace(
"condition_provider.conditioners.description", "text_conditioner"
)
if "in_proj_weight" in k:
dim = arr.shape[0] // 3
name = "in_proj_weight"
out_weights[k.replace(name, "q_proj.weight")] = arr[:dim]
out_weights[k.replace(name, "k_proj.weight")] = arr[dim : dim * 2]
out_weights[k.replace(name, "v_proj.weight")] = arr[dim * 2 :]
continue
out_weights[k] = arr
return out_weights
@classmethod
def from_pretrained(cls, path_or_repo: str):
import torch
from huggingface_hub import snapshot_download
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "state_dict.bin"],
)
)
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
config.text_encoder = SimpleNamespace(**config.text_encoder)
config.audio_encoder = SimpleNamespace(**config.audio_encoder)
config.decoder = SimpleNamespace(**config.decoder)
weights = torch.load(path / "state_dict.bin", weights_only=True)["best_state"]
weights = {k: mx.array(v) for k, v in weights.items()}
weights = cls.sanitize(weights)
model = MusicGen(config)
model.load_weights(list(weights.items()))
return model

View File

@ -0,0 +1,6 @@
mlx>=0.18
numpy
huggingface_hub
torch
transformers
scipy

1
musicgen/t5.py Symbolic link
View File

@ -0,0 +1 @@
../t5/t5.py

15
musicgen/utils.py Normal file
View File

@ -0,0 +1,15 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import numpy as np
def save_audio(file: str, audio: mx.array, sampling_rate: int):
"""
Save audio to a wave (.wav) file.
"""
from scipy.io.wavfile import write
audio = mx.clip(audio, -1, 1)
audio = (audio * 32767).astype(mx.int16)
write(file, sampling_rate, np.array(audio))

View File

@ -7,31 +7,6 @@ tasks by prepending task-specific prefixes to the input, e.g.:
This example also supports the FLAN-T5 models variants.[^2]
## Setup
Download and convert the model:
```sh
python convert.py --model <model>
```
This will make the `<model>.npz` file which MLX can read.
The `<model>` can be any of the following:
| Model Name | Model Size |
| ---------- | ----------
| t5-small | 60 million |
| t5-base | 220 million |
| t5-large | 770 million |
| t5-3b | 3 billion |
| t5-11b | 11 billion |
The FLAN variants can be specified with `google/flan-t5-small`,
`google/flan-t5-base`, etc. See the [Hugging Face
page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a
complete list of models.
## Generate
Generate text with:
@ -48,6 +23,21 @@ To see a list of options run:
python t5.py --help
```
The `<model>` can be any of the following:
| Model Name | Model Size |
| ---------- | ----------
| t5-small | 60 million |
| t5-base | 220 million |
| t5-large | 770 million |
| t5-3b | 3 billion |
| t5-11b | 11 billion |
The FLAN variants can be specified with `google/flan-t5-small`,
`google/flan-t5-base`, etc. See the [Hugging Face
page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a
complete list of models.
[^1]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683)
or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5).
[^2]: For more information on FLAN-T5 see the [original paper](https://arxiv.org/abs/2210.11416).

View File

@ -1,75 +0,0 @@
import numpy as np
from transformers import T5ForConditionalGeneration
SHARED_REPLACEMENT_PATTERNS = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]
ENCODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.", ".dense."),
]
DECODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".self_attention."),
(".layer.1.EncDecAttention.", ".cross_attention."),
(".layer.2.DenseReluDense.", ".dense."),
]
def replace_key(key: str) -> str:
for old, new in SHARED_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
if key.startswith("encoder."):
for old, new in ENCODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
elif key.startswith("decoder."):
for old, new in DECODER_REPLACEMENT_PATTERNS:
key = key.replace(old, new)
return key
def convert(model_name, dtype):
dtype = getattr(np, dtype)
model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
weights = {
replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items()
}
file_name = model_name.replace("/", "-")
print(f"Saving weights to {file_name}.npz")
np.savez(f"{file_name}.npz", **weights)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert T5 weights to MLX")
parser.add_argument(
"--model",
type=str,
help="Name of the T5 model.",
default="t5-small",
)
parser.add_argument(
"--dtype",
help="The model data type.",
type=str,
choices=["float16", "float32"],
default="float32",
)
args = parser.parse_args()
convert(args.model, args.dtype)

181
t5/t5.py
View File

@ -1,12 +1,45 @@
import argparse
import json
from pathlib import Path
from time import perf_counter_ns
from types import SimpleNamespace
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_map, tree_unflatten
from transformers import AutoTokenizer, T5Config
from transformers import AutoTokenizer
class Tokenizer:
def __init__(self, config, model_name):
self._decoder_start_id = config.decoder_start_token_id
self._tokenizer = AutoTokenizer.from_pretrained(
model_name,
legacy=False,
model_max_length=getattr(config, "n_positions", 512),
)
@property
def eos_id(self) -> int:
return self._tokenizer.eos_token_id
@property
def decoder_start_id(self) -> int:
return self._decoder_start_id
def encode(self, s: str) -> mx.array:
return mx.array(
self._tokenizer(
s,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
)
def decode(self, t: List[int], with_sep: bool = True) -> str:
tokens = self._tokenizer.convert_ids_to_tokens(t)
return "".join(t.replace("", " " if with_sep else "") for t in tokens)
def _relative_position_bucket(
@ -60,10 +93,10 @@ def _relative_position_bucket(
class RelativePositionBias(nn.Module):
def __init__(self, config: T5Config, bidirectional: bool):
def __init__(self, config, bidirectional: bool):
self.bidirectional = bidirectional
self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.relative_attention_max_distance
self.max_distance = getattr(config, "relative_attention_max_distance", 128)
self.n_heads = config.num_heads
self.embeddings = nn.Embedding(
config.relative_attention_num_buckets, config.num_heads
@ -91,7 +124,7 @@ class RelativePositionBias(nn.Module):
class MultiHeadAttention(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
super().__init__()
inner_dim = config.d_kv * config.num_heads
self.num_heads = config.num_heads
@ -135,17 +168,21 @@ class MultiHeadAttention(nn.Module):
class DenseActivation(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.gated = config.feed_forward_proj.startswith("gated")
self.gated = hasattr(config, "feed_forward_proj")
activation = (
"relu"
if not self.gated
else config.feed_forward_proj.removeprefix("gated-")
)
if self.gated:
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
else:
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
activation = config.feed_forward_proj.removeprefix("gated-")
if activation == "relu":
self.act = nn.relu
elif activation == "gelu":
@ -166,7 +203,7 @@ class DenseActivation(nn.Module):
class TransformerEncoderLayer(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
super().__init__()
self.attention = MultiHeadAttention(config)
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
@ -184,7 +221,7 @@ class TransformerEncoderLayer(nn.Module):
class TransformerEncoder(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
super().__init__()
self.layers = [
TransformerEncoderLayer(config) for i in range(config.num_layers)
@ -200,7 +237,7 @@ class TransformerEncoder(nn.Module):
class TransformerDecoderLayer(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
super().__init__()
self.self_attention = MultiHeadAttention(config)
self.cross_attention = MultiHeadAttention(config)
@ -233,7 +270,7 @@ class TransformerDecoderLayer(nn.Module):
class TransformerDecoder(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
super().__init__()
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
@ -262,7 +299,7 @@ class TransformerDecoder(nn.Module):
class OutputHead(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
def __call__(self, inputs):
@ -270,11 +307,11 @@ class OutputHead(nn.Module):
class T5(nn.Module):
def __init__(self, config: T5Config):
def __init__(self, config):
self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = TransformerEncoder(config)
self.decoder = TransformerDecoder(config)
self.tie_word_embeddings = config.tie_word_embeddings
self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
if not self.tie_word_embeddings:
self.lm_head = OutputHead(config)
self.model_dim = config.d_model
@ -313,36 +350,82 @@ class T5(nn.Module):
):
return self.decode(decoder_inputs, self.encode(inputs))[0]
@classmethod
def sanitize(cls, weights):
shared_replacement_patterns = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]
class Tokenizer:
def __init__(self, config: T5Config):
self._decoder_start_id = config.decoder_start_token_id
self._tokenizer = AutoTokenizer.from_pretrained(
args.model,
legacy=False,
model_max_length=getattr(config, "n_positions", 512),
encoder_replacement_patterns = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.", ".dense."),
]
decoder_replacement_patterns = [
(".layer.0.SelfAttention.", ".self_attention."),
(".layer.1.EncDecAttention.", ".cross_attention."),
(".layer.2.DenseReluDense.", ".dense."),
]
ignored_keys = [
"decoder.layers.0.cross_attention.relative_attention_bias.weight"
]
def replace_key(key: str) -> str:
for old, new in shared_replacement_patterns:
key = key.replace(old, new)
if key.startswith("encoder."):
for old, new in encoder_replacement_patterns:
key = key.replace(old, new)
elif key.startswith("decoder."):
for old, new in decoder_replacement_patterns:
key = key.replace(old, new)
return key
weights = {replace_key(k): v for k, v in weights.items()}
for key in ignored_keys:
if key in weights:
del weights[key]
return weights
@classmethod
def from_pretrained(
cls, path_or_repo: str, dtype: mx.Dtype = mx.bfloat16
) -> tuple["T5", Tokenizer]:
from huggingface_hub import snapshot_download
path = Path(path_or_repo)
if not path.exists():
path = Path(
snapshot_download(
repo_id=path_or_repo,
allow_patterns=["*.json", "*.safetensors", "*.model"],
)
)
@property
def eos_id(self) -> int:
return self._tokenizer.eos_token_id
with open(path / "config.json", "r") as f:
config = SimpleNamespace(**json.load(f))
@property
def decoder_start_id(self) -> int:
return self._decoder_start_id
def encode(self, s: str) -> mx.array:
return mx.array(
self._tokenizer(
s,
return_tensors="np",
return_attention_mask=False,
)["input_ids"]
)
def decode(self, t: List[int], with_sep: bool = True) -> str:
tokens = self._tokenizer.convert_ids_to_tokens(t)
return "".join(t.replace("", " " if with_sep else "") for t in tokens)
model = T5(config)
weights = mx.load(str(path / "model.safetensors"))
weights = cls.sanitize(weights)
weights = {k: v.astype(dtype) for k, v in weights.items()}
model.load_weights(list(weights.items()))
return model, Tokenizer(config, "t5-base")
def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0):
@ -363,19 +446,6 @@ def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float]
yield y.squeeze()
def load_model(model_name: str, dtype: str = "float16"):
config = T5Config.from_pretrained(args.model)
dtype = getattr(mx, dtype)
model = T5(config)
file_name = model_name.replace("/", "-")
weights = mx.load(f"{file_name}.npz")
weights = tree_unflatten(list(weights.items()))
weights = tree_map(lambda p: p.astype(dtype), weights)
model.update(weights)
mx.eval(model.parameters())
return model, Tokenizer(config)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="T5 Inference script")
parser.add_argument(
@ -421,7 +491,8 @@ if __name__ == "__main__":
mx.random.seed(args.seed)
model, tokenizer = load_model(args.model, args.dtype)
dtype = getattr(mx, args.dtype)
model, tokenizer = T5.from_pretrained(args.model, dtype)
if args.encode_only:
print("[INFO] Encoding with T5...", flush=True)