diff --git a/.circleci/config.yml b/.circleci/config.yml index 02fa1de8..cecd2d57 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -26,8 +26,8 @@ jobs: - run: name: Install dependencies command: | - brew install python@3.8 - python3.8 -m venv env + brew install python@3.9 + python3.9 -m venv env source env/bin/activate pip install --upgrade pip pip install unittest-xml-reporting diff --git a/.gitignore b/.gitignore index f3dfe929..45445fc8 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ __pycache__/ # C extensions *.so +# Vim +*.swp + # Distribution / packaging .Python build/ diff --git a/README.md b/README.md index bd180975..88888ad0 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,10 @@ Some more useful examples are listed below. ### Image Models +- Generating images + - [FLUX](flux) + - [Stable Diffusion or SDXL](stable_diffusion) - Image classification using [ResNets on CIFAR-10](cifar). -- Generating images with [Stable Diffusion or SDXL](stable_diffusion). - Convolutional variational autoencoder [(CVAE) on MNIST](cvae). ### Audio Models diff --git a/encodec/README.md b/encodec/README.md index 3ab2793c..a3b948bf 100644 --- a/encodec/README.md +++ b/encodec/README.md @@ -33,13 +33,14 @@ An example using the model: ```python import mlx.core as mx -from utils import load, load_audio, save_audio +from encodec import EncodecModel +from utils import load_audio, save_audio # Load the 48 KHz model and preprocessor. -model, processor = load("mlx-community/encodec-48khz-float32") +model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32") # Load an audio file -audio = load_audio("path/to/aduio", model.sampling_rate, model.channels) +audio = load_audio("path/to/audio", model.sampling_rate, model.channels) # Preprocess the audio (this can also be a list of arrays for batched # processing). diff --git a/encodec/benchmarks/bench_mx.py b/encodec/benchmarks/bench_mx.py index 2acd4b75..61ddaae8 100644 --- a/encodec/benchmarks/bench_mx.py +++ b/encodec/benchmarks/bench_mx.py @@ -3,9 +3,10 @@ import time import mlx.core as mx -from utils import load -model, processor = load("mlx-community/encodec-48khz-float32") +from encodec import EncodecModel + +model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32") audio = mx.random.uniform(shape=(288000, 2)) feats, mask = processor(audio) diff --git a/encodec/convert.py b/encodec/convert.py index 13bd31a6..5c5f7e22 100644 --- a/encodec/convert.py +++ b/encodec/convert.py @@ -10,7 +10,6 @@ from typing import Any, Dict, Union import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download -from mlx.utils import tree_flatten import encodec diff --git a/encodec/encodec.py b/encodec/encodec.py index 3ef47369..4b85dfdd 100644 --- a/encodec/encodec.py +++ b/encodec/encodec.py @@ -1,7 +1,10 @@ # Copyright © 2024 Apple Inc. +import functools +import json import math -from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace from typing import List, Optional, Tuple, Union import mlx.core as mx @@ -669,3 +672,70 @@ class EncodecModel(nn.Module): if padding_mask is not None and padding_mask.shape[1] < audio_values.shape[1]: audio_values = audio_values[:, : padding_mask.shape[1]] return audio_values + + @classmethod + def from_pretrained(cls, path_or_repo: str): + from huggingface_hub import snapshot_download + + path = Path(path_or_repo) + if not path.exists(): + path = Path( + snapshot_download( + repo_id=path_or_repo, + allow_patterns=["*.json", "*.safetensors", "*.model"], + ) + ) + + with open(path / "config.json", "r") as f: + config = SimpleNamespace(**json.load(f)) + + model = EncodecModel(config) + model.load_weights(str(path / "model.safetensors")) + processor = functools.partial( + preprocess_audio, + sampling_rate=config.sampling_rate, + chunk_length=model.chunk_length, + chunk_stride=model.chunk_stride, + ) + mx.eval(model) + return model, processor + + +def preprocess_audio( + raw_audio: Union[mx.array, List[mx.array]], + sampling_rate: int = 24000, + chunk_length: Optional[int] = None, + chunk_stride: Optional[int] = None, +): + r""" + Prepare inputs for the EnCodec model. + + Args: + raw_audio (mx.array or List[mx.array]): The sequence or batch of + sequences to be processed. + sampling_rate (int): The sampling rate at which the audio waveform + should be digitalized. + chunk_length (int, optional): The model's chunk length. + chunk_stride (int, optional): The model's chunk stride. + """ + if not isinstance(raw_audio, list): + raw_audio = [raw_audio] + + raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio] + + max_length = max(array.shape[0] for array in raw_audio) + if chunk_length is not None: + max_length += chunk_length - (max_length % chunk_stride) + + inputs = [] + masks = [] + for x in raw_audio: + length = x.shape[0] + mask = mx.ones((length,), dtype=mx.bool_) + difference = max_length - length + if difference > 0: + mask = mx.pad(mask, (0, difference)) + x = mx.pad(x, ((0, difference), (0, 0))) + inputs.append(x) + masks.append(mask) + return mx.stack(inputs), mx.stack(masks) diff --git a/encodec/example.py b/encodec/example.py index 97b311a1..15ea476c 100644 --- a/encodec/example.py +++ b/encodec/example.py @@ -1,10 +1,12 @@ # Copyright © 2024 Apple Inc. import mlx.core as mx -from utils import load, load_audio, save_audio +from utils import load_audio, save_audio + +from encodec import EncodecModel # Load the 48 KHz model and preprocessor. -model, processor = load("mlx-community/encodec-48khz-float32") +model, processor = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32") # Load an audio file audio = load_audio("/path/to/audio", model.sampling_rate, model.channels) diff --git a/encodec/test.py b/encodec/test.py index ffc23505..ae565c29 100644 --- a/encodec/test.py +++ b/encodec/test.py @@ -3,9 +3,10 @@ import mlx.core as mx import numpy as np import torch -from datasets import Audio, load_dataset -from transformers import AutoProcessor, EncodecModel -from utils import load, load_audio, preprocess_audio +from transformers import AutoProcessor +from transformers import EncodecModel as PTEncodecModel + +from encodec import EncodecModel, preprocess_audio def compare_processors(): @@ -30,8 +31,8 @@ def compare_processors(): def compare_models(): - pt_model = EncodecModel.from_pretrained("facebook/encodec_48khz") - mx_model, _ = load("mlx-community/encodec-48khz-float32") + pt_model = PTEncodecModel.from_pretrained("facebook/encodec_48khz") + mx_model, _ = EncodecModel.from_pretrained("mlx-community/encodec-48khz-float32") np.random.seed(0) audio_length = 190560 diff --git a/encodec/utils.py b/encodec/utils.py index 18b3f063..b429ed83 100644 --- a/encodec/utils.py +++ b/encodec/utils.py @@ -1,16 +1,7 @@ # Copyright © 2024 Apple Inc. -import functools -import json -from pathlib import Path -from types import SimpleNamespace -from typing import List, Optional, Union - import mlx.core as mx import numpy as np -from huggingface_hub import snapshot_download - -import encodec def save_audio(file: str, audio: mx.array, sampling_rate: int): @@ -59,71 +50,3 @@ def load_audio(file: str, sampling_rate: int, channels: int): out = mx.array(np.frombuffer(out, np.int16)) return out.reshape(-1, channels).astype(mx.float32) / 32767.0 - - -def preprocess_audio( - raw_audio: Union[mx.array, List[mx.array]], - sampling_rate: int = 24000, - chunk_length: Optional[int] = None, - chunk_stride: Optional[int] = None, -): - r""" - Prepare inputs for the EnCodec model. - - Args: - raw_audio (mx.array or List[mx.array]): The sequence or batch of - sequences to be processed. - sampling_rate (int): The sampling rate at which the audio waveform - should be digitalized. - chunk_length (int, optional): The model's chunk length. - chunk_stride (int, optional): The model's chunk stride. - """ - if not isinstance(raw_audio, list): - raw_audio = [raw_audio] - - raw_audio = [x[..., None] if x.ndim == 1 else x for x in raw_audio] - - max_length = max(array.shape[0] for array in raw_audio) - if chunk_length is not None: - max_length += chunk_length - (max_length % chunk_stride) - - inputs = [] - masks = [] - for x in raw_audio: - length = x.shape[0] - mask = mx.ones((length,), dtype=mx.bool_) - difference = max_length - length - if difference > 0: - mask = mx.pad(mask, (0, difference)) - x = mx.pad(x, ((0, difference), (0, 0))) - inputs.append(x) - masks.append(mask) - return mx.stack(inputs), mx.stack(masks) - - -def load(path_or_repo): - """ - Load the model and audo preprocessor. - """ - path = Path(path_or_repo) - if not path.exists(): - path = Path( - snapshot_download( - repo_id=path_or_repo, - allow_patterns=["*.json", "*.safetensors", "*.model"], - ) - ) - - with open(path / "config.json", "r") as f: - config = SimpleNamespace(**json.load(f)) - - model = encodec.EncodecModel(config) - model.load_weights(str(path / "model.safetensors")) - processor = functools.partial( - preprocess_audio, - sampling_rate=config.sampling_rate, - chunk_length=model.chunk_length, - chunk_stride=model.chunk_stride, - ) - mx.eval(model) - return model, processor diff --git a/flux/README.md b/flux/README.md new file mode 100644 index 00000000..1a17e386 --- /dev/null +++ b/flux/README.md @@ -0,0 +1,212 @@ +FLUX +==== + +FLUX implementation in MLX. The implementation is ported directly from +[https://github.com/black-forest-labs/flux](https://github.com/black-forest-labs/flux) +and the model weights are downloaded directly from the Hugging Face Hub. + +The goal of this example is to be clean, educational and to allow for +experimentation with finetuning FLUX models as well as adding extra +functionality such as in-/outpainting, guidance with custom losses etc. + +![MLX image](static/generated-mlx.png) +*Image generated using FLUX-dev in MLX and the prompt 'An image in the style of +tron emanating futuristic technology with the word "MLX" in the center with +capital red letters.'* + +Installation +------------ + +The dependencies are minimal, namely: + +- `huggingface-hub` to download the checkpoints. +- `regex` for the tokenization +- `tqdm`, `PIL`, and `numpy` for the scripts +- `sentencepiece` for the T5 tokenizer +- `datasets` for using an HF dataset directly + +You can install all of the above with the `requirements.txt` as follows: + + pip install -r requirements.txt + + +Usage +--------- + +You can use the following command to generate an image, using `--output` to specify the storage location of the image, defaulting to `out.png`. + +```shell +python txt2image.py --model schnell \ + --n-images 1 \ + --image-size 256x512 \ + --verbose \ + 'A photo of an astronaut riding a horse on Mars.' +``` + +For more parameters, please use the `--help` command to view. + +```shell +python txt2image.py --help +``` + +Inference +--------- + +Inference in this example is similar to the stable diffusion example. The +classes to get you started are `FluxPipeline` from the `flux` module. + +```python +import mlx.core as mx +from flux import FluxPipeline + +# This will download all the weights from HF hub +flux = FluxPipeline("flux-schnell") + +# Make a generator that returns the latent variables from the reverse diffusion +# process +latent_generator = flux.generate_latents( + "A photo of an astronaut riding a horse on Mars", + num_steps=4, + latent_size=(32, 64), # 256x512 image +) + +# The first return value of the generator contains the conditioning and the +# random noise at the beginning of the diffusion process. +conditioning = next(latent_generator) +( + x_T, # The initial noise + x_positions, # The integer positions used for image positional encoding + t5_conditioning, # The T5 features from the text prompt + t5_positions, # Integer positions for text (normally all 0s) + clip_conditioning, # The clip text features from the text prompt +) = conditioning + +# Returning the conditioning as the first output from the generator allows us +# to unload T5 and clip before running the diffusion transformer. +mx.eval(conditioning) + +# Evaluate each diffusion step +for x_t in latent_generator: + mx.eval(x_t) + +# Note that we need to pass the latent size because it is collapsed and +# patchified in x_t and we need to unwrap it. +img = flux.decode(x_t, latent_size=(32, 64)) +``` + +The above are essentially the implementation of the `txt2image.py` script +except for some additional logic to quantize and/or load trained adapters. One +can use the script as follows: + +```shell +python txt2image.py \ + --n-images 4 \ + --n-rows 2 \ + --image-size 256x512 \ + 'A photo of an astronaut riding a horse on Mars.' +``` + +### Experimental Options + +FLUX pads the prompt to a specific size of 512 tokens for the dev model and +256 for the schnell model. Not applying padding results in faster generation +but it is not clear how it may affect the generated images. To enable that +option in this example pass `--no-t5-padding` to the `txt2image.py` script or +instantiate the pipeline with `FluxPipeline("flux-schnell", t5_padding=False)`. + +Finetuning +---------- + +The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell +but ymmv) on a provided image dataset. The dataset folder must have an +`train.jsonl` file with the following format: + +```jsonl +{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"} +{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"} +... +``` + +The training script by default trains for 600 iterations with a batch size of +1, gradient accumulation of 4 and LoRA rank of 8. Run `python dreambooth.py +--help` for the list of hyperparameters you can tune. + +> [!Note] +> FLUX finetuning requires approximately 50GB of RAM. QLoRA is coming soon and +> should reduce this number significantly. + +### Training Example + +This is a step-by-step finetuning example. We will be using the data from +[https://github.com/google/dreambooth](https://github.com/google/dreambooth). +In particular, we will use `dog6` which is a popular example for showcasing +dreambooth [^1]. + +The training images are the following 5 images [^2]: + +![dog6](static/dog6.png) + +We start by making the following `train.jsonl` file and placing it in the same +folder as the images. + +```jsonl +{"image": "00.jpg", "prompt": "A photo of sks dog"} +{"image": "01.jpg", "prompt": "A photo of sks dog"} +{"image": "02.jpg", "prompt": "A photo of sks dog"} +{"image": "03.jpg", "prompt": "A photo of sks dog"} +{"image": "04.jpg", "prompt": "A photo of sks dog"} +``` + +Subsequently we finetune FLUX using the following command: + +```shell +python dreambooth.py \ + --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \ + --progress-every 600 --iterations 1200 --learning-rate 0.0001 \ + --lora-rank 4 --grad-accumulate 8 \ + path/to/dreambooth/dataset/dog6 +``` + + +Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning. + +```shell +python dreambooth.py \ + --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \ + --progress-every 600 --iterations 1200 --learning-rate 0.0001 \ + --lora-rank 4 --grad-accumulate 8 \ + mlx-community/dreambooth-dog6 +``` + +The training requires approximately 50GB of RAM and on an M2 Ultra it takes a +bit more than 1 hour. + +### Using the Adapter + +The adapters are saved in `mlx_output` and can be used directly by the +`txt2image.py` script. For instance, + +```shell +python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \ + --adapter mlx_output/0001200_adapters.safetensors \ + --fuse-adapter \ + --no-t5-padding \ + 'A photo of an sks dog lying on the sand at a beach in Greece' +``` + +generates an image that looks like the following, + +![dog image](static/dog-r4-g8-1200.png) + +and of course we can pass `--image-size 512x1024` to get larger images with +different aspect ratios, + +![wide dog image](static/dog-r4-g8-1200-512x1024.png) + +The arguments that are relevant to the adapters are of course `--adapter` and +`--fuse-adapter`. The first defines the path to an adapter to apply to the +model and the second fuses the adapter back into the model to get a bit more +speed during generation. + +[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details. +[^2]: The images are from unsplash by https://unsplash.com/@alvannee . diff --git a/flux/dreambooth.py b/flux/dreambooth.py new file mode 100644 index 00000000..48dcad47 --- /dev/null +++ b/flux/dreambooth.py @@ -0,0 +1,285 @@ +# Copyright © 2024 Apple Inc. + +import argparse +import time +from functools import partial +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +import numpy as np +from mlx.nn.utils import average_gradients +from mlx.utils import tree_flatten, tree_map, tree_reduce +from PIL import Image + +from flux import FluxPipeline, Trainer, load_dataset + + +def generate_progress_images(iteration, flux, args): + """Generate images to monitor the progress of the finetuning.""" + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + out_file = out_dir / f"{iteration:07d}_progress.png" + print(f"Generating {str(out_file)}", flush=True) + + # Generate some images and arrange them in a grid + n_rows = 2 + n_images = 4 + x = flux.generate_images( + args.progress_prompt, + n_images, + args.progress_steps, + ) + x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)]) + B, H, W, C = x.shape + x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4) + x = x.reshape(n_rows * H, B // n_rows * W, C) + x = mx.pad(x, [(4, 4), (4, 4), (0, 0)]) + x = (x * 255).astype(mx.uint8) + + # Save them to disc + im = Image.fromarray(np.array(x)) + im.save(out_file) + + +def save_adapters(iteration, flux, args): + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + out_file = out_dir / f"{iteration:07d}_adapters.safetensors" + print(f"Saving {str(out_file)}") + + mx.save_safetensors( + str(out_file), + dict(tree_flatten(flux.flow.trainable_parameters())), + metadata={ + "lora_rank": str(args.lora_rank), + "lora_blocks": str(args.lora_blocks), + }, + ) + + +def setup_arg_parser(): + """Set up and return the argument parser.""" + parser = argparse.ArgumentParser( + description="Finetune Flux to generate images with a specific subject" + ) + + parser.add_argument( + "--model", + default="dev", + choices=[ + "dev", + "schnell", + ], + help="Which flux model to train", + ) + parser.add_argument( + "--guidance", type=float, default=4.0, help="The guidance factor to use." + ) + parser.add_argument( + "--iterations", + type=int, + default=600, + help="How many iterations to train for", + ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="The batch size to use when training the stable diffusion model", + ) + parser.add_argument( + "--resolution", + type=lambda x: tuple(map(int, x.split("x"))), + default=(512, 512), + help="The resolution of the training images", + ) + parser.add_argument( + "--num-augmentations", + type=int, + default=5, + help="Augment the images by random cropping and panning", + ) + parser.add_argument( + "--progress-prompt", + required=True, + help="Use this prompt when generating images for evaluation", + ) + parser.add_argument( + "--progress-steps", + type=int, + default=50, + help="Use this many steps when generating images for evaluation", + ) + parser.add_argument( + "--progress-every", + type=int, + default=50, + help="Generate images every PROGRESS_EVERY steps", + ) + parser.add_argument( + "--checkpoint-every", + type=int, + default=50, + help="Save the model every CHECKPOINT_EVERY steps", + ) + parser.add_argument( + "--lora-blocks", + type=int, + default=-1, + help="Train the last LORA_BLOCKS transformer blocks", + ) + parser.add_argument( + "--lora-rank", type=int, default=8, help="LoRA rank for finetuning" + ) + parser.add_argument( + "--warmup-steps", type=int, default=100, help="Learning rate warmup" + ) + parser.add_argument( + "--learning-rate", type=float, default="1e-4", help="Learning rate for training" + ) + parser.add_argument( + "--grad-accumulate", + type=int, + default=4, + help="Accumulate gradients for that many iterations before applying them", + ) + parser.add_argument( + "--output-dir", default="mlx_output", help="Folder to save the checkpoints in" + ) + + parser.add_argument("dataset") + return parser + + +if __name__ == "__main__": + parser = setup_arg_parser() + args = parser.parse_args() + + # Load the model and set it up for LoRA training. We use the same random + # state when creating the LoRA layers so all workers will have the same + # initial weights. + mx.random.seed(0x0F0F0F0F) + flux = FluxPipeline("flux-" + args.model) + flux.flow.freeze() + flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks) + + # Reset the seed to a different seed per worker if we are in distributed + # mode so that each worker is working on different data, diffusion step and + # random noise. + mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank()) + + # Report how many parameters we are training + trainable_params = tree_reduce( + lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0 + ) + print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True) + + # Set up the optimizer and training steps. The steps are a bit verbose to + # support gradient accumulation together with compilation. + warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps) + cosine = optim.cosine_decay( + args.learning_rate, args.iterations // args.grad_accumulate + ) + lr_schedule = optim.join_schedules([warmup, cosine], [args.warmup_steps]) + optimizer = optim.Adam(learning_rate=lr_schedule) + state = [flux.flow.state, optimizer.state, mx.random.state] + + @partial(mx.compile, inputs=state, outputs=state) + def single_step(x, t5_feat, clip_feat, guidance): + loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( + x, t5_feat, clip_feat, guidance + ) + grads = average_gradients(grads) + optimizer.update(flux.flow, grads) + + return loss + + @partial(mx.compile, inputs=state, outputs=state) + def compute_loss_and_grads(x, t5_feat, clip_feat, guidance): + return nn.value_and_grad(flux.flow, flux.training_loss)( + x, t5_feat, clip_feat, guidance + ) + + @partial(mx.compile, inputs=state, outputs=state) + def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads): + loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( + x, t5_feat, clip_feat, guidance + ) + grads = tree_map(lambda a, b: a + b, prev_grads, grads) + return loss, grads + + @partial(mx.compile, inputs=state, outputs=state) + def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads): + loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( + x, t5_feat, clip_feat, guidance + ) + grads = tree_map( + lambda a, b: (a + b) / args.grad_accumulate, + prev_grads, + grads, + ) + grads = average_gradients(grads) + optimizer.update(flux.flow, grads) + + return loss + + # We simply route to the appropriate step based on whether we have + # gradients from a previous step and whether we should be performing an + # update or simply computing and accumulating gradients in this step. + def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step): + if prev_grads is None: + if perform_step: + return single_step(x, t5_feat, clip_feat, guidance), None + else: + return compute_loss_and_grads(x, t5_feat, clip_feat, guidance) + else: + if perform_step: + return ( + grad_accumulate_and_step( + x, t5_feat, clip_feat, guidance, prev_grads + ), + None, + ) + else: + return compute_loss_and_accumulate_grads( + x, t5_feat, clip_feat, guidance, prev_grads + ) + + dataset = load_dataset(args.dataset) + trainer = Trainer(flux, dataset, args) + trainer.encode_dataset() + + guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype) + + # An initial generation to compare + generate_progress_images(0, flux, args) + + grads = None + losses = [] + tic = time.time() + for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)): + loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0) + mx.eval(loss, grads, state) + losses.append(loss.item()) + + if (i + 1) % 10 == 0: + toc = time.time() + peak_mem = mx.metal.get_peak_memory() / 1024**3 + print( + f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " + f"It/s: {10 / (toc - tic):.3f} " + f"Peak mem: {peak_mem:.3f} GB", + flush=True, + ) + + if (i + 1) % args.progress_every == 0: + generate_progress_images(i + 1, flux, args) + + if (i + 1) % args.checkpoint_every == 0: + save_adapters(i + 1, flux, args) + + if (i + 1) % 10 == 0: + losses = [] + tic = time.time() diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py new file mode 100644 index 00000000..b1122d75 --- /dev/null +++ b/flux/flux/__init__.py @@ -0,0 +1,15 @@ +# Copyright © 2024 Apple Inc. + +from .datasets import Dataset, load_dataset +from .flux import FluxPipeline +from .lora import LoRALinear +from .sampler import FluxSampler +from .trainer import Trainer +from .utils import ( + load_ae, + load_clip, + load_clip_tokenizer, + load_flow_model, + load_t5, + load_t5_tokenizer, +) diff --git a/flux/flux/autoencoder.py b/flux/flux/autoencoder.py new file mode 100644 index 00000000..6332bb57 --- /dev/null +++ b/flux/flux/autoencoder.py @@ -0,0 +1,357 @@ +# Copyright © 2024 Apple Inc. + +from dataclasses import dataclass +from typing import List + +import mlx.core as mx +import mlx.nn as nn +from mlx.nn.layers.upsample import upsample_nearest + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: List[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm( + num_groups=32, + dims=in_channels, + eps=1e-6, + affine=True, + pytorch_compatible=True, + ) + self.q = nn.Linear(in_channels, in_channels) + self.k = nn.Linear(in_channels, in_channels) + self.v = nn.Linear(in_channels, in_channels) + self.proj_out = nn.Linear(in_channels, in_channels) + + def __call__(self, x: mx.array) -> mx.array: + B, H, W, C = x.shape + + y = x.reshape(B, 1, -1, C) + y = self.norm(y) + q = self.q(y) + k = self.k(y) + v = self.v(y) + y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5)) + y = self.proj_out(y) + + return x + y.reshape(B, H, W, C) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm( + num_groups=32, + dims=in_channels, + eps=1e-6, + affine=True, + pytorch_compatible=True, + ) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = nn.GroupNorm( + num_groups=32, + dims=out_channels, + eps=1e-6, + affine=True, + pytorch_compatible=True, + ) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Linear(in_channels, out_channels) + + def __call__(self, x): + h = x + h = self.norm1(h) + h = nn.silu(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nn.silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def __call__(self, x: mx.array): + x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def __call__(self, x: mx.array): + x = upsample_nearest(x, (2, 2)) + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = [] + block_in = self.ch + for i_level in range(self.num_resolutions): + block = [] + attn = [] # TODO: Remove the attn, nobody appends anything to it + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = {} + down["block"] = block + down["attn"] = attn + if i_level != self.num_resolutions - 1: + down["downsample"] = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = {} + self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid["attn_1"] = AttnBlock(block_in) + self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True + ) + self.conv_out = nn.Conv2d( + block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 + ) + + def __call__(self, x: mx.array): + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level]["block"][i_block](hs[-1]) + + # TODO: Remove the attn + if len(self.down[i_level]["attn"]) > 0: + h = self.down[i_level]["attn"][i_block](h) + + hs.append(h) + + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level]["downsample"](hs[-1])) + + # middle + h = hs[-1] + h = self.mid["block_1"](h) + h = self.mid["attn_1"](h) + h = self.mid["block_2"](h) + + # end + h = self.norm_out(h) + h = nn.silu(h) + h = self.conv_out(h) + + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = {} + self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid["attn_1"] = AttnBlock(block_in) + self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = [] + for i_level in reversed(range(self.num_resolutions)): + block = [] + attn = [] # TODO: Remove the attn, nobody appends anything to it + + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = {} + up["block"] = block + up["attn"] = attn + if i_level != 0: + up["upsample"] = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True + ) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def __call__(self, z: mx.array): + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid["block_1"](h) + h = self.mid["attn_1"](h) + h = self.mid["block_2"](h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level]["block"][i_block](h) + + # TODO: Remove the attn + if len(self.up[i_level]["attn"]) > 0: + h = self.up[i_level]["attn"][i_block](h) + + if i_level != 0: + h = self.up[i_level]["upsample"](h) + + # end + h = self.norm_out(h) + h = nn.silu(h) + h = self.conv_out(h) + + return h + + +class DiagonalGaussian(nn.Module): + def __call__(self, z: mx.array): + mean, logvar = mx.split(z, 2, axis=-1) + if self.training: + std = mx.exp(0.5 * logvar) + eps = mx.random.normal(shape=z.shape, dtype=z.dtype) + return mean + std * eps + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def sanitize(self, weights): + new_weights = {} + for k, w in weights.items(): + if w.ndim == 4: + w = w.transpose(0, 2, 3, 1) + w = w.reshape(-1).reshape(w.shape) + if w.shape[1:3] == (1, 1): + w = w.squeeze((1, 2)) + new_weights[k] = w + return new_weights + + def encode(self, x: mx.array): + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: mx.array): + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def __call__(self, x: mx.array): + return self.decode(self.encode(x)) diff --git a/flux/flux/clip.py b/flux/flux/clip.py new file mode 100644 index 00000000..d5a30dbf --- /dev/null +++ b/flux/flux/clip.py @@ -0,0 +1,154 @@ +# Copyright © 2024 Apple Inc. + +from dataclasses import dataclass +from typing import List, Optional + +import mlx.core as mx +import mlx.nn as nn + +_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu} + + +@dataclass +class CLIPTextModelConfig: + num_layers: int = 23 + model_dims: int = 1024 + num_heads: int = 16 + max_length: int = 77 + vocab_size: int = 49408 + hidden_act: str = "quick_gelu" + + @classmethod + def from_dict(cls, config): + return cls( + num_layers=config["num_hidden_layers"], + model_dims=config["hidden_size"], + num_heads=config["num_attention_heads"], + max_length=config["max_position_embeddings"], + vocab_size=config["vocab_size"], + hidden_act=config["hidden_act"], + ) + + +@dataclass +class CLIPOutput: + # The last_hidden_state indexed at the EOS token and possibly projected if + # the model has a projection layer + pooled_output: Optional[mx.array] = None + + # The full sequence output of the transformer after the final layernorm + last_hidden_state: Optional[mx.array] = None + + # A list of hidden states corresponding to the outputs of the transformer layers + hidden_states: Optional[List[mx.array]] = None + + +class CLIPEncoderLayer(nn.Module): + """The transformer encoder layer from CLIP.""" + + def __init__(self, model_dims: int, num_heads: int, activation: str): + super().__init__() + + self.layer_norm1 = nn.LayerNorm(model_dims) + self.layer_norm2 = nn.LayerNorm(model_dims) + + self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True) + + self.linear1 = nn.Linear(model_dims, 4 * model_dims) + self.linear2 = nn.Linear(4 * model_dims, model_dims) + + self.act = _ACTIVATIONS[activation] + + def __call__(self, x, attn_mask=None): + y = self.layer_norm1(x) + y = self.attention(y, y, y, attn_mask) + x = y + x + + y = self.layer_norm2(x) + y = self.linear1(y) + y = self.act(y) + y = self.linear2(y) + x = y + x + + return x + + +class CLIPTextModel(nn.Module): + """Implements the text encoder transformer from CLIP.""" + + def __init__(self, config: CLIPTextModelConfig): + super().__init__() + + self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims) + self.position_embedding = nn.Embedding(config.max_length, config.model_dims) + self.layers = [ + CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act) + for i in range(config.num_layers) + ] + self.final_layer_norm = nn.LayerNorm(config.model_dims) + + def _get_mask(self, N, dtype): + indices = mx.arange(N) + mask = indices[:, None] < indices[None] + mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9) + return mask + + def sanitize(self, weights): + new_weights = {} + for key, w in weights.items(): + # Remove prefixes + if key.startswith("text_model."): + key = key[11:] + if key.startswith("embeddings."): + key = key[11:] + if key.startswith("encoder."): + key = key[8:] + + # Map attention layers + if "self_attn." in key: + key = key.replace("self_attn.", "attention.") + if "q_proj." in key: + key = key.replace("q_proj.", "query_proj.") + if "k_proj." in key: + key = key.replace("k_proj.", "key_proj.") + if "v_proj." in key: + key = key.replace("v_proj.", "value_proj.") + + # Map ffn layers + if "mlp.fc1" in key: + key = key.replace("mlp.fc1", "linear1") + if "mlp.fc2" in key: + key = key.replace("mlp.fc2", "linear2") + + new_weights[key] = w + + return new_weights + + def __call__(self, x): + # Extract some shapes + B, N = x.shape + eos_tokens = x.argmax(-1) + + # Compute the embeddings + x = self.token_embedding(x) + x = x + self.position_embedding.weight[:N] + + # Compute the features from the transformer + mask = self._get_mask(N, x.dtype) + hidden_states = [] + for l in self.layers: + x = l(x, mask) + hidden_states.append(x) + + # Apply the final layernorm and return + x = self.final_layer_norm(x) + last_hidden_state = x + + # Select the EOS token + pooled_output = x[mx.arange(len(x)), eos_tokens] + + return CLIPOutput( + pooled_output=pooled_output, + last_hidden_state=last_hidden_state, + hidden_states=hidden_states, + ) diff --git a/flux/flux/datasets.py b/flux/flux/datasets.py new file mode 100644 index 00000000..d31a09f1 --- /dev/null +++ b/flux/flux/datasets.py @@ -0,0 +1,75 @@ +import json +from pathlib import Path + +from PIL import Image + + +class Dataset: + def __getitem__(self, index: int): + raise NotImplementedError() + + def __len__(self): + raise NotImplementedError() + + +class LocalDataset(Dataset): + prompt_key = "prompt" + + def __init__(self, dataset: str, data_file): + self.dataset_base = Path(dataset) + with open(data_file, "r") as fid: + self._data = [json.loads(l) for l in fid] + + def __len__(self): + return len(self._data) + + def __getitem__(self, index: int): + item = self._data[index] + image = Image.open(self.dataset_base / item["image"]) + return image, item[self.prompt_key] + + +class LegacyDataset(LocalDataset): + prompt_key = "text" + + def __init__(self, dataset: str): + self.dataset_base = Path(dataset) + with open(self.dataset_base / "index.json") as f: + self._data = json.load(f)["data"] + + +class HuggingFaceDataset(Dataset): + + def __init__(self, dataset: str): + from datasets import load_dataset as hf_load_dataset + + self._df = hf_load_dataset(dataset)["train"] + + def __len__(self): + return len(self._df) + + def __getitem__(self, index: int): + item = self._df[index] + return item["image"], item["prompt"] + + +def load_dataset(dataset: str): + dataset_base = Path(dataset) + data_file = dataset_base / "train.jsonl" + legacy_file = dataset_base / "index.json" + + if data_file.exists(): + print(f"Load the local dataset {data_file} .", flush=True) + dataset = LocalDataset(dataset, data_file) + elif legacy_file.exists(): + print(f"Load the local dataset {legacy_file} .") + print() + print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.") + print(" See the README for details.") + print(flush=True) + dataset = LegacyDataset(dataset) + else: + print(f"Load the Hugging Face dataset {dataset} .", flush=True) + dataset = HuggingFaceDataset(dataset) + + return dataset diff --git a/flux/flux/flux.py b/flux/flux/flux.py new file mode 100644 index 00000000..3fd044ac --- /dev/null +++ b/flux/flux/flux.py @@ -0,0 +1,246 @@ +# Copyright © 2024 Apple Inc. + +from typing import Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten +from tqdm import tqdm + +from .lora import LoRALinear +from .sampler import FluxSampler +from .utils import ( + load_ae, + load_clip, + load_clip_tokenizer, + load_flow_model, + load_t5, + load_t5_tokenizer, +) + + +class FluxPipeline: + def __init__(self, name: str, t5_padding: bool = True): + self.dtype = mx.bfloat16 + self.name = name + self.t5_padding = t5_padding + + self.ae = load_ae(name) + self.flow = load_flow_model(name) + self.clip = load_clip(name) + self.clip_tokenizer = load_clip_tokenizer(name) + self.t5 = load_t5(name) + self.t5_tokenizer = load_t5_tokenizer(name) + self.sampler = FluxSampler(name) + + def ensure_models_are_loaded(self): + mx.eval( + self.ae.parameters(), + self.flow.parameters(), + self.clip.parameters(), + self.t5.parameters(), + ) + + def reload_text_encoders(self): + self.t5 = load_t5(self.name) + self.clip = load_clip(self.name) + + def tokenize(self, text): + t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding) + clip_tokens = self.clip_tokenizer.encode(text) + return t5_tokens, clip_tokens + + def _prepare_latent_images(self, x): + b, h, w, c = x.shape + + # Pack the latent image to 2x2 patches + x = x.reshape(b, h // 2, 2, w // 2, 2, c) + x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4) + + # Create positions ids used to positionally encode each patch. Due to + # the way RoPE works, this results in an interesting positional + # encoding where parts of the feature are holding different positional + # information. Namely, the first part holds information independent of + # the spatial position (hence 0s), the 2nd part holds vertical spatial + # information and the last one horizontal. + i = mx.zeros((h // 2, w // 2), dtype=mx.int32) + j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij") + x_ids = mx.stack([i, j, k], axis=-1) + x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0) + + return x, x_ids + + def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens): + # Prepare the text features + txt = self.t5(t5_tokens) + if len(txt) == 1 and n_images > 1: + txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:])) + txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32) + + # Prepare the clip text features + vec = self.clip(clip_tokens).pooled_output + if len(vec) == 1 and n_images > 1: + vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:])) + + return txt, txt_ids, vec + + def _denoising_loop( + self, + x_t, + x_ids, + txt, + txt_ids, + vec, + num_steps: int = 35, + guidance: float = 4.0, + start: float = 1, + stop: float = 0, + ): + B = len(x_t) + + def scalar(x): + return mx.full((B,), x, dtype=self.dtype) + + guidance = scalar(guidance) + timesteps = self.sampler.timesteps( + num_steps, + x_t.shape[1], + start=start, + stop=stop, + ) + for i in range(num_steps): + t = timesteps[i] + t_prev = timesteps[i + 1] + + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=scalar(t), + guidance=guidance, + ) + x_t = self.sampler.step(pred, x_t, t, t_prev) + + yield x_t + + def generate_latents( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + ): + # Set the PRNG state + if seed is not None: + mx.random.seed(seed) + + # Create the latent variables + x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype) + x_T, x_ids = self._prepare_latent_images(x_T) + + # Get the conditioning + t5_tokens, clip_tokens = self.tokenize(text) + txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens) + + # Yield the conditioning for controlled evaluation by the caller + yield (x_T, x_ids, txt, txt_ids, vec) + + # Yield the latent sequences from the denoising loop + yield from self._denoising_loop( + x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance + ) + + def decode(self, x, latent_size: Tuple[int, int] = (64, 64)): + h, w = latent_size + x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2) + x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1) + x = self.ae.decode(x) + return mx.clip(x + 1, 0, 2) * 0.5 + + def generate_images( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + reload_text_encoders: bool = True, + progress: bool = True, + ): + latents = self.generate_latents( + text, n_images, num_steps, guidance, latent_size, seed + ) + mx.eval(next(latents)) + + if reload_text_encoders: + self.reload_text_encoders() + + for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True): + mx.eval(x_t) + + images = [] + for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"): + images.append(self.decode(x_t[i : i + 1])) + mx.eval(images[-1]) + images = mx.concatenate(images, axis=0) + mx.eval(images) + + return images + + def training_loss( + self, + x_0: mx.array, + t5_features: mx.array, + clip_features: mx.array, + guidance: mx.array, + ): + # Get the text conditioning + txt = t5_features + txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32) + vec = clip_features + + # Prepare the latent input + x_0, x_ids = self._prepare_latent_images(x_0) + + # Forward process + t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) + eps = mx.random.normal(x_0.shape, dtype=self.dtype) + x_t = self.sampler.add_noise(x_0, t, noise=eps) + x_t = mx.stop_gradient(x_t) + + # Do the denoising + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t, + guidance=guidance, + ) + + return (pred + x_0 - eps).square().mean() + + def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1): + """Swap the linear layers in the transformer blocks with LoRA layers.""" + all_blocks = self.flow.double_blocks + self.flow.single_blocks + all_blocks.reverse() + num_blocks = num_blocks if num_blocks > 0 else len(all_blocks) + for i, block in zip(range(num_blocks), all_blocks): + loras = [] + for name, module in block.named_modules(): + if isinstance(module, nn.Linear): + loras.append((name, LoRALinear.from_base(module, r=rank))) + block.update_modules(tree_unflatten(loras)) + + def fuse_lora_layers(self): + fused_layers = [] + for name, module in self.flow.named_modules(): + if isinstance(module, LoRALinear): + fused_layers.append((name, module.fuse())) + self.flow.update_modules(tree_unflatten(fused_layers)) diff --git a/flux/flux/layers.py b/flux/flux/layers.py new file mode 100644 index 00000000..12397904 --- /dev/null +++ b/flux/flux/layers.py @@ -0,0 +1,302 @@ +# Copyright © 2024 Apple Inc. + +import math +from dataclasses import dataclass +from functools import partial +from typing import List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + + +def _rope(pos: mx.array, dim: int, theta: float): + scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim + omega = 1.0 / (theta**scale) + x = pos[..., None] * omega + cosx = mx.cos(x) + sinx = mx.sin(x) + pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1) + pe = pe.reshape(*pe.shape[:-1], 2, 2) + + return pe + + +@partial(mx.compile, shapeless=True) +def _ab_plus_cd(a, b, c, d): + return a * b + c * d + + +def _apply_rope(x, pe): + s = x.shape + x = x.reshape(*s[:-1], -1, 1, 2) + x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1]) + return x.reshape(s) + + +def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array): + B, H, L, D = q.shape + + q = _apply_rope(q, pe) + k = _apply_rope(k, pe) + x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5)) + + return x.transpose(0, 2, 1, 3).reshape(B, L, -1) + + +def timestep_embedding( + t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0 +): + half = dim // 2 + freqs = mx.arange(0, half, dtype=mx.float32) / half + freqs = freqs * (-math.log(max_period)) + freqs = mx.exp(freqs) + + x = (time_factor * t)[:, None] * freqs[None] + x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1) + + return x.astype(t.dtype) + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: List[int]): + super().__init__() + + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def __call__(self, ids: mx.array): + n_axes = ids.shape[-1] + pe = mx.concatenate( + [_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + axis=-3, + ) + + return pe[:, None] + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def __call__(self, x: mx.array) -> mx.array: + return self.out_layer(nn.silu(self.in_layer(x))) + + +class QKNorm(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = nn.RMSNorm(dim) + self.key_norm = nn.RMSNorm(dim) + + def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]: + return self.query_norm(q), self.key_norm(k) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + def __call__(self, x: mx.array, pe: mx.array) -> mx.array: + H = self.num_heads + B, L, _ = x.shape + qkv = self.qkv(x) + q, k, v = mx.split(qkv, 3, axis=-1) + q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + q, k = self.norm(q, k) + x = _attention(q, k, v, pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: mx.array + scale: mx.array + gate: mx.array + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]: + x = self.lin(nn.silu(x)) + xs = mx.split(x[:, None, :], self.multiplier, axis=-1) + + mod1 = ModulationOut(*xs[:3]) + mod2 = ModulationOut(*xs[3:]) if self.is_double else None + + return mod1, mod2 + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False + ): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.img_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approx="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approx="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + def __call__( + self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array + ) -> Tuple[mx.array, mx.array]: + B, L, _ = img.shape + _, S, _ = txt.shape + H = self.num_heads + + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1) + img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + img_q, img_k = self.img_attn.norm(img_q, img_k) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1) + txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3) + txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3) + txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k) + + # run actual attention + q = mx.concatenate([txt_q, img_q], axis=2) + k = mx.concatenate([txt_k, img_k], axis=2) + v = mx.concatenate([txt_v, img_v], axis=2) + + attn = _attention(q, k, v, pe) + txt_attn, img_attn = mx.split(attn, [S], axis=1) + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp( + (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift + ) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp( + (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift + ) + + return img, txt + + +class SingleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: Optional[float] = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approx="tanh") + self.modulation = Modulation(hidden_size, double=False) + + def __call__(self, x: mx.array, vec: mx.array, pe: mx.array): + B, L, _ = x.shape + H = self.num_heads + + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + + q, k, v, mlp = mx.split( + self.linear1(x_mod), + [self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size], + axis=-1, + ) + q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + q, k = self.norm(q, k) + + # compute attention + y = _attention(q, k, v, pe) + + # compute activation in mlp stream, cat again and run second linear layer + y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2)) + return x + mod.gate * y + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def __call__(self, x: mx.array, vec: mx.array): + shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/flux/flux/lora.py b/flux/flux/lora.py new file mode 100644 index 00000000..b0c8ae56 --- /dev/null +++ b/flux/flux/lora.py @@ -0,0 +1,76 @@ +# Copyright © 2024 Apple Inc. + +import math + +import mlx.core as mx +import mlx.nn as nn + + +class LoRALinear(nn.Module): + @staticmethod + def from_base( + linear: nn.Linear, + r: int = 8, + dropout: float = 0.0, + scale: float = 1.0, + ): + output_dims, input_dims = linear.weight.shape + lora_lin = LoRALinear( + input_dims=input_dims, + output_dims=output_dims, + r=r, + dropout=dropout, + scale=scale, + ) + lora_lin.linear = linear + return lora_lin + + def fuse(self): + linear = self.linear + bias = "bias" in linear + weight = linear.weight + dtype = weight.dtype + + output_dims, input_dims = weight.shape + fused_linear = nn.Linear(input_dims, output_dims, bias=bias) + + lora_b = self.scale * self.lora_b.T + lora_a = self.lora_a.T + fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype) + if bias: + fused_linear.bias = linear.bias + + return fused_linear + + def __init__( + self, + input_dims: int, + output_dims: int, + r: int = 8, + dropout: float = 0.0, + scale: float = 1.0, + bias: bool = False, + ): + super().__init__() + + # Regular linear layer weights + self.linear = nn.Linear(input_dims, output_dims, bias=bias) + + self.dropout = nn.Dropout(p=dropout) + + # Scale for low-rank update + self.scale = scale + + # Low rank lora weights + scale = 1 / math.sqrt(input_dims) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(input_dims, r), + ) + self.lora_b = mx.zeros(shape=(r, output_dims)) + + def __call__(self, x): + y = self.linear(x) + z = (self.dropout(x) @ self.lora_a) @ self.lora_b + return y + (self.scale * z).astype(x.dtype) diff --git a/flux/flux/model.py b/flux/flux/model.py new file mode 100644 index 00000000..18ea70b0 --- /dev/null +++ b/flux/flux/model.py @@ -0,0 +1,134 @@ +# Copyright © 2024 Apple Inc. + +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + +from .layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class Flux(nn.Module): + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError( + f"Got {params.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim + ) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + if params.guidance_embed + else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + + self.single_blocks = [ + SingleStreamBlock( + self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio + ) + for _ in range(params.depth_single_blocks) + ] + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + def sanitize(self, weights): + new_weights = {} + for k, w in weights.items(): + if k.endswith(".scale"): + k = k[:-6] + ".weight" + for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]: + if f".{seq}." in k: + k = k.replace(f".{seq}.", f".{seq}.layers.") + break + new_weights[k] = w + return new_weights + + def __call__( + self, + img: mx.array, + img_ids: mx.array, + txt: mx.array, + txt_ids: mx.array, + timesteps: mx.array, + y: mx.array, + guidance: Optional[mx.array] = None, + ) -> mx.array: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError( + "Didn't get guidance strength for guidance distilled model." + ) + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = mx.concatenate([txt_ids, img_ids], axis=1) + pe = self.pe_embedder(ids).astype(img.dtype) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = mx.concatenate([txt, img], axis=1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) + + return img diff --git a/flux/flux/sampler.py b/flux/flux/sampler.py new file mode 100644 index 00000000..3bff1ca2 --- /dev/null +++ b/flux/flux/sampler.py @@ -0,0 +1,56 @@ +# Copyright © 2024 Apple Inc. + +import math +from functools import lru_cache + +import mlx.core as mx + + +class FluxSampler: + def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5): + self._base_shift = base_shift + self._max_shift = max_shift + self._schnell = "schnell" in name + + def _time_shift(self, x, t): + x1, x2 = 256, 4096 + t1, t2 = self._base_shift, self._max_shift + exp_mu = math.exp((x - x1) * (t2 - t1) / (x2 - x1) + t1) + t = exp_mu / (exp_mu + (1 / t - 1)) + return t + + @lru_cache + def timesteps( + self, num_steps, image_sequence_length, start: float = 1, stop: float = 0 + ): + t = mx.linspace(start, stop, num_steps + 1) + + if self._schnell: + t = self._time_shift(image_sequence_length, t) + + return t.tolist() + + def random_timesteps(self, B, L, dtype=mx.float32, key=None): + if self._schnell: + # TODO: Should we upweigh 1 and 0.75? + t = mx.random.randint(1, 5, shape=(B,), key=key) + t = t.astype(dtype) / 4 + else: + t = mx.random.uniform(shape=(B,), dtype=dtype, key=key) + t = self._time_shift(L, t) + + return t + + def sample_prior(self, shape, dtype=mx.float32, key=None): + return mx.random.normal(shape, dtype=dtype, key=key) + + def add_noise(self, x, t, noise=None, key=None): + noise = ( + noise + if noise is not None + else mx.random.normal(x.shape, dtype=x.dtype, key=key) + ) + return x * (1 - t) + t * noise + + def step(self, pred, x_t, t, t_prev): + return x_t + (t_prev - t) * pred diff --git a/flux/flux/t5.py b/flux/flux/t5.py new file mode 100644 index 00000000..cf0515cd --- /dev/null +++ b/flux/flux/t5.py @@ -0,0 +1,244 @@ +# Copyright © 2024 Apple Inc. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +_SHARED_REPLACEMENT_PATTERNS = [ + (".block.", ".layers."), + (".k.", ".key_proj."), + (".o.", ".out_proj."), + (".q.", ".query_proj."), + (".v.", ".value_proj."), + ("shared.", "wte."), + ("lm_head.", "lm_head.linear."), + (".layer.0.layer_norm.", ".ln1."), + (".layer.1.layer_norm.", ".ln2."), + (".layer.2.layer_norm.", ".ln3."), + (".final_layer_norm.", ".ln."), + ( + "layers.0.layer.0.SelfAttention.relative_attention_bias.", + "relative_attention_bias.embeddings.", + ), +] + +_ENCODER_REPLACEMENT_PATTERNS = [ + (".layer.0.SelfAttention.", ".attention."), + (".layer.1.DenseReluDense.", ".dense."), +] + + +@dataclass +class T5Config: + vocab_size: int + num_layers: int + num_heads: int + relative_attention_num_buckets: int + d_kv: int + d_model: int + feed_forward_proj: str + tie_word_embeddings: bool + + d_ff: Optional[int] = None + num_decoder_layers: Optional[int] = None + relative_attention_max_distance: int = 128 + layer_norm_epsilon: float = 1e-6 + + @classmethod + def from_dict(cls, config): + return cls( + vocab_size=config["vocab_size"], + num_layers=config["num_layers"], + num_heads=config["num_heads"], + relative_attention_num_buckets=config["relative_attention_num_buckets"], + d_kv=config["d_kv"], + d_model=config["d_model"], + feed_forward_proj=config["feed_forward_proj"], + tie_word_embeddings=config["tie_word_embeddings"], + d_ff=config.get("d_ff", 4 * config["d_model"]), + num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]), + relative_attention_max_distance=config.get( + "relative_attention_max_distance", 128 + ), + layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6), + ) + + +class RelativePositionBias(nn.Module): + def __init__(self, config: T5Config, bidirectional: bool): + self.bidirectional = bidirectional + self.num_buckets = config.relative_attention_num_buckets + self.max_distance = config.relative_attention_max_distance + self.n_heads = config.num_heads + self.embeddings = nn.Embedding(self.num_buckets, self.n_heads) + + @staticmethod + def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance): + num_buckets = num_buckets // 2 if bidirectional else num_buckets + max_exact = num_buckets // 2 + + abspos = rpos.abs() + is_small = abspos < max_exact + + scale = (num_buckets - max_exact) / math.log(max_distance / max_exact) + buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16) + buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1) + + buckets = mx.where(is_small, abspos, buckets_large) + if bidirectional: + buckets = buckets + (rpos > 0) * num_buckets + else: + buckets = buckets * (rpos < 0) + + return buckets + + def __call__(self, query_length: int, key_length: int, offset: int = 0): + """Compute binned relative position bias""" + context_position = mx.arange(offset, query_length)[:, None] + memory_position = mx.arange(key_length)[None, :] + + # shape (query_length, key_length) + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=self.bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) + + # shape (query_length, key_length, num_heads) + values = self.embeddings(relative_position_bucket) + + # shape (num_heads, query_length, key_length) + return values.transpose(2, 0, 1) + + +class MultiHeadAttention(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + inner_dim = config.d_kv * config.num_heads + self.num_heads = config.num_heads + self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False) + + def __call__( + self, + queries: mx.array, + keys: mx.array, + values: mx.array, + mask: Optional[mx.array], + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> [mx.array, Tuple[mx.array, mx.array]]: + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + num_heads = self.num_heads + B, L, _ = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + key_cache, value_cache = cache + keys = mx.concatenate([key_cache, keys], axis=3) + values = mx.concatenate([value_cache, values], axis=2) + + values_hat = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype) + ) + values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat), (keys, values) + + +class DenseActivation(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + mlp_dims = config.d_ff or config.d_model * 4 + self.gated = config.feed_forward_proj.startswith("gated") + if self.gated: + self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False) + else: + self.wi = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wo = nn.Linear(mlp_dims, config.d_model, bias=False) + activation = config.feed_forward_proj.removeprefix("gated-") + if activation == "relu": + self.act = nn.relu + elif activation == "gelu": + self.act = nn.gelu + elif activation == "silu": + self.act = nn.silu + else: + raise ValueError(f"Unknown activation: {activation}") + + def __call__(self, x): + if self.gated: + hidden_act = self.act(self.wi_0(x)) + hidden_linear = self.wi_1(x) + x = hidden_act * hidden_linear + else: + x = self.act(self.wi(x)) + return self.wo(x) + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.attention = MultiHeadAttention(config) + self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dense = DenseActivation(config) + + def __call__(self, x, mask): + y = self.ln1(x) + y, _ = self.attention(y, y, y, mask=mask) + x = x + y + + y = self.ln2(x) + y = self.dense(y) + return x + y + + +class TransformerEncoder(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.layers = [ + TransformerEncoderLayer(config) for i in range(config.num_layers) + ] + self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.relative_attention_bias = RelativePositionBias(config, bidirectional=True) + + def __call__(self, x: mx.array): + pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1]) + pos_bias = pos_bias.astype(x.dtype) + for layer in self.layers: + x = layer(x, mask=pos_bias) + return self.ln(x) + + +class T5Encoder(nn.Module): + def __init__(self, config: T5Config): + self.wte = nn.Embedding(config.vocab_size, config.d_model) + self.encoder = TransformerEncoder(config) + + def sanitize(self, weights): + new_weights = {} + for k, w in weights.items(): + for old, new in _SHARED_REPLACEMENT_PATTERNS: + k = k.replace(old, new) + if k.startswith("encoder."): + for old, new in _ENCODER_REPLACEMENT_PATTERNS: + k = k.replace(old, new) + new_weights[k] = w + return new_weights + + def __call__(self, inputs: mx.array): + return self.encoder(self.wte(inputs)) diff --git a/flux/flux/tokenizers.py b/flux/flux/tokenizers.py new file mode 100644 index 00000000..796ef389 --- /dev/null +++ b/flux/flux/tokenizers.py @@ -0,0 +1,185 @@ +# Copyright © 2024 Apple Inc. + +import mlx.core as mx +import regex +from sentencepiece import SentencePieceProcessor + + +class CLIPTokenizer: + """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ .""" + + def __init__(self, bpe_ranks, vocab, max_length=77): + self.max_length = max_length + self.bpe_ranks = bpe_ranks + self.vocab = vocab + self.pat = regex.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + regex.IGNORECASE, + ) + + self._cache = {self.bos: self.bos, self.eos: self.eos} + + @property + def bos(self): + return "<|startoftext|>" + + @property + def bos_token(self): + return self.vocab[self.bos] + + @property + def eos(self): + return "<|endoftext|>" + + @property + def eos_token(self): + return self.vocab[self.eos] + + def bpe(self, text): + if text in self._cache: + return self._cache[text] + + unigrams = list(text[:-1]) + [text[-1] + ""] + unique_bigrams = set(zip(unigrams, unigrams[1:])) + + if not unique_bigrams: + return unigrams + + # In every iteration try to merge the two most likely bigrams. If none + # was merged we are done. + # + # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py + while unique_bigrams: + bigram = min( + unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")) + ) + if bigram not in self.bpe_ranks: + break + + new_unigrams = [] + skip = False + for a, b in zip(unigrams, unigrams[1:]): + if skip: + skip = False + continue + + if (a, b) == bigram: + new_unigrams.append(a + b) + skip = True + + else: + new_unigrams.append(a) + + if not skip: + new_unigrams.append(b) + + unigrams = new_unigrams + unique_bigrams = set(zip(unigrams, unigrams[1:])) + + self._cache[text] = unigrams + + return unigrams + + def tokenize(self, text, prepend_bos=True, append_eos=True): + if isinstance(text, list): + return [self.tokenize(t, prepend_bos, append_eos) for t in text] + + # Lower case cleanup and split according to self.pat. Hugging Face does + # a much more thorough job here but this should suffice for 95% of + # cases. + clean_text = regex.sub(r"\s+", " ", text.lower()) + tokens = regex.findall(self.pat, clean_text) + + # Split the tokens according to the byte-pair merge file + bpe_tokens = [ti for t in tokens for ti in self.bpe(t)] + + # Map to token ids and return + tokens = [self.vocab[t] for t in bpe_tokens] + if prepend_bos: + tokens = [self.bos_token] + tokens + if append_eos: + tokens.append(self.eos_token) + + if len(tokens) > self.max_length: + tokens = tokens[: self.max_length] + if append_eos: + tokens[-1] = self.eos_token + + return tokens + + def encode(self, text): + if not isinstance(text, list): + return self.encode([text]) + + tokens = self.tokenize(text) + length = max(len(t) for t in tokens) + for t in tokens: + t.extend([self.eos_token] * (length - len(t))) + + return mx.array(tokens) + + +class T5Tokenizer: + def __init__(self, model_file, max_length=512): + self._tokenizer = SentencePieceProcessor(model_file) + self.max_length = max_length + + @property + def pad(self): + try: + return self._tokenizer.id_to_piece(self.pad_token) + except IndexError: + return None + + @property + def pad_token(self): + return self._tokenizer.pad_id() + + @property + def bos(self): + try: + return self._tokenizer.id_to_piece(self.bos_token) + except IndexError: + return None + + @property + def bos_token(self): + return self._tokenizer.bos_id() + + @property + def eos(self): + try: + return self._tokenizer.id_to_piece(self.eos_token) + except IndexError: + return None + + @property + def eos_token(self): + return self._tokenizer.eos_id() + + def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True): + if isinstance(text, list): + return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text] + + tokens = self._tokenizer.encode(text) + + if prepend_bos and self.bos_token >= 0: + tokens = [self.bos_token] + tokens + if append_eos and self.eos_token >= 0: + tokens.append(self.eos_token) + if pad and len(tokens) < self.max_length and self.pad_token >= 0: + tokens += [self.pad_token] * (self.max_length - len(tokens)) + + return tokens + + def encode(self, text, pad=True): + if not isinstance(text, list): + return self.encode([text], pad=pad) + + pad_token = self.pad_token if self.pad_token >= 0 else 0 + tokens = self.tokenize(text, pad=pad) + length = max(len(t) for t in tokens) + for t in tokens: + t.extend([pad_token] * (length - len(t))) + + return mx.array(tokens) diff --git a/flux/flux/trainer.py b/flux/flux/trainer.py new file mode 100644 index 00000000..40a126e8 --- /dev/null +++ b/flux/flux/trainer.py @@ -0,0 +1,98 @@ +import mlx.core as mx +import numpy as np +from PIL import Image, ImageFile +from tqdm import tqdm + +from .datasets import Dataset +from .flux import FluxPipeline + + +class Trainer: + + def __init__(self, flux: FluxPipeline, dataset: Dataset, args): + self.flux = flux + self.dataset = dataset + self.args = args + self.latents = [] + self.t5_features = [] + self.clip_features = [] + + def _random_crop_resize(self, img): + resolution = self.args.resolution + width, height = img.size + + a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist() + + # Random crop the input image between 0.8 to 1.0 of its original dimensions + crop_size = ( + max((0.8 + 0.2 * a) * width, resolution[0]), + max((0.8 + 0.2 * b) * height, resolution[1]), + ) + pan = (width - crop_size[0], height - crop_size[1]) + img = img.crop( + ( + pan[0] * c, + pan[1] * d, + crop_size[0] + pan[0] * c, + crop_size[1] + pan[1] * d, + ) + ) + + # Fit the largest rectangle with the ratio of resolution in the image + # rectangle. + width, height = crop_size + ratio = resolution[0] / resolution[1] + r1 = (height * ratio, height) + r2 = (width, width / ratio) + r = r1 if r1[0] <= width else r2 + img = img.crop( + ( + (width - r[0]) / 2, + (height - r[1]) / 2, + (width + r[0]) / 2, + (height + r[1]) / 2, + ) + ) + + # Finally resize the image to resolution + img = img.resize(resolution, Image.LANCZOS) + + return mx.array(np.array(img)) + + def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int): + for i in range(num_augmentations): + img = self._random_crop_resize(input_img) + img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1 + x_0 = self.flux.ae.encode(img[None]) + x_0 = x_0.astype(self.flux.dtype) + mx.eval(x_0) + self.latents.append(x_0) + + def _encode_prompt(self, prompt): + t5_tok, clip_tok = self.flux.tokenize([prompt]) + t5_feat = self.flux.t5(t5_tok) + clip_feat = self.flux.clip(clip_tok).pooled_output + mx.eval(t5_feat, clip_feat) + self.t5_features.append(t5_feat) + self.clip_features.append(clip_feat) + + def encode_dataset(self): + """Encode the images & prompt in the latent space to prepare for training.""" + self.flux.ae.eval() + for image, prompt in tqdm(self.dataset, desc="encode dataset"): + self._encode_image(image, self.args.num_augmentations) + self._encode_prompt(prompt) + + def iterate(self, batch_size): + xs = mx.concatenate(self.latents) + t5 = mx.concatenate(self.t5_features) + clip = mx.concatenate(self.clip_features) + mx.eval(xs, t5, clip) + n_aug = self.args.num_augmentations + while True: + x_indices = mx.random.permutation(len(self.latents)) + c_indices = x_indices // n_aug + for i in range(0, len(self.latents), batch_size): + x_i = x_indices[i : i + batch_size] + c_i = c_indices[i : i + batch_size] + yield xs[x_i], t5[c_i], clip[c_i] diff --git a/flux/flux/utils.py b/flux/flux/utils.py new file mode 100644 index 00000000..21db17d3 --- /dev/null +++ b/flux/flux/utils.py @@ -0,0 +1,209 @@ +# Copyright © 2024 Apple Inc. + +import json +import os +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx +from huggingface_hub import hf_hub_download + +from .autoencoder import AutoEncoder, AutoEncoderParams +from .clip import CLIPTextModel, CLIPTextModelConfig +from .model import Flux, FluxParams +from .t5 import T5Config, T5Encoder +from .tokenizers import CLIPTokenizer, T5Tokenizer + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: Optional[str] + ae_path: Optional[str] + repo_id: Optional[str] + repo_flow: Optional[str] + repo_ae: Optional[str] + + +configs = { + "flux-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_flow="flux1-schnell.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +def load_flow_model(name: str, hf_download: bool = True): + # Get the safetensors file to load + ckpt_path = configs[name].ckpt_path + + # Download if needed + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + + # Make the model + model = Flux(configs[name].params) + + # Load the checkpoint if needed + if ckpt_path is not None: + weights = mx.load(ckpt_path) + weights = model.sanitize(weights) + model.load_weights(list(weights.items())) + + return model + + +def load_ae(name: str, hf_download: bool = True): + # Get the safetensors file to load + ckpt_path = configs[name].ae_path + + # Download if needed + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_ae is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae) + + # Make the autoencoder + ae = AutoEncoder(configs[name].ae_params) + + # Load the checkpoint if needed + if ckpt_path is not None: + weights = mx.load(ckpt_path) + weights = ae.sanitize(weights) + ae.load_weights(list(weights.items())) + + return ae + + +def load_clip(name: str): + # Load the config + config_path = hf_hub_download(configs[name].repo_id, "text_encoder/config.json") + with open(config_path) as f: + config = CLIPTextModelConfig.from_dict(json.load(f)) + + # Make the clip text encoder + clip = CLIPTextModel(config) + + # Load the weights + ckpt_path = hf_hub_download(configs[name].repo_id, "text_encoder/model.safetensors") + weights = mx.load(ckpt_path) + weights = clip.sanitize(weights) + clip.load_weights(list(weights.items())) + + return clip + + +def load_t5(name: str): + # Load the config + config_path = hf_hub_download(configs[name].repo_id, "text_encoder_2/config.json") + with open(config_path) as f: + config = T5Config.from_dict(json.load(f)) + + # Make the T5 model + t5 = T5Encoder(config) + + # Load the weights + model_index = hf_hub_download( + configs[name].repo_id, "text_encoder_2/model.safetensors.index.json" + ) + weight_files = set() + with open(model_index) as f: + for _, w in json.load(f)["weight_map"].items(): + weight_files.add(w) + weights = {} + for w in weight_files: + w = f"text_encoder_2/{w}" + w = hf_hub_download(configs[name].repo_id, w) + weights.update(mx.load(w)) + weights = t5.sanitize(weights) + t5.load_weights(list(weights.items())) + + return t5 + + +def load_clip_tokenizer(name: str): + vocab_file = hf_hub_download(configs[name].repo_id, "tokenizer/vocab.json") + with open(vocab_file, encoding="utf-8") as f: + vocab = json.load(f) + + merges_file = hf_hub_download(configs[name].repo_id, "tokenizer/merges.txt") + with open(merges_file, encoding="utf-8") as f: + bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1] + bpe_merges = [tuple(m.split()) for m in bpe_merges] + bpe_ranks = dict(map(reversed, enumerate(bpe_merges))) + + return CLIPTokenizer(bpe_ranks, vocab, max_length=77) + + +def load_t5_tokenizer(name: str, pad: bool = True): + model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model") + return T5Tokenizer(model_file, 256 if "schnell" in name else 512) diff --git a/flux/requirements.txt b/flux/requirements.txt new file mode 100644 index 00000000..792205c9 --- /dev/null +++ b/flux/requirements.txt @@ -0,0 +1,7 @@ +mlx>=0.18.1 +huggingface-hub +regex +numpy +tqdm +Pillow +sentencepiece diff --git a/flux/static/dog-r4-g8-1200-512x1024.png b/flux/static/dog-r4-g8-1200-512x1024.png new file mode 100644 index 00000000..7b1ca0e6 Binary files /dev/null and b/flux/static/dog-r4-g8-1200-512x1024.png differ diff --git a/flux/static/dog-r4-g8-1200.png b/flux/static/dog-r4-g8-1200.png new file mode 100644 index 00000000..90e47333 Binary files /dev/null and b/flux/static/dog-r4-g8-1200.png differ diff --git a/flux/static/dog6.png b/flux/static/dog6.png new file mode 100644 index 00000000..2bcf7b8c Binary files /dev/null and b/flux/static/dog6.png differ diff --git a/flux/static/generated-mlx.png b/flux/static/generated-mlx.png new file mode 100644 index 00000000..5c274ef4 Binary files /dev/null and b/flux/static/generated-mlx.png differ diff --git a/flux/txt2image.py b/flux/txt2image.py new file mode 100644 index 00000000..5ebec81a --- /dev/null +++ b/flux/txt2image.py @@ -0,0 +1,150 @@ +# Copyright © 2024 Apple Inc. + +import argparse + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from PIL import Image +from tqdm import tqdm + +from flux import FluxPipeline + + +def to_latent_size(image_size): + h, w = image_size + h = ((h + 15) // 16) * 16 + w = ((w + 15) // 16) * 16 + + if (h, w) != image_size: + print( + "Warning: The image dimensions need to be divisible by 16px. " + f"Changing size to {h}x{w}." + ) + + return (h // 8, w // 8) + + +def quantization_predicate(name, m): + return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0 + + +def load_adapter(flux, adapter_file, fuse=False): + weights, lora_config = mx.load(adapter_file, return_metadata=True) + rank = int(lora_config["lora_rank"]) + num_blocks = int(lora_config["lora_blocks"]) + flux.linear_to_lora_layers(rank, num_blocks) + flux.flow.load_weights(list(weights.items()), strict=False) + if fuse: + flux.fuse_lora_layers() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate images from a textual prompt using stable diffusion" + ) + parser.add_argument("prompt") + parser.add_argument("--model", choices=["schnell", "dev"], default="schnell") + parser.add_argument("--n-images", type=int, default=4) + parser.add_argument( + "--image-size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512) + ) + parser.add_argument("--steps", type=int) + parser.add_argument("--guidance", type=float, default=4.0) + parser.add_argument("--n-rows", type=int, default=1) + parser.add_argument("--decoding-batch-size", type=int, default=1) + parser.add_argument("--quantize", "-q", action="store_true") + parser.add_argument("--preload-models", action="store_true") + parser.add_argument("--output", default="out.png") + parser.add_argument("--save-raw", action="store_true") + parser.add_argument("--seed", type=int) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--adapter") + parser.add_argument("--fuse-adapter", action="store_true") + parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false") + args = parser.parse_args() + + # Load the models + flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding) + args.steps = args.steps or (50 if args.model == "dev" else 2) + + if args.adapter: + load_adapter(flux, args.adapter, fuse=args.fuse_adapter) + + if args.quantize: + nn.quantize(flux.flow, class_predicate=quantization_predicate) + nn.quantize(flux.t5, class_predicate=quantization_predicate) + nn.quantize(flux.clip, class_predicate=quantization_predicate) + + if args.preload_models: + flux.ensure_models_are_loaded() + + # Make the generator + latent_size = to_latent_size(args.image_size) + latents = flux.generate_latents( + args.prompt, + n_images=args.n_images, + num_steps=args.steps, + latent_size=latent_size, + guidance=args.guidance, + seed=args.seed, + ) + + # First we get and eval the conditioning + conditioning = next(latents) + mx.eval(conditioning) + peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3 + mx.metal.reset_peak_memory() + + # The following is not necessary but it may help in memory constrained + # systems by reusing the memory kept by the text encoders. + del flux.t5 + del flux.clip + + # Actual denoising loop + for x_t in tqdm(latents, total=args.steps): + mx.eval(x_t) + + # The following is not necessary but it may help in memory constrained + # systems by reusing the memory kept by the flow transformer. + del flux.flow + peak_mem_generation = mx.metal.get_peak_memory() / 1024**3 + mx.metal.reset_peak_memory() + + # Decode them into images + decoded = [] + for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): + decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size)) + mx.eval(decoded[-1]) + peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3 + peak_mem_overall = max( + peak_mem_conditioning, peak_mem_generation, peak_mem_decoding + ) + + if args.save_raw: + *name, suffix = args.output.split(".") + name = ".".join(name) + x = mx.concatenate(decoded, axis=0) + x = (x * 255).astype(mx.uint8) + for i in range(len(x)): + im = Image.fromarray(np.array(x[i])) + im.save(".".join([name, str(i), suffix])) + else: + # Arrange them on a grid + x = mx.concatenate(decoded, axis=0) + x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)]) + B, H, W, C = x.shape + x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4) + x = x.reshape(args.n_rows * H, B // args.n_rows * W, C) + x = (x * 255).astype(mx.uint8) + + # Save them to disc + im = Image.fromarray(np.array(x)) + im.save(args.output) + + # Report the peak memory used during generation + if args.verbose: + print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB") + print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB") + print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB") + print(f"Peak memory used overall: {peak_mem_overall:.3f}GB") diff --git a/llms/README.md b/llms/README.md index 75677865..20863041 100644 --- a/llms/README.md +++ b/llms/README.md @@ -20,6 +20,31 @@ The `mlx-lm` package also has: - [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md) - [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md) +### Quick Start + +To generate text with an LLM use: + +```bash +mlx_lm.generate --prompt "Hi!" +``` + +To chat with an LLM use: + +```bash +mlx_lm.chat +``` + +This will give you a chat REPL that you can use to interact with the LLM. The +chat context is preserved during the lifetime of the REPL. + +Commands in `mlx-lm` typically take command line options which let you specify +the model, sampling parameters, and more. Use `-h` to see a list of available +options for a command, e.g.: + +```bash +mlx_lm.generate -h +``` + ### Python API You can use `mlx-lm` as a module: @@ -138,7 +163,7 @@ mlx_lm.convert \ ### Long Prompts and Generations -MLX LM has some tools to scale efficiently to long prompts and generations: +`mlx-lm` has some tools to scale efficiently to long prompts and generations: - A rotating fixed-size key-value cache. - Prompt caching @@ -155,14 +180,14 @@ different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example: cat prompt.txt | mlx_lm.cache_prompt \ --model mistralai/Mistral-7B-Instruct-v0.3 \ --prompt - \ - --kv-cache-file mistral_prompt.safetensors + --prompt-cache-file mistral_prompt.safetensors ``` Then use the cached prompt with `mlx_lm.generate`: ``` mlx_lm.generate \ - --kv-cache-file mistral_prompt.safetensors \ + --prompt-cache-file mistral_prompt.safetensors \ --prompt "\nSummarize the above text." ``` @@ -170,9 +195,15 @@ The cached prompt is treated as a prefix to the supplied prompt. Also notice when using a cached prompt, the model to use is read from the cache and need not be supplied explicitly. +Prompt caching can also be used in the Python API in order to to avoid +recomputing the prompt. This is useful in multi-turn dialogues or across +requests that use the same context. See the +[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py) +for more usage details. + ### Supported Models -MLX LM supports thousands of Hugging Face format LLMs. If the model you want to +`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to run is not supported, file an [issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet, submit a pull request. diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index 55be1c9c..2976a09f 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \ - `role_mapping`: (Optional) A dictionary to customize the role prefixes in the generated prompt. If not provided, the default mappings are used. -- `stop`: (Optional) An array of strings or a single string. Thesse are +- `stop`: (Optional) An array of strings or a single string. These are sequences of tokens on which the generation should stop. - `max_tokens`: (Optional) An integer specifying the maximum number of tokens @@ -84,7 +84,37 @@ curl localhost:8080/v1/chat/completions \ started in. - `adapters`: (Optional) A string path to low-rank adapters. The path must be - rlative to the directory the server was started in. + relative to the directory the server was started in. + +### Response Fields + +- `id`: A unique identifier for the chat. + +- `system_fingerprint`: A unique identifier for the system. + +- `object`: Any of "chat.completions", "chat.completions.chunk" (for + streaming), or "text.completion". + +- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`). + +- `created`: A time-stamp for when the request was processed. + +- `choices`: A list of outputs. Each output is a dictionary containing the fields: + - `index`: The index in the list. + - `logprobs`: A dictionary containing the fields: + - `token_logprobs`: A list of the log probabilities for the generated + tokens. + - `tokens`: A list of the generated token ids. + - `top_logprobs`: A list of lists. Each list contains the `logprobs` + top tokens (if requested) with their corresponding probabilities. + - `finish_reason`: The reason the completion ended. This can be either of + `"stop"` or `"length"`. + - `message`: The text response from the model. + +- `usage`: A dictionary containing the fields: + - `prompt_tokens`: The number of prompt tokens processed. + - `completion_tokens`: The number of tokens generated. + - `total_tokens`: The total number of tokens, i.e. the sum of the above two fields. ### List Models @@ -97,5 +127,5 @@ curl localhost:8080/v1/models -H "Content-Type: application/json" This will return a list of locally available models where each model in the list contains the following fields: -- `"id"`: The Hugging Face repo id. -- `"created"`: A timestamp representing the model creation time. +- `id`: The Hugging Face repo id. +- `created`: A time-stamp representing the model creation time. diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 8110c823..70239db6 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.18.2" +__version__ = "0.19.1" diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 9829efb4..04e75a3e 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -7,13 +7,14 @@ import time import mlx.core as mx -from .utils import load, make_kv_caches +from .models.cache import make_prompt_cache, save_prompt_cache +from .utils import load def setup_arg_parser(): """Set up and return the argument parser.""" parser = argparse.ArgumentParser( - description="Cache the KV cache of a prompt to be reused with mlx_lm.generate" + description="Cache the state of a prompt to be reused with mlx_lm.generate" ) parser.add_argument( "--model", @@ -60,7 +61,9 @@ def setup_arg_parser(): help="Set the maximum key-value cache size", ) parser.add_argument( - "--kv-cache-file", help="The file to save the KV caches in", required=True + "--prompt-cache-file", + help="The file to save the prompt cache in", + required=True, ) parser.add_argument( "--prompt", @@ -115,7 +118,7 @@ def main(): else: prompt = args.prompt - cache = make_kv_caches(model, args.max_kv_size) + cache = make_prompt_cache(model, args.max_kv_size) y = mx.array(tokenizer.encode(prompt)) # Process the prompt @@ -137,16 +140,12 @@ def main(): print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") print("Saving...") - cache_dict = {} - for i, c in enumerate(cache): - cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :] - cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :] metadata = {} metadata["model"] = args.model metadata["chat_template"] = tokenizer.chat_template metadata["tokenizer_config"] = json.dumps(tokenizer_config) - metadata["max_kv_size"] = str(args.max_kv_size) - mx.save_safetensors(args.kv_cache_file, cache_dict, metadata) + print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") + save_prompt_cache(args.prompt_cache_file, cache, metadata) if __name__ == "__main__": diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py new file mode 100644 index 00000000..7968a868 --- /dev/null +++ b/llms/mlx_lm/chat.py @@ -0,0 +1,82 @@ +# Copyright © 2023-2024 Apple Inc. + +import argparse +import json + +import mlx.core as mx + +from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache +from .utils import load, stream_generate + +DEFAULT_TEMP = 0.0 +DEFAULT_TOP_P = 1.0 +DEFAULT_SEED = 0 +DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" + + +def setup_arg_parser(): + """Set up and return the argument parser.""" + parser = argparse.ArgumentParser(description="Chat with an LLM") + parser.add_argument( + "--model", + type=str, + help="The path to the local model directory or Hugging Face repo.", + default=DEFAULT_MODEL, + ) + parser.add_argument( + "--adapter-path", + type=str, + help="Optional path for the trained adapter weights and config.", + ) + parser.add_argument( + "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" + ) + parser.add_argument( + "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" + ) + parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") + parser.add_argument( + "--max-kv-size", + type=int, + help="Set the maximum key-value cache size", + default=None, + ) + return parser + + +def main(): + parser = setup_arg_parser() + args = parser.parse_args() + + mx.random.seed(args.seed) + + model, tokenizer = load( + args.model, + adapter_path=args.adapter_path, + tokenizer_config={"trust_remote_code": True}, + ) + + print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.") + prompt_cache = make_prompt_cache(model, args.max_kv_size) + while True: + query = input(">> ") + if query == "q": + break + messages = [{"role": "user", "content": query}] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + for response in stream_generate( + model, + tokenizer, + prompt, + temp=args.temp, + top_p=args.top_p, + prompt_cache=prompt_cache, + ): + print(response, flush=True, end="") + print() + + +if __name__ == "__main__": + main() diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py new file mode 100644 index 00000000..3bf01688 --- /dev/null +++ b/llms/mlx_lm/examples/chat.py @@ -0,0 +1,53 @@ +# Copyright © 2024 Apple Inc. + +""" +An example of a multi-turn chat with prompt caching. +""" + +from mlx_lm import generate, load +from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache + +model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit") + +# Make the initial prompt cache for the model +prompt_cache = make_prompt_cache(model) + +# User turn +prompt = "Hi my name is ." +messages = [{"role": "user", "content": prompt}] +prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) + +# Assistant response +response = generate( + model, + tokenizer, + prompt=prompt, + verbose=True, + temp=0.0, + prompt_cache=prompt_cache, +) + +# User turn +prompt = "What's my name?" +messages = [{"role": "user", "content": prompt}] +prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) + +# Assistant response +response = generate( + model, + tokenizer, + prompt=prompt, + verbose=True, + temp=0.0, + prompt_cache=prompt_cache, +) + +# Save the prompt cache to disk to reuse it at a later time +save_prompt_cache("mistral_prompt.safetensors", prompt_cache) + +# Load the prompt cache from disk +prompt_cache = load_prompt_cache("mistral_prompt.safetensors") diff --git a/llms/mlx_lm/examples/generate_response.py b/llms/mlx_lm/examples/generate_response.py index af599c1b..25730617 100644 --- a/llms/mlx_lm/examples/generate_response.py +++ b/llms/mlx_lm/examples/generate_response.py @@ -1,3 +1,5 @@ +# Copyright © 2024 Apple Inc. + from mlx_lm import generate, load # Specify the checkpoint diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 537bd853..0bf98ab2 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -6,13 +6,15 @@ import sys import mlx.core as mx +from .models.cache import load_prompt_cache from .utils import generate, load DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 -DEFAULT_TEMP = 0.6 +DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 +DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" def str2bool(string): @@ -25,7 +27,11 @@ def setup_arg_parser(): parser.add_argument( "--model", type=str, - help="The path to the local model directory or Hugging Face repo.", + help=( + "The path to the local model directory or Hugging Face repo. " + f"If no model is specified, then {DEFAULT_MODEL} is used." + ), + default=None, ) parser.add_argument( "--adapter-path", @@ -96,7 +102,7 @@ def setup_arg_parser(): default=None, ) parser.add_argument( - "--kv-cache-file", + "--prompt-cache-file", type=str, default=None, help="A file containing saved KV caches to avoid recomputing them", @@ -131,24 +137,6 @@ def colorprint_by_t0(s, t0): colorprint(color, s) -def load_kv_cache_from_file(kv_cache_file): - if kv_cache_file is None: - return None, None - - kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True) - cache_per_layer = {} - for k, x in kv_cache.items(): - layer, kv_type = k.split("_") - if layer not in cache_per_layer: - cache_per_layer[layer] = {} - cache_per_layer[layer][kv_type] = x - - cache_history = [None] * len(cache_per_layer) - for layer, c in cache_per_layer.items(): - cache_history[int(layer)] = (c["keys"], c["values"]) - return cache_history, metadata - - def main(): parser = setup_arg_parser() args = parser.parse_args() @@ -158,22 +146,33 @@ def main(): if args.cache_limit_gb is not None: mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - # Load the kv cache and metadata if a kv cache file is provided - cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file) + # Load the prompt cache and metadata if a cache file is provided + using_cache = args.prompt_cache_file is not None + if using_cache: + prompt_cache, metadata = load_prompt_cache( + args.prompt_cache_file, return_metadata=True + ) # Building tokenizer_config tokenizer_config = ( - {} if cache_history is None else json.loads(metadata["tokenizer_config"]) + {} if not using_cache else json.loads(metadata["tokenizer_config"]) ) if args.trust_remote_code: tokenizer_config["trust_remote_code"] = True if args.eos_token is not None: tokenizer_config["eos_token"] = args.eos_token - # If no model path is provided then use the one in the kv cache history model_path = args.model - if cache_history is not None and model_path is None: - model_path = metadata["model"] + if using_cache: + if model_path is None: + model_path = metadata["model"] + elif model_path != metadata["model"]: + raise ValueError( + f"Providing a different model ({model_path}) than that " + f"used to create the prompt cache ({metadata['model']}) " + "is an error." + ) + model_path = model_path or DEFAULT_MODEL model, tokenizer = load( model_path, @@ -184,7 +183,7 @@ def main(): if args.use_default_chat_template: if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template - elif cache_history is not None: + elif using_cache: tokenizer.chat_template = metadata["chat_template"] if not args.ignore_chat_template and ( @@ -203,7 +202,7 @@ def main(): # Treat the prompt as a suffix assuming that the prefix is in the # stored kv cache. - if cache_history is not None: + if using_cache: test_prompt = tokenizer.apply_chat_template( [{"role": "user", "content": ""}], tokenize=False, @@ -217,12 +216,6 @@ def main(): raise ValueError("Cannot use --colorize with --verbose=False") formatter = colorprint_by_t0 if args.colorize else None - # Determine the max kv size from the kv cache or passed arguments - max_kv_size = args.max_kv_size - if cache_history is not None: - max_kv_size = metadata["max_kv_size"] - max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None - response = generate( model, tokenizer, @@ -232,8 +225,8 @@ def main(): formatter=formatter, temp=args.temp, top_p=args.top_p, - max_kv_size=max_kv_size, - cache_history=cache_history, + max_kv_size=args.max_kv_size, + prompt_cache=prompt_cache if using_cache else None, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index dc19dd05..3628a808 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -2,145 +2,9 @@ import inspect from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, Optional import mlx.core as mx -import mlx.nn as nn - - -class KVCache: - - def __init__(self, head_dim, n_kv_heads): - self.n_kv_heads = n_kv_heads - if isinstance(head_dim, int): - self.k_head_dim = self.v_head_dim = head_dim - elif isinstance(head_dim, tuple) and len(head_dim) == 2: - self.k_head_dim, self.v_head_dim = head_dim - else: - raise ValueError("head_dim must be an int or a tuple of two ints") - self.keys = None - self.values = None - self.offset = 0 - self.step = 256 - - def update_and_fetch(self, keys, values): - prev = self.offset - if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: - B = keys.shape[0] - n_steps = (self.step + keys.shape[2] - 1) // self.step - k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim) - v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim) - new_k = mx.zeros(k_shape, keys.dtype) - new_v = mx.zeros(v_shape, values.dtype) - if self.keys is not None: - if prev % self.step != 0: - self.keys = self.keys[..., :prev, :] - self.values = self.values[..., :prev, :] - self.keys = mx.concatenate([self.keys, new_k], axis=2) - self.values = mx.concatenate([self.values, new_v], axis=2) - else: - self.keys, self.values = new_k, new_v - - self.offset += keys.shape[2] - self.keys[..., prev : self.offset, :] = keys - self.values[..., prev : self.offset, :] = values - return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] - - @property - def state(self): - return self.keys, self.values - - -class RotatingKVCache: - - def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256): - self.n_kv_heads = n_kv_heads - if isinstance(head_dim, int): - self.k_head_dim = self.v_head_dim = head_dim - elif isinstance(head_dim, tuple) and len(head_dim) == 2: - self.k_head_dim, self.v_head_dim = head_dim - else: - raise ValueError("head_dim must be an int or a tuple of two ints") - self.keep = keep - self.keys = None - self.values = None - self.offset = 0 - self.max_size = max_size - self.step = step - self._idx = 0 - - def _trim(self, trim_size, v, append=None): - to_cat = [] - if trim_size > 0: - to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]] - else: - to_cat = [v] - if append is not None: - to_cat.append(append) - return mx.concatenate(to_cat, axis=2) - - def update_and_fetch(self, keys, values): - prev = self.offset - B, _, S = keys.shape[:3] - - # Prefill mode - if S > 1: - if self.keys is None: - self.keys = keys - self.values = values - else: - # The largest size is self.max_size + S - 1 to ensure - # every token gets at least self.max_size context - trim_size = self.keys.shape[2] - self.max_size + 1 - self.keys = self._trim(trim_size, self.keys, keys) - self.values = self._trim(trim_size, self.values, values) - self.offset += S - self._idx = self.keys.shape[2] - return self.keys, self.values - - # Generation mode - # May not have hit the max size yet, so potentially - # keep growing the cache - if self.keys is None or ( - prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size - ): - new_size = min(self.step, self.max_size - prev) - k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim) - v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim) - new_k = mx.zeros(k_shape, keys.dtype) - new_v = mx.zeros(v_shape, values.dtype) - if self.keys is not None: - self.keys = mx.concatenate([self.keys, new_k], axis=2) - self.values = mx.concatenate([self.values, new_v], axis=2) - else: - self.keys, self.values = new_k, new_v - self._idx = prev - - # Trim if needed - trim_size = self.keys.shape[2] - self.max_size - if trim_size > 0: - self.keys = self._trim(trim_size, self.keys) - self.values = self._trim(trim_size, self.values) - self._idx = self.max_size - - # Rotate - if self._idx == self.max_size: - self._idx = self.keep - - # Assign - self.keys[..., self._idx : self._idx + 1, :] = keys - self.values[..., self._idx : self._idx + 1, :] = values - self.offset += 1 - self._idx += 1 - - # If the buffer is not full, slice off the end - if self.offset < self.max_size: - return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] - return self.keys, self.values - - @property - def state(self): - return self.keys, self.values @dataclass @@ -156,25 +20,30 @@ class BaseModelArgs: ) -def create_additive_causal_mask(N: int, offset: int = 0): +def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None): rinds = mx.arange(offset + N) linds = mx.arange(offset, offset + N) if offset else rinds - mask = linds[:, None] < rinds[None] + linds = linds[:, None] + rinds = rinds[None] + mask = linds < rinds + if window_size is not None: + mask = mask | (linds > rinds + window_size) return mask * -1e9 def create_attention_mask(h: mx.array, cache: Optional[Any] = None): T = h.shape[1] if T > 1: + window_size = None + offset = 0 if cache is not None and cache[0] is not None: c = cache[0] - if isinstance(c, RotatingKVCache): + if hasattr(c, "max_size"): offset = min(c.max_size - 1, c.offset) + window_size = c.max_size else: offset = c.offset - else: - offset = 0 - mask = create_additive_causal_mask(T, offset) + mask = create_causal_mask(T, offset, window_size=window_size) mask = mask.astype(h.dtype) else: mask = None diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py new file mode 100644 index 00000000..a6a56e0a --- /dev/null +++ b/llms/mlx_lm/models/cache.py @@ -0,0 +1,340 @@ +# Copyright © 2023-2024 Apple Inc. + +from typing import Any, Dict, List, Optional + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten, tree_unflatten + + +def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]: + """ + Construct the model's cache for use when cgeneration. + + This function will defer the cache construction to the model if it has a + ``make_cache`` method, otherwise it will make a default KV cache. + + Args: + model (nn.Module): The language model. + max_kv_size (Optional[int]): If provided and the model does not have a + ``make_cache`` method, a ``RotatingKVCache`` is used with a maximum + size of ``max_kv_size`` + """ + if hasattr(model, "make_cache"): + return model.make_cache() + + num_layers = len(model.layers) + if max_kv_size is not None: + return [ + RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers) + ] + else: + return [KVCache() for _ in range(num_layers)] + + +def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}): + """ + Save a pre-computed prompt cache to a file. + + Args: + file_name (str): The ``.safetensors`` file name. + cache (List[Any]): The model state. + metadata (Dict[str, str]): Optional metadata to save along with model + state. + """ + cache_data = [c.state for c in cache] + cache_info = [c.meta_state for c in cache] + cache_data = dict(tree_flatten(cache_data)) + cache_classes = [type(c).__name__ for c in cache] + cache_metadata = [cache_info, metadata, cache_classes] + cache_metadata = dict(tree_flatten(cache_metadata)) + mx.save_safetensors(file_name, cache_data, cache_metadata) + + +def load_prompt_cache(file_name, return_metadata=False): + """ + Load a prompt cache from a file. + + Args: + file_name (str): The ``.safetensors`` file name. + return_metadata (bool): Whether or not to return metadata. + Default: ``False``. + + Returns: + List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and + the metadata if requested. + """ + arrays, cache_metadata = mx.load(file_name, return_metadata=True) + arrays = tree_unflatten(list(arrays.items())) + cache_metadata = tree_unflatten(list(cache_metadata.items())) + info, metadata, classes = cache_metadata + cache = [globals()[c]() for c in classes] + for c, state, meta_state in zip(cache, arrays, info): + c.state = state + c.meta_state = meta_state + if return_metadata: + return cache, metadata + return cache + + +def can_trim_prompt_cache(cache: List[Any]) -> bool: + """ + Check if model's cache can be trimmed. + """ + return all(c.is_trimmable() for c in cache) + + +def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: + """ + Trim the model's cache by the given number of tokens. + + This function will trim the cache if possible (in-place) and return the + number of tokens that were trimmed. + + Args: + cache (List[Any]): The model's cache. + num_tokens (int): The number of tokens to trim. + + Returns: + (int): The number of tokens that were trimmed. + """ + if not can_trim_prompt_cache(cache) or len(cache) == 0: + return 0 + return [c.trim(num_tokens) for c in cache][0] + + +class _BaseCache: + @property + def state(self): + return [] + + @state.setter + def state(self, v): + if v is not None and v: + raise ValueError("This cache has no state but a state was set.") + + @property + def meta_state(self): + return "" + + @meta_state.setter + def meta_state(self, v): + if v is not None and v: + raise ValueError("This cache has no meta_state but a meta_state was set.") + + def is_trimmable(self): + return False + + +class KVCache(_BaseCache): + def __init__(self): + self.keys = None + self.values = None + self.offset = 0 + self.step = 256 + + def update_and_fetch(self, keys, values): + prev = self.offset + if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: + B, n_kv_heads, _, k_head_dim = keys.shape + v_head_dim = values.shape[3] + n_steps = (self.step + keys.shape[2] - 1) // self.step + k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim) + v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim) + new_k = mx.zeros(k_shape, keys.dtype) + new_v = mx.zeros(v_shape, values.dtype) + if self.keys is not None: + if prev % self.step != 0: + self.keys = self.keys[..., :prev, :] + self.values = self.values[..., :prev, :] + self.keys = mx.concatenate([self.keys, new_k], axis=2) + self.values = mx.concatenate([self.values, new_v], axis=2) + else: + self.keys, self.values = new_k, new_v + + self.offset += keys.shape[2] + self.keys[..., prev : self.offset, :] = keys + self.values[..., prev : self.offset, :] = values + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + + @property + def state(self): + if self.offset == self.keys.shape[2]: + return self.keys, self.values + else: + return ( + self.keys[..., : self.offset, :], + self.values[..., : self.offset, :], + ) + + @state.setter + def state(self, v): + self.keys, self.values = v + self.offset = self.keys.shape[2] + + def is_trimmable(self): + return True + + def trim(self, n): + n = min(self.offset, n) + self.offset -= n + return n + + +class RotatingKVCache(_BaseCache): + + def __init__(self, max_size=None, keep=0, step=256): + self.keep = keep + self.keys = None + self.values = None + self.offset = 0 + self.max_size = max_size + self.step = step + self._idx = 0 + + def _trim(self, trim_size, v, append=None): + to_cat = [] + if trim_size > 0: + to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]] + else: + to_cat = [v] + if append is not None: + to_cat.append(append) + return mx.concatenate(to_cat, axis=2) + + def _temporal_order(self, v): + """ + Rearrange the cache into temporal order, slicing off the end if unused. + """ + if self._idx == v.shape[2]: + return v + elif self._idx < self.offset: + return mx.concatenate( + [ + v[..., : self.keep, :], + v[..., self._idx :, :], + v[..., self.keep : self._idx, :], + ], + axis=2, + ) + else: + return v[..., : self._idx, :] + + def _update_concat(self, keys, values): + if self.keys is None: + self.keys = keys + self.values = values + else: + # Put the keys/values in temporal order to + # preserve context + self.keys = self._temporal_order(self.keys) + self.values = self._temporal_order(self.values) + + # The largest size is self.max_size + S - 1 to ensure + # every token gets at least self.max_size context + trim_size = self._idx - self.max_size + 1 + self.keys = self._trim(trim_size, self.keys, keys) + self.values = self._trim(trim_size, self.values, values) + self.offset += keys.shape[2] + self._idx = self.keys.shape[2] + return self.keys, self.values + + def _update_in_place(self, keys, values): + # May not have hit the max size yet, so potentially + # keep growing the cache + B, n_kv_heads, S, k_head_dim = keys.shape + prev = self.offset + if self.keys is None or ( + prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size + ): + v_head_dim = values.shape[3] + new_size = min(self.step, self.max_size - prev) + k_shape = (B, n_kv_heads, new_size, k_head_dim) + v_shape = (B, n_kv_heads, new_size, v_head_dim) + new_k = mx.zeros(k_shape, keys.dtype) + new_v = mx.zeros(v_shape, values.dtype) + if self.keys is not None: + self.keys = mx.concatenate([self.keys, new_k], axis=2) + self.values = mx.concatenate([self.values, new_v], axis=2) + else: + self.keys, self.values = new_k, new_v + self._idx = prev + + # Trim if needed + trim_size = self.keys.shape[2] - self.max_size + if trim_size > 0: + self.keys = self._trim(trim_size, self.keys) + self.values = self._trim(trim_size, self.values) + self._idx = self.max_size + + # Rotate + if self._idx == self.max_size: + self._idx = self.keep + + # Assign + self.keys[..., self._idx : self._idx + S, :] = keys + self.values[..., self._idx : self._idx + S, :] = values + self.offset += S + self._idx += S + + # If the buffer is not full, slice off the end + if self.offset < self.max_size: + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + return self.keys, self.values + + def update_and_fetch(self, keys, values): + if keys.shape[2] == 1: + return self._update_in_place(keys, values) + return self._update_concat(keys, values) + + @property + def state(self): + if self.offset < self.keys.shape[2]: + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + else: + return self.keys, self.values + + @state.setter + def state(self, v): + self.keys, self.values = v + + @property + def meta_state(self): + return tuple( + map(str, (self.keep, self.max_size, self.step, self.offset, self._idx)) + ) + + @meta_state.setter + def meta_state(self, v): + self.keep, self.max_size, self.step, self.offset, self._idx = map( + int, + v, + ) + + def is_trimmable(self): + return self.offset < self.max_size + + def trim(self, n): + n = min(self.offset, n) + self.offset -= n + self._idx -= n + return n + + +class MambaCache(_BaseCache): + def __init__(self): + self.cache = [None, None] + + def __setitem__(self, idx, value): + self.cache[idx] = value + + def __getitem__(self, idx): + return self.cache[idx] + + @property + def state(self): + return self.cache + + @state.setter + def state(self, v): + self.cache = v diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index cfcf2945..057c816d 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -69,7 +69,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -129,7 +129,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.input_layernorm(x) attn_h = self.self_attn(h, mask, cache) @@ -190,11 +190,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index f0214549..3b7e83d7 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -49,7 +49,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: qkv = self.Wqkv(x) @@ -92,7 +92,7 @@ class NormAttnNorm(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.attn(self.norm_1(x), mask=mask, cache=cache) x = h + x @@ -179,7 +179,7 @@ class DecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r, h = self.norm_attn_norm(x, mask, cache) out = self.ffn(h) + r @@ -249,11 +249,3 @@ class Model(nn.Module): experts = [(s, sv.T) for s, sv in experts] new_weights.update(experts) return new_weights - - @property - def head_dim(self): - return self.args.d_model // self.args.n_heads - - @property - def n_kv_heads(self): - return self.args.attn_config["kv_n_heads"] diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py index dcfa331c..03cb3b1a 100644 --- a/llms/mlx_lm/models/deepseek.py +++ b/llms/mlx_lm/models/deepseek.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from typing import Dict, Optional +from typing import Any, Dict, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .switch_layers import SwitchGLU @@ -77,7 +77,7 @@ class DeepseekAttention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, _ = x.shape @@ -108,8 +108,8 @@ class DeepseekMLP(nn.Module): def __init__( self, config: ModelArgs, - hidden_size: int | None = None, - intermediate_size: int | None = None, + hidden_size: Optional[int] = None, + intermediate_size: Optional[int] = None, ): super().__init__() self.config = config @@ -188,7 +188,7 @@ class DeepseekDecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -210,7 +210,7 @@ class DeepseekModel(nn.Module): def __call__( self, x: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.embed_tokens(x) mask = create_attention_mask(h, cache) @@ -235,7 +235,7 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ): out = self.model(inputs, cache) return self.lm_head(out) @@ -256,11 +256,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 602a9710..bb3e5184 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -2,12 +2,12 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .switch_layers import SwitchGLU @@ -38,7 +38,7 @@ class ModelArgs(BaseModelArgs): max_position_embeddings: int = 2048 rms_norm_eps: float = 1e-6 rope_theta: float = 10000.0 - rope_scaling: Optional[Dict] = None + rope_scaling: Dict = None attention_bias: bool = False @@ -172,12 +172,11 @@ class DeepseekV2Attention(nn.Module): bias=config.attention_bias, ) - if self.config.rope_scaling is not None: - mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) - scaling_factor = self.config.rope_scaling["factor"] - if mscale_all_dim: - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) - self.scale = self.scale * mscale * mscale + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scale = self.scale * mscale * mscale rope_kwargs = { key: self.config.rope_scaling[key] @@ -202,7 +201,7 @@ class DeepseekV2Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -221,17 +220,17 @@ class DeepseekV2Attention(nn.Module): k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) - k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1) - if cache is not None: q_pe = self.rope(q_pe, cache.offset) k_pe = self.rope(k_pe, cache.offset) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) keys, values = cache.update_and_fetch( mx.concatenate([k_nope, k_pe], axis=-1), values ) else: q_pe = self.rope(q_pe) k_pe = self.rope(k_pe) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) keys = mx.concatenate([k_nope, k_pe], axis=-1) queries = mx.concatenate([q_nope, q_pe], axis=-1) @@ -292,7 +291,7 @@ class MoEGate(nn.Module): scores = scores.reshape(bsz, seq_len, -1) k = self.top_k - inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]) + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] scores = mx.take_along_axis(scores, inds, axis=-1) scores = scores * self.routed_scaling_factor @@ -347,7 +346,7 @@ class DeepseekV2DecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -370,7 +369,7 @@ class DeepseekV2Model(nn.Module): def __call__( self, x: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.embed_tokens(x) mask = create_attention_mask(h, cache) @@ -395,7 +394,7 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ): out = self.model(inputs, cache) return self.lm_head(out) @@ -416,14 +415,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return ( - self.args.qk_nope_head_dim + self.args.qk_rope_head_dim, - self.args.v_head_dim, - ) - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index c6150284..61de781e 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -60,7 +60,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -113,7 +113,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -173,11 +173,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.head_dim - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index 1d410a15..ccc327a8 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -64,7 +64,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) @@ -135,13 +135,11 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: - r = self.self_attn(self.input_layernorm(x.astype(mx.float32)), mask, cache) + r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + self.post_attention_layernorm(r) - r = self.mlp(self.pre_feedforward_layernorm(h).astype(mx.float16)).astype( - mx.float32 - ) + r = self.mlp(self.pre_feedforward_layernorm(h)) out = h + self.post_feedforward_layernorm(r) return out @@ -200,11 +198,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.head_dim - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index 8a770936..97d9a8ff 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -46,7 +46,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -100,7 +100,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attn(self.ln_1(x), mask, cache) h = x + r @@ -196,11 +196,3 @@ class Model(nn.Module): @property def layers(self): return self.model.h - - @property - def head_dim(self): - return self.args.n_embd // self.args.n_head - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 652eb9e4..068046ea 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -57,7 +57,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -114,7 +114,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attn(self.ln_1(x), mask, cache) h = x + r @@ -184,11 +184,3 @@ class Model(nn.Module): @property def layers(self): return self.transformer.h - - @property - def head_dim(self): - return self.args.n_embd // self.args.n_head - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index c2aaa9ea..9f662491 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -60,7 +60,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -120,7 +120,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: residual = x # NeoX runs attention and feedforward network in parallel. @@ -214,11 +214,3 @@ class Model(nn.Module): @property def layers(self): return self.model.h - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index bcc0cf0c..5264cb57 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -116,7 +116,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -171,7 +171,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attention(self.attention_norm(x), mask, cache) h = x + r @@ -236,11 +236,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index c4a947a5..7da6b333 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -171,7 +171,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -233,7 +233,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -303,13 +303,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return ( - self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads - ) - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 26408426..d2740dc1 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -7,6 +7,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs +from .cache import MambaCache @dataclass @@ -45,21 +46,6 @@ class ModelArgs(BaseModelArgs): self.time_step_rank = math.ceil(self.hidden_size / 16) -class MambaCache: - def __init__(self): - self.cache = [None, None] - - def __setitem__(self, idx, value): - self.cache[idx] = value - - def __getitem__(self, idx): - return self.cache[idx] - - @property - def state(self): - return self.cache - - class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias=True, padding=0): super().__init__() @@ -223,7 +209,7 @@ class Model(nn.Module): weights[k] = v.moveaxis(2, 1) return weights - def make_cache(self, batch_size: int = 1): + def make_cache(self): return [MambaCache() for _ in range(len(self.layers))] @property diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index df0670be..4ac3c3b4 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -85,7 +85,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ): B, L, _ = x.shape @@ -135,7 +135,7 @@ class DecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers)) @@ -205,11 +205,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 2db57752..20944fe3 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -2,7 +2,7 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -66,7 +66,7 @@ class MixtralAttention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -138,7 +138,7 @@ class MixtralDecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -215,11 +215,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py index ef55d1d7..3ea06e27 100644 --- a/llms/mlx_lm/models/nemotron.py +++ b/llms/mlx_lm/models/nemotron.py @@ -2,12 +2,12 @@ from dataclasses import dataclass from functools import partial -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -94,7 +94,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, _ = x.shape @@ -151,7 +151,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -215,13 +215,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return ( - self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads - ) - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 59849c96..3627df06 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -1,8 +1,8 @@ # Copyright © 2023-2024 Apple Inc. +import sys from dataclasses import dataclass -from sys import exit -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -13,7 +13,7 @@ try: import hf_olmo except ImportError: print("To run olmo install ai2-olmo: pip install ai2-olmo") - exit(1) + sys.exit(1) @dataclass @@ -68,7 +68,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -98,7 +98,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attend(self.att_norm(x), mask, cache) h = x + r @@ -174,11 +174,3 @@ class Model(nn.Module): @property def layers(self): return self.model.transformer.blocks - - @property - def head_dim(self): - return self.args.d_model // self.args.n_heads - - @property - def n_kv_heads(self): - return self.args.n_heads diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 19d3c027..090e21c6 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -80,7 +80,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -152,7 +152,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attn(self.attn_norm(x), mask, cache) h = x + r @@ -218,11 +218,3 @@ class Model(nn.Module): @property def layers(self): return self.transformer.layers - - @property - def head_dim(self): - return self.args.head_dim - - @property - def n_kv_heads(self): - return self.args.num_kv_heads diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index fd3fd709..56b383b2 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -162,19 +162,11 @@ class Model(nn.Module): def __call__( self, x: mx.array, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: y = self.model(x, cache) return self.lm_head(y) @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index 112ade7d..9ef76f04 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .su_rope import SuScaledRotaryEmbedding @@ -84,7 +84,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -143,7 +143,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -202,11 +202,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index 665dbc73..6b0759b4 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -3,12 +3,12 @@ import math from dataclasses import dataclass from functools import partial -from typing import Dict, Optional, Tuple, Union +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -22,14 +22,14 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int layer_norm_epsilon: float vocab_size: int - num_key_value_heads: Optional[int] = None + num_key_value_heads: int mup_attn_multiplier: float = 1.0 mup_use_scaling: bool = True mup_embedding_multiplier: float = 10.0 mup_width_multiplier: float = 8.0 rope_embedding_base: float = 1000000 rope_position_scale: float = 1.0 - blocksparse_block_size: Tuple[int] = (64,) + blocksparse_block_size: int = 64 blocksparse_num_local_blocks: int = 16 blocksparse_vert_stride: int = 8 @@ -61,7 +61,6 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads - assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_q_per_kv = n_heads // n_kv_heads @@ -161,7 +160,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -230,7 +229,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -304,16 +303,8 @@ class Model(nn.Module): def layers(self): return self.model.layers - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - def sanitize(self, weights): # Remove unused precomputed rotary freqs return { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py index db6bd4b5..ca20a388 100644 --- a/llms/mlx_lm/models/phimoe.py +++ b/llms/mlx_lm/models/phimoe.py @@ -173,6 +173,7 @@ class PhiMoEModel(nn.Module): class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() + self.model_type = args.model_type self.args = args self.model = PhiMoEModel(args) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True) @@ -208,11 +209,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index bb67615d..865d0d8e 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -168,8 +168,8 @@ class Model(nn.Module): self, x: mx.array, mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: mask = create_attention_mask(x, cache) y = self.transformer(x, mask, cache) @@ -193,11 +193,3 @@ class Model(nn.Module): @property def layers(self): return self.transformer.h - - @property - def head_dim(self): - return self.args.model_dim // self.args.num_heads - - @property - def n_kv_heads(self): - return self.args.num_heads diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index 5d2b7586..b0fd1a6c 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn @@ -62,8 +62,8 @@ class Attention(nn.Module): self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + cache: Optional[Any] = None, + ) -> mx.array: bsz, q_len, _ = hidden_states.shape queries = self.q_proj(hidden_states) @@ -89,6 +89,9 @@ class Attention(nn.Module): queries = self.rotary_emb(queries) keys = self.rotary_emb(keys) + keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1]) + values = mx.tile(values, [1, self.config.n_shared_head, 1, 1]) + output = mx.fast.scaled_dot_product_attention( queries, keys, @@ -127,8 +130,8 @@ class PlamoDecoderLayer(nn.Module): self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> Tuple[Any, ...]: + cache: Optional[Any] = None, + ): # from LlamaDecoder residual = hidden_states @@ -169,8 +172,8 @@ class PlamoModel(nn.Module): def __call__( self, inputs: mx.array, - cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None, - ) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]: + cache: Optional[Any] = None, + ) -> mx.array: h = self.embed_tokens(inputs) mask = create_attention_mask(h, cache) @@ -197,19 +200,11 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, - cache: Optional[List[Tuple[mx.array, mx.array]]] = None, - ) -> Tuple[mx.array, mx.array]: + cache: Optional[Any] = None, + ) -> mx.array: out = self.model(inputs, cache) return self.lm_head(out) @property def layers(self): return self.model.layers.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_attention_heads // self.args.n_shared_head diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 6d2c7bbf..2b69d5ec 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -1,7 +1,6 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Tuple import mlx.core as mx import mlx.nn as nn @@ -149,19 +148,11 @@ class Model(nn.Module): self, x: mx.array, mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: y = self.transformer(x, mask, cache) return self.lm_head(y) @property def layers(self): return self.transformer.h - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_attention_heads diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index b3ce02a3..4e7858de 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -70,7 +70,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -124,7 +124,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -196,11 +196,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index ff7831f3..d199116f 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -2,12 +2,12 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .switch_layers import SwitchGLU @@ -70,7 +70,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -162,7 +162,7 @@ class Qwen2MoeDecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -236,11 +236,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 34750ace..06a307a6 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -7,13 +7,13 @@ from typing import List, Literal, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask +from .cache import MambaCache, RotatingKVCache @dataclass class ModelArgs(BaseModelArgs): model_type: str - hidden_size: int attention_bias: bool conv1d_width: int hidden_size: int @@ -36,59 +36,6 @@ class ModelArgs(BaseModelArgs): self.block_types = self._block_types -def create_window_causal_mask(N: int, window_size: int): - inds = mx.arange(N) - linds = inds[:, None] - rinds = inds[None] - mask = (linds < rinds) | (linds > rinds + window_size) - return mask * -1e9 - - -class RecurrentCache: - - def __init__(self): - self._cache = (None, None) - - def __getitem__(self, idx): - return self._cache[idx] - - def update(self, conv_state, recurrent_state): - self._cache = (conv_state, recurrent_state) - - def state(self): - return self._cache - - -class WindowKVCache: - - def __init__(self, window_size): - self.keys = None - self.values = None - self.offset = 0 - self.window_size = window_size - - def update_and_fetch(self, keys, values): - # TODO consider using rotating buffer here - # especially for very long generations - def _update(x, v): - t = x.shape[2] - self.window_size - if t > 0: - x = x[..., t:, :] - return mx.concatenate([x, v], axis=2) - - self.offset += keys.shape[2] - if self.keys is None: - self.keys = keys - self.values = values - else: - self.keys = _update(self.keys, keys) - self.values = _update(self.values, values) - return self.keys, self.values - - def state(self): - return self.keys, self.values - - class RMSNorm(nn.Module): def __init__(self, dims: int, eps: float = 1e-5): super().__init__() @@ -136,31 +83,22 @@ class Conv1d(nn.Module): kernel_size: int, ): super().__init__() - self.weight = mx.zeros((kernel_size, channels)) + self.weight = mx.zeros((channels, kernel_size, 1)) self.bias = mx.zeros((channels,)) def __call__(self, x, cache=None): - w = self.weight.T[..., None] - kw, groups = self.weight.shape - if cache is not None: - l = [] - # Pad the cache if needed - if cache.shape[1] < kw - 1: - l.append( - mx.zeros( - (x.shape[0], kw - 1 - cache.shape[1], groups), dtype=x.dtype - ) - ) - l.extend([cache, x]) - x = mx.concatenate(l, axis=1) - y = (x * w.swapaxes(0, 2)).sum(axis=1, keepdims=True) - else: - y = mx.conv_general(x, w, padding=([kw - 1], [0]), groups=groups) + B, L, C = x.shape + groups, K, _ = self.weight.shape - # The cache is always kw - 1 - cache = x[:, max(x.shape[1] - kw + 1, 0) :, :] + if cache is not None: + x = mx.concatenate([cache, x], axis=1) + else: + x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) + + y = mx.conv_general(x, self.weight, groups=groups) y = y + self.bias - return y, cache + + return y, x[:, -K + 1 :, :] class RGLRU(nn.Module): @@ -269,19 +207,9 @@ class RecurrentBlock(nn.Module): # x branch. x = self.linear_x(x) if cache is None: - conv_state, recurrent_state = (None, None) - else: - conv_state, recurrent_state = cache[0], cache[1] - x, conv_state = self.conv_1d( - x=x, - cache=conv_state, - ) - x, recurrent_state = self.rg_lru( - x=x, - cache=recurrent_state, - ) - if cache is not None: - cache.update(conv_state, recurrent_state) + cache = [None, None] + x, cache[0] = self.conv_1d(x=x, cache=cache[0]) + x, cache[1] = self.rg_lru(x=x, cache=cache[1]) x = x * y x = self.linear_out(x) @@ -467,12 +395,14 @@ class Griffin(nn.Module): if self.scale_by_sqrt_dim: x = x * math.sqrt(x.shape[-1]) - mask = None - if x.shape[1] > 1: - mask = create_window_causal_mask( - x.shape[1], self.config.attention_window_size - ) - mask = mask.astype(x.dtype) + if cache is None: + cache = [None] * len(self.layers) + + for i, block in enumerate(self.layers): + if block.temporal_block_type != "recurrent": + mask_cache = [cache[i]] + + mask = create_attention_mask(x, mask_cache) for i, block in enumerate(self.layers): x = block(x, mask=mask, cache=cache[i]) @@ -485,6 +415,7 @@ class Model(nn.Module): def __init__(self, config): self.args = config self.model = Griffin(config) + self.model_type = config.model_type self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def __call__(self, tokens: mx.array, cache=None) -> mx.array: @@ -508,10 +439,9 @@ class Model(nn.Module): return self.model.layers def sanitize(self, weights): - # Remove unused precomputed rotary freqs for k, v in weights.items(): if "conv_1d.weight" in k and v.ndim == 3: - weights[k] = v.squeeze(1).T + weights[k] = v.moveaxis(2, 1) if "lm_head.weight" not in weights: self.pop("lm_head") return weights @@ -520,7 +450,7 @@ class Model(nn.Module): cache = [] for layer in self.layers: if layer.temporal_block_type == "recurrent": - cache.append(RecurrentCache()) + cache.append(MambaCache()) else: - cache.append(WindowKVCache(self.args.attention_window_size)) + cache.append(RotatingKVCache(max_size=self.args.attention_window_size)) return cache diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index b340de28..11202b02 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -2,7 +2,6 @@ import math from dataclasses import dataclass -from typing import Tuple import mlx.core as mx import mlx.nn as nn @@ -198,8 +197,8 @@ class Model(nn.Module): self, x: mx.array, mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: mask = create_attention_mask(x, cache) y = self.model(x, mask, cache) return self.lm_head(y) @@ -207,11 +206,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 9cec0e39..ce0a2ec5 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -45,7 +45,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -100,7 +100,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -164,11 +164,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 42962b54..ec659969 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -3,19 +3,38 @@ import argparse import json import logging +import platform import time import uuid import warnings +from dataclasses import dataclass, field from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path -from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union +from typing import ( + Any, + Dict, + List, + Literal, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) import mlx.core as mx from huggingface_hub import scan_cache_dir +from ._version import __version__ +from .models.cache import make_prompt_cache from .utils import generate_step, load +def get_system_fingerprint(): + gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else "" + return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}" + + class StopCondition(NamedTuple): stop_met: bool trim_length: int @@ -94,6 +113,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): return prompt.rstrip() +@dataclass +class PromptCache: + cache: List[Any] = field(default_factory=list) + model_key: Tuple[str, Optional[str]] = ("", None) + tokens: List[int] = field(default_factory=list) + + class ModelProvider: def __init__(self, cli_args: argparse.Namespace): """Load models on demand and persist them across the whole process.""" @@ -156,12 +182,21 @@ class ModelProvider: class APIHandler(BaseHTTPRequestHandler): - def __init__(self, model_provider: ModelProvider, *args, **kwargs): + def __init__( + self, + model_provider: ModelProvider, + *args, + prompt_cache: Optional[PromptCache] = None, + system_fingerprint: Optional[str] = None, + **kwargs, + ): """ Create static request specific metadata """ self.created = int(time.time()) self.model_provider = model_provider + self.prompt_cache = prompt_cache or PromptCache() + self.system_fingerprint = system_fingerprint or get_system_fingerprint() super().__init__(*args, **kwargs) def _set_cors_headers(self): @@ -215,7 +250,9 @@ class APIHandler(BaseHTTPRequestHandler): self.stream_options = self.body.get("stream_options", None) self.requested_model = self.body.get("model", "default_model") self.adapter = self.body.get("adapters", None) - self.max_tokens = self.body.get("max_tokens", 100) + self.max_tokens = self.body.get("max_completion_tokens", None) + if self.max_tokens is None: + self.max_tokens = self.body.get("max_tokens", 512) self.temperature = self.body.get("temperature", 1.0) self.top_p = self.body.get("top_p", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0) @@ -343,7 +380,7 @@ class APIHandler(BaseHTTPRequestHandler): # Static response response = { "id": self.request_id, - "system_fingerprint": f"fp_{uuid.uuid4()}", + "system_fingerprint": self.system_fingerprint, "object": self.object_type, "model": self.requested_model, "created": self.created, @@ -388,16 +425,30 @@ class APIHandler(BaseHTTPRequestHandler): return response + def get_prompt_cache(self, prompt): + cache_len = len(self.prompt_cache.tokens) + if ( + self.prompt_cache.model_key != self.model_provider.model_key + or cache_len >= len(prompt) + or self.prompt_cache.tokens != prompt[:cache_len] + ): + self.prompt_cache.model_key = self.model_provider.model_key + self.prompt_cache.cache = make_prompt_cache(self.model_provider.model) + else: + prompt = prompt[cache_len:] + self.prompt_cache.tokens.extend(prompt) + return prompt + def handle_completion( self, - prompt: mx.array, + prompt: List[int], stop_id_sequences: List[List[int]], ): """ Generate a response to a prompt and send it to the client in a single batch. Args: - prompt (mx.array): The prompt, in token form inside of a mlx array + prompt (List[int]): The tokenized prompt. stop_id_sequences (List[List[int]]): A list of stop words passed to the stopping_criteria function """ @@ -409,17 +460,21 @@ class APIHandler(BaseHTTPRequestHandler): logging.debug(f"Starting completion:") token_logprobs = [] top_tokens = [] - for (token, logprobs), _ in zip( + + prompt = self.get_prompt_cache(prompt) + + for _, (token, logprobs) in zip( + range(self.max_tokens), generate_step( - prompt=prompt, + prompt=mx.array(prompt), model=self.model, temp=self.temperature, top_p=self.top_p, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, logit_bias=self.logit_bias, + prompt_cache=self.prompt_cache.cache, ), - range(self.max_tokens), ): detokenizer.add_token(token) logging.debug(detokenizer.text) @@ -430,7 +485,7 @@ class APIHandler(BaseHTTPRequestHandler): top_indices = sorted_indices[: self.logprobs] top_logprobs = logprobs[top_indices] top_token_info = zip(top_indices.tolist(), top_logprobs.tolist()) - top_tokens.append(dict(top_token_info)) + top_tokens.append(tuple(top_token_info)) token_logprobs.append(logprobs[token].item()) @@ -445,6 +500,7 @@ class APIHandler(BaseHTTPRequestHandler): ) break + self.prompt_cache.tokens.extend(tokens) detokenizer.finalize() text = ( detokenizer.text @@ -474,7 +530,7 @@ class APIHandler(BaseHTTPRequestHandler): def handle_stream( self, - prompt: mx.array, + prompt: List[int], stop_id_sequences: List[List[int]], ): """ @@ -482,7 +538,7 @@ class APIHandler(BaseHTTPRequestHandler): Sent Events (SSE) stream. Args: - prompt (mx.array): The prompt, in token form inside of a mlx array + prompt (mx.array): The tokenized prompt stop_id_sequences (List[List[int]]): A list of stop words passed to the stopping_criteria function """ @@ -496,16 +552,19 @@ class APIHandler(BaseHTTPRequestHandler): stop_sequence_suffix = None logging.debug(f"Starting stream:") - for (token, _), _ in zip( + prompt = self.get_prompt_cache(prompt) + + for _, (token, _) in zip( + range(self.max_tokens), generate_step( - prompt=prompt, + prompt=mx.array(prompt), model=self.model, temp=self.temperature, top_p=self.top_p, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, + prompt_cache=self.prompt_cache.cache, ), - range(self.max_tokens), ): detokenizer.add_token(token) logging.debug(detokenizer.text) @@ -531,9 +590,12 @@ class APIHandler(BaseHTTPRequestHandler): continue new_text = detokenizer.last_segment - response = self.generate_response(new_text, None) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() + if new_text: + response = self.generate_response(new_text, None) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + + self.prompt_cache.tokens.extend(tokens) # check is there any remaining text to send detokenizer.finalize() @@ -559,7 +621,7 @@ class APIHandler(BaseHTTPRequestHandler): ): response = { "id": self.request_id, - "system_fingerprint": f"fp_{uuid.uuid4()}", + "system_fingerprint": self.system_fingerprint, "object": "chat.completion", "model": self.requested_model, "created": self.created, @@ -572,7 +634,7 @@ class APIHandler(BaseHTTPRequestHandler): } return response - def handle_chat_completions(self) -> mx.array: + def handle_chat_completions(self) -> List[int]: """ Handle a chat completion request. @@ -587,7 +649,6 @@ class APIHandler(BaseHTTPRequestHandler): self.object_type = ( "chat.completions.chunk" if self.stream else "chat.completions" ) - if ( hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template @@ -602,9 +663,9 @@ class APIHandler(BaseHTTPRequestHandler): prompt = convert_chat(body["messages"], body.get("role_mapping")) prompt = self.tokenizer.encode(prompt) - return mx.array(prompt) + return prompt - def handle_text_completions(self) -> mx.array: + def handle_text_completions(self) -> List[int]: """ Handle a text completion request. @@ -614,11 +675,8 @@ class APIHandler(BaseHTTPRequestHandler): # Determine response type self.request_id = f"cmpl-{uuid.uuid4()}" self.object_type = "text_completion" - assert "prompt" in self.body, "Request did not contain a prompt" - prompt_text = self.body["prompt"] - prompt = self.tokenizer.encode(prompt_text) - return mx.array(prompt) + return self.tokenizer.encode(self.body["prompt"]) def do_GET(self): """ @@ -669,9 +727,16 @@ def run( handler_class=APIHandler, ): server_address = (host, port) + prompt_cache = PromptCache() httpd = server_class( server_address, - lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs), + lambda *args, **kwargs: handler_class( + model_provider, + prompt_cache=prompt_cache, + system_fingerprint=get_system_fingerprint(), + *args, + **kwargs, + ), ) warnings.warn( "mlx_lm.server is not recommended for production as " diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 04bbbcc5..d8694d86 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -97,6 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer): def text(self): if self._current_tokens: self._current_text = self._tokenizer.decode(self._current_tokens) + if ( + self._tokenizer.clean_up_tokenization_spaces + and self._current_text[-1] == " " + ): + self._current_text = self._current_text[:-1] if self._current_text and self._current_text[-1] == "\n": self._tokens.extend(self._current_tokens) self._text += self._current_text @@ -164,9 +169,11 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): """ _byte_decoder = None + _space_matches = (".", "?", "!", ",", "'", "n't", "'m", "'s", "'ve", "'re") - def __init__(self, tokenizer, trim_space=False): - self.trim_space = trim_space + def __init__(self, tokenizer): + + self.clean_spaces = tokenizer.clean_up_tokenization_spaces # Extract the tokens in a list from id to text self.tokenmap = [None] * len(tokenizer.vocab) @@ -185,17 +192,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): self.text = "" self.tokens = [] + def _maybe_trim_space(self, current_text): + if current_text[0] != " ": + return current_text + elif not self.text: + return current_text[1:] + elif self.clean_spaces and current_text[1:].startswith(self._space_matches): + return current_text[1:] + return current_text + def add_token(self, token): v = self.tokenmap[token] - # if the token starts with space if self._byte_decoder[v[0]] == 32: current_text = bytearray( self._byte_decoder[c] for c in self._unflushed ).decode("utf-8") - if self.text or not self.trim_space: - self.text += current_text - else: - self.text += _remove_space(current_text) + self.text += self._maybe_trim_space(current_text) self._unflushed = v else: self._unflushed += v @@ -204,10 +216,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( "utf-8" ) - if self.text or not self.trim_space: - self.text += current_text - else: - self.text += _remove_space(current_text) + self.text += self._maybe_trim_space(current_text) self._unflushed = "" @classmethod @@ -303,14 +312,7 @@ def _is_spm_decoder_no_space(decoder): def _is_bpe_decoder(decoder): - _target_description = { - "type": "ByteLevel", - "add_prefix_space": False, - "trim_offsets": False, - "use_regex": False, - } - - return _match(_target_description, decoder) + return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" def load_tokenizer(model_path, tokenizer_config_extra={}): diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 54a96457..4f872982 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -18,7 +18,7 @@ from mlx.utils import tree_flatten from transformers import PreTrainedTokenizer # Local imports -from .models.base import KVCache, RotatingKVCache +from .models import base, cache from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model @@ -124,26 +124,6 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float) return logits -def make_kv_caches( - model: nn.Module, max_kv_size: Optional[int] = None -) -> List[Union[KVCache, RotatingKVCache]]: - if hasattr(model, "make_cache"): - return model.make_cache() - - kv_heads = ( - [model.n_kv_heads] * len(model.layers) - if isinstance(model.n_kv_heads, int) - else model.n_kv_heads - ) - if max_kv_size is not None: - return [ - RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4) - for n in kv_heads - ] - else: - return [KVCache(model.head_dim, n) for n in kv_heads] - - def generate_step( prompt: mx.array, model: nn.Module, @@ -155,7 +135,7 @@ def generate_step( min_tokens_to_keep: int = 1, prefill_step_size: int = 512, max_kv_size: Optional[int] = None, - cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None, + prompt_cache: Optional[Any] = None, logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: @@ -180,6 +160,8 @@ def generate_step( prefill_step_size (int): Step size for processing the prompt. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. logit_bias (dictionary, optional): Additive logit bias. logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed @@ -237,20 +219,13 @@ def generate_step( tokens = None # Create the KV cache for generation - cache = make_kv_caches(model, max_kv_size) - - if cache_history is not None: - if len(cache_history) != len(cache): - raise ValueError("Wrong number of layers in the cache history") - - # Set the history in the cache objects and evaluate them to prepare for - # generation. - for c, h in zip(cache, cache_history): - c.update_and_fetch(h[0], h[1]) - mx.eval([c.state for c in cache]) + if prompt_cache is None: + prompt_cache = cache.make_prompt_cache(model, max_kv_size) + elif len(prompt_cache) != len(model.layers): + raise ValueError("Wrong number of layers in the prompt cache.") def _step(y): - logits = model(y[None], cache=cache) + logits = model(y[None], cache=prompt_cache) logits = logits[:, -1, :] if logits_processor: @@ -264,16 +239,17 @@ def generate_step( return y, logprobs.squeeze(0) while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=cache) - mx.eval([c.state for c in cache]) + model(y[:prefill_step_size][None], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) y = y[prefill_step_size:] + mx.metal.clear_cache() y, logprobs = _step(y) - mx.async_eval(y) + mx.async_eval(y, logprobs) while True: next_y, next_logprobs = _step(y) - mx.async_eval(next_y) + mx.async_eval(next_y, next_logprobs) yield y.item(), logprobs y, logprobs = next_y, next_logprobs @@ -305,9 +281,9 @@ def stream_generate( detokenizer = tokenizer.detokenizer detokenizer.reset() - for (token, _), n in zip( - generate_step(prompt_tokens, model, **kwargs), + for n, (token, _) in zip( range(max_tokens), + generate_step(prompt_tokens, model, **kwargs), ): if token == tokenizer.eos_token_id: break @@ -357,9 +333,9 @@ def generate( tic = time.perf_counter() detokenizer.reset() - for (token, logprobs), n in zip( - generate_step(prompt_tokens, model, **kwargs), + for n, (token, logprobs) in zip( range(max_tokens), + generate_step(prompt_tokens, model, **kwargs), ): if n == 0: prompt_time = time.perf_counter() - tic @@ -372,7 +348,9 @@ def generate( if formatter: # We have to finalize so that the prob corresponds to the last segment detokenizer.finalize() - formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) + with mx.stream(mx.cpu): + prob = mx.exp(logprobs[token]).item() + formatter(detokenizer.last_segment, prob) else: print(detokenizer.last_segment, end="", flush=True) diff --git a/llms/setup.py b/llms/setup.py index e2cfe0cd..1c696dc0 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -32,6 +32,7 @@ setup( entry_points={ "console_scripts": [ "mlx_lm.cache_prompt = mlx_lm.cache_prompt:main", + "mlx_lm.chat = mlx_lm.chat:main", "mlx_lm.convert = mlx_lm.convert:main", "mlx_lm.fuse = mlx_lm.fuse:main", "mlx_lm.generate = mlx_lm.generate:main", diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index cd7e7fd0..1efde5ae 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -1,17 +1,15 @@ # Copyright © 2024 Apple Inc. - import unittest import mlx.core as mx from mlx.utils import tree_map -from mlx_lm.models.base import KVCache, RotatingKVCache -from mlx_lm.utils import make_kv_caches +from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache class TestModels(unittest.TestCase): def test_kv_cache(self): - cache = KVCache(32, 4) + cache = KVCache() k = mx.ones((1, 4, 1, 32), mx.float16) v = mx.ones((1, 4, 1, 32), mx.float16) @@ -32,7 +30,7 @@ class TestModels(unittest.TestCase): def test_rotating_kv_cache(self): b, h, d = 1, 2, 32 - cache = RotatingKVCache(d, h, max_size=8, step=4) + cache = RotatingKVCache(max_size=8, step=4) k = mx.random.uniform(shape=(b, h, 2, d)) v = mx.random.uniform(shape=(b, h, 2, d)) @@ -65,7 +63,7 @@ class TestModels(unittest.TestCase): idx %= 8 # Try with nonzero keep - cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2) + cache = RotatingKVCache(max_size=8, step=4, keep=2) # Check a large update k = mx.random.uniform(shape=(b, h, 20, d)) @@ -88,6 +86,46 @@ class TestModels(unittest.TestCase): if idx >= 8: idx = 2 + def test_rotating_kv_cache_chat_mode(self): + # Test that the rotating kv cache can handle + # alternating prompt/prefill with generation + d = 4 + h = 2 + cache = RotatingKVCache(max_size=18, step=4) + + x = mx.random.uniform(shape=(1, h, 8, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 8) + self.assertEqual(cache.offset, 8) + + x = mx.random.uniform(shape=(1, h, 1, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 9) + self.assertEqual(cache.offset, 9) + self.assertTrue(mx.allclose(x, k[..., 8:9, :])) + + x = mx.random.uniform(shape=(1, h, 2, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 11) + self.assertEqual(cache.offset, 11) + self.assertTrue(mx.allclose(x, k[..., 9:11, :])) + + x = mx.random.uniform(shape=(1, h, 3, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 14) + self.assertEqual(cache.offset, 14) + self.assertTrue(mx.allclose(x, k[..., 11:14, :])) + + x = mx.random.uniform(shape=(1, h, 6, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(cache.offset, 20) + self.assertTrue(mx.allclose(x, k[..., -6:, :])) + + x = mx.random.uniform(shape=(1, h, 2, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(cache.offset, 22) + self.assertTrue(mx.allclose(x, k[..., -2:, :])) + def model_test_runner(self, model, model_type, vocab_size, num_layers): self.assertEqual(len(model.layers), num_layers) @@ -101,7 +139,7 @@ class TestModels(unittest.TestCase): self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) - cache = make_kv_caches(model) + cache = make_prompt_cache(model) outputs = model(inputs, cache) self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) @@ -549,6 +587,179 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_deepseek(self): + from mlx_lm.models import deepseek + + args = deepseek.ModelArgs( + model_type="deepseek", + vocab_size=1024, + hidden_size=128, + intermediate_size=256, + moe_intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=4, + ) + model = deepseek.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_deepseek_v2(self): + from mlx_lm.models import deepseek_v2 + + args = deepseek_v2.ModelArgs( + model_type="deepseek_v2", + vocab_size=1024, + hidden_size=128, + intermediate_size=256, + moe_intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + kv_lora_rank=4, + q_lora_rank=4, + qk_rope_head_dim=32, + v_head_dim=16, + qk_nope_head_dim=32, + rope_scaling={ + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn", + }, + ) + model = deepseek_v2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_gemma2(self): + from mlx_lm.models import gemma2 + + args = gemma2.ModelArgs( + model_type="gemma2", + hidden_size=128, + num_hidden_layers=4, + intermediate_size=256, + num_attention_heads=2, + head_dim=32, + rms_norm_eps=1e-4, + vocab_size=1024, + num_key_value_heads=2, + ) + model = gemma2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_gpt_bigcode(self): + from mlx_lm.models import gpt_bigcode + + args = gpt_bigcode.ModelArgs( + model_type="gpt_bigcode", + n_embd=128, + n_layer=128, + n_inner=256, + n_head=4, + n_positions=1000, + layer_norm_epsilon=1e-5, + vocab_size=1024, + ) + model = gpt_bigcode.Model(args) + self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer) + + def test_nemotron(self): + from mlx_lm.models import nemotron + + args = nemotron.ModelArgs( + model_type="nemotron", + hidden_size=128, + hidden_act="gelu", + num_hidden_layers=4, + intermediate_size=256, + num_attention_heads=4, + norm_eps=1e-5, + vocab_size=1024, + num_key_value_heads=2, + ) + model = nemotron.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_phi3small(self): + from mlx_lm.models import phi3small + + args = phi3small.ModelArgs( + model_type="phi3small", + hidden_size=128, + dense_attention_every_n_layers=2, + ff_intermediate_size=256, + gegelu_limit=1.0, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + layer_norm_epsilon=1e-4, + vocab_size=1000, + ) + model = phi3small.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_phimoe(self): + from mlx_lm.models import phimoe + + args = phimoe.ModelArgs( + model_type="phimoe", + vocab_size=320, + hidden_size=128, + intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=4, + rope_scaling={ + "long_factor": [1.0] * 16, + "long_mscale": 1.243163121016122, + "original_max_position_embeddings": 4096, + "short_factor": [1.0] * 16, + "short_mscale": 1.243163121016122, + "type": "longrope", + }, + ) + model = phimoe.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_recurrent_gemma(self): + from mlx_lm.models import recurrent_gemma + + args = recurrent_gemma.ModelArgs( + model_type="recurrent_gemma", + hidden_size=128, + attention_bias=False, + conv1d_width=3, + intermediate_size=256, + logits_soft_cap=1.0, + num_attention_heads=4, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-4, + rope_theta=1000, + attention_window_size=1024, + vocab_size=1000, + block_types=["recurrent", "recurrent", "attention"], + ) + model = recurrent_gemma.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + if __name__ == "__main__": unittest.main() diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py new file mode 100644 index 00000000..64cd9486 --- /dev/null +++ b/llms/tests/test_prompt_cache.py @@ -0,0 +1,243 @@ +# Copyright © 2024 Apple Inc. + +import copy +import os +import tempfile +import unittest + +import mlx.core as mx +from mlx_lm.models.cache import ( + KVCache, + MambaCache, + RotatingKVCache, + load_prompt_cache, + make_prompt_cache, + save_prompt_cache, + trim_prompt_cache, +) +from mlx_lm.utils import generate_step, load + +HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" + + +class TestPromptCache(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.test_dir_fid = tempfile.TemporaryDirectory() + cls.test_dir = cls.test_dir_fid.name + + @classmethod + def tearDownClass(cls): + cls.test_dir_fid.cleanup() + + def test_save_load(self): + cache = [KVCache() for _ in range(4)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 4)) + c.update_and_fetch(x, x) + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + self.assertTrue(len(cache), len(loaded_cache)) + for c, lc in zip(cache, loaded_cache): + self.assertEqual(c.offset, lc.offset) + self.assertTrue(mx.array_equal(c.state[0], lc.state[0])) + self.assertTrue(mx.array_equal(c.state[1], lc.state[1])) + + # Test with metadata + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + metadata = {"a": "b", "c": "d"} + save_prompt_cache(cache_file, cache, metadata) + _, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True) + self.assertEqual(metadata, loaded_metadata) + + def test_save_load_rotating_cache(self): + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + + # Test with rotating cache + cache = [RotatingKVCache(max_size=8, keep=2) for _ in range(4)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 4)) + c.update_and_fetch(x, x) + + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + self.assertTrue(len(cache), len(loaded_cache)) + for c, lc in zip(cache, loaded_cache): + self.assertEqual(c.offset, lc.offset) + self.assertEqual(c.keep, lc.keep) + self.assertEqual(c.max_size, lc.max_size) + self.assertEqual(c.step, lc.step) + self.assertTrue(mx.array_equal(c.state[0], lc.state[0])) + self.assertTrue(mx.array_equal(c.state[1], lc.state[1])) + + # Do a couple single token updates to get a rotation + for _ in range(2): + for c in cache: + x = mx.random.uniform(shape=(1, 8, 1, 4)) + c.update_and_fetch(x, x) + + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + + for c, lc in zip(cache, loaded_cache): + x = mx.random.uniform(shape=(1, 8, 1, 4)) + k, v = c.update_and_fetch(x, x) + lk, lv = lc.update_and_fetch(x, x) + self.assertEqual(c.offset, lc.offset) + self.assertTrue(mx.array_equal(k, lk)) + self.assertTrue(mx.array_equal(v, lv)) + + def test_save_load_mixed_cache(self): + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + + cache = [MambaCache(), KVCache(), RotatingKVCache(8), MambaCache()] + for c in cache: + if isinstance(c, MambaCache): + c[0] = mx.random.uniform(shape=(4, 4, 4)) + c[1] = mx.random.uniform(shape=(4, 4, 4)) + else: + x = mx.random.uniform(shape=(4, 4, 7, 4)) + y = mx.random.uniform(shape=(4, 4, 7, 4)) + c.update_and_fetch(x, y) + + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + for c, lc in zip(cache, loaded_cache): + if isinstance(c, MambaCache): + self.assertTrue(mx.array_equal(c[0], lc[0])) + self.assertTrue(mx.array_equal(c[1], lc[1])) + else: + x = mx.random.uniform(shape=(4, 4, 1, 4)) + y = mx.random.uniform(shape=(4, 4, 1, 4)) + k, v = c.update_and_fetch(x, y) + lk, lv = lc.update_and_fetch(x, y) + self.assertEqual(c.offset, lc.offset) + self.assertTrue(mx.array_equal(k, lk)) + self.assertTrue(mx.array_equal(v, lv)) + + def test_cache_with_generate(self): + model, tokenizer = load(HF_MODEL_PATH) + prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] + results = zip(range(4), generate_step(prompt, model)) + toks, all_logits = zip(*(r[1] for r in results)) + + prompt_cache = make_prompt_cache(model) + i = 0 + for _, (tok, logits) in zip( + range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + ): + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i])) + i += 1 + + for _, (tok, logits) in zip( + range(1), + generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + ): + i += 1 + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i])) + + def test_trim_cache(self): + cache = [KVCache() for _ in range(2)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 4)) + c.update_and_fetch(x, x) + + # Trim + num_trimmed = trim_prompt_cache(cache, 7) + self.assertEqual(num_trimmed, 7) + + # Trim more tokens than remain + num_trimmed = trim_prompt_cache(cache, 4) + self.assertEqual(num_trimmed, 3) + + # Can't trim mamba cache + cache = [MambaCache() for _ in range(2)] + for c in cache: + c.state = mx.zeros((5, 5)) + num_trimmed = trim_prompt_cache(cache, 7) + self.assertEqual(num_trimmed, 0) + + # All cache's have to be trimmable + cache = [MambaCache(), KVCache()] + cache[0].state = mx.zeros((5, 5)) + x = mx.random.uniform(shape=(1, 8, 10, 4)) + cache[1].update_and_fetch(x, x) + num_trimmed = trim_prompt_cache(cache, 1) + self.assertEqual(num_trimmed, 0) + + cache = [RotatingKVCache(max_size=6) for _ in range(2)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 5, 4)) + c.update_and_fetch(x, x) + + num_trimmed = trim_prompt_cache(cache, 4) + self.assertEqual(num_trimmed, 4) + + # Can't trim fixed-size KV cache after processing + # more than max_kv_size tokens + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 4)) + c.update_and_fetch(x, x) + + num_trimmed = trim_prompt_cache(cache, 4) + self.assertEqual(num_trimmed, 0) + + def test_trim_cache_with_generate(self): + model, tokenizer = load(HF_MODEL_PATH) + prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] + + prompt_cache = make_prompt_cache(model) + + # Generate one token so we process the full prompt + last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache)) + last_tok = mx.array([last_tok]) + + # Generate two more tokens + results = zip( + range(2), generate_step(last_tok, model, prompt_cache=prompt_cache) + ) + toks, all_logits = zip(*(r[1] for r in results)) + + # To get back to the cache just after processing the prompt, + # trim by 3 tokens + trim_prompt_cache(prompt_cache, 3) + + # Generate the same thing again + results = zip( + range(2), generate_step(last_tok, model, prompt_cache=prompt_cache) + ) + second_toks, second_all_logits = zip(*(r[1] for r in results)) + self.assertEqual(toks, second_toks) + self.assertTrue( + all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits)) + ) + + def test_cache_copying(self): + cache = [KVCache()] + + x = mx.random.uniform(shape=(1, 8, 10, 4)) + cache[0].update_and_fetch(x, x) + + y = mx.random.uniform(shape=(1, 8, 1, 4)) + cache[0].update_and_fetch(y, y) + + old_cache = copy.deepcopy(cache) + + trim_prompt_cache(cache, 1) + + self.assertTrue(old_cache[0].offset, 11) + self.assertTrue(cache[0].offset, 10) + + z = mx.random.uniform(shape=(1, 8, 1, 4)) + cache[0].update_and_fetch(z, z) + + self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y)) + self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z)) + + +if __name__ == "__main__": + unittest.main() diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index cbcccfbe..ad17554d 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -14,6 +14,7 @@ class DummyModelProvider: def __init__(self): HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" self.model, self.tokenizer = load(HF_MODEL_PATH) + self.model_key = (HF_MODEL_PATH, None) def load(self, model, adapter=None): assert model in ["default_model", "chat_model"] diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py new file mode 100644 index 00000000..7b4828b1 --- /dev/null +++ b/llms/tests/test_tokenizers.py @@ -0,0 +1,76 @@ +# Copyright © 2024 Apple Inc. + +import unittest +from pathlib import Path + +from huggingface_hub import snapshot_download +from mlx_lm.tokenizer_utils import ( + BPEStreamingDetokenizer, + NaiveStreamingDetokenizer, + SPMStreamingDetokenizer, + load_tokenizer, +) + + +class TestTokenizers(unittest.TestCase): + + def download_tokenizer(self, repo): + path = Path( + snapshot_download( + repo_id=repo, + allow_patterns=[ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "tokenizer.model", + ], + ) + ) + return load_tokenizer(path) + + def check_tokenizer(self, tokenizer): + def check(tokens): + expected_text = tokenizer.decode(tokens) + detokenizer = tokenizer.detokenizer + detokenizer.reset() + text = "" + for t in tokens: + detokenizer.add_token(t) + seg = detokenizer.last_segment + text += seg + detokenizer.finalize() + text += detokenizer.last_segment + self.assertEqual(text, expected_text) + + tokens = tokenizer.encode("a ,b") + check(tokens) + + tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}') + check(tokens) + + tokens = tokenizer.encode("3 3") + check(tokens) + + def test_tokenizers(self): + tokenizer_repos = [ + ("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer), + ("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer), + ("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer), + ("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer), + ("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer), + ] + for tokenizer_repo, expected_detokenizer in tokenizer_repos: + with self.subTest(tokenizer=tokenizer_repo): + tokenizer = self.download_tokenizer(tokenizer_repo) + tokenizer.decode([0, 1, 2]) + self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer)) + self.check_tokenizer(tokenizer) + + # Try one with a naive detokenizer + tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit") + tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer) + self.check_tokenizer(tokenizer) + + +if __name__ == "__main__": + unittest.main() diff --git a/musicgen/README.md b/musicgen/README.md new file mode 100644 index 00000000..c0e340c9 --- /dev/null +++ b/musicgen/README.md @@ -0,0 +1,30 @@ +# MusicGen + +An example of Meta's MusicGen model in MLX.[^1] MusicGen is used to generate +music from text descriptions. + +### Setup + +Install the requirements: + +``` +pip install -r requirements.txt +``` + +### Example + +An example using the model: + +```python +from musicgen import MusicGen +from utils import save_audio + +model = MusicGen.from_pretrained("facebook/musicgen-medium") + +audio = model.generate("happy rock") + +save_audio("out.wav", audio, model.sampling_rate) +``` + +[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2306.05284) and + [code](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md) for more details. diff --git a/musicgen/benchmarks/bench_mx.py b/musicgen/benchmarks/bench_mx.py new file mode 100644 index 00000000..b669597b --- /dev/null +++ b/musicgen/benchmarks/bench_mx.py @@ -0,0 +1,28 @@ +# Copyright © 2024 Apple Inc. + +import sys +import time +from pathlib import Path + +import mlx.core as mx + +cur_path = Path(__file__).parents[1].resolve() +sys.path.append(str(cur_path)) + +from musicgen import MusicGen + +text = "folk ballad" +model = MusicGen.from_pretrained("facebook/musicgen-medium") + +max_steps = 100 + +audio = model.generate(text, max_steps=10) +mx.eval(audio) + +tic = time.time() +audio = model.generate(text, max_steps=max_steps) +mx.eval(audio) +toc = time.time() + +ms = 1000 * (toc - tic) / max_steps +print(f"Time (ms) per step: {ms:.3f}") diff --git a/musicgen/benchmarks/bench_pt.py b/musicgen/benchmarks/bench_pt.py new file mode 100644 index 00000000..de01aa66 --- /dev/null +++ b/musicgen/benchmarks/bench_pt.py @@ -0,0 +1,31 @@ +# Copyright © 2024 Apple Inc. + +import time + +import torch +from transformers import AutoProcessor, MusicgenForConditionalGeneration + +model_name = "facebook/musicgen-medium" +processor = AutoProcessor.from_pretrained(model_name) +model = MusicgenForConditionalGeneration.from_pretrained(model_name).to("mps") + +inputs = processor( + text=["folk ballad"], + padding=True, + return_tensors="pt", +) +inputs["input_ids"] = inputs["input_ids"].to("mps") +inputs["attention_mask"] = inputs["attention_mask"].to("mps") + +# warmup +audio_values = model.generate(**inputs, max_new_tokens=10) +torch.mps.synchronize() + +max_steps = 100 +tic = time.time() +audio_values = model.generate(**inputs, max_new_tokens=max_steps) +torch.mps.synchronize() +toc = time.time() + +ms = 1000 * (toc - tic) / max_steps +print(f"Time (ms) per step: {ms:.3f}") diff --git a/musicgen/encodec.py b/musicgen/encodec.py new file mode 120000 index 00000000..8eb278a7 --- /dev/null +++ b/musicgen/encodec.py @@ -0,0 +1 @@ +../encodec/encodec.py \ No newline at end of file diff --git a/musicgen/generate.py b/musicgen/generate.py new file mode 100644 index 00000000..5a6b7804 --- /dev/null +++ b/musicgen/generate.py @@ -0,0 +1,23 @@ +# Copyright © 2024 Apple Inc. + +import argparse + +from utils import save_audio + +from musicgen import MusicGen + + +def main(text: str, output_path: str, model_name: str, max_steps: int): + model = MusicGen.from_pretrained(model_name) + audio = model.generate(text, max_steps=max_steps) + save_audio(output_path, audio, model.sampling_rate) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=False, default="facebook/musicgen-medium") + parser.add_argument("--text", required=False, default="happy rock") + parser.add_argument("--output-path", required=False, default="0.wav") + parser.add_argument("--max-steps", required=False, default=500, type=int) + args = parser.parse_args() + main(args.text, args.output_path, args.model, args.max_steps) diff --git a/musicgen/musicgen.py b/musicgen/musicgen.py new file mode 100644 index 00000000..a2d021a5 --- /dev/null +++ b/musicgen/musicgen.py @@ -0,0 +1,358 @@ +# Copyright © 2024 Apple Inc. + +import json +from functools import partial +from pathlib import Path +from types import SimpleNamespace +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn +from tqdm import tqdm + +from encodec import EncodecModel +from t5 import T5 + + +class TextConditioner(nn.Module): + def __init__(self, t5_name, input_dim, output_dim): + super().__init__() + self._t5, self.tokenizer = T5.from_pretrained(t5_name) + self.output_proj = nn.Linear(input_dim, output_dim) + + def __call__(self, text): + x = self.tokenizer.encode(text) + x = self._t5.encode(x) + return self.output_proj(x) + + +class KVCache: + def __init__(self, head_dim, n_kv_heads): + self.n_kv_heads = n_kv_heads + if isinstance(head_dim, int): + self.k_head_dim = self.v_head_dim = head_dim + elif isinstance(head_dim, tuple) and len(head_dim) == 2: + self.k_head_dim, self.v_head_dim = head_dim + else: + raise ValueError("head_dim must be an int or a tuple of two ints") + self.keys = None + self.values = None + self.offset = 0 + self.step = 256 + + def update_and_fetch(self, keys, values): + prev = self.offset + if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: + B = keys.shape[0] + n_steps = (self.step + keys.shape[2] - 1) // self.step + k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim) + v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim) + new_k = mx.zeros(k_shape, keys.dtype) + new_v = mx.zeros(v_shape, values.dtype) + if self.keys is not None: + if prev % self.step != 0: + self.keys = self.keys[..., :prev, :] + self.values = self.values[..., :prev, :] + self.keys = mx.concatenate([self.keys, new_k], axis=2) + self.values = mx.concatenate([self.values, new_v], axis=2) + else: + self.keys, self.values = new_k, new_v + + self.offset += keys.shape[2] + self.keys[..., prev : self.offset, :] = keys + self.values[..., prev : self.offset, :] = values + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + + @property + def state(self): + return self.keys, self.values + + +class MultiHeadAttention(nn.Module): + def __init__(self, dim, n_heads): + super().__init__() + + self.n_heads = n_heads + + head_dim = dim // n_heads + + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, dim, bias=False) + self.k_proj = nn.Linear(dim, dim, bias=False) + self.v_proj = nn.Linear(dim, dim, bias=False) + self.out_proj = nn.Linear(dim, dim, bias=False) + + def __call__( + self, + queries: mx.array, + keys: mx.array, + values: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + B, L_q, D = queries.shape + L_k = keys.shape[1] + + queries, keys, values = ( + self.q_proj(queries), + self.k_proj(keys), + self.v_proj(values), + ) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L_q, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L_k, self.n_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L_k, self.n_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L_q, -1) + return self.out_proj(output) + + +class TransformerBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.num_attention_heads = config.decoder.num_attention_heads + self.hidden_size = config.decoder.hidden_size + self.self_attn = MultiHeadAttention(self.hidden_size, self.num_attention_heads) + self.cross_attn = MultiHeadAttention(self.hidden_size, self.num_attention_heads) + self.linear1 = nn.Linear(self.hidden_size, config.decoder.ffn_dim, bias=False) + self.linear2 = nn.Linear(config.decoder.ffn_dim, self.hidden_size, bias=False) + + self.norm1 = nn.LayerNorm(self.hidden_size, eps=1e-5) + self.norm_cross = nn.LayerNorm(self.hidden_size, eps=1e-5) + self.norm2 = nn.LayerNorm(self.hidden_size, eps=1e-5) + + def __call__( + self, + x: mx.array, + conditioning: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[KVCache] = None, + ) -> mx.array: + xn = self.norm1(x) + x += self.self_attn(xn, xn, xn, mask, cache) + xn = self.norm_cross(x) + x += self.cross_attn(xn, conditioning, conditioning, mask) + xn = self.norm2(x) + x += self.linear2(nn.gelu(self.linear1(xn))) + return x + + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def top_k_sampling( + logits: mx.array, top_k: float, temperature: float, axis: int = -1 +) -> mx.array: + """ + Apply top-k sampling to logits. + + Args: + logits: The logits from the model's output. + top_k: Sample from the top k logits. + temperature: Temperature parameter for softmax distribution reshaping. + axis: Axis along which to sample. + Returns: + token selected based on the top-k criterion. + """ + # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460 + probs = mx.softmax(logits * (1 / temperature), axis=axis) + + # sort probs in ascending order + sorted_indices = mx.argsort(probs, axis=axis) + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=axis) + prob_threshold = mx.take(sorted_probs, mx.array(-top_k), axis=axis) + + # select the top K tokens in probability + top_probs = mx.where( + sorted_probs > prob_threshold, + sorted_probs, + 0, + ) + + sorted_token = mx.random.categorical(mx.log(top_probs), axis=axis) + token = mx.take_along_axis( + sorted_indices, mx.expand_dims(sorted_token, axis), axis=axis + ) + + return token + + +def create_sin_embedding(positions: mx.array, dim: int, max_period: float = 10000): + assert dim % 2 == 0 + half_dim = dim // 2 + adim = mx.arange(half_dim).reshape(1, 1, -1) + phase = positions / (max_period ** (adim / (half_dim - 1))) + return mx.concatenate([mx.cos(phase), mx.sin(phase)], axis=-1) + + +class MusicGen(nn.Module): + def __init__(self, config): + self.num_codebooks = config.decoder.num_codebooks + self.codebook_size = config.audio_encoder.codebook_size + self.bos_token_id = config.decoder.bos_token_id + self.hidden_size = config.decoder.hidden_size + self.num_attention_heads = config.decoder.num_attention_heads + self.sampling_rate = config.audio_encoder.sampling_rate + + self.text_conditioner = TextConditioner( + config.text_encoder._name_or_path, + config.text_encoder.d_model, + self.hidden_size, + ) + self.emb = [ + nn.Embedding(self.codebook_size + 1, self.hidden_size) + for _ in range(self.num_codebooks) + ] + self.layers = [ + TransformerBlock(config) for _ in range(config.decoder.num_hidden_layers) + ] + self.out_norm = nn.LayerNorm(self.hidden_size, eps=1e-5) + self.linears = [ + nn.Linear(self.hidden_size, self.codebook_size, bias=False) + for _ in range(self.num_codebooks) + ] + encodec_name = config.audio_encoder._name_or_path.split("/")[-1] + encodec_name = encodec_name.replace("_", "-") + self._audio_decoder, _ = EncodecModel.from_pretrained( + f"mlx-community/{encodec_name}-float32" + ) + + def __call__( + self, + audio_tokens: mx.array, + conditioning: mx.array, + cache: list[KVCache] = None, + ): + + if cache is None: + cache = [None] * len(self.layers) + + x = sum([self.emb[k](audio_tokens[..., k]) for k in range(self.num_codebooks)]) + + offset = cache[0].offset if cache[0] is not None else 0 + pos_emb = create_sin_embedding(offset, self.hidden_size) + x += pos_emb.astype(x.dtype) + + for layer, c in zip(self.layers, cache): + x = layer(x, conditioning, cache=c) + + x = self.out_norm(x) + x = mx.stack([self.linears[k](x) for k in range(self.num_codebooks)], axis=-1) + return x + + def generate( + self, + text: str, + max_steps: int = 200, + top_k: int = 250, + temp: float = 1.0, + guidance_coef: float = 3.0, + ) -> mx.array: + """ + Generates a waveform conditioned on `text`. + + Args: + text (str): The text to condition generation on. + max_steps (int): Max steps to generate. + top_k (int): Top k used in sampling. + temp (float): Sampling softmax temperature. + guidance_coef (float): Classifier free guidance coefficent. + Used to combine conditional and unconditional logits. + + Returns: + An mx.array of audio samples of shape ``(num_samples,)``. + """ + # Assuming no audio prompt we start with all bos token for the codebooks + audio_shape = (1, max_steps + 1, self.num_codebooks) + audio_seq = mx.full(audio_shape, self.bos_token_id) + + text_tokens = self.text_conditioner(text) + # Compute conditional and unconditional logits in one batch + text_tokens = mx.concatenate([text_tokens, mx.zeros_like(text_tokens)], axis=0) + + head_dim = self.hidden_size // self.num_attention_heads + cache = [ + KVCache(head_dim, self.num_attention_heads) for _ in range(len(self.layers)) + ] + for offset in tqdm(range(max_steps)): + audio_input = mx.tile(audio_seq[:, offset : offset + 1], [2, 1, 1]) + audio_logits = self(audio_input, text_tokens, cache) + cond_logits, uncond_logits = audio_logits[:1], audio_logits[1:2] + audio_logits = uncond_logits + (cond_logits - uncond_logits) * guidance_coef + audio_tokens = top_k_sampling(audio_logits, top_k, temp, axis=-2) + # "delay" pattern + audio_tokens[..., offset + 1 :] = self.bos_token_id + audio_tokens[..., : -max_steps + offset] = self.bos_token_id + audio_seq[:, offset + 1 : offset + 2] = audio_tokens + mx.eval(audio_seq) + + # Undo delay + for i in range(self.num_codebooks): + audio_seq[:, : -self.num_codebooks, i] = audio_seq[ + :, i : -self.num_codebooks + i, i + ] + audio_seq = audio_seq[:, 1 : -self.num_codebooks + 1] + + audio_seq = mx.swapaxes(audio_seq, -1, -2)[:, mx.newaxis] + audio = self._audio_decoder.decode(audio_seq, audio_scales=[None]) + return audio[0] + + @classmethod + def sanitize(cls, weights): + out_weights = {} + for k, arr in weights.items(): + if k.startswith("transformer."): + k = k[len("transformer.") :] + + if "cross_attention" in k: + k = k.replace("cross_attention", "cross_attn") + + if "condition_provider" in k: + k = k.replace( + "condition_provider.conditioners.description", "text_conditioner" + ) + + if "in_proj_weight" in k: + dim = arr.shape[0] // 3 + name = "in_proj_weight" + out_weights[k.replace(name, "q_proj.weight")] = arr[:dim] + out_weights[k.replace(name, "k_proj.weight")] = arr[dim : dim * 2] + out_weights[k.replace(name, "v_proj.weight")] = arr[dim * 2 :] + continue + + out_weights[k] = arr + return out_weights + + @classmethod + def from_pretrained(cls, path_or_repo: str): + import torch + from huggingface_hub import snapshot_download + + path = Path(path_or_repo) + if not path.exists(): + path = Path( + snapshot_download( + repo_id=path_or_repo, + allow_patterns=["*.json", "state_dict.bin"], + ) + ) + + with open(path / "config.json", "r") as f: + config = SimpleNamespace(**json.load(f)) + config.text_encoder = SimpleNamespace(**config.text_encoder) + config.audio_encoder = SimpleNamespace(**config.audio_encoder) + config.decoder = SimpleNamespace(**config.decoder) + + weights = torch.load(path / "state_dict.bin", weights_only=True)["best_state"] + weights = {k: mx.array(v) for k, v in weights.items()} + weights = cls.sanitize(weights) + + model = MusicGen(config) + model.load_weights(list(weights.items())) + return model diff --git a/musicgen/requirements.txt b/musicgen/requirements.txt new file mode 100644 index 00000000..5c716fe3 --- /dev/null +++ b/musicgen/requirements.txt @@ -0,0 +1,6 @@ +mlx>=0.18 +numpy +huggingface_hub +torch +transformers +scipy diff --git a/musicgen/t5.py b/musicgen/t5.py new file mode 120000 index 00000000..f31e26f9 --- /dev/null +++ b/musicgen/t5.py @@ -0,0 +1 @@ +../t5/t5.py \ No newline at end of file diff --git a/musicgen/utils.py b/musicgen/utils.py new file mode 100644 index 00000000..78e92571 --- /dev/null +++ b/musicgen/utils.py @@ -0,0 +1,15 @@ +# Copyright © 2024 Apple Inc. + +import mlx.core as mx +import numpy as np + + +def save_audio(file: str, audio: mx.array, sampling_rate: int): + """ + Save audio to a wave (.wav) file. + """ + from scipy.io.wavfile import write + + audio = mx.clip(audio, -1, 1) + audio = (audio * 32767).astype(mx.int16) + write(file, sampling_rate, np.array(audio)) diff --git a/t5/README.md b/t5/README.md index a0cc861b..e5165f8f 100644 --- a/t5/README.md +++ b/t5/README.md @@ -7,31 +7,6 @@ tasks by prepending task-specific prefixes to the input, e.g.: This example also supports the FLAN-T5 models variants.[^2] -## Setup - -Download and convert the model: - -```sh -python convert.py --model -``` - -This will make the `.npz` file which MLX can read. - -The `` can be any of the following: - -| Model Name | Model Size | -| ---------- | ---------- -| t5-small | 60 million | -| t5-base | 220 million | -| t5-large | 770 million | -| t5-3b | 3 billion | -| t5-11b | 11 billion | - -The FLAN variants can be specified with `google/flan-t5-small`, -`google/flan-t5-base`, etc. See the [Hugging Face -page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a -complete list of models. - ## Generate Generate text with: @@ -48,6 +23,21 @@ To see a list of options run: python t5.py --help ``` +The `` can be any of the following: + +| Model Name | Model Size | +| ---------- | ---------- +| t5-small | 60 million | +| t5-base | 220 million | +| t5-large | 770 million | +| t5-3b | 3 billion | +| t5-11b | 11 billion | + +The FLAN variants can be specified with `google/flan-t5-small`, +`google/flan-t5-base`, etc. See the [Hugging Face +page](https://huggingface.co/docs/transformers/model_doc/flan-t5) for a +complete list of models. + [^1]: For more information on T5 see the [original paper](https://arxiv.org/abs/1910.10683) or the [Hugging Face page](https://huggingface.co/docs/transformers/model_doc/t5). [^2]: For more information on FLAN-T5 see the [original paper](https://arxiv.org/abs/2210.11416). diff --git a/t5/convert.py b/t5/convert.py deleted file mode 100644 index e2108a0c..00000000 --- a/t5/convert.py +++ /dev/null @@ -1,75 +0,0 @@ -import numpy as np -from transformers import T5ForConditionalGeneration - -SHARED_REPLACEMENT_PATTERNS = [ - (".block.", ".layers."), - (".k.", ".key_proj."), - (".o.", ".out_proj."), - (".q.", ".query_proj."), - (".v.", ".value_proj."), - ("shared.", "wte."), - ("lm_head.", "lm_head.linear."), - (".layer.0.layer_norm.", ".ln1."), - (".layer.1.layer_norm.", ".ln2."), - (".layer.2.layer_norm.", ".ln3."), - (".final_layer_norm.", ".ln."), - ( - "layers.0.layer.0.SelfAttention.relative_attention_bias.", - "relative_attention_bias.embeddings.", - ), -] - -ENCODER_REPLACEMENT_PATTERNS = [ - (".layer.0.SelfAttention.", ".attention."), - (".layer.1.DenseReluDense.", ".dense."), -] - -DECODER_REPLACEMENT_PATTERNS = [ - (".layer.0.SelfAttention.", ".self_attention."), - (".layer.1.EncDecAttention.", ".cross_attention."), - (".layer.2.DenseReluDense.", ".dense."), -] - - -def replace_key(key: str) -> str: - for old, new in SHARED_REPLACEMENT_PATTERNS: - key = key.replace(old, new) - if key.startswith("encoder."): - for old, new in ENCODER_REPLACEMENT_PATTERNS: - key = key.replace(old, new) - elif key.startswith("decoder."): - for old, new in DECODER_REPLACEMENT_PATTERNS: - key = key.replace(old, new) - return key - - -def convert(model_name, dtype): - dtype = getattr(np, dtype) - model = T5ForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto") - weights = { - replace_key(k): v.numpy().astype(dtype) for k, v in model.state_dict().items() - } - file_name = model_name.replace("/", "-") - print(f"Saving weights to {file_name}.npz") - np.savez(f"{file_name}.npz", **weights) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="Convert T5 weights to MLX") - parser.add_argument( - "--model", - type=str, - help="Name of the T5 model.", - default="t5-small", - ) - parser.add_argument( - "--dtype", - help="The model data type.", - type=str, - choices=["float16", "float32"], - default="float32", - ) - args = parser.parse_args() - convert(args.model, args.dtype) diff --git a/t5/t5.py b/t5/t5.py index 89f2e486..04a0da8c 100644 --- a/t5/t5.py +++ b/t5/t5.py @@ -1,12 +1,45 @@ import argparse +import json +from pathlib import Path from time import perf_counter_ns +from types import SimpleNamespace from typing import List, Optional, Tuple import mlx.core as mx import mlx.nn as nn import numpy as np -from mlx.utils import tree_map, tree_unflatten -from transformers import AutoTokenizer, T5Config +from transformers import AutoTokenizer + + +class Tokenizer: + def __init__(self, config, model_name): + self._decoder_start_id = config.decoder_start_token_id + self._tokenizer = AutoTokenizer.from_pretrained( + model_name, + legacy=False, + model_max_length=getattr(config, "n_positions", 512), + ) + + @property + def eos_id(self) -> int: + return self._tokenizer.eos_token_id + + @property + def decoder_start_id(self) -> int: + return self._decoder_start_id + + def encode(self, s: str) -> mx.array: + return mx.array( + self._tokenizer( + s, + return_tensors="np", + return_attention_mask=False, + )["input_ids"] + ) + + def decode(self, t: List[int], with_sep: bool = True) -> str: + tokens = self._tokenizer.convert_ids_to_tokens(t) + return "".join(t.replace("▁", " " if with_sep else "") for t in tokens) def _relative_position_bucket( @@ -60,10 +93,10 @@ def _relative_position_bucket( class RelativePositionBias(nn.Module): - def __init__(self, config: T5Config, bidirectional: bool): + def __init__(self, config, bidirectional: bool): self.bidirectional = bidirectional self.num_buckets = config.relative_attention_num_buckets - self.max_distance = config.relative_attention_max_distance + self.max_distance = getattr(config, "relative_attention_max_distance", 128) self.n_heads = config.num_heads self.embeddings = nn.Embedding( config.relative_attention_num_buckets, config.num_heads @@ -91,7 +124,7 @@ class RelativePositionBias(nn.Module): class MultiHeadAttention(nn.Module): - def __init__(self, config: T5Config): + def __init__(self, config): super().__init__() inner_dim = config.d_kv * config.num_heads self.num_heads = config.num_heads @@ -135,17 +168,21 @@ class MultiHeadAttention(nn.Module): class DenseActivation(nn.Module): - def __init__(self, config: T5Config): + def __init__(self, config): super().__init__() mlp_dims = config.d_ff or config.d_model * 4 - self.gated = config.feed_forward_proj.startswith("gated") + self.gated = hasattr(config, "feed_forward_proj") + activation = ( + "relu" + if not self.gated + else config.feed_forward_proj.removeprefix("gated-") + ) if self.gated: self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False) self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False) else: self.wi = nn.Linear(config.d_model, mlp_dims, bias=False) self.wo = nn.Linear(mlp_dims, config.d_model, bias=False) - activation = config.feed_forward_proj.removeprefix("gated-") if activation == "relu": self.act = nn.relu elif activation == "gelu": @@ -166,7 +203,7 @@ class DenseActivation(nn.Module): class TransformerEncoderLayer(nn.Module): - def __init__(self, config: T5Config): + def __init__(self, config): super().__init__() self.attention = MultiHeadAttention(config) self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) @@ -184,7 +221,7 @@ class TransformerEncoderLayer(nn.Module): class TransformerEncoder(nn.Module): - def __init__(self, config: T5Config): + def __init__(self, config): super().__init__() self.layers = [ TransformerEncoderLayer(config) for i in range(config.num_layers) @@ -200,7 +237,7 @@ class TransformerEncoder(nn.Module): class TransformerDecoderLayer(nn.Module): - def __init__(self, config: T5Config): + def __init__(self, config): super().__init__() self.self_attention = MultiHeadAttention(config) self.cross_attention = MultiHeadAttention(config) @@ -233,7 +270,7 @@ class TransformerDecoderLayer(nn.Module): class TransformerDecoder(nn.Module): - def __init__(self, config: T5Config): + def __init__(self, config): super().__init__() n_layers = getattr(config, "num_decoder_layers", config.num_layers) self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)] @@ -262,7 +299,7 @@ class TransformerDecoder(nn.Module): class OutputHead(nn.Module): - def __init__(self, config: T5Config): + def __init__(self, config): self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False) def __call__(self, inputs): @@ -270,11 +307,11 @@ class OutputHead(nn.Module): class T5(nn.Module): - def __init__(self, config: T5Config): + def __init__(self, config): self.wte = nn.Embedding(config.vocab_size, config.d_model) self.encoder = TransformerEncoder(config) self.decoder = TransformerDecoder(config) - self.tie_word_embeddings = config.tie_word_embeddings + self.tie_word_embeddings = getattr(config, "tie_word_embeddings", True) if not self.tie_word_embeddings: self.lm_head = OutputHead(config) self.model_dim = config.d_model @@ -313,36 +350,82 @@ class T5(nn.Module): ): return self.decode(decoder_inputs, self.encode(inputs))[0] + @classmethod + def sanitize(cls, weights): + shared_replacement_patterns = [ + (".block.", ".layers."), + (".k.", ".key_proj."), + (".o.", ".out_proj."), + (".q.", ".query_proj."), + (".v.", ".value_proj."), + ("shared.", "wte."), + ("lm_head.", "lm_head.linear."), + (".layer.0.layer_norm.", ".ln1."), + (".layer.1.layer_norm.", ".ln2."), + (".layer.2.layer_norm.", ".ln3."), + (".final_layer_norm.", ".ln."), + ( + "layers.0.layer.0.SelfAttention.relative_attention_bias.", + "relative_attention_bias.embeddings.", + ), + ] -class Tokenizer: - def __init__(self, config: T5Config): - self._decoder_start_id = config.decoder_start_token_id - self._tokenizer = AutoTokenizer.from_pretrained( - args.model, - legacy=False, - model_max_length=getattr(config, "n_positions", 512), - ) + encoder_replacement_patterns = [ + (".layer.0.SelfAttention.", ".attention."), + (".layer.1.DenseReluDense.", ".dense."), + ] - @property - def eos_id(self) -> int: - return self._tokenizer.eos_token_id + decoder_replacement_patterns = [ + (".layer.0.SelfAttention.", ".self_attention."), + (".layer.1.EncDecAttention.", ".cross_attention."), + (".layer.2.DenseReluDense.", ".dense."), + ] - @property - def decoder_start_id(self) -> int: - return self._decoder_start_id + ignored_keys = [ + "decoder.layers.0.cross_attention.relative_attention_bias.weight" + ] - def encode(self, s: str) -> mx.array: - return mx.array( - self._tokenizer( - s, - return_tensors="np", - return_attention_mask=False, - )["input_ids"] - ) + def replace_key(key: str) -> str: + for old, new in shared_replacement_patterns: + key = key.replace(old, new) + if key.startswith("encoder."): + for old, new in encoder_replacement_patterns: + key = key.replace(old, new) + elif key.startswith("decoder."): + for old, new in decoder_replacement_patterns: + key = key.replace(old, new) + return key - def decode(self, t: List[int], with_sep: bool = True) -> str: - tokens = self._tokenizer.convert_ids_to_tokens(t) - return "".join(t.replace("▁", " " if with_sep else "") for t in tokens) + weights = {replace_key(k): v for k, v in weights.items()} + for key in ignored_keys: + if key in weights: + del weights[key] + return weights + + @classmethod + def from_pretrained( + cls, path_or_repo: str, dtype: mx.Dtype = mx.bfloat16 + ) -> tuple["T5", Tokenizer]: + from huggingface_hub import snapshot_download + + path = Path(path_or_repo) + if not path.exists(): + path = Path( + snapshot_download( + repo_id=path_or_repo, + allow_patterns=["*.json", "*.safetensors", "*.model"], + ) + ) + + with open(path / "config.json", "r") as f: + config = SimpleNamespace(**json.load(f)) + + model = T5(config) + weights = mx.load(str(path / "model.safetensors")) + weights = cls.sanitize(weights) + weights = {k: v.astype(dtype) for k, v in weights.items()} + model.load_weights(list(weights.items())) + return model, Tokenizer(config, "t5-base") def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] = 0.0): @@ -363,19 +446,6 @@ def generate(prompt: str, model: T5, tokenizer: Tokenizer, temp: Optional[float] yield y.squeeze() -def load_model(model_name: str, dtype: str = "float16"): - config = T5Config.from_pretrained(args.model) - dtype = getattr(mx, dtype) - model = T5(config) - file_name = model_name.replace("/", "-") - weights = mx.load(f"{file_name}.npz") - weights = tree_unflatten(list(weights.items())) - weights = tree_map(lambda p: p.astype(dtype), weights) - model.update(weights) - mx.eval(model.parameters()) - return model, Tokenizer(config) - - if __name__ == "__main__": parser = argparse.ArgumentParser(description="T5 Inference script") parser.add_argument( @@ -421,7 +491,8 @@ if __name__ == "__main__": mx.random.seed(args.seed) - model, tokenizer = load_model(args.model, args.dtype) + dtype = getattr(mx, args.dtype) + model, tokenizer = T5.from_pretrained(args.model, dtype) if args.encode_only: print("[INFO] Encoding with T5...", flush=True)