diff --git a/.circleci/config.yml b/.circleci/config.yml index 02fa1de8..cecd2d57 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -26,8 +26,8 @@ jobs: - run: name: Install dependencies command: | - brew install python@3.8 - python3.8 -m venv env + brew install python@3.9 + python3.9 -m venv env source env/bin/activate pip install --upgrade pip pip install unittest-xml-reporting diff --git a/README.md b/README.md index bd180975..88888ad0 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,10 @@ Some more useful examples are listed below. ### Image Models +- Generating images + - [FLUX](flux) + - [Stable Diffusion or SDXL](stable_diffusion) - Image classification using [ResNets on CIFAR-10](cifar). -- Generating images with [Stable Diffusion or SDXL](stable_diffusion). - Convolutional variational autoencoder [(CVAE) on MNIST](cvae). ### Audio Models diff --git a/flux/README.md b/flux/README.md index 62eb9b62..1a17e386 100644 --- a/flux/README.md +++ b/flux/README.md @@ -21,13 +21,34 @@ The dependencies are minimal, namely: - `huggingface-hub` to download the checkpoints. - `regex` for the tokenization -- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script +- `tqdm`, `PIL`, and `numpy` for the scripts - `sentencepiece` for the T5 tokenizer +- `datasets` for using an HF dataset directly You can install all of the above with the `requirements.txt` as follows: pip install -r requirements.txt + +Usage +--------- + +You can use the following command to generate an image, using `--output` to specify the storage location of the image, defaulting to `out.png`. + +```shell +python txt2image.py --model schnell \ + --n-images 1 \ + --image-size 256x512 \ + --verbose \ + 'A photo of an astronaut riding a horse on Mars.' +``` + +For more parameters, please use the `--help` command to view. + +```shell +python txt2image.py --help +``` + Inference --------- @@ -78,7 +99,11 @@ except for some additional logic to quantize and/or load trained adapters. One can use the script as follows: ```shell -python txt2image.py --n-images 4 --n-rows 2 --image-size 256x512 'A photo of an astronaut riding a horse on Mars.' +python txt2image.py \ + --n-images 4 \ + --n-rows 2 \ + --image-size 256x512 \ + 'A photo of an astronaut riding a horse on Mars.' ``` ### Experimental Options @@ -94,17 +119,12 @@ Finetuning The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell but ymmv) on a provided image dataset. The dataset folder must have an -`index.json` file with the following format: +`train.jsonl` file with the following format: -```json -{ - "data": [ - {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, - {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, - {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, - ... - ] -} +```jsonl +{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"} +{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"} +... ``` The training script by default trains for 600 iterations with a batch size of @@ -126,19 +146,15 @@ The training images are the following 5 images [^2]: ![dog6](static/dog6.png) -We start by making the following `index.json` file and placing it in the same +We start by making the following `train.jsonl` file and placing it in the same folder as the images. -```json -{ - "data": [ - {"image": "00.jpg", "text": "A photo of sks dog"}, - {"image": "01.jpg", "text": "A photo of sks dog"}, - {"image": "02.jpg", "text": "A photo of sks dog"}, - {"image": "03.jpg", "text": "A photo of sks dog"}, - {"image": "04.jpg", "text": "A photo of sks dog"} - ] -} +```jsonl +{"image": "00.jpg", "prompt": "A photo of sks dog"} +{"image": "01.jpg", "prompt": "A photo of sks dog"} +{"image": "02.jpg", "prompt": "A photo of sks dog"} +{"image": "03.jpg", "prompt": "A photo of sks dog"} +{"image": "04.jpg", "prompt": "A photo of sks dog"} ``` Subsequently we finetune FLUX using the following command: @@ -151,6 +167,17 @@ python dreambooth.py \ path/to/dreambooth/dataset/dog6 ``` + +Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning. + +```shell +python dreambooth.py \ + --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \ + --progress-every 600 --iterations 1200 --learning-rate 0.0001 \ + --lora-rank 4 --grad-accumulate 8 \ + mlx-community/dreambooth-dog6 +``` + The training requires approximately 50GB of RAM and on an M2 Ultra it takes a bit more than 1 hour. diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 4a4dbb08..48dcad47 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -1,7 +1,6 @@ # Copyright © 2024 Apple Inc. import argparse -import json import time from functools import partial from pathlib import Path @@ -13,105 +12,8 @@ import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten, tree_map, tree_reduce from PIL import Image -from tqdm import tqdm -from flux import FluxPipeline - - -class FinetuningDataset: - def __init__(self, flux, args): - self.args = args - self.flux = flux - self.dataset_base = Path(args.dataset) - dataset_index = self.dataset_base / "index.json" - if not dataset_index.exists(): - raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset") - with open(dataset_index, "r") as f: - self.index = json.load(f) - - self.latents = [] - self.t5_features = [] - self.clip_features = [] - - def _random_crop_resize(self, img): - resolution = self.args.resolution - width, height = img.size - - a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist() - - # Random crop the input image between 0.8 to 1.0 of its original dimensions - crop_size = ( - max((0.8 + 0.2 * a) * width, resolution[0]), - max((0.8 + 0.2 * a) * height, resolution[1]), - ) - pan = (width - crop_size[0], height - crop_size[1]) - img = img.crop( - ( - pan[0] * b, - pan[1] * c, - crop_size[0] + pan[0] * b, - crop_size[1] + pan[1] * c, - ) - ) - - # Fit the largest rectangle with the ratio of resolution in the image - # rectangle. - width, height = crop_size - ratio = resolution[0] / resolution[1] - r1 = (height * ratio, height) - r2 = (width, width / ratio) - r = r1 if r1[0] <= width else r2 - img = img.crop( - ( - (width - r[0]) / 2, - (height - r[1]) / 2, - (width + r[0]) / 2, - (height + r[1]) / 2, - ) - ) - - # Finally resize the image to resolution - img = img.resize(resolution, Image.LANCZOS) - - return mx.array(np.array(img)) - - def encode_images(self): - """Encode the images in the latent space to prepare for training.""" - self.flux.ae.eval() - for sample in tqdm(self.index["data"]): - input_img = Image.open(self.dataset_base / sample["image"]) - for i in range(self.args.num_augmentations): - img = self._random_crop_resize(input_img) - img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1 - x_0 = self.flux.ae.encode(img[None]) - x_0 = x_0.astype(self.flux.dtype) - mx.eval(x_0) - self.latents.append(x_0) - - def encode_prompts(self): - """Pre-encode the prompts so that we don't recompute them during - training (doesn't allow finetuning the text encoders).""" - for sample in tqdm(self.index["data"]): - t5_tok, clip_tok = self.flux.tokenize([sample["text"]]) - t5_feat = self.flux.t5(t5_tok) - clip_feat = self.flux.clip(clip_tok).pooled_output - mx.eval(t5_feat, clip_feat) - self.t5_features.append(t5_feat) - self.clip_features.append(clip_feat) - - def iterate(self, batch_size): - xs = mx.concatenate(self.latents) - t5 = mx.concatenate(self.t5_features) - clip = mx.concatenate(self.clip_features) - mx.eval(xs, t5, clip) - n_aug = self.args.num_augmentations - while True: - x_indices = mx.random.permutation(len(self.latents)) - c_indices = x_indices // n_aug - for i in range(0, len(self.latents), batch_size): - x_i = x_indices[i : i + batch_size] - c_i = c_indices[i : i + batch_size] - yield xs[x_i], t5[c_i], clip[c_i] +from flux import FluxPipeline, Trainer, load_dataset def generate_progress_images(iteration, flux, args): @@ -157,7 +59,8 @@ def save_adapters(iteration, flux, args): ) -if __name__ == "__main__": +def setup_arg_parser(): + """Set up and return the argument parser.""" parser = argparse.ArgumentParser( description="Finetune Flux to generate images with a specific subject" ) @@ -247,7 +150,11 @@ if __name__ == "__main__": ) parser.add_argument("dataset") + return parser + +if __name__ == "__main__": + parser = setup_arg_parser() args = parser.parse_args() # Load the model and set it up for LoRA training. We use the same random @@ -267,7 +174,7 @@ if __name__ == "__main__": trainable_params = tree_reduce( lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0 ) - print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True) + print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True) # Set up the optimizer and training steps. The steps are a bit verbose to # support gradient accumulation together with compilation. @@ -340,10 +247,10 @@ if __name__ == "__main__": x, t5_feat, clip_feat, guidance, prev_grads ) - print("Create the training dataset.", flush=True) - dataset = FinetuningDataset(flux, args) - dataset.encode_images() - dataset.encode_prompts() + dataset = load_dataset(args.dataset) + trainer = Trainer(flux, dataset, args) + trainer.encode_dataset() + guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype) # An initial generation to compare @@ -352,7 +259,7 @@ if __name__ == "__main__": grads = None losses = [] tic = time.time() - for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)): + for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)): loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0) mx.eval(loss, grads, state) losses.append(loss.item()) @@ -361,7 +268,7 @@ if __name__ == "__main__": toc = time.time() peak_mem = mx.metal.get_peak_memory() / 1024**3 print( - f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} " + f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " f"It/s: {10 / (toc - tic):.3f} " f"Peak mem: {peak_mem:.3f} GB", flush=True, diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py index 8d39d605..b1122d75 100644 --- a/flux/flux/__init__.py +++ b/flux/flux/__init__.py @@ -1,16 +1,10 @@ # Copyright © 2024 Apple Inc. -import math -import time -from typing import Tuple - -import mlx.core as mx -import mlx.nn as nn -from mlx.utils import tree_unflatten -from tqdm import tqdm - +from .datasets import Dataset, load_dataset +from .flux import FluxPipeline from .lora import LoRALinear from .sampler import FluxSampler +from .trainer import Trainer from .utils import ( load_ae, load_clip, @@ -19,230 +13,3 @@ from .utils import ( load_t5, load_t5_tokenizer, ) - - -class FluxPipeline: - def __init__(self, name: str, t5_padding: bool = True): - self.dtype = mx.bfloat16 - self.name = name - self.t5_padding = t5_padding - - self.ae = load_ae(name) - self.flow = load_flow_model(name) - self.clip = load_clip(name) - self.clip_tokenizer = load_clip_tokenizer(name) - self.t5 = load_t5(name) - self.t5_tokenizer = load_t5_tokenizer(name) - self.sampler = FluxSampler(name) - - def ensure_models_are_loaded(self): - mx.eval( - self.ae.parameters(), - self.flow.parameters(), - self.clip.parameters(), - self.t5.parameters(), - ) - - def reload_text_encoders(self): - self.t5 = load_t5(self.name) - self.clip = load_clip(self.name) - - def tokenize(self, text): - t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding) - clip_tokens = self.clip_tokenizer.encode(text) - return t5_tokens, clip_tokens - - def _prepare_latent_images(self, x): - b, h, w, c = x.shape - - # Pack the latent image to 2x2 patches - x = x.reshape(b, h // 2, 2, w // 2, 2, c) - x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4) - - # Create positions ids used to positionally encode each patch. Due to - # the way RoPE works, this results in an interesting positional - # encoding where parts of the feature are holding different positional - # information. Namely, the first part holds information independent of - # the spatial position (hence 0s), the 2nd part holds vertical spatial - # information and the last one horizontal. - i = mx.zeros((h // 2, w // 2), dtype=mx.int32) - j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij") - x_ids = mx.stack([i, j, k], axis=-1) - x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0) - - return x, x_ids - - def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens): - # Prepare the text features - txt = self.t5(t5_tokens) - if len(txt) == 1 and n_images > 1: - txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:])) - txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32) - - # Prepare the clip text features - vec = self.clip(clip_tokens).pooled_output - if len(vec) == 1 and n_images > 1: - vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:])) - - return txt, txt_ids, vec - - def _denoising_loop( - self, - x_t, - x_ids, - txt, - txt_ids, - vec, - num_steps: int = 35, - guidance: float = 4.0, - start: float = 1, - stop: float = 0, - ): - B = len(x_t) - - def scalar(x): - return mx.full((B,), x, dtype=self.dtype) - - guidance = scalar(guidance) - timesteps = self.sampler.timesteps( - num_steps, - x_t.shape[1], - start=start, - stop=stop, - ) - for i in range(num_steps): - t = timesteps[i] - t_prev = timesteps[i + 1] - - pred = self.flow( - img=x_t, - img_ids=x_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - timesteps=scalar(t), - guidance=guidance, - ) - x_t = self.sampler.step(pred, x_t, t, t_prev) - - yield x_t - - def generate_latents( - self, - text: str, - n_images: int = 1, - num_steps: int = 35, - guidance: float = 4.0, - latent_size: Tuple[int, int] = (64, 64), - seed=None, - ): - # Set the PRNG state - if seed is not None: - mx.random.seed(seed) - - # Create the latent variables - x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype) - x_T, x_ids = self._prepare_latent_images(x_T) - - # Get the conditioning - t5_tokens, clip_tokens = self.tokenize(text) - txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens) - - # Yield the conditioning for controlled evaluation by the caller - yield (x_T, x_ids, txt, txt_ids, vec) - - # Yield the latent sequences from the denoising loop - yield from self._denoising_loop( - x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance - ) - - def decode(self, x, latent_size: Tuple[int, int] = (64, 64)): - h, w = latent_size - x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2) - x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1) - x = self.ae.decode(x) - return mx.clip(x + 1, 0, 2) * 0.5 - - def generate_images( - self, - text: str, - n_images: int = 1, - num_steps: int = 35, - guidance: float = 4.0, - latent_size: Tuple[int, int] = (64, 64), - seed=None, - reload_text_encoders: bool = True, - progress: bool = True, - ): - latents = self.generate_latents( - text, n_images, num_steps, guidance, latent_size, seed - ) - mx.eval(next(latents)) - - if reload_text_encoders: - self.reload_text_encoders() - - for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True): - mx.eval(x_t) - - images = [] - for i in tqdm(range(len(x_t)), disable=not progress): - images.append(self.decode(x_t[i : i + 1])) - mx.eval(images[-1]) - images = mx.concatenate(images, axis=0) - mx.eval(images) - - return images - - def training_loss( - self, - x_0: mx.array, - t5_features: mx.array, - clip_features: mx.array, - guidance: mx.array, - ): - # Get the text conditioning - txt = t5_features - txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32) - vec = clip_features - - # Prepare the latent input - x_0, x_ids = self._prepare_latent_images(x_0) - - # Forward process - t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) - eps = mx.random.normal(x_0.shape, dtype=self.dtype) - x_t = self.sampler.add_noise(x_0, t, noise=eps) - x_t = mx.stop_gradient(x_t) - - # Do the denoising - pred = self.flow( - img=x_t, - img_ids=x_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - timesteps=t, - guidance=guidance, - ) - - return (pred + x_0 - eps).square().mean() - - def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1): - """Swap the linear layers in the transformer blocks with LoRA layers.""" - all_blocks = self.flow.double_blocks + self.flow.single_blocks - all_blocks.reverse() - num_blocks = num_blocks if num_blocks > 0 else len(all_blocks) - for i, block in zip(range(num_blocks), all_blocks): - loras = [] - for name, module in block.named_modules(): - if isinstance(module, nn.Linear): - loras.append((name, LoRALinear.from_base(module, r=rank))) - block.update_modules(tree_unflatten(loras)) - - def fuse_lora_layers(self): - fused_layers = [] - for name, module in self.flow.named_modules(): - if isinstance(module, LoRALinear): - fused_layers.append((name, module.fuse())) - self.flow.update_modules(tree_unflatten(fused_layers)) diff --git a/flux/flux/datasets.py b/flux/flux/datasets.py new file mode 100644 index 00000000..d31a09f1 --- /dev/null +++ b/flux/flux/datasets.py @@ -0,0 +1,75 @@ +import json +from pathlib import Path + +from PIL import Image + + +class Dataset: + def __getitem__(self, index: int): + raise NotImplementedError() + + def __len__(self): + raise NotImplementedError() + + +class LocalDataset(Dataset): + prompt_key = "prompt" + + def __init__(self, dataset: str, data_file): + self.dataset_base = Path(dataset) + with open(data_file, "r") as fid: + self._data = [json.loads(l) for l in fid] + + def __len__(self): + return len(self._data) + + def __getitem__(self, index: int): + item = self._data[index] + image = Image.open(self.dataset_base / item["image"]) + return image, item[self.prompt_key] + + +class LegacyDataset(LocalDataset): + prompt_key = "text" + + def __init__(self, dataset: str): + self.dataset_base = Path(dataset) + with open(self.dataset_base / "index.json") as f: + self._data = json.load(f)["data"] + + +class HuggingFaceDataset(Dataset): + + def __init__(self, dataset: str): + from datasets import load_dataset as hf_load_dataset + + self._df = hf_load_dataset(dataset)["train"] + + def __len__(self): + return len(self._df) + + def __getitem__(self, index: int): + item = self._df[index] + return item["image"], item["prompt"] + + +def load_dataset(dataset: str): + dataset_base = Path(dataset) + data_file = dataset_base / "train.jsonl" + legacy_file = dataset_base / "index.json" + + if data_file.exists(): + print(f"Load the local dataset {data_file} .", flush=True) + dataset = LocalDataset(dataset, data_file) + elif legacy_file.exists(): + print(f"Load the local dataset {legacy_file} .") + print() + print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.") + print(" See the README for details.") + print(flush=True) + dataset = LegacyDataset(dataset) + else: + print(f"Load the Hugging Face dataset {dataset} .", flush=True) + dataset = HuggingFaceDataset(dataset) + + return dataset diff --git a/flux/flux/flux.py b/flux/flux/flux.py new file mode 100644 index 00000000..3fd044ac --- /dev/null +++ b/flux/flux/flux.py @@ -0,0 +1,246 @@ +# Copyright © 2024 Apple Inc. + +from typing import Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten +from tqdm import tqdm + +from .lora import LoRALinear +from .sampler import FluxSampler +from .utils import ( + load_ae, + load_clip, + load_clip_tokenizer, + load_flow_model, + load_t5, + load_t5_tokenizer, +) + + +class FluxPipeline: + def __init__(self, name: str, t5_padding: bool = True): + self.dtype = mx.bfloat16 + self.name = name + self.t5_padding = t5_padding + + self.ae = load_ae(name) + self.flow = load_flow_model(name) + self.clip = load_clip(name) + self.clip_tokenizer = load_clip_tokenizer(name) + self.t5 = load_t5(name) + self.t5_tokenizer = load_t5_tokenizer(name) + self.sampler = FluxSampler(name) + + def ensure_models_are_loaded(self): + mx.eval( + self.ae.parameters(), + self.flow.parameters(), + self.clip.parameters(), + self.t5.parameters(), + ) + + def reload_text_encoders(self): + self.t5 = load_t5(self.name) + self.clip = load_clip(self.name) + + def tokenize(self, text): + t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding) + clip_tokens = self.clip_tokenizer.encode(text) + return t5_tokens, clip_tokens + + def _prepare_latent_images(self, x): + b, h, w, c = x.shape + + # Pack the latent image to 2x2 patches + x = x.reshape(b, h // 2, 2, w // 2, 2, c) + x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4) + + # Create positions ids used to positionally encode each patch. Due to + # the way RoPE works, this results in an interesting positional + # encoding where parts of the feature are holding different positional + # information. Namely, the first part holds information independent of + # the spatial position (hence 0s), the 2nd part holds vertical spatial + # information and the last one horizontal. + i = mx.zeros((h // 2, w // 2), dtype=mx.int32) + j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij") + x_ids = mx.stack([i, j, k], axis=-1) + x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0) + + return x, x_ids + + def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens): + # Prepare the text features + txt = self.t5(t5_tokens) + if len(txt) == 1 and n_images > 1: + txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:])) + txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32) + + # Prepare the clip text features + vec = self.clip(clip_tokens).pooled_output + if len(vec) == 1 and n_images > 1: + vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:])) + + return txt, txt_ids, vec + + def _denoising_loop( + self, + x_t, + x_ids, + txt, + txt_ids, + vec, + num_steps: int = 35, + guidance: float = 4.0, + start: float = 1, + stop: float = 0, + ): + B = len(x_t) + + def scalar(x): + return mx.full((B,), x, dtype=self.dtype) + + guidance = scalar(guidance) + timesteps = self.sampler.timesteps( + num_steps, + x_t.shape[1], + start=start, + stop=stop, + ) + for i in range(num_steps): + t = timesteps[i] + t_prev = timesteps[i + 1] + + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=scalar(t), + guidance=guidance, + ) + x_t = self.sampler.step(pred, x_t, t, t_prev) + + yield x_t + + def generate_latents( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + ): + # Set the PRNG state + if seed is not None: + mx.random.seed(seed) + + # Create the latent variables + x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype) + x_T, x_ids = self._prepare_latent_images(x_T) + + # Get the conditioning + t5_tokens, clip_tokens = self.tokenize(text) + txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens) + + # Yield the conditioning for controlled evaluation by the caller + yield (x_T, x_ids, txt, txt_ids, vec) + + # Yield the latent sequences from the denoising loop + yield from self._denoising_loop( + x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance + ) + + def decode(self, x, latent_size: Tuple[int, int] = (64, 64)): + h, w = latent_size + x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2) + x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1) + x = self.ae.decode(x) + return mx.clip(x + 1, 0, 2) * 0.5 + + def generate_images( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + reload_text_encoders: bool = True, + progress: bool = True, + ): + latents = self.generate_latents( + text, n_images, num_steps, guidance, latent_size, seed + ) + mx.eval(next(latents)) + + if reload_text_encoders: + self.reload_text_encoders() + + for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True): + mx.eval(x_t) + + images = [] + for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"): + images.append(self.decode(x_t[i : i + 1])) + mx.eval(images[-1]) + images = mx.concatenate(images, axis=0) + mx.eval(images) + + return images + + def training_loss( + self, + x_0: mx.array, + t5_features: mx.array, + clip_features: mx.array, + guidance: mx.array, + ): + # Get the text conditioning + txt = t5_features + txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32) + vec = clip_features + + # Prepare the latent input + x_0, x_ids = self._prepare_latent_images(x_0) + + # Forward process + t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) + eps = mx.random.normal(x_0.shape, dtype=self.dtype) + x_t = self.sampler.add_noise(x_0, t, noise=eps) + x_t = mx.stop_gradient(x_t) + + # Do the denoising + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t, + guidance=guidance, + ) + + return (pred + x_0 - eps).square().mean() + + def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1): + """Swap the linear layers in the transformer blocks with LoRA layers.""" + all_blocks = self.flow.double_blocks + self.flow.single_blocks + all_blocks.reverse() + num_blocks = num_blocks if num_blocks > 0 else len(all_blocks) + for i, block in zip(range(num_blocks), all_blocks): + loras = [] + for name, module in block.named_modules(): + if isinstance(module, nn.Linear): + loras.append((name, LoRALinear.from_base(module, r=rank))) + block.update_modules(tree_unflatten(loras)) + + def fuse_lora_layers(self): + fused_layers = [] + for name, module in self.flow.named_modules(): + if isinstance(module, LoRALinear): + fused_layers.append((name, module.fuse())) + self.flow.update_modules(tree_unflatten(fused_layers)) diff --git a/flux/flux/trainer.py b/flux/flux/trainer.py new file mode 100644 index 00000000..40a126e8 --- /dev/null +++ b/flux/flux/trainer.py @@ -0,0 +1,98 @@ +import mlx.core as mx +import numpy as np +from PIL import Image, ImageFile +from tqdm import tqdm + +from .datasets import Dataset +from .flux import FluxPipeline + + +class Trainer: + + def __init__(self, flux: FluxPipeline, dataset: Dataset, args): + self.flux = flux + self.dataset = dataset + self.args = args + self.latents = [] + self.t5_features = [] + self.clip_features = [] + + def _random_crop_resize(self, img): + resolution = self.args.resolution + width, height = img.size + + a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist() + + # Random crop the input image between 0.8 to 1.0 of its original dimensions + crop_size = ( + max((0.8 + 0.2 * a) * width, resolution[0]), + max((0.8 + 0.2 * b) * height, resolution[1]), + ) + pan = (width - crop_size[0], height - crop_size[1]) + img = img.crop( + ( + pan[0] * c, + pan[1] * d, + crop_size[0] + pan[0] * c, + crop_size[1] + pan[1] * d, + ) + ) + + # Fit the largest rectangle with the ratio of resolution in the image + # rectangle. + width, height = crop_size + ratio = resolution[0] / resolution[1] + r1 = (height * ratio, height) + r2 = (width, width / ratio) + r = r1 if r1[0] <= width else r2 + img = img.crop( + ( + (width - r[0]) / 2, + (height - r[1]) / 2, + (width + r[0]) / 2, + (height + r[1]) / 2, + ) + ) + + # Finally resize the image to resolution + img = img.resize(resolution, Image.LANCZOS) + + return mx.array(np.array(img)) + + def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int): + for i in range(num_augmentations): + img = self._random_crop_resize(input_img) + img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1 + x_0 = self.flux.ae.encode(img[None]) + x_0 = x_0.astype(self.flux.dtype) + mx.eval(x_0) + self.latents.append(x_0) + + def _encode_prompt(self, prompt): + t5_tok, clip_tok = self.flux.tokenize([prompt]) + t5_feat = self.flux.t5(t5_tok) + clip_feat = self.flux.clip(clip_tok).pooled_output + mx.eval(t5_feat, clip_feat) + self.t5_features.append(t5_feat) + self.clip_features.append(clip_feat) + + def encode_dataset(self): + """Encode the images & prompt in the latent space to prepare for training.""" + self.flux.ae.eval() + for image, prompt in tqdm(self.dataset, desc="encode dataset"): + self._encode_image(image, self.args.num_augmentations) + self._encode_prompt(prompt) + + def iterate(self, batch_size): + xs = mx.concatenate(self.latents) + t5 = mx.concatenate(self.t5_features) + clip = mx.concatenate(self.clip_features) + mx.eval(xs, t5, clip) + n_aug = self.args.num_augmentations + while True: + x_indices = mx.random.permutation(len(self.latents)) + c_indices = x_indices // n_aug + for i in range(0, len(self.latents), batch_size): + x_i = x_indices[i : i + batch_size] + c_i = c_indices[i : i + batch_size] + yield xs[x_i], t5[c_i], clip[c_i] diff --git a/flux/txt2image.py b/flux/txt2image.py index bf2f7294..5ebec81a 100644 --- a/flux/txt2image.py +++ b/flux/txt2image.py @@ -77,7 +77,7 @@ if __name__ == "__main__": nn.quantize(flux.clip, class_predicate=quantization_predicate) if args.preload_models: - sd.ensure_models_are_loaded() + flux.ensure_models_are_loaded() # Make the generator latent_size = to_latent_size(args.image_size) diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index 55be1c9c..2976a09f 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \ - `role_mapping`: (Optional) A dictionary to customize the role prefixes in the generated prompt. If not provided, the default mappings are used. -- `stop`: (Optional) An array of strings or a single string. Thesse are +- `stop`: (Optional) An array of strings or a single string. These are sequences of tokens on which the generation should stop. - `max_tokens`: (Optional) An integer specifying the maximum number of tokens @@ -84,7 +84,37 @@ curl localhost:8080/v1/chat/completions \ started in. - `adapters`: (Optional) A string path to low-rank adapters. The path must be - rlative to the directory the server was started in. + relative to the directory the server was started in. + +### Response Fields + +- `id`: A unique identifier for the chat. + +- `system_fingerprint`: A unique identifier for the system. + +- `object`: Any of "chat.completions", "chat.completions.chunk" (for + streaming), or "text.completion". + +- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`). + +- `created`: A time-stamp for when the request was processed. + +- `choices`: A list of outputs. Each output is a dictionary containing the fields: + - `index`: The index in the list. + - `logprobs`: A dictionary containing the fields: + - `token_logprobs`: A list of the log probabilities for the generated + tokens. + - `tokens`: A list of the generated token ids. + - `top_logprobs`: A list of lists. Each list contains the `logprobs` + top tokens (if requested) with their corresponding probabilities. + - `finish_reason`: The reason the completion ended. This can be either of + `"stop"` or `"length"`. + - `message`: The text response from the model. + +- `usage`: A dictionary containing the fields: + - `prompt_tokens`: The number of prompt tokens processed. + - `completion_tokens`: The number of tokens generated. + - `total_tokens`: The total number of tokens, i.e. the sum of the above two fields. ### List Models @@ -97,5 +127,5 @@ curl localhost:8080/v1/models -H "Content-Type: application/json" This will return a list of locally available models where each model in the list contains the following fields: -- `"id"`: The Hugging Face repo id. -- `"created"`: A timestamp representing the model creation time. +- `id`: The Hugging Face repo id. +- `created`: A time-stamp representing the model creation time. diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index b06422e5..a6a56e0a 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False): return cache +def can_trim_prompt_cache(cache: List[Any]) -> bool: + """ + Check if model's cache can be trimmed. + """ + return all(c.is_trimmable() for c in cache) + + def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: """ Trim the model's cache by the given number of tokens. @@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: Returns: (int): The number of tokens that were trimmed. """ - if not all(c.is_trimmable() for c in cache) or len(cache) == 0: + if not can_trim_prompt_cache(cache) or len(cache) == 0: return 0 return [c.trim(num_tokens) for c in cache][0] diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 17d061a8..bb3e5184 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -220,17 +220,17 @@ class DeepseekV2Attention(nn.Module): k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) - k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1) - if cache is not None: q_pe = self.rope(q_pe, cache.offset) k_pe = self.rope(k_pe, cache.offset) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) keys, values = cache.update_and_fetch( mx.concatenate([k_nope, k_pe], axis=-1), values ) else: q_pe = self.rope(q_pe) k_pe = self.rope(k_pe) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) keys = mx.concatenate([k_nope, k_pe], axis=-1) queries = mx.concatenate([q_nope, q_pe], axis=-1) @@ -291,7 +291,7 @@ class MoEGate(nn.Module): scores = scores.reshape(bsz, seq_len, -1) k = self.top_k - inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]) + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] scores = mx.take_along_axis(scores, inds, axis=-1) scores = scores * self.routed_scaling_factor diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 42962b54..ec659969 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -3,19 +3,38 @@ import argparse import json import logging +import platform import time import uuid import warnings +from dataclasses import dataclass, field from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path -from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union +from typing import ( + Any, + Dict, + List, + Literal, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) import mlx.core as mx from huggingface_hub import scan_cache_dir +from ._version import __version__ +from .models.cache import make_prompt_cache from .utils import generate_step, load +def get_system_fingerprint(): + gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else "" + return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}" + + class StopCondition(NamedTuple): stop_met: bool trim_length: int @@ -94,6 +113,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): return prompt.rstrip() +@dataclass +class PromptCache: + cache: List[Any] = field(default_factory=list) + model_key: Tuple[str, Optional[str]] = ("", None) + tokens: List[int] = field(default_factory=list) + + class ModelProvider: def __init__(self, cli_args: argparse.Namespace): """Load models on demand and persist them across the whole process.""" @@ -156,12 +182,21 @@ class ModelProvider: class APIHandler(BaseHTTPRequestHandler): - def __init__(self, model_provider: ModelProvider, *args, **kwargs): + def __init__( + self, + model_provider: ModelProvider, + *args, + prompt_cache: Optional[PromptCache] = None, + system_fingerprint: Optional[str] = None, + **kwargs, + ): """ Create static request specific metadata """ self.created = int(time.time()) self.model_provider = model_provider + self.prompt_cache = prompt_cache or PromptCache() + self.system_fingerprint = system_fingerprint or get_system_fingerprint() super().__init__(*args, **kwargs) def _set_cors_headers(self): @@ -215,7 +250,9 @@ class APIHandler(BaseHTTPRequestHandler): self.stream_options = self.body.get("stream_options", None) self.requested_model = self.body.get("model", "default_model") self.adapter = self.body.get("adapters", None) - self.max_tokens = self.body.get("max_tokens", 100) + self.max_tokens = self.body.get("max_completion_tokens", None) + if self.max_tokens is None: + self.max_tokens = self.body.get("max_tokens", 512) self.temperature = self.body.get("temperature", 1.0) self.top_p = self.body.get("top_p", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0) @@ -343,7 +380,7 @@ class APIHandler(BaseHTTPRequestHandler): # Static response response = { "id": self.request_id, - "system_fingerprint": f"fp_{uuid.uuid4()}", + "system_fingerprint": self.system_fingerprint, "object": self.object_type, "model": self.requested_model, "created": self.created, @@ -388,16 +425,30 @@ class APIHandler(BaseHTTPRequestHandler): return response + def get_prompt_cache(self, prompt): + cache_len = len(self.prompt_cache.tokens) + if ( + self.prompt_cache.model_key != self.model_provider.model_key + or cache_len >= len(prompt) + or self.prompt_cache.tokens != prompt[:cache_len] + ): + self.prompt_cache.model_key = self.model_provider.model_key + self.prompt_cache.cache = make_prompt_cache(self.model_provider.model) + else: + prompt = prompt[cache_len:] + self.prompt_cache.tokens.extend(prompt) + return prompt + def handle_completion( self, - prompt: mx.array, + prompt: List[int], stop_id_sequences: List[List[int]], ): """ Generate a response to a prompt and send it to the client in a single batch. Args: - prompt (mx.array): The prompt, in token form inside of a mlx array + prompt (List[int]): The tokenized prompt. stop_id_sequences (List[List[int]]): A list of stop words passed to the stopping_criteria function """ @@ -409,17 +460,21 @@ class APIHandler(BaseHTTPRequestHandler): logging.debug(f"Starting completion:") token_logprobs = [] top_tokens = [] - for (token, logprobs), _ in zip( + + prompt = self.get_prompt_cache(prompt) + + for _, (token, logprobs) in zip( + range(self.max_tokens), generate_step( - prompt=prompt, + prompt=mx.array(prompt), model=self.model, temp=self.temperature, top_p=self.top_p, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, logit_bias=self.logit_bias, + prompt_cache=self.prompt_cache.cache, ), - range(self.max_tokens), ): detokenizer.add_token(token) logging.debug(detokenizer.text) @@ -430,7 +485,7 @@ class APIHandler(BaseHTTPRequestHandler): top_indices = sorted_indices[: self.logprobs] top_logprobs = logprobs[top_indices] top_token_info = zip(top_indices.tolist(), top_logprobs.tolist()) - top_tokens.append(dict(top_token_info)) + top_tokens.append(tuple(top_token_info)) token_logprobs.append(logprobs[token].item()) @@ -445,6 +500,7 @@ class APIHandler(BaseHTTPRequestHandler): ) break + self.prompt_cache.tokens.extend(tokens) detokenizer.finalize() text = ( detokenizer.text @@ -474,7 +530,7 @@ class APIHandler(BaseHTTPRequestHandler): def handle_stream( self, - prompt: mx.array, + prompt: List[int], stop_id_sequences: List[List[int]], ): """ @@ -482,7 +538,7 @@ class APIHandler(BaseHTTPRequestHandler): Sent Events (SSE) stream. Args: - prompt (mx.array): The prompt, in token form inside of a mlx array + prompt (mx.array): The tokenized prompt stop_id_sequences (List[List[int]]): A list of stop words passed to the stopping_criteria function """ @@ -496,16 +552,19 @@ class APIHandler(BaseHTTPRequestHandler): stop_sequence_suffix = None logging.debug(f"Starting stream:") - for (token, _), _ in zip( + prompt = self.get_prompt_cache(prompt) + + for _, (token, _) in zip( + range(self.max_tokens), generate_step( - prompt=prompt, + prompt=mx.array(prompt), model=self.model, temp=self.temperature, top_p=self.top_p, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, + prompt_cache=self.prompt_cache.cache, ), - range(self.max_tokens), ): detokenizer.add_token(token) logging.debug(detokenizer.text) @@ -531,9 +590,12 @@ class APIHandler(BaseHTTPRequestHandler): continue new_text = detokenizer.last_segment - response = self.generate_response(new_text, None) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() + if new_text: + response = self.generate_response(new_text, None) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + + self.prompt_cache.tokens.extend(tokens) # check is there any remaining text to send detokenizer.finalize() @@ -559,7 +621,7 @@ class APIHandler(BaseHTTPRequestHandler): ): response = { "id": self.request_id, - "system_fingerprint": f"fp_{uuid.uuid4()}", + "system_fingerprint": self.system_fingerprint, "object": "chat.completion", "model": self.requested_model, "created": self.created, @@ -572,7 +634,7 @@ class APIHandler(BaseHTTPRequestHandler): } return response - def handle_chat_completions(self) -> mx.array: + def handle_chat_completions(self) -> List[int]: """ Handle a chat completion request. @@ -587,7 +649,6 @@ class APIHandler(BaseHTTPRequestHandler): self.object_type = ( "chat.completions.chunk" if self.stream else "chat.completions" ) - if ( hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template @@ -602,9 +663,9 @@ class APIHandler(BaseHTTPRequestHandler): prompt = convert_chat(body["messages"], body.get("role_mapping")) prompt = self.tokenizer.encode(prompt) - return mx.array(prompt) + return prompt - def handle_text_completions(self) -> mx.array: + def handle_text_completions(self) -> List[int]: """ Handle a text completion request. @@ -614,11 +675,8 @@ class APIHandler(BaseHTTPRequestHandler): # Determine response type self.request_id = f"cmpl-{uuid.uuid4()}" self.object_type = "text_completion" - assert "prompt" in self.body, "Request did not contain a prompt" - prompt_text = self.body["prompt"] - prompt = self.tokenizer.encode(prompt_text) - return mx.array(prompt) + return self.tokenizer.encode(self.body["prompt"]) def do_GET(self): """ @@ -669,9 +727,16 @@ def run( handler_class=APIHandler, ): server_address = (host, port) + prompt_cache = PromptCache() httpd = server_class( server_address, - lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs), + lambda *args, **kwargs: handler_class( + model_provider, + prompt_cache=prompt_cache, + system_fingerprint=get_system_fingerprint(), + *args, + **kwargs, + ), ) warnings.warn( "mlx_lm.server is not recommended for production as " diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 04bbbcc5..d8694d86 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -97,6 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer): def text(self): if self._current_tokens: self._current_text = self._tokenizer.decode(self._current_tokens) + if ( + self._tokenizer.clean_up_tokenization_spaces + and self._current_text[-1] == " " + ): + self._current_text = self._current_text[:-1] if self._current_text and self._current_text[-1] == "\n": self._tokens.extend(self._current_tokens) self._text += self._current_text @@ -164,9 +169,11 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): """ _byte_decoder = None + _space_matches = (".", "?", "!", ",", "'", "n't", "'m", "'s", "'ve", "'re") - def __init__(self, tokenizer, trim_space=False): - self.trim_space = trim_space + def __init__(self, tokenizer): + + self.clean_spaces = tokenizer.clean_up_tokenization_spaces # Extract the tokens in a list from id to text self.tokenmap = [None] * len(tokenizer.vocab) @@ -185,17 +192,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): self.text = "" self.tokens = [] + def _maybe_trim_space(self, current_text): + if current_text[0] != " ": + return current_text + elif not self.text: + return current_text[1:] + elif self.clean_spaces and current_text[1:].startswith(self._space_matches): + return current_text[1:] + return current_text + def add_token(self, token): v = self.tokenmap[token] - # if the token starts with space if self._byte_decoder[v[0]] == 32: current_text = bytearray( self._byte_decoder[c] for c in self._unflushed ).decode("utf-8") - if self.text or not self.trim_space: - self.text += current_text - else: - self.text += _remove_space(current_text) + self.text += self._maybe_trim_space(current_text) self._unflushed = v else: self._unflushed += v @@ -204,10 +216,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( "utf-8" ) - if self.text or not self.trim_space: - self.text += current_text - else: - self.text += _remove_space(current_text) + self.text += self._maybe_trim_space(current_text) self._unflushed = "" @classmethod @@ -303,14 +312,7 @@ def _is_spm_decoder_no_space(decoder): def _is_bpe_decoder(decoder): - _target_description = { - "type": "ByteLevel", - "add_prefix_space": False, - "trim_offsets": False, - "use_regex": False, - } - - return _match(_target_description, decoder) + return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" def load_tokenizer(model_path, tokenizer_config_extra={}): diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 1e07546e..4f872982 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -246,10 +246,10 @@ def generate_step( y, logprobs = _step(y) - mx.async_eval(y) + mx.async_eval(y, logprobs) while True: next_y, next_logprobs = _step(y) - mx.async_eval(next_y) + mx.async_eval(next_y, next_logprobs) yield y.item(), logprobs y, logprobs = next_y, next_logprobs @@ -348,7 +348,9 @@ def generate( if formatter: # We have to finalize so that the prob corresponds to the last segment detokenizer.finalize() - formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) + with mx.stream(mx.cpu): + prob = mx.exp(logprobs[token]).item() + formatter(detokenizer.last_segment, prob) else: print(detokenizer.last_segment, end="", flush=True) diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 3c1ef49b..64cd9486 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -1,5 +1,6 @@ # Copyright © 2024 Apple Inc. +import copy import os import tempfile import unittest @@ -215,6 +216,28 @@ class TestPromptCache(unittest.TestCase): all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits)) ) + def test_cache_copying(self): + cache = [KVCache()] + + x = mx.random.uniform(shape=(1, 8, 10, 4)) + cache[0].update_and_fetch(x, x) + + y = mx.random.uniform(shape=(1, 8, 1, 4)) + cache[0].update_and_fetch(y, y) + + old_cache = copy.deepcopy(cache) + + trim_prompt_cache(cache, 1) + + self.assertTrue(old_cache[0].offset, 11) + self.assertTrue(cache[0].offset, 10) + + z = mx.random.uniform(shape=(1, 8, 1, 4)) + cache[0].update_and_fetch(z, z) + + self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y)) + self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z)) + if __name__ == "__main__": unittest.main() diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index cbcccfbe..ad17554d 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -14,6 +14,7 @@ class DummyModelProvider: def __init__(self): HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" self.model, self.tokenizer = load(HF_MODEL_PATH) + self.model_key = (HF_MODEL_PATH, None) def load(self, model, adapter=None): assert model in ["default_model", "chat_model"] diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py new file mode 100644 index 00000000..7b4828b1 --- /dev/null +++ b/llms/tests/test_tokenizers.py @@ -0,0 +1,76 @@ +# Copyright © 2024 Apple Inc. + +import unittest +from pathlib import Path + +from huggingface_hub import snapshot_download +from mlx_lm.tokenizer_utils import ( + BPEStreamingDetokenizer, + NaiveStreamingDetokenizer, + SPMStreamingDetokenizer, + load_tokenizer, +) + + +class TestTokenizers(unittest.TestCase): + + def download_tokenizer(self, repo): + path = Path( + snapshot_download( + repo_id=repo, + allow_patterns=[ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "tokenizer.model", + ], + ) + ) + return load_tokenizer(path) + + def check_tokenizer(self, tokenizer): + def check(tokens): + expected_text = tokenizer.decode(tokens) + detokenizer = tokenizer.detokenizer + detokenizer.reset() + text = "" + for t in tokens: + detokenizer.add_token(t) + seg = detokenizer.last_segment + text += seg + detokenizer.finalize() + text += detokenizer.last_segment + self.assertEqual(text, expected_text) + + tokens = tokenizer.encode("a ,b") + check(tokens) + + tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}') + check(tokens) + + tokens = tokenizer.encode("3 3") + check(tokens) + + def test_tokenizers(self): + tokenizer_repos = [ + ("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer), + ("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer), + ("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer), + ("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer), + ("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer), + ] + for tokenizer_repo, expected_detokenizer in tokenizer_repos: + with self.subTest(tokenizer=tokenizer_repo): + tokenizer = self.download_tokenizer(tokenizer_repo) + tokenizer.decode([0, 1, 2]) + self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer)) + self.check_tokenizer(tokenizer) + + # Try one with a naive detokenizer + tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit") + tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer) + self.check_tokenizer(tokenizer) + + +if __name__ == "__main__": + unittest.main()