mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-07-17 07:51:12 +08:00
Merge branch 'ml-explore:main' into adding-support-for-mamba2
This commit is contained in:
commit
855fcc4327
@ -26,8 +26,8 @@ jobs:
|
||||
- run:
|
||||
name: Install dependencies
|
||||
command: |
|
||||
brew install python@3.8
|
||||
python3.8 -m venv env
|
||||
brew install python@3.9
|
||||
python3.9 -m venv env
|
||||
source env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install unittest-xml-reporting
|
||||
|
@ -20,8 +20,10 @@ Some more useful examples are listed below.
|
||||
|
||||
### Image Models
|
||||
|
||||
- Generating images
|
||||
- [FLUX](flux)
|
||||
- [Stable Diffusion or SDXL](stable_diffusion)
|
||||
- Image classification using [ResNets on CIFAR-10](cifar).
|
||||
- Generating images with [Stable Diffusion or SDXL](stable_diffusion).
|
||||
- Convolutional variational autoencoder [(CVAE) on MNIST](cvae).
|
||||
|
||||
### Audio Models
|
||||
|
@ -21,13 +21,34 @@ The dependencies are minimal, namely:
|
||||
|
||||
- `huggingface-hub` to download the checkpoints.
|
||||
- `regex` for the tokenization
|
||||
- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script
|
||||
- `tqdm`, `PIL`, and `numpy` for the scripts
|
||||
- `sentencepiece` for the T5 tokenizer
|
||||
- `datasets` for using an HF dataset directly
|
||||
|
||||
You can install all of the above with the `requirements.txt` as follows:
|
||||
|
||||
pip install -r requirements.txt
|
||||
|
||||
|
||||
Usage
|
||||
---------
|
||||
|
||||
You can use the following command to generate an image, using `--output` to specify the storage location of the image, defaulting to `out.png`.
|
||||
|
||||
```shell
|
||||
python txt2image.py --model schnell \
|
||||
--n-images 1 \
|
||||
--image-size 256x512 \
|
||||
--verbose \
|
||||
'A photo of an astronaut riding a horse on Mars.'
|
||||
```
|
||||
|
||||
For more parameters, please use the `--help` command to view.
|
||||
|
||||
```shell
|
||||
python txt2image.py --help
|
||||
```
|
||||
|
||||
Inference
|
||||
---------
|
||||
|
||||
@ -78,7 +99,11 @@ except for some additional logic to quantize and/or load trained adapters. One
|
||||
can use the script as follows:
|
||||
|
||||
```shell
|
||||
python txt2image.py --n-images 4 --n-rows 2 --image-size 256x512 'A photo of an astronaut riding a horse on Mars.'
|
||||
python txt2image.py \
|
||||
--n-images 4 \
|
||||
--n-rows 2 \
|
||||
--image-size 256x512 \
|
||||
'A photo of an astronaut riding a horse on Mars.'
|
||||
```
|
||||
|
||||
### Experimental Options
|
||||
@ -94,17 +119,12 @@ Finetuning
|
||||
|
||||
The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
|
||||
but ymmv) on a provided image dataset. The dataset folder must have an
|
||||
`index.json` file with the following format:
|
||||
`train.jsonl` file with the following format:
|
||||
|
||||
```json
|
||||
{
|
||||
"data": [
|
||||
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
|
||||
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
|
||||
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
|
||||
...
|
||||
]
|
||||
}
|
||||
```jsonl
|
||||
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
|
||||
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
|
||||
...
|
||||
```
|
||||
|
||||
The training script by default trains for 600 iterations with a batch size of
|
||||
@ -126,19 +146,15 @@ The training images are the following 5 images [^2]:
|
||||
|
||||

|
||||
|
||||
We start by making the following `index.json` file and placing it in the same
|
||||
We start by making the following `train.jsonl` file and placing it in the same
|
||||
folder as the images.
|
||||
|
||||
```json
|
||||
{
|
||||
"data": [
|
||||
{"image": "00.jpg", "text": "A photo of sks dog"},
|
||||
{"image": "01.jpg", "text": "A photo of sks dog"},
|
||||
{"image": "02.jpg", "text": "A photo of sks dog"},
|
||||
{"image": "03.jpg", "text": "A photo of sks dog"},
|
||||
{"image": "04.jpg", "text": "A photo of sks dog"}
|
||||
]
|
||||
}
|
||||
```jsonl
|
||||
{"image": "00.jpg", "prompt": "A photo of sks dog"}
|
||||
{"image": "01.jpg", "prompt": "A photo of sks dog"}
|
||||
{"image": "02.jpg", "prompt": "A photo of sks dog"}
|
||||
{"image": "03.jpg", "prompt": "A photo of sks dog"}
|
||||
{"image": "04.jpg", "prompt": "A photo of sks dog"}
|
||||
```
|
||||
|
||||
Subsequently we finetune FLUX using the following command:
|
||||
@ -151,6 +167,17 @@ python dreambooth.py \
|
||||
path/to/dreambooth/dataset/dog6
|
||||
```
|
||||
|
||||
|
||||
Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning.
|
||||
|
||||
```shell
|
||||
python dreambooth.py \
|
||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
||||
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
|
||||
--lora-rank 4 --grad-accumulate 8 \
|
||||
mlx-community/dreambooth-dog6
|
||||
```
|
||||
|
||||
The training requires approximately 50GB of RAM and on an M2 Ultra it takes a
|
||||
bit more than 1 hour.
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@ -13,105 +12,8 @@ import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from flux import FluxPipeline
|
||||
|
||||
|
||||
class FinetuningDataset:
|
||||
def __init__(self, flux, args):
|
||||
self.args = args
|
||||
self.flux = flux
|
||||
self.dataset_base = Path(args.dataset)
|
||||
dataset_index = self.dataset_base / "index.json"
|
||||
if not dataset_index.exists():
|
||||
raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset")
|
||||
with open(dataset_index, "r") as f:
|
||||
self.index = json.load(f)
|
||||
|
||||
self.latents = []
|
||||
self.t5_features = []
|
||||
self.clip_features = []
|
||||
|
||||
def _random_crop_resize(self, img):
|
||||
resolution = self.args.resolution
|
||||
width, height = img.size
|
||||
|
||||
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
||||
|
||||
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
||||
crop_size = (
|
||||
max((0.8 + 0.2 * a) * width, resolution[0]),
|
||||
max((0.8 + 0.2 * a) * height, resolution[1]),
|
||||
)
|
||||
pan = (width - crop_size[0], height - crop_size[1])
|
||||
img = img.crop(
|
||||
(
|
||||
pan[0] * b,
|
||||
pan[1] * c,
|
||||
crop_size[0] + pan[0] * b,
|
||||
crop_size[1] + pan[1] * c,
|
||||
)
|
||||
)
|
||||
|
||||
# Fit the largest rectangle with the ratio of resolution in the image
|
||||
# rectangle.
|
||||
width, height = crop_size
|
||||
ratio = resolution[0] / resolution[1]
|
||||
r1 = (height * ratio, height)
|
||||
r2 = (width, width / ratio)
|
||||
r = r1 if r1[0] <= width else r2
|
||||
img = img.crop(
|
||||
(
|
||||
(width - r[0]) / 2,
|
||||
(height - r[1]) / 2,
|
||||
(width + r[0]) / 2,
|
||||
(height + r[1]) / 2,
|
||||
)
|
||||
)
|
||||
|
||||
# Finally resize the image to resolution
|
||||
img = img.resize(resolution, Image.LANCZOS)
|
||||
|
||||
return mx.array(np.array(img))
|
||||
|
||||
def encode_images(self):
|
||||
"""Encode the images in the latent space to prepare for training."""
|
||||
self.flux.ae.eval()
|
||||
for sample in tqdm(self.index["data"]):
|
||||
input_img = Image.open(self.dataset_base / sample["image"])
|
||||
for i in range(self.args.num_augmentations):
|
||||
img = self._random_crop_resize(input_img)
|
||||
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
||||
x_0 = self.flux.ae.encode(img[None])
|
||||
x_0 = x_0.astype(self.flux.dtype)
|
||||
mx.eval(x_0)
|
||||
self.latents.append(x_0)
|
||||
|
||||
def encode_prompts(self):
|
||||
"""Pre-encode the prompts so that we don't recompute them during
|
||||
training (doesn't allow finetuning the text encoders)."""
|
||||
for sample in tqdm(self.index["data"]):
|
||||
t5_tok, clip_tok = self.flux.tokenize([sample["text"]])
|
||||
t5_feat = self.flux.t5(t5_tok)
|
||||
clip_feat = self.flux.clip(clip_tok).pooled_output
|
||||
mx.eval(t5_feat, clip_feat)
|
||||
self.t5_features.append(t5_feat)
|
||||
self.clip_features.append(clip_feat)
|
||||
|
||||
def iterate(self, batch_size):
|
||||
xs = mx.concatenate(self.latents)
|
||||
t5 = mx.concatenate(self.t5_features)
|
||||
clip = mx.concatenate(self.clip_features)
|
||||
mx.eval(xs, t5, clip)
|
||||
n_aug = self.args.num_augmentations
|
||||
while True:
|
||||
x_indices = mx.random.permutation(len(self.latents))
|
||||
c_indices = x_indices // n_aug
|
||||
for i in range(0, len(self.latents), batch_size):
|
||||
x_i = x_indices[i : i + batch_size]
|
||||
c_i = c_indices[i : i + batch_size]
|
||||
yield xs[x_i], t5[c_i], clip[c_i]
|
||||
from flux import FluxPipeline, Trainer, load_dataset
|
||||
|
||||
|
||||
def generate_progress_images(iteration, flux, args):
|
||||
@ -157,7 +59,8 @@ def save_adapters(iteration, flux, args):
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Finetune Flux to generate images with a specific subject"
|
||||
)
|
||||
@ -247,7 +150,11 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
parser.add_argument("dataset")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the model and set it up for LoRA training. We use the same random
|
||||
@ -267,7 +174,7 @@ if __name__ == "__main__":
|
||||
trainable_params = tree_reduce(
|
||||
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
|
||||
)
|
||||
print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True)
|
||||
print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True)
|
||||
|
||||
# Set up the optimizer and training steps. The steps are a bit verbose to
|
||||
# support gradient accumulation together with compilation.
|
||||
@ -340,10 +247,10 @@ if __name__ == "__main__":
|
||||
x, t5_feat, clip_feat, guidance, prev_grads
|
||||
)
|
||||
|
||||
print("Create the training dataset.", flush=True)
|
||||
dataset = FinetuningDataset(flux, args)
|
||||
dataset.encode_images()
|
||||
dataset.encode_prompts()
|
||||
dataset = load_dataset(args.dataset)
|
||||
trainer = Trainer(flux, dataset, args)
|
||||
trainer.encode_dataset()
|
||||
|
||||
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
|
||||
|
||||
# An initial generation to compare
|
||||
@ -352,7 +259,7 @@ if __name__ == "__main__":
|
||||
grads = None
|
||||
losses = []
|
||||
tic = time.time()
|
||||
for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)):
|
||||
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
|
||||
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
||||
mx.eval(loss, grads, state)
|
||||
losses.append(loss.item())
|
||||
@ -361,7 +268,7 @@ if __name__ == "__main__":
|
||||
toc = time.time()
|
||||
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
||||
print(
|
||||
f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} "
|
||||
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
||||
f"It/s: {10 / (toc - tic):.3f} "
|
||||
f"Peak mem: {peak_mem:.3f} GB",
|
||||
flush=True,
|
||||
|
@ -1,16 +1,10 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
from tqdm import tqdm
|
||||
|
||||
from .datasets import Dataset, load_dataset
|
||||
from .flux import FluxPipeline
|
||||
from .lora import LoRALinear
|
||||
from .sampler import FluxSampler
|
||||
from .trainer import Trainer
|
||||
from .utils import (
|
||||
load_ae,
|
||||
load_clip,
|
||||
@ -19,230 +13,3 @@ from .utils import (
|
||||
load_t5,
|
||||
load_t5_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
class FluxPipeline:
|
||||
def __init__(self, name: str, t5_padding: bool = True):
|
||||
self.dtype = mx.bfloat16
|
||||
self.name = name
|
||||
self.t5_padding = t5_padding
|
||||
|
||||
self.ae = load_ae(name)
|
||||
self.flow = load_flow_model(name)
|
||||
self.clip = load_clip(name)
|
||||
self.clip_tokenizer = load_clip_tokenizer(name)
|
||||
self.t5 = load_t5(name)
|
||||
self.t5_tokenizer = load_t5_tokenizer(name)
|
||||
self.sampler = FluxSampler(name)
|
||||
|
||||
def ensure_models_are_loaded(self):
|
||||
mx.eval(
|
||||
self.ae.parameters(),
|
||||
self.flow.parameters(),
|
||||
self.clip.parameters(),
|
||||
self.t5.parameters(),
|
||||
)
|
||||
|
||||
def reload_text_encoders(self):
|
||||
self.t5 = load_t5(self.name)
|
||||
self.clip = load_clip(self.name)
|
||||
|
||||
def tokenize(self, text):
|
||||
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
|
||||
clip_tokens = self.clip_tokenizer.encode(text)
|
||||
return t5_tokens, clip_tokens
|
||||
|
||||
def _prepare_latent_images(self, x):
|
||||
b, h, w, c = x.shape
|
||||
|
||||
# Pack the latent image to 2x2 patches
|
||||
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
|
||||
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
|
||||
|
||||
# Create positions ids used to positionally encode each patch. Due to
|
||||
# the way RoPE works, this results in an interesting positional
|
||||
# encoding where parts of the feature are holding different positional
|
||||
# information. Namely, the first part holds information independent of
|
||||
# the spatial position (hence 0s), the 2nd part holds vertical spatial
|
||||
# information and the last one horizontal.
|
||||
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
|
||||
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
|
||||
x_ids = mx.stack([i, j, k], axis=-1)
|
||||
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
|
||||
|
||||
return x, x_ids
|
||||
|
||||
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
|
||||
# Prepare the text features
|
||||
txt = self.t5(t5_tokens)
|
||||
if len(txt) == 1 and n_images > 1:
|
||||
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
|
||||
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
|
||||
|
||||
# Prepare the clip text features
|
||||
vec = self.clip(clip_tokens).pooled_output
|
||||
if len(vec) == 1 and n_images > 1:
|
||||
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
|
||||
|
||||
return txt, txt_ids, vec
|
||||
|
||||
def _denoising_loop(
|
||||
self,
|
||||
x_t,
|
||||
x_ids,
|
||||
txt,
|
||||
txt_ids,
|
||||
vec,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
start: float = 1,
|
||||
stop: float = 0,
|
||||
):
|
||||
B = len(x_t)
|
||||
|
||||
def scalar(x):
|
||||
return mx.full((B,), x, dtype=self.dtype)
|
||||
|
||||
guidance = scalar(guidance)
|
||||
timesteps = self.sampler.timesteps(
|
||||
num_steps,
|
||||
x_t.shape[1],
|
||||
start=start,
|
||||
stop=stop,
|
||||
)
|
||||
for i in range(num_steps):
|
||||
t = timesteps[i]
|
||||
t_prev = timesteps[i + 1]
|
||||
|
||||
pred = self.flow(
|
||||
img=x_t,
|
||||
img_ids=x_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=scalar(t),
|
||||
guidance=guidance,
|
||||
)
|
||||
x_t = self.sampler.step(pred, x_t, t, t_prev)
|
||||
|
||||
yield x_t
|
||||
|
||||
def generate_latents(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
latent_size: Tuple[int, int] = (64, 64),
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
if seed is not None:
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
|
||||
x_T, x_ids = self._prepare_latent_images(x_T)
|
||||
|
||||
# Get the conditioning
|
||||
t5_tokens, clip_tokens = self.tokenize(text)
|
||||
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
|
||||
|
||||
# Yield the conditioning for controlled evaluation by the caller
|
||||
yield (x_T, x_ids, txt, txt_ids, vec)
|
||||
|
||||
# Yield the latent sequences from the denoising loop
|
||||
yield from self._denoising_loop(
|
||||
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
|
||||
)
|
||||
|
||||
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
|
||||
h, w = latent_size
|
||||
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
|
||||
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
|
||||
x = self.ae.decode(x)
|
||||
return mx.clip(x + 1, 0, 2) * 0.5
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
latent_size: Tuple[int, int] = (64, 64),
|
||||
seed=None,
|
||||
reload_text_encoders: bool = True,
|
||||
progress: bool = True,
|
||||
):
|
||||
latents = self.generate_latents(
|
||||
text, n_images, num_steps, guidance, latent_size, seed
|
||||
)
|
||||
mx.eval(next(latents))
|
||||
|
||||
if reload_text_encoders:
|
||||
self.reload_text_encoders()
|
||||
|
||||
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
|
||||
mx.eval(x_t)
|
||||
|
||||
images = []
|
||||
for i in tqdm(range(len(x_t)), disable=not progress):
|
||||
images.append(self.decode(x_t[i : i + 1]))
|
||||
mx.eval(images[-1])
|
||||
images = mx.concatenate(images, axis=0)
|
||||
mx.eval(images)
|
||||
|
||||
return images
|
||||
|
||||
def training_loss(
|
||||
self,
|
||||
x_0: mx.array,
|
||||
t5_features: mx.array,
|
||||
clip_features: mx.array,
|
||||
guidance: mx.array,
|
||||
):
|
||||
# Get the text conditioning
|
||||
txt = t5_features
|
||||
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
|
||||
vec = clip_features
|
||||
|
||||
# Prepare the latent input
|
||||
x_0, x_ids = self._prepare_latent_images(x_0)
|
||||
|
||||
# Forward process
|
||||
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
|
||||
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
|
||||
x_t = self.sampler.add_noise(x_0, t, noise=eps)
|
||||
x_t = mx.stop_gradient(x_t)
|
||||
|
||||
# Do the denoising
|
||||
pred = self.flow(
|
||||
img=x_t,
|
||||
img_ids=x_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
return (pred + x_0 - eps).square().mean()
|
||||
|
||||
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
|
||||
"""Swap the linear layers in the transformer blocks with LoRA layers."""
|
||||
all_blocks = self.flow.double_blocks + self.flow.single_blocks
|
||||
all_blocks.reverse()
|
||||
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
|
||||
for i, block in zip(range(num_blocks), all_blocks):
|
||||
loras = []
|
||||
for name, module in block.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
loras.append((name, LoRALinear.from_base(module, r=rank)))
|
||||
block.update_modules(tree_unflatten(loras))
|
||||
|
||||
def fuse_lora_layers(self):
|
||||
fused_layers = []
|
||||
for name, module in self.flow.named_modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
fused_layers.append((name, module.fuse()))
|
||||
self.flow.update_modules(tree_unflatten(fused_layers))
|
||||
|
75
flux/flux/datasets.py
Normal file
75
flux/flux/datasets.py
Normal file
@ -0,0 +1,75 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Dataset:
|
||||
def __getitem__(self, index: int):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LocalDataset(Dataset):
|
||||
prompt_key = "prompt"
|
||||
|
||||
def __init__(self, dataset: str, data_file):
|
||||
self.dataset_base = Path(dataset)
|
||||
with open(data_file, "r") as fid:
|
||||
self._data = [json.loads(l) for l in fid]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
item = self._data[index]
|
||||
image = Image.open(self.dataset_base / item["image"])
|
||||
return image, item[self.prompt_key]
|
||||
|
||||
|
||||
class LegacyDataset(LocalDataset):
|
||||
prompt_key = "text"
|
||||
|
||||
def __init__(self, dataset: str):
|
||||
self.dataset_base = Path(dataset)
|
||||
with open(self.dataset_base / "index.json") as f:
|
||||
self._data = json.load(f)["data"]
|
||||
|
||||
|
||||
class HuggingFaceDataset(Dataset):
|
||||
|
||||
def __init__(self, dataset: str):
|
||||
from datasets import load_dataset as hf_load_dataset
|
||||
|
||||
self._df = hf_load_dataset(dataset)["train"]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._df)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
item = self._df[index]
|
||||
return item["image"], item["prompt"]
|
||||
|
||||
|
||||
def load_dataset(dataset: str):
|
||||
dataset_base = Path(dataset)
|
||||
data_file = dataset_base / "train.jsonl"
|
||||
legacy_file = dataset_base / "index.json"
|
||||
|
||||
if data_file.exists():
|
||||
print(f"Load the local dataset {data_file} .", flush=True)
|
||||
dataset = LocalDataset(dataset, data_file)
|
||||
elif legacy_file.exists():
|
||||
print(f"Load the local dataset {legacy_file} .")
|
||||
print()
|
||||
print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
|
||||
print(" See the README for details.")
|
||||
print(flush=True)
|
||||
dataset = LegacyDataset(dataset)
|
||||
else:
|
||||
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
|
||||
dataset = HuggingFaceDataset(dataset)
|
||||
|
||||
return dataset
|
246
flux/flux/flux.py
Normal file
246
flux/flux/flux.py
Normal file
@ -0,0 +1,246 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
from tqdm import tqdm
|
||||
|
||||
from .lora import LoRALinear
|
||||
from .sampler import FluxSampler
|
||||
from .utils import (
|
||||
load_ae,
|
||||
load_clip,
|
||||
load_clip_tokenizer,
|
||||
load_flow_model,
|
||||
load_t5,
|
||||
load_t5_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
class FluxPipeline:
|
||||
def __init__(self, name: str, t5_padding: bool = True):
|
||||
self.dtype = mx.bfloat16
|
||||
self.name = name
|
||||
self.t5_padding = t5_padding
|
||||
|
||||
self.ae = load_ae(name)
|
||||
self.flow = load_flow_model(name)
|
||||
self.clip = load_clip(name)
|
||||
self.clip_tokenizer = load_clip_tokenizer(name)
|
||||
self.t5 = load_t5(name)
|
||||
self.t5_tokenizer = load_t5_tokenizer(name)
|
||||
self.sampler = FluxSampler(name)
|
||||
|
||||
def ensure_models_are_loaded(self):
|
||||
mx.eval(
|
||||
self.ae.parameters(),
|
||||
self.flow.parameters(),
|
||||
self.clip.parameters(),
|
||||
self.t5.parameters(),
|
||||
)
|
||||
|
||||
def reload_text_encoders(self):
|
||||
self.t5 = load_t5(self.name)
|
||||
self.clip = load_clip(self.name)
|
||||
|
||||
def tokenize(self, text):
|
||||
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
|
||||
clip_tokens = self.clip_tokenizer.encode(text)
|
||||
return t5_tokens, clip_tokens
|
||||
|
||||
def _prepare_latent_images(self, x):
|
||||
b, h, w, c = x.shape
|
||||
|
||||
# Pack the latent image to 2x2 patches
|
||||
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
|
||||
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
|
||||
|
||||
# Create positions ids used to positionally encode each patch. Due to
|
||||
# the way RoPE works, this results in an interesting positional
|
||||
# encoding where parts of the feature are holding different positional
|
||||
# information. Namely, the first part holds information independent of
|
||||
# the spatial position (hence 0s), the 2nd part holds vertical spatial
|
||||
# information and the last one horizontal.
|
||||
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
|
||||
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
|
||||
x_ids = mx.stack([i, j, k], axis=-1)
|
||||
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
|
||||
|
||||
return x, x_ids
|
||||
|
||||
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
|
||||
# Prepare the text features
|
||||
txt = self.t5(t5_tokens)
|
||||
if len(txt) == 1 and n_images > 1:
|
||||
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
|
||||
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
|
||||
|
||||
# Prepare the clip text features
|
||||
vec = self.clip(clip_tokens).pooled_output
|
||||
if len(vec) == 1 and n_images > 1:
|
||||
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
|
||||
|
||||
return txt, txt_ids, vec
|
||||
|
||||
def _denoising_loop(
|
||||
self,
|
||||
x_t,
|
||||
x_ids,
|
||||
txt,
|
||||
txt_ids,
|
||||
vec,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
start: float = 1,
|
||||
stop: float = 0,
|
||||
):
|
||||
B = len(x_t)
|
||||
|
||||
def scalar(x):
|
||||
return mx.full((B,), x, dtype=self.dtype)
|
||||
|
||||
guidance = scalar(guidance)
|
||||
timesteps = self.sampler.timesteps(
|
||||
num_steps,
|
||||
x_t.shape[1],
|
||||
start=start,
|
||||
stop=stop,
|
||||
)
|
||||
for i in range(num_steps):
|
||||
t = timesteps[i]
|
||||
t_prev = timesteps[i + 1]
|
||||
|
||||
pred = self.flow(
|
||||
img=x_t,
|
||||
img_ids=x_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=scalar(t),
|
||||
guidance=guidance,
|
||||
)
|
||||
x_t = self.sampler.step(pred, x_t, t, t_prev)
|
||||
|
||||
yield x_t
|
||||
|
||||
def generate_latents(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
latent_size: Tuple[int, int] = (64, 64),
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
if seed is not None:
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
|
||||
x_T, x_ids = self._prepare_latent_images(x_T)
|
||||
|
||||
# Get the conditioning
|
||||
t5_tokens, clip_tokens = self.tokenize(text)
|
||||
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
|
||||
|
||||
# Yield the conditioning for controlled evaluation by the caller
|
||||
yield (x_T, x_ids, txt, txt_ids, vec)
|
||||
|
||||
# Yield the latent sequences from the denoising loop
|
||||
yield from self._denoising_loop(
|
||||
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
|
||||
)
|
||||
|
||||
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
|
||||
h, w = latent_size
|
||||
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
|
||||
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
|
||||
x = self.ae.decode(x)
|
||||
return mx.clip(x + 1, 0, 2) * 0.5
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
latent_size: Tuple[int, int] = (64, 64),
|
||||
seed=None,
|
||||
reload_text_encoders: bool = True,
|
||||
progress: bool = True,
|
||||
):
|
||||
latents = self.generate_latents(
|
||||
text, n_images, num_steps, guidance, latent_size, seed
|
||||
)
|
||||
mx.eval(next(latents))
|
||||
|
||||
if reload_text_encoders:
|
||||
self.reload_text_encoders()
|
||||
|
||||
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
|
||||
mx.eval(x_t)
|
||||
|
||||
images = []
|
||||
for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"):
|
||||
images.append(self.decode(x_t[i : i + 1]))
|
||||
mx.eval(images[-1])
|
||||
images = mx.concatenate(images, axis=0)
|
||||
mx.eval(images)
|
||||
|
||||
return images
|
||||
|
||||
def training_loss(
|
||||
self,
|
||||
x_0: mx.array,
|
||||
t5_features: mx.array,
|
||||
clip_features: mx.array,
|
||||
guidance: mx.array,
|
||||
):
|
||||
# Get the text conditioning
|
||||
txt = t5_features
|
||||
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
|
||||
vec = clip_features
|
||||
|
||||
# Prepare the latent input
|
||||
x_0, x_ids = self._prepare_latent_images(x_0)
|
||||
|
||||
# Forward process
|
||||
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
|
||||
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
|
||||
x_t = self.sampler.add_noise(x_0, t, noise=eps)
|
||||
x_t = mx.stop_gradient(x_t)
|
||||
|
||||
# Do the denoising
|
||||
pred = self.flow(
|
||||
img=x_t,
|
||||
img_ids=x_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
return (pred + x_0 - eps).square().mean()
|
||||
|
||||
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
|
||||
"""Swap the linear layers in the transformer blocks with LoRA layers."""
|
||||
all_blocks = self.flow.double_blocks + self.flow.single_blocks
|
||||
all_blocks.reverse()
|
||||
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
|
||||
for i, block in zip(range(num_blocks), all_blocks):
|
||||
loras = []
|
||||
for name, module in block.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
loras.append((name, LoRALinear.from_base(module, r=rank)))
|
||||
block.update_modules(tree_unflatten(loras))
|
||||
|
||||
def fuse_lora_layers(self):
|
||||
fused_layers = []
|
||||
for name, module in self.flow.named_modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
fused_layers.append((name, module.fuse()))
|
||||
self.flow.update_modules(tree_unflatten(fused_layers))
|
98
flux/flux/trainer.py
Normal file
98
flux/flux/trainer.py
Normal file
@ -0,0 +1,98 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from tqdm import tqdm
|
||||
|
||||
from .datasets import Dataset
|
||||
from .flux import FluxPipeline
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
||||
def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
|
||||
self.flux = flux
|
||||
self.dataset = dataset
|
||||
self.args = args
|
||||
self.latents = []
|
||||
self.t5_features = []
|
||||
self.clip_features = []
|
||||
|
||||
def _random_crop_resize(self, img):
|
||||
resolution = self.args.resolution
|
||||
width, height = img.size
|
||||
|
||||
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
||||
|
||||
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
||||
crop_size = (
|
||||
max((0.8 + 0.2 * a) * width, resolution[0]),
|
||||
max((0.8 + 0.2 * b) * height, resolution[1]),
|
||||
)
|
||||
pan = (width - crop_size[0], height - crop_size[1])
|
||||
img = img.crop(
|
||||
(
|
||||
pan[0] * c,
|
||||
pan[1] * d,
|
||||
crop_size[0] + pan[0] * c,
|
||||
crop_size[1] + pan[1] * d,
|
||||
)
|
||||
)
|
||||
|
||||
# Fit the largest rectangle with the ratio of resolution in the image
|
||||
# rectangle.
|
||||
width, height = crop_size
|
||||
ratio = resolution[0] / resolution[1]
|
||||
r1 = (height * ratio, height)
|
||||
r2 = (width, width / ratio)
|
||||
r = r1 if r1[0] <= width else r2
|
||||
img = img.crop(
|
||||
(
|
||||
(width - r[0]) / 2,
|
||||
(height - r[1]) / 2,
|
||||
(width + r[0]) / 2,
|
||||
(height + r[1]) / 2,
|
||||
)
|
||||
)
|
||||
|
||||
# Finally resize the image to resolution
|
||||
img = img.resize(resolution, Image.LANCZOS)
|
||||
|
||||
return mx.array(np.array(img))
|
||||
|
||||
def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
|
||||
for i in range(num_augmentations):
|
||||
img = self._random_crop_resize(input_img)
|
||||
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
||||
x_0 = self.flux.ae.encode(img[None])
|
||||
x_0 = x_0.astype(self.flux.dtype)
|
||||
mx.eval(x_0)
|
||||
self.latents.append(x_0)
|
||||
|
||||
def _encode_prompt(self, prompt):
|
||||
t5_tok, clip_tok = self.flux.tokenize([prompt])
|
||||
t5_feat = self.flux.t5(t5_tok)
|
||||
clip_feat = self.flux.clip(clip_tok).pooled_output
|
||||
mx.eval(t5_feat, clip_feat)
|
||||
self.t5_features.append(t5_feat)
|
||||
self.clip_features.append(clip_feat)
|
||||
|
||||
def encode_dataset(self):
|
||||
"""Encode the images & prompt in the latent space to prepare for training."""
|
||||
self.flux.ae.eval()
|
||||
for image, prompt in tqdm(self.dataset, desc="encode dataset"):
|
||||
self._encode_image(image, self.args.num_augmentations)
|
||||
self._encode_prompt(prompt)
|
||||
|
||||
def iterate(self, batch_size):
|
||||
xs = mx.concatenate(self.latents)
|
||||
t5 = mx.concatenate(self.t5_features)
|
||||
clip = mx.concatenate(self.clip_features)
|
||||
mx.eval(xs, t5, clip)
|
||||
n_aug = self.args.num_augmentations
|
||||
while True:
|
||||
x_indices = mx.random.permutation(len(self.latents))
|
||||
c_indices = x_indices // n_aug
|
||||
for i in range(0, len(self.latents), batch_size):
|
||||
x_i = x_indices[i : i + batch_size]
|
||||
c_i = c_indices[i : i + batch_size]
|
||||
yield xs[x_i], t5[c_i], clip[c_i]
|
@ -77,7 +77,7 @@ if __name__ == "__main__":
|
||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||
|
||||
if args.preload_models:
|
||||
sd.ensure_models_are_loaded()
|
||||
flux.ensure_models_are_loaded()
|
||||
|
||||
# Make the generator
|
||||
latent_size = to_latent_size(args.image_size)
|
||||
|
@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \
|
||||
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
|
||||
the generated prompt. If not provided, the default mappings are used.
|
||||
|
||||
- `stop`: (Optional) An array of strings or a single string. Thesse are
|
||||
- `stop`: (Optional) An array of strings or a single string. These are
|
||||
sequences of tokens on which the generation should stop.
|
||||
|
||||
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
|
||||
@ -84,7 +84,37 @@ curl localhost:8080/v1/chat/completions \
|
||||
started in.
|
||||
|
||||
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
|
||||
rlative to the directory the server was started in.
|
||||
relative to the directory the server was started in.
|
||||
|
||||
### Response Fields
|
||||
|
||||
- `id`: A unique identifier for the chat.
|
||||
|
||||
- `system_fingerprint`: A unique identifier for the system.
|
||||
|
||||
- `object`: Any of "chat.completions", "chat.completions.chunk" (for
|
||||
streaming), or "text.completion".
|
||||
|
||||
- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).
|
||||
|
||||
- `created`: A time-stamp for when the request was processed.
|
||||
|
||||
- `choices`: A list of outputs. Each output is a dictionary containing the fields:
|
||||
- `index`: The index in the list.
|
||||
- `logprobs`: A dictionary containing the fields:
|
||||
- `token_logprobs`: A list of the log probabilities for the generated
|
||||
tokens.
|
||||
- `tokens`: A list of the generated token ids.
|
||||
- `top_logprobs`: A list of lists. Each list contains the `logprobs`
|
||||
top tokens (if requested) with their corresponding probabilities.
|
||||
- `finish_reason`: The reason the completion ended. This can be either of
|
||||
`"stop"` or `"length"`.
|
||||
- `message`: The text response from the model.
|
||||
|
||||
- `usage`: A dictionary containing the fields:
|
||||
- `prompt_tokens`: The number of prompt tokens processed.
|
||||
- `completion_tokens`: The number of tokens generated.
|
||||
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.
|
||||
|
||||
### List Models
|
||||
|
||||
@ -97,5 +127,5 @@ curl localhost:8080/v1/models -H "Content-Type: application/json"
|
||||
This will return a list of locally available models where each model in the
|
||||
list contains the following fields:
|
||||
|
||||
- `"id"`: The Hugging Face repo id.
|
||||
- `"created"`: A timestamp representing the model creation time.
|
||||
- `id`: The Hugging Face repo id.
|
||||
- `created`: A time-stamp representing the model creation time.
|
||||
|
@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False):
|
||||
return cache
|
||||
|
||||
|
||||
def can_trim_prompt_cache(cache: List[Any]) -> bool:
|
||||
"""
|
||||
Check if model's cache can be trimmed.
|
||||
"""
|
||||
return all(c.is_trimmable() for c in cache)
|
||||
|
||||
|
||||
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||
"""
|
||||
Trim the model's cache by the given number of tokens.
|
||||
@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||
Returns:
|
||||
(int): The number of tokens that were trimmed.
|
||||
"""
|
||||
if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
|
||||
if not can_trim_prompt_cache(cache) or len(cache) == 0:
|
||||
return 0
|
||||
return [c.trim(num_tokens) for c in cache][0]
|
||||
|
||||
|
@ -220,17 +220,17 @@ class DeepseekV2Attention(nn.Module):
|
||||
|
||||
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
|
||||
|
||||
k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
q_pe = self.rope(q_pe, cache.offset)
|
||||
k_pe = self.rope(k_pe, cache.offset)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys, values = cache.update_and_fetch(
|
||||
mx.concatenate([k_nope, k_pe], axis=-1), values
|
||||
)
|
||||
else:
|
||||
q_pe = self.rope(q_pe)
|
||||
k_pe = self.rope(k_pe)
|
||||
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
|
||||
keys = mx.concatenate([k_nope, k_pe], axis=-1)
|
||||
|
||||
queries = mx.concatenate([q_nope, q_pe], axis=-1)
|
||||
@ -291,7 +291,7 @@ class MoEGate(nn.Module):
|
||||
scores = scores.reshape(bsz, seq_len, -1)
|
||||
|
||||
k = self.top_k
|
||||
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
|
||||
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
|
||||
scores = mx.take_along_axis(scores, inds, axis=-1)
|
||||
scores = scores * self.routed_scaling_factor
|
||||
|
||||
|
@ -3,19 +3,38 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import platform
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import mlx.core as mx
|
||||
from huggingface_hub import scan_cache_dir
|
||||
|
||||
from ._version import __version__
|
||||
from .models.cache import make_prompt_cache
|
||||
from .utils import generate_step, load
|
||||
|
||||
|
||||
def get_system_fingerprint():
|
||||
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
|
||||
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"
|
||||
|
||||
|
||||
class StopCondition(NamedTuple):
|
||||
stop_met: bool
|
||||
trim_length: int
|
||||
@ -94,6 +113,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
|
||||
return prompt.rstrip()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptCache:
|
||||
cache: List[Any] = field(default_factory=list)
|
||||
model_key: Tuple[str, Optional[str]] = ("", None)
|
||||
tokens: List[int] = field(default_factory=list)
|
||||
|
||||
|
||||
class ModelProvider:
|
||||
def __init__(self, cli_args: argparse.Namespace):
|
||||
"""Load models on demand and persist them across the whole process."""
|
||||
@ -156,12 +182,21 @@ class ModelProvider:
|
||||
|
||||
|
||||
class APIHandler(BaseHTTPRequestHandler):
|
||||
def __init__(self, model_provider: ModelProvider, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
model_provider: ModelProvider,
|
||||
*args,
|
||||
prompt_cache: Optional[PromptCache] = None,
|
||||
system_fingerprint: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create static request specific metadata
|
||||
"""
|
||||
self.created = int(time.time())
|
||||
self.model_provider = model_provider
|
||||
self.prompt_cache = prompt_cache or PromptCache()
|
||||
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _set_cors_headers(self):
|
||||
@ -215,7 +250,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.stream_options = self.body.get("stream_options", None)
|
||||
self.requested_model = self.body.get("model", "default_model")
|
||||
self.adapter = self.body.get("adapters", None)
|
||||
self.max_tokens = self.body.get("max_tokens", 100)
|
||||
self.max_tokens = self.body.get("max_completion_tokens", None)
|
||||
if self.max_tokens is None:
|
||||
self.max_tokens = self.body.get("max_tokens", 512)
|
||||
self.temperature = self.body.get("temperature", 1.0)
|
||||
self.top_p = self.body.get("top_p", 1.0)
|
||||
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
|
||||
@ -343,7 +380,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
# Static response
|
||||
response = {
|
||||
"id": self.request_id,
|
||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||
"system_fingerprint": self.system_fingerprint,
|
||||
"object": self.object_type,
|
||||
"model": self.requested_model,
|
||||
"created": self.created,
|
||||
@ -388,16 +425,30 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
return response
|
||||
|
||||
def get_prompt_cache(self, prompt):
|
||||
cache_len = len(self.prompt_cache.tokens)
|
||||
if (
|
||||
self.prompt_cache.model_key != self.model_provider.model_key
|
||||
or cache_len >= len(prompt)
|
||||
or self.prompt_cache.tokens != prompt[:cache_len]
|
||||
):
|
||||
self.prompt_cache.model_key = self.model_provider.model_key
|
||||
self.prompt_cache.cache = make_prompt_cache(self.model_provider.model)
|
||||
else:
|
||||
prompt = prompt[cache_len:]
|
||||
self.prompt_cache.tokens.extend(prompt)
|
||||
return prompt
|
||||
|
||||
def handle_completion(
|
||||
self,
|
||||
prompt: mx.array,
|
||||
prompt: List[int],
|
||||
stop_id_sequences: List[List[int]],
|
||||
):
|
||||
"""
|
||||
Generate a response to a prompt and send it to the client in a single batch.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||
prompt (List[int]): The tokenized prompt.
|
||||
stop_id_sequences (List[List[int]]): A list of stop words passed
|
||||
to the stopping_criteria function
|
||||
"""
|
||||
@ -409,17 +460,21 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
logging.debug(f"Starting completion:")
|
||||
token_logprobs = []
|
||||
top_tokens = []
|
||||
for (token, logprobs), _ in zip(
|
||||
|
||||
prompt = self.get_prompt_cache(prompt)
|
||||
|
||||
for _, (token, logprobs) in zip(
|
||||
range(self.max_tokens),
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
prompt=mx.array(prompt),
|
||||
model=self.model,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
repetition_context_size=self.repetition_context_size,
|
||||
logit_bias=self.logit_bias,
|
||||
prompt_cache=self.prompt_cache.cache,
|
||||
),
|
||||
range(self.max_tokens),
|
||||
):
|
||||
detokenizer.add_token(token)
|
||||
logging.debug(detokenizer.text)
|
||||
@ -430,7 +485,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
top_indices = sorted_indices[: self.logprobs]
|
||||
top_logprobs = logprobs[top_indices]
|
||||
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
|
||||
top_tokens.append(dict(top_token_info))
|
||||
top_tokens.append(tuple(top_token_info))
|
||||
|
||||
token_logprobs.append(logprobs[token].item())
|
||||
|
||||
@ -445,6 +500,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
)
|
||||
break
|
||||
|
||||
self.prompt_cache.tokens.extend(tokens)
|
||||
detokenizer.finalize()
|
||||
text = (
|
||||
detokenizer.text
|
||||
@ -474,7 +530,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
|
||||
def handle_stream(
|
||||
self,
|
||||
prompt: mx.array,
|
||||
prompt: List[int],
|
||||
stop_id_sequences: List[List[int]],
|
||||
):
|
||||
"""
|
||||
@ -482,7 +538,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
Sent Events (SSE) stream.
|
||||
|
||||
Args:
|
||||
prompt (mx.array): The prompt, in token form inside of a mlx array
|
||||
prompt (mx.array): The tokenized prompt
|
||||
stop_id_sequences (List[List[int]]): A list of stop words passed to
|
||||
the stopping_criteria function
|
||||
"""
|
||||
@ -496,16 +552,19 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
stop_sequence_suffix = None
|
||||
logging.debug(f"Starting stream:")
|
||||
|
||||
for (token, _), _ in zip(
|
||||
prompt = self.get_prompt_cache(prompt)
|
||||
|
||||
for _, (token, _) in zip(
|
||||
range(self.max_tokens),
|
||||
generate_step(
|
||||
prompt=prompt,
|
||||
prompt=mx.array(prompt),
|
||||
model=self.model,
|
||||
temp=self.temperature,
|
||||
top_p=self.top_p,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
repetition_context_size=self.repetition_context_size,
|
||||
prompt_cache=self.prompt_cache.cache,
|
||||
),
|
||||
range(self.max_tokens),
|
||||
):
|
||||
detokenizer.add_token(token)
|
||||
logging.debug(detokenizer.text)
|
||||
@ -531,9 +590,12 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
continue
|
||||
|
||||
new_text = detokenizer.last_segment
|
||||
response = self.generate_response(new_text, None)
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
if new_text:
|
||||
response = self.generate_response(new_text, None)
|
||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||
self.wfile.flush()
|
||||
|
||||
self.prompt_cache.tokens.extend(tokens)
|
||||
|
||||
# check is there any remaining text to send
|
||||
detokenizer.finalize()
|
||||
@ -559,7 +621,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
):
|
||||
response = {
|
||||
"id": self.request_id,
|
||||
"system_fingerprint": f"fp_{uuid.uuid4()}",
|
||||
"system_fingerprint": self.system_fingerprint,
|
||||
"object": "chat.completion",
|
||||
"model": self.requested_model,
|
||||
"created": self.created,
|
||||
@ -572,7 +634,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
}
|
||||
return response
|
||||
|
||||
def handle_chat_completions(self) -> mx.array:
|
||||
def handle_chat_completions(self) -> List[int]:
|
||||
"""
|
||||
Handle a chat completion request.
|
||||
|
||||
@ -587,7 +649,6 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
self.object_type = (
|
||||
"chat.completions.chunk" if self.stream else "chat.completions"
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(self.tokenizer, "apply_chat_template")
|
||||
and self.tokenizer.chat_template
|
||||
@ -602,9 +663,9 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
prompt = convert_chat(body["messages"], body.get("role_mapping"))
|
||||
prompt = self.tokenizer.encode(prompt)
|
||||
|
||||
return mx.array(prompt)
|
||||
return prompt
|
||||
|
||||
def handle_text_completions(self) -> mx.array:
|
||||
def handle_text_completions(self) -> List[int]:
|
||||
"""
|
||||
Handle a text completion request.
|
||||
|
||||
@ -614,11 +675,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
||||
# Determine response type
|
||||
self.request_id = f"cmpl-{uuid.uuid4()}"
|
||||
self.object_type = "text_completion"
|
||||
|
||||
assert "prompt" in self.body, "Request did not contain a prompt"
|
||||
prompt_text = self.body["prompt"]
|
||||
prompt = self.tokenizer.encode(prompt_text)
|
||||
return mx.array(prompt)
|
||||
return self.tokenizer.encode(self.body["prompt"])
|
||||
|
||||
def do_GET(self):
|
||||
"""
|
||||
@ -669,9 +727,16 @@ def run(
|
||||
handler_class=APIHandler,
|
||||
):
|
||||
server_address = (host, port)
|
||||
prompt_cache = PromptCache()
|
||||
httpd = server_class(
|
||||
server_address,
|
||||
lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs),
|
||||
lambda *args, **kwargs: handler_class(
|
||||
model_provider,
|
||||
prompt_cache=prompt_cache,
|
||||
system_fingerprint=get_system_fingerprint(),
|
||||
*args,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
warnings.warn(
|
||||
"mlx_lm.server is not recommended for production as "
|
||||
|
@ -97,6 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
|
||||
def text(self):
|
||||
if self._current_tokens:
|
||||
self._current_text = self._tokenizer.decode(self._current_tokens)
|
||||
if (
|
||||
self._tokenizer.clean_up_tokenization_spaces
|
||||
and self._current_text[-1] == " "
|
||||
):
|
||||
self._current_text = self._current_text[:-1]
|
||||
if self._current_text and self._current_text[-1] == "\n":
|
||||
self._tokens.extend(self._current_tokens)
|
||||
self._text += self._current_text
|
||||
@ -164,9 +169,11 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
"""
|
||||
|
||||
_byte_decoder = None
|
||||
_space_matches = (".", "?", "!", ",", "'", "n't", "'m", "'s", "'ve", "'re")
|
||||
|
||||
def __init__(self, tokenizer, trim_space=False):
|
||||
self.trim_space = trim_space
|
||||
def __init__(self, tokenizer):
|
||||
|
||||
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
|
||||
|
||||
# Extract the tokens in a list from id to text
|
||||
self.tokenmap = [None] * len(tokenizer.vocab)
|
||||
@ -185,17 +192,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
self.text = ""
|
||||
self.tokens = []
|
||||
|
||||
def _maybe_trim_space(self, current_text):
|
||||
if current_text[0] != " ":
|
||||
return current_text
|
||||
elif not self.text:
|
||||
return current_text[1:]
|
||||
elif self.clean_spaces and current_text[1:].startswith(self._space_matches):
|
||||
return current_text[1:]
|
||||
return current_text
|
||||
|
||||
def add_token(self, token):
|
||||
v = self.tokenmap[token]
|
||||
# if the token starts with space
|
||||
if self._byte_decoder[v[0]] == 32:
|
||||
current_text = bytearray(
|
||||
self._byte_decoder[c] for c in self._unflushed
|
||||
).decode("utf-8")
|
||||
if self.text or not self.trim_space:
|
||||
self.text += current_text
|
||||
else:
|
||||
self.text += _remove_space(current_text)
|
||||
self.text += self._maybe_trim_space(current_text)
|
||||
self._unflushed = v
|
||||
else:
|
||||
self._unflushed += v
|
||||
@ -204,10 +216,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
|
||||
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
|
||||
"utf-8"
|
||||
)
|
||||
if self.text or not self.trim_space:
|
||||
self.text += current_text
|
||||
else:
|
||||
self.text += _remove_space(current_text)
|
||||
self.text += self._maybe_trim_space(current_text)
|
||||
self._unflushed = ""
|
||||
|
||||
@classmethod
|
||||
@ -303,14 +312,7 @@ def _is_spm_decoder_no_space(decoder):
|
||||
|
||||
|
||||
def _is_bpe_decoder(decoder):
|
||||
_target_description = {
|
||||
"type": "ByteLevel",
|
||||
"add_prefix_space": False,
|
||||
"trim_offsets": False,
|
||||
"use_regex": False,
|
||||
}
|
||||
|
||||
return _match(_target_description, decoder)
|
||||
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
|
||||
|
||||
|
||||
def load_tokenizer(model_path, tokenizer_config_extra={}):
|
||||
|
@ -246,10 +246,10 @@ def generate_step(
|
||||
|
||||
y, logprobs = _step(y)
|
||||
|
||||
mx.async_eval(y)
|
||||
mx.async_eval(y, logprobs)
|
||||
while True:
|
||||
next_y, next_logprobs = _step(y)
|
||||
mx.async_eval(next_y)
|
||||
mx.async_eval(next_y, next_logprobs)
|
||||
yield y.item(), logprobs
|
||||
y, logprobs = next_y, next_logprobs
|
||||
|
||||
@ -348,7 +348,9 @@ def generate(
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
|
||||
with mx.stream(mx.cpu):
|
||||
prob = mx.exp(logprobs[token]).item()
|
||||
formatter(detokenizer.last_segment, prob)
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import copy
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -215,6 +216,28 @@ class TestPromptCache(unittest.TestCase):
|
||||
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
|
||||
)
|
||||
|
||||
def test_cache_copying(self):
|
||||
cache = [KVCache()]
|
||||
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||
cache[0].update_and_fetch(x, x)
|
||||
|
||||
y = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||
cache[0].update_and_fetch(y, y)
|
||||
|
||||
old_cache = copy.deepcopy(cache)
|
||||
|
||||
trim_prompt_cache(cache, 1)
|
||||
|
||||
self.assertTrue(old_cache[0].offset, 11)
|
||||
self.assertTrue(cache[0].offset, 10)
|
||||
|
||||
z = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||
cache[0].update_and_fetch(z, z)
|
||||
|
||||
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
|
||||
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -14,6 +14,7 @@ class DummyModelProvider:
|
||||
def __init__(self):
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
self.model, self.tokenizer = load(HF_MODEL_PATH)
|
||||
self.model_key = (HF_MODEL_PATH, None)
|
||||
|
||||
def load(self, model, adapter=None):
|
||||
assert model in ["default_model", "chat_model"]
|
||||
|
76
llms/tests/test_tokenizers.py
Normal file
76
llms/tests/test_tokenizers.py
Normal file
@ -0,0 +1,76 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx_lm.tokenizer_utils import (
|
||||
BPEStreamingDetokenizer,
|
||||
NaiveStreamingDetokenizer,
|
||||
SPMStreamingDetokenizer,
|
||||
load_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
class TestTokenizers(unittest.TestCase):
|
||||
|
||||
def download_tokenizer(self, repo):
|
||||
path = Path(
|
||||
snapshot_download(
|
||||
repo_id=repo,
|
||||
allow_patterns=[
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"special_tokens_map.json",
|
||||
"tokenizer.model",
|
||||
],
|
||||
)
|
||||
)
|
||||
return load_tokenizer(path)
|
||||
|
||||
def check_tokenizer(self, tokenizer):
|
||||
def check(tokens):
|
||||
expected_text = tokenizer.decode(tokens)
|
||||
detokenizer = tokenizer.detokenizer
|
||||
detokenizer.reset()
|
||||
text = ""
|
||||
for t in tokens:
|
||||
detokenizer.add_token(t)
|
||||
seg = detokenizer.last_segment
|
||||
text += seg
|
||||
detokenizer.finalize()
|
||||
text += detokenizer.last_segment
|
||||
self.assertEqual(text, expected_text)
|
||||
|
||||
tokens = tokenizer.encode("a ,b")
|
||||
check(tokens)
|
||||
|
||||
tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}')
|
||||
check(tokens)
|
||||
|
||||
tokens = tokenizer.encode("3 3")
|
||||
check(tokens)
|
||||
|
||||
def test_tokenizers(self):
|
||||
tokenizer_repos = [
|
||||
("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer),
|
||||
("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer),
|
||||
("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer),
|
||||
("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer),
|
||||
("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer),
|
||||
]
|
||||
for tokenizer_repo, expected_detokenizer in tokenizer_repos:
|
||||
with self.subTest(tokenizer=tokenizer_repo):
|
||||
tokenizer = self.download_tokenizer(tokenizer_repo)
|
||||
tokenizer.decode([0, 1, 2])
|
||||
self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer))
|
||||
self.check_tokenizer(tokenizer)
|
||||
|
||||
# Try one with a naive detokenizer
|
||||
tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit")
|
||||
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
|
||||
self.check_tokenizer(tokenizer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user