Merge branch 'ml-explore:main' into adding-support-for-mamba2

This commit is contained in:
Gökdeniz Gülmez 2024-10-16 18:57:55 +02:00 committed by GitHub
commit 855fcc4327
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 756 additions and 428 deletions

View File

@ -26,8 +26,8 @@ jobs:
- run: - run:
name: Install dependencies name: Install dependencies
command: | command: |
brew install python@3.8 brew install python@3.9
python3.8 -m venv env python3.9 -m venv env
source env/bin/activate source env/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install unittest-xml-reporting pip install unittest-xml-reporting

View File

@ -20,8 +20,10 @@ Some more useful examples are listed below.
### Image Models ### Image Models
- Generating images
- [FLUX](flux)
- [Stable Diffusion or SDXL](stable_diffusion)
- Image classification using [ResNets on CIFAR-10](cifar). - Image classification using [ResNets on CIFAR-10](cifar).
- Generating images with [Stable Diffusion or SDXL](stable_diffusion).
- Convolutional variational autoencoder [(CVAE) on MNIST](cvae). - Convolutional variational autoencoder [(CVAE) on MNIST](cvae).
### Audio Models ### Audio Models

View File

@ -21,13 +21,34 @@ The dependencies are minimal, namely:
- `huggingface-hub` to download the checkpoints. - `huggingface-hub` to download the checkpoints.
- `regex` for the tokenization - `regex` for the tokenization
- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script - `tqdm`, `PIL`, and `numpy` for the scripts
- `sentencepiece` for the T5 tokenizer - `sentencepiece` for the T5 tokenizer
- `datasets` for using an HF dataset directly
You can install all of the above with the `requirements.txt` as follows: You can install all of the above with the `requirements.txt` as follows:
pip install -r requirements.txt pip install -r requirements.txt
Usage
---------
You can use the following command to generate an image, using `--output` to specify the storage location of the image, defaulting to `out.png`.
```shell
python txt2image.py --model schnell \
--n-images 1 \
--image-size 256x512 \
--verbose \
'A photo of an astronaut riding a horse on Mars.'
```
For more parameters, please use the `--help` command to view.
```shell
python txt2image.py --help
```
Inference Inference
--------- ---------
@ -78,7 +99,11 @@ except for some additional logic to quantize and/or load trained adapters. One
can use the script as follows: can use the script as follows:
```shell ```shell
python txt2image.py --n-images 4 --n-rows 2 --image-size 256x512 'A photo of an astronaut riding a horse on Mars.' python txt2image.py \
--n-images 4 \
--n-rows 2 \
--image-size 256x512 \
'A photo of an astronaut riding a horse on Mars.'
``` ```
### Experimental Options ### Experimental Options
@ -94,17 +119,12 @@ Finetuning
The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
but ymmv) on a provided image dataset. The dataset folder must have an but ymmv) on a provided image dataset. The dataset folder must have an
`index.json` file with the following format: `train.jsonl` file with the following format:
```json ```jsonl
{ {"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
"data": [ {"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, ...
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
...
]
}
``` ```
The training script by default trains for 600 iterations with a batch size of The training script by default trains for 600 iterations with a batch size of
@ -126,19 +146,15 @@ The training images are the following 5 images [^2]:
![dog6](static/dog6.png) ![dog6](static/dog6.png)
We start by making the following `index.json` file and placing it in the same We start by making the following `train.jsonl` file and placing it in the same
folder as the images. folder as the images.
```json ```jsonl
{ {"image": "00.jpg", "prompt": "A photo of sks dog"}
"data": [ {"image": "01.jpg", "prompt": "A photo of sks dog"}
{"image": "00.jpg", "text": "A photo of sks dog"}, {"image": "02.jpg", "prompt": "A photo of sks dog"}
{"image": "01.jpg", "text": "A photo of sks dog"}, {"image": "03.jpg", "prompt": "A photo of sks dog"}
{"image": "02.jpg", "text": "A photo of sks dog"}, {"image": "04.jpg", "prompt": "A photo of sks dog"}
{"image": "03.jpg", "text": "A photo of sks dog"},
{"image": "04.jpg", "text": "A photo of sks dog"}
]
}
``` ```
Subsequently we finetune FLUX using the following command: Subsequently we finetune FLUX using the following command:
@ -151,6 +167,17 @@ python dreambooth.py \
path/to/dreambooth/dataset/dog6 path/to/dreambooth/dataset/dog6
``` ```
Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning.
```shell
python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 8 \
mlx-community/dreambooth-dog6
```
The training requires approximately 50GB of RAM and on an M2 Ultra it takes a The training requires approximately 50GB of RAM and on an M2 Ultra it takes a
bit more than 1 hour. bit more than 1 hour.

View File

@ -1,7 +1,6 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import argparse import argparse
import json
import time import time
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
@ -13,105 +12,8 @@ import numpy as np
from mlx.nn.utils import average_gradients from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline from flux import FluxPipeline, Trainer, load_dataset
class FinetuningDataset:
def __init__(self, flux, args):
self.args = args
self.flux = flux
self.dataset_base = Path(args.dataset)
dataset_index = self.dataset_base / "index.json"
if not dataset_index.exists():
raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset")
with open(dataset_index, "r") as f:
self.index = json.load(f)
self.latents = []
self.t5_features = []
self.clip_features = []
def _random_crop_resize(self, img):
resolution = self.args.resolution
width, height = img.size
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * a) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * b,
pan[1] * c,
crop_size[0] + pan[0] * b,
crop_size[1] + pan[1] * c,
)
)
# Fit the largest rectangle with the ratio of resolution in the image
# rectangle.
width, height = crop_size
ratio = resolution[0] / resolution[1]
r1 = (height * ratio, height)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)
# Finally resize the image to resolution
img = img.resize(resolution, Image.LANCZOS)
return mx.array(np.array(img))
def encode_images(self):
"""Encode the images in the latent space to prepare for training."""
self.flux.ae.eval()
for sample in tqdm(self.index["data"]):
input_img = Image.open(self.dataset_base / sample["image"])
for i in range(self.args.num_augmentations):
img = self._random_crop_resize(input_img)
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)
def encode_prompts(self):
"""Pre-encode the prompts so that we don't recompute them during
training (doesn't allow finetuning the text encoders)."""
for sample in tqdm(self.index["data"]):
t5_tok, clip_tok = self.flux.tokenize([sample["text"]])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)
def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True:
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i : i + batch_size]
c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]
def generate_progress_images(iteration, flux, args): def generate_progress_images(iteration, flux, args):
@ -157,7 +59,8 @@ def save_adapters(iteration, flux, args):
) )
if __name__ == "__main__": def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Finetune Flux to generate images with a specific subject" description="Finetune Flux to generate images with a specific subject"
) )
@ -247,7 +150,11 @@ if __name__ == "__main__":
) )
parser.add_argument("dataset") parser.add_argument("dataset")
return parser
if __name__ == "__main__":
parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
# Load the model and set it up for LoRA training. We use the same random # Load the model and set it up for LoRA training. We use the same random
@ -267,7 +174,7 @@ if __name__ == "__main__":
trainable_params = tree_reduce( trainable_params = tree_reduce(
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0 lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
) )
print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True) print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True)
# Set up the optimizer and training steps. The steps are a bit verbose to # Set up the optimizer and training steps. The steps are a bit verbose to
# support gradient accumulation together with compilation. # support gradient accumulation together with compilation.
@ -340,10 +247,10 @@ if __name__ == "__main__":
x, t5_feat, clip_feat, guidance, prev_grads x, t5_feat, clip_feat, guidance, prev_grads
) )
print("Create the training dataset.", flush=True) dataset = load_dataset(args.dataset)
dataset = FinetuningDataset(flux, args) trainer = Trainer(flux, dataset, args)
dataset.encode_images() trainer.encode_dataset()
dataset.encode_prompts()
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype) guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
# An initial generation to compare # An initial generation to compare
@ -352,7 +259,7 @@ if __name__ == "__main__":
grads = None grads = None
losses = [] losses = []
tic = time.time() tic = time.time()
for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)): for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0) loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
mx.eval(loss, grads, state) mx.eval(loss, grads, state)
losses.append(loss.item()) losses.append(loss.item())
@ -361,7 +268,7 @@ if __name__ == "__main__":
toc = time.time() toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024**3 peak_mem = mx.metal.get_peak_memory() / 1024**3
print( print(
f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} " f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
f"It/s: {10 / (toc - tic):.3f} " f"It/s: {10 / (toc - tic):.3f} "
f"Peak mem: {peak_mem:.3f} GB", f"Peak mem: {peak_mem:.3f} GB",
flush=True, flush=True,

View File

@ -1,16 +1,10 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import math from .datasets import Dataset, load_dataset
import time from .flux import FluxPipeline
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from tqdm import tqdm
from .lora import LoRALinear from .lora import LoRALinear
from .sampler import FluxSampler from .sampler import FluxSampler
from .trainer import Trainer
from .utils import ( from .utils import (
load_ae, load_ae,
load_clip, load_clip,
@ -19,230 +13,3 @@ from .utils import (
load_t5, load_t5,
load_t5_tokenizer, load_t5_tokenizer,
) )
class FluxPipeline:
def __init__(self, name: str, t5_padding: bool = True):
self.dtype = mx.bfloat16
self.name = name
self.t5_padding = t5_padding
self.ae = load_ae(name)
self.flow = load_flow_model(name)
self.clip = load_clip(name)
self.clip_tokenizer = load_clip_tokenizer(name)
self.t5 = load_t5(name)
self.t5_tokenizer = load_t5_tokenizer(name)
self.sampler = FluxSampler(name)
def ensure_models_are_loaded(self):
mx.eval(
self.ae.parameters(),
self.flow.parameters(),
self.clip.parameters(),
self.t5.parameters(),
)
def reload_text_encoders(self):
self.t5 = load_t5(self.name)
self.clip = load_clip(self.name)
def tokenize(self, text):
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
clip_tokens = self.clip_tokenizer.encode(text)
return t5_tokens, clip_tokens
def _prepare_latent_images(self, x):
b, h, w, c = x.shape
# Pack the latent image to 2x2 patches
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
# Create positions ids used to positionally encode each patch. Due to
# the way RoPE works, this results in an interesting positional
# encoding where parts of the feature are holding different positional
# information. Namely, the first part holds information independent of
# the spatial position (hence 0s), the 2nd part holds vertical spatial
# information and the last one horizontal.
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
x_ids = mx.stack([i, j, k], axis=-1)
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
return x, x_ids
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
# Prepare the text features
txt = self.t5(t5_tokens)
if len(txt) == 1 and n_images > 1:
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
# Prepare the clip text features
vec = self.clip(clip_tokens).pooled_output
if len(vec) == 1 and n_images > 1:
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
return txt, txt_ids, vec
def _denoising_loop(
self,
x_t,
x_ids,
txt,
txt_ids,
vec,
num_steps: int = 35,
guidance: float = 4.0,
start: float = 1,
stop: float = 0,
):
B = len(x_t)
def scalar(x):
return mx.full((B,), x, dtype=self.dtype)
guidance = scalar(guidance)
timesteps = self.sampler.timesteps(
num_steps,
x_t.shape[1],
start=start,
stop=stop,
)
for i in range(num_steps):
t = timesteps[i]
t_prev = timesteps[i + 1]
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=scalar(t),
guidance=guidance,
)
x_t = self.sampler.step(pred, x_t, t, t_prev)
yield x_t
def generate_latents(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
):
# Set the PRNG state
if seed is not None:
mx.random.seed(seed)
# Create the latent variables
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
x_T, x_ids = self._prepare_latent_images(x_T)
# Get the conditioning
t5_tokens, clip_tokens = self.tokenize(text)
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
# Yield the conditioning for controlled evaluation by the caller
yield (x_T, x_ids, txt, txt_ids, vec)
# Yield the latent sequences from the denoising loop
yield from self._denoising_loop(
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
)
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
h, w = latent_size
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
x = self.ae.decode(x)
return mx.clip(x + 1, 0, 2) * 0.5
def generate_images(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
reload_text_encoders: bool = True,
progress: bool = True,
):
latents = self.generate_latents(
text, n_images, num_steps, guidance, latent_size, seed
)
mx.eval(next(latents))
if reload_text_encoders:
self.reload_text_encoders()
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
mx.eval(x_t)
images = []
for i in tqdm(range(len(x_t)), disable=not progress):
images.append(self.decode(x_t[i : i + 1]))
mx.eval(images[-1])
images = mx.concatenate(images, axis=0)
mx.eval(images)
return images
def training_loss(
self,
x_0: mx.array,
t5_features: mx.array,
clip_features: mx.array,
guidance: mx.array,
):
# Get the text conditioning
txt = t5_features
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
vec = clip_features
# Prepare the latent input
x_0, x_ids = self._prepare_latent_images(x_0)
# Forward process
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
x_t = self.sampler.add_noise(x_0, t, noise=eps)
x_t = mx.stop_gradient(x_t)
# Do the denoising
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t,
guidance=guidance,
)
return (pred + x_0 - eps).square().mean()
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
"""Swap the linear layers in the transformer blocks with LoRA layers."""
all_blocks = self.flow.double_blocks + self.flow.single_blocks
all_blocks.reverse()
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
for i, block in zip(range(num_blocks), all_blocks):
loras = []
for name, module in block.named_modules():
if isinstance(module, nn.Linear):
loras.append((name, LoRALinear.from_base(module, r=rank)))
block.update_modules(tree_unflatten(loras))
def fuse_lora_layers(self):
fused_layers = []
for name, module in self.flow.named_modules():
if isinstance(module, LoRALinear):
fused_layers.append((name, module.fuse()))
self.flow.update_modules(tree_unflatten(fused_layers))

75
flux/flux/datasets.py Normal file
View File

@ -0,0 +1,75 @@
import json
from pathlib import Path
from PIL import Image
class Dataset:
def __getitem__(self, index: int):
raise NotImplementedError()
def __len__(self):
raise NotImplementedError()
class LocalDataset(Dataset):
prompt_key = "prompt"
def __init__(self, dataset: str, data_file):
self.dataset_base = Path(dataset)
with open(data_file, "r") as fid:
self._data = [json.loads(l) for l in fid]
def __len__(self):
return len(self._data)
def __getitem__(self, index: int):
item = self._data[index]
image = Image.open(self.dataset_base / item["image"])
return image, item[self.prompt_key]
class LegacyDataset(LocalDataset):
prompt_key = "text"
def __init__(self, dataset: str):
self.dataset_base = Path(dataset)
with open(self.dataset_base / "index.json") as f:
self._data = json.load(f)["data"]
class HuggingFaceDataset(Dataset):
def __init__(self, dataset: str):
from datasets import load_dataset as hf_load_dataset
self._df = hf_load_dataset(dataset)["train"]
def __len__(self):
return len(self._df)
def __getitem__(self, index: int):
item = self._df[index]
return item["image"], item["prompt"]
def load_dataset(dataset: str):
dataset_base = Path(dataset)
data_file = dataset_base / "train.jsonl"
legacy_file = dataset_base / "index.json"
if data_file.exists():
print(f"Load the local dataset {data_file} .", flush=True)
dataset = LocalDataset(dataset, data_file)
elif legacy_file.exists():
print(f"Load the local dataset {legacy_file} .")
print()
print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
print(" See the README for details.")
print(flush=True)
dataset = LegacyDataset(dataset)
else:
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
dataset = HuggingFaceDataset(dataset)
return dataset

246
flux/flux/flux.py Normal file
View File

@ -0,0 +1,246 @@
# Copyright © 2024 Apple Inc.
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from tqdm import tqdm
from .lora import LoRALinear
from .sampler import FluxSampler
from .utils import (
load_ae,
load_clip,
load_clip_tokenizer,
load_flow_model,
load_t5,
load_t5_tokenizer,
)
class FluxPipeline:
def __init__(self, name: str, t5_padding: bool = True):
self.dtype = mx.bfloat16
self.name = name
self.t5_padding = t5_padding
self.ae = load_ae(name)
self.flow = load_flow_model(name)
self.clip = load_clip(name)
self.clip_tokenizer = load_clip_tokenizer(name)
self.t5 = load_t5(name)
self.t5_tokenizer = load_t5_tokenizer(name)
self.sampler = FluxSampler(name)
def ensure_models_are_loaded(self):
mx.eval(
self.ae.parameters(),
self.flow.parameters(),
self.clip.parameters(),
self.t5.parameters(),
)
def reload_text_encoders(self):
self.t5 = load_t5(self.name)
self.clip = load_clip(self.name)
def tokenize(self, text):
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
clip_tokens = self.clip_tokenizer.encode(text)
return t5_tokens, clip_tokens
def _prepare_latent_images(self, x):
b, h, w, c = x.shape
# Pack the latent image to 2x2 patches
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
# Create positions ids used to positionally encode each patch. Due to
# the way RoPE works, this results in an interesting positional
# encoding where parts of the feature are holding different positional
# information. Namely, the first part holds information independent of
# the spatial position (hence 0s), the 2nd part holds vertical spatial
# information and the last one horizontal.
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
x_ids = mx.stack([i, j, k], axis=-1)
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
return x, x_ids
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
# Prepare the text features
txt = self.t5(t5_tokens)
if len(txt) == 1 and n_images > 1:
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
# Prepare the clip text features
vec = self.clip(clip_tokens).pooled_output
if len(vec) == 1 and n_images > 1:
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
return txt, txt_ids, vec
def _denoising_loop(
self,
x_t,
x_ids,
txt,
txt_ids,
vec,
num_steps: int = 35,
guidance: float = 4.0,
start: float = 1,
stop: float = 0,
):
B = len(x_t)
def scalar(x):
return mx.full((B,), x, dtype=self.dtype)
guidance = scalar(guidance)
timesteps = self.sampler.timesteps(
num_steps,
x_t.shape[1],
start=start,
stop=stop,
)
for i in range(num_steps):
t = timesteps[i]
t_prev = timesteps[i + 1]
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=scalar(t),
guidance=guidance,
)
x_t = self.sampler.step(pred, x_t, t, t_prev)
yield x_t
def generate_latents(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
):
# Set the PRNG state
if seed is not None:
mx.random.seed(seed)
# Create the latent variables
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
x_T, x_ids = self._prepare_latent_images(x_T)
# Get the conditioning
t5_tokens, clip_tokens = self.tokenize(text)
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
# Yield the conditioning for controlled evaluation by the caller
yield (x_T, x_ids, txt, txt_ids, vec)
# Yield the latent sequences from the denoising loop
yield from self._denoising_loop(
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
)
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
h, w = latent_size
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
x = self.ae.decode(x)
return mx.clip(x + 1, 0, 2) * 0.5
def generate_images(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
reload_text_encoders: bool = True,
progress: bool = True,
):
latents = self.generate_latents(
text, n_images, num_steps, guidance, latent_size, seed
)
mx.eval(next(latents))
if reload_text_encoders:
self.reload_text_encoders()
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
mx.eval(x_t)
images = []
for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"):
images.append(self.decode(x_t[i : i + 1]))
mx.eval(images[-1])
images = mx.concatenate(images, axis=0)
mx.eval(images)
return images
def training_loss(
self,
x_0: mx.array,
t5_features: mx.array,
clip_features: mx.array,
guidance: mx.array,
):
# Get the text conditioning
txt = t5_features
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
vec = clip_features
# Prepare the latent input
x_0, x_ids = self._prepare_latent_images(x_0)
# Forward process
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
x_t = self.sampler.add_noise(x_0, t, noise=eps)
x_t = mx.stop_gradient(x_t)
# Do the denoising
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t,
guidance=guidance,
)
return (pred + x_0 - eps).square().mean()
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
"""Swap the linear layers in the transformer blocks with LoRA layers."""
all_blocks = self.flow.double_blocks + self.flow.single_blocks
all_blocks.reverse()
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
for i, block in zip(range(num_blocks), all_blocks):
loras = []
for name, module in block.named_modules():
if isinstance(module, nn.Linear):
loras.append((name, LoRALinear.from_base(module, r=rank)))
block.update_modules(tree_unflatten(loras))
def fuse_lora_layers(self):
fused_layers = []
for name, module in self.flow.named_modules():
if isinstance(module, LoRALinear):
fused_layers.append((name, module.fuse()))
self.flow.update_modules(tree_unflatten(fused_layers))

98
flux/flux/trainer.py Normal file
View File

@ -0,0 +1,98 @@
import mlx.core as mx
import numpy as np
from PIL import Image, ImageFile
from tqdm import tqdm
from .datasets import Dataset
from .flux import FluxPipeline
class Trainer:
def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
self.flux = flux
self.dataset = dataset
self.args = args
self.latents = []
self.t5_features = []
self.clip_features = []
def _random_crop_resize(self, img):
resolution = self.args.resolution
width, height = img.size
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * b) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * c,
pan[1] * d,
crop_size[0] + pan[0] * c,
crop_size[1] + pan[1] * d,
)
)
# Fit the largest rectangle with the ratio of resolution in the image
# rectangle.
width, height = crop_size
ratio = resolution[0] / resolution[1]
r1 = (height * ratio, height)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)
# Finally resize the image to resolution
img = img.resize(resolution, Image.LANCZOS)
return mx.array(np.array(img))
def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
for i in range(num_augmentations):
img = self._random_crop_resize(input_img)
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)
def _encode_prompt(self, prompt):
t5_tok, clip_tok = self.flux.tokenize([prompt])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)
def encode_dataset(self):
"""Encode the images & prompt in the latent space to prepare for training."""
self.flux.ae.eval()
for image, prompt in tqdm(self.dataset, desc="encode dataset"):
self._encode_image(image, self.args.num_augmentations)
self._encode_prompt(prompt)
def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True:
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i : i + batch_size]
c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]

View File

@ -77,7 +77,7 @@ if __name__ == "__main__":
nn.quantize(flux.clip, class_predicate=quantization_predicate) nn.quantize(flux.clip, class_predicate=quantization_predicate)
if args.preload_models: if args.preload_models:
sd.ensure_models_are_loaded() flux.ensure_models_are_loaded()
# Make the generator # Make the generator
latent_size = to_latent_size(args.image_size) latent_size = to_latent_size(args.image_size)

View File

@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in - `role_mapping`: (Optional) A dictionary to customize the role prefixes in
the generated prompt. If not provided, the default mappings are used. the generated prompt. If not provided, the default mappings are used.
- `stop`: (Optional) An array of strings or a single string. Thesse are - `stop`: (Optional) An array of strings or a single string. These are
sequences of tokens on which the generation should stop. sequences of tokens on which the generation should stop.
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens - `max_tokens`: (Optional) An integer specifying the maximum number of tokens
@ -84,7 +84,37 @@ curl localhost:8080/v1/chat/completions \
started in. started in.
- `adapters`: (Optional) A string path to low-rank adapters. The path must be - `adapters`: (Optional) A string path to low-rank adapters. The path must be
rlative to the directory the server was started in. relative to the directory the server was started in.
### Response Fields
- `id`: A unique identifier for the chat.
- `system_fingerprint`: A unique identifier for the system.
- `object`: Any of "chat.completions", "chat.completions.chunk" (for
streaming), or "text.completion".
- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).
- `created`: A time-stamp for when the request was processed.
- `choices`: A list of outputs. Each output is a dictionary containing the fields:
- `index`: The index in the list.
- `logprobs`: A dictionary containing the fields:
- `token_logprobs`: A list of the log probabilities for the generated
tokens.
- `tokens`: A list of the generated token ids.
- `top_logprobs`: A list of lists. Each list contains the `logprobs`
top tokens (if requested) with their corresponding probabilities.
- `finish_reason`: The reason the completion ended. This can be either of
`"stop"` or `"length"`.
- `message`: The text response from the model.
- `usage`: A dictionary containing the fields:
- `prompt_tokens`: The number of prompt tokens processed.
- `completion_tokens`: The number of tokens generated.
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.
### List Models ### List Models
@ -97,5 +127,5 @@ curl localhost:8080/v1/models -H "Content-Type: application/json"
This will return a list of locally available models where each model in the This will return a list of locally available models where each model in the
list contains the following fields: list contains the following fields:
- `"id"`: The Hugging Face repo id. - `id`: The Hugging Face repo id.
- `"created"`: A timestamp representing the model creation time. - `created`: A time-stamp representing the model creation time.

View File

@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False):
return cache return cache
def can_trim_prompt_cache(cache: List[Any]) -> bool:
"""
Check if model's cache can be trimmed.
"""
return all(c.is_trimmable() for c in cache)
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
""" """
Trim the model's cache by the given number of tokens. Trim the model's cache by the given number of tokens.
@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
Returns: Returns:
(int): The number of tokens that were trimmed. (int): The number of tokens that were trimmed.
""" """
if not all(c.is_trimmable() for c in cache) or len(cache) == 0: if not can_trim_prompt_cache(cache) or len(cache) == 0:
return 0 return 0
return [c.trim(num_tokens) for c in cache][0] return [c.trim(num_tokens) for c in cache][0]

View File

@ -220,17 +220,17 @@ class DeepseekV2Attention(nn.Module):
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1)
if cache is not None: if cache is not None:
q_pe = self.rope(q_pe, cache.offset) q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset) k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch( keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values mx.concatenate([k_nope, k_pe], axis=-1), values
) )
else: else:
q_pe = self.rope(q_pe) q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe) k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1) keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1) queries = mx.concatenate([q_nope, q_pe], axis=-1)
@ -291,7 +291,7 @@ class MoEGate(nn.Module):
scores = scores.reshape(bsz, seq_len, -1) scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]) inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1) scores = mx.take_along_axis(scores, inds, axis=-1)
scores = scores * self.routed_scaling_factor scores = scores * self.routed_scaling_factor

View File

@ -3,19 +3,38 @@
import argparse import argparse
import json import json
import logging import logging
import platform
import time import time
import uuid import uuid
import warnings import warnings
from dataclasses import dataclass, field
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union from typing import (
Any,
Dict,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
import mlx.core as mx import mlx.core as mx
from huggingface_hub import scan_cache_dir from huggingface_hub import scan_cache_dir
from ._version import __version__
from .models.cache import make_prompt_cache
from .utils import generate_step, load from .utils import generate_step, load
def get_system_fingerprint():
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"
class StopCondition(NamedTuple): class StopCondition(NamedTuple):
stop_met: bool stop_met: bool
trim_length: int trim_length: int
@ -94,6 +113,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip() return prompt.rstrip()
@dataclass
class PromptCache:
cache: List[Any] = field(default_factory=list)
model_key: Tuple[str, Optional[str]] = ("", None)
tokens: List[int] = field(default_factory=list)
class ModelProvider: class ModelProvider:
def __init__(self, cli_args: argparse.Namespace): def __init__(self, cli_args: argparse.Namespace):
"""Load models on demand and persist them across the whole process.""" """Load models on demand and persist them across the whole process."""
@ -156,12 +182,21 @@ class ModelProvider:
class APIHandler(BaseHTTPRequestHandler): class APIHandler(BaseHTTPRequestHandler):
def __init__(self, model_provider: ModelProvider, *args, **kwargs): def __init__(
self,
model_provider: ModelProvider,
*args,
prompt_cache: Optional[PromptCache] = None,
system_fingerprint: Optional[str] = None,
**kwargs,
):
""" """
Create static request specific metadata Create static request specific metadata
""" """
self.created = int(time.time()) self.created = int(time.time())
self.model_provider = model_provider self.model_provider = model_provider
self.prompt_cache = prompt_cache or PromptCache()
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def _set_cors_headers(self): def _set_cors_headers(self):
@ -215,7 +250,9 @@ class APIHandler(BaseHTTPRequestHandler):
self.stream_options = self.body.get("stream_options", None) self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model") self.requested_model = self.body.get("model", "default_model")
self.adapter = self.body.get("adapters", None) self.adapter = self.body.get("adapters", None)
self.max_tokens = self.body.get("max_tokens", 100) self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 1.0) self.temperature = self.body.get("temperature", 1.0)
self.top_p = self.body.get("top_p", 1.0) self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
@ -343,7 +380,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Static response # Static response
response = { response = {
"id": self.request_id, "id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}", "system_fingerprint": self.system_fingerprint,
"object": self.object_type, "object": self.object_type,
"model": self.requested_model, "model": self.requested_model,
"created": self.created, "created": self.created,
@ -388,16 +425,30 @@ class APIHandler(BaseHTTPRequestHandler):
return response return response
def get_prompt_cache(self, prompt):
cache_len = len(self.prompt_cache.tokens)
if (
self.prompt_cache.model_key != self.model_provider.model_key
or cache_len >= len(prompt)
or self.prompt_cache.tokens != prompt[:cache_len]
):
self.prompt_cache.model_key = self.model_provider.model_key
self.prompt_cache.cache = make_prompt_cache(self.model_provider.model)
else:
prompt = prompt[cache_len:]
self.prompt_cache.tokens.extend(prompt)
return prompt
def handle_completion( def handle_completion(
self, self,
prompt: mx.array, prompt: List[int],
stop_id_sequences: List[List[int]], stop_id_sequences: List[List[int]],
): ):
""" """
Generate a response to a prompt and send it to the client in a single batch. Generate a response to a prompt and send it to the client in a single batch.
Args: Args:
prompt (mx.array): The prompt, in token form inside of a mlx array prompt (List[int]): The tokenized prompt.
stop_id_sequences (List[List[int]]): A list of stop words passed stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function to the stopping_criteria function
""" """
@ -409,17 +460,21 @@ class APIHandler(BaseHTTPRequestHandler):
logging.debug(f"Starting completion:") logging.debug(f"Starting completion:")
token_logprobs = [] token_logprobs = []
top_tokens = [] top_tokens = []
for (token, logprobs), _ in zip(
prompt = self.get_prompt_cache(prompt)
for _, (token, logprobs) in zip(
range(self.max_tokens),
generate_step( generate_step(
prompt=prompt, prompt=mx.array(prompt),
model=self.model, model=self.model,
temp=self.temperature, temp=self.temperature,
top_p=self.top_p, top_p=self.top_p,
repetition_penalty=self.repetition_penalty, repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size, repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache,
), ),
range(self.max_tokens),
): ):
detokenizer.add_token(token) detokenizer.add_token(token)
logging.debug(detokenizer.text) logging.debug(detokenizer.text)
@ -430,7 +485,7 @@ class APIHandler(BaseHTTPRequestHandler):
top_indices = sorted_indices[: self.logprobs] top_indices = sorted_indices[: self.logprobs]
top_logprobs = logprobs[top_indices] top_logprobs = logprobs[top_indices]
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist()) top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
top_tokens.append(dict(top_token_info)) top_tokens.append(tuple(top_token_info))
token_logprobs.append(logprobs[token].item()) token_logprobs.append(logprobs[token].item())
@ -445,6 +500,7 @@ class APIHandler(BaseHTTPRequestHandler):
) )
break break
self.prompt_cache.tokens.extend(tokens)
detokenizer.finalize() detokenizer.finalize()
text = ( text = (
detokenizer.text detokenizer.text
@ -474,7 +530,7 @@ class APIHandler(BaseHTTPRequestHandler):
def handle_stream( def handle_stream(
self, self,
prompt: mx.array, prompt: List[int],
stop_id_sequences: List[List[int]], stop_id_sequences: List[List[int]],
): ):
""" """
@ -482,7 +538,7 @@ class APIHandler(BaseHTTPRequestHandler):
Sent Events (SSE) stream. Sent Events (SSE) stream.
Args: Args:
prompt (mx.array): The prompt, in token form inside of a mlx array prompt (mx.array): The tokenized prompt
stop_id_sequences (List[List[int]]): A list of stop words passed to stop_id_sequences (List[List[int]]): A list of stop words passed to
the stopping_criteria function the stopping_criteria function
""" """
@ -496,16 +552,19 @@ class APIHandler(BaseHTTPRequestHandler):
stop_sequence_suffix = None stop_sequence_suffix = None
logging.debug(f"Starting stream:") logging.debug(f"Starting stream:")
for (token, _), _ in zip( prompt = self.get_prompt_cache(prompt)
for _, (token, _) in zip(
range(self.max_tokens),
generate_step( generate_step(
prompt=prompt, prompt=mx.array(prompt),
model=self.model, model=self.model,
temp=self.temperature, temp=self.temperature,
top_p=self.top_p, top_p=self.top_p,
repetition_penalty=self.repetition_penalty, repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size, repetition_context_size=self.repetition_context_size,
prompt_cache=self.prompt_cache.cache,
), ),
range(self.max_tokens),
): ):
detokenizer.add_token(token) detokenizer.add_token(token)
logging.debug(detokenizer.text) logging.debug(detokenizer.text)
@ -531,9 +590,12 @@ class APIHandler(BaseHTTPRequestHandler):
continue continue
new_text = detokenizer.last_segment new_text = detokenizer.last_segment
response = self.generate_response(new_text, None) if new_text:
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) response = self.generate_response(new_text, None)
self.wfile.flush() self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.prompt_cache.tokens.extend(tokens)
# check is there any remaining text to send # check is there any remaining text to send
detokenizer.finalize() detokenizer.finalize()
@ -559,7 +621,7 @@ class APIHandler(BaseHTTPRequestHandler):
): ):
response = { response = {
"id": self.request_id, "id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}", "system_fingerprint": self.system_fingerprint,
"object": "chat.completion", "object": "chat.completion",
"model": self.requested_model, "model": self.requested_model,
"created": self.created, "created": self.created,
@ -572,7 +634,7 @@ class APIHandler(BaseHTTPRequestHandler):
} }
return response return response
def handle_chat_completions(self) -> mx.array: def handle_chat_completions(self) -> List[int]:
""" """
Handle a chat completion request. Handle a chat completion request.
@ -587,7 +649,6 @@ class APIHandler(BaseHTTPRequestHandler):
self.object_type = ( self.object_type = (
"chat.completions.chunk" if self.stream else "chat.completions" "chat.completions.chunk" if self.stream else "chat.completions"
) )
if ( if (
hasattr(self.tokenizer, "apply_chat_template") hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template and self.tokenizer.chat_template
@ -602,9 +663,9 @@ class APIHandler(BaseHTTPRequestHandler):
prompt = convert_chat(body["messages"], body.get("role_mapping")) prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = self.tokenizer.encode(prompt) prompt = self.tokenizer.encode(prompt)
return mx.array(prompt) return prompt
def handle_text_completions(self) -> mx.array: def handle_text_completions(self) -> List[int]:
""" """
Handle a text completion request. Handle a text completion request.
@ -614,11 +675,8 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type # Determine response type
self.request_id = f"cmpl-{uuid.uuid4()}" self.request_id = f"cmpl-{uuid.uuid4()}"
self.object_type = "text_completion" self.object_type = "text_completion"
assert "prompt" in self.body, "Request did not contain a prompt" assert "prompt" in self.body, "Request did not contain a prompt"
prompt_text = self.body["prompt"] return self.tokenizer.encode(self.body["prompt"])
prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt)
def do_GET(self): def do_GET(self):
""" """
@ -669,9 +727,16 @@ def run(
handler_class=APIHandler, handler_class=APIHandler,
): ):
server_address = (host, port) server_address = (host, port)
prompt_cache = PromptCache()
httpd = server_class( httpd = server_class(
server_address, server_address,
lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs), lambda *args, **kwargs: handler_class(
model_provider,
prompt_cache=prompt_cache,
system_fingerprint=get_system_fingerprint(),
*args,
**kwargs,
),
) )
warnings.warn( warnings.warn(
"mlx_lm.server is not recommended for production as " "mlx_lm.server is not recommended for production as "

View File

@ -97,6 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
def text(self): def text(self):
if self._current_tokens: if self._current_tokens:
self._current_text = self._tokenizer.decode(self._current_tokens) self._current_text = self._tokenizer.decode(self._current_tokens)
if (
self._tokenizer.clean_up_tokenization_spaces
and self._current_text[-1] == " "
):
self._current_text = self._current_text[:-1]
if self._current_text and self._current_text[-1] == "\n": if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens) self._tokens.extend(self._current_tokens)
self._text += self._current_text self._text += self._current_text
@ -164,9 +169,11 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
""" """
_byte_decoder = None _byte_decoder = None
_space_matches = (".", "?", "!", ",", "'", "n't", "'m", "'s", "'ve", "'re")
def __init__(self, tokenizer, trim_space=False): def __init__(self, tokenizer):
self.trim_space = trim_space
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
# Extract the tokens in a list from id to text # Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab) self.tokenmap = [None] * len(tokenizer.vocab)
@ -185,17 +192,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
self.text = "" self.text = ""
self.tokens = [] self.tokens = []
def _maybe_trim_space(self, current_text):
if current_text[0] != " ":
return current_text
elif not self.text:
return current_text[1:]
elif self.clean_spaces and current_text[1:].startswith(self._space_matches):
return current_text[1:]
return current_text
def add_token(self, token): def add_token(self, token):
v = self.tokenmap[token] v = self.tokenmap[token]
# if the token starts with space
if self._byte_decoder[v[0]] == 32: if self._byte_decoder[v[0]] == 32:
current_text = bytearray( current_text = bytearray(
self._byte_decoder[c] for c in self._unflushed self._byte_decoder[c] for c in self._unflushed
).decode("utf-8") ).decode("utf-8")
if self.text or not self.trim_space: self.text += self._maybe_trim_space(current_text)
self.text += current_text
else:
self.text += _remove_space(current_text)
self._unflushed = v self._unflushed = v
else: else:
self._unflushed += v self._unflushed += v
@ -204,10 +216,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
"utf-8" "utf-8"
) )
if self.text or not self.trim_space: self.text += self._maybe_trim_space(current_text)
self.text += current_text
else:
self.text += _remove_space(current_text)
self._unflushed = "" self._unflushed = ""
@classmethod @classmethod
@ -303,14 +312,7 @@ def _is_spm_decoder_no_space(decoder):
def _is_bpe_decoder(decoder): def _is_bpe_decoder(decoder):
_target_description = { return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
"type": "ByteLevel",
"add_prefix_space": False,
"trim_offsets": False,
"use_regex": False,
}
return _match(_target_description, decoder)
def load_tokenizer(model_path, tokenizer_config_extra={}): def load_tokenizer(model_path, tokenizer_config_extra={}):

View File

@ -246,10 +246,10 @@ def generate_step(
y, logprobs = _step(y) y, logprobs = _step(y)
mx.async_eval(y) mx.async_eval(y, logprobs)
while True: while True:
next_y, next_logprobs = _step(y) next_y, next_logprobs = _step(y)
mx.async_eval(next_y) mx.async_eval(next_y, next_logprobs)
yield y.item(), logprobs yield y.item(), logprobs
y, logprobs = next_y, next_logprobs y, logprobs = next_y, next_logprobs
@ -348,7 +348,9 @@ def generate(
if formatter: if formatter:
# We have to finalize so that the prob corresponds to the last segment # We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize() detokenizer.finalize()
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) with mx.stream(mx.cpu):
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else: else:
print(detokenizer.last_segment, end="", flush=True) print(detokenizer.last_segment, end="", flush=True)

View File

@ -1,5 +1,6 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import copy
import os import os
import tempfile import tempfile
import unittest import unittest
@ -215,6 +216,28 @@ class TestPromptCache(unittest.TestCase):
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits)) all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
) )
def test_cache_copying(self):
cache = [KVCache()]
x = mx.random.uniform(shape=(1, 8, 10, 4))
cache[0].update_and_fetch(x, x)
y = mx.random.uniform(shape=(1, 8, 1, 4))
cache[0].update_and_fetch(y, y)
old_cache = copy.deepcopy(cache)
trim_prompt_cache(cache, 1)
self.assertTrue(old_cache[0].offset, 11)
self.assertTrue(cache[0].offset, 10)
z = mx.random.uniform(shape=(1, 8, 1, 4))
cache[0].update_and_fetch(z, z)
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -14,6 +14,7 @@ class DummyModelProvider:
def __init__(self): def __init__(self):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH) self.model, self.tokenizer = load(HF_MODEL_PATH)
self.model_key = (HF_MODEL_PATH, None)
def load(self, model, adapter=None): def load(self, model, adapter=None):
assert model in ["default_model", "chat_model"] assert model in ["default_model", "chat_model"]

View File

@ -0,0 +1,76 @@
# Copyright © 2024 Apple Inc.
import unittest
from pathlib import Path
from huggingface_hub import snapshot_download
from mlx_lm.tokenizer_utils import (
BPEStreamingDetokenizer,
NaiveStreamingDetokenizer,
SPMStreamingDetokenizer,
load_tokenizer,
)
class TestTokenizers(unittest.TestCase):
def download_tokenizer(self, repo):
path = Path(
snapshot_download(
repo_id=repo,
allow_patterns=[
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"tokenizer.model",
],
)
)
return load_tokenizer(path)
def check_tokenizer(self, tokenizer):
def check(tokens):
expected_text = tokenizer.decode(tokens)
detokenizer = tokenizer.detokenizer
detokenizer.reset()
text = ""
for t in tokens:
detokenizer.add_token(t)
seg = detokenizer.last_segment
text += seg
detokenizer.finalize()
text += detokenizer.last_segment
self.assertEqual(text, expected_text)
tokens = tokenizer.encode("a ,b")
check(tokens)
tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}')
check(tokens)
tokens = tokenizer.encode("3 3")
check(tokens)
def test_tokenizers(self):
tokenizer_repos = [
("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer),
("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer),
("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer),
("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer),
("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer),
]
for tokenizer_repo, expected_detokenizer in tokenizer_repos:
with self.subTest(tokenizer=tokenizer_repo):
tokenizer = self.download_tokenizer(tokenizer_repo)
tokenizer.decode([0, 1, 2])
self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer))
self.check_tokenizer(tokenizer)
# Try one with a naive detokenizer
tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit")
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
self.check_tokenizer(tokenizer)
if __name__ == "__main__":
unittest.main()