diff --git a/.circleci/config.yml b/.circleci/config.yml index 02fa1de8..8367281e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -26,13 +26,13 @@ jobs: - run: name: Install dependencies command: | - brew install python@3.8 - python3.8 -m venv env + brew install python@3.9 + python3.9 -m venv env source env/bin/activate pip install --upgrade pip pip install unittest-xml-reporting cd llms/ - pip install -e ".[testing]" + pip install -e ".[test]" - run: name: Run Python tests command: | diff --git a/README.md b/README.md index bd180975..88888ad0 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,10 @@ Some more useful examples are listed below. ### Image Models +- Generating images + - [FLUX](flux) + - [Stable Diffusion or SDXL](stable_diffusion) - Image classification using [ResNets on CIFAR-10](cifar). -- Generating images with [Stable Diffusion or SDXL](stable_diffusion). - Convolutional variational autoencoder [(CVAE) on MNIST](cvae). ### Audio Models diff --git a/clip/linear_probe.py b/clip/linear_probe.py new file mode 100644 index 00000000..2649e397 --- /dev/null +++ b/clip/linear_probe.py @@ -0,0 +1,56 @@ +# Mirror of the Linear Probe Evaluation Script +# from the official CLIP Repository. + +import mlx.core as mx +import numpy as np +from image_processor import CLIPImageProcessor +from mlx.data.datasets import load_cifar10 +from model import CLIPModel +from PIL import Image +from sklearn.linear_model import LogisticRegression +from tqdm import tqdm + + +def get_cifar10(batch_size, root=None): + tr = load_cifar10(root=root).batch(batch_size) + test = load_cifar10(root=root, train=False).batch(batch_size) + + return tr, test + + +def get_features(model, image_proc, iter): + all_features = [] + all_labels = [] + + for batch in tqdm(iter): + image, label = batch["image"], batch["label"] + x = image_proc([Image.fromarray(im) for im in image]) + y = mx.array(label) + + image_embeds = model.get_image_features(x) + mx.eval(image_embeds) + + all_features.append(image_embeds) + all_labels.append(y) + + return mx.concatenate(all_features), mx.concatenate(all_labels) + + +if __name__ == "__main__": + model = CLIPModel.from_pretrained("mlx_model") + image_proc = CLIPImageProcessor.from_pretrained("mlx_model") + + train_iter, test_iter = get_cifar10(batch_size=256) + train_features, train_labels = get_features(model, image_proc, train_iter) + test_features, test_labels = get_features(model, image_proc, test_iter) + + # Perform logistic regression + # NOTE: The value of C should be determined via a hyperparameter sweep + # using a validation split + classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) + classifier.fit(train_features, train_labels) + + # Evaluate using the logistic regression classifier + predictions = classifier.predict(test_features) + accuracy = (test_labels.squeeze() == predictions).mean().item() * 100 + print(f"Accuracy = {accuracy:.3f}") diff --git a/clip/requirements.txt b/clip/requirements.txt index 74f826ea..8e05620e 100644 --- a/clip/requirements.txt +++ b/clip/requirements.txt @@ -1,4 +1,5 @@ mlx +mlx-data numpy transformers torch diff --git a/flux/README.md b/flux/README.md index 62eb9b62..b00a9621 100644 --- a/flux/README.md +++ b/flux/README.md @@ -21,13 +21,34 @@ The dependencies are minimal, namely: - `huggingface-hub` to download the checkpoints. - `regex` for the tokenization -- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script +- `tqdm`, `PIL`, and `numpy` for the scripts - `sentencepiece` for the T5 tokenizer +- `datasets` for using an HF dataset directly You can install all of the above with the `requirements.txt` as follows: pip install -r requirements.txt + +Usage +--------- + +You can use the following command to generate an image, using `--output` to specify the storage location of the image, defaulting to `out.png`. + +```shell +python txt2image.py --model schnell \ + --n-images 1 \ + --image-size 256x512 \ + --verbose \ + 'A photo of an astronaut riding a horse on Mars.' +``` + +For more parameters, please use the `--help` command to view. + +```shell +python txt2image.py --help +``` + Inference --------- @@ -78,7 +99,11 @@ except for some additional logic to quantize and/or load trained adapters. One can use the script as follows: ```shell -python txt2image.py --n-images 4 --n-rows 2 --image-size 256x512 'A photo of an astronaut riding a horse on Mars.' +python txt2image.py \ + --n-images 4 \ + --n-rows 2 \ + --image-size 256x512 \ + 'A photo of an astronaut riding a horse on Mars.' ``` ### Experimental Options @@ -94,17 +119,12 @@ Finetuning The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell but ymmv) on a provided image dataset. The dataset folder must have an -`index.json` file with the following format: +`train.jsonl` file with the following format: -```json -{ - "data": [ - {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, - {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, - {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, - ... - ] -} +```jsonl +{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"} +{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"} +... ``` The training script by default trains for 600 iterations with a batch size of @@ -126,19 +146,15 @@ The training images are the following 5 images [^2]: ![dog6](static/dog6.png) -We start by making the following `index.json` file and placing it in the same +We start by making the following `train.jsonl` file and placing it in the same folder as the images. -```json -{ - "data": [ - {"image": "00.jpg", "text": "A photo of sks dog"}, - {"image": "01.jpg", "text": "A photo of sks dog"}, - {"image": "02.jpg", "text": "A photo of sks dog"}, - {"image": "03.jpg", "text": "A photo of sks dog"}, - {"image": "04.jpg", "text": "A photo of sks dog"} - ] -} +```jsonl +{"image": "00.jpg", "prompt": "A photo of sks dog"} +{"image": "01.jpg", "prompt": "A photo of sks dog"} +{"image": "02.jpg", "prompt": "A photo of sks dog"} +{"image": "03.jpg", "prompt": "A photo of sks dog"} +{"image": "04.jpg", "prompt": "A photo of sks dog"} ``` Subsequently we finetune FLUX using the following command: @@ -151,6 +167,17 @@ python dreambooth.py \ path/to/dreambooth/dataset/dog6 ``` + +Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning. + +```shell +python dreambooth.py \ + --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \ + --progress-every 600 --iterations 1200 --learning-rate 0.0001 \ + --lora-rank 4 --grad-accumulate 8 \ + mlx-community/dreambooth-dog6 +``` + The training requires approximately 50GB of RAM and on an M2 Ultra it takes a bit more than 1 hour. @@ -161,7 +188,7 @@ The adapters are saved in `mlx_output` and can be used directly by the ```shell python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \ - --adapter mlx_output/0001200_adapters.safetensors \ + --adapter mlx_output/final_adapters.safetensors \ --fuse-adapter \ --no-t5-padding \ 'A photo of an sks dog lying on the sand at a beach in Greece' diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 2cf06582..8739fda1 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -1,7 +1,6 @@ # Copyright © 2024 Apple Inc. import argparse -import json import time from functools import partial from pathlib import Path @@ -14,13 +13,11 @@ import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten, tree_map, tree_reduce from PIL import Image -from tqdm import tqdm from huggingface_hub import HfApi, interpreter_login from huggingface_hub.utils import HfFolder -from flux import FluxPipeline - +from flux import FluxPipeline, Trainer, load_dataset, save_config class FinetuningDataset: def __init__(self, flux, args): @@ -145,10 +142,10 @@ def generate_progress_images(iteration, flux, args): im.save(out_file) -def save_adapters(iteration, flux, args): +def save_adapters(adapter_name, flux, args): out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) - out_file = out_dir / f"{iteration:07d}_adapters.safetensors" + out_file = out_dir / adapter_name print(f"Saving {str(out_file)}") mx.save_safetensors( @@ -263,7 +260,8 @@ For more details on using Flux, check the [Flux documentation](https://github.co """ return readme_content -if __name__ == "__main__": +def setup_arg_parser(): + """Set up and return the argument parser.""" parser = argparse.ArgumentParser( description="Finetune Flux to generate images with a specific subject" ) @@ -374,9 +372,17 @@ if __name__ == "__main__": help="Make the Hugging Face repository private", ) parser.add_argument("dataset") + return parser + +if __name__ == "__main__": + parser = setup_arg_parser() args = parser.parse_args() + output_path = Path(args.output_dir) + output_path.mkdir(parents=True, exist_ok=True) + save_config(vars(args), output_path / "adapter_config.json") + # Load the model and set it up for LoRA training. We use the same random # state when creating the LoRA layers so all workers will have the same # initial weights. @@ -394,7 +400,7 @@ if __name__ == "__main__": trainable_params = tree_reduce( lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0 ) - print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True) + print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True) # Set up the optimizer and training steps. The steps are a bit verbose to # support gradient accumulation together with compilation. @@ -467,10 +473,10 @@ if __name__ == "__main__": x, t5_feat, clip_feat, guidance, prev_grads ) - print("Create the training dataset.", flush=True) - dataset = FinetuningDataset(flux, args) - dataset.encode_images() - dataset.encode_prompts() + dataset = load_dataset(args.dataset) + trainer = Trainer(flux, dataset, args) + trainer.encode_dataset() + guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype) # An initial generation to compare @@ -479,7 +485,7 @@ if __name__ == "__main__": grads = None losses = [] tic = time.time() - for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)): + for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)): loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0) mx.eval(loss, grads, state) losses.append(loss.item()) @@ -488,7 +494,7 @@ if __name__ == "__main__": toc = time.time() peak_mem = mx.metal.get_peak_memory() / 1024**3 print( - f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} " + f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " f"It/s: {10 / (toc - tic):.3f} " f"Peak mem: {peak_mem:.3f} GB", flush=True, @@ -498,7 +504,7 @@ if __name__ == "__main__": generate_progress_images(i + 1, flux, args) if (i + 1) % args.checkpoint_every == 0: - save_adapters(i + 1, flux, args) + save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args) if (i + 1) % 10 == 0: losses = [] @@ -506,3 +512,6 @@ if __name__ == "__main__": if args.push_to_hub: push_to_hub(args) + + save_adapters("final_adapters.safetensors", flux, args) + print("Training successful.") diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py index 8d39d605..3dd423b7 100644 --- a/flux/flux/__init__.py +++ b/flux/flux/__init__.py @@ -1,16 +1,10 @@ # Copyright © 2024 Apple Inc. -import math -import time -from typing import Tuple - -import mlx.core as mx -import mlx.nn as nn -from mlx.utils import tree_unflatten -from tqdm import tqdm - +from .datasets import Dataset, load_dataset +from .flux import FluxPipeline from .lora import LoRALinear from .sampler import FluxSampler +from .trainer import Trainer from .utils import ( load_ae, load_clip, @@ -18,231 +12,5 @@ from .utils import ( load_flow_model, load_t5, load_t5_tokenizer, + save_config, ) - - -class FluxPipeline: - def __init__(self, name: str, t5_padding: bool = True): - self.dtype = mx.bfloat16 - self.name = name - self.t5_padding = t5_padding - - self.ae = load_ae(name) - self.flow = load_flow_model(name) - self.clip = load_clip(name) - self.clip_tokenizer = load_clip_tokenizer(name) - self.t5 = load_t5(name) - self.t5_tokenizer = load_t5_tokenizer(name) - self.sampler = FluxSampler(name) - - def ensure_models_are_loaded(self): - mx.eval( - self.ae.parameters(), - self.flow.parameters(), - self.clip.parameters(), - self.t5.parameters(), - ) - - def reload_text_encoders(self): - self.t5 = load_t5(self.name) - self.clip = load_clip(self.name) - - def tokenize(self, text): - t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding) - clip_tokens = self.clip_tokenizer.encode(text) - return t5_tokens, clip_tokens - - def _prepare_latent_images(self, x): - b, h, w, c = x.shape - - # Pack the latent image to 2x2 patches - x = x.reshape(b, h // 2, 2, w // 2, 2, c) - x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4) - - # Create positions ids used to positionally encode each patch. Due to - # the way RoPE works, this results in an interesting positional - # encoding where parts of the feature are holding different positional - # information. Namely, the first part holds information independent of - # the spatial position (hence 0s), the 2nd part holds vertical spatial - # information and the last one horizontal. - i = mx.zeros((h // 2, w // 2), dtype=mx.int32) - j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij") - x_ids = mx.stack([i, j, k], axis=-1) - x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0) - - return x, x_ids - - def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens): - # Prepare the text features - txt = self.t5(t5_tokens) - if len(txt) == 1 and n_images > 1: - txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:])) - txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32) - - # Prepare the clip text features - vec = self.clip(clip_tokens).pooled_output - if len(vec) == 1 and n_images > 1: - vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:])) - - return txt, txt_ids, vec - - def _denoising_loop( - self, - x_t, - x_ids, - txt, - txt_ids, - vec, - num_steps: int = 35, - guidance: float = 4.0, - start: float = 1, - stop: float = 0, - ): - B = len(x_t) - - def scalar(x): - return mx.full((B,), x, dtype=self.dtype) - - guidance = scalar(guidance) - timesteps = self.sampler.timesteps( - num_steps, - x_t.shape[1], - start=start, - stop=stop, - ) - for i in range(num_steps): - t = timesteps[i] - t_prev = timesteps[i + 1] - - pred = self.flow( - img=x_t, - img_ids=x_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - timesteps=scalar(t), - guidance=guidance, - ) - x_t = self.sampler.step(pred, x_t, t, t_prev) - - yield x_t - - def generate_latents( - self, - text: str, - n_images: int = 1, - num_steps: int = 35, - guidance: float = 4.0, - latent_size: Tuple[int, int] = (64, 64), - seed=None, - ): - # Set the PRNG state - if seed is not None: - mx.random.seed(seed) - - # Create the latent variables - x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype) - x_T, x_ids = self._prepare_latent_images(x_T) - - # Get the conditioning - t5_tokens, clip_tokens = self.tokenize(text) - txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens) - - # Yield the conditioning for controlled evaluation by the caller - yield (x_T, x_ids, txt, txt_ids, vec) - - # Yield the latent sequences from the denoising loop - yield from self._denoising_loop( - x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance - ) - - def decode(self, x, latent_size: Tuple[int, int] = (64, 64)): - h, w = latent_size - x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2) - x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1) - x = self.ae.decode(x) - return mx.clip(x + 1, 0, 2) * 0.5 - - def generate_images( - self, - text: str, - n_images: int = 1, - num_steps: int = 35, - guidance: float = 4.0, - latent_size: Tuple[int, int] = (64, 64), - seed=None, - reload_text_encoders: bool = True, - progress: bool = True, - ): - latents = self.generate_latents( - text, n_images, num_steps, guidance, latent_size, seed - ) - mx.eval(next(latents)) - - if reload_text_encoders: - self.reload_text_encoders() - - for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True): - mx.eval(x_t) - - images = [] - for i in tqdm(range(len(x_t)), disable=not progress): - images.append(self.decode(x_t[i : i + 1])) - mx.eval(images[-1]) - images = mx.concatenate(images, axis=0) - mx.eval(images) - - return images - - def training_loss( - self, - x_0: mx.array, - t5_features: mx.array, - clip_features: mx.array, - guidance: mx.array, - ): - # Get the text conditioning - txt = t5_features - txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32) - vec = clip_features - - # Prepare the latent input - x_0, x_ids = self._prepare_latent_images(x_0) - - # Forward process - t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) - eps = mx.random.normal(x_0.shape, dtype=self.dtype) - x_t = self.sampler.add_noise(x_0, t, noise=eps) - x_t = mx.stop_gradient(x_t) - - # Do the denoising - pred = self.flow( - img=x_t, - img_ids=x_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - timesteps=t, - guidance=guidance, - ) - - return (pred + x_0 - eps).square().mean() - - def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1): - """Swap the linear layers in the transformer blocks with LoRA layers.""" - all_blocks = self.flow.double_blocks + self.flow.single_blocks - all_blocks.reverse() - num_blocks = num_blocks if num_blocks > 0 else len(all_blocks) - for i, block in zip(range(num_blocks), all_blocks): - loras = [] - for name, module in block.named_modules(): - if isinstance(module, nn.Linear): - loras.append((name, LoRALinear.from_base(module, r=rank))) - block.update_modules(tree_unflatten(loras)) - - def fuse_lora_layers(self): - fused_layers = [] - for name, module in self.flow.named_modules(): - if isinstance(module, LoRALinear): - fused_layers.append((name, module.fuse())) - self.flow.update_modules(tree_unflatten(fused_layers)) diff --git a/flux/flux/datasets.py b/flux/flux/datasets.py new file mode 100644 index 00000000..d31a09f1 --- /dev/null +++ b/flux/flux/datasets.py @@ -0,0 +1,75 @@ +import json +from pathlib import Path + +from PIL import Image + + +class Dataset: + def __getitem__(self, index: int): + raise NotImplementedError() + + def __len__(self): + raise NotImplementedError() + + +class LocalDataset(Dataset): + prompt_key = "prompt" + + def __init__(self, dataset: str, data_file): + self.dataset_base = Path(dataset) + with open(data_file, "r") as fid: + self._data = [json.loads(l) for l in fid] + + def __len__(self): + return len(self._data) + + def __getitem__(self, index: int): + item = self._data[index] + image = Image.open(self.dataset_base / item["image"]) + return image, item[self.prompt_key] + + +class LegacyDataset(LocalDataset): + prompt_key = "text" + + def __init__(self, dataset: str): + self.dataset_base = Path(dataset) + with open(self.dataset_base / "index.json") as f: + self._data = json.load(f)["data"] + + +class HuggingFaceDataset(Dataset): + + def __init__(self, dataset: str): + from datasets import load_dataset as hf_load_dataset + + self._df = hf_load_dataset(dataset)["train"] + + def __len__(self): + return len(self._df) + + def __getitem__(self, index: int): + item = self._df[index] + return item["image"], item["prompt"] + + +def load_dataset(dataset: str): + dataset_base = Path(dataset) + data_file = dataset_base / "train.jsonl" + legacy_file = dataset_base / "index.json" + + if data_file.exists(): + print(f"Load the local dataset {data_file} .", flush=True) + dataset = LocalDataset(dataset, data_file) + elif legacy_file.exists(): + print(f"Load the local dataset {legacy_file} .") + print() + print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.") + print(" See the README for details.") + print(flush=True) + dataset = LegacyDataset(dataset) + else: + print(f"Load the Hugging Face dataset {dataset} .", flush=True) + dataset = HuggingFaceDataset(dataset) + + return dataset diff --git a/flux/flux/flux.py b/flux/flux/flux.py new file mode 100644 index 00000000..3fd044ac --- /dev/null +++ b/flux/flux/flux.py @@ -0,0 +1,246 @@ +# Copyright © 2024 Apple Inc. + +from typing import Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten +from tqdm import tqdm + +from .lora import LoRALinear +from .sampler import FluxSampler +from .utils import ( + load_ae, + load_clip, + load_clip_tokenizer, + load_flow_model, + load_t5, + load_t5_tokenizer, +) + + +class FluxPipeline: + def __init__(self, name: str, t5_padding: bool = True): + self.dtype = mx.bfloat16 + self.name = name + self.t5_padding = t5_padding + + self.ae = load_ae(name) + self.flow = load_flow_model(name) + self.clip = load_clip(name) + self.clip_tokenizer = load_clip_tokenizer(name) + self.t5 = load_t5(name) + self.t5_tokenizer = load_t5_tokenizer(name) + self.sampler = FluxSampler(name) + + def ensure_models_are_loaded(self): + mx.eval( + self.ae.parameters(), + self.flow.parameters(), + self.clip.parameters(), + self.t5.parameters(), + ) + + def reload_text_encoders(self): + self.t5 = load_t5(self.name) + self.clip = load_clip(self.name) + + def tokenize(self, text): + t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding) + clip_tokens = self.clip_tokenizer.encode(text) + return t5_tokens, clip_tokens + + def _prepare_latent_images(self, x): + b, h, w, c = x.shape + + # Pack the latent image to 2x2 patches + x = x.reshape(b, h // 2, 2, w // 2, 2, c) + x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4) + + # Create positions ids used to positionally encode each patch. Due to + # the way RoPE works, this results in an interesting positional + # encoding where parts of the feature are holding different positional + # information. Namely, the first part holds information independent of + # the spatial position (hence 0s), the 2nd part holds vertical spatial + # information and the last one horizontal. + i = mx.zeros((h // 2, w // 2), dtype=mx.int32) + j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij") + x_ids = mx.stack([i, j, k], axis=-1) + x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0) + + return x, x_ids + + def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens): + # Prepare the text features + txt = self.t5(t5_tokens) + if len(txt) == 1 and n_images > 1: + txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:])) + txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32) + + # Prepare the clip text features + vec = self.clip(clip_tokens).pooled_output + if len(vec) == 1 and n_images > 1: + vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:])) + + return txt, txt_ids, vec + + def _denoising_loop( + self, + x_t, + x_ids, + txt, + txt_ids, + vec, + num_steps: int = 35, + guidance: float = 4.0, + start: float = 1, + stop: float = 0, + ): + B = len(x_t) + + def scalar(x): + return mx.full((B,), x, dtype=self.dtype) + + guidance = scalar(guidance) + timesteps = self.sampler.timesteps( + num_steps, + x_t.shape[1], + start=start, + stop=stop, + ) + for i in range(num_steps): + t = timesteps[i] + t_prev = timesteps[i + 1] + + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=scalar(t), + guidance=guidance, + ) + x_t = self.sampler.step(pred, x_t, t, t_prev) + + yield x_t + + def generate_latents( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + ): + # Set the PRNG state + if seed is not None: + mx.random.seed(seed) + + # Create the latent variables + x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype) + x_T, x_ids = self._prepare_latent_images(x_T) + + # Get the conditioning + t5_tokens, clip_tokens = self.tokenize(text) + txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens) + + # Yield the conditioning for controlled evaluation by the caller + yield (x_T, x_ids, txt, txt_ids, vec) + + # Yield the latent sequences from the denoising loop + yield from self._denoising_loop( + x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance + ) + + def decode(self, x, latent_size: Tuple[int, int] = (64, 64)): + h, w = latent_size + x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2) + x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1) + x = self.ae.decode(x) + return mx.clip(x + 1, 0, 2) * 0.5 + + def generate_images( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + reload_text_encoders: bool = True, + progress: bool = True, + ): + latents = self.generate_latents( + text, n_images, num_steps, guidance, latent_size, seed + ) + mx.eval(next(latents)) + + if reload_text_encoders: + self.reload_text_encoders() + + for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True): + mx.eval(x_t) + + images = [] + for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"): + images.append(self.decode(x_t[i : i + 1])) + mx.eval(images[-1]) + images = mx.concatenate(images, axis=0) + mx.eval(images) + + return images + + def training_loss( + self, + x_0: mx.array, + t5_features: mx.array, + clip_features: mx.array, + guidance: mx.array, + ): + # Get the text conditioning + txt = t5_features + txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32) + vec = clip_features + + # Prepare the latent input + x_0, x_ids = self._prepare_latent_images(x_0) + + # Forward process + t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) + eps = mx.random.normal(x_0.shape, dtype=self.dtype) + x_t = self.sampler.add_noise(x_0, t, noise=eps) + x_t = mx.stop_gradient(x_t) + + # Do the denoising + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t, + guidance=guidance, + ) + + return (pred + x_0 - eps).square().mean() + + def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1): + """Swap the linear layers in the transformer blocks with LoRA layers.""" + all_blocks = self.flow.double_blocks + self.flow.single_blocks + all_blocks.reverse() + num_blocks = num_blocks if num_blocks > 0 else len(all_blocks) + for i, block in zip(range(num_blocks), all_blocks): + loras = [] + for name, module in block.named_modules(): + if isinstance(module, nn.Linear): + loras.append((name, LoRALinear.from_base(module, r=rank))) + block.update_modules(tree_unflatten(loras)) + + def fuse_lora_layers(self): + fused_layers = [] + for name, module in self.flow.named_modules(): + if isinstance(module, LoRALinear): + fused_layers.append((name, module.fuse())) + self.flow.update_modules(tree_unflatten(fused_layers)) diff --git a/flux/flux/model.py b/flux/flux/model.py index 18ea70b0..d8ad9d9b 100644 --- a/flux/flux/model.py +++ b/flux/flux/model.py @@ -85,6 +85,8 @@ class Flux(nn.Module): def sanitize(self, weights): new_weights = {} for k, w in weights.items(): + if k.startswith("model.diffusion_model."): + k = k[22:] if k.endswith(".scale"): k = k[:-6] + ".weight" for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]: diff --git a/flux/flux/sampler.py b/flux/flux/sampler.py index 3bff1ca2..6f293edc 100644 --- a/flux/flux/sampler.py +++ b/flux/flux/sampler.py @@ -7,7 +7,7 @@ import mlx.core as mx class FluxSampler: - def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5): + def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.15): self._base_shift = base_shift self._max_shift = max_shift self._schnell = "schnell" in name @@ -25,7 +25,7 @@ class FluxSampler: ): t = mx.linspace(start, stop, num_steps + 1) - if self._schnell: + if not self._schnell: t = self._time_shift(image_sequence_length, t) return t.tolist() @@ -50,6 +50,7 @@ class FluxSampler: if noise is not None else mx.random.normal(x.shape, dtype=x.dtype, key=key) ) + t = t.reshape([-1] + [1] * (x.ndim - 1)) return x * (1 - t) + t * noise def step(self, pred, x_t, t, t_prev): diff --git a/flux/flux/trainer.py b/flux/flux/trainer.py new file mode 100644 index 00000000..40a126e8 --- /dev/null +++ b/flux/flux/trainer.py @@ -0,0 +1,98 @@ +import mlx.core as mx +import numpy as np +from PIL import Image, ImageFile +from tqdm import tqdm + +from .datasets import Dataset +from .flux import FluxPipeline + + +class Trainer: + + def __init__(self, flux: FluxPipeline, dataset: Dataset, args): + self.flux = flux + self.dataset = dataset + self.args = args + self.latents = [] + self.t5_features = [] + self.clip_features = [] + + def _random_crop_resize(self, img): + resolution = self.args.resolution + width, height = img.size + + a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist() + + # Random crop the input image between 0.8 to 1.0 of its original dimensions + crop_size = ( + max((0.8 + 0.2 * a) * width, resolution[0]), + max((0.8 + 0.2 * b) * height, resolution[1]), + ) + pan = (width - crop_size[0], height - crop_size[1]) + img = img.crop( + ( + pan[0] * c, + pan[1] * d, + crop_size[0] + pan[0] * c, + crop_size[1] + pan[1] * d, + ) + ) + + # Fit the largest rectangle with the ratio of resolution in the image + # rectangle. + width, height = crop_size + ratio = resolution[0] / resolution[1] + r1 = (height * ratio, height) + r2 = (width, width / ratio) + r = r1 if r1[0] <= width else r2 + img = img.crop( + ( + (width - r[0]) / 2, + (height - r[1]) / 2, + (width + r[0]) / 2, + (height + r[1]) / 2, + ) + ) + + # Finally resize the image to resolution + img = img.resize(resolution, Image.LANCZOS) + + return mx.array(np.array(img)) + + def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int): + for i in range(num_augmentations): + img = self._random_crop_resize(input_img) + img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1 + x_0 = self.flux.ae.encode(img[None]) + x_0 = x_0.astype(self.flux.dtype) + mx.eval(x_0) + self.latents.append(x_0) + + def _encode_prompt(self, prompt): + t5_tok, clip_tok = self.flux.tokenize([prompt]) + t5_feat = self.flux.t5(t5_tok) + clip_feat = self.flux.clip(clip_tok).pooled_output + mx.eval(t5_feat, clip_feat) + self.t5_features.append(t5_feat) + self.clip_features.append(clip_feat) + + def encode_dataset(self): + """Encode the images & prompt in the latent space to prepare for training.""" + self.flux.ae.eval() + for image, prompt in tqdm(self.dataset, desc="encode dataset"): + self._encode_image(image, self.args.num_augmentations) + self._encode_prompt(prompt) + + def iterate(self, batch_size): + xs = mx.concatenate(self.latents) + t5 = mx.concatenate(self.t5_features) + clip = mx.concatenate(self.clip_features) + mx.eval(xs, t5, clip) + n_aug = self.args.num_augmentations + while True: + x_indices = mx.random.permutation(len(self.latents)) + c_indices = x_indices // n_aug + for i in range(0, len(self.latents), batch_size): + x_i = x_indices[i : i + batch_size] + c_i = c_indices[i : i + batch_size] + yield xs[x_i], t5[c_i], clip[c_i] diff --git a/flux/flux/utils.py b/flux/flux/utils.py index 21db17d3..2437f21f 100644 --- a/flux/flux/utils.py +++ b/flux/flux/utils.py @@ -3,7 +3,8 @@ import json import os from dataclasses import dataclass -from typing import Optional +from pathlib import Path +from typing import Optional, Union import mlx.core as mx from huggingface_hub import hf_hub_download @@ -207,3 +208,23 @@ def load_clip_tokenizer(name: str): def load_t5_tokenizer(name: str, pad: bool = True): model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model") return T5Tokenizer(model_file, 256 if "schnell" in name else 512) + + +def save_config( + config: dict, + config_path: Union[str, Path], +) -> None: + """Save the model configuration to the ``config_path``. + + The final configuration will be sorted before saving for better readability. + + Args: + config (dict): The model configuration. + config_path (Union[str, Path]): Model configuration file path. + """ + # Sort the config for better readability + config = dict(sorted(config.items())) + + # Write the config to the provided file + with open(config_path, "w") as fid: + json.dump(config, fid, indent=4) diff --git a/flux/txt2image.py b/flux/txt2image.py index bf2f7294..5ebec81a 100644 --- a/flux/txt2image.py +++ b/flux/txt2image.py @@ -77,7 +77,7 @@ if __name__ == "__main__": nn.quantize(flux.clip, class_predicate=quantization_predicate) if args.preload_models: - sd.ensure_models_are_loaded() + flux.ensure_models_are_loaded() # Make the generator latent_size = to_latent_size(args.image_size) diff --git a/llava/generate.py b/llava/generate.py index 8067839e..64313858 100644 --- a/llava/generate.py +++ b/llava/generate.py @@ -79,10 +79,10 @@ def load_image(image_source): def prepare_inputs(processor, image, prompt): if isinstance(image, str): image = load_image(image) - inputs = processor(prompt, image, return_tensors="np") + inputs = processor(image, prompt, return_tensors="np") pixel_values = mx.array(inputs["pixel_values"]) input_ids = mx.array(inputs["input_ids"]) - return input_ids, pixel_values + return pixel_values, input_ids def load_model(model_path, tokenizer_config={}): @@ -126,8 +126,7 @@ def main(): processor, model = load_model(args.model, tokenizer_config) prompt = codecs.decode(args.prompt, "unicode_escape") - - input_ids, pixel_values = prepare_inputs(processor, args.image, prompt) + pixel_values, input_ids = prepare_inputs(processor, args.image, prompt) print(prompt) generated_text = generate_text( diff --git a/llava/llava.py b/llava/llava.py index 9e6b7511..c5f190f8 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -104,31 +104,21 @@ class LlavaModel(nn.Module): self, image_features, inputs_embeds, input_ids ): image_token_index = self.config.image_token_index - num_images, num_image_patches, embed_dim = image_features.shape + batch_size, num_image_patches, embed_dim = image_features.shape # Positions of tokens in input_ids, assuming batch size is 1 - image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + image_positions = mx.array( + np.where(input_ids[0] == image_token_index)[0], mx.uint32 + ) - if len(image_positions) != num_images: + if len(image_positions) != num_image_patches: raise ValueError( f"The number of image tokens ({len(image_positions)}) does not " - f" match the number of image inputs ({num_images})." + f" match the number of image patches ({num_image_patches})." ) - text_segments = [] - start_idx = 0 - - for position in image_positions: - text_segments.append(inputs_embeds[:, start_idx:position]) - start_idx = position + 1 - - image_embeddings = mx.split(image_features, image_features.shape[0]) - final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] - final_embeddings += [inputs_embeds[:, start_idx:]] - - # Create a final embedding of shape - # (1, num_image_patches*num_images + sequence_len, embed_dim) - return mx.concatenate(final_embeddings, axis=1) + inputs_embeds[0, image_positions] = image_features + return inputs_embeds def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): input_embddings = self.get_input_embeddings(input_ids, pixel_values) diff --git a/llms/README.md b/llms/README.md index 20863041..e943ed69 100644 --- a/llms/README.md +++ b/llms/README.md @@ -58,10 +58,10 @@ prompt = "Write a story about Einstein" messages = [{"role": "user", "content": prompt}] prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, add_generation_prompt=True ) -response = generate(model, tokenizer, prompt=prompt, verbose=True) +text = generate(model, tokenizer, prompt=prompt, verbose=True) ``` To see a description of all the arguments you can do: @@ -77,7 +77,7 @@ to see how to use the API in more detail. The `mlx-lm` package also comes with functionality to quantize and optionally upload models to the Hugging Face Hub. -You can convert models in the Python API with: +You can convert models using the Python API: ```python from mlx_lm import convert @@ -100,8 +100,10 @@ To see a description of all the arguments you can do: #### Streaming -For streaming generation, use the `stream_generate` function. This returns a -generator object which streams the output text. For example, +For streaming generation, use the `stream_generate` function. This yields +a generation response object. + +For example, ```python from mlx_lm import load, stream_generate @@ -113,11 +115,11 @@ prompt = "Write a story about Einstein" messages = [{"role": "user", "content": prompt}] prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, add_generation_prompt=True ) -for t in stream_generate(model, tokenizer, prompt, max_tokens=512): - print(t, end="", flush=True) +for response in stream_generate(model, tokenizer, prompt, max_tokens=512): + print(response.text, end="", flush=True) print() ``` @@ -161,6 +163,10 @@ mlx_lm.convert \ --upload-repo mlx-community/my-4bit-mistral ``` +Models can also be converted and quantized directly in the +[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging +Face Space. + ### Long Prompts and Generations `mlx-lm` has some tools to scale efficiently to long prompts and generations: @@ -221,6 +227,7 @@ Here are a few examples of Hugging Face models that work with this example: - [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct) - [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b) - [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b) +- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct) Most [Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending), @@ -248,3 +255,28 @@ model, tokenizer = load( tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True}, ) ``` + +### Large Models + +> [!NOTE] + This requires macOS 15.0 or higher to work. + +Models which are large relative to the total RAM available on the machine can +be slow. `mlx-lm` will attempt to make them faster by wiring the memory +occupied by the model and cache. This requires macOS 15 or higher to +work. + +If you see the following warning message: + +> [WARNING] Generating with a model that requires ... + +then the model will likely be slow on the given machine. If the model fits in +RAM then it can often be sped up by increasing the system wired memory limit. +To increase the limit, set the following `sysctl`: + +```bash +sudo sysctl iogpu.wired_limit_mb=N +``` + +The value `N` should be larger than the size of the model in megabytes but +smaller than the memory size of the machine. diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 2d0dcf60..9eac9d7f 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -222,6 +222,17 @@ data formats. Here are examples of these formats: } ``` + +The format for the `arguments` field in a function varies for different models. +Common formats include JSON strings and dictionaries. The example provided +follows the format used by +[OpenAI](https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples) +and [Mistral +AI](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#instruct). +A dictionary format is used in Hugging Face's [chat +templates](https://huggingface.co/docs/transformers/main/en/chat_templating#a-complete-tool-use-example). +Refer to the documentation for the model you are fine-tuning for more details. + `completions`: @@ -230,18 +241,29 @@ data formats. Here are examples of these formats: {"prompt": "What is the capital of France?", "completion": "Paris."} ``` +For the `completions` data format, a different key can be used for the prompt +and completion by specifying the following in the YAML config: + +```yaml +prompt_feature: "input" +completion_feature: "output" +``` + +Here, `"input"` is the expected key instead of the default `"prompt"`, and +`"output"` is the expected key instead of `"completion"`. + `text`: ```jsonl {"text": "This is an example for the model."} ``` -Note, the format is automatically determined by the dataset. Note also, keys in -each line not expected by the loader will be ignored. +Note, the format is automatically determined by the dataset. Note also, keys +in each line not expected by the loader will be ignored. > [!NOTE] > Each example in the datasets must be on a single line. Do not put more than -> one example per line and do not split an example accross multiple lines. +> one example per line and do not split an example across multiple lines. ### Hugging Face Datasets @@ -259,7 +281,7 @@ Otherwise, provide a mapping of keys in the dataset to the features MLX LM expects. Use a YAML config to specify the Hugging Face dataset arguments. For example: -``` +```yaml hf_dataset: name: "billsum" prompt_feature: "text" diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index 55be1c9c..e544c6fa 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \ - `role_mapping`: (Optional) A dictionary to customize the role prefixes in the generated prompt. If not provided, the default mappings are used. -- `stop`: (Optional) An array of strings or a single string. Thesse are +- `stop`: (Optional) An array of strings or a single string. These are sequences of tokens on which the generation should stop. - `max_tokens`: (Optional) An integer specifying the maximum number of tokens @@ -84,7 +84,37 @@ curl localhost:8080/v1/chat/completions \ started in. - `adapters`: (Optional) A string path to low-rank adapters. The path must be - rlative to the directory the server was started in. + relative to the directory the server was started in. + +### Response Fields + +- `id`: A unique identifier for the chat. + +- `system_fingerprint`: A unique identifier for the system. + +- `object`: Any of "chat.completion", "chat.completion.chunk" (for + streaming), or "text.completion". + +- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`). + +- `created`: A time-stamp for when the request was processed. + +- `choices`: A list of outputs. Each output is a dictionary containing the fields: + - `index`: The index in the list. + - `logprobs`: A dictionary containing the fields: + - `token_logprobs`: A list of the log probabilities for the generated + tokens. + - `tokens`: A list of the generated token ids. + - `top_logprobs`: A list of lists. Each list contains the `logprobs` + top tokens (if requested) with their corresponding probabilities. + - `finish_reason`: The reason the completion ended. This can be either of + `"stop"` or `"length"`. + - `message`: The text response from the model. + +- `usage`: A dictionary containing the fields: + - `prompt_tokens`: The number of prompt tokens processed. + - `completion_tokens`: The number of tokens generated. + - `total_tokens`: The total number of tokens, i.e. the sum of the above two fields. ### List Models @@ -97,5 +127,5 @@ curl localhost:8080/v1/models -H "Content-Type: application/json" This will return a list of locally available models where each model in the list contains the following fields: -- `"id"`: The Hugging Face repo id. -- `"created"`: A timestamp representing the model creation time. +- `id`: The Hugging Face repo id. +- `created`: A time-stamp representing the model creation time. diff --git a/llms/mlx_lm/__init__.py b/llms/mlx_lm/__init__.py index 502c78e5..538be927 100644 --- a/llms/mlx_lm/__init__.py +++ b/llms/mlx_lm/__init__.py @@ -1,4 +1,9 @@ # Copyright © 2023-2024 Apple Inc. +import os + from ._version import __version__ + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" + from .utils import convert, generate, load, stream_generate diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 70239db6..b2f98e6f 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.19.1" +__version__ = "0.21.0" diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 04e75a3e..c18f1bae 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -8,7 +8,9 @@ import time import mlx.core as mx from .models.cache import make_prompt_cache, save_prompt_cache -from .utils import load +from .utils import generate_step, load + +DEFAULT_QUANTIZED_KV_START = 5000 def setup_arg_parser(): @@ -48,12 +50,6 @@ def setup_arg_parser(): action="store_true", help="Use the default chat template", ) - parser.add_argument( - "--cache-limit-gb", - type=int, - default=None, - help="Set the MLX cache limit in GB", - ) parser.add_argument( "--max-kv-size", type=int, @@ -70,6 +66,26 @@ def setup_arg_parser(): required=True, help="Message to be processed by the model ('-' reads from stdin)", ) + parser.add_argument( + "--kv-bits", + type=int, + help="Number of bits for KV cache quantization. " + "Defaults to no quantization.", + default=None, + ) + parser.add_argument( + "--kv-group-size", + type=int, + help="Group size for KV cache quantization.", + default=64, + ) + parser.add_argument( + "--quantized-kv-start", + help="When --kv-bits is set, start quantizing the KV cache " + "from this step onwards.", + type=int, + default=DEFAULT_QUANTIZED_KV_START, + ) return parser @@ -77,9 +93,6 @@ def main(): parser = setup_arg_parser() args = parser.parse_args() - if args.cache_limit_gb is not None: - mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - # Building tokenizer_config tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} if args.eos_token is not None: @@ -97,54 +110,50 @@ def main(): if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template - if not args.ignore_chat_template and ( - hasattr(tokenizer, "apply_chat_template") - and tokenizer.chat_template is not None - ): + if not args.ignore_chat_template and tokenizer.chat_template is not None: messages = [{"role": "user", "content": args.prompt}] prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, add_generation_prompt=False, continue_final_message=True ) - # Treat the prompt as a prefix assuming that the suffix will be - # provided at generation time. - test_prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": ""}], - tokenize=False, - add_generation_prompt=True, - ) - n = len(test_prompt) - test_prompt.index("") - len("") - prompt = prompt[:-n] else: - prompt = args.prompt + prompt = tokenizer.encode(args.prompt) cache = make_prompt_cache(model, args.max_kv_size) - y = mx.array(tokenizer.encode(prompt)) + y = mx.array(prompt) # Process the prompt - processed = 0 - step_size = 512 start = time.time() max_msg_len = 0 - while y.size > 0: - model(y[:step_size][None], cache=cache) - mx.eval([c.state for c in cache]) - processed += min(y.size, step_size) - y = y[step_size:] + + def callback(processed, total_tokens): current = time.time() speed = processed / (current - start) msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" + nonlocal max_msg_len max_msg_len = max(max_msg_len, len(msg)) print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) + + for _ in generate_step( + y, + model, + max_tokens=0, + prompt_cache=cache, + kv_bits=args.kv_bits, + kv_group_size=args.kv_group_size, + quantized_kv_start=args.quantized_kv_start, + prompt_progress_callback=callback, + ): + pass + print() - print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") + print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB") print("Saving...") metadata = {} metadata["model"] = args.model metadata["chat_template"] = tokenizer.chat_template metadata["tokenizer_config"] = json.dumps(tokenizer_config) - print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") save_prompt_cache(args.prompt_cache_file, cache, metadata) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 7968a868..e52ad10d 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -5,12 +5,14 @@ import json import mlx.core as mx -from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache +from .models.cache import make_prompt_cache +from .sample_utils import make_sampler from .utils import load, stream_generate DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 +DEFAULT_MAX_TOKENS = 256 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" @@ -41,6 +43,13 @@ def setup_arg_parser(): help="Set the maximum key-value cache size", default=None, ) + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=DEFAULT_MAX_TOKENS, + help="Maximum number of tokens to generate", + ) return parser @@ -56,25 +65,23 @@ def main(): tokenizer_config={"trust_remote_code": True}, ) - print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.") + print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") prompt_cache = make_prompt_cache(model, args.max_kv_size) while True: query = input(">> ") if query == "q": break messages = [{"role": "user", "content": query}] - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) + prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) for response in stream_generate( model, tokenizer, prompt, - temp=args.temp, - top_p=args.top_p, + max_tokens=args.max_tokens, + sampler=make_sampler(args.temp, args.top_p), prompt_cache=prompt_cache, ): - print(response, flush=True, end="") + print(response.text, flush=True, end="") print() diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index a3f43f71..9bac77a5 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -31,7 +31,7 @@ def configure_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--dtype", - help="Type to save the parameters, ignored if -q is given.", + help="Type to save the non-quantized parameters.", type=str, choices=["float16", "bfloat16", "float32"], default="float16", diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py new file mode 100644 index 00000000..ca5e83bb --- /dev/null +++ b/llms/mlx_lm/evaluate.py @@ -0,0 +1,392 @@ +# Copyright © 2024 Apple Inc. + +""" +Adapted from a PyTorch implementation by David Grangier +""" + +import argparse +import json +import logging +import os +from importlib.metadata import version +from pathlib import Path +from typing import Optional, Union + +import lm_eval +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from lm_eval.api.model import LM +from lm_eval.api.registry import register_model +from tqdm import tqdm + +from .models.cache import make_prompt_cache +from .utils import load, stream_generate + +PAD = 0 + + +def _len_longest_common_prefix(a, b): + l = 0 + for item_a, item_b in zip(a, b): + if item_a != item_b: + break + l += 1 + return l + + +def _rstrip_until(s, untils): + """Limit a string to the first occurrence of any substring in untils.""" + l = len(s) + f = [s.find(u) for u in untils] + f = [l if x < 0 else x for x in f] + return s[: min(f)] + + +def _pad_inputs( + inputs, + maxlen, + genlen=0, + pad_left=False, + pad_multiple=32, + truncate=False, +): + # pad the prompts to the left with at least genlen tokens. + actual_maxlen = max(len(p) for p in inputs) + genlen + if actual_maxlen > maxlen: + if not truncate: + raise ValueError("Inputs are too long.") + else: # drop begining + actual_maxlen = maxlen + inputs = [p[max(0, len(p) - maxlen) :] for p in inputs] + if pad_multiple > 0: + maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple + maxlen *= pad_multiple + assert PAD == 0 + lr = np.array((1, 0) if pad_left else (0, 1)) + return np.stack( + [np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs], + axis=0, + ) + + +@register_model("mlxlm") +class MLXLM(LM): + def __init__( + self, + path_or_hf_repo: str, + batch_size: int = 16, + max_tokens: Optional[int] = None, + use_chat_template: Optional[bool] = None, + ) -> None: + super().__init__() + self._batch_size = batch_size + self._model, self.tokenizer = load(path_or_hf_repo) + self._max_tokens = max_tokens or self.tokenizer.model_max_length + self.use_chat_template = use_chat_template or ( + self.tokenizer.chat_template is not None + ) + + def _score_fn(self, inputs, tokenize=True, step_size=32): + if tokenize: + inputs = self._tokenize(inputs) + inputs = _pad_inputs(inputs, self._max_tokens, truncate=False) + inputs = mx.array(inputs) + inputs, targets = inputs[..., :-1], inputs[..., 1:] + + cache = make_prompt_cache(self._model) + + mask = targets != PAD + + scores, is_greedy = [], [] + for i in range(0, inputs.shape[1], step_size): + logits = self._model(inputs[:, i : i + step_size], cache=cache) + + log_probs = nn.log_softmax(logits.astype(mx.float32)) + score = mx.take_along_axis( + log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1 + )[..., 0] + ig = mask[:, i : i + step_size] * ( + targets[:, i : i + step_size] == mx.argmax(logits, axis=-1) + ) + + mx.eval(score, ig) + mx.metal.clear_cache() + + is_greedy.append(ig) + scores.append(score) + + scores = mx.concatenate(scores, axis=1) + is_greedy = mx.concatenate(is_greedy, axis=1) + + return scores, mask.sum(axis=-1), is_greedy + + def _loglikelihood(self, texts, score_spans=None, tokenize=True): + # sort by length to get batches with little padding. + sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i])) + sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))] + sorted_spans = None + if score_spans is not None: + sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))] + + results = [] + for i in tqdm(range(0, len(sorted_inputs), self._batch_size)): + batch = sorted_inputs[i : i + self._batch_size] + scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize) + for j in range(len(batch)): + if sorted_spans is None: # full sequence score + mask = mx.arange(scores[j].shape[-1]) < length + score = (scores[j].astype(mx.float32) * mask).sum(axis=-1) + ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1) + else: # subsequence score + start, end = sorted_spans[i + j] + score = scores[j][start:end].astype(mx.float32).sum() + ig = is_greedy[j][start:end].astype(mx.int32).sum() + length = end - start + + results.append((score.item(), ig.item(), length)) + + # reorder the outputs + inv_sort = np.argsort(sorted_indices) + results = [results[inv_sort[i]] for i in range(len(results))] + + return results + + def _tokenize(self, texts): + return [ + tuple( + self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template) + ) + for t in texts + ] + + def loglikelihood(self, requests) -> list[tuple[float, bool]]: + """Compute log-likelihood of generating a continuation from a context. + Downstream tasks should attempt to use loglikelihood instead of other + LM calls whenever possible. + :param requests: list[Instance] + A list of Instance objects, with property `args` which returns a tuple (context, continuation). + `context: str` + Context string. Implementations of LM must be able to handle an + empty context string. + `continuation: str` + The continuation over which log likelihood will be calculated. If + there is a word boundary, the space should be in the continuation. + For example, context="hello" continuation=" world" is correct. + :return: list[tuple[float, bool]] + A list of pairs (logprob, isgreedy) + `logprob: float` + The log probability of `continuation`. + `isgreedy`: + Whether `continuation` would be generated by greedy sampling from `context`. + """ + logging.info("Estimating loglikelihood for %d pairs." % len(requests)) + + # tokenize prefix and prefix + completion for all requests. + tokenized = self._tokenize( + [t for r in requests for t in [r.args[0], r.args[0] + r.args[1]]] + ) + + # max length (prefix + completion) and longest common prefix per question. + length_stats = {} + for prefix, completed in zip(tokenized[0::2], tokenized[1::2]): + max_completed_l, min_prefix_l = length_stats.get(prefix, (0, 1e8)) + length_stats[prefix] = ( + max(max_completed_l, len(completed)), + min(min_prefix_l, _len_longest_common_prefix(prefix, completed)), + ) + + # truncate requests for completed sequences longer than model context. + shortened = [] + completion_spans = [] + long_completions = 0 + for prefix, completed in zip(tokenized[0::2], tokenized[1::2]): + max_completed_l, prefix_l = length_stats[prefix] + # compute truncation length + truncation = max(0, max_completed_l - self._max_tokens - 1) + prefix_l = prefix_l - truncation + if prefix_l <= 0: + # completion too long, prefix is eliminated for some requests. + long_completions += 1 + truncation = max(0, len(completed) - self._max_tokens - 1) + prefix_l = 1 + # truncate the completed sequence + completed = completed[truncation:] + shortened.append(completed) + # scores do not include initial bos, substract 1 to span bounds + completion_spans.append((prefix_l - 1, len(completed) - 1)) + + if long_completions > 0: + logging.info( + f"Prefix eliminated for {long_completions} requests with " + + "completion longer than context." + ) + + # model scoring, returns num_requests x (logp, is_greedy, length). + results = self._loglikelihood( + shortened, + score_spans=completion_spans, + tokenize=False, + ) + return [(r[0], r[1] == r[2]) for r in results] + + tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name + apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template + + def loglikelihood_rolling(self, requests) -> list[float]: + """Compute full log-likelihood of a string, with no truncation, for perplexity computation + - We will use the full max context length of the model. + - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to + the max context length. + - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations + which may simply concatenate multiple documents together. + - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into + multiple chunks, the last input will still a full-sized context. + Example: + Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] + Prefix: EOT + Max context length: 4 + Resulting input/prediction pairs: + INPUT: EOT 0 1 2 + PRED: 0 1 2 3 + INPUT: 3 4 5 6 + PRED: 4 5 6 7 + INPUT: 5 6 7 8 + PRED: 8 9 + Observe that: + 1. Each token is predicted exactly once + 2. For the last pair, we provide the full context, but only score the last two tokens + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context,). + string: str + String for which we are computing overall loglikelihood + :return: list[tuple[float]] + A list of tuples (logprob,) + logprob: float + The log probability of `context` conditioned on the EOT token. + """ + logging.info( + "Estimating loglikelihood rolling for %d sequences." % len(requests) + ) + inputs = [req.args[0] for req in requests] + return [t[0] for t in self._loglikelihood(inputs)] + + def generate_until(self, requests) -> list[str]: + """Generate greedily until a stopping sequence + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context, until). + context: str + Context string + until: [str] + The string sequences to generate until. These string sequences + may each span across multiple tokens, or may be part of one token. + :return: list[str] + A list of strings continuation + continuation: str + The generated continuation. + """ + logging.info("Generating continuation for %d sequences." % len(requests)) + contexts, options = zip(*[req.args for req in requests]) + # contrary to the doc the second element of the tuple contains + # {'do_sample': False, 'until': ['\n\n'], 'temperature': 0} + keys = list(options[0].keys()) + assert "until" in keys + untils = [x["until"] for x in options] + completions = [] + + for context, until in tqdm(zip(contexts, untils), total=len(contexts)): + context = self._tokenize(context) + max_tokens = min( + self._max_tokens, + self.tokenizer.model_max_length - len(context), + ) + text = "" + for response in stream_generate( + self._model, self.tokenizer, prompt=context, max_tokens=max_tokens + ): + text += response.text + if any(u in text for u in until): + text = _rstrip_until(text, until) + completions.append(text) + break + else: + completions.append(text) + return completions + + +def main(): + parser = argparse.ArgumentParser( + "Evaluate an MLX model using lm-evaluation-harness." + ) + parser.add_argument("--model", help="Model to evaluate", required=True) + parser.add_argument("--tasks", nargs="+", required=True) + parser.add_argument( + "--output-dir", default=".", help="Output directory for result files." + ) + parser.add_argument("--batch-size", type=int, default=16, help="Batch size") + parser.add_argument("--num-shots", type=int, default=0, help="Number of shots") + parser.add_argument( + "--max-tokens", + type=int, + help="Maximum nunber of tokens to generate. Defaults to the model's max context length.", + ) + parser.add_argument( + "--limit", + default=1.0, + help="Limit the number of examples per task.", + type=float, + ) + parser.add_argument("--seed", type=int, default=123, help="Random seed.") + parser.add_argument( + "--fewshot-as-multiturn", + action="store_true", + help="Whether to provide the fewshot examples as a multiturn " + "conversation or a single user turn.", + default=False, + ) + parser.add_argument( + "--apply-chat-template", + action=argparse.BooleanOptionalAction, + help="Specifies whether to apply a chat template to the prompt. If " + "the model has a chat template, this defaults to `True`, " + "otherwise `False`.", + default=None, + ) + args = parser.parse_args() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Silence tokenizer warnings + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + mx.random.seed(args.seed) + + lm = MLXLM( + args.model, + batch_size=args.batch_size, + max_tokens=args.max_tokens, + use_chat_template=args.apply_chat_template, + ) + results = lm_eval.simple_evaluate( + model=lm, + tasks=args.tasks, + fewshot_as_multiturn=args.fewshot_as_multiturn, + apply_chat_template=lm.use_chat_template, + num_fewshot=args.num_shots, + limit=args.limit, + random_seed=args.seed, + numpy_random_seed=args.seed, + torch_random_seed=args.seed, + fewshot_random_seed=args.seed, + ) + + model_name = args.model.replace("/", "_") + task_names = "_".join(args.tasks) + ver = version("lm_eval") + filename = f"eval_{model_name}_{task_names}_{args.num_shots:02d}_v_{ver}.json" + output_path = output_dir / filename + output_path.write_text(json.dumps(results["results"], indent=4)) + print("Results:") + for result in results["results"].values(): + print(json.dumps(result, indent=4)) diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py index 3bf01688..4a7020f1 100644 --- a/llms/mlx_lm/examples/chat.py +++ b/llms/mlx_lm/examples/chat.py @@ -15,9 +15,7 @@ prompt_cache = make_prompt_cache(model) # User turn prompt = "Hi my name is ." messages = [{"role": "user", "content": prompt}] -prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True -) +prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) # Assistant response response = generate( @@ -32,9 +30,7 @@ response = generate( # User turn prompt = "What's my name?" messages = [{"role": "user", "content": prompt}] -prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True -) +prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) # Assistant response response = generate( @@ -42,7 +38,6 @@ response = generate( tokenizer, prompt=prompt, verbose=True, - temp=0.0, prompt_cache=prompt_cache, ) diff --git a/llms/mlx_lm/examples/generate_response.py b/llms/mlx_lm/examples/generate_response.py index 25730617..41eaf1da 100644 --- a/llms/mlx_lm/examples/generate_response.py +++ b/llms/mlx_lm/examples/generate_response.py @@ -14,7 +14,7 @@ conversation = [{"role": "user", "content": prompt}] # Transform the prompt into the chat template prompt = tokenizer.apply_chat_template( - conversation=conversation, tokenize=False, add_generation_prompt=True + conversation=conversation, add_generation_prompt=True ) # Specify the maximum number of tokens @@ -23,14 +23,6 @@ max_tokens = 1_000 # Specify if tokens and timing information will be printed verbose = True -# Some optional arguments for causal language model generation -generation_args = { - "temp": 0.7, - "repetition_penalty": 1.2, - "repetition_context_size": 20, - "top_p": 0.95, -} - # Generate a response with the specified settings response = generate( model=model, @@ -38,5 +30,4 @@ response = generate( prompt=prompt, max_tokens=max_tokens, verbose=verbose, - **generation_args, ) diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index 4ec9a23c..530272c7 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -14,7 +14,7 @@ data: "/path/to/training/data" seed: 0 # Number of layers to fine-tune -lora_layers: 16 +num_layers: 16 # Minibatch size. batch_size: 4 diff --git a/llms/mlx_lm/examples/pipeline_generate.py b/llms/mlx_lm/examples/pipeline_generate.py new file mode 100644 index 00000000..b98e757b --- /dev/null +++ b/llms/mlx_lm/examples/pipeline_generate.py @@ -0,0 +1,75 @@ +# Copyright © 2024 Apple Inc. + +""" +Run with: + +``` +/path/to/mpirun \ + -np 2 \ + --hostfile /path/to/hosts.txt \ + python /path/to/pipeline_generate.py --prompt "hello world" +``` + +Make sure you can run MLX over MPI on two hosts. For more information see the +documentation: + +https://ml-explore.github.io/mlx/build/html/usage/distributed.html). +""" + +import argparse + +import mlx.core as mx +from mlx_lm import load, stream_generate + +parser = argparse.ArgumentParser(description="LLM pipelined inference example") +parser.add_argument( + "--prompt", + "-p", + default="Write a quicksort in C++.", + help="Message to be processed by the model ('-' reads from stdin)", +) +parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=256, + help="Maximum number of tokens to generate", +) +args = parser.parse_args() + +model_repo = "mlx-community/DeepSeek-V3-3bit" + +model, tokenizer = load(model_repo, lazy=True) + +messages = [{"role": "user", "content": args.prompt}] +prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) + +group = mx.distributed.init() +rank = group.rank() +model.model.pipeline(group) +mx.eval(model.parameters()) + +# Synchronize processes before generation to avoid timeout if downloading +# model for the first time. +mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu)) + + +def rprint(*args, **kwargs): + if rank == 0: + print(*args, **kwargs) + + +for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens): + rprint(response.text, end="", flush=True) + +rprint() +rprint("=" * 10) +rprint( + f"Prompt: {response.prompt_tokens} tokens, " + f"{response.prompt_tps:.3f} tokens-per-sec" +) +rprint( + f"Generation: {response.generation_tokens} tokens, " + f"{response.generation_tps:.3f} tokens-per-sec" +) +rprint(f"Peak memory: {response.peak_memory:.3f} GB") diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0bf98ab2..0d286c75 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -6,15 +6,19 @@ import sys import mlx.core as mx -from .models.cache import load_prompt_cache +from .models.cache import QuantizedKVCache, load_prompt_cache +from .sample_utils import make_sampler from .utils import generate, load DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 +DEFAULT_MIN_P = 0.0 +DEFAULT_MIN_TOKENS_TO_KEEP = 1 DEFAULT_SEED = 0 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" +DEFAULT_QUANTIZED_KV_START = 5000 def str2bool(string): @@ -39,18 +43,20 @@ def setup_arg_parser(): help="Optional path for the trained adapter weights and config.", ) parser.add_argument( - "--trust-remote-code", - action="store_true", - help="Enable trusting remote code for tokenizer", + "--extra-eos-token", + type=str, + default=(), + nargs="+", + help="Add tokens in the list of eos tokens that stop generation.", ) parser.add_argument( - "--eos-token", - type=str, + "--system-prompt", default=None, - help="End of sequence token for tokenizer", + help="System prompt to be used for the chat template", ) parser.add_argument( "--prompt", + "-p", default=DEFAULT_PROMPT, help="Message to be processed by the model ('-' reads from stdin)", ) @@ -67,6 +73,15 @@ def setup_arg_parser(): parser.add_argument( "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" ) + parser.add_argument( + "--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p" + ) + parser.add_argument( + "--min-tokens-to-keep", + type=int, + default=DEFAULT_MIN_TOKENS_TO_KEEP, + help="Minimum tokens to keep for min-p sampling.", + ) parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") parser.add_argument( "--ignore-chat-template", @@ -84,17 +99,6 @@ def setup_arg_parser(): default=True, help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'", ) - parser.add_argument( - "--colorize", - action="store_true", - help="Colorize output based on T[0] probability", - ) - parser.add_argument( - "--cache-limit-gb", - type=int, - default=None, - help="Set the MLX cache limit in GB", - ) parser.add_argument( "--max-kv-size", type=int, @@ -107,60 +111,69 @@ def setup_arg_parser(): default=None, help="A file containing saved KV caches to avoid recomputing them", ) + parser.add_argument( + "--kv-bits", + type=int, + help="Number of bits for KV cache quantization. " + "Defaults to no quantization.", + default=None, + ) + parser.add_argument( + "--kv-group-size", + type=int, + help="Group size for KV cache quantization.", + default=64, + ) + parser.add_argument( + "--quantized-kv-start", + help="When --kv-bits is set, start quantizing the KV cache " + "from this step onwards.", + type=int, + default=DEFAULT_QUANTIZED_KV_START, + ) + parser.add_argument( + "--draft-model", + type=str, + help="A model to be used for speculative decoding.", + default=None, + ) + parser.add_argument( + "--num-draft-tokens", + type=int, + help="Number of tokens to draft when using speculative decoding.", + default=2, + ) return parser -def colorprint(color, s): - color_codes = { - "black": 30, - "red": 31, - "green": 32, - "yellow": 33, - "blue": 34, - "magenta": 35, - "cyan": 36, - "white": 39, - } - ccode = color_codes.get(color, 30) - print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True) - - -def colorprint_by_t0(s, t0): - if t0 > 0.95: - color = "white" - elif t0 > 0.70: - color = "green" - elif t0 > 0.30: - color = "yellow" - else: - color = "red" - colorprint(color, s) - - def main(): parser = setup_arg_parser() args = parser.parse_args() mx.random.seed(args.seed) - if args.cache_limit_gb is not None: - mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - # Load the prompt cache and metadata if a cache file is provided using_cache = args.prompt_cache_file is not None if using_cache: prompt_cache, metadata = load_prompt_cache( - args.prompt_cache_file, return_metadata=True + args.prompt_cache_file, + return_metadata=True, ) + if isinstance(prompt_cache[0], QuantizedKVCache): + if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits: + raise ValueError( + "--kv-bits does not match the kv cache loaded from --prompt-cache-file." + ) + if args.kv_group_size != prompt_cache[0].group_size: + raise ValueError( + "--kv-group-size does not match the kv cache loaded from --prompt-cache-file." + ) # Building tokenizer_config tokenizer_config = ( {} if not using_cache else json.loads(metadata["tokenizer_config"]) ) - if args.trust_remote_code: - tokenizer_config["trust_remote_code"] = True - if args.eos_token is not None: - tokenizer_config["eos_token"] = args.eos_token + tokenizer_config["trust_remote_code"] = True model_path = args.model if using_cache: @@ -179,6 +192,8 @@ def main(): adapter_path=args.adapter_path, tokenizer_config=tokenizer_config, ) + for eos_token in args.extra_eos_token: + tokenizer.add_eos_token(eos_token) if args.use_default_chat_template: if tokenizer.chat_template is None: @@ -186,16 +201,14 @@ def main(): elif using_cache: tokenizer.chat_template = metadata["chat_template"] - if not args.ignore_chat_template and ( - hasattr(tokenizer, "apply_chat_template") - and tokenizer.chat_template is not None - ): - messages = [ - { - "role": "user", - "content": sys.stdin.read() if args.prompt == "-" else args.prompt, - } - ] + prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t") + prompt = sys.stdin.read() if prompt == "-" else prompt + if not args.ignore_chat_template and tokenizer.chat_template is not None: + if args.system_prompt is not None: + messages = [{"role": "system", "content": args.system_prompt}] + else: + messages = [] + messages.append({"role": "user", "content": prompt}) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) @@ -203,30 +216,38 @@ def main(): # Treat the prompt as a suffix assuming that the prefix is in the # stored kv cache. if using_cache: + messages[-1]["content"] = "" test_prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": ""}], + messages, tokenize=False, add_generation_prompt=True, ) prompt = prompt[test_prompt.index("") :] + prompt = tokenizer.encode(prompt, add_special_tokens=False) else: - prompt = args.prompt - - if args.colorize and not args.verbose: - raise ValueError("Cannot use --colorize with --verbose=False") - formatter = colorprint_by_t0 if args.colorize else None + prompt = tokenizer.encode(prompt) + if args.draft_model is not None: + draft_model, draft_tokenizer = load(args.draft_model) + if draft_tokenizer.vocab_size != tokenizer.vocab_size: + raise ValueError("Draft model tokenizer does not match model tokenizer.") + else: + draft_model = None + sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate( model, tokenizer, prompt, - args.max_tokens, + max_tokens=args.max_tokens, verbose=args.verbose, - formatter=formatter, - temp=args.temp, - top_p=args.top_p, + sampler=sampler, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, + kv_bits=args.kv_bits, + kv_group_size=args.kv_group_size, + quantized_kv_start=args.quantized_kv_start, + draft_model=draft_model, + num_draft_tokens=args.num_draft_tokens, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index c96e75a7..4d050bd5 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -2,6 +2,7 @@ import argparse import math +import os import re import types from pathlib import Path @@ -57,6 +58,8 @@ CONFIG_DEFAULTS = { "test": False, "test_batches": 500, "max_seq_length": 2048, + "config": None, + "grad_checkpoint": False, "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, } @@ -66,6 +69,7 @@ def build_parser(): parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser.add_argument( "--model", + type=str, help="The path to the local model directory or Hugging Face repo.", ) @@ -74,7 +78,6 @@ def build_parser(): "--train", action="store_true", help="Do training", - default=None, ) parser.add_argument( "--data", @@ -88,7 +91,6 @@ def build_parser(): "--fine-tune-type", type=str, choices=["lora", "dora", "full"], - default="lora", help="Type of fine-tuning to perform: lora, dora, or full.", ) parser.add_argument( @@ -133,7 +135,6 @@ def build_parser(): "--test", action="store_true", help="Evaluate on the test set after training", - default=None, ) parser.add_argument( "--test-batches", @@ -148,16 +149,15 @@ def build_parser(): parser.add_argument( "-c", "--config", - default=None, + type=str, help="A YAML configuration file with the training options", ) parser.add_argument( "--grad-checkpoint", action="store_true", help="Use gradient checkpointing to reduce memory use.", - default=None, ) - parser.add_argument("--seed", type=int, default=None, help="The PRNG seed") + parser.add_argument("--seed", type=int, help="The PRNG seed") return parser @@ -271,6 +271,7 @@ def run(args, training_callback: TrainingCallback = None): def main(): + os.environ["TOKENIZERS_PARALLELISM"] = "true" parser = build_parser() args = parser.parse_args() config = args.config diff --git a/llms/mlx_lm/manage.py b/llms/mlx_lm/manage.py index bb5c3a09..9827f3dc 100644 --- a/llms/mlx_lm/manage.py +++ b/llms/mlx_lm/manage.py @@ -6,19 +6,18 @@ from transformers.commands.user import tabulate def ask_for_confirmation(message: str) -> bool: + """Ask user for confirmation with Y/N prompt. + Returns True for Y/yes, False for N/no/empty.""" y = ("y", "yes", "1") - n = ("n", "no", "0") - all_values = y + n + ("",) - full_message = f"{message} (Y/n) " + n = ("n", "no", "0", "") + full_message = f"{message} (y/n) " while True: answer = input(full_message).lower() - if answer == "": - return False if answer in y: return True if answer in n: return False - print(f"Invalid input. Must be one of {all_values}") + print(f"Invalid input. Must be one of: yes/no/y/n or empty for no") def main(): @@ -43,9 +42,7 @@ def main(): args = parser.parse_args() if args.scan: - print( - "Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".' - ) + print(f'Scanning Hugging Face cache for models with pattern "{args.pattern}".') hf_cache_info = scan_cache_dir() print( tabulate( @@ -86,35 +83,41 @@ def main(): if args.pattern in repo.repo_id ] if repos: + print("\nFound the following models:") print( tabulate( rows=[ [ repo.repo_id, + repo.size_on_disk_str, # Added size information str(repo.repo_path), ] for repo in repos ], headers=[ "REPO ID", + "SIZE", # Added size header "LOCAL PATH", ], ) ) - confirmed = ask_for_confirmation(f"Confirm deletion ?") + confirmed = ask_for_confirmation( + "\nAre you sure you want to delete these models?" + ) if confirmed: for model_info in repos: + print(f"\nDeleting {model_info.repo_id}...") for revision in sorted( model_info.revisions, key=lambda revision: revision.commit_hash ): strategy = hf_cache_info.delete_revisions(revision.commit_hash) strategy.execute() - print("Model(s) deleted.") + print("\nModel(s) deleted successfully.") else: - print("Deletion is cancelled. Do nothing.") + print("\nDeletion cancelled - no changes made.") else: - print(f"No models found.") + print(f'No models found matching pattern "{args.pattern}"') if __name__ == "__main__": diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 3628a808..ad7a4a65 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -5,6 +5,9 @@ from dataclasses import dataclass from typing import Any, Optional import mlx.core as mx +from mlx.utils import tree_map + +from .cache import QuantizedKVCache @dataclass @@ -20,7 +23,12 @@ class BaseModelArgs: ) -def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None): +def create_causal_mask( + N: int, + offset: int = 0, + window_size: Optional[int] = None, + lengths: Optional[mx.array] = None, +): rinds = mx.arange(offset + N) linds = mx.arange(offset, offset + N) if offset else rinds linds = linds[:, None] @@ -28,6 +36,9 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non mask = linds < rinds if window_size is not None: mask = mask | (linds > rinds + window_size) + if lengths is not None: + lengths = lengths[:, None, None, None] + mask = mask | (rinds >= lengths) return mask * -1e9 @@ -39,7 +50,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): if cache is not None and cache[0] is not None: c = cache[0] if hasattr(c, "max_size"): - offset = min(c.max_size - 1, c.offset) + offset = min(c.max_size, c.offset) window_size = c.max_size else: offset = c.offset @@ -48,3 +59,63 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): else: mask = None return mask + + +def quantized_scaled_dot_product_attention( + queries: mx.array, + q_keys: tuple[mx.array, mx.array, mx.array], + q_values: tuple[mx.array, mx.array, mx.array], + scale: float, + mask: Optional[mx.array], + group_size: int = 64, + bits: int = 8, +) -> mx.array: + B, n_q_heads, L, D = queries.shape + n_kv_heads = q_keys[0].shape[-3] + n_repeats = n_q_heads // n_kv_heads + + queries *= scale + + if n_repeats > 1: + queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D)) + q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys) + q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values) + + scores = mx.quantized_matmul( + queries, *q_keys, transpose=True, group_size=group_size, bits=bits + ) + if mask is not None: + scores += mask + scores = mx.softmax(scores, axis=-1, precise=True) + out = mx.quantized_matmul( + scores, *q_values, transpose=False, group_size=group_size, bits=bits + ) + + if n_repeats > 1: + out = mx.reshape(out, (B, n_q_heads, L, D)) + + return out + + +def scaled_dot_product_attention( + queries, + keys, + values, + cache, + scale: float, + mask: Optional[mx.array], +) -> mx.array: + if isinstance(cache, QuantizedKVCache): + return quantized_scaled_dot_product_attention( + queries, + keys, + values, + scale=scale, + mask=mask, + group_size=cache.group_size, + bits=cache.bits, + ) + else: + return mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=scale, mask=mask + ) diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index b06422e5..14026f0c 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional import mlx.core as mx import mlx.nn as nn -from mlx.utils import tree_flatten, tree_unflatten +from mlx.utils import tree_flatten, tree_map, tree_unflatten -def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]: +def make_prompt_cache( + model: nn.Module, + max_kv_size: Optional[int] = None, +) -> List[Any]: """ Construct the model's cache for use when cgeneration. @@ -77,6 +80,13 @@ def load_prompt_cache(file_name, return_metadata=False): return cache +def can_trim_prompt_cache(cache: List[Any]) -> bool: + """ + Check if model's cache can be trimmed. + """ + return all(c.is_trimmable() for c in cache) + + def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: """ Trim the model's cache by the given number of tokens. @@ -91,7 +101,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: Returns: (int): The number of tokens that were trimmed. """ - if not all(c.is_trimmable() for c in cache) or len(cache) == 0: + if not can_trim_prompt_cache(cache) or len(cache) == 0: return 0 return [c.trim(num_tokens) for c in cache][0] @@ -119,6 +129,88 @@ class _BaseCache: return False +class QuantizedKVCache(_BaseCache): + def __init__(self, group_size: int = 64, bits: int = 8): + self.keys = None + self.values = None + self.offset = 0 + self.step = 256 + self.group_size = group_size + self.bits = bits + + def update_and_fetch(self, keys, values): + B, n_kv_heads, num_steps, k_head_dim = keys.shape + v_head_dim = values.shape[-1] + prev = self.offset + + if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]: + el_per_int = 8 * mx.uint32.size // self.bits + new_steps = (self.step + num_steps - 1) // self.step * self.step + shape = (B, n_kv_heads, new_steps) + + def init_quant(dim): + return ( + mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32), + mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), + mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), + ) + + def expand_quant(x): + new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype) + return mx.concatenate([x, new_x], axis=-2) + + if self.keys is not None: + if prev % self.step != 0: + self.keys, self.values = tree_map( + lambda x: x[..., :prev, :], (self.keys, self.values) + ) + + self.keys, self.values = tree_map( + expand_quant, (self.keys, self.values) + ) + else: + self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim) + + self.offset += num_steps + + keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits) + values = mx.quantize(values, group_size=self.group_size, bits=self.bits) + for i in range(len(self.keys)): + self.keys[i][..., prev : self.offset, :] = keys[i] + self.values[i][..., prev : self.offset, :] = values[i] + + return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values)) + + @property + def state(self): + if self.offset == self.keys[0].shape[2]: + return self.keys, self.values + else: + return tree_map( + lambda x: x[..., : self.offset, :], (self.keys, self.values) + ) + + @state.setter + def state(self, v): + self.keys, self.values = v + + @property + def meta_state(self): + return tuple(map(str, (self.step, self.offset, self.group_size, self.bits))) + + @meta_state.setter + def meta_state(self, v): + self.step, self.offset, self.group_size, self.bits = map(int, v) + + def is_trimmable(self): + return True + + def trim(self, n): + n = min(self.offset, n) + self.offset -= n + return n + + class KVCache(_BaseCache): def __init__(self): self.keys = None @@ -173,6 +265,16 @@ class KVCache(_BaseCache): self.offset -= n return n + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: + quant_cache = QuantizedKVCache(group_size=group_size, bits=bits) + quant_cache.offset = self.offset + if self.keys is not None: + quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits) + quant_cache.values = mx.quantize( + self.values, group_size=group_size, bits=bits + ) + return quant_cache + class RotatingKVCache(_BaseCache): @@ -223,9 +325,9 @@ class RotatingKVCache(_BaseCache): self.keys = self._temporal_order(self.keys) self.values = self._temporal_order(self.values) - # The largest size is self.max_size + S - 1 to ensure + # The largest size is self.max_size + S to ensure # every token gets at least self.max_size context - trim_size = self._idx - self.max_size + 1 + trim_size = self._idx - self.max_size self.keys = self._trim(trim_size, self.keys, keys) self.values = self._trim(trim_size, self.values, values) self.offset += keys.shape[2] @@ -313,6 +415,9 @@ class RotatingKVCache(_BaseCache): self._idx -= n return n + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: + raise NotImplementedError("RotatingKVCache Quantization NYI") + class MambaCache(_BaseCache): def __init__(self): diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index 057c816d..b2d16dd7 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -93,8 +93,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) @@ -155,11 +155,13 @@ class CohereModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -180,9 +182,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) out = out * self.model.args.logit_scale return out diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py new file mode 100644 index 00000000..19bfa6b6 --- /dev/null +++ b/llms/mlx_lm/models/cohere2.py @@ -0,0 +1,206 @@ +# Copyright © 2023-2024 Apple Inc. + +from dataclasses import dataclass +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .cache import KVCache, RotatingKVCache + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int = 4096 + head_dim: int = 128 + num_hidden_layers: int = 32 + intermediate_size: int = 14336 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + rope_theta: float = 50000.0 + vocab_size: int = 256000 + layer_norm_eps: float = 1e-05 + logit_scale: float = 0.0625 + attention_bias: bool = False + layer_norm_bias: bool = False + sliding_window: int = 4096 + sliding_window_pattern: int = 4 + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.args = args + self.layer_idx = layer_idx + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.head_dim = head_dim = args.head_dim + if (head_dim * n_heads) != dim: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}" + f" and `num_heads`: {n_heads})." + ) + self.scale = head_dim**-0.5 + + attetion_bias = args.attention_bias + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias) + + self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) + + self.use_sliding_window = (layer_idx + 1) % args.sliding_window_pattern != 0 + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + # Apply RoPE only if sliding window is enabled + if self.use_sliding_window: + if cache is None: + queries = self.rope(queries) + keys = self.rope(keys) + else: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + if self.use_sliding_window and mask is not None: + key_len = keys.shape[-2] + if mask.shape[-1] != key_len: + mask = mask[..., -key_len:] + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + + def __call__(self, x): + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.hidden_size = args.hidden_size + self.n_heads = args.num_attention_heads + + self.self_attn = Attention(args, layer_idx) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = nn.LayerNorm( + args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + h = self.input_layernorm(x) + attn_h = self.self_attn(h, mask, cache) + ff_h = self.mlp(h) + return attn_h + ff_h + x + + +class CohereModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args, layer_idx=i) + for i in range(args.num_hidden_layers) + ] + self.norm = nn.LayerNorm( + args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias + ) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + h = self.embed_tokens(inputs) + + if cache is None: + cache = [None] * len(self.layers) + + if mask is None: + j = self.args.sliding_window_pattern + mask = create_attention_mask(h, cache[j - 1 : j]) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model_type = args.model_type + self.model = CohereModel(args) + self.args = args + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + out = self.model(inputs, mask, cache) + out = self.model.embed_tokens.as_linear(out) + out = out * self.model.args.logit_scale + return out + + def make_cache(self): + caches = [] + for i in range(self.args.num_hidden_layers): + if ( + i % self.args.sliding_window_pattern + == self.args.sliding_window_pattern - 1 + ): + caches.append(KVCache()) + else: + caches.append( + RotatingKVCache(max_size=self.args.sliding_window, keep=0) + ) + return caches + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index 3b7e83d7..886b5630 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -74,8 +74,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.out_proj(output) @@ -197,11 +197,13 @@ class DBRX(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.wte(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) @@ -223,9 +225,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) return self.lm_head(out) @property diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py index 03cb3b1a..ffc30c36 100644 --- a/llms/mlx_lm/models/deepseek.py +++ b/llms/mlx_lm/models/deepseek.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -97,8 +97,8 @@ class DeepseekAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) @@ -211,9 +211,11 @@ class DeepseekModel(nn.Module): self, x: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(x) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -236,8 +238,9 @@ class Model(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 17d061a8..9027da7e 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -220,23 +220,23 @@ class DeepseekV2Attention(nn.Module): k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) - k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1) - if cache is not None: q_pe = self.rope(q_pe, cache.offset) k_pe = self.rope(k_pe, cache.offset) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) keys, values = cache.update_and_fetch( mx.concatenate([k_nope, k_pe], axis=-1), values ) else: q_pe = self.rope(q_pe) k_pe = self.rope(k_pe) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) keys = mx.concatenate([k_nope, k_pe], axis=-1) queries = mx.concatenate([q_nope, q_pe], axis=-1) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) @@ -291,7 +291,7 @@ class MoEGate(nn.Module): scores = scores.reshape(bsz, seq_len, -1) k = self.top_k - inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]) + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] scores = mx.take_along_axis(scores, inds, axis=-1) scores = scores * self.routed_scaling_factor @@ -370,9 +370,12 @@ class DeepseekV2Model(nn.Module): self, x: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(x) - mask = create_attention_mask(h, cache) + + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -395,8 +398,9 @@ class Model(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py new file mode 100644 index 00000000..f95949f9 --- /dev/null +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -0,0 +1,460 @@ +# Copyright © 2024 Apple Inc. + +import math +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str = "deepseek_v3" + vocab_size: int = 102400 + hidden_size: int = 4096 + intermediate_size: int = 11008 + moe_intermediate_size: int = 1407 + num_hidden_layers: int = 30 + num_attention_heads: int = 32 + num_key_value_heads: int = 32 + n_shared_experts: Optional[int] = None + n_routed_experts: Optional[int] = None + routed_scaling_factor: float = 1.0 + kv_lora_rank: int = 512 + q_lora_rank: int = 1536 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + qk_nope_head_dim: int = 128 + topk_method: str = "noaux_tc" + scoring_func: str = "sigmoid" + norm_topk_prob: bool = True + n_group: Optional[int] = None + topk_group: Optional[int] = None + num_experts_per_tok: Optional[int] = None + moe_layer_freq: int = 1 + first_k_dense_replace: int = 0 + max_position_embeddings: int = 2048 + rms_norm_eps: float = 1e-6 + rope_theta: float = 10000.0 + rope_scaling: Dict = None + attention_bias: bool = False + + +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min_val, max_val, dim): + if min_val == max_val: + max_val += 0.001 # Prevent singularity + + linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val) + return mx.clip(linear_func, 0, 1) + + +class DeepseekV3YarnRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + super().__init__() + self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale( + scaling_factor, mscale_all_dim + ) + freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim) + freq_inter = scaling_factor * base ** ( + mx.arange(0, dim, 2, dtype=mx.float32) / dim + ) + low, high = yarn_find_correction_range( + beta_fast, + beta_slow, + dim, + base, + original_max_position_embeddings, + ) + freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) + self._freqs = (freq_inter * freq_extra) / ( + freq_inter * freq_mask + freq_extra * (1 - freq_mask) + ) + + def __call__(self, x, offset=0): + if self.mscale != 1.0: + x = self.mscale * x + return mx.fast.rope( + x, + x.shape[-1], + traditional=True, + base=None, + scale=1.0, + offset=offset, + freqs=self._freqs, + ) + + +class DeepseekV3Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.scale = self.q_head_dim**-0.5 + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, self.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank) + self.q_b_proj = nn.Linear( + self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scale = self.scale * mscale * mscale + + rope_kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rope = DeepseekV3YarnRotaryEmbedding( + dim=self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **rope_kwargs, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + if self.q_lora_rank is None: + q = self.q_proj(x) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) + + q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) + q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1) + compressed_kv = self.kv_a_proj_with_mqa(x) + compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1) + k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) + + k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) + + if cache is not None: + q_pe = self.rope(q_pe, cache.offset) + k_pe = self.rope(k_pe, cache.offset) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) + keys, values = cache.update_and_fetch( + mx.concatenate([k_nope, k_pe], axis=-1), values + ) + else: + q_pe = self.rope(q_pe) + k_pe = self.rope(k_pe) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) + keys = mx.concatenate([k_nope, k_pe], axis=-1) + + queries = mx.concatenate([q_nope, q_pe], axis=-1) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class DeepseekV3MLP(nn.Module): + def __init__( + self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + def __call__(self, x): + down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + self.weight = mx.zeros((self.n_routed_experts, config.hidden_size)) + self.e_score_correction_bias = mx.zeros((self.n_routed_experts,)) + + def __call__(self, x): + gates = x @ self.weight.T + + scores = mx.sigmoid(gates.astype(mx.float32)) + + assert self.topk_method == "noaux_tc", "Unsupported topk method." + bsz, seq_len = x.shape[:2] + scores = scores + self.e_score_correction_bias + scores = scores.reshape(bsz, seq_len, self.n_group, -1) + group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1) + k = self.n_group - self.topk_group + group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k] + batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2)) + seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2)) + scores[batch_idx, seq_idx, group_idx] = 0.0 + scores = scores.reshape(bsz, seq_len, -1) + + k = self.top_k + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] + scores = mx.take_along_axis(scores, inds, axis=-1) + if self.top_k > 1 and self.norm_topk_prob: + denominator = scores.sum(axis=-1, keepdims=True) + 1e-20 + scores = scores / denominator + scores = scores * self.routed_scaling_factor + + return inds, scores + + +class DeepseekV3MoE(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + self.switch_mlp = SwitchGLU( + config.hidden_size, config.moe_intermediate_size, config.n_routed_experts + ) + + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV3MLP( + config=config, intermediate_size=intermediate_size + ) + + def __call__(self, x): + inds, scores = self.gate(x) + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(x) + + return y + + +class DeepseekV3DecoderLayer(nn.Module): + def __init__(self, config: ModelArgs, layer_idx: int): + super().__init__() + self.self_attn = DeepseekV3Attention(config) + self.mlp = ( + DeepseekV3MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV3MLP(config) + ) + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + # Protect against overflow for fp16 + if out.dtype == mx.float16: + out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000) + return out + + +class DeepseekV3Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = [ + DeepseekV3DecoderLayer(config, idx) + for idx in range(config.num_hidden_layers) + ] + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pipeline_rank = 0 + self.pipeline_size = 1 + + def pipeline(self, group): + # Split layers in reverse so rank=0 gets the last layers and + # rank=pipeline_size-1 gets the first + self.pipeline_rank = group.rank() + self.pipeline_size = group.size() + layers_per_rank = ( + len(self.layers) + self.pipeline_size - 1 + ) // self.pipeline_size + start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank + self.layers = self.layers[start : start + layers_per_rank] + + def __call__( + self, + x: mx.array, + cache: Optional[Any] = None, + mask: Optional[mx.array] = None, + ) -> mx.array: + h = self.embed_tokens(x) + + pipeline_rank = self.pipeline_rank + pipeline_size = self.pipeline_size + if mask is None: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + # Receive from the previous process in the pipeline + if pipeline_rank < pipeline_size - 1: + h = mx.distributed.recv_like(h, (pipeline_rank + 1)) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + # Send to the next process in the pipeline + if pipeline_rank != 0: + h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size) + + # Broadcast h while keeping it in the graph + h = mx.distributed.all_gather(h)[: h.shape[0]] + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.args = config + self.model_type = config.model_type + self.model = DeepseekV3Model(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache: Optional[Any] = None, + mask: Optional[mx.array] = None, + ): + out = self.model(inputs, cache, mask) + return self.lm_head(out) + + def sanitize(self, weights): + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: + for k in ["weight", "scales", "biases"]: + if f"{prefix}.mlp.experts.0.{m}.{k}" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") + for e in range(self.args.n_routed_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) + + # Remove multi-token prediction layer + return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")} + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/models/exaone.py b/llms/mlx_lm/models/exaone.py new file mode 100644 index 00000000..ee3ed1e8 --- /dev/null +++ b/llms/mlx_lm/models/exaone.py @@ -0,0 +1,166 @@ +# Copyright © 2024 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_layers: int + intermediate_size: int + num_attention_heads: int + vocab_size: int + rope_theta: float + layer_norm_epsilon: float + num_key_value_heads: int + head_dim: Optional[int] = None + max_position_embeddings: Optional[int] = None + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + attention_bias: bool = False + mlp_bias: bool = False + + +class AttentionModule(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.head_dim = head_dim = args.head_dim or (dim // n_heads) + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) + + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) + + def __call__( + self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None + ) -> mx.array: + B, L, D = x.shape + q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + q = self.rope(q, offset=cache.offset) + k = self.rope(k, offset=cache.offset) + k, v = cache.update_and_fetch(k, v) + else: + q = self.rope(q) + k = self.rope(k) + + out = scaled_dot_product_attention( + q, k, v, cache=cache, scale=self.scale, mask=mask + ) + out = out.transpose(0, 2, 1, 3).reshape(B, L, D) + return self.out_proj(out) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.attention = AttentionModule(args) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + hidden_dim = args.intermediate_size + self.c_fc_0 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) + self.c_fc_1 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) + self.c_proj = nn.Linear(hidden_dim, dim, bias=args.mlp_bias) + + def __call__(self, x: mx.array) -> mx.array: + return self.c_proj(nn.silu(self.c_fc_0(x)) * self.c_fc_1(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.attn = Attention(args) + self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.mlp = MLP(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + h = x + self.attn.attention(self.ln_1(x), mask, cache) + out = h + self.mlp(self.ln_2(h)) + return out + + +class ExaoneModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.wte = nn.Embedding(args.vocab_size, args.hidden_size) + self.h = [TransformerBlock(args) for _ in range(args.num_layers)] + self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + h = self.wte(inputs) + if mask is None: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.h) + + for layer, c in zip(self.h, cache): + h = layer(h, mask, cache=c) + + return self.ln_f(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.transformer = ExaoneModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + out = self.transformer(inputs, mask, cache) + if self.args.tie_word_embeddings: + out = self.transformer.wte.as_linear(out) + else: + out = self.lm_head(out) + return out + + @property + def layers(self): + return self.transformer.h diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index 61de781e..0860ddeb 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -79,8 +79,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) @@ -138,12 +138,14 @@ class GemmaModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -164,9 +166,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) return out diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index ccc327a8..321a58ff 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -111,7 +111,7 @@ class MLP(nn.Module): self.up_proj = nn.Linear(dim, hidden_dim, bias=False) def __call__(self, x) -> mx.array: - return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) + return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): @@ -160,12 +160,14 @@ class GemmaModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -187,9 +189,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) out = mx.tanh(out / self.final_logit_softcapping) out = out * self.final_logit_softcapping diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index 97d9a8ff..5b277734 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -61,8 +61,8 @@ class Attention(nn.Module): if cache is not None: keys, values = cache.update_and_fetch(keys, values) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) @@ -126,6 +126,7 @@ class GPT2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): _, L = inputs.shape @@ -138,7 +139,8 @@ class GPT2Model(nn.Module): position_ids = mx.array(np.arange(L)) hidden_states += self.wpe(position_ids) - mask = create_attention_mask(hidden_states, cache) + if mask is None: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) @@ -159,9 +161,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.wte.as_linear(out) return out diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 068046ea..1d9794b6 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -74,8 +74,8 @@ class Attention(nn.Module): if cache is not None: keys, values = cache.update_and_fetch(keys, values) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.c_proj(output) @@ -137,6 +137,7 @@ class GPTBigCodeModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): B, L = inputs.shape @@ -144,15 +145,16 @@ class GPTBigCodeModel(nn.Module): hidden_states = self.wte(inputs) mask = None - if hidden_states.shape[1] > 1: - - position_ids = mx.array(np.arange(L)) - hidden_states += self.wpe(position_ids) - + if mask is not None and hidden_states.shape[1] > 1: mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) + position_ids = mx.array(np.arange(L)) + else: + position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L)) + + hidden_states += self.wpe(position_ids) for layer, c in zip(self.h, cache): hidden_states = layer(hidden_states, mask, cache=c) @@ -172,9 +174,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.transformer.wte.as_linear(out) else: diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index 9f662491..5e124a67 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention # Based on the transformers implementation at: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -79,8 +79,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) @@ -146,13 +146,15 @@ class GPTNeoXModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): _, L = inputs.shape hidden_states = self.embed_in(inputs) - mask = create_attention_mask(hidden_states, cache) + if mask is None: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) @@ -176,9 +178,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return out def sanitize(self, weights): diff --git a/llms/mlx_lm/models/hunyuan.py b/llms/mlx_lm/models/hunyuan.py new file mode 100644 index 00000000..f9dc5652 --- /dev/null +++ b/llms/mlx_lm/models/hunyuan.py @@ -0,0 +1,294 @@ +# Copyright © 2023-2024 Apple Inc. + +import math +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + num_key_value_heads: int + attention_bias: bool + moe_topk: int + num_experts: int + num_shared_expert: int + use_mixed_mlp_moe: bool + use_qk_norm: bool + rms_norm_eps: float + rope_theta: float + use_cla: bool + cla_share_factor: 2 + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = False + + def __post_init__(self): + + if self.rope_scaling: + required_keys = {"factor", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + +class DynamicNTKAlphaRoPE(nn.Module): + def __init__( + self, + dims: int, + base: float = 10000, + scaling_alpha: float = 1.0, + ): + super().__init__() + self.dims = dims + base = base * scaling_alpha ** (dims / (dims - 2)) + self._freqs = base ** (mx.arange(0, self.dims, 2) / self.dims) + + def __call__(self, x, offset: int = 0): + return mx.fast.rope( + x, + self.dims, + traditional=False, + base=None, + scale=1.0, + offset=offset, + freqs=self._freqs, + ) + + +class Attention(nn.Module): + def __init__(self, kv_proj: bool, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) + if kv_proj: + self.k_proj = nn.Linear( + dim, n_kv_heads * head_dim, bias=args.attention_bias + ) + self.v_proj = nn.Linear( + dim, n_kv_heads * head_dim, bias=args.attention_bias + ) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) + self.use_qk_norm = args.use_qk_norm + if self.use_qk_norm: + self.query_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps) + self.key_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps) + + self.rope = DynamicNTKAlphaRoPE( + head_dim, + base=args.rope_theta, + scaling_alpha=args.rope_scaling["alpha"], + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + kv_states=None, + ) -> mx.array: + B, L, D = x.shape + + queries = self.q_proj(x) + + if kv_states is None: + keys, values = self.k_proj(x), self.v_proj(x) + kv_states = keys, values + else: + keys, values = kv_states + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + offset = cache.offset if cache else 0 + queries = self.rope(queries, offset=offset) + keys = self.rope(keys, offset=offset) + if self.use_qk_norm: + queries = self.query_layernorm(queries) + keys = self.key_layernorm(keys) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), kv_states + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class Gate(nn.Module): + def __init__(self, dim, num_experts): + super().__init__() + self.wg = nn.Linear(dim, num_experts, bias=False) + + def __call__(self, x) -> mx.array: + return self.wg(x) + + +class MoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + intermediate_size = args.intermediate_size + self.use_shared_mlp = args.use_mixed_mlp_moe + + if args.use_mixed_mlp_moe: + self.shared_mlp = MLP(dim, intermediate_size * args.num_shared_expert) + + self.num_experts = num_experts = args.num_experts + self.top_k = args.moe_topk + + self.gate = Gate(dim, num_experts) + self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts) + + def __call__( + self, + x: mx.array, + ): + gates = self.gate(x) + gates = mx.softmax(gates, axis=-1, precise=True) + + k = self.top_k + inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) + scores = mx.take_along_axis(gates, inds, axis=-1) + + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) + + if self.use_shared_mlp: + shared_expert_output = self.shared_mlp(x) + y = y + shared_expert_output + + return y + + +class DecoderLayer(nn.Module): + def __init__(self, args: ModelArgs, kv_proj: bool): + super().__init__() + self.hidden_size = args.hidden_size + self.self_attn = Attention(kv_proj, args) + self.mlp = MoeBlock(args) + + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + shared_kv_states: Optional[Tuple[mx.array, mx.array]] = None, + ): + r, shared_kv_states = self.self_attn( + self.input_layernorm(x), mask, cache, shared_kv_states + ) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, shared_kv_states + + +class HunYuanModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0) + for i in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + h = self.embed_tokens(inputs) + + if mask is None: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for i, (layer, c) in enumerate(zip(self.layers, cache)): + if i % self.args.cla_share_factor == 0: + shared_kv_states = None + h, shared_kv_states = layer(h, mask, c, shared_kv_states) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = HunYuanModel(args) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + out = self.model(inputs, mask, cache) + return self.model.embed_tokens.as_linear(out) + + def sanitize(self, weights): + if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: + return weights + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n in ["up_proj", "down_proj", "gate_proj"]: + for k in ["weight", "scales", "biases"]: + if f"{prefix}.mlp.experts.0.{n}.{k}" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) + return weights + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index 5264cb57..28a095e1 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -141,8 +141,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.wo(output) @@ -193,11 +193,13 @@ class InternLM2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.tok_embeddings(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -220,9 +222,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.tok_embeddings.as_linear(out) else: diff --git a/llms/mlx_lm/models/internlm3.py b/llms/mlx_lm/models/internlm3.py new file mode 100644 index 00000000..3be6f536 --- /dev/null +++ b/llms/mlx_lm/models/internlm3.py @@ -0,0 +1,241 @@ +# Copyright © 2023-2024 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + bias: bool = False + qkv_bias: bool = False + max_position_embeddings: int = 32768 + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = False + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling: + required_keys = {"factor", "rope_type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + if self.rope_scaling["rope_type"] not in ["linear", "dynamic"]: + raise ValueError( + "rope_scaling 'rope_type' currently only supports 'linear' or 'dynamic" + ) + + +class DynamicNTKScalingRoPE(nn.Module): + """Implements the rotary positional encoding with Dynamic NTK scaling.""" + + def __init__( + self, + dims: int, + max_position_embeddings: int = 2048, + traditional: bool = False, + base: float = 10000, + scale: float = 1.0, + ): + super().__init__() + self.max_position_embeddings = max_position_embeddings + self.original_base = base + self.dims = dims + self.traditional = traditional + self.scale = scale + + def extra_repr(self): + return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}" + + def __call__(self, x, offset: int = 0): + seq_len = x.shape[1] + offset + if seq_len > self.max_position_embeddings: + base = self.original_base * ( + (self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1) + ) ** (self.dims / (self.dims - 2)) + else: + base = self.original_base + + return mx.fast.rope( + x, + self.dims, + traditional=self.traditional, + base=base, + scale=self.scale, + offset=offset, + ) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + qkv_bias = args.qkv_bias + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.n_kv_groups = n_heads // args.num_key_value_heads + + self.head_dim = head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=qkv_bias) + + rope_scale = ( + 1 / args.rope_scaling["factor"] + if args.rope_scaling is not None + and args.rope_scaling["rope_type"] == "linear" + else 2.0 + ) + + self.rope = DynamicNTKScalingRoPE( + head_dim, + max_position_embeddings=args.max_position_embeddings, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim, bias): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias) + self.down_proj = nn.Linear(hidden_dim, dim, bias=bias) + self.up_proj = nn.Linear(dim, hidden_dim, bias=bias) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size, args.bias) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + +class InternLM2Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + assert args.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + h = self.embed_tokens(inputs) + + if mask is None: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = InternLM2Model(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + out = self.model(inputs, mask, cache) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + def sanitize(self, weights): + # Remove unused precomputed rotary freqs + return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k} + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 7da6b333..7b452ea4 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -1,12 +1,13 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope @dataclass @@ -32,117 +33,6 @@ class ModelArgs(BaseModelArgs): if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads - if self.rope_scaling: - if not "factor" in self.rope_scaling: - raise ValueError(f"rope_scaling must contain 'factor'") - rope_type = self.rope_scaling.get("type") or self.rope_scaling.get( - "rope_type" - ) - if rope_type is None: - raise ValueError( - f"rope_scaling must contain either 'type' or 'rope_type'" - ) - if rope_type not in ["linear", "dynamic", "llama3"]: - raise ValueError( - "rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'" - ) - - -class DynamicNTKScalingRoPE(nn.Module): - """Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE.""" - - def __init__( - self, - dims: int, - max_position_embeddings: int = 2048, - traditional: bool = False, - base: float = 10000, - scale: float = 1.0, - rope_type: str = "default", - rope_scaling: dict = None, - ): - super().__init__() - self.dims = dims - self.max_position_embeddings = max_position_embeddings - self.traditional = traditional - self.scale = scale - self.rope_type = rope_type - self.rope_scaling = rope_scaling - self.base = base - self.compute_freqs() - - def compute_freqs(self): - if self.rope_type != "llama3": - self._freqs = None - return - factor = self.rope_scaling["factor"] - low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0) - high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0) - old_context_len = self.rope_scaling.get( - "original_max_position_embeddings", - 8192, - ) - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - - freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims) - wavelens = 2 * mx.pi * freqs - - freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) - is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) - smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) - self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) - self.base = None - - def extra_repr(self): - return ( - f"{self.dims}, traditional={self.traditional}, " - f"max_position_embeddings={self.max_position_embeddings}, " - f"scaling_factor={self.scale}, rope_type={self.rope_type}" - ) - - def __call__(self, x, offset: int = 0): - return mx.fast.rope( - x, - self.dims, - traditional=self.traditional, - base=self.base, - scale=self.scale, - offset=offset, - freqs=self._freqs, - ) - - -def initialize_rope(args: ModelArgs): - head_dim = args.head_dim or args.hidden_size // args.num_attention_heads - - rope_scaling = args.rope_scaling - rope_type = "default" - rope_scale = 1.0 - - if rope_scaling is not None: - rope_type = ( - rope_scaling.get("type") or rope_scaling.get("rope_type") or "default" - ) - if rope_type == "linear": - rope_scale = 1 / rope_scaling["factor"] - elif rope_type == "llama3": - rope_scale = 1.0 # The scaling is handled internally for llama3 - - return DynamicNTKScalingRoPE( - dims=head_dim, - max_position_embeddings=args.max_position_embeddings, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - rope_type=rope_type, - rope_scaling=rope_scaling, - ) - class Attention(nn.Module): def __init__(self, args: ModelArgs): @@ -165,7 +55,13 @@ class Attention(nn.Module): self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) - self.rope = initialize_rope(args) + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) def __call__( self, @@ -190,9 +86,10 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) @@ -258,11 +155,13 @@ class LlamaModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -285,9 +184,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index d2740dc1..f2414660 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs): use_conv_bias: bool time_step_rank: int tie_word_embeddings: bool = True + use_bcdt_rms: bool = False + mixer_rms_eps: float = 1e-6 def __post_init__(self): if not hasattr(self, "hidden_size") and hasattr(self, "d_model"): @@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs): if self.time_step_rank == "auto": self.time_step_rank = math.ceil(self.hidden_size / 16) + if self.model_type == "falcon_mamba": + self.use_bcdt_rms = True class DepthWiseConv1d(nn.Module): @@ -83,6 +87,11 @@ class MambaBlock(nn.Module): self.intermediate_size = args.intermediate_size self.time_step_rank = int(args.time_step_rank) self.use_conv_bias = args.use_conv_bias + self.use_bcdt_rms = args.use_bcdt_rms + if self.use_bcdt_rms: + self.mixer_norm = lambda x: mx.fast.rms_norm( + x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps + ) self.in_proj = nn.Linear( self.hidden_size, self.intermediate_size * 2, bias=args.use_bias @@ -126,6 +135,8 @@ class MambaBlock(nn.Module): ], axis=-1, ) + if self.use_bcdt_rms: + delta, B, C = map(self.mixer_norm, (delta, B, C)) delta = nn.softplus(self.dt_proj(delta)) new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) if state is not None: @@ -205,7 +216,7 @@ class Model(nn.Module): def sanitize(self, weights): for k, v in weights.items(): - if "conv1d.weight" in k and v.ndim == 3: + if "conv1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) return weights diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index 4ac3c3b4..edddd583 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -105,8 +105,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - attn_output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + attn_output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1) @@ -158,11 +158,13 @@ class MiniCPMModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) * self.args.scale_emb - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -186,9 +188,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if not self.args.tie_word_embeddings: out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base)) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 20944fe3..0afd1235 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -87,8 +87,8 @@ class MixtralAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) @@ -162,11 +162,13 @@ class MixtralModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -188,9 +190,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py index 3ea06e27..eabfac8c 100644 --- a/llms/mlx_lm/models/nemotron.py +++ b/llms/mlx_lm/models/nemotron.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -113,8 +113,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) @@ -176,11 +176,13 @@ class NemotronModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -203,9 +205,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 3627df06..4273b0ec 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -124,11 +124,13 @@ class Transformer(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.wte(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) @@ -152,9 +154,10 @@ class OlmoModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - return self.transformer(inputs, cache) + return self.transformer(inputs, mask, cache) class Model(nn.Module): @@ -167,9 +170,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - return self.model(inputs, cache) + return self.model(inputs, mask, cache) @property def layers(self): diff --git a/llms/mlx_lm/models/olmo2.py b/llms/mlx_lm/models/olmo2.py new file mode 100644 index 00000000..510ff882 --- /dev/null +++ b/llms/mlx_lm/models/olmo2.py @@ -0,0 +1,212 @@ +# Copyright © 2023-2024 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + head_dim: Optional[int] = None + max_position_embeddings: Optional[int] = None + num_key_value_heads: Optional[int] = None + attention_bias: bool = False + mlp_bias: bool = False + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads + + self.scale = head_dim**-0.5 + if hasattr(args, "attention_bias"): + attention_bias = args.attention_bias + else: + attention_bias = False + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) + + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) + + self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps) + self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + queries = self.q_norm(queries) + keys = self.k_norm(keys) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + hidden_dim = args.intermediate_size + if hasattr(args, "mlp_bias"): + mlp_bias = args.mlp_bias + else: + mlp_bias = False + + self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) + self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.post_feedforward_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.post_attention_layernorm(self.self_attn(x, mask, cache)) + h = x + r + r = self.post_feedforward_layernorm(self.mlp(h)) + out = h + r + return out + + +class LlamaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + mask=None, + ): + h = self.embed_tokens(inputs) + + if mask is None: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = LlamaModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + mask=None, + ): + out = self.model(inputs, cache, mask) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + def sanitize(self, weights): + # Remove unused precomputed rotary freqs + return { + k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k + } + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 090e21c6..504fe95c 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -107,8 +107,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) @@ -178,11 +178,13 @@ class OpenELMModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.token_embeddings(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -205,9 +207,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) if self.args.share_input_output_layers: out = self.transformer.token_embeddings.as_linear(out) else: diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 56b383b2..e9724691 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -7,7 +7,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -93,8 +93,13 @@ class PhiAttention(nn.Module): keys = self.rope(keys) scale = math.sqrt(1 / queries.shape[-1]) - output = mx.fast.scaled_dot_product_attention( - queries.astype(mx.float32), keys, values, scale=scale, mask=mask + output = scaled_dot_product_attention( + queries.astype(mx.float32), + keys, + values, + cache=cache, + scale=scale, + mask=mask, ).astype(values.dtype) output = output.moveaxis(2, 1).reshape(B, L, -1) @@ -138,10 +143,11 @@ class PhiModel(nn.Module): config.hidden_size, eps=config.layer_norm_eps ) - def __call__(self, x, cache): + def __call__(self, x, mask, cache): x = self.embed_tokens(x) - mask = create_attention_mask(x, cache) + if mask is None: + mask = create_attention_mask(x, cache) if cache is None: cache = [None] * len(self.layers) @@ -162,9 +168,10 @@ class Model(nn.Module): def __call__( self, x: mx.array, + mask: mx.array = None, cache=None, ) -> mx.array: - y = self.model(x, cache) + y = self.model(x, mask, cache) return self.lm_head(y) @property diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index 9ef76f04..d1c21e25 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .su_rope import SuScaledRotaryEmbedding @@ -107,8 +107,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) @@ -168,11 +168,13 @@ class Phi3Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -194,9 +196,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) @property diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index 6b0759b4..cd566eec 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -8,7 +8,7 @@ from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -188,8 +188,8 @@ class Attention(nn.Module): queries, keys, values, scale=self.scale, mask=mask ) else: - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.dense(output) @@ -258,13 +258,15 @@ class Phi3Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) if self.mup_embedding_multiplier: h = self.mup_embedding_multiplier * h - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -290,9 +292,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) if self.mup_width_multiplier: out = out / self.mup_width_multiplier diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py index ca20a388..bddcb128 100644 --- a/llms/mlx_lm/models/phimoe.py +++ b/llms/mlx_lm/models/phimoe.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .su_rope import SuScaledRotaryEmbedding from .switch_layers import SwitchGLU @@ -79,8 +79,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) @@ -155,11 +155,13 @@ class PhiMoEModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ) -> mx.array: h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -181,9 +183,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 865d0d8e..5477c2c0 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -8,7 +8,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import create_attention_mask +from .base import create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchMLP @@ -71,8 +71,13 @@ class RoPEAttention(nn.Module): # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) - output = mx.fast.scaled_dot_product_attention( - queries.astype(mx.float32), keys, values, scale=scale, mask=mask + output = scaled_dot_product_attention( + queries.astype(mx.float32), + keys, + values, + cache=cache, + scale=scale, + mask=mask, ).astype(values.dtype) output = output.moveaxis(2, 1).reshape(B, L, -1) @@ -170,7 +175,9 @@ class Model(nn.Module): mask: mx.array = None, cache=None, ) -> mx.array: - mask = create_attention_mask(x, cache) + + if mask is None: + mask = create_attention_mask(x, cache) y = self.transformer(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index b0fd1a6c..9107daad 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -92,10 +92,11 @@ class Attention(nn.Module): keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1]) values = mx.tile(values, [1, self.config.n_shared_head, 1, 1]) - output = mx.fast.scaled_dot_product_attention( + output = scaled_dot_product_attention( queries, keys, values, + cache=cache, scale=self.scale, mask=attention_mask, ) @@ -173,10 +174,12 @@ class PlamoModel(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None for _ in range(len(self.layers.layers))] @@ -201,8 +204,9 @@ class Model(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) return self.lm_head(out) @property diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 2b69d5ec..ec8a0199 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -5,7 +5,7 @@ from dataclasses import dataclass import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -64,8 +64,8 @@ class Attention(nn.Module): queries = self.rotary_emb(queries) keys = self.rotary_emb(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) @@ -123,7 +123,8 @@ class QwenModel(nn.Module): def __call__(self, inputs, mask=None, cache=None): x = self.wte(inputs) - mask = create_attention_mask(x, cache) + if mask is None: + mask = create_attention_mask(x, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 4e7858de..381767c4 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -89,8 +89,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) @@ -149,11 +149,13 @@ class Qwen2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -176,9 +178,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index d199116f..c6aba622 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -89,8 +89,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) @@ -187,11 +187,13 @@ class Qwen2MoeModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -213,9 +215,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 06a307a6..ad07d925 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -7,7 +7,7 @@ from typing import List, Literal, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .cache import MambaCache, RotatingKVCache @@ -263,8 +263,8 @@ class LocalAttentionBlock(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) @@ -389,6 +389,7 @@ class Griffin(nn.Module): def __call__( self, tokens, + mask: mx.array = None, cache=None, ): x = self.embed_tokens(tokens) @@ -402,7 +403,8 @@ class Griffin(nn.Module): if block.temporal_block_type != "recurrent": mask_cache = [cache[i]] - mask = create_attention_mask(x, mask_cache) + if mask is None: + mask = create_attention_mask(x, mask_cache) for i, block in enumerate(self.layers): x = block(x, mask=mask, cache=cache[i]) @@ -418,12 +420,12 @@ class Model(nn.Module): self.model_type = config.model_type self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - def __call__(self, tokens: mx.array, cache=None) -> mx.array: + def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array: """ Args: tokens: Sequence of input tokens. """ - logits = self.model(tokens, cache=cache) + logits = self.model(tokens, mask=mask, cache=cache) if "lm_head" in self: logits = self.lm_head(logits) else: @@ -440,7 +442,7 @@ class Model(nn.Module): def sanitize(self, weights): for k, v in weights.items(): - if "conv_1d.weight" in k and v.ndim == 3: + if "conv_1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) if "lm_head.weight" not in weights: self.pop("lm_head") diff --git a/llms/mlx_lm/models/rope_utils.py b/llms/mlx_lm/models/rope_utils.py new file mode 100644 index 00000000..d30b432d --- /dev/null +++ b/llms/mlx_lm/models/rope_utils.py @@ -0,0 +1,91 @@ +# Copyright © 2023-2024 Apple Inc. + +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + + +class Llama3RoPE(nn.Module): + + def __init__( + self, + dims: int, + max_position_embeddings: int = 2048, + traditional: bool = False, + base: float = 10000, + scaling_config: dict = None, + ): + super().__init__() + self.dims = dims + self.max_position_embeddings = max_position_embeddings + self.traditional = traditional + + factor = scaling_config["factor"] + low_freq_factor = scaling_config.get("low_freq_factor", 1.0) + high_freq_factor = scaling_config.get("high_freq_factor", 4.0) + old_context_len = scaling_config.get( + "original_max_position_embeddings", + 8192, + ) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + freqs = base ** (mx.arange(0, dims, 2) / dims) + wavelens = 2 * mx.pi * freqs + + freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) + is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) + smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) + self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) + + def extra_repr(self): + return ( + f"{self.dims}, traditional={self.traditional}, " + f"max_position_embeddings={self.max_position_embeddings}" + ) + + def __call__(self, x, offset: int = 0): + return mx.fast.rope( + x, + self.dims, + traditional=self.traditional, + base=None, + scale=1.0, + offset=offset, + freqs=self._freqs, + ) + + +def initialize_rope( + dims, + base, + traditional, + scaling_config: Optional[dict] = None, + max_position_embeddings: Optional[int] = None, +): + if scaling_config is not None: + rope_type = scaling_config.get("type") or scaling_config.get( + "rope_type", "default" + ) + else: + rope_type = "default" + + if rope_type in ["default", "linear"]: + scale = 1 / scaling_config["factor"] if rope_type == "linear" else 1.0 + return nn.RoPE(dims, traditional=traditional, base=base, scale=scale) + + elif rope_type == "llama3": + return Llama3RoPE( + dims=dims, + max_position_embeddings=max_position_embeddings, + traditional=traditional, + base=base, + scaling_config=scaling_config, + ) + else: + raise ValueError(f"Unsupported RoPE type {rope_type}") diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 11202b02..0bbc2ca4 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -6,7 +6,7 @@ from dataclasses import dataclass import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -120,8 +120,8 @@ class Attention(nn.Module): # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=scale, mask=mask ).astype(values.dtype) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) @@ -199,7 +199,10 @@ class Model(nn.Module): mask: mx.array = None, cache=None, ) -> mx.array: - mask = create_attention_mask(x, cache) + + if mask is None: + mask = create_attention_mask(x, cache) + y = self.model(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index ce0a2ec5..71c397f6 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -6,7 +6,7 @@ from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -64,8 +64,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) @@ -125,11 +125,13 @@ class Starcoder2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -152,9 +154,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 814c03cc..72e1ef89 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.17.0 +mlx>=0.22.0 numpy transformers[sentencepiece]>=4.39.3 protobuf diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 20b008fa..c48a32cf 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -1,26 +1,132 @@ # Copyright © 2023-2024 Apple Inc. +import math from functools import partial +from typing import Callable, Dict, Optional import mlx.core as mx +def make_sampler( + temp: float = 0.0, + top_p: float = 0.0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, + top_k: int = -1, +) -> Callable[mx.array, mx.array]: + """ + Make a sampler function for use with ``generate_step``. + + Args: + temp (float): The temperature for sampling, if 0 the argmax is used. + Default: ``0``. + top_p (float, optional): Nulceus sampling, higher means model considers + more less likely words. + min_p (float, optional): The minimum value (scaled by the top token's + probability) that a token probability must have to be considered. + min_tokens_to_keep (int, optional): Minimum number of tokens that cannot + be filtered by min_p sampling. + top_k (int, optional): The top k tokens ranked by probability to constrain + the sampling to. + + Returns: + Callable[mx.array, mx.array]: + A sampler which takes log-probabilities and returns tokens. + """ + if temp == 0: + return lambda x: mx.argmax(x, axis=-1) + elif top_p > 0 and top_p < 1.0: + return lambda x: top_p_sampling(x, top_p, temp) + elif min_p != 0.0: + return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp) + elif top_k > 0: + return lambda x: top_k_sampling(x, top_k, temp) + else: + return lambda x: categorical_sampling(x, temp) + + +def make_logits_processors( + logit_bias: Optional[Dict[int, float]] = None, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = 20, +): + """ + Make logits processors for use with ``generate_step``. + + Args: + repetition_penalty (float, optional): The penalty factor for repeating + tokens. + repetition_context_size (int, optional): The number of tokens to + consider for repetition penalty. Default: ``20``. + logit_bias (dictionary, optional): Additive logit bias. + + Returns: + List[Callable[[mx.array, mx.array], mx.array]]: + A list of logits processors. Each processor in the list is a + callable which takes an array of tokens and an array of logits + and returns the updated logits. + """ + logits_processors = [] + if logit_bias: + indices = mx.array(list(logit_bias.keys())) + values = mx.array(list(logit_bias.values())) + + def logit_bias_processor(_, logits): + logits[:, indices] += values + return logits + + logits_processors.append(logit_bias_processor) + + if repetition_penalty and repetition_penalty != 0.0: + logits_processors.append( + make_repetition_penalty(repetition_penalty, repetition_context_size) + ) + return logits_processors + + +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def top_k_sampling( + logprobs: mx.array, + top_k: int, + temperature=1.0, +) -> mx.array: + """ + Sample from only the top K tokens ranked by probability. + + Args: + logprobs: A vector of log probabilities. + top_k (int): Top k tokens to sample from. + """ + vocab_size = logprobs.shape[-1] + if not isinstance(top_k, int) or not (0 < top_k < vocab_size): + raise ValueError( + f"`top_k` has to be an integer in the (0, {vocab_size}] interval," + f" but is {top_k}." + ) + logprobs = logprobs * (1 / temperature) + mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:] + masked_logprobs = mx.put_along_axis( + logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1 + ) + return mx.random.categorical(masked_logprobs, axis=-1) + + @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def min_p_sampling( - logits: mx.array, + logprobs: mx.array, min_p: float, min_tokens_to_keep: int = 1, temperature=1.0, ) -> mx.array: """ - Apply min-p sampling to the logits. + Apply min-p sampling to the logprobs. Min-p keeps all tokens that are above a minimum probability, scaled by the probability of the most likely token. As a result, the filter is more aggressive given a very high-probability token. Args: - logits: The logits from the model's output. + logprobs: A vector of log probabilities. min_p (float): Minimum token probability. Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in the 0.99-0.8 range. @@ -38,28 +144,27 @@ def min_p_sampling( ) # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 - # Softmax probabilities - probs = mx.softmax(logits * (1 / temperature), axis=-1) + logprobs = logprobs * (1 / temperature) # Indices sorted in decreasing order - sorted_indices = mx.argsort(-logits).squeeze(0) - sorted_probs = probs[..., sorted_indices] + sorted_indices = mx.argsort(-logprobs).squeeze(0) + sorted_logprobs = logprobs[..., sorted_indices] # Top probability - top_probs = probs[..., sorted_indices[0]] + top_logprobs = logprobs[..., sorted_indices[0]] # Calculate the min_p threshold - scaled_min_p = min_p * top_probs + scaled_min_p = top_logprobs + math.log(min_p) # Mask tokens that have a probability less than the scaled min_p - tokens_to_remove = sorted_probs < scaled_min_p + tokens_to_remove = sorted_logprobs < scaled_min_p tokens_to_remove[..., :min_tokens_to_keep] = False # Create pool of tokens with probability less than scaled min_p - selected_probs = mx.where(tokens_to_remove, 0, sorted_probs) + selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs) # Return sampled token - sorted_token = mx.random.categorical(mx.log(selected_probs)) + sorted_token = mx.random.categorical(selected_logprobs) return sorted_indices[sorted_token] @@ -100,3 +205,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def categorical_sampling(logits, temp): return mx.random.categorical(logits * (1 / temp)) + + +def make_repetition_penalty(penalty: float, context_size: int = 20): + """ + Make repetition penalty processor. + + Paper: https://arxiv.org/abs/1909.05858 + + Args: + penalty (float): The repetition penalty factor to be applied. + context_size (int): The number of previous tokens to use. + Default: ``20``. + + Returns: + Callable[[mx.array, List[int]], mx.array]: + The repetition penalty processor. + """ + if penalty < 0 or not isinstance(penalty, (int, float)): + raise ValueError(f"penalty must be a non-negative float, got {penalty}") + + def repetition_penalty_processor(tokens, logits): + if len(tokens) > 0: + tokens = tokens[-context_size:] + selected_logits = logits[:, tokens] + selected_logits = mx.where( + selected_logits < 0, + selected_logits * penalty, + selected_logits / penalty, + ) + logits[:, tokens] = selected_logits + return logits + + return repetition_penalty_processor diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 42962b54..4523e3ae 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -3,17 +3,37 @@ import argparse import json import logging +import platform import time import uuid import warnings +from dataclasses import dataclass, field from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path -from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union +from typing import ( + Any, + Dict, + List, + Literal, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) import mlx.core as mx from huggingface_hub import scan_cache_dir -from .utils import generate_step, load +from ._version import __version__ +from .models.cache import make_prompt_cache +from .sample_utils import make_logits_processors, make_sampler +from .utils import load, stream_generate + + +def get_system_fingerprint(): + gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else "" + return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}" class StopCondition(NamedTuple): @@ -45,7 +65,7 @@ def stopping_criteria( end if it has (`trim_length`). """ if tokens and tokens[-1] == eos_token_id: - return StopCondition(stop_met=True, trim_length=1) + return StopCondition(stop_met=True, trim_length=0) for stop_ids in stop_id_sequences: if len(tokens) >= len(stop_ids): @@ -94,6 +114,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): return prompt.rstrip() +@dataclass +class PromptCache: + cache: List[Any] = field(default_factory=list) + model_key: Tuple[str, Optional[str]] = ("", None) + tokens: List[int] = field(default_factory=list) + + class ModelProvider: def __init__(self, cli_args: argparse.Namespace): """Load models on demand and persist them across the whole process.""" @@ -156,12 +183,21 @@ class ModelProvider: class APIHandler(BaseHTTPRequestHandler): - def __init__(self, model_provider: ModelProvider, *args, **kwargs): + def __init__( + self, + model_provider: ModelProvider, + *args, + prompt_cache: Optional[PromptCache] = None, + system_fingerprint: Optional[str] = None, + **kwargs, + ): """ Create static request specific metadata """ self.created = int(time.time()) self.model_provider = model_provider + self.prompt_cache = prompt_cache or PromptCache() + self.system_fingerprint = system_fingerprint or get_system_fingerprint() super().__init__(*args, **kwargs) def _set_cors_headers(self): @@ -215,8 +251,10 @@ class APIHandler(BaseHTTPRequestHandler): self.stream_options = self.body.get("stream_options", None) self.requested_model = self.body.get("model", "default_model") self.adapter = self.body.get("adapters", None) - self.max_tokens = self.body.get("max_tokens", 100) - self.temperature = self.body.get("temperature", 1.0) + self.max_tokens = self.body.get("max_completion_tokens", None) + if self.max_tokens is None: + self.max_tokens = self.body.get("max_tokens", 512) + self.temperature = self.body.get("temperature", 0.0) self.top_p = self.body.get("top_p", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0) self.repetition_context_size = self.body.get("repetition_context_size", 20) @@ -253,10 +291,7 @@ class APIHandler(BaseHTTPRequestHandler): # Call endpoint specific method prompt = endpoints[self.path]() - - # Call method based on response type - method = self.handle_stream if self.stream else self.handle_completion - method(prompt, stop_id_sequences) + self.handle_completion(prompt, stop_id_sequences) def validate_model_parameters(self): """ @@ -343,7 +378,7 @@ class APIHandler(BaseHTTPRequestHandler): # Static response response = { "id": self.request_id, - "system_fingerprint": f"fp_{uuid.uuid4()}", + "system_fingerprint": self.system_fingerprint, "object": self.object_type, "model": self.requested_model, "created": self.created, @@ -388,41 +423,66 @@ class APIHandler(BaseHTTPRequestHandler): return response + def get_prompt_cache(self, prompt): + cache_len = len(self.prompt_cache.tokens) + if ( + self.prompt_cache.model_key != self.model_provider.model_key + or cache_len >= len(prompt) + or self.prompt_cache.tokens != prompt[:cache_len] + ): + self.prompt_cache.model_key = self.model_provider.model_key + self.prompt_cache.cache = make_prompt_cache(self.model_provider.model) + else: + prompt = prompt[cache_len:] + self.prompt_cache.tokens.extend(prompt) + return prompt + def handle_completion( self, - prompt: mx.array, + prompt: List[int], stop_id_sequences: List[List[int]], ): """ Generate a response to a prompt and send it to the client in a single batch. Args: - prompt (mx.array): The prompt, in token form inside of a mlx array + prompt (List[int]): The tokenized prompt. stop_id_sequences (List[List[int]]): A list of stop words passed to the stopping_criteria function """ - detokenizer = self.tokenizer.detokenizer - detokenizer.reset() tokens = [] finish_reason = "length" stop_sequence_suffix = None - logging.debug(f"Starting completion:") + if self.stream: + self.end_headers() + logging.debug(f"Starting stream:") + else: + logging.debug(f"Starting completion:") token_logprobs = [] top_tokens = [] - for (token, logprobs), _ in zip( - generate_step( - prompt=prompt, - model=self.model, - temp=self.temperature, - top_p=self.top_p, - repetition_penalty=self.repetition_penalty, - repetition_context_size=self.repetition_context_size, - logit_bias=self.logit_bias, - ), - range(self.max_tokens), + + prompt = self.get_prompt_cache(prompt) + + text = "" + tic = time.perf_counter() + sampler = make_sampler(self.temperature, top_p=self.top_p) + logits_processors = make_logits_processors( + self.logit_bias, self.repetition_penalty, self.repetition_context_size + ) + for gen_response in stream_generate( + model=self.model, + tokenizer=self.tokenizer, + prompt=prompt, + max_tokens=self.max_tokens, + sampler=sampler, + logits_processors=logits_processors, + prompt_cache=self.prompt_cache.cache, ): - detokenizer.add_token(token) - logging.debug(detokenizer.text) + segment = gen_response.text + text += segment + logging.debug(text) + token = gen_response.token + logprobs = gen_response.logprobs tokens.append(token) if self.logprobs > 0: @@ -430,7 +490,7 @@ class APIHandler(BaseHTTPRequestHandler): top_indices = sorted_indices[: self.logprobs] top_logprobs = logprobs[top_indices] top_token_info = zip(top_indices.tolist(), top_logprobs.tolist()) - top_tokens.append(dict(top_token_info)) + top_tokens.append(tuple(top_token_info)) token_logprobs.append(logprobs[token].item()) @@ -443,114 +503,59 @@ class APIHandler(BaseHTTPRequestHandler): stop_sequence_suffix = self.tokenizer.decode( tokens[-stop_condition.trim_length :] ) + text = text[: -len(stop_sequence_suffix)] break - detokenizer.finalize() - text = ( - detokenizer.text - if stop_sequence_suffix is None - else detokenizer.text[: -len(stop_sequence_suffix)] - ) - response = self.generate_response( - text, - finish_reason, - len(prompt), - len(tokens), - token_logprobs=token_logprobs, - top_tokens=top_tokens, - tokens=tokens, - ) - - response_json = json.dumps(response).encode() - indent = "\t" # Backslashes can't be inside of f-strings - logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") - - # Send an additional Content-Length header when it is known - self.send_header("Content-Length", str(len(response_json))) - self.end_headers() - - self.wfile.write(response_json) - self.wfile.flush() - - def handle_stream( - self, - prompt: mx.array, - stop_id_sequences: List[List[int]], - ): - """ - Generate response to prompt and foward it to the client using a Server - Sent Events (SSE) stream. - - Args: - prompt (mx.array): The prompt, in token form inside of a mlx array - stop_id_sequences (List[List[int]]): A list of stop words passed to - the stopping_criteria function - """ - # No additional headers are needed, call end_headers - self.end_headers() - - detokenizer = self.tokenizer.detokenizer - detokenizer.reset() - tokens = [] - - stop_sequence_suffix = None - logging.debug(f"Starting stream:") - - for (token, _), _ in zip( - generate_step( - prompt=prompt, - model=self.model, - temp=self.temperature, - top_p=self.top_p, - repetition_penalty=self.repetition_penalty, - repetition_context_size=self.repetition_context_size, - ), - range(self.max_tokens), - ): - detokenizer.add_token(token) - logging.debug(detokenizer.text) - tokens.append(token) - - stop_condition = stopping_criteria( - tokens, - stop_id_sequences, - self.tokenizer.eos_token_id, - ) - if stop_condition.stop_met: - if stop_condition.trim_length: - stop_sequence_suffix = self.tokenizer.decode( - tokens[-stop_condition.trim_length :] + if self.stream: + # If the end of tokens overlaps with a stop sequence, generate new + # tokens until we know if the stop sequence is hit or not + if any( + ( + sequence_overlap(tokens, sequence) + for sequence in stop_id_sequences ) - break + ): + continue + elif segment: + response = self.generate_response(segment, None) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() - # If the end of tokens overlaps with a stop sequence, generate new - # tokens until we know if the stop sequence is hit or not - if any( - (sequence_overlap(tokens, sequence) for sequence in stop_id_sequences) - ): - continue + self.prompt_cache.tokens.extend(tokens) - new_text = detokenizer.last_segment - response = self.generate_response(new_text, None) + logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec") + logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec") + logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB") + + if self.stream: + response = self.generate_response(segment, finish_reason) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.flush() - - # check is there any remaining text to send - detokenizer.finalize() - last_segment = detokenizer.last_segment - if last_segment: - if stop_sequence_suffix is not None: - last_segment = last_segment[: -len(stop_sequence_suffix)] - response = self.generate_response(last_segment, "length") - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + if self.stream_options is not None and self.stream_options["include_usage"]: + response = self.completion_usage_response(len(prompt), len(tokens)) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + self.wfile.write("data: [DONE]\n\n".encode()) self.wfile.flush() + else: + response = self.generate_response( + text, + finish_reason, + len(prompt), + len(tokens), + token_logprobs=token_logprobs, + top_tokens=top_tokens, + tokens=tokens, + ) + response_json = json.dumps(response).encode() + indent = "\t" # Backslashes can't be inside of f-strings + logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") - if self.stream_options is not None and self.stream_options["include_usage"]: - response = self.completion_usage_response(len(prompt), len(tokens)) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - - self.wfile.write("data: [DONE]\n\n".encode()) - self.wfile.flush() + # Send an additional Content-Length header when it is known + self.send_header("Content-Length", str(len(response_json))) + self.end_headers() + self.wfile.write(response_json) + self.wfile.flush() def completion_usage_response( self, @@ -559,7 +564,7 @@ class APIHandler(BaseHTTPRequestHandler): ): response = { "id": self.request_id, - "system_fingerprint": f"fp_{uuid.uuid4()}", + "system_fingerprint": self.system_fingerprint, "object": "chat.completion", "model": self.requested_model, "created": self.created, @@ -572,7 +577,7 @@ class APIHandler(BaseHTTPRequestHandler): } return response - def handle_chat_completions(self) -> mx.array: + def handle_chat_completions(self) -> List[int]: """ Handle a chat completion request. @@ -584,27 +589,20 @@ class APIHandler(BaseHTTPRequestHandler): # Determine response type self.request_id = f"chatcmpl-{uuid.uuid4()}" - self.object_type = ( - "chat.completions.chunk" if self.stream else "chat.completions" - ) - - if ( - hasattr(self.tokenizer, "apply_chat_template") - and self.tokenizer.chat_template - ): + self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" + if self.tokenizer.chat_template: prompt = self.tokenizer.apply_chat_template( body["messages"], body.get("tools", None), - tokenize=True, add_generation_prompt=True, ) else: prompt = convert_chat(body["messages"], body.get("role_mapping")) prompt = self.tokenizer.encode(prompt) - return mx.array(prompt) + return prompt - def handle_text_completions(self) -> mx.array: + def handle_text_completions(self) -> List[int]: """ Handle a text completion request. @@ -614,11 +612,8 @@ class APIHandler(BaseHTTPRequestHandler): # Determine response type self.request_id = f"cmpl-{uuid.uuid4()}" self.object_type = "text_completion" - assert "prompt" in self.body, "Request did not contain a prompt" - prompt_text = self.body["prompt"] - prompt = self.tokenizer.encode(prompt_text) - return mx.array(prompt) + return self.tokenizer.encode(self.body["prompt"]) def do_GET(self): """ @@ -669,9 +664,16 @@ def run( handler_class=APIHandler, ): server_address = (host, port) + prompt_cache = PromptCache() httpd = server_class( server_address, - lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs), + lambda *args, **kwargs: handler_class( + model_provider, + prompt_cache=prompt_cache, + system_fingerprint=get_system_fingerprint(), + *args, + **kwargs, + ), ) warnings.warn( "mlx_lm.server is not recommended for production as " diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 04bbbcc5..1b5bdd77 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -3,14 +3,6 @@ from functools import partial from transformers import AutoTokenizer -REPLACEMENT_CHAR = "\ufffd" - - -def _remove_space(x): - if x and x[0] == " ": - return x[1:] - return x - class StreamingDetokenizer: """The streaming detokenizer interface so that we can detokenize one token at a time. @@ -57,11 +49,9 @@ class StreamingDetokenizer: def last_segment(self): """Return the last segment of readable text since last time this property was accessed.""" text = self.text - if text and text[-1] != REPLACEMENT_CHAR: - segment = text[self.offset :] - self.offset = len(text) - return segment - return "" + segment = text[self.offset :] + self.offset = len(text) + return segment class NaiveStreamingDetokenizer(StreamingDetokenizer): @@ -79,16 +69,16 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer): def reset(self): self.offset = 0 - self._tokens = [] + self.tokens = [] self._text = "" self._current_tokens = [] self._current_text = "" def add_token(self, token): self._current_tokens.append(token) + self.tokens.append(token) def finalize(self): - self._tokens.extend(self._current_tokens) self._text += self._tokenizer.decode(self._current_tokens) self._current_tokens = [] self._current_text = "" @@ -97,17 +87,17 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer): def text(self): if self._current_tokens: self._current_text = self._tokenizer.decode(self._current_tokens) + if ( + self._tokenizer.clean_up_tokenization_spaces + and self._current_text[-1] == " " + ): + self._current_text = self._current_text[:-1] if self._current_text and self._current_text[-1] == "\n": - self._tokens.extend(self._current_tokens) self._text += self._current_text self._current_tokens.clear() self._current_text = "" return self._text + self._current_text - @property - def tokens(self): - return self._tokens - class SPMStreamingDetokenizer(StreamingDetokenizer): """A streaming detokenizer for SPM models. @@ -118,42 +108,43 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): def __init__(self, tokenizer, trim_space=True): self.trim_space = trim_space + self._sep = "\u2581".encode() # Extract the tokens in a list from id to text self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1) for value, tokenid in tokenizer.vocab.items(): - self.tokenmap[tokenid] = value - - # Replace bytes with their value - for i in range(len(self.tokenmap)): - if self.tokenmap[i].startswith("<0x"): - self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16)) + if value.startswith("<0x"): + # Replace bytes with their value + self.tokenmap[tokenid] = bytes([int(value[3:5], 16)]) + else: + self.tokenmap[tokenid] = value.encode() self.reset() def reset(self): self.offset = 0 - self._unflushed = "" + self._unflushed = b"" self.text = "" self.tokens = [] + def _try_flush(self, force=False): + text = self._unflushed.replace(self._sep, b" ").decode("utf-8", "replace") + if not force and text.endswith("\ufffd"): + return + if not self.text and self.trim_space and text and text[0] == " ": + text = text[1:] + self.text += text + self._unflushed = b"" + def add_token(self, token): + self.tokens.append(token) v = self.tokenmap[token] - if v[0] == "\u2581": - if self.text or not self.trim_space: - self.text += self._unflushed.replace("\u2581", " ") - else: - self.text = _remove_space(self._unflushed.replace("\u2581", " ")) - self._unflushed = v - else: - self._unflushed += v + self._unflushed += v + self._try_flush() def finalize(self): - if self.text or not self.trim_space: - self.text += self._unflushed.replace("\u2581", " ") - else: - self.text = _remove_space(self._unflushed.replace("\u2581", " ")) - self._unflushed = "" + self._try_flush(force=True) + self._unflushed = b"" class BPEStreamingDetokenizer(StreamingDetokenizer): @@ -164,9 +155,10 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): """ _byte_decoder = None + _space_matches = (".", "?", "!", ",", "n't", "'m", "'s", "'ve", "'re") - def __init__(self, tokenizer, trim_space=False): - self.trim_space = trim_space + def __init__(self, tokenizer): + self.clean_spaces = tokenizer.clean_up_tokenization_spaces # Extract the tokens in a list from id to text self.tokenmap = [None] * len(tokenizer.vocab) @@ -185,29 +177,47 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): self.text = "" self.tokens = [] - def add_token(self, token): - v = self.tokenmap[token] - # if the token starts with space - if self._byte_decoder[v[0]] == 32: - current_text = bytearray( - self._byte_decoder[c] for c in self._unflushed - ).decode("utf-8") - if self.text or not self.trim_space: - self.text += current_text + def _decode_bytes(self, seq): + barr = bytearray() + for c in seq: + res = self._byte_decoder.get(c, False) + if res: + barr.append(res) else: - self.text += _remove_space(current_text) - self._unflushed = v - else: - self._unflushed += v + barr.extend(bytes(c, "utf-8")) + return barr.decode("utf-8", "replace") + + def _maybe_trim_space(self, current_text): + if len(current_text) == 0: + return current_text + elif current_text[0] != " ": + return current_text + elif not self.text: + return current_text[1:] + elif self.clean_spaces and current_text[1:].startswith(self._space_matches): + return current_text[1:] + return current_text + + def add_token(self, token): + self.tokens.append(token) + v = self.tokenmap[token] + self._unflushed += v + text = self._decode_bytes(self._unflushed) + + # For multi-byte utf-8 wait until they are complete + # For single spaces wait until the next token to clean it if needed + if not text.endswith("\ufffd") and not ( + len(v) == 1 and self._byte_decoder[v[0]] == 32 + ): + self.text += self._maybe_trim_space(text) + self._unflushed = "" def finalize(self): current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( - "utf-8" + "utf-8", + "replace", ) - if self.text or not self.trim_space: - self.text += current_text - else: - self.text += _remove_space(current_text) + self.text += self._maybe_trim_space(current_text) self._unflushed = "" @classmethod @@ -245,21 +255,45 @@ class TokenizerWrapper: huggingface tokenizer. """ - def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer): + def __init__( + self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer, eos_token_ids=None + ): self._tokenizer = tokenizer self._detokenizer = detokenizer_class(tokenizer) + self._eos_token_ids = ( + set(eos_token_ids) + if eos_token_ids is not None + else {tokenizer.eos_token_id} + ) + + def add_eos_token(self, token: str): + token_id = None + try: + token_id = int(token) + except ValueError: + token_id = self._tokenizer.convert_tokens_to_ids(token) + + if token_id is None: + raise ValueError(f"'{token}' is not a token for this tokenizer") + + self._eos_token_ids.add(token_id) def __getattr__(self, attr): if attr == "detokenizer": return self._detokenizer + elif attr == "eos_token_ids": + return self._eos_token_ids elif attr.startswith("_"): return self.__getattribute__(attr) else: return getattr(self._tokenizer, attr) def __setattr__(self, attr, value): - if attr == "detokenizer": - raise AttributeError("Cannot set the detokenizer.") + if attr in {"detokenizer", "eos_token_ids"}: + if attr == "detokenizer": + raise AttributeError("Cannot set the detokenizer.") + elif attr == "eos_token_ids": + self._eos_token_ids = set(value) if value is not None else set() elif attr.startswith("_"): super().__setattr__(attr, value) else: @@ -303,17 +337,10 @@ def _is_spm_decoder_no_space(decoder): def _is_bpe_decoder(decoder): - _target_description = { - "type": "ByteLevel", - "add_prefix_space": False, - "trim_offsets": False, - "use_regex": False, - } - - return _match(_target_description, decoder) + return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" -def load_tokenizer(model_path, tokenizer_config_extra={}): +def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None): """Load a huggingface tokenizer and try to infer the type of streaming detokenizer to use. @@ -334,7 +361,10 @@ def load_tokenizer(model_path, tokenizer_config_extra={}): elif _is_bpe_decoder(tokenizer_content["decoder"]): detokenizer_class = BPEStreamingDetokenizer + if isinstance(eos_token_ids, int): + eos_token_ids = [eos_token_ids] return TokenizerWrapper( AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), detokenizer_class, + eos_token_ids=eos_token_ids, ) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 20b32eff..1b09c7e2 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional from transformers import PreTrainedTokenizer @@ -10,41 +10,47 @@ class Dataset: Light-weight wrapper to hold a dataset. """ - def __init__(self, data: List[Dict[str, str]], text_key: str = "text"): - self._text_key = text_key - self._data = data + def __init__( + self, + data: List[Dict[str, str]], + tokenizer: PreTrainedTokenizer, + text_key: str = "text", + ): + self._data = [tokenizer.encode(d[text_key]) for d in data] + for d in self._data: + if d[-1] != tokenizer.eos_token_id: + d.append(tokenizer.eos_token_id) def __getitem__(self, idx: int): - return self._data[idx][self._text_key] + return self._data[idx] def __len__(self): - if self._data is None: - return 0 return len(self._data) -class ChatDataset(Dataset): +class ChatDataset: """ A dataset for chat data in the format of {"messages": [...]} https://platform.openai.com/docs/guides/fine-tuning/example-format """ def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): - super().__init__(data) - self._tokenizer = tokenizer + self._data = [ + tokenizer.apply_chat_template( + d["messages"], + tools=d.get("tools", None), + ) + for d in data + ] def __getitem__(self, idx: int): - messages = self._data[idx]["messages"] - text = self._tokenizer.apply_chat_template( - messages, - tools=self._data[idx].get("tools", None), - tokenize=False, - add_generation_prompt=True, - ) - return text + return self._data[idx] + + def __len__(self): + return len(self._data) -class CompletionsDataset(Dataset): +class CompletionsDataset: """ A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...} or using user-provided keys for prompt and completion values @@ -55,36 +61,41 @@ class CompletionsDataset(Dataset): self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, - prompt_key: str = "prompt", - completion_key: str = "completion", + prompt_key: str, + completion_key: str, ): - super().__init__(data) - self._tokenizer = tokenizer - self._prompt_key = prompt_key - self._completion_key = completion_key + self._data = [ + tokenizer.apply_chat_template( + [ + {"role": "user", "content": d[prompt_key]}, + {"role": "assistant", "content": d[completion_key]}, + ], + ) + for d in data + ] def __getitem__(self, idx: int): - data = self._data[idx] - text = self._tokenizer.apply_chat_template( - [ - {"role": "user", "content": data[self._prompt_key]}, - {"role": "assistant", "content": data[self._completion_key]}, - ], - tokenize=False, - add_generation_prompt=True, - ) - return text + return self._data[idx] + + def __len__(self): + return len(self._data) -def create_dataset(data, tokenizer: PreTrainedTokenizer = None): +def create_dataset( + data, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): + prompt_feature = prompt_feature or "prompt" + completion_feature = completion_feature or "completion" sample = data[0] - if "messages" in sample: return ChatDataset(data, tokenizer) - elif "prompt" in sample and "completion" in sample: - return CompletionsDataset(data, tokenizer) + elif prompt_feature in sample and completion_feature in sample: + return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) elif "text" in sample: - return Dataset(data) + return Dataset(data, tokenizer) else: raise ValueError( "Unsupported data format, check the supported formats here:\n" @@ -92,20 +103,30 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None): ) -def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer): +def load_local_dataset( + data_path: Path, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): def load_subset(path): if not path.exists(): return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer) + return create_dataset(data, tokenizer, prompt_feature, completion_feature) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] return train, valid, test -def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): +def load_hf_dataset( + data_id: str, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): from datasets import exceptions, load_dataset try: @@ -114,7 +135,13 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): names = ("train", "valid", "test") train, valid, test = [ - create_dataset(dataset[n], tokenizer) if n in dataset.keys() else [] + ( + create_dataset( + dataset[n], tokenizer, prompt_feature, completion_feature + ) + if n in dataset.keys() + else [] + ) for n in names ] @@ -143,7 +170,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): if prompt_feature and completion_feature: return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) elif text_feature: - return Dataset(train_ds, text_key=text_feature) + return Dataset(train_ds, tokenizer, text_key=text_feature) else: raise ValueError( "Specify either a prompt and completion feature or a text " @@ -166,15 +193,22 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): def load_dataset(args, tokenizer: PreTrainedTokenizer): - if getattr(args, "hf_dataset", None) is not None: + if getattr(args, "hf_dataset", False): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data) + + prompt_feature = getattr(args, "prompt_feature", None) + completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): - train, valid, test = load_local_dataset(data_path, tokenizer) + train, valid, test = load_local_dataset( + data_path, tokenizer, prompt_feature, completion_feature + ) else: print(f"Loading Hugging Face dataset {args.data}.") - train, valid, test = load_hf_dataset(args.data, tokenizer) + train, valid, test = load_hf_dataset( + args.data, tokenizer, prompt_feature, completion_feature + ) if args.train and len(train) == 0: raise ValueError( diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 1d934a72..63ca58bb 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -10,6 +10,7 @@ from typing import Union import mlx.core as mx import mlx.nn as nn import numpy as np +from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten @@ -84,22 +85,23 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) f" examples but only has {len(dataset)}." ) + # If running in distributed mode (N machines) then each one should skip N-1 + # samples + step = mx.distributed.init().size() + if batch_size % step != 0: + raise ValueError("The batch size must be divisible by the number of workers") + # Make the batches: batch_idx = [ - idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size) + idx[i : i + batch_size : step] + for i in range(0, len(idx) - batch_size + 1, batch_size) ] while True: indices = np.random.permutation(len(batch_idx)) for i in indices: - # Encode batch - batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]] - for b in batch: - if b[-1] != tokenizer.eos_token_id: - b.append(tokenizer.eos_token_id) - + batch = [dataset[j] for j in batch_idx[i]] lengths = [len(x) for x in batch] - if max(lengths) > max_seq_length: print( f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " @@ -112,9 +114,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) max_length_in_batch = min(max_length_in_batch, max_seq_length) - batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32) + batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - for j in range(batch_size): + for j in range(batch_size // step): truncated_length = min(lengths[j], max_seq_length) batch_arr[j, :truncated_length] = batch[j][:truncated_length] lengths[j] = ( @@ -138,7 +140,7 @@ def evaluate( loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): - all_losses = [] + all_losses = 0 ntokens = 0 index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) @@ -153,10 +155,14 @@ def evaluate( ), ): losses, toks = loss(model, *batch) - all_losses.append((losses * toks).item()) - ntokens += toks.item() + all_losses += losses * toks + ntokens += toks + mx.eval(all_losses, ntokens) - return np.sum(all_losses) / ntokens + all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) + ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) + + return (all_losses / ntokens).item() class TrainingCallback: @@ -182,6 +188,11 @@ def train( training_callback: TrainingCallback = None, ): print(f"Starting training..., iters: {args.iters}") + world = mx.distributed.init() + world_size = world.size() + rank = world.rank() + if world_size > 1: + print(f"Node {rank} of {world_size}") if args.grad_checkpoint: grad_checkpoint(model.layers[0]) @@ -192,6 +203,9 @@ def train( # Forward and backward pass (lvalue, toks), grad = loss_value_and_grad(model, *batch) + # All reduce the gradients if running in distributed mode + grad = average_gradients(grad) + # Model update optimizer.update(model, grad) @@ -199,8 +213,9 @@ def train( loss_value_and_grad = nn.value_and_grad(model, loss) - losses = [] + losses = 0 n_tokens = 0 + steps = 0 trained_tokens = 0 # Main training loop start = time.perf_counter() @@ -229,9 +244,13 @@ def train( iterate_batches=iterate_batches, ) val_time = time.perf_counter() - stop - print( - f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s" - ) + if rank == 0: + print( + f"Iter {it}: " + f"Val loss {val_loss:.3f}, " + f"Val took {val_time:.3f}s", + flush=True, + ) if training_callback is not None: val_info = { @@ -244,30 +263,33 @@ def train( start = time.perf_counter() lvalue, toks = step(batch) - mx.eval(state, lvalue, toks) - - # Record loss - losses.append(lvalue.item()) - n_tokens += toks.item() + losses += lvalue + n_tokens += toks + steps += 1 + mx.eval(state, losses, n_tokens) # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: stop = time.perf_counter() - train_loss = np.mean(losses) + train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() + train_loss /= steps * mx.distributed.init().size() + n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens - peak_mem = mx.metal.get_peak_memory() / 2**30 - print( - f"Iter {it}: Train loss {train_loss:.3f}, " - f"Learning Rate {learning_rate:.3e}, " - f"It/sec {it_sec:.3f}, " - f"Tokens/sec {tokens_sec:.3f}, " - f"Trained Tokens {trained_tokens}, " - f"Peak mem {peak_mem:.3f} GB" - ) + peak_mem = mx.metal.get_peak_memory() / 1e9 + if rank == 0: + print( + f"Iter {it}: Train loss {train_loss:.3f}, " + f"Learning Rate {learning_rate:.3e}, " + f"It/sec {it_sec:.3f}, " + f"Tokens/sec {tokens_sec:.3f}, " + f"Trained Tokens {trained_tokens}, " + f"Peak mem {peak_mem:.3f} GB", + flush=True, + ) if training_callback is not None: train_info = { @@ -281,8 +303,9 @@ def train( } training_callback.on_train_loss_report(train_info) - losses = [] + losses = 0 n_tokens = 0 + steps = 0 start = time.perf_counter() # Save adapter weights diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 7c78ee91..594f8040 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -96,8 +96,11 @@ def linear_to_lora_layers( "gemma2", "starcoder2", "cohere", + "cohere2", "minicpm", "deepseek", + "olmo2", + "internlm3", ]: keys = set(["self_attn.q_proj", "self_attn.v_proj"]) if model.model_type in ["mixtral", "phimoe"]: @@ -143,6 +146,8 @@ def linear_to_lora_layers( "mixer.out_proj", ] ) + elif model.model_type == "exaone": + keys = set(["attn.attention.q_proj", "attn.attention.v_proj"]) else: raise ValueError(f"Lora does not support {model.model_type}") @@ -249,12 +254,14 @@ def remove_lora_layers(model: nn.Module) -> nn.Module: return model -def print_trainable_parameters(model): - def nparams(m): - if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): - return m.weight.size * (32 // m.bits) - return sum(v.size for _, v in tree_flatten(m.parameters())) +def nparams(module): + if hasattr(module, "bits"): + n = 0 if not hasattr(module, "bias") else module.bias.size + return n + module.weight.size * 32 // module.bits + return sum(v.size for _, v in tree_flatten(module.parameters())) + +def print_trainable_parameters(model): leaf_modules = tree_flatten( model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 1e07546e..b9037295 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -1,37 +1,55 @@ # Copyright © 2023-2024 Apple Inc. +import contextlib import copy +import functools import glob import importlib import json import logging +import os import shutil import time +from dataclasses import dataclass from pathlib import Path from textwrap import dedent from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union import mlx.core as mx import mlx.nn as nn -from huggingface_hub import snapshot_download -from mlx.utils import tree_flatten + +if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true": + try: + from modelscope import snapshot_download + except ImportError: + raise ImportError( + "Please run `pip install modelscope` to activate the ModelScope." + ) +else: + from huggingface_hub import snapshot_download + +from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer # Local imports -from .models import base, cache -from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling +from .models import cache +from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model -from .tuner.utils import load_adapters +from .tuner.utils import load_adapters, nparams # Constants MODEL_REMAPPING = { "mistral": "llama", # mistral is compatible with llama "phi-msft": "phixtral", + "falcon_mamba": "mamba", } MAX_FILE_SIZE_GB = 5 +# A stream on the default device just for generation +generation_stream = mx.new_stream(mx.default_device()) + class ModelNotFoundError(Exception): def __init__(self, message): @@ -39,6 +57,68 @@ class ModelNotFoundError(Exception): super().__init__(self.message) +@dataclass +class GenerationResponse: + """ + The output of :func:`stream_generate`. + + Args: + text (str): The next segment of decoded text. This can be an empty string. + token (int): The next token. + logprobs (mx.array): A vector of log probabilities. + prompt_tokens (int): The number of tokens in the prompt. + prompt_tps (float): The prompt processing tokens-per-second. + generation_tokens (int): The number of generated tokens. + generation_tps (float): The tokens-per-second for generation. + peak_memory (float): The peak memory used so far in GB. + finish_reason (str): The reason the response is being sent: "length", "stop" or `None` + """ + + text: str + token: int + logprobs: mx.array + prompt_tokens: int + prompt_tps: float + generation_tokens: int + generation_tps: float + peak_memory: float + finish_reason: Optional[str] = None + + +@contextlib.contextmanager +def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): + """ + A context manager to temporarily change the wired limit. + + Note, the wired limit should not be changed during an async eval. If an + async eval could be running pass in the streams to synchronize with prior + to exiting the context manager. + """ + model_bytes = tree_reduce( + lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 + ) + max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"] + if model_bytes > 0.9 * max_rec_size: + model_mb = model_bytes // 2**20 + max_rec_mb = max_rec_size // 2**20 + print( + f"[WARNING] Generating with a model that requires {model_mb} MB " + f"which is close to the maximum recommended size of {max_rec_mb} " + "MB. This can be slow. See the documentation for possible work-arounds: " + "https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models" + ) + old_limit = mx.metal.set_wired_limit(max_rec_size) + try: + yield None + finally: + if streams is not None: + for s in streams: + mx.synchronize(s) + else: + mx.synchronize() + mx.metal.set_wired_limit(old_limit) + + def _get_classes(config: dict): """ Retrieve the model and model args classes based on the configuration. @@ -61,6 +141,17 @@ def _get_classes(config: dict): return arch.Model, arch.ModelArgs +def compute_bits_per_weight(model): + model_bytes = tree_reduce( + lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 + ) + leaf_modules = tree_flatten( + model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) + ) + model_params = sum(nparams(m) for _, m in leaf_modules) + return model_bytes * 8 / model_params + + def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: """ Ensures the model is available locally. If the path does not exist locally, @@ -74,11 +165,12 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path Path: The path to the model. """ model_path = Path(path_or_hf_repo) + if not model_path.exists(): try: model_path = Path( snapshot_download( - repo_id=path_or_hf_repo, + path_or_hf_repo, revision=revision, allow_patterns=[ "*.json", @@ -101,43 +193,33 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path return model_path -def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float): - """ - Apply repetition penalty to specific logits based on the given context. - - Paper: https://arxiv.org/abs/1909.05858 - - Args: - logits (mx.array): The logits produced by the language model. - tokens (mx.array): A list of N previous tokens. - penalty (float): The repetition penalty factor to be applied. - - Returns: - logits (mx.array): Logits with repetition penalty applied to generated tokens. - """ - if len(tokens) > 0: - selected_logits = logits[:, tokens] - selected_logits = mx.where( - selected_logits < 0, selected_logits * penalty, selected_logits / penalty - ) - logits[:, tokens] = selected_logits - return logits +def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): + if ( + kv_bits is not None + and not isinstance(prompt_cache[0], cache.QuantizedKVCache) + and prompt_cache[0].offset > quantized_kv_start + ): + for i in range(len(prompt_cache)): + if isinstance(prompt_cache[i], cache.KVCache): + prompt_cache[i] = prompt_cache[i].to_quantized( + group_size=kv_group_size, bits=kv_bits + ) def generate_step( prompt: mx.array, model: nn.Module, - temp: float = 0.0, - repetition_penalty: Optional[float] = None, - repetition_context_size: Optional[int] = 20, - top_p: float = 1.0, - min_p: float = 0.0, - min_tokens_to_keep: int = 1, - prefill_step_size: int = 512, + *, + max_tokens: int = 256, + sampler: Optional[Callable[mx.array, mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, - logit_bias: Optional[Dict[int, float]] = None, - logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prefill_step_size: int = 512, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, + prompt_progress_callback: Optional[Callable[int, int]] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -145,231 +227,390 @@ def generate_step( Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - temp (float): The temperature for sampling, if 0 the argmax is used. - Default: ``0``. - repetition_penalty (float, optional): The penalty factor for repeating - tokens. - repetition_context_size (int, optional): The number of tokens to - consider for repetition penalty. Default: ``20``. - top_p (float, optional): Nulceus sampling, higher means model considers - more less likely words. - min_p (float, optional): The minimum value (scaled by the top token's - probability) that a token probability must have to be considered. - min_tokens_to_keep (int, optional): Minimum number of tokens that cannot - be filtered by min_p sampling. - prefill_step_size (int): Step size for processing the prompt. + max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite + generator. Default: ``256``. + sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a + token from a vector of log probabilities. Default: ``None``. + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): + A list of functions that take tokens and logits and return the processed + logits. Default: ``None``. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if provided, the cache will be updated in place. - logit_bias (dictionary, optional): Additive logit bias. - logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): - A list of functions that take tokens and logits and return the processed - logits. Default: ``None``. + prefill_step_size (int): Step size for processing the prompt. + kv_bits (int, optional): Number of bits to use for KV cache quantization. + None implies no cache quantization. Default: ``None``. + kv_group_size (int): Group size for KV cache quantization. Default: ``64``. + quantized_kv_start (int): Step to begin using a quantized KV cache. + when ``kv_bits`` is non-None. Default: ``0``. + prompt_prorgress_callback (Callable[int, int]): A call-back which takes the + prompt tokens processed so far and the total number of prompt tokens. Yields: - Generator[Tuple[mx.array, mx.array], None, None]: A generator producing - one token and a vector of log probabilities. + Tuple[mx.array, mx.array]: One token and a vector of log probabilities. """ - def sample(logits: mx.array) -> Tuple[mx.array, float]: - logprobs = logits - mx.logsumexp(logits) - - if temp == 0: - token = mx.argmax(logits, axis=-1) - else: - if top_p > 0 and top_p < 1.0: - token = top_p_sampling(logits, top_p, temp) - elif min_p != 0.0: - token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp) - else: - token = categorical_sampling(logits, temp) - - return token, logprobs - - if repetition_penalty and ( - repetition_penalty < 0 or not isinstance(repetition_penalty, float) - ): - raise ValueError( - f"repetition_penalty must be a non-negative float, got {repetition_penalty}" - ) - - logits_processor = logits_processor or [] - - if repetition_penalty: - - def repetition_penalty_processor(tokens, logits): - return apply_repetition_penalty( - logits, tokens[-repetition_context_size:], repetition_penalty - ) - - logits_processor.append(repetition_penalty_processor) - - if logit_bias: - indices = mx.array(list(logit_bias.keys())) - values = mx.array(list(logit_bias.values())) - - def logit_bias_processor(_, logits): - logits[:, indices] += values - return logits - - logits_processor.append(logit_bias_processor) - y = prompt tokens = None # Create the KV cache for generation if prompt_cache is None: - prompt_cache = cache.make_prompt_cache(model, max_kv_size) + prompt_cache = cache.make_prompt_cache( + model, + max_kv_size=max_kv_size, + ) elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") + prompt_progress_callback = prompt_progress_callback or (lambda *_: None) + + quantize_cache_fn = functools.partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + def _step(y): - logits = model(y[None], cache=prompt_cache) - logits = logits[:, -1, :] + with mx.stream(generation_stream): + logits = model(y[None], cache=prompt_cache) + logits = logits[:, -1, :] - if logits_processor: - nonlocal tokens - tokens = mx.concat([tokens, y]) if tokens is not None else y + if logits_processors: + nonlocal tokens + tokens = mx.concat([tokens, y]) if tokens is not None else y - for processor in logits_processor: - logits = processor(tokens, logits) + for processor in logits_processors: + logits = processor(tokens, logits) - y, logprobs = sample(logits) - return y, logprobs.squeeze(0) + quantize_cache_fn(prompt_cache) - while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - y = y[prefill_step_size:] - mx.metal.clear_cache() + logprobs = logits - mx.logsumexp(logits, keepdims=True) + y = sampler(logprobs) + return y, logprobs.squeeze(0) - y, logprobs = _step(y) + with mx.stream(generation_stream): + total_prompt_tokens = y.size + prompt_processed_tokens = 0 + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=prompt_cache) + quantize_cache_fn(prompt_cache) + mx.eval([c.state for c in prompt_cache]) + prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) + prompt_processed_tokens += prefill_step_size + y = y[prefill_step_size:] + mx.metal.clear_cache() - mx.async_eval(y) + y, logprobs = _step(y) + + mx.async_eval(y, logprobs) + n = 0 while True: - next_y, next_logprobs = _step(y) - mx.async_eval(next_y) + if n != max_tokens: + next_y, next_logprobs = _step(y) + mx.async_eval(next_y, next_logprobs) + if n == 0: + mx.eval(y) + prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) + if n == max_tokens: + break yield y.item(), logprobs + if n % 256 == 0: + mx.metal.clear_cache() y, logprobs = next_y, next_logprobs + n += 1 + + +def speculative_generate_step( + prompt: mx.array, + model: nn.Module, + draft_model: nn.Module, + *, + num_draft_tokens=2, + max_tokens: int = 256, + sampler: Optional[Callable[mx.array, mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prompt_cache: Optional[Any] = None, + prefill_step_size: int = 512, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, +) -> Generator[Tuple[mx.array, mx.array], None, None]: + """ + A generator producing token ids based on the given prompt from the model. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + draft_model (nn.Module): The draft model for speculative decoding. + num_draft_tokens (int, optional): The number of draft tokens for + speculative decoding. Default: ``2``. + max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite + generator. Default: ``256``. + sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a + token from a vector of log probabilities. Default: ``None``. + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): + A list of functions that take tokens and logits and return the processed + logits. Default: ``None``. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. The cache must be trimmable. + prefill_step_size (int): Step size for processing the prompt. + kv_bits (int, optional): Number of bits to use for KV cache quantization. + None implies no cache quantization. Default: ``None``. + kv_group_size (int): Group size for KV cache quantization. Default: ``64``. + quantized_kv_start (int): Step to begin using a quantized KV cache. + when ``kv_bits`` is non-None. Default: ``0``. + + Yields: + Tuple[mx.array, mx.array]: One token and a vector of log probabilities. + """ + + y = prompt + tokens = None + + # Create the KV cache for generation + if prompt_cache is None: + model_cache = cache.make_prompt_cache(model) + draft_cache = cache.make_prompt_cache(draft_model) + elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)): + raise ValueError("Wrong number of layers in the prompt cache.") + else: + model_cache = prompt_cache[: len(model.layers)] + draft_cache = prompt_cache[len(model.layers) :] + + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + + quantize_cache_fn = functools.partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + + def _step(model, cache, y, n_predict=1): + with mx.stream(generation_stream): + logits = model(y[None], cache=cache) + logits = logits[:, -n_predict:, :] + + quantize_cache_fn(cache) + + logprobs = logits - mx.logsumexp(logits, keepdims=True) + y = sampler(logprobs).squeeze(0) + return y, logprobs.squeeze(0) + + def _prefill(model, cache, y): + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=cache) + quantize_cache_fn(cache) + mx.eval([c.state for c in cache]) + y = y[prefill_step_size:] + mx.metal.clear_cache() + return y + + def _rewind_cache(num_draft, num_accept): + cache.trim_prompt_cache(model_cache, num_draft - num_accept) + cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0)) + + def _draft_generate(y, num_draft): + if num_draft == 0: + return mx.array([], mx.uint32) + ys = [] + for _ in range(num_draft): + y, _ = _step(draft_model, draft_cache, y) + mx.async_eval(y) + ys.append(y) + return mx.concatenate(ys) + + with mx.stream(generation_stream): + draft_y = _prefill(draft_model, draft_cache, y) + y = _prefill(model, model_cache, y) + + ntoks = 0 + # Set these so the finally block doesn't raise + num_draft = 0 + n = 0 + try: + while True: + num_draft = min(max_tokens - ntoks, num_draft_tokens) + draft_tokens = _draft_generate(draft_y, num_draft) + y = mx.concatenate([y, draft_tokens]) + + tokens, logprobs = _step(model, model_cache, y, num_draft + 1) + mx.eval(tokens, draft_tokens) + draft_tokens = draft_tokens.tolist() + tokens = tokens.tolist() + n = 0 + while n < num_draft: + tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n] + if tn != dtn: + break + n += 1 + ntoks += 1 + yield tn, lpn + if ntoks == max_tokens: + break + if ntoks < max_tokens: + ntoks += 1 + yield tokens[n], logprobs[n] + + if ntoks == max_tokens: + break + + y = mx.array([tokens[n]], mx.uint32) + draft_y = y + + # If we accpeted all the draft tokens, include the last + # draft token in the next draft step since it hasn't been + # processed yet by the draft model + if n == num_draft: + draft_y = mx.concatenate( + [mx.array(draft_tokens[-1:], mx.uint32), draft_y] + ) + + _rewind_cache(num_draft, n) + finally: + _rewind_cache(num_draft, n) def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: str, - max_tokens: int = 100, + prompt: Union[str, mx.array, List[int]], + draft_model: Optional[nn.Module] = None, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> Generator[GenerationResponse, None, None]: """ A generator producing text based on the given prompt from the model. Args: - prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - max_tokens (int): The ma + tokenizer (PreTrainedTokenizer): The tokenizer. + prompt (Union[str, mx.array, List[int]]): The input prompt string or + integer tokens. + draft_model (Optional[nn.Module]): An optional draft model. If provided + then speculative decoding is used. The draft model must use the same + tokenizer as the main model. Default: ``None``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. Yields: - Generator[Tuple[mx.array, mx.array]]: A generator producing text. + GenerationResponse: An instance containing the generated text segment and + associated metadata. See :class:`GenerationResponse` for details. """ if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - prompt_tokens = mx.array(tokenizer.encode(prompt)) + if not isinstance(prompt, mx.array): + if isinstance(prompt, str): + # Try to infer if special tokens are needed + add_special_tokens = tokenizer.bos_token is None or not prompt.startswith( + tokenizer.bos_token + ) + prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens) + prompt = mx.array(prompt) + detokenizer = tokenizer.detokenizer - detokenizer.reset() - for n, (token, _) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if token == tokenizer.eos_token_id: - break - detokenizer.add_token(token) + if draft_model is None: + kwargs.pop("num_draft_tokens", None) + token_generator = generate_step(prompt, model, **kwargs) + else: + kwargs.pop("max_kv_size", None) + token_generator = speculative_generate_step( + prompt, model, draft_model, **kwargs + ) + with wired_limit(model, [generation_stream]): + detokenizer.reset() + tic = time.perf_counter() + for n, (token, logprobs) in enumerate(token_generator): + if n == 0: + prompt_time = time.perf_counter() - tic + prompt_tps = prompt.size / prompt_time + tic = time.perf_counter() + if token in tokenizer.eos_token_ids: + break - # Yield the last segment if streaming - yield detokenizer.last_segment + detokenizer.add_token(token) - detokenizer.finalize() - yield detokenizer.last_segment + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + prompt_tokens=prompt.size, + prompt_tps=prompt_tps, + generation_tokens=n + 1, + generation_tps=(n + 1) / (time.perf_counter() - tic), + peak_memory=mx.metal.get_peak_memory() / 1e9, + finish_reason=None, + ) + + detokenizer.finalize() + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + prompt_tokens=prompt.size, + prompt_tps=prompt_tps, + generation_tokens=n + 1, + generation_tps=(n + 1) / (time.perf_counter() - tic), + peak_memory=mx.metal.get_peak_memory() / 1e9, + finish_reason="stop" if token in tokenizer.eos_token_ids else "length", + ) def generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: str, - max_tokens: int = 100, + prompt: Union[str, List[int]], verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> str: """ Generate a complete response from the model. Args: model (nn.Module): The language model. tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (str): The string prompt. - max_tokens (int): The maximum number of tokens. Default: ``100``. + prompt (Union[str, List[int]]): The input prompt string or integer tokens. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. - formatter (Optional[Callable]): A function which takes a token and a - probability and displays it. - kwargs: The remaining options get passed to :func:`generate_step`. - See :func:`generate_step` for more details. + kwargs: The remaining options get passed to :func:`stream_generate`. + See :func:`stream_generate` for more details. """ - if not isinstance(tokenizer, TokenizerWrapper): - tokenizer = TokenizerWrapper(tokenizer) - + if formatter is not None: + print( + "[Warning] Text formatting is deprecated and no longer used. " + "The argument will be removed in a future version." + ) if verbose: print("=" * 10) - print("Prompt:", prompt) - - prompt_tokens = mx.array(tokenizer.encode(prompt)) - detokenizer = tokenizer.detokenizer - - tic = time.perf_counter() - detokenizer.reset() - - for n, (token, logprobs) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if n == 0: - prompt_time = time.perf_counter() - tic - tic = time.perf_counter() - if token == tokenizer.eos_token_id: - break - detokenizer.add_token(token) + text = "" + for response in stream_generate(model, tokenizer, prompt, **kwargs): if verbose: - if formatter: - # We have to finalize so that the prob corresponds to the last segment - detokenizer.finalize() - formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) - else: - print(detokenizer.last_segment, end="", flush=True) - - token_count = n + 1 - detokenizer.finalize() + print(response.text, end="", flush=True) + text += response.text if verbose: - gen_time = time.perf_counter() - tic - print(detokenizer.last_segment, flush=True) + print() print("=" * 10) - if token_count == 0: - print("No tokens generated for this prompt") + if len(text) == 0: + print("No text generated for this prompt") return - prompt_tps = prompt_tokens.size / prompt_time - gen_tps = (token_count - 1) / gen_time - print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec") - print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 2**30 - print(f"Peak memory: {peak_mem:.3f} GB") - - return detokenizer.text + print( + f"Prompt: {response.prompt_tokens} tokens, " + f"{response.prompt_tps:.3f} tokens-per-sec" + ) + print( + f"Generation: {response.generation_tokens} tokens, " + f"{response.generation_tps:.3f} tokens-per-sec" + ) + print(f"Peak memory: {response.peak_memory:.3f} GB") + return text def load_config(model_path: Path) -> dict: @@ -396,11 +637,11 @@ def load_model( lazy (bool): If False eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` - model_config (dict, optional): Configuration parameters for the model. - Defaults to an empty dictionary. + model_config (dict, optional): Optional configuration parameters for the + model. Defaults to an empty dictionary. get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): A function that returns the model class and model args class given a config. - Defaults to the _get_classes function. + Defaults to the ``_get_classes`` function. Returns: nn.Module: The loaded and initialized model. @@ -409,7 +650,6 @@ def load_model( FileNotFoundError: If the weight files (.safetensors) are not found. ValueError: If the model class or args class are not found or cannot be instantiated. """ - config = load_config(model_path) config.update(model_config) @@ -436,15 +676,20 @@ def load_model( weights = model.sanitize(weights) if (quantization := config.get("quantization", None)) is not None: - # Handle legacy models which may not have everything quantized + def class_predicate(p, m): + # Handle custom per layer quantizations + if p in config["quantization"]: + return config["quantization"][p] if not hasattr(m, "to_quantized"): return False + # Handle legacy models which may not have everything quantized return f"{p}.scales" in weights nn.quantize( model, - **quantization, + group_size=quantization["group_size"], + bits=quantization["bits"], class_predicate=class_predicate, ) @@ -454,7 +699,7 @@ def load_model( mx.eval(model.parameters()) model.eval() - return model + return model, config def load( @@ -475,7 +720,7 @@ def load( Defaults to an empty dictionary. adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers to the model. Default: ``None``. - lazy (bool): If False eval the model parameters to make sure they are + lazy (bool): If ``False`` eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` Returns: @@ -487,11 +732,13 @@ def load( """ model_path = get_model_path(path_or_hf_repo) - model = load_model(model_path, lazy, model_config) + model, config = load_model(model_path, lazy) if adapter_path is not None: model = load_adapters(model, adapter_path) model.eval() - tokenizer = load_tokenizer(model_path, tokenizer_config) + tokenizer = load_tokenizer( + model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None) + ) return model, tokenizer @@ -499,9 +746,10 @@ def load( def fetch_from_hub( model_path: Path, lazy: bool = False ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: - model = load_model(model_path, lazy) - config = load_config(model_path) - tokenizer = load_tokenizer(model_path) + model, config = load_model(model_path, lazy) + tokenizer = load_tokenizer( + model_path, eos_token_ids=config.get("eos_token_id", None) + ) return model, config, tokenizer @@ -551,7 +799,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): f""" # {upload_repo} - The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**. + The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was + converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) + using mlx-lm version **{__version__}**. ## Use with mlx @@ -564,12 +814,12 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): model, tokenizer = load("{upload_repo}") - prompt="hello" + prompt = "hello" - if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None: + if tokenizer.chat_template is not None: messages = [{{"role": "user", "content": prompt}}] prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, add_generation_prompt=True ) response = generate(model, tokenizer, prompt=prompt, verbose=True) @@ -582,12 +832,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): api = HfApi() api.create_repo(repo_id=upload_repo, exist_ok=True) - api.upload_folder( + api.upload_large_folder( folder_path=path, repo_id=upload_repo, repo_type="model", - multi_commits=True, - multi_commits_verbose=True, ) print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.") @@ -645,7 +893,13 @@ def save_weights( def quantize_model( - model: nn.Module, config: dict, q_group_size: int, q_bits: int + model: nn.Module, + config: dict, + q_group_size: int, + q_bits: int, + quant_predicate: Optional[ + Callable[[str, nn.Module, dict], Union[bool, dict]] + ] = None, ) -> Tuple: """ Applies quantization to the model weights. @@ -655,17 +909,37 @@ def quantize_model( config (dict): Model configuration. q_group_size (int): Group size for quantization. q_bits (int): Bits per weight for quantization. + quant_predicate (Callable): A callable that decides how + to quantize each layer based on the path. + Accepts the layer `path`, the `module` and the model `config`. + Returns either a bool to signify quantize/no quantize or + a dict of quantization parameters to pass to `to_quantized`. Returns: Tuple: Tuple containing quantized weights and config. """ quantized_config = copy.deepcopy(config) - nn.quantize(model, q_group_size, q_bits) quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} + + # Add any custom quantization parameters to the config as we go + def _class_predicate(p, m): + bool_or_params = quant_predicate(p, m, config) + quantized_config["quantization"][p] = bool_or_params + return bool_or_params + + nn.quantize( + model, + q_group_size, + q_bits, + class_predicate=_class_predicate if quant_predicate else None, + ) # support hf model tree #957 quantized_config["quantization_config"] = quantized_config["quantization"] quantized_weights = dict(tree_flatten(model.parameters())) + bpw = compute_bits_per_weight(model) + print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.") + return quantized_weights, quantized_config @@ -702,6 +976,9 @@ def convert( upload_repo: str = None, revision: Optional[str] = None, dequantize: bool = False, + quant_predicate: Optional[ + Callable[[str, nn.Module, dict], Union[bool, dict]] + ] = None, ): # Check the save path is empty if isinstance(mlx_path, str): @@ -718,7 +995,7 @@ def convert( model, config, tokenizer = fetch_from_hub(model_path, lazy=True) weights = dict(tree_flatten(model.parameters())) - dtype = mx.float16 if quantize else getattr(mx, dtype) + dtype = getattr(mx, dtype) weights = {k: v.astype(dtype) for k, v in weights.items()} if quantize and dequantize: @@ -727,7 +1004,9 @@ def convert( if quantize: print("[INFO] Quantizing") model.load_weights(list(weights.items())) - weights, config = quantize_model(model, config, q_group_size, q_bits) + weights, config = quantize_model( + model, config, q_group_size, q_bits, quant_predicate=quant_predicate + ) if dequantize: print("[INFO] Dequantizing") diff --git a/llms/setup.py b/llms/setup.py index 1c696dc0..e6fddbae 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -27,13 +27,15 @@ setup( packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"], python_requires=">=3.8", extras_require={ - "testing": ["datasets"], + "test": ["datasets"], + "evaluate": ["lm-eval", "tqdm"], }, entry_points={ "console_scripts": [ "mlx_lm.cache_prompt = mlx_lm.cache_prompt:main", "mlx_lm.chat = mlx_lm.chat:main", "mlx_lm.convert = mlx_lm.convert:main", + "mlx_lm.evaluate = mlx_lm.evaluate:main", "mlx_lm.fuse = mlx_lm.fuse:main", "mlx_lm.generate = mlx_lm.generate:main", "mlx_lm.lora = mlx_lm.lora:main", diff --git a/llms/tests/test_datsets.py b/llms/tests/test_datsets.py index 240bfb4a..dd86d277 100644 --- a/llms/tests/test_datsets.py +++ b/llms/tests/test_datsets.py @@ -36,7 +36,8 @@ class TestDatasets(unittest.TestCase): data = {"text": "This is an example for the model."} self.save_data(4 * [data]) args = types.SimpleNamespace(train=True, test=False, data=self.test_dir) - train, valid, test = datasets.load_dataset(args, None) + tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH) + train, valid, test = datasets.load_dataset(args, tokenizer) self.assertEqual(len(train), 4) self.assertEqual(len(valid), 4) self.assertEqual(len(test), 0) @@ -82,6 +83,8 @@ class TestDatasets(unittest.TestCase): "name": "billsum", "prompt_feature": "text", "completion_feature": "summary", + "train_split": "train[:2%]", + "valid_split": "train[-2%:]", }, test=False, train=True, diff --git a/llms/tests/test_finetune.py b/llms/tests/test_finetune.py index 107be092..a6d53747 100644 --- a/llms/tests/test_finetune.py +++ b/llms/tests/test_finetune.py @@ -3,6 +3,7 @@ import math import sys import unittest +from contextlib import contextmanager from io import StringIO from unittest.mock import MagicMock @@ -17,6 +18,14 @@ from mlx_lm.tuner.trainer import evaluate from mlx_lm.tuner.utils import build_schedule +@contextmanager +def swapped_with_identity(obj, func): + old_func = getattr(obj, func) + setattr(obj, func, lambda x, **kwargs: x) + yield + setattr(obj, func, old_func) + + class TestLora(unittest.TestCase): def setUp(self): self.capturedOutput = StringIO() @@ -374,16 +383,17 @@ class TestScheduleConfig(unittest.TestCase): (MagicMock(return_value=0.4), MagicMock(return_value=180)), (MagicMock(return_value=0.6), MagicMock(return_value=120)), ] - evaluate( - model=mock_model, - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - num_batches=2, - max_seq_length=2048, - loss=mock_default_loss, - iterate_batches=mock_iterate_batches, - ) + with swapped_with_identity(mx.distributed, "all_sum"): + evaluate( + model=mock_model, + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + num_batches=2, + max_seq_length=2048, + loss=mock_default_loss, + iterate_batches=mock_iterate_batches, + ) mock_iterate_batches.assert_called_once_with( dataset=mock_dataset, @@ -412,16 +422,17 @@ class TestScheduleConfig(unittest.TestCase): (MagicMock(return_value=0.2), MagicMock(return_value=150)), ] - evaluate( - model=mock_model, - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - num_batches=-1, - max_seq_length=2048, - loss=mock_default_loss, - iterate_batches=mock_iterate_batches, - ) + with swapped_with_identity(mx.distributed, "all_sum"): + evaluate( + model=mock_model, + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + num_batches=-1, + max_seq_length=2048, + loss=mock_default_loss, + iterate_batches=mock_iterate_batches, + ) mock_iterate_batches.assert_called_once_with( dataset=mock_dataset, diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index 68f1670b..f2345394 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -2,6 +2,7 @@ import unittest +from mlx_lm.sample_utils import make_logits_processors from mlx_lm.utils import generate, load @@ -25,8 +26,8 @@ class TestGenerate(unittest.TestCase): self.tokenizer, "hello", max_tokens=5, + logits_processors=make_logits_processors(logit_bias), verbose=False, - logit_bias=logit_bias, ) self.assertEqual(text, "!!!!!") @@ -46,7 +47,7 @@ class TestGenerate(unittest.TestCase): "hello", max_tokens=5, verbose=False, - logits_processor=[logits_processor], + logits_processors=[logits_processor], ) self.assertEqual(len(all_toks), len(init_toks) + 5) diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 1efde5ae..d8cf6820 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -2,7 +2,10 @@ import unittest import mlx.core as mx +import mlx.nn as nn from mlx.utils import tree_map +from mlx_lm.models import rope_utils +from mlx_lm.models.base import create_causal_mask from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache @@ -126,6 +129,42 @@ class TestModels(unittest.TestCase): self.assertEqual(cache.offset, 22) self.assertTrue(mx.allclose(x, k[..., -2:, :])) + def test_causal_mask_lengths(self): + mx.random.seed(8) + B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2) + lengths = mx.array([1, 2, 3, 1]) + q = mx.random.uniform(shape=(B, N_q, T_q, D)) + k = mx.random.uniform(shape=(B, N_kv, T_kv, D)) + v = k + mask = create_causal_mask(T_q, 0, lengths=lengths) + + out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + q[1, :, 2:] = mx.ones_like(q[1, :, 2:]) + k[1, :, 2:] = mx.ones_like(k[1, :, 2:]) + v[1, :, 2:] = mx.ones_like(v[1, :, 2:]) + out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2])) + + def test_rope(self): + rope = rope_utils.initialize_rope(32, base=100, traditional=False) + self.assertTrue(isinstance(rope, nn.RoPE)) + + rope = rope_utils.initialize_rope( + 32, + base=100, + traditional=False, + scaling_config={"rope_type": "linear", "factor": 10.0}, + ) + self.assertTrue(isinstance(rope, nn.RoPE)) + + rope = rope_utils.initialize_rope( + 32, + base=100, + traditional=False, + scaling_config={"rope_type": "llama3", "factor": 2.0}, + ) + self.assertTrue(isinstance(rope, rope_utils.Llama3RoPE)) + def model_test_runner(self, model, model_type, vocab_size, num_layers): self.assertEqual(len(model.layers), num_layers) @@ -140,10 +179,16 @@ class TestModels(unittest.TestCase): self.assertEqual(outputs.dtype, t) cache = make_prompt_cache(model) - outputs = model(inputs, cache) + outputs = model(inputs, cache=cache) self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) + if model_type != "mamba": + mask = create_causal_mask(inputs.shape[1], 0).astype(t) + outputs = model(inputs, mask=mask) + self.assertEqual(outputs.shape, (1, 2, vocab_size)) + self.assertEqual(outputs.dtype, t) + outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache) self.assertEqual(outputs.shape, (1, 1, vocab_size)) self.assertEqual(outputs.dtype, t) @@ -637,6 +682,43 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_deepseek_v3(self): + from mlx_lm.models import deepseek_v3 + + args = deepseek_v3.ModelArgs( + model_type="deepseek_v3", + vocab_size=1024, + hidden_size=128, + intermediate_size=256, + moe_intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + n_routed_experts=4, + n_group=2, + topk_group=1, + num_experts_per_tok=2, + n_shared_experts=1, + kv_lora_rank=4, + q_lora_rank=4, + qk_rope_head_dim=32, + v_head_dim=16, + qk_nope_head_dim=32, + rope_scaling={ + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn", + }, + ) + model = deepseek_v3.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_gemma2(self): from mlx_lm.models import gemma2 @@ -760,6 +842,108 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_hunyuan(self): + from mlx_lm.models import hunyuan + + args = hunyuan.ModelArgs( + model_type="hunyuan", + hidden_size=128, + attention_bias=False, + intermediate_size=256, + num_attention_heads=4, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-4, + rope_theta=1000, + vocab_size=1000, + moe_topk=2, + num_experts=2, + num_shared_expert=1, + use_mixed_mlp_moe=True, + use_qk_norm=True, + rope_scaling={ + "alpha": 1000.0, + "factor": 1.0, + "type": "dynamic", + }, + use_cla=True, + cla_share_factor=2, + ) + model = hunyuan.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_olmo2(self): + from mlx_lm.models import olmo2 + + args = olmo2.ModelArgs( + model_type="olmo2", + hidden_size=128, + attention_bias=False, + intermediate_size=256, + num_attention_heads=4, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-4, + rope_theta=1000, + vocab_size=1000, + ) + model = olmo2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_exaone(self): + from mlx_lm.models import exaone + + args = exaone.ModelArgs( + model_type="exaone", + hidden_size=128, + num_layers=4, + intermediate_size=256, + num_attention_heads=8, + num_key_value_heads=2, + vocab_size=1000, + layer_norm_epsilon=1e-4, + rope_theta=10000, + ) + model = exaone.Model(args) + self.model_test_runner(model, args.model_type, args.vocab_size, args.num_layers) + + def test_cohere2(self): + from mlx_lm.models import cohere2 + + args = cohere2.ModelArgs( + model_type="cohere2", + hidden_size=4096, + head_dim=128, + num_hidden_layers=40, + sliding_window=4096, + sliding_window_pattern=4, + ) + model = cohere2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_internlm3(self): + from mlx_lm.models import internlm3 + + args = internlm3.ModelArgs( + model_type="internlm3", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=4, + rms_norm_eps=1e-5, + vocab_size=10_000, + ) + model = internlm3.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + if __name__ == "__main__": unittest.main() diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 3c1ef49b..de5694d5 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -1,5 +1,6 @@ # Copyright © 2024 Apple Inc. +import copy import os import tempfile import unittest @@ -8,6 +9,7 @@ import mlx.core as mx from mlx_lm.models.cache import ( KVCache, MambaCache, + QuantizedKVCache, RotatingKVCache, load_prompt_cache, make_prompt_cache, @@ -119,21 +121,20 @@ class TestPromptCache(unittest.TestCase): def test_cache_with_generate(self): model, tokenizer = load(HF_MODEL_PATH) prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] - results = zip(range(4), generate_step(prompt, model)) - toks, all_logits = zip(*(r[1] for r in results)) + results = list(generate_step(prompt, model, max_tokens=4)) + toks, all_logits = zip(*results) prompt_cache = make_prompt_cache(model) i = 0 - for _, (tok, logits) in zip( - range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + for tok, logits in generate_step( + prompt, model, prompt_cache=prompt_cache, max_tokens=2 ): self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i])) i += 1 - for _, (tok, logits) in zip( - range(1), - generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + for tok, logits in generate_step( + mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1 ): i += 1 self.assertEqual(tok, toks[i]) @@ -185,6 +186,18 @@ class TestPromptCache(unittest.TestCase): num_trimmed = trim_prompt_cache(cache, 4) self.assertEqual(num_trimmed, 0) + cache = [QuantizedKVCache() for _ in range(2)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 64)) + c.update_and_fetch(x, x) + + num_trimmed = trim_prompt_cache(cache, 7) + self.assertEqual(num_trimmed, 7) + + # Trim more tokens than remain + num_trimmed = trim_prompt_cache(cache, 4) + self.assertEqual(num_trimmed, 3) + def test_trim_cache_with_generate(self): model, tokenizer = load(HF_MODEL_PATH) prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] @@ -215,6 +228,78 @@ class TestPromptCache(unittest.TestCase): all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits)) ) + def test_cache_copying(self): + cache = [KVCache()] + + x = mx.random.uniform(shape=(1, 8, 10, 4)) + cache[0].update_and_fetch(x, x) + + y = mx.random.uniform(shape=(1, 8, 1, 4)) + cache[0].update_and_fetch(y, y) + + old_cache = copy.deepcopy(cache) + + trim_prompt_cache(cache, 1) + + self.assertTrue(old_cache[0].offset, 11) + self.assertTrue(cache[0].offset, 10) + + z = mx.random.uniform(shape=(1, 8, 1, 4)) + cache[0].update_and_fetch(z, z) + + self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y)) + self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z)) + + def test_save_load_quantized_cache(self): + cache = [QuantizedKVCache(bits=4, group_size=32) for _ in range(4)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 32)) + c.update_and_fetch(x, x) + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + self.assertTrue(loaded_cache[0].bits == cache[0].bits) + self.assertTrue(loaded_cache[0].group_size == cache[0].group_size) + self.assertTrue(len(cache), len(loaded_cache)) + for c, lc in zip(cache, loaded_cache): + self.assertEqual(c.offset, lc.offset) + # Loop over quantized tuple + for i in range(3): + self.assertTrue(mx.array_equal(c.state[0][i], lc.state[0][i])) + self.assertTrue(mx.array_equal(c.state[1][i], lc.state[1][i])) + + # Test with metadata + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + metadata = {"a": "b", "c": "d"} + save_prompt_cache(cache_file, cache, metadata) + _, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True) + self.assertEqual(metadata, loaded_metadata) + + def test_cache_to_quantized(self): + model, tokenizer = load(HF_MODEL_PATH) + prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] + results = zip(range(4), generate_step(prompt, model)) + toks, all_logits = zip(*(r[1] for r in results)) + + prompt_cache = make_prompt_cache(model) + i = 0 + for _, (tok, logits) in zip( + range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + ): + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i])) + i += 1 + + prompt_cache = [c.to_quantized(bits=8, group_size=32) for c in prompt_cache] + + for _, (tok, logits) in zip( + range(1), + generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + ): + i += 1 + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2)) + if __name__ == "__main__": unittest.main() diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index ec0e2cb7..c45fa443 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -1,10 +1,10 @@ import unittest import mlx.core as mx -from mlx_lm.sample_utils import top_p_sampling +from mlx_lm.sample_utils import min_p_sampling, top_k_sampling, top_p_sampling -class TestSamplingUtils(unittest.TestCase): +class TestSampleUtils(unittest.TestCase): def test_top_p_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) @@ -28,6 +28,41 @@ class TestSamplingUtils(unittest.TestCase): token = top_p_sampling(logits, 0.95, temperature).item() self.assertTrue(token in (1, 2, 3)) + def test_min_p_sampling(self): + probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] + logits = mx.log(probs) + temperature = 1.0 + token = min_p_sampling(logits, 0.8) + self.assertEqual(token, 0) + + probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] + logits = mx.log(probs) + temperature = 1.0 + for _ in range(5): + token = min_p_sampling(logits, 0.05) + self.assertTrue(token in (0, 3)) + + def test_top_k_sampling(self): + probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] + logits = mx.log(probs) + + token = top_k_sampling(logits, 1).item() + self.assertEqual(token, 0) + + probs = mx.array([0.5, 0.0, 0.0, 0.5])[None] + tokens = set() + for _ in range(100): + token = top_k_sampling(logits, 2) + tokens.add(token.item()) + self.assertEqual(tokens, {0, 3}) + + # Batch mode works + probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) + logits = mx.log(probs) + + tokens = top_k_sampling(logits, 1) + self.assertEqual(tokens.tolist(), [0, 1]) + if __name__ == "__main__": unittest.main() diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index cbcccfbe..ad17554d 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -14,6 +14,7 @@ class DummyModelProvider: def __init__(self): HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" self.model, self.tokenizer = load(HF_MODEL_PATH) + self.model_key = (HF_MODEL_PATH, None) def load(self, model, adapter=None): assert model in ["default_model", "chat_model"] diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py new file mode 100644 index 00000000..3009d1b1 --- /dev/null +++ b/llms/tests/test_tokenizers.py @@ -0,0 +1,98 @@ +# Copyright © 2024 Apple Inc. + +import unittest +from pathlib import Path + +from huggingface_hub import snapshot_download +from mlx_lm.tokenizer_utils import ( + BPEStreamingDetokenizer, + NaiveStreamingDetokenizer, + SPMStreamingDetokenizer, + load_tokenizer, +) + + +class TestTokenizers(unittest.TestCase): + + def download_tokenizer(self, repo): + path = Path( + snapshot_download( + repo_id=repo, + allow_patterns=[ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "tokenizer.model", + ], + ) + ) + return load_tokenizer(path) + + def check_tokenizer(self, tokenizer): + def check(tokens): + expected_text = tokenizer.decode(tokens) + detokenizer = tokenizer.detokenizer + detokenizer.reset() + text = "" + for e, t in enumerate(tokens): + detokenizer.add_token(t) + seg = detokenizer.last_segment + text += seg + self.assertEqual(detokenizer.tokens, tokens[: e + 1]) + detokenizer.finalize() + text += detokenizer.last_segment + self.assertEqual(text, expected_text) + + tokens = tokenizer.encode("こんにちは!私の名前はAI") + check(tokens) + + tokens = tokenizer.encode("a ,b") + check(tokens) + + tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}') + check(tokens) + + tokens = tokenizer.encode("3 3") + check(tokens) + + tokens = tokenizer.encode("import 'package:flutter/material.dart';") + check(tokens) + + tokens = tokenizer.encode("hello\nworld") + check(tokens) + + def test_tokenizers(self): + tokenizer_repos = [ + ("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer), + ("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer), + ("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer), + ("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer), + ("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer), + ("mlx-community/Falcon3-7B-Instruct-4bit", BPEStreamingDetokenizer), + ] + for tokenizer_repo, expected_detokenizer in tokenizer_repos: + with self.subTest(tokenizer=tokenizer_repo): + tokenizer = self.download_tokenizer(tokenizer_repo) + tokenizer.decode([0, 1, 2]) + self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer)) + self.check_tokenizer(tokenizer) + + # Try one with a naive detokenizer + tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit") + tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer) + self.check_tokenizer(tokenizer) + + def test_special_tokens(self): + tokenizer_repo = "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" + tokenizer = self.download_tokenizer(tokenizer_repo) + + detokenizer = tokenizer.detokenizer + detokenizer.reset() + detokenizer.add_token(tokenizer.eos_token_id) + detokenizer.finalize() + + self.assertEqual(detokenizer.last_segment, tokenizer.eos_token) + + +if __name__ == "__main__": + unittest.main() diff --git a/llms/tests/test_utils_load_model.py b/llms/tests/test_utils_load_model.py index 73ee1352..8da19afb 100644 --- a/llms/tests/test_utils_load_model.py +++ b/llms/tests/test_utils_load_model.py @@ -17,7 +17,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase): self.config = args self.custom_attribute = "This is a custom model" - def load_weights(self, weights): + def load_weights(self, weights, **kwargs): self.qwenWeights = weights class CustomQwenConfig: @@ -32,7 +32,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase): return CustomQwenModel, CustomQwenConfig model_path = get_model_path(HF_MODEL_PATH) - model = load_model(model_path, get_model_classes=custom_get_classes) + model, _ = load_model(model_path, get_model_classes=custom_get_classes) self.assertIsInstance(model, CustomQwenModel) self.assertTrue(hasattr(model, "custom_attribute")) @@ -41,7 +41,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase): def test_load_model_with_default_get_classes(self): model_path = get_model_path(HF_MODEL_PATH) - model = load_model(model_path) + model, _ = load_model(model_path) self.assertIsInstance(model, Qwen2Model) diff --git a/speechcommands/main.py b/speechcommands/main.py index 0d8da9fd..ed328f4c 100644 --- a/speechcommands/main.py +++ b/speechcommands/main.py @@ -76,6 +76,7 @@ def train_epoch(model, train_iter, optimizer, epoch): samples_per_sec = [] model.train(True) + train_iter.reset() for batch_counter, batch in enumerate(train_iter): x = mx.array(batch["audio"]) y = mx.array(batch["label"]) @@ -111,6 +112,7 @@ def test_epoch(model, test_iter): model.train(False) accs = [] throughput = [] + test_iter.reset() for batch_counter, batch in enumerate(test_iter): x = mx.array(batch["audio"]) y = mx.array(batch["label"]) diff --git a/stable_diffusion/image2image.py b/stable_diffusion/image2image.py index e470aa81..4444c488 100644 --- a/stable_diffusion/image2image.py +++ b/stable_diffusion/image2image.py @@ -30,6 +30,7 @@ if __name__ == "__main__": parser.add_argument("--preload-models", action="store_true") parser.add_argument("--output", default="out.png") parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--seed", type=int) args = parser.parse_args() # Load the models @@ -94,6 +95,7 @@ if __name__ == "__main__": cfg_weight=args.cfg, num_steps=args.steps, negative_text=args.negative_prompt, + seed=args.seed, ) for x_t in tqdm(latents, total=int(args.steps * args.strength)): mx.eval(x_t) diff --git a/whisper/README.md b/whisper/README.md index ac6e95f6..cd3bc684 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -25,7 +25,7 @@ pip install mlx-whisper At its simplest: -``` +```sh mlx_whisper audio_file.mp3 ``` @@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There are many other supported command line options. To see them all, run `mlx_whisper -h`. +You can also pipe the audio content of other programs via stdin: + +```sh +some-process | mlx_whisper - +``` + +The default output file name will be `content.*`. You can specify the name with +the `--output-name` flag. + #### API Transcribe audio with: @@ -103,7 +112,7 @@ python convert.py --help ``` By default, the conversion script will make the directory `mlx_models` -and save the converted `weights.npz` and `config.json` there. +and save the converted `weights.npz` and `config.json` there. Each time it is run, `convert.py` will overwrite any model in the provided path. To save different models, make sure to set `--mlx-path` to a unique diff --git a/whisper/convert.py b/whisper/convert.py index cdd50bc5..7369fafa 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -174,14 +174,9 @@ def load_torch_weights_and_config( "*.txt", ], ) - else: - raise RuntimeError( - f"Model {name_or_path} is not found in {available_models()}," - "on Hugging Face or as a local path." - ) if name_or_path.endswith(".pt"): - checkpoint = torch.load(name_or_path, map_location="cpu") + checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False) weights, config = checkpoint["model_state_dict"], checkpoint["dims"] else: name_or_path = Path(name_or_path) @@ -387,7 +382,7 @@ if __name__ == "__main__": # Save weights print("[INFO] Saving") - np.savez(str(mlx_path / "weights.npz"), **weights) + mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights) # Save config.json with model_type with open(str(mlx_path / "config.json"), "w") as f: diff --git a/whisper/mlx_whisper/_version.py b/whisper/mlx_whisper/_version.py index 67c7397c..8280e038 100644 --- a/whisper/mlx_whisper/_version.py +++ b/whisper/mlx_whisper/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.3.0" +__version__ = "0.4.1" diff --git a/whisper/mlx_whisper/audio.py b/whisper/mlx_whisper/audio.py index e04309c1..c8cca07c 100644 --- a/whisper/mlx_whisper/audio.py +++ b/whisper/mlx_whisper/audio.py @@ -3,7 +3,7 @@ import os from functools import lru_cache from subprocess import CalledProcessError, run -from typing import Union +from typing import Optional, Union import mlx.core as mx import numpy as np @@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token -def load_audio(file: str, sr: int = SAMPLE_RATE): +def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False): """ Open an audio file and read as mono waveform, resampling as necessary @@ -39,19 +39,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): """ # This launches a subprocess to decode audio while down-mixing - # and resampling as necessary. Requires the ffmpeg CLI in PATH. + # and resampling as necessary. Requires the ffmpeg CLI in PATH. + if from_stdin: + cmd = ["ffmpeg", "-i", "pipe:0"] + else: + cmd = ["ffmpeg", "-nostdin", "-i", file] + # fmt: off - cmd = [ - "ffmpeg", - "-nostdin", + cmd.extend([ "-threads", "0", - "-i", file, "-f", "s16le", "-ac", "1", "-acodec", "pcm_s16le", "-ar", str(sr), "-" - ] + ]) # fmt: on try: out = run(cmd, capture_output=True, check=True).stdout diff --git a/whisper/mlx_whisper/cli.py b/whisper/mlx_whisper/cli.py index c2813338..7d08a043 100644 --- a/whisper/mlx_whisper/cli.py +++ b/whisper/mlx_whisper/cli.py @@ -2,9 +2,11 @@ import argparse import os +import pathlib import traceback import warnings +from . import audio from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE from .transcribe import transcribe from .writers import get_writer @@ -27,15 +29,24 @@ def build_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument( - "audio", nargs="+", type=str, help="Audio file(s) to transcribe" - ) + + parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe") + parser.add_argument( "--model", default="mlx-community/whisper-tiny", type=str, help="The model directory or hugging face repo", ) + parser.add_argument( + "--output-name", + type=str, + default=None, + help=( + "The name of transcription/translation output files before " + "--output-format extensions" + ), + ) parser.add_argument( "--output-dir", "-o", @@ -200,6 +211,7 @@ def main(): path_or_hf_repo: str = args.pop("model") output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") + output_name: str = args.pop("output_name") os.makedirs(output_dir, exist_ok=True) writer = get_writer(output_format, output_dir) @@ -219,17 +231,25 @@ def main(): warnings.warn("--max-line-count has no effect without --max-line-width") if writer_args["max_words_per_line"] and writer_args["max_line_width"]: warnings.warn("--max-words-per-line has no effect with --max-line-width") - for audio_path in args.pop("audio"): + + for audio_obj in args.pop("audio"): + if audio_obj == "-": + # receive the contents from stdin rather than read a file + audio_obj = audio.load_audio(from_stdin=True) + + output_name = output_name or "content" + else: + output_name = output_name or pathlib.Path(audio_obj).stem try: result = transcribe( - audio_path, + audio_obj, path_or_hf_repo=path_or_hf_repo, **args, ) - writer(result, audio_path, **writer_args) + writer(result, output_name, **writer_args) except Exception as e: traceback.print_exc() - print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}") + print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}") if __name__ == "__main__": diff --git a/whisper/mlx_whisper/decoding.py b/whisper/mlx_whisper/decoding.py index 41c2ec6d..4e060cd5 100644 --- a/whisper/mlx_whisper/decoding.py +++ b/whisper/mlx_whisper/decoding.py @@ -58,11 +58,12 @@ def detect_language( logits = model.logits(x, mel)[:, 0] # collect detected languages; suppress all non-language tokens - mask = np.full(logits.shape[-1], -np.inf, dtype=np.float32) + mask = mx.full(logits.shape[-1], -mx.inf, dtype=mx.float32) mask[list(tokenizer.all_language_tokens)] = 0.0 - logits += mx.array(mask) + logits += mask language_tokens = mx.argmax(logits, axis=-1) language_token_probs = mx.softmax(logits, axis=-1) + language_token_probs = np.array(language_token_probs) language_probs = [ { c: language_token_probs[i, j].item() @@ -129,17 +130,12 @@ class DecodingResult: class Inference: - def __init__(self, model: "Whisper", initial_token_length: int): + def __init__(self, model: "Whisper"): self.model: "Whisper" = model - self.initial_token_length = initial_token_length self.kv_cache = None def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array: """Perform a forward pass on the decoder and return per-token logits""" - if tokens.shape[-1] > self.initial_token_length: - # only need to use the last token except in the first forward pass - tokens = tokens[:, -1:] - logits, self.kv_cache, _ = self.model.decoder( tokens, audio_features, kv_cache=self.kv_cache ) @@ -251,6 +247,11 @@ class TokenDecoder: raise NotImplementedError +@mx.compile +def categorical(logits, temp): + return mx.random.categorical(logits / temp) + + class GreedyDecoder(TokenDecoder): def __init__(self, temperature: float, eot: int): self.temperature = temperature @@ -262,10 +263,8 @@ class GreedyDecoder(TokenDecoder): if self.temperature == 0: next_tokens = logits.argmax(axis=-1) else: - next_tokens = mx.random.categorical(logits=logits / self.temperature) + next_tokens = categorical(logits, self.temperature) - next_tokens = mx.argmax(logits, axis=-1) - logits = logits.astype(mx.float32) logprobs = logits - mx.logsumexp(logits, axis=-1) current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens] @@ -281,7 +280,7 @@ class GreedyDecoder(TokenDecoder): def finalize(self, tokens: mx.array, sum_logprobs: mx.array): # make sure each sequence has at least one EOT token at the end tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot) - return tokens, sum_logprobs.tolist() + return tokens, sum_logprobs class LogitFilter: @@ -340,10 +339,10 @@ class ApplyTimestampRules(LogitFilter): if self.tokenizer.no_timestamps is not None: mask[:, self.tokenizer.no_timestamps] = -np.inf - # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly - for k in range(tokens.shape[0]): - sampled_tokens = tokens[k, self.sample_begin :] - seq = sampled_tokens.tolist() + ## timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + tokens = tokens.tolist() + for k in range(len(tokens)): + seq = tokens[k][self.sample_begin :] last_was_timestamp = ( len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin ) @@ -368,7 +367,7 @@ class ApplyTimestampRules(LogitFilter): last_timestamp += 1 mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf - if tokens.shape[1] == self.sample_begin: + if len(tokens[0]) == self.sample_begin: # suppress generating non-timestamp tokens at the beginning mask[:, : self.tokenizer.timestamp_begin] = -np.inf @@ -380,16 +379,20 @@ class ApplyTimestampRules(LogitFilter): mask[:, last_allowed + 1 :] = -np.inf # if sum of probability over timestamps is above any other token, sample timestamp + mask = mx.array(mask) logprobs = logits - mx.logsumexp(logits, axis=-1) - for k in range(tokens.shape[0]): - timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp( - axis=-1 - ) - max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() - if timestamp_logprob > max_text_token_logprob: - mask[k, : self.tokenizer.timestamp_begin] = -np.inf - - return logits + mx.array(mask, logits.dtype) + timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp( + axis=-1, keepdims=True + ) + max_text_token_logprob = logprobs[:, : self.tokenizer.timestamp_begin].max( + axis=-1, keepdims=True + ) + mask[:, : self.tokenizer.timestamp_begin] = mx.where( + timestamp_logprob > max_text_token_logprob, + -mx.inf, + mask[:, : self.tokenizer.timestamp_begin], + ) + return logits + mask class DecodingTask: @@ -424,7 +427,7 @@ class DecodingTask: self.sot_index: int = self.initial_tokens.index(tokenizer.sot) # inference: implements the forward pass through the decoder, including kv caching - self.inference = Inference(model, len(self.initial_tokens)) + self.inference = Inference(model) # sequence ranker: implements how to rank a group of sampled sequences self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) @@ -432,9 +435,6 @@ class DecodingTask: # decoder: implements how to select the next tokens, given the autoregressive distribution if options.beam_size is not None: raise NotImplementedError("Beam search decoder is not yet implemented") - # self.decoder = BeamSearchDecoder( - # options.beam_size, tokenizer.eot, self.inference, options.patience - # ) else: self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) @@ -448,6 +448,7 @@ class DecodingTask: self.logit_filters.append( SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab) ) + if not options.without_timestamps: precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds max_initial_timestamp_index = None @@ -570,48 +571,59 @@ class DecodingTask: def _main_loop(self, audio_features: mx.array, tokens: mx.array): n_batch = tokens.shape[0] - sum_logprobs: mx.array = mx.zeros(n_batch) - no_speech_probs = [np.nan] * n_batch + sum_logprobs = mx.zeros(n_batch) - try: - for i in range(self.sample_len): - logits = self.inference.logits(tokens, audio_features) + def _step(inputs, audio_features, tokens, sum_logprobs): + pre_logits = self.inference.logits(inputs, audio_features) - if ( - i == 0 and self.tokenizer.no_speech is not None - ): # save no_speech_probs - probs_at_sot = mx.softmax( - logits[:, self.sot_index].astype(mx.float32), axis=-1 - ) - no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() + # consider the logits at the last token only + logits = pre_logits[:, -1] - # now we need to consider the logits at the last token only - logits = logits[:, -1] + # apply the logit filters, e.g. for suppressing or applying penalty to + for logit_filter in self.logit_filters: + logits = logit_filter.apply(logits, tokens) - # apply the logit filters, e.g. for suppressing or applying penalty to - for logit_filter in self.logit_filters: - logits = logit_filter.apply(logits, tokens) + # expand the tokens tensor with the selected next tokens + tokens, completed, sum_logprobs = self.decoder.update( + tokens, logits, sum_logprobs + ) + return tokens, completed, sum_logprobs, pre_logits - # expand the tokens tensor with the selected next tokens - tokens, completed, sum_logprobs = self.decoder.update( - tokens, logits, sum_logprobs - ) + tokens, completed, sum_logprobs, pre_logits = _step( + tokens, audio_features, tokens, sum_logprobs + ) + if self.tokenizer.no_speech is not None: # compute no_speech_probs + probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1) + no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech] + else: + no_speech_probs = mx.full(n_batch, mx.nan) + mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs) - if completed or tokens.shape[-1] > self.n_ctx: - break - finally: - self.inference.reset() + for i in range(1, self.sample_len): + inputs = tokens[:, -1:] + if tokens.shape[-1] > self.n_ctx: + break + next_tokens, next_completed, next_sum_logprobs, _ = _step( + inputs, audio_features, tokens, sum_logprobs + ) + mx.async_eval(next_completed, next_tokens, next_sum_logprobs) + if completed: + break + tokens = next_tokens + completed = next_completed + sum_logprobs = next_sum_logprobs return tokens, sum_logprobs, no_speech_probs def run(self, mel: mx.array) -> List[DecodingResult]: + self.inference.reset() self.decoder.reset() tokenizer: Tokenizer = self.tokenizer n_audio: int = mel.shape[0] audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass - tokens: np.array = np.array(self.initial_tokens) - tokens = np.broadcast_to(tokens, (n_audio, len(self.initial_tokens))).copy() + tokens: mx.array = mx.array(self.initial_tokens) + tokens = mx.broadcast_to(tokens, (n_audio, len(self.initial_tokens))) # detect language if requested, overwriting the language token languages, language_probs = self._detect_language(audio_features, tokens) @@ -626,7 +638,6 @@ class DecodingTask: ] # repeat tokens by the group size, for beam search or best-of-n sampling - tokens = mx.array(tokens) if self.n_group > 1: tokens = tokens[:, None, :] tokens = mx.broadcast_to( @@ -649,7 +660,13 @@ class DecodingTask: # get the final candidates for each group, and slice between the first sampled token and EOT tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) - tokens = tokens[..., self.sample_begin :].tolist() + tokens = tokens[..., self.sample_begin :] + + # eval and convert to list + mx.eval(tokens, sum_logprobs, no_speech_probs) + tokens = tokens.tolist() + sum_logprobs = sum_logprobs.tolist() + no_speech_probs = no_speech_probs.tolist() tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens] # select the top-ranked sample in each group diff --git a/whisper/mlx_whisper/load_models.py b/whisper/mlx_whisper/load_models.py index 6705385d..60766ab2 100644 --- a/whisper/mlx_whisper/load_models.py +++ b/whisper/mlx_whisper/load_models.py @@ -26,7 +26,10 @@ def load_model( model_args = whisper.ModelDimensions(**config) - weights = mx.load(str(model_path / "weights.npz")) + wf = model_path / "weights.safetensors" + if not wf.exists(): + wf = model_path / "weights.npz" + weights = mx.load(str(wf)) model = whisper.Whisper(model_args, dtype) diff --git a/whisper/mlx_whisper/transcribe.py b/whisper/mlx_whisper/transcribe.py index 786b4232..7057679b 100644 --- a/whisper/mlx_whisper/transcribe.py +++ b/whisper/mlx_whisper/transcribe.py @@ -293,6 +293,7 @@ def transcribe( decode_options["prompt"] = all_tokens[prompt_reset_since:] result: DecodingResult = decode_with_fallback(mel_segment) + tokens = np.array(result.tokens) if no_speech_threshold is not None: diff --git a/whisper/mlx_whisper/whisper.py b/whisper/mlx_whisper/whisper.py index e691792c..1c2b390e 100644 --- a/whisper/mlx_whisper/whisper.py +++ b/whisper/mlx_whisper/whisper.py @@ -80,12 +80,11 @@ class MultiHeadAttention(nn.Module): qk = q @ k if mask is not None: qk = qk + mask[:n_ctx, :n_ctx] - qk = qk.astype(mx.float32) - w = mx.softmax(qk, axis=-1).astype(q.dtype) + w = mx.softmax(qk, axis=-1, precise=True) out = (w @ v).transpose(0, 2, 1, 3) out = out.reshape(n_batch, n_ctx, n_state) - return out, qk + return out, qk.astype(mx.float32) class ResidualAttentionBlock(nn.Module): diff --git a/whisper/mlx_whisper/writers.py b/whisper/mlx_whisper/writers.py index 464ead18..cdb35063 100644 --- a/whisper/mlx_whisper/writers.py +++ b/whisper/mlx_whisper/writers.py @@ -1,10 +1,8 @@ # Copyright © 2024 Apple Inc. import json -import os +import pathlib import re -import sys -import zlib from typing import Callable, List, Optional, TextIO @@ -43,15 +41,13 @@ class ResultWriter: self.output_dir = output_dir def __call__( - self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs + self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs ): - audio_basename = os.path.basename(audio_path) - audio_basename = os.path.splitext(audio_basename)[0] - output_path = os.path.join( - self.output_dir, audio_basename + "." + self.extension + output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix( + f".{self.extension}" ) - with open(output_path, "w", encoding="utf-8") as f: + with output_path.open("wt", encoding="utf-8") as f: self.write_result(result, file=f, options=options, **kwargs) def write_result(