Merge branch 'ml-explore:main' into adding-support-for-mamba2

This commit is contained in:
Gökdeniz Gülmez 2024-10-16 18:57:55 +02:00 committed by GitHub
commit 855fcc4327
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 756 additions and 428 deletions

View File

@ -26,8 +26,8 @@ jobs:
- run:
name: Install dependencies
command: |
brew install python@3.8
python3.8 -m venv env
brew install python@3.9
python3.9 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install unittest-xml-reporting

View File

@ -20,8 +20,10 @@ Some more useful examples are listed below.
### Image Models
- Generating images
- [FLUX](flux)
- [Stable Diffusion or SDXL](stable_diffusion)
- Image classification using [ResNets on CIFAR-10](cifar).
- Generating images with [Stable Diffusion or SDXL](stable_diffusion).
- Convolutional variational autoencoder [(CVAE) on MNIST](cvae).
### Audio Models

View File

@ -21,13 +21,34 @@ The dependencies are minimal, namely:
- `huggingface-hub` to download the checkpoints.
- `regex` for the tokenization
- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script
- `tqdm`, `PIL`, and `numpy` for the scripts
- `sentencepiece` for the T5 tokenizer
- `datasets` for using an HF dataset directly
You can install all of the above with the `requirements.txt` as follows:
pip install -r requirements.txt
Usage
---------
You can use the following command to generate an image, using `--output` to specify the storage location of the image, defaulting to `out.png`.
```shell
python txt2image.py --model schnell \
--n-images 1 \
--image-size 256x512 \
--verbose \
'A photo of an astronaut riding a horse on Mars.'
```
For more parameters, please use the `--help` command to view.
```shell
python txt2image.py --help
```
Inference
---------
@ -78,7 +99,11 @@ except for some additional logic to quantize and/or load trained adapters. One
can use the script as follows:
```shell
python txt2image.py --n-images 4 --n-rows 2 --image-size 256x512 'A photo of an astronaut riding a horse on Mars.'
python txt2image.py \
--n-images 4 \
--n-rows 2 \
--image-size 256x512 \
'A photo of an astronaut riding a horse on Mars.'
```
### Experimental Options
@ -94,17 +119,12 @@ Finetuning
The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
but ymmv) on a provided image dataset. The dataset folder must have an
`index.json` file with the following format:
`train.jsonl` file with the following format:
```json
{
"data": [
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
...
]
}
```jsonl
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
...
```
The training script by default trains for 600 iterations with a batch size of
@ -126,19 +146,15 @@ The training images are the following 5 images [^2]:
![dog6](static/dog6.png)
We start by making the following `index.json` file and placing it in the same
We start by making the following `train.jsonl` file and placing it in the same
folder as the images.
```json
{
"data": [
{"image": "00.jpg", "text": "A photo of sks dog"},
{"image": "01.jpg", "text": "A photo of sks dog"},
{"image": "02.jpg", "text": "A photo of sks dog"},
{"image": "03.jpg", "text": "A photo of sks dog"},
{"image": "04.jpg", "text": "A photo of sks dog"}
]
}
```jsonl
{"image": "00.jpg", "prompt": "A photo of sks dog"}
{"image": "01.jpg", "prompt": "A photo of sks dog"}
{"image": "02.jpg", "prompt": "A photo of sks dog"}
{"image": "03.jpg", "prompt": "A photo of sks dog"}
{"image": "04.jpg", "prompt": "A photo of sks dog"}
```
Subsequently we finetune FLUX using the following command:
@ -151,6 +167,17 @@ python dreambooth.py \
path/to/dreambooth/dataset/dog6
```
Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning.
```shell
python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 8 \
mlx-community/dreambooth-dog6
```
The training requires approximately 50GB of RAM and on an M2 Ultra it takes a
bit more than 1 hour.

View File

@ -1,7 +1,6 @@
# Copyright © 2024 Apple Inc.
import argparse
import json
import time
from functools import partial
from pathlib import Path
@ -13,105 +12,8 @@ import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline
class FinetuningDataset:
def __init__(self, flux, args):
self.args = args
self.flux = flux
self.dataset_base = Path(args.dataset)
dataset_index = self.dataset_base / "index.json"
if not dataset_index.exists():
raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset")
with open(dataset_index, "r") as f:
self.index = json.load(f)
self.latents = []
self.t5_features = []
self.clip_features = []
def _random_crop_resize(self, img):
resolution = self.args.resolution
width, height = img.size
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * a) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * b,
pan[1] * c,
crop_size[0] + pan[0] * b,
crop_size[1] + pan[1] * c,
)
)
# Fit the largest rectangle with the ratio of resolution in the image
# rectangle.
width, height = crop_size
ratio = resolution[0] / resolution[1]
r1 = (height * ratio, height)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)
# Finally resize the image to resolution
img = img.resize(resolution, Image.LANCZOS)
return mx.array(np.array(img))
def encode_images(self):
"""Encode the images in the latent space to prepare for training."""
self.flux.ae.eval()
for sample in tqdm(self.index["data"]):
input_img = Image.open(self.dataset_base / sample["image"])
for i in range(self.args.num_augmentations):
img = self._random_crop_resize(input_img)
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)
def encode_prompts(self):
"""Pre-encode the prompts so that we don't recompute them during
training (doesn't allow finetuning the text encoders)."""
for sample in tqdm(self.index["data"]):
t5_tok, clip_tok = self.flux.tokenize([sample["text"]])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)
def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True:
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i : i + batch_size]
c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]
from flux import FluxPipeline, Trainer, load_dataset
def generate_progress_images(iteration, flux, args):
@ -157,7 +59,8 @@ def save_adapters(iteration, flux, args):
)
if __name__ == "__main__":
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="Finetune Flux to generate images with a specific subject"
)
@ -247,7 +150,11 @@ if __name__ == "__main__":
)
parser.add_argument("dataset")
return parser
if __name__ == "__main__":
parser = setup_arg_parser()
args = parser.parse_args()
# Load the model and set it up for LoRA training. We use the same random
@ -267,7 +174,7 @@ if __name__ == "__main__":
trainable_params = tree_reduce(
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
)
print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True)
print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True)
# Set up the optimizer and training steps. The steps are a bit verbose to
# support gradient accumulation together with compilation.
@ -340,10 +247,10 @@ if __name__ == "__main__":
x, t5_feat, clip_feat, guidance, prev_grads
)
print("Create the training dataset.", flush=True)
dataset = FinetuningDataset(flux, args)
dataset.encode_images()
dataset.encode_prompts()
dataset = load_dataset(args.dataset)
trainer = Trainer(flux, dataset, args)
trainer.encode_dataset()
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
# An initial generation to compare
@ -352,7 +259,7 @@ if __name__ == "__main__":
grads = None
losses = []
tic = time.time()
for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)):
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
mx.eval(loss, grads, state)
losses.append(loss.item())
@ -361,7 +268,7 @@ if __name__ == "__main__":
toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024**3
print(
f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} "
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
f"It/s: {10 / (toc - tic):.3f} "
f"Peak mem: {peak_mem:.3f} GB",
flush=True,

View File

@ -1,16 +1,10 @@
# Copyright © 2024 Apple Inc.
import math
import time
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from tqdm import tqdm
from .datasets import Dataset, load_dataset
from .flux import FluxPipeline
from .lora import LoRALinear
from .sampler import FluxSampler
from .trainer import Trainer
from .utils import (
load_ae,
load_clip,
@ -19,230 +13,3 @@ from .utils import (
load_t5,
load_t5_tokenizer,
)
class FluxPipeline:
def __init__(self, name: str, t5_padding: bool = True):
self.dtype = mx.bfloat16
self.name = name
self.t5_padding = t5_padding
self.ae = load_ae(name)
self.flow = load_flow_model(name)
self.clip = load_clip(name)
self.clip_tokenizer = load_clip_tokenizer(name)
self.t5 = load_t5(name)
self.t5_tokenizer = load_t5_tokenizer(name)
self.sampler = FluxSampler(name)
def ensure_models_are_loaded(self):
mx.eval(
self.ae.parameters(),
self.flow.parameters(),
self.clip.parameters(),
self.t5.parameters(),
)
def reload_text_encoders(self):
self.t5 = load_t5(self.name)
self.clip = load_clip(self.name)
def tokenize(self, text):
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
clip_tokens = self.clip_tokenizer.encode(text)
return t5_tokens, clip_tokens
def _prepare_latent_images(self, x):
b, h, w, c = x.shape
# Pack the latent image to 2x2 patches
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
# Create positions ids used to positionally encode each patch. Due to
# the way RoPE works, this results in an interesting positional
# encoding where parts of the feature are holding different positional
# information. Namely, the first part holds information independent of
# the spatial position (hence 0s), the 2nd part holds vertical spatial
# information and the last one horizontal.
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
x_ids = mx.stack([i, j, k], axis=-1)
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
return x, x_ids
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
# Prepare the text features
txt = self.t5(t5_tokens)
if len(txt) == 1 and n_images > 1:
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
# Prepare the clip text features
vec = self.clip(clip_tokens).pooled_output
if len(vec) == 1 and n_images > 1:
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
return txt, txt_ids, vec
def _denoising_loop(
self,
x_t,
x_ids,
txt,
txt_ids,
vec,
num_steps: int = 35,
guidance: float = 4.0,
start: float = 1,
stop: float = 0,
):
B = len(x_t)
def scalar(x):
return mx.full((B,), x, dtype=self.dtype)
guidance = scalar(guidance)
timesteps = self.sampler.timesteps(
num_steps,
x_t.shape[1],
start=start,
stop=stop,
)
for i in range(num_steps):
t = timesteps[i]
t_prev = timesteps[i + 1]
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=scalar(t),
guidance=guidance,
)
x_t = self.sampler.step(pred, x_t, t, t_prev)
yield x_t
def generate_latents(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
):
# Set the PRNG state
if seed is not None:
mx.random.seed(seed)
# Create the latent variables
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
x_T, x_ids = self._prepare_latent_images(x_T)
# Get the conditioning
t5_tokens, clip_tokens = self.tokenize(text)
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
# Yield the conditioning for controlled evaluation by the caller
yield (x_T, x_ids, txt, txt_ids, vec)
# Yield the latent sequences from the denoising loop
yield from self._denoising_loop(
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
)
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
h, w = latent_size
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
x = self.ae.decode(x)
return mx.clip(x + 1, 0, 2) * 0.5
def generate_images(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
reload_text_encoders: bool = True,
progress: bool = True,
):
latents = self.generate_latents(
text, n_images, num_steps, guidance, latent_size, seed
)
mx.eval(next(latents))
if reload_text_encoders:
self.reload_text_encoders()
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
mx.eval(x_t)
images = []
for i in tqdm(range(len(x_t)), disable=not progress):
images.append(self.decode(x_t[i : i + 1]))
mx.eval(images[-1])
images = mx.concatenate(images, axis=0)
mx.eval(images)
return images
def training_loss(
self,
x_0: mx.array,
t5_features: mx.array,
clip_features: mx.array,
guidance: mx.array,
):
# Get the text conditioning
txt = t5_features
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
vec = clip_features
# Prepare the latent input
x_0, x_ids = self._prepare_latent_images(x_0)
# Forward process
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
x_t = self.sampler.add_noise(x_0, t, noise=eps)
x_t = mx.stop_gradient(x_t)
# Do the denoising
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t,
guidance=guidance,
)
return (pred + x_0 - eps).square().mean()
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
"""Swap the linear layers in the transformer blocks with LoRA layers."""
all_blocks = self.flow.double_blocks + self.flow.single_blocks
all_blocks.reverse()
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
for i, block in zip(range(num_blocks), all_blocks):
loras = []
for name, module in block.named_modules():
if isinstance(module, nn.Linear):
loras.append((name, LoRALinear.from_base(module, r=rank)))
block.update_modules(tree_unflatten(loras))
def fuse_lora_layers(self):
fused_layers = []
for name, module in self.flow.named_modules():
if isinstance(module, LoRALinear):
fused_layers.append((name, module.fuse()))
self.flow.update_modules(tree_unflatten(fused_layers))

75
flux/flux/datasets.py Normal file
View File

@ -0,0 +1,75 @@
import json
from pathlib import Path
from PIL import Image
class Dataset:
def __getitem__(self, index: int):
raise NotImplementedError()
def __len__(self):
raise NotImplementedError()
class LocalDataset(Dataset):
prompt_key = "prompt"
def __init__(self, dataset: str, data_file):
self.dataset_base = Path(dataset)
with open(data_file, "r") as fid:
self._data = [json.loads(l) for l in fid]
def __len__(self):
return len(self._data)
def __getitem__(self, index: int):
item = self._data[index]
image = Image.open(self.dataset_base / item["image"])
return image, item[self.prompt_key]
class LegacyDataset(LocalDataset):
prompt_key = "text"
def __init__(self, dataset: str):
self.dataset_base = Path(dataset)
with open(self.dataset_base / "index.json") as f:
self._data = json.load(f)["data"]
class HuggingFaceDataset(Dataset):
def __init__(self, dataset: str):
from datasets import load_dataset as hf_load_dataset
self._df = hf_load_dataset(dataset)["train"]
def __len__(self):
return len(self._df)
def __getitem__(self, index: int):
item = self._df[index]
return item["image"], item["prompt"]
def load_dataset(dataset: str):
dataset_base = Path(dataset)
data_file = dataset_base / "train.jsonl"
legacy_file = dataset_base / "index.json"
if data_file.exists():
print(f"Load the local dataset {data_file} .", flush=True)
dataset = LocalDataset(dataset, data_file)
elif legacy_file.exists():
print(f"Load the local dataset {legacy_file} .")
print()
print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
print(" See the README for details.")
print(flush=True)
dataset = LegacyDataset(dataset)
else:
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
dataset = HuggingFaceDataset(dataset)
return dataset

246
flux/flux/flux.py Normal file
View File

@ -0,0 +1,246 @@
# Copyright © 2024 Apple Inc.
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from tqdm import tqdm
from .lora import LoRALinear
from .sampler import FluxSampler
from .utils import (
load_ae,
load_clip,
load_clip_tokenizer,
load_flow_model,
load_t5,
load_t5_tokenizer,
)
class FluxPipeline:
def __init__(self, name: str, t5_padding: bool = True):
self.dtype = mx.bfloat16
self.name = name
self.t5_padding = t5_padding
self.ae = load_ae(name)
self.flow = load_flow_model(name)
self.clip = load_clip(name)
self.clip_tokenizer = load_clip_tokenizer(name)
self.t5 = load_t5(name)
self.t5_tokenizer = load_t5_tokenizer(name)
self.sampler = FluxSampler(name)
def ensure_models_are_loaded(self):
mx.eval(
self.ae.parameters(),
self.flow.parameters(),
self.clip.parameters(),
self.t5.parameters(),
)
def reload_text_encoders(self):
self.t5 = load_t5(self.name)
self.clip = load_clip(self.name)
def tokenize(self, text):
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
clip_tokens = self.clip_tokenizer.encode(text)
return t5_tokens, clip_tokens
def _prepare_latent_images(self, x):
b, h, w, c = x.shape
# Pack the latent image to 2x2 patches
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
# Create positions ids used to positionally encode each patch. Due to
# the way RoPE works, this results in an interesting positional
# encoding where parts of the feature are holding different positional
# information. Namely, the first part holds information independent of
# the spatial position (hence 0s), the 2nd part holds vertical spatial
# information and the last one horizontal.
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
x_ids = mx.stack([i, j, k], axis=-1)
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
return x, x_ids
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
# Prepare the text features
txt = self.t5(t5_tokens)
if len(txt) == 1 and n_images > 1:
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
# Prepare the clip text features
vec = self.clip(clip_tokens).pooled_output
if len(vec) == 1 and n_images > 1:
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
return txt, txt_ids, vec
def _denoising_loop(
self,
x_t,
x_ids,
txt,
txt_ids,
vec,
num_steps: int = 35,
guidance: float = 4.0,
start: float = 1,
stop: float = 0,
):
B = len(x_t)
def scalar(x):
return mx.full((B,), x, dtype=self.dtype)
guidance = scalar(guidance)
timesteps = self.sampler.timesteps(
num_steps,
x_t.shape[1],
start=start,
stop=stop,
)
for i in range(num_steps):
t = timesteps[i]
t_prev = timesteps[i + 1]
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=scalar(t),
guidance=guidance,
)
x_t = self.sampler.step(pred, x_t, t, t_prev)
yield x_t
def generate_latents(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
):
# Set the PRNG state
if seed is not None:
mx.random.seed(seed)
# Create the latent variables
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
x_T, x_ids = self._prepare_latent_images(x_T)
# Get the conditioning
t5_tokens, clip_tokens = self.tokenize(text)
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
# Yield the conditioning for controlled evaluation by the caller
yield (x_T, x_ids, txt, txt_ids, vec)
# Yield the latent sequences from the denoising loop
yield from self._denoising_loop(
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
)
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
h, w = latent_size
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
x = self.ae.decode(x)
return mx.clip(x + 1, 0, 2) * 0.5
def generate_images(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
reload_text_encoders: bool = True,
progress: bool = True,
):
latents = self.generate_latents(
text, n_images, num_steps, guidance, latent_size, seed
)
mx.eval(next(latents))
if reload_text_encoders:
self.reload_text_encoders()
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
mx.eval(x_t)
images = []
for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"):
images.append(self.decode(x_t[i : i + 1]))
mx.eval(images[-1])
images = mx.concatenate(images, axis=0)
mx.eval(images)
return images
def training_loss(
self,
x_0: mx.array,
t5_features: mx.array,
clip_features: mx.array,
guidance: mx.array,
):
# Get the text conditioning
txt = t5_features
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
vec = clip_features
# Prepare the latent input
x_0, x_ids = self._prepare_latent_images(x_0)
# Forward process
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
x_t = self.sampler.add_noise(x_0, t, noise=eps)
x_t = mx.stop_gradient(x_t)
# Do the denoising
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t,
guidance=guidance,
)
return (pred + x_0 - eps).square().mean()
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
"""Swap the linear layers in the transformer blocks with LoRA layers."""
all_blocks = self.flow.double_blocks + self.flow.single_blocks
all_blocks.reverse()
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
for i, block in zip(range(num_blocks), all_blocks):
loras = []
for name, module in block.named_modules():
if isinstance(module, nn.Linear):
loras.append((name, LoRALinear.from_base(module, r=rank)))
block.update_modules(tree_unflatten(loras))
def fuse_lora_layers(self):
fused_layers = []
for name, module in self.flow.named_modules():
if isinstance(module, LoRALinear):
fused_layers.append((name, module.fuse()))
self.flow.update_modules(tree_unflatten(fused_layers))

98
flux/flux/trainer.py Normal file
View File

@ -0,0 +1,98 @@
import mlx.core as mx
import numpy as np
from PIL import Image, ImageFile
from tqdm import tqdm
from .datasets import Dataset
from .flux import FluxPipeline
class Trainer:
def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
self.flux = flux
self.dataset = dataset
self.args = args
self.latents = []
self.t5_features = []
self.clip_features = []
def _random_crop_resize(self, img):
resolution = self.args.resolution
width, height = img.size
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * b) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * c,
pan[1] * d,
crop_size[0] + pan[0] * c,
crop_size[1] + pan[1] * d,
)
)
# Fit the largest rectangle with the ratio of resolution in the image
# rectangle.
width, height = crop_size
ratio = resolution[0] / resolution[1]
r1 = (height * ratio, height)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)
# Finally resize the image to resolution
img = img.resize(resolution, Image.LANCZOS)
return mx.array(np.array(img))
def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
for i in range(num_augmentations):
img = self._random_crop_resize(input_img)
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)
def _encode_prompt(self, prompt):
t5_tok, clip_tok = self.flux.tokenize([prompt])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)
def encode_dataset(self):
"""Encode the images & prompt in the latent space to prepare for training."""
self.flux.ae.eval()
for image, prompt in tqdm(self.dataset, desc="encode dataset"):
self._encode_image(image, self.args.num_augmentations)
self._encode_prompt(prompt)
def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True:
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i : i + batch_size]
c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]

View File

@ -77,7 +77,7 @@ if __name__ == "__main__":
nn.quantize(flux.clip, class_predicate=quantization_predicate)
if args.preload_models:
sd.ensure_models_are_loaded()
flux.ensure_models_are_loaded()
# Make the generator
latent_size = to_latent_size(args.image_size)

View File

@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
the generated prompt. If not provided, the default mappings are used.
- `stop`: (Optional) An array of strings or a single string. Thesse are
- `stop`: (Optional) An array of strings or a single string. These are
sequences of tokens on which the generation should stop.
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
@ -84,7 +84,37 @@ curl localhost:8080/v1/chat/completions \
started in.
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
rlative to the directory the server was started in.
relative to the directory the server was started in.
### Response Fields
- `id`: A unique identifier for the chat.
- `system_fingerprint`: A unique identifier for the system.
- `object`: Any of "chat.completions", "chat.completions.chunk" (for
streaming), or "text.completion".
- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).
- `created`: A time-stamp for when the request was processed.
- `choices`: A list of outputs. Each output is a dictionary containing the fields:
- `index`: The index in the list.
- `logprobs`: A dictionary containing the fields:
- `token_logprobs`: A list of the log probabilities for the generated
tokens.
- `tokens`: A list of the generated token ids.
- `top_logprobs`: A list of lists. Each list contains the `logprobs`
top tokens (if requested) with their corresponding probabilities.
- `finish_reason`: The reason the completion ended. This can be either of
`"stop"` or `"length"`.
- `message`: The text response from the model.
- `usage`: A dictionary containing the fields:
- `prompt_tokens`: The number of prompt tokens processed.
- `completion_tokens`: The number of tokens generated.
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.
### List Models
@ -97,5 +127,5 @@ curl localhost:8080/v1/models -H "Content-Type: application/json"
This will return a list of locally available models where each model in the
list contains the following fields:
- `"id"`: The Hugging Face repo id.
- `"created"`: A timestamp representing the model creation time.
- `id`: The Hugging Face repo id.
- `created`: A time-stamp representing the model creation time.

View File

@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False):
return cache
def can_trim_prompt_cache(cache: List[Any]) -> bool:
"""
Check if model's cache can be trimmed.
"""
return all(c.is_trimmable() for c in cache)
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
"""
Trim the model's cache by the given number of tokens.
@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
Returns:
(int): The number of tokens that were trimmed.
"""
if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
if not can_trim_prompt_cache(cache) or len(cache) == 0:
return 0
return [c.trim(num_tokens) for c in cache][0]

View File

@ -220,17 +220,17 @@ class DeepseekV2Attention(nn.Module):
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1)
if cache is not None:
q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values
)
else:
q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
@ -291,7 +291,7 @@ class MoEGate(nn.Module):
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
scores = scores * self.routed_scaling_factor

View File

@ -3,19 +3,38 @@
import argparse
import json
import logging
import platform
import time
import uuid
import warnings
from dataclasses import dataclass, field
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
from typing import (
Any,
Dict,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
import mlx.core as mx
from huggingface_hub import scan_cache_dir
from ._version import __version__
from .models.cache import make_prompt_cache
from .utils import generate_step, load
def get_system_fingerprint():
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"
class StopCondition(NamedTuple):
stop_met: bool
trim_length: int
@ -94,6 +113,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip()
@dataclass
class PromptCache:
cache: List[Any] = field(default_factory=list)
model_key: Tuple[str, Optional[str]] = ("", None)
tokens: List[int] = field(default_factory=list)
class ModelProvider:
def __init__(self, cli_args: argparse.Namespace):
"""Load models on demand and persist them across the whole process."""
@ -156,12 +182,21 @@ class ModelProvider:
class APIHandler(BaseHTTPRequestHandler):
def __init__(self, model_provider: ModelProvider, *args, **kwargs):
def __init__(
self,
model_provider: ModelProvider,
*args,
prompt_cache: Optional[PromptCache] = None,
system_fingerprint: Optional[str] = None,
**kwargs,
):
"""
Create static request specific metadata
"""
self.created = int(time.time())
self.model_provider = model_provider
self.prompt_cache = prompt_cache or PromptCache()
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
super().__init__(*args, **kwargs)
def _set_cors_headers(self):
@ -215,7 +250,9 @@ class APIHandler(BaseHTTPRequestHandler):
self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model")
self.adapter = self.body.get("adapters", None)
self.max_tokens = self.body.get("max_tokens", 100)
self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 1.0)
self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
@ -343,7 +380,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Static response
response = {
"id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"system_fingerprint": self.system_fingerprint,
"object": self.object_type,
"model": self.requested_model,
"created": self.created,
@ -388,16 +425,30 @@ class APIHandler(BaseHTTPRequestHandler):
return response
def get_prompt_cache(self, prompt):
cache_len = len(self.prompt_cache.tokens)
if (
self.prompt_cache.model_key != self.model_provider.model_key
or cache_len >= len(prompt)
or self.prompt_cache.tokens != prompt[:cache_len]
):
self.prompt_cache.model_key = self.model_provider.model_key
self.prompt_cache.cache = make_prompt_cache(self.model_provider.model)
else:
prompt = prompt[cache_len:]
self.prompt_cache.tokens.extend(prompt)
return prompt
def handle_completion(
self,
prompt: mx.array,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
Generate a response to a prompt and send it to the client in a single batch.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
prompt (List[int]): The tokenized prompt.
stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function
"""
@ -409,17 +460,21 @@ class APIHandler(BaseHTTPRequestHandler):
logging.debug(f"Starting completion:")
token_logprobs = []
top_tokens = []
for (token, logprobs), _ in zip(
prompt = self.get_prompt_cache(prompt)
for _, (token, logprobs) in zip(
range(self.max_tokens),
generate_step(
prompt=prompt,
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache,
),
range(self.max_tokens),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
@ -430,7 +485,7 @@ class APIHandler(BaseHTTPRequestHandler):
top_indices = sorted_indices[: self.logprobs]
top_logprobs = logprobs[top_indices]
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
top_tokens.append(dict(top_token_info))
top_tokens.append(tuple(top_token_info))
token_logprobs.append(logprobs[token].item())
@ -445,6 +500,7 @@ class APIHandler(BaseHTTPRequestHandler):
)
break
self.prompt_cache.tokens.extend(tokens)
detokenizer.finalize()
text = (
detokenizer.text
@ -474,7 +530,7 @@ class APIHandler(BaseHTTPRequestHandler):
def handle_stream(
self,
prompt: mx.array,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
@ -482,7 +538,7 @@ class APIHandler(BaseHTTPRequestHandler):
Sent Events (SSE) stream.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
prompt (mx.array): The tokenized prompt
stop_id_sequences (List[List[int]]): A list of stop words passed to
the stopping_criteria function
"""
@ -496,16 +552,19 @@ class APIHandler(BaseHTTPRequestHandler):
stop_sequence_suffix = None
logging.debug(f"Starting stream:")
for (token, _), _ in zip(
prompt = self.get_prompt_cache(prompt)
for _, (token, _) in zip(
range(self.max_tokens),
generate_step(
prompt=prompt,
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
prompt_cache=self.prompt_cache.cache,
),
range(self.max_tokens),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
@ -531,9 +590,12 @@ class APIHandler(BaseHTTPRequestHandler):
continue
new_text = detokenizer.last_segment
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
if new_text:
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.prompt_cache.tokens.extend(tokens)
# check is there any remaining text to send
detokenizer.finalize()
@ -559,7 +621,7 @@ class APIHandler(BaseHTTPRequestHandler):
):
response = {
"id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"system_fingerprint": self.system_fingerprint,
"object": "chat.completion",
"model": self.requested_model,
"created": self.created,
@ -572,7 +634,7 @@ class APIHandler(BaseHTTPRequestHandler):
}
return response
def handle_chat_completions(self) -> mx.array:
def handle_chat_completions(self) -> List[int]:
"""
Handle a chat completion request.
@ -587,7 +649,6 @@ class APIHandler(BaseHTTPRequestHandler):
self.object_type = (
"chat.completions.chunk" if self.stream else "chat.completions"
)
if (
hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template
@ -602,9 +663,9 @@ class APIHandler(BaseHTTPRequestHandler):
prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = self.tokenizer.encode(prompt)
return mx.array(prompt)
return prompt
def handle_text_completions(self) -> mx.array:
def handle_text_completions(self) -> List[int]:
"""
Handle a text completion request.
@ -614,11 +675,8 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type
self.request_id = f"cmpl-{uuid.uuid4()}"
self.object_type = "text_completion"
assert "prompt" in self.body, "Request did not contain a prompt"
prompt_text = self.body["prompt"]
prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt)
return self.tokenizer.encode(self.body["prompt"])
def do_GET(self):
"""
@ -669,9 +727,16 @@ def run(
handler_class=APIHandler,
):
server_address = (host, port)
prompt_cache = PromptCache()
httpd = server_class(
server_address,
lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs),
lambda *args, **kwargs: handler_class(
model_provider,
prompt_cache=prompt_cache,
system_fingerprint=get_system_fingerprint(),
*args,
**kwargs,
),
)
warnings.warn(
"mlx_lm.server is not recommended for production as "

View File

@ -97,6 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
def text(self):
if self._current_tokens:
self._current_text = self._tokenizer.decode(self._current_tokens)
if (
self._tokenizer.clean_up_tokenization_spaces
and self._current_text[-1] == " "
):
self._current_text = self._current_text[:-1]
if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens)
self._text += self._current_text
@ -164,9 +169,11 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
"""
_byte_decoder = None
_space_matches = (".", "?", "!", ",", "'", "n't", "'m", "'s", "'ve", "'re")
def __init__(self, tokenizer, trim_space=False):
self.trim_space = trim_space
def __init__(self, tokenizer):
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
# Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab)
@ -185,17 +192,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
self.text = ""
self.tokens = []
def _maybe_trim_space(self, current_text):
if current_text[0] != " ":
return current_text
elif not self.text:
return current_text[1:]
elif self.clean_spaces and current_text[1:].startswith(self._space_matches):
return current_text[1:]
return current_text
def add_token(self, token):
v = self.tokenmap[token]
# if the token starts with space
if self._byte_decoder[v[0]] == 32:
current_text = bytearray(
self._byte_decoder[c] for c in self._unflushed
).decode("utf-8")
if self.text or not self.trim_space:
self.text += current_text
else:
self.text += _remove_space(current_text)
self.text += self._maybe_trim_space(current_text)
self._unflushed = v
else:
self._unflushed += v
@ -204,10 +216,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
"utf-8"
)
if self.text or not self.trim_space:
self.text += current_text
else:
self.text += _remove_space(current_text)
self.text += self._maybe_trim_space(current_text)
self._unflushed = ""
@classmethod
@ -303,14 +312,7 @@ def _is_spm_decoder_no_space(decoder):
def _is_bpe_decoder(decoder):
_target_description = {
"type": "ByteLevel",
"add_prefix_space": False,
"trim_offsets": False,
"use_regex": False,
}
return _match(_target_description, decoder)
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
def load_tokenizer(model_path, tokenizer_config_extra={}):

View File

@ -246,10 +246,10 @@ def generate_step(
y, logprobs = _step(y)
mx.async_eval(y)
mx.async_eval(y, logprobs)
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y)
mx.async_eval(next_y, next_logprobs)
yield y.item(), logprobs
y, logprobs = next_y, next_logprobs
@ -348,7 +348,9 @@ def generate(
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
with mx.stream(mx.cpu):
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)

View File

@ -1,5 +1,6 @@
# Copyright © 2024 Apple Inc.
import copy
import os
import tempfile
import unittest
@ -215,6 +216,28 @@ class TestPromptCache(unittest.TestCase):
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
)
def test_cache_copying(self):
cache = [KVCache()]
x = mx.random.uniform(shape=(1, 8, 10, 4))
cache[0].update_and_fetch(x, x)
y = mx.random.uniform(shape=(1, 8, 1, 4))
cache[0].update_and_fetch(y, y)
old_cache = copy.deepcopy(cache)
trim_prompt_cache(cache, 1)
self.assertTrue(old_cache[0].offset, 11)
self.assertTrue(cache[0].offset, 10)
z = mx.random.uniform(shape=(1, 8, 1, 4))
cache[0].update_and_fetch(z, z)
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
if __name__ == "__main__":
unittest.main()

View File

@ -14,6 +14,7 @@ class DummyModelProvider:
def __init__(self):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH)
self.model_key = (HF_MODEL_PATH, None)
def load(self, model, adapter=None):
assert model in ["default_model", "chat_model"]

View File

@ -0,0 +1,76 @@
# Copyright © 2024 Apple Inc.
import unittest
from pathlib import Path
from huggingface_hub import snapshot_download
from mlx_lm.tokenizer_utils import (
BPEStreamingDetokenizer,
NaiveStreamingDetokenizer,
SPMStreamingDetokenizer,
load_tokenizer,
)
class TestTokenizers(unittest.TestCase):
def download_tokenizer(self, repo):
path = Path(
snapshot_download(
repo_id=repo,
allow_patterns=[
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"tokenizer.model",
],
)
)
return load_tokenizer(path)
def check_tokenizer(self, tokenizer):
def check(tokens):
expected_text = tokenizer.decode(tokens)
detokenizer = tokenizer.detokenizer
detokenizer.reset()
text = ""
for t in tokens:
detokenizer.add_token(t)
seg = detokenizer.last_segment
text += seg
detokenizer.finalize()
text += detokenizer.last_segment
self.assertEqual(text, expected_text)
tokens = tokenizer.encode("a ,b")
check(tokens)
tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}')
check(tokens)
tokens = tokenizer.encode("3 3")
check(tokens)
def test_tokenizers(self):
tokenizer_repos = [
("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer),
("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer),
("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer),
("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer),
("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer),
]
for tokenizer_repo, expected_detokenizer in tokenizer_repos:
with self.subTest(tokenizer=tokenizer_repo):
tokenizer = self.download_tokenizer(tokenizer_repo)
tokenizer.decode([0, 1, 2])
self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer))
self.check_tokenizer(tokenizer)
# Try one with a naive detokenizer
tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit")
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
self.check_tokenizer(tokenizer)
if __name__ == "__main__":
unittest.main()