From c799133998a943affdc395c5cb159ead2576a225 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 14 Oct 2024 10:25:24 -0700 Subject: [PATCH 01/77] Make llm async eval less brittle (#1040) * Make llm async eval less brittle * nit --- llms/mlx_lm/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 1e07546e..4f872982 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -246,10 +246,10 @@ def generate_step( y, logprobs = _step(y) - mx.async_eval(y) + mx.async_eval(y, logprobs) while True: next_y, next_logprobs = _step(y) - mx.async_eval(next_y) + mx.async_eval(next_y, next_logprobs) yield y.item(), logprobs y, logprobs = next_y, next_logprobs @@ -348,7 +348,9 @@ def generate( if formatter: # We have to finalize so that the prob corresponds to the last segment detokenizer.finalize() - formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) + with mx.stream(mx.cpu): + prob = mx.exp(logprobs[token]).item() + formatter(detokenizer.last_segment, prob) else: print(detokenizer.last_segment, end="", flush=True) From 6c368f212473dc11e46d19c5dc65410ee70b2594 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 14 Oct 2024 10:40:36 -0700 Subject: [PATCH 02/77] bump mac tests to use py39 (#1047) --- .circleci/config.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 02fa1de8..cecd2d57 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -26,8 +26,8 @@ jobs: - run: name: Install dependencies command: | - brew install python@3.8 - python3.8 -m venv env + brew install python@3.9 + python3.9 -m venv env source env/bin/activate pip install --upgrade pip pip install unittest-xml-reporting From 8dca1a2f6091f443ffb54b5f39a390f6971e677c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 14 Oct 2024 10:48:46 -0700 Subject: [PATCH 03/77] Tokenizer updates + tests (#1024) * tokenizer updates + tests * nit * add can_trim_prompt_cache * nits --- llms/mlx_lm/models/cache.py | 9 +++- llms/mlx_lm/models/deepseek_v2.py | 6 +-- llms/mlx_lm/tokenizer_utils.py | 40 ++++++++-------- llms/tests/test_tokenizers.py | 76 +++++++++++++++++++++++++++++++ 4 files changed, 108 insertions(+), 23 deletions(-) create mode 100644 llms/tests/test_tokenizers.py diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index b06422e5..a6a56e0a 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False): return cache +def can_trim_prompt_cache(cache: List[Any]) -> bool: + """ + Check if model's cache can be trimmed. + """ + return all(c.is_trimmable() for c in cache) + + def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: """ Trim the model's cache by the given number of tokens. @@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: Returns: (int): The number of tokens that were trimmed. """ - if not all(c.is_trimmable() for c in cache) or len(cache) == 0: + if not can_trim_prompt_cache(cache) or len(cache) == 0: return 0 return [c.trim(num_tokens) for c in cache][0] diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 17d061a8..bb3e5184 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -220,17 +220,17 @@ class DeepseekV2Attention(nn.Module): k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) - k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1) - if cache is not None: q_pe = self.rope(q_pe, cache.offset) k_pe = self.rope(k_pe, cache.offset) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) keys, values = cache.update_and_fetch( mx.concatenate([k_nope, k_pe], axis=-1), values ) else: q_pe = self.rope(q_pe) k_pe = self.rope(k_pe) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) keys = mx.concatenate([k_nope, k_pe], axis=-1) queries = mx.concatenate([q_nope, q_pe], axis=-1) @@ -291,7 +291,7 @@ class MoEGate(nn.Module): scores = scores.reshape(bsz, seq_len, -1) k = self.top_k - inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]) + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] scores = mx.take_along_axis(scores, inds, axis=-1) scores = scores * self.routed_scaling_factor diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 04bbbcc5..d8694d86 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -97,6 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer): def text(self): if self._current_tokens: self._current_text = self._tokenizer.decode(self._current_tokens) + if ( + self._tokenizer.clean_up_tokenization_spaces + and self._current_text[-1] == " " + ): + self._current_text = self._current_text[:-1] if self._current_text and self._current_text[-1] == "\n": self._tokens.extend(self._current_tokens) self._text += self._current_text @@ -164,9 +169,11 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): """ _byte_decoder = None + _space_matches = (".", "?", "!", ",", "'", "n't", "'m", "'s", "'ve", "'re") - def __init__(self, tokenizer, trim_space=False): - self.trim_space = trim_space + def __init__(self, tokenizer): + + self.clean_spaces = tokenizer.clean_up_tokenization_spaces # Extract the tokens in a list from id to text self.tokenmap = [None] * len(tokenizer.vocab) @@ -185,17 +192,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): self.text = "" self.tokens = [] + def _maybe_trim_space(self, current_text): + if current_text[0] != " ": + return current_text + elif not self.text: + return current_text[1:] + elif self.clean_spaces and current_text[1:].startswith(self._space_matches): + return current_text[1:] + return current_text + def add_token(self, token): v = self.tokenmap[token] - # if the token starts with space if self._byte_decoder[v[0]] == 32: current_text = bytearray( self._byte_decoder[c] for c in self._unflushed ).decode("utf-8") - if self.text or not self.trim_space: - self.text += current_text - else: - self.text += _remove_space(current_text) + self.text += self._maybe_trim_space(current_text) self._unflushed = v else: self._unflushed += v @@ -204,10 +216,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( "utf-8" ) - if self.text or not self.trim_space: - self.text += current_text - else: - self.text += _remove_space(current_text) + self.text += self._maybe_trim_space(current_text) self._unflushed = "" @classmethod @@ -303,14 +312,7 @@ def _is_spm_decoder_no_space(decoder): def _is_bpe_decoder(decoder): - _target_description = { - "type": "ByteLevel", - "add_prefix_space": False, - "trim_offsets": False, - "use_regex": False, - } - - return _match(_target_description, decoder) + return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" def load_tokenizer(model_path, tokenizer_config_extra={}): diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py new file mode 100644 index 00000000..7b4828b1 --- /dev/null +++ b/llms/tests/test_tokenizers.py @@ -0,0 +1,76 @@ +# Copyright © 2024 Apple Inc. + +import unittest +from pathlib import Path + +from huggingface_hub import snapshot_download +from mlx_lm.tokenizer_utils import ( + BPEStreamingDetokenizer, + NaiveStreamingDetokenizer, + SPMStreamingDetokenizer, + load_tokenizer, +) + + +class TestTokenizers(unittest.TestCase): + + def download_tokenizer(self, repo): + path = Path( + snapshot_download( + repo_id=repo, + allow_patterns=[ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "tokenizer.model", + ], + ) + ) + return load_tokenizer(path) + + def check_tokenizer(self, tokenizer): + def check(tokens): + expected_text = tokenizer.decode(tokens) + detokenizer = tokenizer.detokenizer + detokenizer.reset() + text = "" + for t in tokens: + detokenizer.add_token(t) + seg = detokenizer.last_segment + text += seg + detokenizer.finalize() + text += detokenizer.last_segment + self.assertEqual(text, expected_text) + + tokens = tokenizer.encode("a ,b") + check(tokens) + + tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}') + check(tokens) + + tokens = tokenizer.encode("3 3") + check(tokens) + + def test_tokenizers(self): + tokenizer_repos = [ + ("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer), + ("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer), + ("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer), + ("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer), + ("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer), + ] + for tokenizer_repo, expected_detokenizer in tokenizer_repos: + with self.subTest(tokenizer=tokenizer_repo): + tokenizer = self.download_tokenizer(tokenizer_repo) + tokenizer.decode([0, 1, 2]) + self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer)) + self.check_tokenizer(tokenizer) + + # Try one with a naive detokenizer + tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit") + tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer) + self.check_tokenizer(tokenizer) + + +if __name__ == "__main__": + unittest.main() From 605c4854f1547e8eb0ef3f9c9d81c8aef3196c15 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 14 Oct 2024 10:57:22 -0700 Subject: [PATCH 04/77] Prompt caching in `mlx_lm.server` (#1026) * caching in server * nits * fix tests * don't throw if no metal * comments --- llms/mlx_lm/SERVER.md | 38 ++++++++-- llms/mlx_lm/server.py | 121 ++++++++++++++++++++++++-------- llms/tests/test_prompt_cache.py | 23 ++++++ llms/tests/test_server.py | 1 + 4 files changed, 151 insertions(+), 32 deletions(-) diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index 55be1c9c..2976a09f 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \ - `role_mapping`: (Optional) A dictionary to customize the role prefixes in the generated prompt. If not provided, the default mappings are used. -- `stop`: (Optional) An array of strings or a single string. Thesse are +- `stop`: (Optional) An array of strings or a single string. These are sequences of tokens on which the generation should stop. - `max_tokens`: (Optional) An integer specifying the maximum number of tokens @@ -84,7 +84,37 @@ curl localhost:8080/v1/chat/completions \ started in. - `adapters`: (Optional) A string path to low-rank adapters. The path must be - rlative to the directory the server was started in. + relative to the directory the server was started in. + +### Response Fields + +- `id`: A unique identifier for the chat. + +- `system_fingerprint`: A unique identifier for the system. + +- `object`: Any of "chat.completions", "chat.completions.chunk" (for + streaming), or "text.completion". + +- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`). + +- `created`: A time-stamp for when the request was processed. + +- `choices`: A list of outputs. Each output is a dictionary containing the fields: + - `index`: The index in the list. + - `logprobs`: A dictionary containing the fields: + - `token_logprobs`: A list of the log probabilities for the generated + tokens. + - `tokens`: A list of the generated token ids. + - `top_logprobs`: A list of lists. Each list contains the `logprobs` + top tokens (if requested) with their corresponding probabilities. + - `finish_reason`: The reason the completion ended. This can be either of + `"stop"` or `"length"`. + - `message`: The text response from the model. + +- `usage`: A dictionary containing the fields: + - `prompt_tokens`: The number of prompt tokens processed. + - `completion_tokens`: The number of tokens generated. + - `total_tokens`: The total number of tokens, i.e. the sum of the above two fields. ### List Models @@ -97,5 +127,5 @@ curl localhost:8080/v1/models -H "Content-Type: application/json" This will return a list of locally available models where each model in the list contains the following fields: -- `"id"`: The Hugging Face repo id. -- `"created"`: A timestamp representing the model creation time. +- `id`: The Hugging Face repo id. +- `created`: A time-stamp representing the model creation time. diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 42962b54..ec659969 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -3,19 +3,38 @@ import argparse import json import logging +import platform import time import uuid import warnings +from dataclasses import dataclass, field from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path -from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union +from typing import ( + Any, + Dict, + List, + Literal, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) import mlx.core as mx from huggingface_hub import scan_cache_dir +from ._version import __version__ +from .models.cache import make_prompt_cache from .utils import generate_step, load +def get_system_fingerprint(): + gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else "" + return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}" + + class StopCondition(NamedTuple): stop_met: bool trim_length: int @@ -94,6 +113,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None): return prompt.rstrip() +@dataclass +class PromptCache: + cache: List[Any] = field(default_factory=list) + model_key: Tuple[str, Optional[str]] = ("", None) + tokens: List[int] = field(default_factory=list) + + class ModelProvider: def __init__(self, cli_args: argparse.Namespace): """Load models on demand and persist them across the whole process.""" @@ -156,12 +182,21 @@ class ModelProvider: class APIHandler(BaseHTTPRequestHandler): - def __init__(self, model_provider: ModelProvider, *args, **kwargs): + def __init__( + self, + model_provider: ModelProvider, + *args, + prompt_cache: Optional[PromptCache] = None, + system_fingerprint: Optional[str] = None, + **kwargs, + ): """ Create static request specific metadata """ self.created = int(time.time()) self.model_provider = model_provider + self.prompt_cache = prompt_cache or PromptCache() + self.system_fingerprint = system_fingerprint or get_system_fingerprint() super().__init__(*args, **kwargs) def _set_cors_headers(self): @@ -215,7 +250,9 @@ class APIHandler(BaseHTTPRequestHandler): self.stream_options = self.body.get("stream_options", None) self.requested_model = self.body.get("model", "default_model") self.adapter = self.body.get("adapters", None) - self.max_tokens = self.body.get("max_tokens", 100) + self.max_tokens = self.body.get("max_completion_tokens", None) + if self.max_tokens is None: + self.max_tokens = self.body.get("max_tokens", 512) self.temperature = self.body.get("temperature", 1.0) self.top_p = self.body.get("top_p", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0) @@ -343,7 +380,7 @@ class APIHandler(BaseHTTPRequestHandler): # Static response response = { "id": self.request_id, - "system_fingerprint": f"fp_{uuid.uuid4()}", + "system_fingerprint": self.system_fingerprint, "object": self.object_type, "model": self.requested_model, "created": self.created, @@ -388,16 +425,30 @@ class APIHandler(BaseHTTPRequestHandler): return response + def get_prompt_cache(self, prompt): + cache_len = len(self.prompt_cache.tokens) + if ( + self.prompt_cache.model_key != self.model_provider.model_key + or cache_len >= len(prompt) + or self.prompt_cache.tokens != prompt[:cache_len] + ): + self.prompt_cache.model_key = self.model_provider.model_key + self.prompt_cache.cache = make_prompt_cache(self.model_provider.model) + else: + prompt = prompt[cache_len:] + self.prompt_cache.tokens.extend(prompt) + return prompt + def handle_completion( self, - prompt: mx.array, + prompt: List[int], stop_id_sequences: List[List[int]], ): """ Generate a response to a prompt and send it to the client in a single batch. Args: - prompt (mx.array): The prompt, in token form inside of a mlx array + prompt (List[int]): The tokenized prompt. stop_id_sequences (List[List[int]]): A list of stop words passed to the stopping_criteria function """ @@ -409,17 +460,21 @@ class APIHandler(BaseHTTPRequestHandler): logging.debug(f"Starting completion:") token_logprobs = [] top_tokens = [] - for (token, logprobs), _ in zip( + + prompt = self.get_prompt_cache(prompt) + + for _, (token, logprobs) in zip( + range(self.max_tokens), generate_step( - prompt=prompt, + prompt=mx.array(prompt), model=self.model, temp=self.temperature, top_p=self.top_p, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, logit_bias=self.logit_bias, + prompt_cache=self.prompt_cache.cache, ), - range(self.max_tokens), ): detokenizer.add_token(token) logging.debug(detokenizer.text) @@ -430,7 +485,7 @@ class APIHandler(BaseHTTPRequestHandler): top_indices = sorted_indices[: self.logprobs] top_logprobs = logprobs[top_indices] top_token_info = zip(top_indices.tolist(), top_logprobs.tolist()) - top_tokens.append(dict(top_token_info)) + top_tokens.append(tuple(top_token_info)) token_logprobs.append(logprobs[token].item()) @@ -445,6 +500,7 @@ class APIHandler(BaseHTTPRequestHandler): ) break + self.prompt_cache.tokens.extend(tokens) detokenizer.finalize() text = ( detokenizer.text @@ -474,7 +530,7 @@ class APIHandler(BaseHTTPRequestHandler): def handle_stream( self, - prompt: mx.array, + prompt: List[int], stop_id_sequences: List[List[int]], ): """ @@ -482,7 +538,7 @@ class APIHandler(BaseHTTPRequestHandler): Sent Events (SSE) stream. Args: - prompt (mx.array): The prompt, in token form inside of a mlx array + prompt (mx.array): The tokenized prompt stop_id_sequences (List[List[int]]): A list of stop words passed to the stopping_criteria function """ @@ -496,16 +552,19 @@ class APIHandler(BaseHTTPRequestHandler): stop_sequence_suffix = None logging.debug(f"Starting stream:") - for (token, _), _ in zip( + prompt = self.get_prompt_cache(prompt) + + for _, (token, _) in zip( + range(self.max_tokens), generate_step( - prompt=prompt, + prompt=mx.array(prompt), model=self.model, temp=self.temperature, top_p=self.top_p, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, + prompt_cache=self.prompt_cache.cache, ), - range(self.max_tokens), ): detokenizer.add_token(token) logging.debug(detokenizer.text) @@ -531,9 +590,12 @@ class APIHandler(BaseHTTPRequestHandler): continue new_text = detokenizer.last_segment - response = self.generate_response(new_text, None) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() + if new_text: + response = self.generate_response(new_text, None) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + + self.prompt_cache.tokens.extend(tokens) # check is there any remaining text to send detokenizer.finalize() @@ -559,7 +621,7 @@ class APIHandler(BaseHTTPRequestHandler): ): response = { "id": self.request_id, - "system_fingerprint": f"fp_{uuid.uuid4()}", + "system_fingerprint": self.system_fingerprint, "object": "chat.completion", "model": self.requested_model, "created": self.created, @@ -572,7 +634,7 @@ class APIHandler(BaseHTTPRequestHandler): } return response - def handle_chat_completions(self) -> mx.array: + def handle_chat_completions(self) -> List[int]: """ Handle a chat completion request. @@ -587,7 +649,6 @@ class APIHandler(BaseHTTPRequestHandler): self.object_type = ( "chat.completions.chunk" if self.stream else "chat.completions" ) - if ( hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template @@ -602,9 +663,9 @@ class APIHandler(BaseHTTPRequestHandler): prompt = convert_chat(body["messages"], body.get("role_mapping")) prompt = self.tokenizer.encode(prompt) - return mx.array(prompt) + return prompt - def handle_text_completions(self) -> mx.array: + def handle_text_completions(self) -> List[int]: """ Handle a text completion request. @@ -614,11 +675,8 @@ class APIHandler(BaseHTTPRequestHandler): # Determine response type self.request_id = f"cmpl-{uuid.uuid4()}" self.object_type = "text_completion" - assert "prompt" in self.body, "Request did not contain a prompt" - prompt_text = self.body["prompt"] - prompt = self.tokenizer.encode(prompt_text) - return mx.array(prompt) + return self.tokenizer.encode(self.body["prompt"]) def do_GET(self): """ @@ -669,9 +727,16 @@ def run( handler_class=APIHandler, ): server_address = (host, port) + prompt_cache = PromptCache() httpd = server_class( server_address, - lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs), + lambda *args, **kwargs: handler_class( + model_provider, + prompt_cache=prompt_cache, + system_fingerprint=get_system_fingerprint(), + *args, + **kwargs, + ), ) warnings.warn( "mlx_lm.server is not recommended for production as " diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 3c1ef49b..64cd9486 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -1,5 +1,6 @@ # Copyright © 2024 Apple Inc. +import copy import os import tempfile import unittest @@ -215,6 +216,28 @@ class TestPromptCache(unittest.TestCase): all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits)) ) + def test_cache_copying(self): + cache = [KVCache()] + + x = mx.random.uniform(shape=(1, 8, 10, 4)) + cache[0].update_and_fetch(x, x) + + y = mx.random.uniform(shape=(1, 8, 1, 4)) + cache[0].update_and_fetch(y, y) + + old_cache = copy.deepcopy(cache) + + trim_prompt_cache(cache, 1) + + self.assertTrue(old_cache[0].offset, 11) + self.assertTrue(cache[0].offset, 10) + + z = mx.random.uniform(shape=(1, 8, 1, 4)) + cache[0].update_and_fetch(z, z) + + self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y)) + self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z)) + if __name__ == "__main__": unittest.main() diff --git a/llms/tests/test_server.py b/llms/tests/test_server.py index cbcccfbe..ad17554d 100644 --- a/llms/tests/test_server.py +++ b/llms/tests/test_server.py @@ -14,6 +14,7 @@ class DummyModelProvider: def __init__(self): HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" self.model, self.tokenizer = load(HF_MODEL_PATH) + self.model_key = (HF_MODEL_PATH, None) def load(self, model, adapter=None): assert model in ["default_model", "chat_model"] From bbd20030476737c1b9e027e3358392c40821987f Mon Sep 17 00:00:00 2001 From: madroid Date: Tue, 15 Oct 2024 02:21:41 +0800 Subject: [PATCH 05/77] FLUX: update README.md (#1036) --- README.md | 4 +++- flux/README.md | 26 +++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index bd180975..88888ad0 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,10 @@ Some more useful examples are listed below. ### Image Models +- Generating images + - [FLUX](flux) + - [Stable Diffusion or SDXL](stable_diffusion) - Image classification using [ResNets on CIFAR-10](cifar). -- Generating images with [Stable Diffusion or SDXL](stable_diffusion). - Convolutional variational autoencoder [(CVAE) on MNIST](cvae). ### Audio Models diff --git a/flux/README.md b/flux/README.md index 62eb9b62..0496c71b 100644 --- a/flux/README.md +++ b/flux/README.md @@ -28,6 +28,26 @@ You can install all of the above with the `requirements.txt` as follows: pip install -r requirements.txt + +Usage +--------- + +You can use the following command to generate an image, using `--output` to specify the storage location of the image, defaulting to `out.png`. + +```shell +python txt2image.py --model schnell \ + --n-images 1 \ + --image-size 256x512 \ + --verbose \ + 'A photo of an astronaut riding a horse on Mars.' +``` + +For more parameters, please use the `--help` command to view. + +```shell +python txt2image.py --help +``` + Inference --------- @@ -78,7 +98,11 @@ except for some additional logic to quantize and/or load trained adapters. One can use the script as follows: ```shell -python txt2image.py --n-images 4 --n-rows 2 --image-size 256x512 'A photo of an astronaut riding a horse on Mars.' +python txt2image.py \ + --n-images 4 \ + --n-rows 2 \ + --image-size 256x512 \ + 'A photo of an astronaut riding a horse on Mars.' ``` ### Experimental Options From 3d62b058a4235019c945957da575ad3e36036cee Mon Sep 17 00:00:00 2001 From: "Zak B. Elep" Date: Wed, 16 Oct 2024 00:13:01 +0800 Subject: [PATCH 06/77] fix: typo on flux model preloading (#1050) --- flux/txt2image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux/txt2image.py b/flux/txt2image.py index bf2f7294..5ebec81a 100644 --- a/flux/txt2image.py +++ b/flux/txt2image.py @@ -77,7 +77,7 @@ if __name__ == "__main__": nn.quantize(flux.clip, class_predicate=quantization_predicate) if args.preload_models: - sd.ensure_models_are_loaded() + flux.ensure_models_are_loaded() # Make the generator latent_size = to_latent_size(args.image_size) From f491d473a332fc3ff9daf74be2bd154c0f9231b5 Mon Sep 17 00:00:00 2001 From: madroid Date: Wed, 16 Oct 2024 01:37:45 +0800 Subject: [PATCH 07/77] FLUX: Optimize dataset loading logic (#1038) --- flux/README.md | 47 ++++---- flux/dreambooth.py | 121 +++------------------ flux/flux/__init__.py | 239 +--------------------------------------- flux/flux/datasets.py | 75 +++++++++++++ flux/flux/flux.py | 246 ++++++++++++++++++++++++++++++++++++++++++ flux/flux/trainer.py | 98 +++++++++++++++++ 6 files changed, 461 insertions(+), 365 deletions(-) create mode 100644 flux/flux/datasets.py create mode 100644 flux/flux/flux.py create mode 100644 flux/flux/trainer.py diff --git a/flux/README.md b/flux/README.md index 0496c71b..1a17e386 100644 --- a/flux/README.md +++ b/flux/README.md @@ -21,8 +21,9 @@ The dependencies are minimal, namely: - `huggingface-hub` to download the checkpoints. - `regex` for the tokenization -- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script +- `tqdm`, `PIL`, and `numpy` for the scripts - `sentencepiece` for the T5 tokenizer +- `datasets` for using an HF dataset directly You can install all of the above with the `requirements.txt` as follows: @@ -118,17 +119,12 @@ Finetuning The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell but ymmv) on a provided image dataset. The dataset folder must have an -`index.json` file with the following format: +`train.jsonl` file with the following format: -```json -{ - "data": [ - {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, - {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, - {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, - ... - ] -} +```jsonl +{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"} +{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"} +... ``` The training script by default trains for 600 iterations with a batch size of @@ -150,19 +146,15 @@ The training images are the following 5 images [^2]: ![dog6](static/dog6.png) -We start by making the following `index.json` file and placing it in the same +We start by making the following `train.jsonl` file and placing it in the same folder as the images. -```json -{ - "data": [ - {"image": "00.jpg", "text": "A photo of sks dog"}, - {"image": "01.jpg", "text": "A photo of sks dog"}, - {"image": "02.jpg", "text": "A photo of sks dog"}, - {"image": "03.jpg", "text": "A photo of sks dog"}, - {"image": "04.jpg", "text": "A photo of sks dog"} - ] -} +```jsonl +{"image": "00.jpg", "prompt": "A photo of sks dog"} +{"image": "01.jpg", "prompt": "A photo of sks dog"} +{"image": "02.jpg", "prompt": "A photo of sks dog"} +{"image": "03.jpg", "prompt": "A photo of sks dog"} +{"image": "04.jpg", "prompt": "A photo of sks dog"} ``` Subsequently we finetune FLUX using the following command: @@ -175,6 +167,17 @@ python dreambooth.py \ path/to/dreambooth/dataset/dog6 ``` + +Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning. + +```shell +python dreambooth.py \ + --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \ + --progress-every 600 --iterations 1200 --learning-rate 0.0001 \ + --lora-rank 4 --grad-accumulate 8 \ + mlx-community/dreambooth-dog6 +``` + The training requires approximately 50GB of RAM and on an M2 Ultra it takes a bit more than 1 hour. diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 4a4dbb08..48dcad47 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -1,7 +1,6 @@ # Copyright © 2024 Apple Inc. import argparse -import json import time from functools import partial from pathlib import Path @@ -13,105 +12,8 @@ import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten, tree_map, tree_reduce from PIL import Image -from tqdm import tqdm -from flux import FluxPipeline - - -class FinetuningDataset: - def __init__(self, flux, args): - self.args = args - self.flux = flux - self.dataset_base = Path(args.dataset) - dataset_index = self.dataset_base / "index.json" - if not dataset_index.exists(): - raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset") - with open(dataset_index, "r") as f: - self.index = json.load(f) - - self.latents = [] - self.t5_features = [] - self.clip_features = [] - - def _random_crop_resize(self, img): - resolution = self.args.resolution - width, height = img.size - - a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist() - - # Random crop the input image between 0.8 to 1.0 of its original dimensions - crop_size = ( - max((0.8 + 0.2 * a) * width, resolution[0]), - max((0.8 + 0.2 * a) * height, resolution[1]), - ) - pan = (width - crop_size[0], height - crop_size[1]) - img = img.crop( - ( - pan[0] * b, - pan[1] * c, - crop_size[0] + pan[0] * b, - crop_size[1] + pan[1] * c, - ) - ) - - # Fit the largest rectangle with the ratio of resolution in the image - # rectangle. - width, height = crop_size - ratio = resolution[0] / resolution[1] - r1 = (height * ratio, height) - r2 = (width, width / ratio) - r = r1 if r1[0] <= width else r2 - img = img.crop( - ( - (width - r[0]) / 2, - (height - r[1]) / 2, - (width + r[0]) / 2, - (height + r[1]) / 2, - ) - ) - - # Finally resize the image to resolution - img = img.resize(resolution, Image.LANCZOS) - - return mx.array(np.array(img)) - - def encode_images(self): - """Encode the images in the latent space to prepare for training.""" - self.flux.ae.eval() - for sample in tqdm(self.index["data"]): - input_img = Image.open(self.dataset_base / sample["image"]) - for i in range(self.args.num_augmentations): - img = self._random_crop_resize(input_img) - img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1 - x_0 = self.flux.ae.encode(img[None]) - x_0 = x_0.astype(self.flux.dtype) - mx.eval(x_0) - self.latents.append(x_0) - - def encode_prompts(self): - """Pre-encode the prompts so that we don't recompute them during - training (doesn't allow finetuning the text encoders).""" - for sample in tqdm(self.index["data"]): - t5_tok, clip_tok = self.flux.tokenize([sample["text"]]) - t5_feat = self.flux.t5(t5_tok) - clip_feat = self.flux.clip(clip_tok).pooled_output - mx.eval(t5_feat, clip_feat) - self.t5_features.append(t5_feat) - self.clip_features.append(clip_feat) - - def iterate(self, batch_size): - xs = mx.concatenate(self.latents) - t5 = mx.concatenate(self.t5_features) - clip = mx.concatenate(self.clip_features) - mx.eval(xs, t5, clip) - n_aug = self.args.num_augmentations - while True: - x_indices = mx.random.permutation(len(self.latents)) - c_indices = x_indices // n_aug - for i in range(0, len(self.latents), batch_size): - x_i = x_indices[i : i + batch_size] - c_i = c_indices[i : i + batch_size] - yield xs[x_i], t5[c_i], clip[c_i] +from flux import FluxPipeline, Trainer, load_dataset def generate_progress_images(iteration, flux, args): @@ -157,7 +59,8 @@ def save_adapters(iteration, flux, args): ) -if __name__ == "__main__": +def setup_arg_parser(): + """Set up and return the argument parser.""" parser = argparse.ArgumentParser( description="Finetune Flux to generate images with a specific subject" ) @@ -247,7 +150,11 @@ if __name__ == "__main__": ) parser.add_argument("dataset") + return parser + +if __name__ == "__main__": + parser = setup_arg_parser() args = parser.parse_args() # Load the model and set it up for LoRA training. We use the same random @@ -267,7 +174,7 @@ if __name__ == "__main__": trainable_params = tree_reduce( lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0 ) - print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True) + print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True) # Set up the optimizer and training steps. The steps are a bit verbose to # support gradient accumulation together with compilation. @@ -340,10 +247,10 @@ if __name__ == "__main__": x, t5_feat, clip_feat, guidance, prev_grads ) - print("Create the training dataset.", flush=True) - dataset = FinetuningDataset(flux, args) - dataset.encode_images() - dataset.encode_prompts() + dataset = load_dataset(args.dataset) + trainer = Trainer(flux, dataset, args) + trainer.encode_dataset() + guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype) # An initial generation to compare @@ -352,7 +259,7 @@ if __name__ == "__main__": grads = None losses = [] tic = time.time() - for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)): + for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)): loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0) mx.eval(loss, grads, state) losses.append(loss.item()) @@ -361,7 +268,7 @@ if __name__ == "__main__": toc = time.time() peak_mem = mx.metal.get_peak_memory() / 1024**3 print( - f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} " + f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " f"It/s: {10 / (toc - tic):.3f} " f"Peak mem: {peak_mem:.3f} GB", flush=True, diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py index 8d39d605..b1122d75 100644 --- a/flux/flux/__init__.py +++ b/flux/flux/__init__.py @@ -1,16 +1,10 @@ # Copyright © 2024 Apple Inc. -import math -import time -from typing import Tuple - -import mlx.core as mx -import mlx.nn as nn -from mlx.utils import tree_unflatten -from tqdm import tqdm - +from .datasets import Dataset, load_dataset +from .flux import FluxPipeline from .lora import LoRALinear from .sampler import FluxSampler +from .trainer import Trainer from .utils import ( load_ae, load_clip, @@ -19,230 +13,3 @@ from .utils import ( load_t5, load_t5_tokenizer, ) - - -class FluxPipeline: - def __init__(self, name: str, t5_padding: bool = True): - self.dtype = mx.bfloat16 - self.name = name - self.t5_padding = t5_padding - - self.ae = load_ae(name) - self.flow = load_flow_model(name) - self.clip = load_clip(name) - self.clip_tokenizer = load_clip_tokenizer(name) - self.t5 = load_t5(name) - self.t5_tokenizer = load_t5_tokenizer(name) - self.sampler = FluxSampler(name) - - def ensure_models_are_loaded(self): - mx.eval( - self.ae.parameters(), - self.flow.parameters(), - self.clip.parameters(), - self.t5.parameters(), - ) - - def reload_text_encoders(self): - self.t5 = load_t5(self.name) - self.clip = load_clip(self.name) - - def tokenize(self, text): - t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding) - clip_tokens = self.clip_tokenizer.encode(text) - return t5_tokens, clip_tokens - - def _prepare_latent_images(self, x): - b, h, w, c = x.shape - - # Pack the latent image to 2x2 patches - x = x.reshape(b, h // 2, 2, w // 2, 2, c) - x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4) - - # Create positions ids used to positionally encode each patch. Due to - # the way RoPE works, this results in an interesting positional - # encoding where parts of the feature are holding different positional - # information. Namely, the first part holds information independent of - # the spatial position (hence 0s), the 2nd part holds vertical spatial - # information and the last one horizontal. - i = mx.zeros((h // 2, w // 2), dtype=mx.int32) - j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij") - x_ids = mx.stack([i, j, k], axis=-1) - x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0) - - return x, x_ids - - def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens): - # Prepare the text features - txt = self.t5(t5_tokens) - if len(txt) == 1 and n_images > 1: - txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:])) - txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32) - - # Prepare the clip text features - vec = self.clip(clip_tokens).pooled_output - if len(vec) == 1 and n_images > 1: - vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:])) - - return txt, txt_ids, vec - - def _denoising_loop( - self, - x_t, - x_ids, - txt, - txt_ids, - vec, - num_steps: int = 35, - guidance: float = 4.0, - start: float = 1, - stop: float = 0, - ): - B = len(x_t) - - def scalar(x): - return mx.full((B,), x, dtype=self.dtype) - - guidance = scalar(guidance) - timesteps = self.sampler.timesteps( - num_steps, - x_t.shape[1], - start=start, - stop=stop, - ) - for i in range(num_steps): - t = timesteps[i] - t_prev = timesteps[i + 1] - - pred = self.flow( - img=x_t, - img_ids=x_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - timesteps=scalar(t), - guidance=guidance, - ) - x_t = self.sampler.step(pred, x_t, t, t_prev) - - yield x_t - - def generate_latents( - self, - text: str, - n_images: int = 1, - num_steps: int = 35, - guidance: float = 4.0, - latent_size: Tuple[int, int] = (64, 64), - seed=None, - ): - # Set the PRNG state - if seed is not None: - mx.random.seed(seed) - - # Create the latent variables - x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype) - x_T, x_ids = self._prepare_latent_images(x_T) - - # Get the conditioning - t5_tokens, clip_tokens = self.tokenize(text) - txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens) - - # Yield the conditioning for controlled evaluation by the caller - yield (x_T, x_ids, txt, txt_ids, vec) - - # Yield the latent sequences from the denoising loop - yield from self._denoising_loop( - x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance - ) - - def decode(self, x, latent_size: Tuple[int, int] = (64, 64)): - h, w = latent_size - x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2) - x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1) - x = self.ae.decode(x) - return mx.clip(x + 1, 0, 2) * 0.5 - - def generate_images( - self, - text: str, - n_images: int = 1, - num_steps: int = 35, - guidance: float = 4.0, - latent_size: Tuple[int, int] = (64, 64), - seed=None, - reload_text_encoders: bool = True, - progress: bool = True, - ): - latents = self.generate_latents( - text, n_images, num_steps, guidance, latent_size, seed - ) - mx.eval(next(latents)) - - if reload_text_encoders: - self.reload_text_encoders() - - for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True): - mx.eval(x_t) - - images = [] - for i in tqdm(range(len(x_t)), disable=not progress): - images.append(self.decode(x_t[i : i + 1])) - mx.eval(images[-1]) - images = mx.concatenate(images, axis=0) - mx.eval(images) - - return images - - def training_loss( - self, - x_0: mx.array, - t5_features: mx.array, - clip_features: mx.array, - guidance: mx.array, - ): - # Get the text conditioning - txt = t5_features - txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32) - vec = clip_features - - # Prepare the latent input - x_0, x_ids = self._prepare_latent_images(x_0) - - # Forward process - t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) - eps = mx.random.normal(x_0.shape, dtype=self.dtype) - x_t = self.sampler.add_noise(x_0, t, noise=eps) - x_t = mx.stop_gradient(x_t) - - # Do the denoising - pred = self.flow( - img=x_t, - img_ids=x_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, - timesteps=t, - guidance=guidance, - ) - - return (pred + x_0 - eps).square().mean() - - def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1): - """Swap the linear layers in the transformer blocks with LoRA layers.""" - all_blocks = self.flow.double_blocks + self.flow.single_blocks - all_blocks.reverse() - num_blocks = num_blocks if num_blocks > 0 else len(all_blocks) - for i, block in zip(range(num_blocks), all_blocks): - loras = [] - for name, module in block.named_modules(): - if isinstance(module, nn.Linear): - loras.append((name, LoRALinear.from_base(module, r=rank))) - block.update_modules(tree_unflatten(loras)) - - def fuse_lora_layers(self): - fused_layers = [] - for name, module in self.flow.named_modules(): - if isinstance(module, LoRALinear): - fused_layers.append((name, module.fuse())) - self.flow.update_modules(tree_unflatten(fused_layers)) diff --git a/flux/flux/datasets.py b/flux/flux/datasets.py new file mode 100644 index 00000000..d31a09f1 --- /dev/null +++ b/flux/flux/datasets.py @@ -0,0 +1,75 @@ +import json +from pathlib import Path + +from PIL import Image + + +class Dataset: + def __getitem__(self, index: int): + raise NotImplementedError() + + def __len__(self): + raise NotImplementedError() + + +class LocalDataset(Dataset): + prompt_key = "prompt" + + def __init__(self, dataset: str, data_file): + self.dataset_base = Path(dataset) + with open(data_file, "r") as fid: + self._data = [json.loads(l) for l in fid] + + def __len__(self): + return len(self._data) + + def __getitem__(self, index: int): + item = self._data[index] + image = Image.open(self.dataset_base / item["image"]) + return image, item[self.prompt_key] + + +class LegacyDataset(LocalDataset): + prompt_key = "text" + + def __init__(self, dataset: str): + self.dataset_base = Path(dataset) + with open(self.dataset_base / "index.json") as f: + self._data = json.load(f)["data"] + + +class HuggingFaceDataset(Dataset): + + def __init__(self, dataset: str): + from datasets import load_dataset as hf_load_dataset + + self._df = hf_load_dataset(dataset)["train"] + + def __len__(self): + return len(self._df) + + def __getitem__(self, index: int): + item = self._df[index] + return item["image"], item["prompt"] + + +def load_dataset(dataset: str): + dataset_base = Path(dataset) + data_file = dataset_base / "train.jsonl" + legacy_file = dataset_base / "index.json" + + if data_file.exists(): + print(f"Load the local dataset {data_file} .", flush=True) + dataset = LocalDataset(dataset, data_file) + elif legacy_file.exists(): + print(f"Load the local dataset {legacy_file} .") + print() + print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.") + print(" See the README for details.") + print(flush=True) + dataset = LegacyDataset(dataset) + else: + print(f"Load the Hugging Face dataset {dataset} .", flush=True) + dataset = HuggingFaceDataset(dataset) + + return dataset diff --git a/flux/flux/flux.py b/flux/flux/flux.py new file mode 100644 index 00000000..3fd044ac --- /dev/null +++ b/flux/flux/flux.py @@ -0,0 +1,246 @@ +# Copyright © 2024 Apple Inc. + +from typing import Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten +from tqdm import tqdm + +from .lora import LoRALinear +from .sampler import FluxSampler +from .utils import ( + load_ae, + load_clip, + load_clip_tokenizer, + load_flow_model, + load_t5, + load_t5_tokenizer, +) + + +class FluxPipeline: + def __init__(self, name: str, t5_padding: bool = True): + self.dtype = mx.bfloat16 + self.name = name + self.t5_padding = t5_padding + + self.ae = load_ae(name) + self.flow = load_flow_model(name) + self.clip = load_clip(name) + self.clip_tokenizer = load_clip_tokenizer(name) + self.t5 = load_t5(name) + self.t5_tokenizer = load_t5_tokenizer(name) + self.sampler = FluxSampler(name) + + def ensure_models_are_loaded(self): + mx.eval( + self.ae.parameters(), + self.flow.parameters(), + self.clip.parameters(), + self.t5.parameters(), + ) + + def reload_text_encoders(self): + self.t5 = load_t5(self.name) + self.clip = load_clip(self.name) + + def tokenize(self, text): + t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding) + clip_tokens = self.clip_tokenizer.encode(text) + return t5_tokens, clip_tokens + + def _prepare_latent_images(self, x): + b, h, w, c = x.shape + + # Pack the latent image to 2x2 patches + x = x.reshape(b, h // 2, 2, w // 2, 2, c) + x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4) + + # Create positions ids used to positionally encode each patch. Due to + # the way RoPE works, this results in an interesting positional + # encoding where parts of the feature are holding different positional + # information. Namely, the first part holds information independent of + # the spatial position (hence 0s), the 2nd part holds vertical spatial + # information and the last one horizontal. + i = mx.zeros((h // 2, w // 2), dtype=mx.int32) + j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij") + x_ids = mx.stack([i, j, k], axis=-1) + x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0) + + return x, x_ids + + def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens): + # Prepare the text features + txt = self.t5(t5_tokens) + if len(txt) == 1 and n_images > 1: + txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:])) + txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32) + + # Prepare the clip text features + vec = self.clip(clip_tokens).pooled_output + if len(vec) == 1 and n_images > 1: + vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:])) + + return txt, txt_ids, vec + + def _denoising_loop( + self, + x_t, + x_ids, + txt, + txt_ids, + vec, + num_steps: int = 35, + guidance: float = 4.0, + start: float = 1, + stop: float = 0, + ): + B = len(x_t) + + def scalar(x): + return mx.full((B,), x, dtype=self.dtype) + + guidance = scalar(guidance) + timesteps = self.sampler.timesteps( + num_steps, + x_t.shape[1], + start=start, + stop=stop, + ) + for i in range(num_steps): + t = timesteps[i] + t_prev = timesteps[i + 1] + + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=scalar(t), + guidance=guidance, + ) + x_t = self.sampler.step(pred, x_t, t, t_prev) + + yield x_t + + def generate_latents( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + ): + # Set the PRNG state + if seed is not None: + mx.random.seed(seed) + + # Create the latent variables + x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype) + x_T, x_ids = self._prepare_latent_images(x_T) + + # Get the conditioning + t5_tokens, clip_tokens = self.tokenize(text) + txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens) + + # Yield the conditioning for controlled evaluation by the caller + yield (x_T, x_ids, txt, txt_ids, vec) + + # Yield the latent sequences from the denoising loop + yield from self._denoising_loop( + x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance + ) + + def decode(self, x, latent_size: Tuple[int, int] = (64, 64)): + h, w = latent_size + x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2) + x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1) + x = self.ae.decode(x) + return mx.clip(x + 1, 0, 2) * 0.5 + + def generate_images( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + reload_text_encoders: bool = True, + progress: bool = True, + ): + latents = self.generate_latents( + text, n_images, num_steps, guidance, latent_size, seed + ) + mx.eval(next(latents)) + + if reload_text_encoders: + self.reload_text_encoders() + + for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True): + mx.eval(x_t) + + images = [] + for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"): + images.append(self.decode(x_t[i : i + 1])) + mx.eval(images[-1]) + images = mx.concatenate(images, axis=0) + mx.eval(images) + + return images + + def training_loss( + self, + x_0: mx.array, + t5_features: mx.array, + clip_features: mx.array, + guidance: mx.array, + ): + # Get the text conditioning + txt = t5_features + txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32) + vec = clip_features + + # Prepare the latent input + x_0, x_ids = self._prepare_latent_images(x_0) + + # Forward process + t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) + eps = mx.random.normal(x_0.shape, dtype=self.dtype) + x_t = self.sampler.add_noise(x_0, t, noise=eps) + x_t = mx.stop_gradient(x_t) + + # Do the denoising + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t, + guidance=guidance, + ) + + return (pred + x_0 - eps).square().mean() + + def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1): + """Swap the linear layers in the transformer blocks with LoRA layers.""" + all_blocks = self.flow.double_blocks + self.flow.single_blocks + all_blocks.reverse() + num_blocks = num_blocks if num_blocks > 0 else len(all_blocks) + for i, block in zip(range(num_blocks), all_blocks): + loras = [] + for name, module in block.named_modules(): + if isinstance(module, nn.Linear): + loras.append((name, LoRALinear.from_base(module, r=rank))) + block.update_modules(tree_unflatten(loras)) + + def fuse_lora_layers(self): + fused_layers = [] + for name, module in self.flow.named_modules(): + if isinstance(module, LoRALinear): + fused_layers.append((name, module.fuse())) + self.flow.update_modules(tree_unflatten(fused_layers)) diff --git a/flux/flux/trainer.py b/flux/flux/trainer.py new file mode 100644 index 00000000..40a126e8 --- /dev/null +++ b/flux/flux/trainer.py @@ -0,0 +1,98 @@ +import mlx.core as mx +import numpy as np +from PIL import Image, ImageFile +from tqdm import tqdm + +from .datasets import Dataset +from .flux import FluxPipeline + + +class Trainer: + + def __init__(self, flux: FluxPipeline, dataset: Dataset, args): + self.flux = flux + self.dataset = dataset + self.args = args + self.latents = [] + self.t5_features = [] + self.clip_features = [] + + def _random_crop_resize(self, img): + resolution = self.args.resolution + width, height = img.size + + a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist() + + # Random crop the input image between 0.8 to 1.0 of its original dimensions + crop_size = ( + max((0.8 + 0.2 * a) * width, resolution[0]), + max((0.8 + 0.2 * b) * height, resolution[1]), + ) + pan = (width - crop_size[0], height - crop_size[1]) + img = img.crop( + ( + pan[0] * c, + pan[1] * d, + crop_size[0] + pan[0] * c, + crop_size[1] + pan[1] * d, + ) + ) + + # Fit the largest rectangle with the ratio of resolution in the image + # rectangle. + width, height = crop_size + ratio = resolution[0] / resolution[1] + r1 = (height * ratio, height) + r2 = (width, width / ratio) + r = r1 if r1[0] <= width else r2 + img = img.crop( + ( + (width - r[0]) / 2, + (height - r[1]) / 2, + (width + r[0]) / 2, + (height + r[1]) / 2, + ) + ) + + # Finally resize the image to resolution + img = img.resize(resolution, Image.LANCZOS) + + return mx.array(np.array(img)) + + def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int): + for i in range(num_augmentations): + img = self._random_crop_resize(input_img) + img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1 + x_0 = self.flux.ae.encode(img[None]) + x_0 = x_0.astype(self.flux.dtype) + mx.eval(x_0) + self.latents.append(x_0) + + def _encode_prompt(self, prompt): + t5_tok, clip_tok = self.flux.tokenize([prompt]) + t5_feat = self.flux.t5(t5_tok) + clip_feat = self.flux.clip(clip_tok).pooled_output + mx.eval(t5_feat, clip_feat) + self.t5_features.append(t5_feat) + self.clip_features.append(clip_feat) + + def encode_dataset(self): + """Encode the images & prompt in the latent space to prepare for training.""" + self.flux.ae.eval() + for image, prompt in tqdm(self.dataset, desc="encode dataset"): + self._encode_image(image, self.args.num_augmentations) + self._encode_prompt(prompt) + + def iterate(self, batch_size): + xs = mx.concatenate(self.latents) + t5 = mx.concatenate(self.t5_features) + clip = mx.concatenate(self.clip_features) + mx.eval(xs, t5, clip) + n_aug = self.args.num_augmentations + while True: + x_indices = mx.random.permutation(len(self.latents)) + c_indices = x_indices // n_aug + for i in range(0, len(self.latents), batch_size): + x_i = x_indices[i : i + batch_size] + c_i = c_indices[i : i + batch_size] + yield xs[x_i], t5[c_i], clip[c_i] From 743763bc2e4113dcf3c478058623cb02e770a237 Mon Sep 17 00:00:00 2001 From: aronson Date: Sun, 20 Oct 2024 22:46:43 -0500 Subject: [PATCH 08/77] Handle empty string case in maybe_trim_space (#1055) * Handle empty string case in maybe_trim_space * nit --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/tokenizer_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index d8694d86..78ec2ff8 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -193,7 +193,9 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): self.tokens = [] def _maybe_trim_space(self, current_text): - if current_text[0] != " ": + if len(current_text) == 0: + return current_text + elif current_text[0] != " ": return current_text elif not self.text: return current_text[1:] From 66e7bcb8866a050727849d9a303c54a0119f0f99 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 22 Oct 2024 09:56:45 -0700 Subject: [PATCH 09/77] override dtype with quant (#1062) --- llms/mlx_lm/convert.py | 2 +- llms/mlx_lm/models/gemma2.py | 2 +- llms/mlx_lm/utils.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/convert.py b/llms/mlx_lm/convert.py index a3f43f71..9bac77a5 100644 --- a/llms/mlx_lm/convert.py +++ b/llms/mlx_lm/convert.py @@ -31,7 +31,7 @@ def configure_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--dtype", - help="Type to save the parameters, ignored if -q is given.", + help="Type to save the non-quantized parameters.", type=str, choices=["float16", "bfloat16", "float32"], default="float16", diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index ccc327a8..64951ae4 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -111,7 +111,7 @@ class MLP(nn.Module): self.up_proj = nn.Linear(dim, hidden_dim, bias=False) def __call__(self, x) -> mx.array: - return self.down_proj(nn.gelu(self.gate_proj(x)) * self.up_proj(x)) + return self.down_proj(nn.gelu_approx(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 4f872982..92741b68 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -720,7 +720,7 @@ def convert( model, config, tokenizer = fetch_from_hub(model_path, lazy=True) weights = dict(tree_flatten(model.parameters())) - dtype = mx.float16 if quantize else getattr(mx, dtype) + dtype = getattr(mx, dtype) weights = {k: v.astype(dtype) for k, v in weights.items()} if quantize and dequantize: From d1d480867b2248fb95fedcf7f9d33b41689d9991 Mon Sep 17 00:00:00 2001 From: madroid Date: Wed, 23 Oct 2024 03:19:11 +0800 Subject: [PATCH 10/77] LoRA: update tools datasets docs (#1063) * LoRA: update tools datasets docs * nits * nits --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/LORA.md | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 2d0dcf60..15676360 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -222,6 +222,17 @@ data formats. Here are examples of these formats: } ``` + +The format for the `arguments` field in a function varies for different models. +Common formats include JSON strings and dictionaries. The example provided +follows the format used by +[OpenAI](https://platform.openai.com/docs/guides/fine-tuning/fine-tuning-examples) +and [Mistral +AI](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#instruct). +A dictionary format is used in Hugging Face's [chat +templates](https://huggingface.co/docs/transformers/main/en/chat_templating#a-complete-tool-use-example). +Refer to the documentation for the model you are fine-tuning for more details. + `completions`: @@ -241,7 +252,7 @@ each line not expected by the loader will be ignored. > [!NOTE] > Each example in the datasets must be on a single line. Do not put more than -> one example per line and do not split an example accross multiple lines. +> one example per line and do not split an example across multiple lines. ### Hugging Face Datasets From 9000e280aeb56c2bcce128001ab157030095687a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 22 Oct 2024 15:44:08 -0700 Subject: [PATCH 11/77] fix mamba models conversion (#1065) --- llms/mlx_lm/models/mamba.py | 2 +- llms/mlx_lm/models/recurrent_gemma.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index d2740dc1..84f498e9 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -205,7 +205,7 @@ class Model(nn.Module): def sanitize(self, weights): for k, v in weights.items(): - if "conv1d.weight" in k and v.ndim == 3: + if "conv1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) return weights diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 06a307a6..5595d311 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -440,7 +440,7 @@ class Model(nn.Module): def sanitize(self, weights): for k, v in weights.items(): - if "conv_1d.weight" in k and v.ndim == 3: + if "conv_1d.weight" in k and v.shape[-1] != 1: weights[k] = v.moveaxis(2, 1) if "lm_head.weight" not in weights: self.pop("lm_head") From 4971462bf0dd7bba07d9f18fb0fd2752a51fde40 Mon Sep 17 00:00:00 2001 From: Saurav Maheshkar Date: Fri, 25 Oct 2024 05:56:17 +0100 Subject: [PATCH 12/77] feat(clip): add linear probe evaluation script (#960) --- clip/linear_probe.py | 56 +++++++++++++++++++++++++++++++++++++++++++ clip/requirements.txt | 1 + 2 files changed, 57 insertions(+) create mode 100644 clip/linear_probe.py diff --git a/clip/linear_probe.py b/clip/linear_probe.py new file mode 100644 index 00000000..2649e397 --- /dev/null +++ b/clip/linear_probe.py @@ -0,0 +1,56 @@ +# Mirror of the Linear Probe Evaluation Script +# from the official CLIP Repository. + +import mlx.core as mx +import numpy as np +from image_processor import CLIPImageProcessor +from mlx.data.datasets import load_cifar10 +from model import CLIPModel +from PIL import Image +from sklearn.linear_model import LogisticRegression +from tqdm import tqdm + + +def get_cifar10(batch_size, root=None): + tr = load_cifar10(root=root).batch(batch_size) + test = load_cifar10(root=root, train=False).batch(batch_size) + + return tr, test + + +def get_features(model, image_proc, iter): + all_features = [] + all_labels = [] + + for batch in tqdm(iter): + image, label = batch["image"], batch["label"] + x = image_proc([Image.fromarray(im) for im in image]) + y = mx.array(label) + + image_embeds = model.get_image_features(x) + mx.eval(image_embeds) + + all_features.append(image_embeds) + all_labels.append(y) + + return mx.concatenate(all_features), mx.concatenate(all_labels) + + +if __name__ == "__main__": + model = CLIPModel.from_pretrained("mlx_model") + image_proc = CLIPImageProcessor.from_pretrained("mlx_model") + + train_iter, test_iter = get_cifar10(batch_size=256) + train_features, train_labels = get_features(model, image_proc, train_iter) + test_features, test_labels = get_features(model, image_proc, test_iter) + + # Perform logistic regression + # NOTE: The value of C should be determined via a hyperparameter sweep + # using a validation split + classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) + classifier.fit(train_features, train_labels) + + # Evaluate using the logistic regression classifier + predictions = classifier.predict(test_features) + accuracy = (test_labels.squeeze() == predictions).mean().item() * 100 + print(f"Accuracy = {accuracy:.3f}") diff --git a/clip/requirements.txt b/clip/requirements.txt index 74f826ea..8e05620e 100644 --- a/clip/requirements.txt +++ b/clip/requirements.txt @@ -1,4 +1,5 @@ mlx +mlx-data numpy transformers torch From ab4bf05c6e72928ec0ca0143a3f976f6e787e40c Mon Sep 17 00:00:00 2001 From: hschaeufler <9865991+hschaeufler@users.noreply.github.com> Date: Sat, 26 Oct 2024 19:34:46 +0300 Subject: [PATCH 13/77] Update lora_config.yaml with new param: num_layers (#1068) --- llms/mlx_lm/examples/lora_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index 4ec9a23c..530272c7 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -14,7 +14,7 @@ data: "/path/to/training/data" seed: 0 # Number of layers to fine-tune -lora_layers: 16 +num_layers: 16 # Minibatch size. batch_size: 4 From 8fe9539af76075405b2c3071ba9657aa921d749d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 27 Oct 2024 15:06:07 -0700 Subject: [PATCH 14/77] Fix detokenizer space match for quote (#1072) * fix + test * remove transformer flax/torch warning * format --- llms/mlx_lm/__init__.py | 5 +++++ llms/mlx_lm/tokenizer_utils.py | 2 +- llms/tests/test_tokenizers.py | 3 +++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/__init__.py b/llms/mlx_lm/__init__.py index 502c78e5..538be927 100644 --- a/llms/mlx_lm/__init__.py +++ b/llms/mlx_lm/__init__.py @@ -1,4 +1,9 @@ # Copyright © 2023-2024 Apple Inc. +import os + from ._version import __version__ + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" + from .utils import convert, generate, load, stream_generate diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 78ec2ff8..0cbc3b9b 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -169,7 +169,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): """ _byte_decoder = None - _space_matches = (".", "?", "!", ",", "'", "n't", "'m", "'s", "'ve", "'re") + _space_matches = (".", "?", "!", ",", "n't", "'m", "'s", "'ve", "'re") def __init__(self, tokenizer): diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py index 7b4828b1..03445c1f 100644 --- a/llms/tests/test_tokenizers.py +++ b/llms/tests/test_tokenizers.py @@ -51,6 +51,9 @@ class TestTokenizers(unittest.TestCase): tokens = tokenizer.encode("3 3") check(tokens) + tokens = tokenizer.encode("import 'package:flutter/material.dart';") + check(tokens) + def test_tokenizers(self): tokenizer_repos = [ ("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer), From 9f34fdbda4527e85ab6b98d9f343f7a2972085f1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 31 Oct 2024 08:17:14 -0700 Subject: [PATCH 15/77] Wire models in MLX LM (#1069) * wired in MLX LM * fix synch * comment + nit * version * mlx lm version * bump to 0.19.2 --- llms/README.md | 25 ++++++++ llms/mlx_lm/_version.py | 2 +- llms/mlx_lm/chat.py | 2 +- llms/mlx_lm/requirements.txt | 2 +- llms/mlx_lm/utils.py | 115 +++++++++++++++++++++++------------ 5 files changed, 104 insertions(+), 42 deletions(-) diff --git a/llms/README.md b/llms/README.md index 20863041..f539988a 100644 --- a/llms/README.md +++ b/llms/README.md @@ -248,3 +248,28 @@ model, tokenizer = load( tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True}, ) ``` + +### Large Models + +> [!NOTE] + This requires macOS 15.0 or higher to work. + +Models which are large relative to the total RAM available on the machine can +be slow. `mlx-lm` will attempt to make them faster by wiring the memory +occupied by the model and cache. This requires macOS 15 or higher to +work. + +If you see the following warning message: + +> [WARNING] Generating with a model that requires ... + +then the model will likely be slow on the given machine. If the model fits in +RAM then it can often be sped up by increasing the system wired memory limit. +To increase the limit, set the following `sysctl`: + +```bash +sudo sysctl iogpu.wired_limit_mb=N +``` + +The value `N` should be larger than the size of the model in megabytes but +smaller than the memory size of the machine. diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 70239db6..3811616f 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.19.1" +__version__ = "0.19.3" diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 7968a868..ea1a99c7 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -56,7 +56,7 @@ def main(): tokenizer_config={"trust_remote_code": True}, ) - print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.") + print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") prompt_cache = make_prompt_cache(model, args.max_kv_size) while True: query = input(">> ") diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 814c03cc..48012863 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.17.0 +mlx>=0.19.2 numpy transformers[sentencepiece]>=4.39.3 protobuf diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 92741b68..5b437c98 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -1,5 +1,6 @@ # Copyright © 2023-2024 Apple Inc. +import contextlib import copy import glob import importlib @@ -14,7 +15,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download -from mlx.utils import tree_flatten +from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer # Local imports @@ -39,6 +40,40 @@ class ModelNotFoundError(Exception): super().__init__(self.message) +@contextlib.contextmanager +def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): + """ + A context manager to temporarily change the wired limit. + + Note, the wired limit should not be changed during an async eval. If an + async eval could be running pass in the streams to synchronize with prior + to exiting the context manager. + """ + model_bytes = tree_reduce( + lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 + ) + max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"] + if model_bytes > 0.9 * max_rec_size: + model_mb = model_bytes // 2**20 + max_rec_mb = max_rec_size // 2**20 + print( + "[WARNING] Generating with a model that requires {model_mb} MB " + "which is close to the maximum recommended size of {max_rec_mb} " + "MB. This can be slow. See the documentation for possible work-arounds: " + "https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models" + ) + old_limit = mx.metal.set_wired_limit(max_rec_size) + try: + yield None + finally: + if streams is not None: + for s in streams: + mx.synchronize(s) + else: + mx.synchronize() + mx.metal.set_wired_limit(old_limit) + + def _get_classes(config: dict): """ Retrieve the model and model args classes based on the configuration. @@ -330,48 +365,50 @@ def generate( prompt_tokens = mx.array(tokenizer.encode(prompt)) detokenizer = tokenizer.detokenizer - tic = time.perf_counter() - detokenizer.reset() + with wired_limit(model): + tic = time.perf_counter() + detokenizer.reset() + for n, (token, logprobs) in zip( + range(max_tokens), + generate_step(prompt_tokens, model, **kwargs), + ): + if n == 0: + prompt_time = time.perf_counter() - tic + tic = time.perf_counter() + if token == tokenizer.eos_token_id: + break + detokenizer.add_token(token) - for n, (token, logprobs) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if n == 0: - prompt_time = time.perf_counter() - tic - tic = time.perf_counter() - if token == tokenizer.eos_token_id: - break - detokenizer.add_token(token) + if verbose: + if formatter: + # We have to finalize so that the prob corresponds to the last segment + detokenizer.finalize() + with mx.stream(mx.cpu): + prob = mx.exp(logprobs[token]).item() + formatter(detokenizer.last_segment, prob) + else: + print(detokenizer.last_segment, end="", flush=True) + + token_count = n + 1 + detokenizer.finalize() if verbose: - if formatter: - # We have to finalize so that the prob corresponds to the last segment - detokenizer.finalize() - with mx.stream(mx.cpu): - prob = mx.exp(logprobs[token]).item() - formatter(detokenizer.last_segment, prob) - else: - print(detokenizer.last_segment, end="", flush=True) + gen_time = time.perf_counter() - tic + print(detokenizer.last_segment, flush=True) + print("=" * 10) + if token_count == 0: + print("No tokens generated for this prompt") + return + prompt_tps = prompt_tokens.size / prompt_time + gen_tps = (token_count - 1) / gen_time + print( + f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" + ) + print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") + peak_mem = mx.metal.get_peak_memory() / 2**30 + print(f"Peak memory: {peak_mem:.3f} GB") - token_count = n + 1 - detokenizer.finalize() - - if verbose: - gen_time = time.perf_counter() - tic - print(detokenizer.last_segment, flush=True) - print("=" * 10) - if token_count == 0: - print("No tokens generated for this prompt") - return - prompt_tps = prompt_tokens.size / prompt_time - gen_tps = (token_count - 1) / gen_time - print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec") - print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 2**30 - print(f"Peak memory: {peak_mem:.3f} GB") - - return detokenizer.text + return detokenizer.text def load_config(model_path: Path) -> dict: From 85ffd2c96a45a8cb900f95a2ded61d858d673399 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Thu, 31 Oct 2024 16:59:52 -0700 Subject: [PATCH 16/77] Quantized KV Cache (#1075) * add QuantizedKVCache * simplify * add tests * single sdpa function * fix sed * in place * fix tests * support different k and v head dims --- llms/mlx_lm/cache_prompt.py | 30 +++++++- llms/mlx_lm/generate.py | 38 +++++++++- llms/mlx_lm/models/base.py | 63 ++++++++++++++++ llms/mlx_lm/models/cache.py | 102 +++++++++++++++++++++++++- llms/mlx_lm/models/cohere.py | 6 +- llms/mlx_lm/models/dbrx.py | 6 +- llms/mlx_lm/models/deepseek.py | 6 +- llms/mlx_lm/models/deepseek_v2.py | 6 +- llms/mlx_lm/models/gemma.py | 6 +- llms/mlx_lm/models/gpt2.py | 6 +- llms/mlx_lm/models/gpt_bigcode.py | 6 +- llms/mlx_lm/models/gpt_neox.py | 6 +- llms/mlx_lm/models/internlm2.py | 6 +- llms/mlx_lm/models/llama.py | 9 ++- llms/mlx_lm/models/minicpm.py | 6 +- llms/mlx_lm/models/mixtral.py | 6 +- llms/mlx_lm/models/nemotron.py | 6 +- llms/mlx_lm/models/openelm.py | 6 +- llms/mlx_lm/models/phi.py | 11 ++- llms/mlx_lm/models/phi3.py | 6 +- llms/mlx_lm/models/phi3small.py | 6 +- llms/mlx_lm/models/phimoe.py | 6 +- llms/mlx_lm/models/phixtral.py | 11 ++- llms/mlx_lm/models/plamo.py | 5 +- llms/mlx_lm/models/qwen.py | 6 +- llms/mlx_lm/models/qwen2.py | 6 +- llms/mlx_lm/models/qwen2_moe.py | 6 +- llms/mlx_lm/models/recurrent_gemma.py | 6 +- llms/mlx_lm/models/stablelm.py | 6 +- llms/mlx_lm/models/starcoder2.py | 6 +- llms/mlx_lm/utils.py | 32 +++++++- llms/tests/test_prompt_cache.py | 63 ++++++++++++++++ 32 files changed, 411 insertions(+), 85 deletions(-) diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 04e75a3e..7bb06411 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -8,7 +8,9 @@ import time import mlx.core as mx from .models.cache import make_prompt_cache, save_prompt_cache -from .utils import load +from .utils import load, maybe_quantize_kv_cache + +DEFAULT_QUANTIZED_KV_START = 5000 def setup_arg_parser(): @@ -70,6 +72,26 @@ def setup_arg_parser(): required=True, help="Message to be processed by the model ('-' reads from stdin)", ) + parser.add_argument( + "--kv-bits", + type=int, + help="Number of bits for KV cache quantization. " + "Defaults to no quantization.", + default=None, + ) + parser.add_argument( + "--kv-group-size", + type=int, + help="Group size for KV cache quantization.", + default=64, + ) + parser.add_argument( + "--quantized-kv-start", + help="When --kv-bits is set, start quantizing the KV cache " + "from this step onwards.", + type=int, + default=DEFAULT_QUANTIZED_KV_START, + ) return parser @@ -127,6 +149,7 @@ def main(): start = time.time() max_msg_len = 0 while y.size > 0: + model(y[:step_size][None], cache=cache) mx.eval([c.state for c in cache]) processed += min(y.size, step_size) @@ -136,6 +159,11 @@ def main(): msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" max_msg_len = max(max_msg_len, len(msg)) print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) + + maybe_quantize_kv_cache( + cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits + ) + print() print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0bf98ab2..0355ca29 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -6,7 +6,7 @@ import sys import mlx.core as mx -from .models.cache import load_prompt_cache +from .models.cache import QuantizedKVCache, load_prompt_cache from .utils import generate, load DEFAULT_PROMPT = "hello" @@ -15,6 +15,7 @@ DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" +DEFAULT_QUANTIZED_KV_START = 5000 def str2bool(string): @@ -107,6 +108,26 @@ def setup_arg_parser(): default=None, help="A file containing saved KV caches to avoid recomputing them", ) + parser.add_argument( + "--kv-bits", + type=int, + help="Number of bits for KV cache quantization. " + "Defaults to no quantization.", + default=None, + ) + parser.add_argument( + "--kv-group-size", + type=int, + help="Group size for KV cache quantization.", + default=64, + ) + parser.add_argument( + "--quantized-kv-start", + help="When --kv-bits is set, start quantizing the KV cache " + "from this step onwards.", + type=int, + default=DEFAULT_QUANTIZED_KV_START, + ) return parser @@ -150,8 +171,18 @@ def main(): using_cache = args.prompt_cache_file is not None if using_cache: prompt_cache, metadata = load_prompt_cache( - args.prompt_cache_file, return_metadata=True + args.prompt_cache_file, + return_metadata=True, ) + if isinstance(prompt_cache[0], QuantizedKVCache): + if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits: + raise ValueError( + "--kv-bits does not match the kv cache loaded from --prompt-cache-file." + ) + if args.kv_group_size != prompt_cache[0].group_size: + raise ValueError( + "--kv-group-size does not match the kv cache loaded from --prompt-cache-file." + ) # Building tokenizer_config tokenizer_config = ( @@ -227,6 +258,9 @@ def main(): top_p=args.top_p, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, + kv_bits=args.kv_bits, + kv_group_size=args.kv_group_size, + quantized_kv_start=args.quantized_kv_start, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 3628a808..cda41c79 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -5,6 +5,9 @@ from dataclasses import dataclass from typing import Any, Optional import mlx.core as mx +from mlx.utils import tree_map + +from .cache import QuantizedKVCache @dataclass @@ -48,3 +51,63 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): else: mask = None return mask + + +def quantized_scaled_dot_product_attention( + queries: mx.array, + q_keys: tuple[mx.array, mx.array, mx.array], + q_values: tuple[mx.array, mx.array, mx.array], + scale: float, + mask: Optional[mx.array], + group_size: int = 64, + bits: int = 8, +) -> mx.array: + B, n_q_heads, L, D = queries.shape + n_kv_heads = q_keys[0].shape[-3] + n_repeats = n_q_heads // n_kv_heads + + queries *= scale + + if n_repeats > 1: + queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D)) + q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys) + q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values) + + scores = mx.quantized_matmul( + queries, *q_keys, transpose=True, group_size=group_size, bits=bits + ) + if mask is not None: + scores += mask + scores = mx.softmax(scores, axis=-1, precise=True) + out = mx.quantized_matmul( + scores, *q_values, transpose=False, group_size=group_size, bits=bits + ) + + if n_repeats > 1: + out = mx.reshape(out, (B, n_q_heads, L, D)) + + return out + + +def scaled_dot_product_attention( + queries, + keys, + values, + cache, + scale: float, + mask: Optional[mx.array], +) -> mx.array: + if isinstance(cache, QuantizedKVCache): + return quantized_scaled_dot_product_attention( + queries, + keys, + values, + scale=scale, + mask=mask, + group_size=cache.group_size, + bits=cache.bits, + ) + else: + return mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=scale, mask=mask + ) diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index a6a56e0a..1cd5289d 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional import mlx.core as mx import mlx.nn as nn -from mlx.utils import tree_flatten, tree_unflatten +from mlx.utils import tree_flatten, tree_map, tree_unflatten -def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]: +def make_prompt_cache( + model: nn.Module, + max_kv_size: Optional[int] = None, +) -> List[Any]: """ Construct the model's cache for use when cgeneration. @@ -126,6 +129,88 @@ class _BaseCache: return False +class QuantizedKVCache(_BaseCache): + def __init__(self, group_size: int = 64, bits: int = 8): + self.keys = None + self.values = None + self.offset = 0 + self.step = 256 + self.group_size = group_size + self.bits = bits + + def update_and_fetch(self, keys, values): + B, n_kv_heads, num_steps, k_head_dim = keys.shape + v_head_dim = values.shape[-1] + prev = self.offset + + if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]: + el_per_int = 8 * mx.uint32.size // self.bits + new_steps = (self.step + num_steps - 1) // self.step * self.step + shape = (B, n_kv_heads, new_steps) + + def init_quant(dim): + return ( + mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32), + mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), + mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), + ) + + def expand_quant(x): + new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype) + return mx.concatenate([x, new_x], axis=-2) + + if self.keys is not None: + if prev % self.step != 0: + self.keys, self.values = tree_map( + lambda x: x[..., :prev, :], (self.keys, self.values) + ) + + self.keys, self.values = tree_map( + expand_quant, (self.keys, self.values) + ) + else: + self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim) + + self.offset += num_steps + + keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits) + values = mx.quantize(values, group_size=self.group_size, bits=self.bits) + for i in range(len(self.keys)): + self.keys[i][..., prev : self.offset, :] = keys[i] + self.values[i][..., prev : self.offset, :] = values[i] + + return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values)) + + @property + def state(self): + if self.offset == self.keys[0].shape[2]: + return self.keys, self.values + else: + return tree_map( + lambda x: x[..., : self.offset, :], (self.keys, self.values) + ) + + @state.setter + def state(self, v): + self.keys, self.values = v + + @property + def meta_state(self): + return tuple(map(str, (self.step, self.offset, self.group_size, self.bits))) + + @meta_state.setter + def meta_state(self, v): + self.step, self.offset, self.group_size, self.bits = map(int, v) + + def is_trimmable(self): + return True + + def trim(self, n): + n = min(self.offset, n) + self.offset -= n + return n + + class KVCache(_BaseCache): def __init__(self): self.keys = None @@ -180,6 +265,16 @@ class KVCache(_BaseCache): self.offset -= n return n + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: + quant_cache = QuantizedKVCache(group_size=group_size, bits=bits) + quant_cache.offset = self.offset + if self.keys is not None: + quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits) + quant_cache.values = mx.quantize( + self.values, group_size=group_size, bits=bits + ) + return quant_cache + class RotatingKVCache(_BaseCache): @@ -320,6 +415,9 @@ class RotatingKVCache(_BaseCache): self._idx -= n return n + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: + raise NotImplementedError("RotatingKVCache Quantization NYI") + class MambaCache(_BaseCache): def __init__(self): diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index 057c816d..7e002b0c 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -93,8 +93,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index 3b7e83d7..7be274cc 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -74,8 +74,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.out_proj(output) diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py index 03cb3b1a..b7b24dba 100644 --- a/llms/mlx_lm/models/deepseek.py +++ b/llms/mlx_lm/models/deepseek.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -97,8 +97,8 @@ class DeepseekAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index bb3e5184..444813b9 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -235,8 +235,8 @@ class DeepseekV2Attention(nn.Module): queries = mx.concatenate([q_nope, q_pe], axis=-1) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index 61de781e..3f384c3f 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -79,8 +79,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index 97d9a8ff..52076a34 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -61,8 +61,8 @@ class Attention(nn.Module): if cache is not None: keys, values = cache.update_and_fetch(keys, values) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 068046ea..23e86e20 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -74,8 +74,8 @@ class Attention(nn.Module): if cache is not None: keys, values = cache.update_and_fetch(keys, values) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.c_proj(output) diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index 9f662491..ccb0b28b 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention # Based on the transformers implementation at: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -79,8 +79,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index 5264cb57..f5ce057e 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -141,8 +141,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.wo(output) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 7da6b333..438278e5 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -190,9 +190,10 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index 4ac3c3b4..907beb2a 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -105,8 +105,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - attn_output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + attn_output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 20944fe3..dd94d1f4 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -87,8 +87,8 @@ class MixtralAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py index 3ea06e27..f73c0277 100644 --- a/llms/mlx_lm/models/nemotron.py +++ b/llms/mlx_lm/models/nemotron.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -113,8 +113,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 090e21c6..408802f4 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -107,8 +107,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 56b383b2..510025ea 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -7,7 +7,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -93,8 +93,13 @@ class PhiAttention(nn.Module): keys = self.rope(keys) scale = math.sqrt(1 / queries.shape[-1]) - output = mx.fast.scaled_dot_product_attention( - queries.astype(mx.float32), keys, values, scale=scale, mask=mask + output = scaled_dot_product_attention( + queries.astype(mx.float32), + keys, + values, + cache=cache, + scale=scale, + mask=mask, ).astype(values.dtype) output = output.moveaxis(2, 1).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index 9ef76f04..ee6efc49 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .su_rope import SuScaledRotaryEmbedding @@ -107,8 +107,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index 6b0759b4..53e1a638 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -8,7 +8,7 @@ from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -188,8 +188,8 @@ class Attention(nn.Module): queries, keys, values, scale=self.scale, mask=mask ) else: - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.dense(output) diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py index ca20a388..f42a6dd0 100644 --- a/llms/mlx_lm/models/phimoe.py +++ b/llms/mlx_lm/models/phimoe.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .su_rope import SuScaledRotaryEmbedding from .switch_layers import SwitchGLU @@ -79,8 +79,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 865d0d8e..42d647b0 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -8,7 +8,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import create_attention_mask +from .base import create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchMLP @@ -71,8 +71,13 @@ class RoPEAttention(nn.Module): # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) - output = mx.fast.scaled_dot_product_attention( - queries.astype(mx.float32), keys, values, scale=scale, mask=mask + output = scaled_dot_product_attention( + queries.astype(mx.float32), + keys, + values, + cache=cache, + scale=scale, + mask=mask, ).astype(values.dtype) output = output.moveaxis(2, 1).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index b0fd1a6c..c8e5bf50 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -92,10 +92,11 @@ class Attention(nn.Module): keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1]) values = mx.tile(values, [1, self.config.n_shared_head, 1, 1]) - output = mx.fast.scaled_dot_product_attention( + output = scaled_dot_product_attention( queries, keys, values, + cache=cache, scale=self.scale, mask=attention_mask, ) diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 2b69d5ec..8145a890 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -5,7 +5,7 @@ from dataclasses import dataclass import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -64,8 +64,8 @@ class Attention(nn.Module): queries = self.rotary_emb(queries) keys = self.rotary_emb(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 4e7858de..fac59d78 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -89,8 +89,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index d199116f..167fc5dd 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -89,8 +89,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 5595d311..49e4bb8f 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -7,7 +7,7 @@ from typing import List, Literal, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .cache import MambaCache, RotatingKVCache @@ -263,8 +263,8 @@ class LocalAttentionBlock(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 11202b02..482bb324 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -6,7 +6,7 @@ from dataclasses import dataclass import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -120,8 +120,8 @@ class Attention(nn.Module): # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=scale, mask=mask ).astype(values.dtype) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index ce0a2ec5..d7e626f2 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -6,7 +6,7 @@ from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -64,8 +64,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 5b437c98..06784f10 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -19,7 +19,7 @@ from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer # Local imports -from .models import base, cache +from .models import cache from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model @@ -159,6 +159,18 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float) return logits +def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): + if ( + kv_bits is not None + and not isinstance(prompt_cache[0], cache.QuantizedKVCache) + and prompt_cache[0].offset > quantized_kv_start + ): + for i in range(len(prompt_cache)): + prompt_cache[i] = prompt_cache[i].to_quantized( + group_size=kv_group_size, bits=kv_bits + ) + + def generate_step( prompt: mx.array, model: nn.Module, @@ -173,6 +185,9 @@ def generate_step( prompt_cache: Optional[Any] = None, logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -201,6 +216,11 @@ def generate_step( logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed logits. Default: ``None``. + kv_bits (int, optional): Number of bits to use for KV cache quantization. + None implies no cache quantization. Default: ``None``. + kv_group_size (int): Group size for KV cache quantization. Default: ``64``. + quantized_kv_start (int): Step to begin using a quantized KV cache. + when ``kv_bits`` is non-None. Default: ``0``. Yields: Generator[Tuple[mx.array, mx.array], None, None]: A generator producing @@ -255,11 +275,15 @@ def generate_step( # Create the KV cache for generation if prompt_cache is None: - prompt_cache = cache.make_prompt_cache(model, max_kv_size) + prompt_cache = cache.make_prompt_cache( + model, + max_kv_size=max_kv_size, + ) elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") def _step(y): + logits = model(y[None], cache=prompt_cache) logits = logits[:, -1, :] @@ -270,6 +294,10 @@ def generate_step( for processor in logits_processor: logits = processor(tokens, logits) + maybe_quantize_kv_cache( + prompt_cache, quantized_kv_start, kv_group_size, kv_bits + ) + y, logprobs = sample(logits) return y, logprobs.squeeze(0) diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 64cd9486..1e57bd86 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -9,6 +9,7 @@ import mlx.core as mx from mlx_lm.models.cache import ( KVCache, MambaCache, + QuantizedKVCache, RotatingKVCache, load_prompt_cache, make_prompt_cache, @@ -186,6 +187,18 @@ class TestPromptCache(unittest.TestCase): num_trimmed = trim_prompt_cache(cache, 4) self.assertEqual(num_trimmed, 0) + cache = [QuantizedKVCache() for _ in range(2)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 64)) + c.update_and_fetch(x, x) + + num_trimmed = trim_prompt_cache(cache, 7) + self.assertEqual(num_trimmed, 7) + + # Trim more tokens than remain + num_trimmed = trim_prompt_cache(cache, 4) + self.assertEqual(num_trimmed, 3) + def test_trim_cache_with_generate(self): model, tokenizer = load(HF_MODEL_PATH) prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] @@ -238,6 +251,56 @@ class TestPromptCache(unittest.TestCase): self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y)) self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z)) + def test_save_load_quantized_cache(self): + cache = [QuantizedKVCache(bits=4, group_size=32) for _ in range(4)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 32)) + c.update_and_fetch(x, x) + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + self.assertTrue(loaded_cache[0].bits == cache[0].bits) + self.assertTrue(loaded_cache[0].group_size == cache[0].group_size) + self.assertTrue(len(cache), len(loaded_cache)) + for c, lc in zip(cache, loaded_cache): + self.assertEqual(c.offset, lc.offset) + # Loop over quantized tuple + for i in range(3): + self.assertTrue(mx.array_equal(c.state[0][i], lc.state[0][i])) + self.assertTrue(mx.array_equal(c.state[1][i], lc.state[1][i])) + + # Test with metadata + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + metadata = {"a": "b", "c": "d"} + save_prompt_cache(cache_file, cache, metadata) + _, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True) + self.assertEqual(metadata, loaded_metadata) + + def test_cache_to_quantized(self): + model, tokenizer = load(HF_MODEL_PATH) + prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] + results = zip(range(4), generate_step(prompt, model)) + toks, all_logits = zip(*(r[1] for r in results)) + + prompt_cache = make_prompt_cache(model) + i = 0 + for _, (tok, logits) in zip( + range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + ): + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i])) + i += 1 + + prompt_cache = [c.to_quantized(bits=8, group_size=32) for c in prompt_cache] + + for _, (tok, logits) in zip( + range(1), + generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + ): + i += 1 + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i], rtol=1e-2)) + if __name__ == "__main__": unittest.main() From 8160e0c4e56df261d0c8406f68d40b42ef0a188b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 1 Nov 2024 10:52:28 -0700 Subject: [PATCH 17/77] Whisper improvements (#1080) * use safetensors in whisper * speed up decoder * version --- whisper/convert.py | 4 +- whisper/mlx_whisper/_version.py | 2 +- whisper/mlx_whisper/decoding.py | 132 ++++++++++++++++------------- whisper/mlx_whisper/load_models.py | 5 +- whisper/mlx_whisper/transcribe.py | 1 + whisper/mlx_whisper/whisper.py | 5 +- 6 files changed, 85 insertions(+), 64 deletions(-) diff --git a/whisper/convert.py b/whisper/convert.py index cdd50bc5..301fd5b4 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -181,7 +181,7 @@ def load_torch_weights_and_config( ) if name_or_path.endswith(".pt"): - checkpoint = torch.load(name_or_path, map_location="cpu") + checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False) weights, config = checkpoint["model_state_dict"], checkpoint["dims"] else: name_or_path = Path(name_or_path) @@ -387,7 +387,7 @@ if __name__ == "__main__": # Save weights print("[INFO] Saving") - np.savez(str(mlx_path / "weights.npz"), **weights) + mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights) # Save config.json with model_type with open(str(mlx_path / "config.json"), "w") as f: diff --git a/whisper/mlx_whisper/_version.py b/whisper/mlx_whisper/_version.py index 67c7397c..45e522d1 100644 --- a/whisper/mlx_whisper/_version.py +++ b/whisper/mlx_whisper/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.3.0" +__version__ = "0.4.0" diff --git a/whisper/mlx_whisper/decoding.py b/whisper/mlx_whisper/decoding.py index 41c2ec6d..6bf975d5 100644 --- a/whisper/mlx_whisper/decoding.py +++ b/whisper/mlx_whisper/decoding.py @@ -58,11 +58,12 @@ def detect_language( logits = model.logits(x, mel)[:, 0] # collect detected languages; suppress all non-language tokens - mask = np.full(logits.shape[-1], -np.inf, dtype=np.float32) + mask = mx.full(logits.shape[-1], -mx.inf, dtype=mx.float32) mask[list(tokenizer.all_language_tokens)] = 0.0 - logits += mx.array(mask) + logits += mask language_tokens = mx.argmax(logits, axis=-1) language_token_probs = mx.softmax(logits, axis=-1) + language_token_probs = np.array(language_token_probs) language_probs = [ { c: language_token_probs[i, j].item() @@ -129,17 +130,12 @@ class DecodingResult: class Inference: - def __init__(self, model: "Whisper", initial_token_length: int): + def __init__(self, model: "Whisper"): self.model: "Whisper" = model - self.initial_token_length = initial_token_length self.kv_cache = None def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array: """Perform a forward pass on the decoder and return per-token logits""" - if tokens.shape[-1] > self.initial_token_length: - # only need to use the last token except in the first forward pass - tokens = tokens[:, -1:] - logits, self.kv_cache, _ = self.model.decoder( tokens, audio_features, kv_cache=self.kv_cache ) @@ -251,6 +247,11 @@ class TokenDecoder: raise NotImplementedError +@mx.compile +def categorical(logits, temp): + return mx.random.categorical(logits / temp) + + class GreedyDecoder(TokenDecoder): def __init__(self, temperature: float, eot: int): self.temperature = temperature @@ -262,10 +263,8 @@ class GreedyDecoder(TokenDecoder): if self.temperature == 0: next_tokens = logits.argmax(axis=-1) else: - next_tokens = mx.random.categorical(logits=logits / self.temperature) + next_tokens = categorical(logits, self.temperature) - next_tokens = mx.argmax(logits, axis=-1) - logits = logits.astype(mx.float32) logprobs = logits - mx.logsumexp(logits, axis=-1) current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens] @@ -281,7 +280,7 @@ class GreedyDecoder(TokenDecoder): def finalize(self, tokens: mx.array, sum_logprobs: mx.array): # make sure each sequence has at least one EOT token at the end tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot) - return tokens, sum_logprobs.tolist() + return tokens, sum_logprobs class LogitFilter: @@ -340,10 +339,10 @@ class ApplyTimestampRules(LogitFilter): if self.tokenizer.no_timestamps is not None: mask[:, self.tokenizer.no_timestamps] = -np.inf - # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly - for k in range(tokens.shape[0]): - sampled_tokens = tokens[k, self.sample_begin :] - seq = sampled_tokens.tolist() + ## timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + tokens = tokens.tolist() + for k in range(len(tokens)): + seq = tokens[k][self.sample_begin :] last_was_timestamp = ( len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin ) @@ -368,7 +367,7 @@ class ApplyTimestampRules(LogitFilter): last_timestamp += 1 mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf - if tokens.shape[1] == self.sample_begin: + if len(tokens[0]) == self.sample_begin: # suppress generating non-timestamp tokens at the beginning mask[:, : self.tokenizer.timestamp_begin] = -np.inf @@ -380,16 +379,20 @@ class ApplyTimestampRules(LogitFilter): mask[:, last_allowed + 1 :] = -np.inf # if sum of probability over timestamps is above any other token, sample timestamp + mask = mx.array(mask) logprobs = logits - mx.logsumexp(logits, axis=-1) - for k in range(tokens.shape[0]): - timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp( - axis=-1 - ) - max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() - if timestamp_logprob > max_text_token_logprob: - mask[k, : self.tokenizer.timestamp_begin] = -np.inf - - return logits + mx.array(mask, logits.dtype) + timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp( + axis=-1, keepdims=True + ) + max_text_token_logprob = logprobs[:, : self.tokenizer.timestamp_begin].max( + axis=-1, keepdims=True + ) + mask[:, : self.tokenizer.timestamp_begin] = mx.where( + timestamp_logprob > max_text_token_logprob, + -mx.inf, + mask[:, : self.tokenizer.timestamp_begin], + ) + return logits + mask class DecodingTask: @@ -424,7 +427,7 @@ class DecodingTask: self.sot_index: int = self.initial_tokens.index(tokenizer.sot) # inference: implements the forward pass through the decoder, including kv caching - self.inference = Inference(model, len(self.initial_tokens)) + self.inference = Inference(model) # sequence ranker: implements how to rank a group of sampled sequences self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) @@ -432,9 +435,6 @@ class DecodingTask: # decoder: implements how to select the next tokens, given the autoregressive distribution if options.beam_size is not None: raise NotImplementedError("Beam search decoder is not yet implemented") - # self.decoder = BeamSearchDecoder( - # options.beam_size, tokenizer.eot, self.inference, options.patience - # ) else: self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) @@ -448,6 +448,7 @@ class DecodingTask: self.logit_filters.append( SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab) ) + if not options.without_timestamps: precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds max_initial_timestamp_index = None @@ -570,35 +571,47 @@ class DecodingTask: def _main_loop(self, audio_features: mx.array, tokens: mx.array): n_batch = tokens.shape[0] - sum_logprobs: mx.array = mx.zeros(n_batch) - no_speech_probs = [np.nan] * n_batch + sum_logprobs = mx.zeros(n_batch) + + def _step(inputs, audio_features, tokens, sum_logprobs): + pre_logits = self.inference.logits(inputs, audio_features) + + # consider the logits at the last token only + logits = pre_logits[:, -1] + + # apply the logit filters, e.g. for suppressing or applying penalty to + for logit_filter in self.logit_filters: + logits = logit_filter.apply(logits, tokens) + + # expand the tokens tensor with the selected next tokens + tokens, completed, sum_logprobs = self.decoder.update( + tokens, logits, sum_logprobs + ) + return tokens, completed, sum_logprobs, pre_logits try: - for i in range(self.sample_len): - logits = self.inference.logits(tokens, audio_features) + tokens, completed, sum_logprobs, pre_logits = _step( + tokens, audio_features, tokens, sum_logprobs + ) + if self.tokenizer.no_speech is not None: # compute no_speech_probs + probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1) + no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech] + else: + no_speech_probs = mx.full(n_batch, mx.nan) + mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs) - if ( - i == 0 and self.tokenizer.no_speech is not None - ): # save no_speech_probs - probs_at_sot = mx.softmax( - logits[:, self.sot_index].astype(mx.float32), axis=-1 - ) - no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() - - # now we need to consider the logits at the last token only - logits = logits[:, -1] - - # apply the logit filters, e.g. for suppressing or applying penalty to - for logit_filter in self.logit_filters: - logits = logit_filter.apply(logits, tokens) - - # expand the tokens tensor with the selected next tokens - tokens, completed, sum_logprobs = self.decoder.update( - tokens, logits, sum_logprobs + for i in range(1, self.sample_len): + inputs = tokens[:, -1:] + next_tokens, next_completed, next_sum_logprobs, _ = _step( + inputs, audio_features, tokens, sum_logprobs ) - + mx.async_eval(next_completed, next_tokens, next_sum_logprobs) if completed or tokens.shape[-1] > self.n_ctx: break + tokens = next_tokens + completed = next_completed + sum_logprobs = next_sum_logprobs + finally: self.inference.reset() @@ -610,8 +623,8 @@ class DecodingTask: n_audio: int = mel.shape[0] audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass - tokens: np.array = np.array(self.initial_tokens) - tokens = np.broadcast_to(tokens, (n_audio, len(self.initial_tokens))).copy() + tokens: mx.array = mx.array(self.initial_tokens) + tokens = mx.broadcast_to(tokens, (n_audio, len(self.initial_tokens))) # detect language if requested, overwriting the language token languages, language_probs = self._detect_language(audio_features, tokens) @@ -626,7 +639,6 @@ class DecodingTask: ] # repeat tokens by the group size, for beam search or best-of-n sampling - tokens = mx.array(tokens) if self.n_group > 1: tokens = tokens[:, None, :] tokens = mx.broadcast_to( @@ -649,7 +661,13 @@ class DecodingTask: # get the final candidates for each group, and slice between the first sampled token and EOT tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) - tokens = tokens[..., self.sample_begin :].tolist() + tokens = tokens[..., self.sample_begin :] + + # eval and convert to list + mx.eval(tokens, sum_logprobs, no_speech_probs) + tokens = tokens.tolist() + sum_logprobs = sum_logprobs.tolist() + no_speech_probs = no_speech_probs.tolist() tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens] # select the top-ranked sample in each group diff --git a/whisper/mlx_whisper/load_models.py b/whisper/mlx_whisper/load_models.py index 6705385d..60766ab2 100644 --- a/whisper/mlx_whisper/load_models.py +++ b/whisper/mlx_whisper/load_models.py @@ -26,7 +26,10 @@ def load_model( model_args = whisper.ModelDimensions(**config) - weights = mx.load(str(model_path / "weights.npz")) + wf = model_path / "weights.safetensors" + if not wf.exists(): + wf = model_path / "weights.npz" + weights = mx.load(str(wf)) model = whisper.Whisper(model_args, dtype) diff --git a/whisper/mlx_whisper/transcribe.py b/whisper/mlx_whisper/transcribe.py index 786b4232..7057679b 100644 --- a/whisper/mlx_whisper/transcribe.py +++ b/whisper/mlx_whisper/transcribe.py @@ -293,6 +293,7 @@ def transcribe( decode_options["prompt"] = all_tokens[prompt_reset_since:] result: DecodingResult = decode_with_fallback(mel_segment) + tokens = np.array(result.tokens) if no_speech_threshold is not None: diff --git a/whisper/mlx_whisper/whisper.py b/whisper/mlx_whisper/whisper.py index e691792c..1c2b390e 100644 --- a/whisper/mlx_whisper/whisper.py +++ b/whisper/mlx_whisper/whisper.py @@ -80,12 +80,11 @@ class MultiHeadAttention(nn.Module): qk = q @ k if mask is not None: qk = qk + mask[:n_ctx, :n_ctx] - qk = qk.astype(mx.float32) - w = mx.softmax(qk, axis=-1).astype(q.dtype) + w = mx.softmax(qk, axis=-1, precise=True) out = (w @ v).transpose(0, 2, 1, 3) out = out.reshape(n_batch, n_ctx, n_state) - return out, qk + return out, qk.astype(mx.float32) class ResidualAttentionBlock(nn.Module): From e510987870fdc0c9741d8448fe37a776d2ee52a0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 1 Nov 2024 14:15:32 -0700 Subject: [PATCH 18/77] Clear cache every now and then (#1081) * clear cache every now and then * don't need user arg anymore --- llms/mlx_lm/generate.py | 9 --------- llms/mlx_lm/utils.py | 4 ++++ 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0355ca29..29976da2 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -90,12 +90,6 @@ def setup_arg_parser(): action="store_true", help="Colorize output based on T[0] probability", ) - parser.add_argument( - "--cache-limit-gb", - type=int, - default=None, - help="Set the MLX cache limit in GB", - ) parser.add_argument( "--max-kv-size", type=int, @@ -164,9 +158,6 @@ def main(): mx.random.seed(args.seed) - if args.cache_limit_gb is not None: - mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - # Load the prompt cache and metadata if a cache file is provided using_cache = args.prompt_cache_file is not None if using_cache: diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 06784f10..b9fc202d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -310,10 +310,14 @@ def generate_step( y, logprobs = _step(y) mx.async_eval(y, logprobs) + n = 0 while True: next_y, next_logprobs = _step(y) mx.async_eval(next_y, next_logprobs) yield y.item(), logprobs + if n % 256 == 0: + mx.metal.clear_cache() + n += 1 y, logprobs = next_y, next_logprobs From 0f799947d0c73ff4901ce17188aceaa933b3c02e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 1 Nov 2024 16:30:32 -0700 Subject: [PATCH 19/77] fix (#1079) --- llms/mlx_lm/tokenizer_utils.py | 11 +++++++++-- llms/tests/test_tokenizers.py | 11 +++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 0cbc3b9b..568a672d 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -186,6 +186,8 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): # https://github.com/openai/gpt-2/blob/master/src/encoder.py self.make_byte_decoder() + self._added_ids = set(tokenizer.added_tokens_decoder.keys()) + def reset(self): self.offset = 0 self._unflushed = "" @@ -205,12 +207,17 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): def add_token(self, token): v = self.tokenmap[token] - if self._byte_decoder[v[0]] == 32: + is_added = token in self._added_ids + if is_added or self._byte_decoder[v[0]] == 32: current_text = bytearray( self._byte_decoder[c] for c in self._unflushed ).decode("utf-8") self.text += self._maybe_trim_space(current_text) - self._unflushed = v + if is_added: + self.text += v + self._unflushed = "" + else: + self._unflushed = v else: self._unflushed += v diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py index 03445c1f..3c93fbe2 100644 --- a/llms/tests/test_tokenizers.py +++ b/llms/tests/test_tokenizers.py @@ -74,6 +74,17 @@ class TestTokenizers(unittest.TestCase): tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer) self.check_tokenizer(tokenizer) + def test_special_tokens(self): + tokenizer_repo = "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" + tokenizer = self.download_tokenizer(tokenizer_repo) + + detokenizer = tokenizer.detokenizer + detokenizer.reset() + detokenizer.add_token(tokenizer.eos_token_id) + detokenizer.finalize() + + self.assertEqual(detokenizer.last_segment, tokenizer.eos_token) + if __name__ == "__main__": unittest.main() From 29c954f4cb3eb708a6b7115327168ca83c5c0972 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 2 Nov 2024 13:51:38 -0700 Subject: [PATCH 20/77] fix (#1082) --- whisper/mlx_whisper/_version.py | 2 +- whisper/mlx_whisper/decoding.py | 47 ++++++++++++++++----------------- 2 files changed, 24 insertions(+), 25 deletions(-) diff --git a/whisper/mlx_whisper/_version.py b/whisper/mlx_whisper/_version.py index 45e522d1..8280e038 100644 --- a/whisper/mlx_whisper/_version.py +++ b/whisper/mlx_whisper/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.4.0" +__version__ = "0.4.1" diff --git a/whisper/mlx_whisper/decoding.py b/whisper/mlx_whisper/decoding.py index 6bf975d5..4e060cd5 100644 --- a/whisper/mlx_whisper/decoding.py +++ b/whisper/mlx_whisper/decoding.py @@ -589,35 +589,34 @@ class DecodingTask: ) return tokens, completed, sum_logprobs, pre_logits - try: - tokens, completed, sum_logprobs, pre_logits = _step( - tokens, audio_features, tokens, sum_logprobs + tokens, completed, sum_logprobs, pre_logits = _step( + tokens, audio_features, tokens, sum_logprobs + ) + if self.tokenizer.no_speech is not None: # compute no_speech_probs + probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1) + no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech] + else: + no_speech_probs = mx.full(n_batch, mx.nan) + mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs) + + for i in range(1, self.sample_len): + inputs = tokens[:, -1:] + if tokens.shape[-1] > self.n_ctx: + break + next_tokens, next_completed, next_sum_logprobs, _ = _step( + inputs, audio_features, tokens, sum_logprobs ) - if self.tokenizer.no_speech is not None: # compute no_speech_probs - probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1) - no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech] - else: - no_speech_probs = mx.full(n_batch, mx.nan) - mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs) - - for i in range(1, self.sample_len): - inputs = tokens[:, -1:] - next_tokens, next_completed, next_sum_logprobs, _ = _step( - inputs, audio_features, tokens, sum_logprobs - ) - mx.async_eval(next_completed, next_tokens, next_sum_logprobs) - if completed or tokens.shape[-1] > self.n_ctx: - break - tokens = next_tokens - completed = next_completed - sum_logprobs = next_sum_logprobs - - finally: - self.inference.reset() + mx.async_eval(next_completed, next_tokens, next_sum_logprobs) + if completed: + break + tokens = next_tokens + completed = next_completed + sum_logprobs = next_sum_logprobs return tokens, sum_logprobs, no_speech_probs def run(self, mel: mx.array) -> List[DecodingResult]: + self.inference.reset() self.decoder.reset() tokenizer: Tokenizer = self.tokenizer n_audio: int = mel.shape[0] From 331148d8ec05ce2f1dd50444570c61805b700039 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 2 Nov 2024 18:02:31 -0700 Subject: [PATCH 21/77] Enable distributed LoRA training (#821) --- llms/mlx_lm/tuner/trainer.py | 81 ++++++++++++++++++++++++------------ llms/tests/test_finetune.py | 51 ++++++++++++++--------- 2 files changed, 86 insertions(+), 46 deletions(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 1d934a72..38619d95 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -10,6 +10,7 @@ from typing import Union import mlx.core as mx import mlx.nn as nn import numpy as np +from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten @@ -84,9 +85,16 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) f" examples but only has {len(dataset)}." ) + # If running in distributed mode (N machines) then each one should skip N-1 + # samples + step = mx.distributed.init().size() + if batch_size % step != 0: + raise ValueError("The batch size must be divisible by the number of workers") + # Make the batches: batch_idx = [ - idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size) + idx[i : i + batch_size : step] + for i in range(0, len(idx) - batch_size + 1, batch_size) ] while True: @@ -112,9 +120,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) max_length_in_batch = min(max_length_in_batch, max_seq_length) - batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32) + batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - for j in range(batch_size): + for j in range(batch_size // step): truncated_length = min(lengths[j], max_seq_length) batch_arr[j, :truncated_length] = batch[j][:truncated_length] lengths[j] = ( @@ -138,7 +146,7 @@ def evaluate( loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): - all_losses = [] + all_losses = 0 ntokens = 0 index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) @@ -153,10 +161,14 @@ def evaluate( ), ): losses, toks = loss(model, *batch) - all_losses.append((losses * toks).item()) - ntokens += toks.item() + all_losses += losses * toks + ntokens += toks + mx.eval(all_losses, ntokens) - return np.sum(all_losses) / ntokens + all_losses = mx.distributed.all_sum(all_losses) + ntokens = mx.distributed.all_sum(ntokens) + + return (all_losses / ntokens).item() class TrainingCallback: @@ -182,6 +194,11 @@ def train( training_callback: TrainingCallback = None, ): print(f"Starting training..., iters: {args.iters}") + world = mx.distributed.init() + world_size = world.size() + rank = world.rank() + if world_size > 1: + print(f"Node {rank} of {world_size}") if args.grad_checkpoint: grad_checkpoint(model.layers[0]) @@ -192,6 +209,9 @@ def train( # Forward and backward pass (lvalue, toks), grad = loss_value_and_grad(model, *batch) + # All reduce the gradients if running in distributed mode + grad = average_gradients(grad) + # Model update optimizer.update(model, grad) @@ -199,8 +219,9 @@ def train( loss_value_and_grad = nn.value_and_grad(model, loss) - losses = [] + losses = 0 n_tokens = 0 + steps = 0 trained_tokens = 0 # Main training loop start = time.perf_counter() @@ -229,9 +250,13 @@ def train( iterate_batches=iterate_batches, ) val_time = time.perf_counter() - stop - print( - f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s" - ) + if rank == 0: + print( + f"Iter {it}: " + f"Val loss {val_loss:.3f}, " + f"Val took {val_time:.3f}s", + flush=True, + ) if training_callback is not None: val_info = { @@ -244,30 +269,33 @@ def train( start = time.perf_counter() lvalue, toks = step(batch) - mx.eval(state, lvalue, toks) - - # Record loss - losses.append(lvalue.item()) - n_tokens += toks.item() + losses += lvalue + n_tokens += toks + steps += 1 + mx.eval(state, losses, n_tokens) # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: stop = time.perf_counter() - train_loss = np.mean(losses) + train_loss = mx.distributed.all_sum(losses).item() + train_loss /= steps * mx.distributed.init().size() + n_tokens = mx.distributed.all_sum(n_tokens).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 2**30 - print( - f"Iter {it}: Train loss {train_loss:.3f}, " - f"Learning Rate {learning_rate:.3e}, " - f"It/sec {it_sec:.3f}, " - f"Tokens/sec {tokens_sec:.3f}, " - f"Trained Tokens {trained_tokens}, " - f"Peak mem {peak_mem:.3f} GB" - ) + if rank == 0: + print( + f"Iter {it}: Train loss {train_loss:.3f}, " + f"Learning Rate {learning_rate:.3e}, " + f"It/sec {it_sec:.3f}, " + f"Tokens/sec {tokens_sec:.3f}, " + f"Trained Tokens {trained_tokens}, " + f"Peak mem {peak_mem:.3f} GB", + flush=True, + ) if training_callback is not None: train_info = { @@ -281,8 +309,9 @@ def train( } training_callback.on_train_loss_report(train_info) - losses = [] + losses = 0 n_tokens = 0 + steps = 0 start = time.perf_counter() # Save adapter weights diff --git a/llms/tests/test_finetune.py b/llms/tests/test_finetune.py index 107be092..6ba81628 100644 --- a/llms/tests/test_finetune.py +++ b/llms/tests/test_finetune.py @@ -3,6 +3,7 @@ import math import sys import unittest +from contextlib import contextmanager from io import StringIO from unittest.mock import MagicMock @@ -17,6 +18,14 @@ from mlx_lm.tuner.trainer import evaluate from mlx_lm.tuner.utils import build_schedule +@contextmanager +def swapped_with_identity(obj, func): + old_func = getattr(obj, func) + setattr(obj, func, lambda x: x) + yield + setattr(obj, func, old_func) + + class TestLora(unittest.TestCase): def setUp(self): self.capturedOutput = StringIO() @@ -374,16 +383,17 @@ class TestScheduleConfig(unittest.TestCase): (MagicMock(return_value=0.4), MagicMock(return_value=180)), (MagicMock(return_value=0.6), MagicMock(return_value=120)), ] - evaluate( - model=mock_model, - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - num_batches=2, - max_seq_length=2048, - loss=mock_default_loss, - iterate_batches=mock_iterate_batches, - ) + with swapped_with_identity(mx.distributed, "all_sum"): + evaluate( + model=mock_model, + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + num_batches=2, + max_seq_length=2048, + loss=mock_default_loss, + iterate_batches=mock_iterate_batches, + ) mock_iterate_batches.assert_called_once_with( dataset=mock_dataset, @@ -412,16 +422,17 @@ class TestScheduleConfig(unittest.TestCase): (MagicMock(return_value=0.2), MagicMock(return_value=150)), ] - evaluate( - model=mock_model, - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - num_batches=-1, - max_seq_length=2048, - loss=mock_default_loss, - iterate_batches=mock_iterate_batches, - ) + with swapped_with_identity(mx.distributed, "all_sum"): + evaluate( + model=mock_model, + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + num_batches=-1, + max_seq_length=2048, + loss=mock_default_loss, + iterate_batches=mock_iterate_batches, + ) mock_iterate_batches.assert_called_once_with( dataset=mock_dataset, From 82e333898707eb57235f408aa6907beca095f759 Mon Sep 17 00:00:00 2001 From: Anchen Date: Mon, 4 Nov 2024 22:06:34 +0800 Subject: [PATCH 22/77] chore(mlx-lm): add max token arg for mlx_lm.chat (#1089) * chore(mlx-lm): add max token arg for mlx_lm.chat * chore: update the default max token value --- llms/mlx_lm/chat.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index ea1a99c7..85d32d5f 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -11,6 +11,7 @@ from .utils import load, stream_generate DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 +DEFAULT_MAX_TOKENS = 256 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" @@ -41,6 +42,13 @@ def setup_arg_parser(): help="Set the maximum key-value cache size", default=None, ) + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=DEFAULT_MAX_TOKENS, + help="Maximum number of tokens to generate", + ) return parser @@ -70,6 +78,7 @@ def main(): model, tokenizer, prompt, + args.max_tokens, temp=args.temp, top_p=args.top_p, prompt_cache=prompt_cache, From 3b526f0aa1219fae662a86f012dbda82045f4fb0 Mon Sep 17 00:00:00 2001 From: ilyasch2 <104485953+ilyasch2@users.noreply.github.com> Date: Tue, 5 Nov 2024 00:23:30 +0400 Subject: [PATCH 23/77] Add support for falcon-mamba (#1074) * Add support for falcon-mamba * nits * nit --------- Co-authored-by: Awni Hannun --- llms/README.md | 1 + llms/mlx_lm/models/mamba.py | 11 +++++++++++ llms/mlx_lm/utils.py | 1 + 3 files changed, 13 insertions(+) diff --git a/llms/README.md b/llms/README.md index f539988a..0e7dc7fb 100644 --- a/llms/README.md +++ b/llms/README.md @@ -221,6 +221,7 @@ Here are a few examples of Hugging Face models that work with this example: - [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct) - [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b) - [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b) +- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct) Most [Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending), diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 84f498e9..f2414660 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs): use_conv_bias: bool time_step_rank: int tie_word_embeddings: bool = True + use_bcdt_rms: bool = False + mixer_rms_eps: float = 1e-6 def __post_init__(self): if not hasattr(self, "hidden_size") and hasattr(self, "d_model"): @@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs): if self.time_step_rank == "auto": self.time_step_rank = math.ceil(self.hidden_size / 16) + if self.model_type == "falcon_mamba": + self.use_bcdt_rms = True class DepthWiseConv1d(nn.Module): @@ -83,6 +87,11 @@ class MambaBlock(nn.Module): self.intermediate_size = args.intermediate_size self.time_step_rank = int(args.time_step_rank) self.use_conv_bias = args.use_conv_bias + self.use_bcdt_rms = args.use_bcdt_rms + if self.use_bcdt_rms: + self.mixer_norm = lambda x: mx.fast.rms_norm( + x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps + ) self.in_proj = nn.Linear( self.hidden_size, self.intermediate_size * 2, bias=args.use_bias @@ -126,6 +135,8 @@ class MambaBlock(nn.Module): ], axis=-1, ) + if self.use_bcdt_rms: + delta, B, C = map(self.mixer_norm, (delta, B, C)) delta = nn.softplus(self.dt_proj(delta)) new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) if state is not None: diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b9fc202d..7b440db6 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -29,6 +29,7 @@ from .tuner.utils import load_adapters MODEL_REMAPPING = { "mistral": "llama", # mistral is compatible with llama "phi-msft": "phixtral", + "falcon_mamba": "mamba", } MAX_FILE_SIZE_GB = 5 From 4394633ce0f9d96cbbdf571e077fa4fd78479b9f Mon Sep 17 00:00:00 2001 From: Anthony Wu <462072+anthonywu@users.noreply.github.com> Date: Mon, 4 Nov 2024 14:02:13 -0800 Subject: [PATCH 24/77] mlx_whisper: add support for audio input from stdin (#1012) * add support for audio and input name from stdin * refactored to stdin - arg, and output-name template * fix bugs, add test coverage * fix doc to match arg rename * some nits --------- Co-authored-by: Awni Hannun --- whisper/README.md | 13 +++++++++++-- whisper/mlx_whisper/audio.py | 18 ++++++++++-------- whisper/mlx_whisper/cli.py | 34 +++++++++++++++++++++++++++------- whisper/mlx_whisper/writers.py | 14 +++++--------- 4 files changed, 53 insertions(+), 26 deletions(-) diff --git a/whisper/README.md b/whisper/README.md index ac6e95f6..cd3bc684 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -25,7 +25,7 @@ pip install mlx-whisper At its simplest: -``` +```sh mlx_whisper audio_file.mp3 ``` @@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There are many other supported command line options. To see them all, run `mlx_whisper -h`. +You can also pipe the audio content of other programs via stdin: + +```sh +some-process | mlx_whisper - +``` + +The default output file name will be `content.*`. You can specify the name with +the `--output-name` flag. + #### API Transcribe audio with: @@ -103,7 +112,7 @@ python convert.py --help ``` By default, the conversion script will make the directory `mlx_models` -and save the converted `weights.npz` and `config.json` there. +and save the converted `weights.npz` and `config.json` there. Each time it is run, `convert.py` will overwrite any model in the provided path. To save different models, make sure to set `--mlx-path` to a unique diff --git a/whisper/mlx_whisper/audio.py b/whisper/mlx_whisper/audio.py index e04309c1..c8cca07c 100644 --- a/whisper/mlx_whisper/audio.py +++ b/whisper/mlx_whisper/audio.py @@ -3,7 +3,7 @@ import os from functools import lru_cache from subprocess import CalledProcessError, run -from typing import Union +from typing import Optional, Union import mlx.core as mx import numpy as np @@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token -def load_audio(file: str, sr: int = SAMPLE_RATE): +def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False): """ Open an audio file and read as mono waveform, resampling as necessary @@ -39,19 +39,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): """ # This launches a subprocess to decode audio while down-mixing - # and resampling as necessary. Requires the ffmpeg CLI in PATH. + # and resampling as necessary. Requires the ffmpeg CLI in PATH. + if from_stdin: + cmd = ["ffmpeg", "-i", "pipe:0"] + else: + cmd = ["ffmpeg", "-nostdin", "-i", file] + # fmt: off - cmd = [ - "ffmpeg", - "-nostdin", + cmd.extend([ "-threads", "0", - "-i", file, "-f", "s16le", "-ac", "1", "-acodec", "pcm_s16le", "-ar", str(sr), "-" - ] + ]) # fmt: on try: out = run(cmd, capture_output=True, check=True).stdout diff --git a/whisper/mlx_whisper/cli.py b/whisper/mlx_whisper/cli.py index c2813338..7d08a043 100644 --- a/whisper/mlx_whisper/cli.py +++ b/whisper/mlx_whisper/cli.py @@ -2,9 +2,11 @@ import argparse import os +import pathlib import traceback import warnings +from . import audio from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE from .transcribe import transcribe from .writers import get_writer @@ -27,15 +29,24 @@ def build_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument( - "audio", nargs="+", type=str, help="Audio file(s) to transcribe" - ) + + parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe") + parser.add_argument( "--model", default="mlx-community/whisper-tiny", type=str, help="The model directory or hugging face repo", ) + parser.add_argument( + "--output-name", + type=str, + default=None, + help=( + "The name of transcription/translation output files before " + "--output-format extensions" + ), + ) parser.add_argument( "--output-dir", "-o", @@ -200,6 +211,7 @@ def main(): path_or_hf_repo: str = args.pop("model") output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") + output_name: str = args.pop("output_name") os.makedirs(output_dir, exist_ok=True) writer = get_writer(output_format, output_dir) @@ -219,17 +231,25 @@ def main(): warnings.warn("--max-line-count has no effect without --max-line-width") if writer_args["max_words_per_line"] and writer_args["max_line_width"]: warnings.warn("--max-words-per-line has no effect with --max-line-width") - for audio_path in args.pop("audio"): + + for audio_obj in args.pop("audio"): + if audio_obj == "-": + # receive the contents from stdin rather than read a file + audio_obj = audio.load_audio(from_stdin=True) + + output_name = output_name or "content" + else: + output_name = output_name or pathlib.Path(audio_obj).stem try: result = transcribe( - audio_path, + audio_obj, path_or_hf_repo=path_or_hf_repo, **args, ) - writer(result, audio_path, **writer_args) + writer(result, output_name, **writer_args) except Exception as e: traceback.print_exc() - print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}") + print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}") if __name__ == "__main__": diff --git a/whisper/mlx_whisper/writers.py b/whisper/mlx_whisper/writers.py index 464ead18..cdb35063 100644 --- a/whisper/mlx_whisper/writers.py +++ b/whisper/mlx_whisper/writers.py @@ -1,10 +1,8 @@ # Copyright © 2024 Apple Inc. import json -import os +import pathlib import re -import sys -import zlib from typing import Callable, List, Optional, TextIO @@ -43,15 +41,13 @@ class ResultWriter: self.output_dir = output_dir def __call__( - self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs + self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs ): - audio_basename = os.path.basename(audio_path) - audio_basename = os.path.splitext(audio_basename)[0] - output_path = os.path.join( - self.output_dir, audio_basename + "." + self.extension + output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix( + f".{self.extension}" ) - with open(output_path, "w", encoding="utf-8") as f: + with output_path.open("wt", encoding="utf-8") as f: self.write_result(result, file=f, options=options, **kwargs) def write_result( From 6fd1f70f7366a1e55f14e2b4cd885b86875ab56c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 5 Nov 2024 06:06:26 -0800 Subject: [PATCH 25/77] fix spm decoder multi-byte (#1092) --- llms/mlx_lm/tokenizer_utils.py | 40 +++++++++++++++------------------- llms/tests/test_tokenizers.py | 3 +++ 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 568a672d..9d390733 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -6,12 +6,6 @@ from transformers import AutoTokenizer REPLACEMENT_CHAR = "\ufffd" -def _remove_space(x): - if x and x[0] == " ": - return x[1:] - return x - - class StreamingDetokenizer: """The streaming detokenizer interface so that we can detokenize one token at a time. @@ -123,42 +117,42 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): def __init__(self, tokenizer, trim_space=True): self.trim_space = trim_space + self._sep = "\u2581".encode() # Extract the tokens in a list from id to text self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1) for value, tokenid in tokenizer.vocab.items(): - self.tokenmap[tokenid] = value - - # Replace bytes with their value - for i in range(len(self.tokenmap)): - if self.tokenmap[i].startswith("<0x"): - self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16)) + if value.startswith("<0x"): + # Replace bytes with their value + self.tokenmap[tokenid] = bytes([int(value[3:5], 16)]) + else: + self.tokenmap[tokenid] = value.encode() self.reset() def reset(self): self.offset = 0 - self._unflushed = "" + self._unflushed = b"" self.text = "" self.tokens = [] + def _flush(self): + text = self._unflushed.replace(self._sep, b" ").decode("utf-8") + if not self.text and self.trim_space and text and text[0] == " ": + text = text[1:] + self.text += text + def add_token(self, token): v = self.tokenmap[token] - if v[0] == "\u2581": - if self.text or not self.trim_space: - self.text += self._unflushed.replace("\u2581", " ") - else: - self.text = _remove_space(self._unflushed.replace("\u2581", " ")) + if v.startswith(self._sep): + self._flush() self._unflushed = v else: self._unflushed += v def finalize(self): - if self.text or not self.trim_space: - self.text += self._unflushed.replace("\u2581", " ") - else: - self.text = _remove_space(self._unflushed.replace("\u2581", " ")) - self._unflushed = "" + self._flush() + self._unflushed = b"" class BPEStreamingDetokenizer(StreamingDetokenizer): diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py index 3c93fbe2..9c30d51e 100644 --- a/llms/tests/test_tokenizers.py +++ b/llms/tests/test_tokenizers.py @@ -42,6 +42,9 @@ class TestTokenizers(unittest.TestCase): text += detokenizer.last_segment self.assertEqual(text, expected_text) + tokens = tokenizer.encode("こんにちは!私の名前はAI") + check(tokens) + tokens = tokenizer.encode("a ,b") check(tokens) From ed9e81dd581a9505e677e12c025137d5326fe6df Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 5 Nov 2024 10:24:24 -0800 Subject: [PATCH 26/77] Fix rotating kv cache size (#1093) --- llms/mlx_lm/models/base.py | 2 +- llms/mlx_lm/models/cache.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index cda41c79..f02f49b1 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -42,7 +42,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): if cache is not None and cache[0] is not None: c = cache[0] if hasattr(c, "max_size"): - offset = min(c.max_size - 1, c.offset) + offset = min(c.max_size, c.offset) window_size = c.max_size else: offset = c.offset diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 1cd5289d..14026f0c 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -325,9 +325,9 @@ class RotatingKVCache(_BaseCache): self.keys = self._temporal_order(self.keys) self.values = self._temporal_order(self.values) - # The largest size is self.max_size + S - 1 to ensure + # The largest size is self.max_size + S to ensure # every token gets at least self.max_size context - trim_size = self._idx - self.max_size + 1 + trim_size = self._idx - self.max_size self.keys = self._trim(trim_size, self.keys, keys) self.values = self._trim(trim_size, self.values, values) self.offset += keys.shape[2] From 657b4cc0aa90af09ac9793168cb81d406db882c6 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 7 Nov 2024 16:15:24 -0800 Subject: [PATCH 27/77] [MLX LM] Sampler refactor + a few improvements (#1094) * starting * refactor sampler/processor and a few improvements * fix stream * fix stream generate * fix eos handling in stream generate --- llms/README.md | 5 +- llms/mlx_lm/cache_prompt.py | 4 +- llms/mlx_lm/chat.py | 2 +- llms/mlx_lm/generate.py | 14 +++ llms/mlx_lm/sample_utils.py | 106 ++++++++++++++++++ llms/mlx_lm/server.py | 193 ++++++++++++-------------------- llms/mlx_lm/tuner/trainer.py | 2 +- llms/mlx_lm/utils.py | 168 ++++++++++----------------- llms/tests/test_generate.py | 2 +- llms/tests/test_prompt_cache.py | 2 +- 10 files changed, 259 insertions(+), 239 deletions(-) diff --git a/llms/README.md b/llms/README.md index 0e7dc7fb..eeb3ed6a 100644 --- a/llms/README.md +++ b/llms/README.md @@ -101,7 +101,8 @@ To see a description of all the arguments you can do: #### Streaming For streaming generation, use the `stream_generate` function. This returns a -generator object which streams the output text. For example, +generator object which streams the output text, token, and log probabilities. +For example, ```python from mlx_lm import load, stream_generate @@ -116,7 +117,7 @@ prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) -for t in stream_generate(model, tokenizer, prompt, max_tokens=512): +for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512): print(t, end="", flush=True) print() ``` diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 7bb06411..987b640d 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -152,6 +152,7 @@ def main(): model(y[:step_size][None], cache=cache) mx.eval([c.state for c in cache]) + mx.metal.clear_cache() processed += min(y.size, step_size) y = y[step_size:] current = time.time() @@ -165,14 +166,13 @@ def main(): ) print() - print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") + print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB") print("Saving...") metadata = {} metadata["model"] = args.model metadata["chat_template"] = tokenizer.chat_template metadata["tokenizer_config"] = json.dumps(tokenizer_config) - print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") save_prompt_cache(args.prompt_cache_file, cache, metadata) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 85d32d5f..c03056a6 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -74,7 +74,7 @@ def main(): prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - for response in stream_generate( + for response, *_ in stream_generate( model, tokenizer, prompt, diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 29976da2..51169def 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -13,6 +13,8 @@ DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 +DEFAULT_MIN_P = 0.0 +DEFAULT_MIN_TOKENS_TO_KEEP = 1 DEFAULT_SEED = 0 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_QUANTIZED_KV_START = 5000 @@ -52,6 +54,7 @@ def setup_arg_parser(): ) parser.add_argument( "--prompt", + "-p", default=DEFAULT_PROMPT, help="Message to be processed by the model ('-' reads from stdin)", ) @@ -68,6 +71,15 @@ def setup_arg_parser(): parser.add_argument( "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" ) + parser.add_argument( + "--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p" + ) + parser.add_argument( + "--min-tokens-to-keep", + type=float, + default=DEFAULT_MIN_TOKENS_TO_KEEP, + help="Minimum tokens to keep for min-p sampling.", + ) parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") parser.add_argument( "--ignore-chat-template", @@ -247,6 +259,8 @@ def main(): formatter=formatter, temp=args.temp, top_p=args.top_p, + min_p=args.min_p, + min_tokens_to_keep=args.min_tokens_to_keep, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, kv_bits=args.kv_bits, diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 20b008fa..c27b52d8 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -1,10 +1,83 @@ # Copyright © 2023-2024 Apple Inc. from functools import partial +from typing import Callable, Dict, Optional import mlx.core as mx +def make_sampler( + temp: float = 0.0, + top_p: float = 0.0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, +) -> Callable[mx.array, mx.array]: + """ + Make a sampler function for use with ``generate_step``. + + Args: + temp (float): The temperature for sampling, if 0 the argmax is used. + Default: ``0``. + top_p (float, optional): Nulceus sampling, higher means model considers + more less likely words. + min_p (float, optional): The minimum value (scaled by the top token's + probability) that a token probability must have to be considered. + min_tokens_to_keep (int, optional): Minimum number of tokens that cannot + be filtered by min_p sampling. + + Returns: + Callable[mx.array, mx.array]: + A sampler which takes log-probabilities and returns tokens. + """ + if temp == 0: + return lambda x: mx.argmax(x, axis=-1) + elif top_p > 0 and top_p < 1.0: + return lambda x: top_p_sampling(x, top_p, temp) + elif min_p != 0.0: + return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp) + else: + return lambda x: categorical_sampling(x, temp) + + +def make_logits_processors( + logit_bias: Optional[Dict[int, float]] = None, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = 20, +): + """ + Make logits processors for use with ``generate_step``. + + Args: + repetition_penalty (float, optional): The penalty factor for repeating + tokens. + repetition_context_size (int, optional): The number of tokens to + consider for repetition penalty. Default: ``20``. + logit_bias (dictionary, optional): Additive logit bias. + + Returns: + List[Callable[[mx.array, mx.array], mx.array]]: + A list of logits processors. Each processor in the list is a + callable which takes an array of tokens and an array of logits + and returns the updated logits. + """ + logits_processors = [] + if logit_bias: + indices = mx.array(list(logit_bias.keys())) + values = mx.array(list(logit_bias.values())) + + def logit_bias_processor(_, logits): + logits[:, indices] += values + return logits + + logits_processors.append(logit_bias_processor) + + if repetition_penalty and repetition_penalty != 0.0: + logits_processors.append( + make_repetition_penalty(repetition_penalty, repetition_context_size) + ) + return logits_processors + + @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def min_p_sampling( logits: mx.array, @@ -100,3 +173,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def categorical_sampling(logits, temp): return mx.random.categorical(logits * (1 / temp)) + + +def make_repetition_penalty(penalty: float, context_size: int = 20): + """ + Make repetition penalty processor. + + Paper: https://arxiv.org/abs/1909.05858 + + Args: + penalty (float): The repetition penalty factor to be applied. + context_size (int): The number of previous tokens to use. + Default: ``20``. + + Returns: + Callable[[mx.array, List[int]], mx.array]: + The repetition penalty processor. + """ + if penalty < 0 or not isinstance(penalty, float): + raise ValueError(f"penalty must be a non-negative float, got {penalty}") + + def repetition_penalty_processor(tokens, logits): + if len(tokens) > 0: + tokens = tokens[-context_size:] + selected_logits = logits[:, tokens] + selected_logits = mx.where( + selected_logits < 0, + selected_logits * penalty, + selected_logits / penalty, + ) + logits[:, tokens] = selected_logits + return logits + + return repetition_penalty_processor diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index ec659969..c1365b36 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -27,7 +27,7 @@ from huggingface_hub import scan_cache_dir from ._version import __version__ from .models.cache import make_prompt_cache -from .utils import generate_step, load +from .utils import load, stream_generate def get_system_fingerprint(): @@ -64,7 +64,7 @@ def stopping_criteria( end if it has (`trim_length`). """ if tokens and tokens[-1] == eos_token_id: - return StopCondition(stop_met=True, trim_length=1) + return StopCondition(stop_met=True, trim_length=0) for stop_ids in stop_id_sequences: if len(tokens) >= len(stop_ids): @@ -253,7 +253,7 @@ class APIHandler(BaseHTTPRequestHandler): self.max_tokens = self.body.get("max_completion_tokens", None) if self.max_tokens is None: self.max_tokens = self.body.get("max_tokens", 512) - self.temperature = self.body.get("temperature", 1.0) + self.temperature = self.body.get("temperature", 0.0) self.top_p = self.body.get("top_p", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0) self.repetition_context_size = self.body.get("repetition_context_size", 20) @@ -290,10 +290,7 @@ class APIHandler(BaseHTTPRequestHandler): # Call endpoint specific method prompt = endpoints[self.path]() - - # Call method based on response type - method = self.handle_stream if self.stream else self.handle_completion - method(prompt, stop_id_sequences) + self.handle_completion(prompt, stop_id_sequences) def validate_model_parameters(self): """ @@ -452,32 +449,40 @@ class APIHandler(BaseHTTPRequestHandler): stop_id_sequences (List[List[int]]): A list of stop words passed to the stopping_criteria function """ - detokenizer = self.tokenizer.detokenizer - detokenizer.reset() tokens = [] finish_reason = "length" stop_sequence_suffix = None - logging.debug(f"Starting completion:") + if self.stream: + self.end_headers() + logging.debug(f"Starting stream:") + else: + logging.debug(f"Starting completion:") token_logprobs = [] top_tokens = [] prompt = self.get_prompt_cache(prompt) - for _, (token, logprobs) in zip( - range(self.max_tokens), - generate_step( - prompt=mx.array(prompt), + text = "" + tic = time.perf_counter() + for n, (segment, token, logprobs) in enumerate( + stream_generate( model=self.model, + tokenizer=self.tokenizer, + prompt=prompt, + max_tokens=self.max_tokens, temp=self.temperature, - top_p=self.top_p, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, logit_bias=self.logit_bias, prompt_cache=self.prompt_cache.cache, ), ): - detokenizer.add_token(token) - logging.debug(detokenizer.text) + if n == 0: + prompt_time = time.perf_counter() - tic + tic = time.perf_counter() + + text += segment + logging.debug(text) tokens.append(token) if self.logprobs > 0: @@ -498,121 +503,63 @@ class APIHandler(BaseHTTPRequestHandler): stop_sequence_suffix = self.tokenizer.decode( tokens[-stop_condition.trim_length :] ) + text = text[: -len(stop_sequence_suffix)] break - self.prompt_cache.tokens.extend(tokens) - detokenizer.finalize() - text = ( - detokenizer.text - if stop_sequence_suffix is None - else detokenizer.text[: -len(stop_sequence_suffix)] - ) - response = self.generate_response( - text, - finish_reason, - len(prompt), - len(tokens), - token_logprobs=token_logprobs, - top_tokens=top_tokens, - tokens=tokens, - ) - - response_json = json.dumps(response).encode() - indent = "\t" # Backslashes can't be inside of f-strings - logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") - - # Send an additional Content-Length header when it is known - self.send_header("Content-Length", str(len(response_json))) - self.end_headers() - - self.wfile.write(response_json) - self.wfile.flush() - - def handle_stream( - self, - prompt: List[int], - stop_id_sequences: List[List[int]], - ): - """ - Generate response to prompt and foward it to the client using a Server - Sent Events (SSE) stream. - - Args: - prompt (mx.array): The tokenized prompt - stop_id_sequences (List[List[int]]): A list of stop words passed to - the stopping_criteria function - """ - # No additional headers are needed, call end_headers - self.end_headers() - - detokenizer = self.tokenizer.detokenizer - detokenizer.reset() - tokens = [] - - stop_sequence_suffix = None - logging.debug(f"Starting stream:") - - prompt = self.get_prompt_cache(prompt) - - for _, (token, _) in zip( - range(self.max_tokens), - generate_step( - prompt=mx.array(prompt), - model=self.model, - temp=self.temperature, - top_p=self.top_p, - repetition_penalty=self.repetition_penalty, - repetition_context_size=self.repetition_context_size, - prompt_cache=self.prompt_cache.cache, - ), - ): - detokenizer.add_token(token) - logging.debug(detokenizer.text) - tokens.append(token) - - stop_condition = stopping_criteria( - tokens, - stop_id_sequences, - self.tokenizer.eos_token_id, - ) - if stop_condition.stop_met: - if stop_condition.trim_length: - stop_sequence_suffix = self.tokenizer.decode( - tokens[-stop_condition.trim_length :] + if self.stream: + # If the end of tokens overlaps with a stop sequence, generate new + # tokens until we know if the stop sequence is hit or not + if any( + ( + sequence_overlap(tokens, sequence) + for sequence in stop_id_sequences ) - break - - # If the end of tokens overlaps with a stop sequence, generate new - # tokens until we know if the stop sequence is hit or not - if any( - (sequence_overlap(tokens, sequence) for sequence in stop_id_sequences) - ): - continue - - new_text = detokenizer.last_segment - if new_text: - response = self.generate_response(new_text, None) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() + ): + continue + elif segment: + response = self.generate_response(segment, None) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() self.prompt_cache.tokens.extend(tokens) - # check is there any remaining text to send - detokenizer.finalize() - last_segment = detokenizer.last_segment - if last_segment: - if stop_sequence_suffix is not None: - last_segment = last_segment[: -len(stop_sequence_suffix)] - response = self.generate_response(last_segment, "length") + gen_time = time.perf_counter() - tic + prompt_tps = len(prompt) / prompt_time + gen_tps = len(tokens) / gen_time + peak_mem = mx.metal.get_peak_memory() / 1e9 + logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec") + logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec") + logging.debug(f"Peak memory: {peak_mem:.3f} GB") + + if self.stream: + response = self.generate_response(segment, finish_reason) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.flush() + if self.stream_options is not None and self.stream_options["include_usage"]: + response = self.completion_usage_response(len(prompt), len(tokens)) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + self.wfile.write("data: [DONE]\n\n".encode()) + self.wfile.flush() + else: + response = self.generate_response( + text, + finish_reason, + len(prompt), + len(tokens), + token_logprobs=token_logprobs, + top_tokens=top_tokens, + tokens=tokens, + ) + response_json = json.dumps(response).encode() + indent = "\t" # Backslashes can't be inside of f-strings + logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") - if self.stream_options is not None and self.stream_options["include_usage"]: - response = self.completion_usage_response(len(prompt), len(tokens)) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - - self.wfile.write("data: [DONE]\n\n".encode()) - self.wfile.flush() + # Send an additional Content-Length header when it is known + self.send_header("Content-Length", str(len(response_json))) + self.end_headers() + self.wfile.write(response_json) + self.wfile.flush() def completion_usage_response( self, diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 38619d95..21b1af18 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -285,7 +285,7 @@ def train( it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens - peak_mem = mx.metal.get_peak_memory() / 2**30 + peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: print( f"Iter {it}: Train loss {train_loss:.3f}, " diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 7b440db6..8893b570 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer # Local imports from .models import cache -from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling +from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model from .tuner.utils import load_adapters @@ -34,6 +34,9 @@ MODEL_REMAPPING = { MAX_FILE_SIZE_GB = 5 +# A stream on the default device just for generation +generation_stream = mx.new_stream(mx.default_device()) + class ModelNotFoundError(Exception): def __init__(self, message): @@ -137,29 +140,6 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path return model_path -def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float): - """ - Apply repetition penalty to specific logits based on the given context. - - Paper: https://arxiv.org/abs/1909.05858 - - Args: - logits (mx.array): The logits produced by the language model. - tokens (mx.array): A list of N previous tokens. - penalty (float): The repetition penalty factor to be applied. - - Returns: - logits (mx.array): Logits with repetition penalty applied to generated tokens. - """ - if len(tokens) > 0: - selected_logits = logits[:, tokens] - selected_logits = mx.where( - selected_logits < 0, selected_logits * penalty, selected_logits / penalty - ) - logits[:, tokens] = selected_logits - return logits - - def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): if ( kv_bits is not None @@ -185,7 +165,7 @@ def generate_step( max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, logit_bias: Optional[Dict[int, float]] = None, - logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, @@ -214,7 +194,7 @@ def generate_step( prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if provided, the cache will be updated in place. logit_bias (dictionary, optional): Additive logit bias. - logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed logits. Default: ``None``. kv_bits (int, optional): Number of bits to use for KV cache quantization. @@ -224,53 +204,9 @@ def generate_step( when ``kv_bits`` is non-None. Default: ``0``. Yields: - Generator[Tuple[mx.array, mx.array], None, None]: A generator producing - one token and a vector of log probabilities. + Tuple[mx.array, mx.array]: One token and a vector of log probabilities. """ - def sample(logits: mx.array) -> Tuple[mx.array, float]: - logprobs = logits - mx.logsumexp(logits) - - if temp == 0: - token = mx.argmax(logits, axis=-1) - else: - if top_p > 0 and top_p < 1.0: - token = top_p_sampling(logits, top_p, temp) - elif min_p != 0.0: - token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp) - else: - token = categorical_sampling(logits, temp) - - return token, logprobs - - if repetition_penalty and ( - repetition_penalty < 0 or not isinstance(repetition_penalty, float) - ): - raise ValueError( - f"repetition_penalty must be a non-negative float, got {repetition_penalty}" - ) - - logits_processor = logits_processor or [] - - if repetition_penalty: - - def repetition_penalty_processor(tokens, logits): - return apply_repetition_penalty( - logits, tokens[-repetition_context_size:], repetition_penalty - ) - - logits_processor.append(repetition_penalty_processor) - - if logit_bias: - indices = mx.array(list(logit_bias.keys())) - values = mx.array(list(logit_bias.values())) - - def logit_bias_processor(_, logits): - logits[:, indices] += values - return logits - - logits_processor.append(logit_bias_processor) - y = prompt tokens = None @@ -283,24 +219,31 @@ def generate_step( elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") + sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) + logits_processors = logits_processors or [] + logits_processors.extend( + make_logits_processors(logit_bias, repetition_penalty, repetition_context_size) + ) + def _step(y): + with mx.stream(generation_stream): + logits = model(y[None], cache=prompt_cache) + logits = logits[:, -1, :] - logits = model(y[None], cache=prompt_cache) - logits = logits[:, -1, :] + if logits_processors: + nonlocal tokens + tokens = mx.concat([tokens, y]) if tokens is not None else y - if logits_processor: - nonlocal tokens - tokens = mx.concat([tokens, y]) if tokens is not None else y + for processor in logits_processors: + logits = processor(tokens, logits) - for processor in logits_processor: - logits = processor(tokens, logits) + maybe_quantize_kv_cache( + prompt_cache, quantized_kv_start, kv_group_size, kv_bits + ) - maybe_quantize_kv_cache( - prompt_cache, quantized_kv_start, kv_group_size, kv_bits - ) - - y, logprobs = sample(logits) - return y, logprobs.squeeze(0) + logprobs = logits - mx.logsumexp(logits, keepdims=True) + y = sampler(logprobs) + return y, logprobs.squeeze(0) while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) @@ -325,43 +268,51 @@ def generate_step( def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: str, + prompt: Union[str, List[int]], max_tokens: int = 100, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> Generator[Tuple[str, int, mx.array], None, None]: """ A generator producing text based on the given prompt from the model. Args: - prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - max_tokens (int): The ma + tokenizer (PreTrainedTokenizer): The tokenizer. + prompt (Union[str, List[int]]): The input prompt string or integer tokens. + max_tokens (int): The maximum number of tokens. Default: ``100``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. Yields: - Generator[Tuple[mx.array, mx.array]]: A generator producing text. + Tuple[str, int, mx.array]: + The next text segment, token, and vector of log probabilities. """ if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - prompt_tokens = mx.array(tokenizer.encode(prompt)) + prompt_tokens = mx.array( + prompt if isinstance(prompt, list) else tokenizer.encode(prompt) + ) detokenizer = tokenizer.detokenizer - detokenizer.reset() - for n, (token, _) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if token == tokenizer.eos_token_id: - break - detokenizer.add_token(token) + with wired_limit(model, [generation_stream]): + detokenizer.reset() + for n, (token, logits) in zip( + range(max_tokens), + generate_step(prompt_tokens, model, **kwargs), + ): + if token == tokenizer.eos_token_id: + break - # Yield the last segment if streaming - yield detokenizer.last_segment + detokenizer.add_token(token) - detokenizer.finalize() - yield detokenizer.last_segment + if n == (max_tokens - 1): + break + + yield detokenizer.last_segment, token, logits + + detokenizer.finalize() + yield detokenizer.last_segment, token, logits def generate( @@ -372,7 +323,7 @@ def generate( verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> str: """ Generate a complete response from the model. @@ -398,7 +349,7 @@ def generate( prompt_tokens = mx.array(tokenizer.encode(prompt)) detokenizer = tokenizer.detokenizer - with wired_limit(model): + with wired_limit(model, [generation_stream]): tic = time.perf_counter() detokenizer.reset() for n, (token, logprobs) in zip( @@ -416,8 +367,7 @@ def generate( if formatter: # We have to finalize so that the prob corresponds to the last segment detokenizer.finalize() - with mx.stream(mx.cpu): - prob = mx.exp(logprobs[token]).item() + prob = mx.exp(logprobs[token]).item() formatter(detokenizer.last_segment, prob) else: print(detokenizer.last_segment, end="", flush=True) @@ -438,7 +388,7 @@ def generate( f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" ) print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 2**30 + peak_mem = mx.metal.get_peak_memory() / 1e9 print(f"Peak memory: {peak_mem:.3f} GB") return detokenizer.text @@ -623,7 +573,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): f""" # {upload_repo} - The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**. + The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was + converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) + using mlx-lm version **{__version__}**. ## Use with mlx diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index 68f1670b..e0a372a9 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase): "hello", max_tokens=5, verbose=False, - logits_processor=[logits_processor], + logits_processors=[logits_processor], ) self.assertEqual(len(all_toks), len(init_toks) + 5) diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 1e57bd86..0867ab56 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -299,7 +299,7 @@ class TestPromptCache(unittest.TestCase): ): i += 1 self.assertEqual(tok, toks[i]) - self.assertTrue(mx.allclose(logits, all_logits[i], rtol=1e-2)) + self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2)) if __name__ == "__main__": From 1e0766018494c46bc6078769278b8e2a360503dc Mon Sep 17 00:00:00 2001 From: madroid Date: Sat, 9 Nov 2024 09:15:19 +0800 Subject: [PATCH 28/77] FLUX: save train config (#1049) --- flux/README.md | 2 +- flux/dreambooth.py | 15 +++++++++++---- flux/flux/__init__.py | 1 + flux/flux/utils.py | 23 ++++++++++++++++++++++- 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/flux/README.md b/flux/README.md index 1a17e386..b00a9621 100644 --- a/flux/README.md +++ b/flux/README.md @@ -188,7 +188,7 @@ The adapters are saved in `mlx_output` and can be used directly by the ```shell python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \ - --adapter mlx_output/0001200_adapters.safetensors \ + --adapter mlx_output/final_adapters.safetensors \ --fuse-adapter \ --no-t5-padding \ 'A photo of an sks dog lying on the sand at a beach in Greece' diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 48dcad47..ffdb02d7 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -13,7 +13,7 @@ from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten, tree_map, tree_reduce from PIL import Image -from flux import FluxPipeline, Trainer, load_dataset +from flux import FluxPipeline, Trainer, load_dataset, save_config def generate_progress_images(iteration, flux, args): @@ -43,10 +43,10 @@ def generate_progress_images(iteration, flux, args): im.save(out_file) -def save_adapters(iteration, flux, args): +def save_adapters(adapter_name, flux, args): out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) - out_file = out_dir / f"{iteration:07d}_adapters.safetensors" + out_file = out_dir / adapter_name print(f"Saving {str(out_file)}") mx.save_safetensors( @@ -157,6 +157,10 @@ if __name__ == "__main__": parser = setup_arg_parser() args = parser.parse_args() + output_path = Path(args.output_dir) + output_path.mkdir(parents=True, exist_ok=True) + save_config(vars(args), output_path / "adapter_config.json") + # Load the model and set it up for LoRA training. We use the same random # state when creating the LoRA layers so all workers will have the same # initial weights. @@ -278,8 +282,11 @@ if __name__ == "__main__": generate_progress_images(i + 1, flux, args) if (i + 1) % args.checkpoint_every == 0: - save_adapters(i + 1, flux, args) + save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args) if (i + 1) % 10 == 0: losses = [] tic = time.time() + + save_adapters("final_adapters.safetensors", flux, args) + print(f"Training successful. Saved final weights to {args.adapter_file}.") diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py index b1122d75..3dd423b7 100644 --- a/flux/flux/__init__.py +++ b/flux/flux/__init__.py @@ -12,4 +12,5 @@ from .utils import ( load_flow_model, load_t5, load_t5_tokenizer, + save_config, ) diff --git a/flux/flux/utils.py b/flux/flux/utils.py index 21db17d3..2437f21f 100644 --- a/flux/flux/utils.py +++ b/flux/flux/utils.py @@ -3,7 +3,8 @@ import json import os from dataclasses import dataclass -from typing import Optional +from pathlib import Path +from typing import Optional, Union import mlx.core as mx from huggingface_hub import hf_hub_download @@ -207,3 +208,23 @@ def load_clip_tokenizer(name: str): def load_t5_tokenizer(name: str, pad: bool = True): model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model") return T5Tokenizer(model_file, 256 if "schnell" in name else 512) + + +def save_config( + config: dict, + config_path: Union[str, Path], +) -> None: + """Save the model configuration to the ``config_path``. + + The final configuration will be sorted before saving for better readability. + + Args: + config (dict): The model configuration. + config_path (Union[str, Path]): Model configuration file path. + """ + # Sort the config for better readability + config = dict(sorted(config.items())) + + # Write the config to the provided file + with open(config_path, "w") as fid: + json.dump(config, fid, indent=4) From bd6d910ca3744d75bf704e6e7039f97f71014bd5 Mon Sep 17 00:00:00 2001 From: Alban Lecocq Date: Wed, 13 Nov 2024 15:14:03 +0100 Subject: [PATCH 29/77] [MLX LM] Fix f-string formatting in memory warning message (#1105) * Fix missing f-prefix for string interpolation in model size warning * Ensures proper display of memory values in MB for model and max size --- llms/mlx_lm/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 8893b570..d4afd428 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -61,8 +61,8 @@ def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): model_mb = model_bytes // 2**20 max_rec_mb = max_rec_size // 2**20 print( - "[WARNING] Generating with a model that requires {model_mb} MB " - "which is close to the maximum recommended size of {max_rec_mb} " + f"[WARNING] Generating with a model that requires {model_mb} MB " + f"which is close to the maximum recommended size of {max_rec_mb} " "MB. This can be slow. See the documentation for possible work-arounds: " "https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models" ) From 60c7b803500df4dd84ef8b5ed70deace99272bc2 Mon Sep 17 00:00:00 2001 From: Valentin Roussellet Date: Wed, 20 Nov 2024 15:21:52 -0800 Subject: [PATCH 30/77] Pass seed to sd img2img (#1114) --- stable_diffusion/image2image.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/stable_diffusion/image2image.py b/stable_diffusion/image2image.py index e470aa81..a037af6a 100644 --- a/stable_diffusion/image2image.py +++ b/stable_diffusion/image2image.py @@ -30,6 +30,7 @@ if __name__ == "__main__": parser.add_argument("--preload-models", action="store_true") parser.add_argument("--output", default="out.png") parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--seed", type=int) args = parser.parse_args() # Load the models @@ -94,6 +95,7 @@ if __name__ == "__main__": cfg_weight=args.cfg, num_steps=args.steps, negative_text=args.negative_prompt, + seed=args.seed ) for x_t in tqdm(latents, total=int(args.steps * args.strength)): mx.eval(x_t) From 042280ce50645b69d9c322ccf1cb8471384007f1 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Wed, 20 Nov 2024 16:15:53 -0800 Subject: [PATCH 31/77] Fix format (#1115) --- stable_diffusion/image2image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_diffusion/image2image.py b/stable_diffusion/image2image.py index a037af6a..4444c488 100644 --- a/stable_diffusion/image2image.py +++ b/stable_diffusion/image2image.py @@ -95,7 +95,7 @@ if __name__ == "__main__": cfg_weight=args.cfg, num_steps=args.steps, negative_text=args.negative_prompt, - seed=args.seed + seed=args.seed, ) for x_t in tqdm(latents, total=int(args.steps * args.strength)): mx.eval(x_t) From 004eb4cc9d3d390dbadb8eb015de7d28a788701b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 23 Nov 2024 11:06:26 -0800 Subject: [PATCH 32/77] Tencent HunYuan MOE model (#1100) * hunyuan * fix * format str * default trust remote code for tokenizer, allow system prompt to be configurable --- llms/mlx_lm/generate.py | 24 +-- llms/mlx_lm/models/hunyuan.py | 291 ++++++++++++++++++++++++++++++++++ llms/tests/test_models.py | 32 ++++ 3 files changed, 337 insertions(+), 10 deletions(-) create mode 100644 llms/mlx_lm/models/hunyuan.py diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 51169def..de5c5719 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -41,17 +41,17 @@ def setup_arg_parser(): type=str, help="Optional path for the trained adapter weights and config.", ) - parser.add_argument( - "--trust-remote-code", - action="store_true", - help="Enable trusting remote code for tokenizer", - ) parser.add_argument( "--eos-token", type=str, default=None, help="End of sequence token for tokenizer", ) + parser.add_argument( + "--system-prompt", + default=None, + help="System prompt to be used for the chat template", + ) parser.add_argument( "--prompt", "-p", @@ -191,8 +191,7 @@ def main(): tokenizer_config = ( {} if not using_cache else json.loads(metadata["tokenizer_config"]) ) - if args.trust_remote_code: - tokenizer_config["trust_remote_code"] = True + tokenizer_config["trust_remote_code"] = True if args.eos_token is not None: tokenizer_config["eos_token"] = args.eos_token @@ -224,12 +223,16 @@ def main(): hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None ): - messages = [ + if args.system_prompt is not None: + messages = [{"role": "system", "content": args.system_prompt}] + else: + messages = [] + messages.append( { "role": "user", "content": sys.stdin.read() if args.prompt == "-" else args.prompt, } - ] + ) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) @@ -237,8 +240,9 @@ def main(): # Treat the prompt as a suffix assuming that the prefix is in the # stored kv cache. if using_cache: + messages[-1]["content"] = "" test_prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": ""}], + messages, tokenize=False, add_generation_prompt=True, ) diff --git a/llms/mlx_lm/models/hunyuan.py b/llms/mlx_lm/models/hunyuan.py new file mode 100644 index 00000000..b098c20d --- /dev/null +++ b/llms/mlx_lm/models/hunyuan.py @@ -0,0 +1,291 @@ +# Copyright © 2023-2024 Apple Inc. + +import math +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + vocab_size: int + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + num_key_value_heads: int + attention_bias: bool + moe_topk: int + num_experts: int + num_shared_expert: int + use_mixed_mlp_moe: bool + use_qk_norm: bool + rms_norm_eps: float + rope_theta: float + use_cla: bool + cla_share_factor: 2 + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = False + + def __post_init__(self): + + if self.rope_scaling: + required_keys = {"factor", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + +class DynamicNTKAlphaRoPE(nn.Module): + def __init__( + self, + dims: int, + base: float = 10000, + scaling_alpha: float = 1.0, + ): + super().__init__() + self.dims = dims + base = base * scaling_alpha ** (dims / (dims - 2)) + self._freqs = base ** (mx.arange(0, self.dims, 2) / self.dims) + + def __call__(self, x, offset: int = 0): + return mx.fast.rope( + x, + self.dims, + traditional=False, + base=None, + scale=1.0, + offset=offset, + freqs=self._freqs, + ) + + +class Attention(nn.Module): + def __init__(self, kv_proj: bool, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + assert args.num_key_value_heads is not None + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) + if kv_proj: + self.k_proj = nn.Linear( + dim, n_kv_heads * head_dim, bias=args.attention_bias + ) + self.v_proj = nn.Linear( + dim, n_kv_heads * head_dim, bias=args.attention_bias + ) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) + self.use_qk_norm = args.use_qk_norm + if self.use_qk_norm: + self.query_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps) + self.key_layernorm = nn.RMSNorm(head_dim, args.rms_norm_eps) + + self.rope = DynamicNTKAlphaRoPE( + head_dim, + base=args.rope_theta, + scaling_alpha=args.rope_scaling["alpha"], + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + kv_states=None, + ) -> mx.array: + B, L, D = x.shape + + queries = self.q_proj(x) + + if kv_states is None: + keys, values = self.k_proj(x), self.v_proj(x) + kv_states = keys, values + else: + keys, values = kv_states + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + offset = cache.offset if cache else 0 + queries = self.rope(queries, offset=offset) + keys = self.rope(keys, offset=offset) + if self.use_qk_norm: + queries = self.query_layernorm(queries) + keys = self.key_layernorm(keys) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), kv_states + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class Gate(nn.Module): + def __init__(self, dim, num_experts): + super().__init__() + self.wg = nn.Linear(dim, num_experts, bias=False) + + def __call__(self, x) -> mx.array: + return self.wg(x) + + +class MoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + intermediate_size = args.intermediate_size + self.use_shared_mlp = args.use_mixed_mlp_moe + + if args.use_mixed_mlp_moe: + self.shared_mlp = MLP(dim, intermediate_size * args.num_shared_expert) + + self.num_experts = num_experts = args.num_experts + self.top_k = args.moe_topk + + self.gate = Gate(dim, num_experts) + self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts) + + def __call__( + self, + x: mx.array, + ): + gates = self.gate(x) + gates = mx.softmax(gates, axis=-1, precise=True) + + k = self.top_k + inds = mx.stop_gradient(mx.argpartition(-gates, kth=k - 1, axis=-1)[..., :k]) + scores = mx.take_along_axis(gates, inds, axis=-1) + + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) + + if self.use_shared_mlp: + shared_expert_output = self.shared_mlp(x) + y = y + shared_expert_output + + return y + + +class DecoderLayer(nn.Module): + def __init__(self, args: ModelArgs, kv_proj: bool): + super().__init__() + self.hidden_size = args.hidden_size + self.self_attn = Attention(kv_proj, args) + self.mlp = MoeBlock(args) + + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + shared_kv_states: Optional[Tuple[mx.array, mx.array]] = None, + ): + r, shared_kv_states = self.self_attn( + self.input_layernorm(x), mask, cache, shared_kv_states + ) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, shared_kv_states + + +class HunYuanModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + DecoderLayer(args=args, kv_proj=(i % args.cla_share_factor) == 0) + for i in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for i, (layer, c) in enumerate(zip(self.layers, cache)): + if i % self.args.cla_share_factor == 0: + shared_kv_states = None + h, shared_kv_states = layer(h, mask, c, shared_kv_states) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = HunYuanModel(args) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + return self.model.embed_tokens.as_linear(out) + + def sanitize(self, weights): + if "model.layers.0.mlp.experts.0.up_proj.weight" not in weights: + return weights + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n in ["up_proj", "down_proj", "gate_proj"]: + for k in ["weight", "scales", "biases"]: + if f"{prefix}.mlp.experts.0.{n}.{k}" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{n}.{k}") + for e in range(self.args.num_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.{n}.{k}"] = mx.stack(to_join) + return weights + + @property + def layers(self): + return self.model.layers diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 1efde5ae..93b881b9 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -760,6 +760,38 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_hunyuan(self): + from mlx_lm.models import hunyuan + + args = hunyuan.ModelArgs( + model_type="hunyuan", + hidden_size=128, + attention_bias=False, + intermediate_size=256, + num_attention_heads=4, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-4, + rope_theta=1000, + vocab_size=1000, + moe_topk=2, + num_experts=2, + num_shared_expert=1, + use_mixed_mlp_moe=True, + use_qk_norm=True, + rope_scaling={ + "alpha": 1000.0, + "factor": 1.0, + "type": "dynamic", + }, + use_cla=True, + cla_share_factor=2, + ) + model = hunyuan.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + if __name__ == "__main__": unittest.main() From 0f135396ae7fcb2bad407d6a41296ac84c0fb666 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 23 Nov 2024 11:47:06 -0800 Subject: [PATCH 33/77] Generation refactor: part 2 (#1099) * unify with stream_generate * fixes * nit * some cleanup, warnings, tests * fix test + faster min p + test * version --- llms/README.md | 11 +- llms/mlx_lm/_version.py | 2 +- llms/mlx_lm/chat.py | 10 +- llms/mlx_lm/examples/chat.py | 1 - llms/mlx_lm/examples/generate_response.py | 9 - llms/mlx_lm/generate.py | 46 +---- llms/mlx_lm/sample_utils.py | 22 +-- llms/mlx_lm/server.py | 42 ++--- llms/mlx_lm/tokenizer_utils.py | 11 +- llms/mlx_lm/utils.py | 203 ++++++++++++---------- llms/tests/test_generate.py | 3 +- llms/tests/test_sample_utils.py | 18 +- llms/tests/test_tokenizers.py | 3 +- 13 files changed, 184 insertions(+), 197 deletions(-) diff --git a/llms/README.md b/llms/README.md index eeb3ed6a..60f68353 100644 --- a/llms/README.md +++ b/llms/README.md @@ -61,7 +61,7 @@ prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) -response = generate(model, tokenizer, prompt=prompt, verbose=True) +text = generate(model, tokenizer, prompt=prompt, verbose=True) ``` To see a description of all the arguments you can do: @@ -100,8 +100,9 @@ To see a description of all the arguments you can do: #### Streaming -For streaming generation, use the `stream_generate` function. This returns a -generator object which streams the output text, token, and log probabilities. +For streaming generation, use the `stream_generate` function. This yields +a generation response object. + For example, ```python @@ -117,8 +118,8 @@ prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) -for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512): - print(t, end="", flush=True) +for response in stream_generate(model, tokenizer, prompt, max_tokens=512): + print(response.text, end="", flush=True) print() ``` diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 3811616f..5168eee4 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.19.3" +__version__ = "0.20.0" diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index c03056a6..7795d8d7 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -5,7 +5,8 @@ import json import mlx.core as mx -from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache +from .models.cache import make_prompt_cache +from .sample_utils import make_sampler from .utils import load, stream_generate DEFAULT_TEMP = 0.0 @@ -74,16 +75,15 @@ def main(): prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - for response, *_ in stream_generate( + for response in stream_generate( model, tokenizer, prompt, args.max_tokens, - temp=args.temp, - top_p=args.top_p, + sampler=make_sampler(args.temp, args.top_p), prompt_cache=prompt_cache, ): - print(response, flush=True, end="") + print(response.text, flush=True, end="") print() diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py index 3bf01688..c7512b3c 100644 --- a/llms/mlx_lm/examples/chat.py +++ b/llms/mlx_lm/examples/chat.py @@ -42,7 +42,6 @@ response = generate( tokenizer, prompt=prompt, verbose=True, - temp=0.0, prompt_cache=prompt_cache, ) diff --git a/llms/mlx_lm/examples/generate_response.py b/llms/mlx_lm/examples/generate_response.py index 25730617..e6535b47 100644 --- a/llms/mlx_lm/examples/generate_response.py +++ b/llms/mlx_lm/examples/generate_response.py @@ -23,14 +23,6 @@ max_tokens = 1_000 # Specify if tokens and timing information will be printed verbose = True -# Some optional arguments for causal language model generation -generation_args = { - "temp": 0.7, - "repetition_penalty": 1.2, - "repetition_context_size": 20, - "top_p": 0.95, -} - # Generate a response with the specified settings response = generate( model=model, @@ -38,5 +30,4 @@ response = generate( prompt=prompt, max_tokens=max_tokens, verbose=verbose, - **generation_args, ) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index de5c5719..9e96fbdc 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -7,6 +7,7 @@ import sys import mlx.core as mx from .models.cache import QuantizedKVCache, load_prompt_cache +from .sample_utils import make_sampler from .utils import generate, load DEFAULT_PROMPT = "hello" @@ -97,11 +98,6 @@ def setup_arg_parser(): default=True, help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'", ) - parser.add_argument( - "--colorize", - action="store_true", - help="Colorize output based on T[0] probability", - ) parser.add_argument( "--max-kv-size", type=int, @@ -137,33 +133,6 @@ def setup_arg_parser(): return parser -def colorprint(color, s): - color_codes = { - "black": 30, - "red": 31, - "green": 32, - "yellow": 33, - "blue": 34, - "magenta": 35, - "cyan": 36, - "white": 39, - } - ccode = color_codes.get(color, 30) - print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True) - - -def colorprint_by_t0(s, t0): - if t0 > 0.95: - color = "white" - elif t0 > 0.70: - color = "green" - elif t0 > 0.30: - color = "yellow" - else: - color = "red" - colorprint(color, s) - - def main(): parser = setup_arg_parser() args = parser.parse_args() @@ -250,21 +219,14 @@ def main(): else: prompt = args.prompt - if args.colorize and not args.verbose: - raise ValueError("Cannot use --colorize with --verbose=False") - formatter = colorprint_by_t0 if args.colorize else None - + sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate( model, tokenizer, prompt, - args.max_tokens, + max_tokens=args.max_tokens, verbose=args.verbose, - formatter=formatter, - temp=args.temp, - top_p=args.top_p, - min_p=args.min_p, - min_tokens_to_keep=args.min_tokens_to_keep, + sampler=sampler, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, kv_bits=args.kv_bits, diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index c27b52d8..f9868422 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -1,5 +1,6 @@ # Copyright © 2023-2024 Apple Inc. +import math from functools import partial from typing import Callable, Dict, Optional @@ -80,7 +81,7 @@ def make_logits_processors( @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def min_p_sampling( - logits: mx.array, + logprobs: mx.array, min_p: float, min_tokens_to_keep: int = 1, temperature=1.0, @@ -93,7 +94,7 @@ def min_p_sampling( aggressive given a very high-probability token. Args: - logits: The logits from the model's output. + logprobs: A vector of log probabilities. min_p (float): Minimum token probability. Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in the 0.99-0.8 range. @@ -111,28 +112,27 @@ def min_p_sampling( ) # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 - # Softmax probabilities - probs = mx.softmax(logits * (1 / temperature), axis=-1) + logprobs = logprobs * (1 / temperature) # Indices sorted in decreasing order - sorted_indices = mx.argsort(-logits).squeeze(0) - sorted_probs = probs[..., sorted_indices] + sorted_indices = mx.argsort(-logprobs).squeeze(0) + sorted_logprobs = logprobs[..., sorted_indices] # Top probability - top_probs = probs[..., sorted_indices[0]] + top_logprobs = logprobs[..., sorted_indices[0]] # Calculate the min_p threshold - scaled_min_p = min_p * top_probs + scaled_min_p = top_logprobs + math.log(min_p) # Mask tokens that have a probability less than the scaled min_p - tokens_to_remove = sorted_probs < scaled_min_p + tokens_to_remove = sorted_logprobs < scaled_min_p tokens_to_remove[..., :min_tokens_to_keep] = False # Create pool of tokens with probability less than scaled min_p - selected_probs = mx.where(tokens_to_remove, 0, sorted_probs) + selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs) # Return sampled token - sorted_token = mx.random.categorical(mx.log(selected_probs)) + sorted_token = mx.random.categorical(selected_logprobs) return sorted_indices[sorted_token] diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index c1365b36..badc6dd3 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -27,6 +27,7 @@ from huggingface_hub import scan_cache_dir from ._version import __version__ from .models.cache import make_prompt_cache +from .sample_utils import make_logits_processors, make_sampler from .utils import load, stream_generate @@ -464,25 +465,24 @@ class APIHandler(BaseHTTPRequestHandler): text = "" tic = time.perf_counter() - for n, (segment, token, logprobs) in enumerate( - stream_generate( - model=self.model, - tokenizer=self.tokenizer, - prompt=prompt, - max_tokens=self.max_tokens, - temp=self.temperature, - repetition_penalty=self.repetition_penalty, - repetition_context_size=self.repetition_context_size, - logit_bias=self.logit_bias, - prompt_cache=self.prompt_cache.cache, - ), + sampler = make_sampler(self.temperature) + logits_processors = make_logits_processors( + self.logit_bias, self.repetition_penalty, self.repetition_context_size + ) + for gen_response in stream_generate( + model=self.model, + tokenizer=self.tokenizer, + prompt=prompt, + max_tokens=self.max_tokens, + sampler=sampler, + logits_processors=logits_processors, + prompt_cache=self.prompt_cache.cache, ): - if n == 0: - prompt_time = time.perf_counter() - tic - tic = time.perf_counter() - + segment = gen_response.text text += segment logging.debug(text) + token = gen_response.token + logprobs = gen_response.logprobs tokens.append(token) if self.logprobs > 0: @@ -523,13 +523,9 @@ class APIHandler(BaseHTTPRequestHandler): self.prompt_cache.tokens.extend(tokens) - gen_time = time.perf_counter() - tic - prompt_tps = len(prompt) / prompt_time - gen_tps = len(tokens) / gen_time - peak_mem = mx.metal.get_peak_memory() / 1e9 - logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec") - logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec") - logging.debug(f"Peak memory: {peak_mem:.3f} GB") + logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec") + logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec") + logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB") if self.stream: response = self.generate_response(segment, finish_reason) diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 9d390733..0fa41ac0 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -73,16 +73,16 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer): def reset(self): self.offset = 0 - self._tokens = [] + self.tokens = [] self._text = "" self._current_tokens = [] self._current_text = "" def add_token(self, token): self._current_tokens.append(token) + self.tokens.append(token) def finalize(self): - self._tokens.extend(self._current_tokens) self._text += self._tokenizer.decode(self._current_tokens) self._current_tokens = [] self._current_text = "" @@ -97,16 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer): ): self._current_text = self._current_text[:-1] if self._current_text and self._current_text[-1] == "\n": - self._tokens.extend(self._current_tokens) self._text += self._current_text self._current_tokens.clear() self._current_text = "" return self._text + self._current_text - @property - def tokens(self): - return self._tokens - class SPMStreamingDetokenizer(StreamingDetokenizer): """A streaming detokenizer for SPM models. @@ -143,6 +138,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): self.text += text def add_token(self, token): + self.tokens.append(token) v = self.tokenmap[token] if v.startswith(self._sep): self._flush() @@ -200,6 +196,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): return current_text def add_token(self, token): + self.tokens.append(token) v = self.tokenmap[token] is_added = token in self._added_ids if is_added or self._byte_decoder[v[0]] == 32: diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index d4afd428..496ae4fc 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -8,6 +8,7 @@ import json import logging import shutil import time +from dataclasses import dataclass from pathlib import Path from textwrap import dedent from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union @@ -44,6 +45,32 @@ class ModelNotFoundError(Exception): super().__init__(self.message) +@dataclass +class GenerationResponse: + """ + The output of :func:`stream_generate`. + + Args: + text (str): The next segment of decoded text. This can be an empty string. + token (int): The next token. + logprobs (mx.array): A vector of log probabilities. + prompt_tokens (int): The number of tokens in the prompt. + prompt_tps (float): The prompt processing tokens-per-second. + generation_tokens (int): The number of generated tokens. + generation_tps (float): The tokens-per-second for generation. + peak_memory (float): The peak memory used so far in GB. + """ + + text: str + token: int + logprobs: mx.array + prompt_tokens: int + prompt_tps: float + generation_tokens: int + generation_tps: float + peak_memory: float + + @contextlib.contextmanager def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): """ @@ -155,20 +182,21 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ def generate_step( prompt: mx.array, model: nn.Module, - temp: float = 0.0, - repetition_penalty: Optional[float] = None, - repetition_context_size: Optional[int] = 20, - top_p: float = 1.0, - min_p: float = 0.0, - min_tokens_to_keep: int = 1, - prefill_step_size: int = 512, + *, + sampler: Optional[Callable[mx.array, mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, - logit_bias: Optional[Dict[int, float]] = None, - logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prefill_step_size: int = 512, kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, + temp: Optional[float] = None, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = None, + top_p: Optional[float] = None, + min_p: Optional[float] = None, + min_tokens_to_keep: Optional[int] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -176,32 +204,21 @@ def generate_step( Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - temp (float): The temperature for sampling, if 0 the argmax is used. - Default: ``0``. - repetition_penalty (float, optional): The penalty factor for repeating - tokens. - repetition_context_size (int, optional): The number of tokens to - consider for repetition penalty. Default: ``20``. - top_p (float, optional): Nulceus sampling, higher means model considers - more less likely words. - min_p (float, optional): The minimum value (scaled by the top token's - probability) that a token probability must have to be considered. - min_tokens_to_keep (int, optional): Minimum number of tokens that cannot - be filtered by min_p sampling. prefill_step_size (int): Step size for processing the prompt. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if provided, the cache will be updated in place. - logit_bias (dictionary, optional): Additive logit bias. + sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a + token from a vector of log probabilities. Default: ``None``. logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): - A list of functions that take tokens and logits and return the processed - logits. Default: ``None``. + A list of functions that take tokens and logits and return the processed + logits. Default: ``None``. kv_bits (int, optional): Number of bits to use for KV cache quantization. - None implies no cache quantization. Default: ``None``. + None implies no cache quantization. Default: ``None``. kv_group_size (int): Group size for KV cache quantization. Default: ``64``. quantized_kv_start (int): Step to begin using a quantized KV cache. - when ``kv_bits`` is non-None. Default: ``0``. + when ``kv_bits`` is non-None. Default: ``0``. Yields: Tuple[mx.array, mx.array]: One token and a vector of log probabilities. @@ -219,10 +236,22 @@ def generate_step( elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") - sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) - logits_processors = logits_processors or [] - logits_processors.extend( - make_logits_processors(logit_bias, repetition_penalty, repetition_context_size) + if temp is not None or top_p is not None or min_tokens_to_keep is not None: + print( + "[Warning] Specifying sampling arguments to ``generate_step`` is " + "deprecated. Pass in a ``sampler`` instead." + ) + if repetition_penalty is not None: + print( + "[Warning] Specifying ``repetition_penalty`` is deprecated. " + "Pass in ``logits_processors`` instead." + ) + + sampler = sampler or make_sampler( + temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1 + ) + logits_processors = logits_processors or make_logits_processors( + None, repetition_penalty, repetition_context_size or 20 ) def _step(y): @@ -290,17 +319,20 @@ def stream_generate( if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - prompt_tokens = mx.array( - prompt if isinstance(prompt, list) else tokenizer.encode(prompt) - ) + prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt)) detokenizer = tokenizer.detokenizer with wired_limit(model, [generation_stream]): detokenizer.reset() - for n, (token, logits) in zip( + tic = time.perf_counter() + for n, (token, logprobs) in zip( range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), + generate_step(prompt, model, **kwargs), ): + if n == 0: + prompt_time = time.perf_counter() - tic + prompt_tps = prompt.size / prompt_time + tic = time.perf_counter() if token == tokenizer.eos_token_id: break @@ -309,17 +341,34 @@ def stream_generate( if n == (max_tokens - 1): break - yield detokenizer.last_segment, token, logits + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + prompt_tokens=prompt.size, + prompt_tps=prompt_tps, + generation_tokens=n + 1, + generation_tps=(n + 1) / (time.perf_counter() - tic), + peak_memory=mx.metal.get_peak_memory() / 1e9, + ) detokenizer.finalize() - yield detokenizer.last_segment, token, logits + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + prompt_tokens=prompt.size, + prompt_tps=prompt_tps, + generation_tokens=n + 1, + generation_tps=(n + 1) / (time.perf_counter() - tic), + peak_memory=mx.metal.get_peak_memory() / 1e9, + ) def generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: str, - max_tokens: int = 100, verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, @@ -334,64 +383,40 @@ def generate( max_tokens (int): The maximum number of tokens. Default: ``100``. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. - formatter (Optional[Callable]): A function which takes a token and a - probability and displays it. - kwargs: The remaining options get passed to :func:`generate_step`. - See :func:`generate_step` for more details. + kwargs: The remaining options get passed to :func:`stream_generate`. + See :func:`stream_generate` for more details. """ - if not isinstance(tokenizer, TokenizerWrapper): - tokenizer = TokenizerWrapper(tokenizer) - + if formatter is not None: + print( + "[Warning] Text formatting is deprecated and no longer used. " + "The argument will be removed in a future version." + ) if verbose: print("=" * 10) print("Prompt:", prompt) - prompt_tokens = mx.array(tokenizer.encode(prompt)) - detokenizer = tokenizer.detokenizer - - with wired_limit(model, [generation_stream]): - tic = time.perf_counter() - detokenizer.reset() - for n, (token, logprobs) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if n == 0: - prompt_time = time.perf_counter() - tic - tic = time.perf_counter() - if token == tokenizer.eos_token_id: - break - detokenizer.add_token(token) - - if verbose: - if formatter: - # We have to finalize so that the prob corresponds to the last segment - detokenizer.finalize() - prob = mx.exp(logprobs[token]).item() - formatter(detokenizer.last_segment, prob) - else: - print(detokenizer.last_segment, end="", flush=True) - - token_count = n + 1 - detokenizer.finalize() - + text = "" + for response in stream_generate(model, tokenizer, prompt, **kwargs): if verbose: - gen_time = time.perf_counter() - tic - print(detokenizer.last_segment, flush=True) - print("=" * 10) - if token_count == 0: - print("No tokens generated for this prompt") - return - prompt_tps = prompt_tokens.size / prompt_time - gen_tps = (token_count - 1) / gen_time - print( - f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" - ) - print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 1e9 - print(f"Peak memory: {peak_mem:.3f} GB") + print(response.text, end="", flush=True) + text += response.text - return detokenizer.text + if verbose: + print() + print("=" * 10) + if len(text) == 0: + print("No text generated for this prompt") + return + print( + f"Prompt: {response.prompt_tokens} tokens, " + f"{response.prompt_tps:.3f} tokens-per-sec" + ) + print( + f"Generation: {response.generation_tokens} tokens, " + f"{response.generation_tps:.3f} tokens-per-sec" + ) + print(f"Peak memory: {response.peak_memory:.3f} GB") + return text def load_config(model_path: Path) -> dict: diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index e0a372a9..f2345394 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -2,6 +2,7 @@ import unittest +from mlx_lm.sample_utils import make_logits_processors from mlx_lm.utils import generate, load @@ -25,8 +26,8 @@ class TestGenerate(unittest.TestCase): self.tokenizer, "hello", max_tokens=5, + logits_processors=make_logits_processors(logit_bias), verbose=False, - logit_bias=logit_bias, ) self.assertEqual(text, "!!!!!") diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index ec0e2cb7..ebc90ce8 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -1,10 +1,10 @@ import unittest import mlx.core as mx -from mlx_lm.sample_utils import top_p_sampling +from mlx_lm.sample_utils import min_p_sampling, top_p_sampling -class TestSamplingUtils(unittest.TestCase): +class TestSampleUtils(unittest.TestCase): def test_top_p_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) @@ -28,6 +28,20 @@ class TestSamplingUtils(unittest.TestCase): token = top_p_sampling(logits, 0.95, temperature).item() self.assertTrue(token in (1, 2, 3)) + def test_min_p_sampling(self): + probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] + logits = mx.log(probs) + temperature = 1.0 + token = min_p_sampling(logits, 0.8) + self.assertEqual(token, 0) + + probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] + logits = mx.log(probs) + temperature = 1.0 + for _ in range(5): + token = min_p_sampling(logits, 0.05) + self.assertTrue(token in (0, 3)) + if __name__ == "__main__": unittest.main() diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py index 9c30d51e..db6b9f9e 100644 --- a/llms/tests/test_tokenizers.py +++ b/llms/tests/test_tokenizers.py @@ -34,10 +34,11 @@ class TestTokenizers(unittest.TestCase): detokenizer = tokenizer.detokenizer detokenizer.reset() text = "" - for t in tokens: + for e, t in enumerate(tokens): detokenizer.add_token(t) seg = detokenizer.last_segment text += seg + self.assertEqual(detokenizer.tokens, tokens[: e + 1]) detokenizer.finalize() text += detokenizer.last_segment self.assertEqual(text, expected_text) From 0ffdb6dd20f3cf45445b69d80aa93f793faf222d Mon Sep 17 00:00:00 2001 From: Kevin Conner Date: Sun, 24 Nov 2024 16:37:37 -0800 Subject: [PATCH 34/77] Fix object property value in mlx_lm.server chat completions response to match OpenAI spec (#1119) These were "chat.completions" and "chat.completions.chunk" but should be "chat.completion" and "chat.completion.chunk" for compatibility with clients expecting an OpenAI API. In particular, this solves a problem in which aider 0.64.1 reports hitting a token limit on any completion request, no matter how small, despite apparently correct counts in the usage property. Refer to: https://platform.openai.com/docs/api-reference/chat/object > object string > The object type, which is always chat.completion. https://platform.openai.com/docs/api-reference/chat/streaming > object string > The object type, which is always chat.completion.chunk. --- llms/mlx_lm/SERVER.md | 2 +- llms/mlx_lm/server.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/SERVER.md b/llms/mlx_lm/SERVER.md index 2976a09f..e544c6fa 100644 --- a/llms/mlx_lm/SERVER.md +++ b/llms/mlx_lm/SERVER.md @@ -92,7 +92,7 @@ curl localhost:8080/v1/chat/completions \ - `system_fingerprint`: A unique identifier for the system. -- `object`: Any of "chat.completions", "chat.completions.chunk" (for +- `object`: Any of "chat.completion", "chat.completion.chunk" (for streaming), or "text.completion". - `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`). diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index badc6dd3..ce09cf45 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -589,9 +589,7 @@ class APIHandler(BaseHTTPRequestHandler): # Determine response type self.request_id = f"chatcmpl-{uuid.uuid4()}" - self.object_type = ( - "chat.completions.chunk" if self.stream else "chat.completions" - ) + self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" if ( hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template From adaab81029eb5f53d9a40c94968bf143cbc5985c Mon Sep 17 00:00:00 2001 From: Remixer Dec Date: Mon, 25 Nov 2024 04:41:06 +0400 Subject: [PATCH 35/77] Allow converting models from local directories (#1118) --- whisper/convert.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/whisper/convert.py b/whisper/convert.py index 301fd5b4..7369fafa 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -174,11 +174,6 @@ def load_torch_weights_and_config( "*.txt", ], ) - else: - raise RuntimeError( - f"Model {name_or_path} is not found in {available_models()}," - "on Hugging Face or as a local path." - ) if name_or_path.endswith(".pt"): checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False) From a5e173802ea0da999923d240f65e94ec8ad3c415 Mon Sep 17 00:00:00 2001 From: madroid Date: Tue, 26 Nov 2024 00:10:14 +0800 Subject: [PATCH 36/77] docs: update stream_generate return type annotation (#1121) Improve documentation clarity by: 1. Fix return type annotation to correctly reflect GenerationResponse 2. Simplify docstring by referencing GenerationResponse class 3. Remove redundant field descriptions --- llms/mlx_lm/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 496ae4fc..5abd396d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -300,7 +300,7 @@ def stream_generate( prompt: Union[str, List[int]], max_tokens: int = 100, **kwargs, -) -> Generator[Tuple[str, int, mx.array], None, None]: +) -> Generator[GenerationResponse, None, None]: """ A generator producing text based on the given prompt from the model. @@ -313,8 +313,8 @@ def stream_generate( See :func:`generate_step` for more details. Yields: - Tuple[str, int, mx.array]: - The next text segment, token, and vector of log probabilities. + GenerationResponse: An instance containing the generated text segment and + associated metadata. See :class:`GenerationResponse` for details. """ if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) From cfc29c29f45372c78876335a44b0c99ab6565ae0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 25 Nov 2024 09:47:00 -0800 Subject: [PATCH 37/77] Put prompt processing in same stream (#1122) * put prompt processing in same stream * patch --- llms/mlx_lm/_version.py | 2 +- llms/mlx_lm/utils.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 5168eee4..343e0016 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.20.0" +__version__ = "0.20.1" diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 5abd396d..0e2f7af7 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -274,13 +274,14 @@ def generate_step( y = sampler(logprobs) return y, logprobs.squeeze(0) - while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - y = y[prefill_step_size:] - mx.metal.clear_cache() + with mx.stream(generation_stream): + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) + y = y[prefill_step_size:] + mx.metal.clear_cache() - y, logprobs = _step(y) + y, logprobs = _step(y) mx.async_eval(y, logprobs) n = 0 From cefe793ae0991b394b89c497d92afc6459490460 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Tue, 26 Nov 2024 19:51:55 -0500 Subject: [PATCH 38/77] Accept mx.array type for prompt argument for stream_generate (#1125) * Accept mx.array type for prompt argument for stream_generate * Fix formatting --- llms/mlx_lm/utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0e2f7af7..f439ca99 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -298,7 +298,7 @@ def generate_step( def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: Union[str, List[int]], + prompt: Union[str, mx.array, List[int]], max_tokens: int = 100, **kwargs, ) -> Generator[GenerationResponse, None, None]: @@ -308,7 +308,7 @@ def stream_generate( Args: model (nn.Module): The model to use for generation. tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (Union[str, List[int]]): The input prompt string or integer tokens. + prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. max_tokens (int): The maximum number of tokens. Default: ``100``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. @@ -320,7 +320,11 @@ def stream_generate( if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt)) + if not isinstance(prompt, mx.array): + prompt = mx.array( + prompt if isinstance(prompt, list) else tokenizer.encode(prompt) + ) + detokenizer = tokenizer.detokenizer with wired_limit(model, [generation_stream]): From 8801beb66f61d16114d4014fec32a266778e4481 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 2 Dec 2024 11:42:58 -0800 Subject: [PATCH 39/77] Add olmo2 (#1128) * add olmo2 * add olmo2 --- llms/mlx_lm/models/olmo2.py | 312 ++++++++++++++++++++++++++++++++++++ llms/mlx_lm/tuner/utils.py | 1 + llms/tests/test_models.py | 20 +++ 3 files changed, 333 insertions(+) create mode 100644 llms/mlx_lm/models/olmo2.py diff --git a/llms/mlx_lm/models/olmo2.py b/llms/mlx_lm/models/olmo2.py new file mode 100644 index 00000000..a28fdcc1 --- /dev/null +++ b/llms/mlx_lm/models/olmo2.py @@ -0,0 +1,312 @@ +# Copyright © 2023-2024 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + head_dim: Optional[int] = None + max_position_embeddings: Optional[int] = None + num_key_value_heads: Optional[int] = None + attention_bias: bool = False + mlp_bias: bool = False + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling: + if not "factor" in self.rope_scaling: + raise ValueError(f"rope_scaling must contain 'factor'") + rope_type = self.rope_scaling.get("type") or self.rope_scaling.get( + "rope_type" + ) + if rope_type is None: + raise ValueError( + f"rope_scaling must contain either 'type' or 'rope_type'" + ) + if rope_type not in ["linear", "dynamic", "llama3"]: + raise ValueError( + "rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'" + ) + + +class DynamicNTKScalingRoPE(nn.Module): + """Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE.""" + + def __init__( + self, + dims: int, + max_position_embeddings: int = 2048, + traditional: bool = False, + base: float = 10000, + scale: float = 1.0, + rope_type: str = "default", + rope_scaling: dict = None, + ): + super().__init__() + self.dims = dims + self.max_position_embeddings = max_position_embeddings + self.traditional = traditional + self.scale = scale + self.rope_type = rope_type + self.rope_scaling = rope_scaling + self.base = base + self.compute_freqs() + + def compute_freqs(self): + if self.rope_type != "llama3": + self._freqs = None + return + factor = self.rope_scaling["factor"] + low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0) + high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0) + old_context_len = self.rope_scaling.get( + "original_max_position_embeddings", + 8192, + ) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims) + wavelens = 2 * mx.pi * freqs + + freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) + is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) + smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) + self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) + self.base = None + + def extra_repr(self): + return ( + f"{self.dims}, traditional={self.traditional}, " + f"max_position_embeddings={self.max_position_embeddings}, " + f"scaling_factor={self.scale}, rope_type={self.rope_type}" + ) + + def __call__(self, x, offset: int = 0): + return mx.fast.rope( + x, + self.dims, + traditional=self.traditional, + base=self.base, + scale=self.scale, + offset=offset, + freqs=self._freqs, + ) + + +def initialize_rope(args: ModelArgs): + head_dim = args.head_dim or args.hidden_size // args.num_attention_heads + + rope_scaling = args.rope_scaling + rope_type = "default" + rope_scale = 1.0 + + if rope_scaling is not None: + rope_type = ( + rope_scaling.get("type") or rope_scaling.get("rope_type") or "default" + ) + if rope_type == "linear": + rope_scale = 1 / rope_scaling["factor"] + elif rope_type == "llama3": + rope_scale = 1.0 # The scaling is handled internally for llama3 + + return DynamicNTKScalingRoPE( + dims=head_dim, + max_position_embeddings=args.max_position_embeddings, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, + rope_type=rope_type, + rope_scaling=rope_scaling, + ) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads + + self.scale = head_dim**-0.5 + if hasattr(args, "attention_bias"): + attention_bias = args.attention_bias + else: + attention_bias = False + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) + + self.rope = initialize_rope(args) + self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps) + self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + queries = self.q_norm(queries) + keys = self.k_norm(keys) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + hidden_dim = args.intermediate_size + if hasattr(args, "mlp_bias"): + mlp_bias = args.mlp_bias + else: + mlp_bias = False + + self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) + self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.post_feedforward_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.post_attention_layernorm(self.self_attn(x, mask, cache)) + h = x + r + r = self.post_feedforward_layernorm(self.mlp(h)) + out = h + r + return out + + +class LlamaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = LlamaModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + def sanitize(self, weights): + # Remove unused precomputed rotary freqs + return { + k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k + } + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 7c78ee91..835cb482 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -98,6 +98,7 @@ def linear_to_lora_layers( "cohere", "minicpm", "deepseek", + "olmo2", ]: keys = set(["self_attn.q_proj", "self_attn.v_proj"]) if model.model_type in ["mixtral", "phimoe"]: diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 93b881b9..edb594d7 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -792,6 +792,26 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_olmo2(self): + from mlx_lm.models import olmo2 + + args = olmo2.ModelArgs( + model_type="olmo2", + hidden_size=128, + attention_bias=False, + intermediate_size=256, + num_attention_heads=4, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-4, + rope_theta=1000, + vocab_size=1000, + ) + model = olmo2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + if __name__ == "__main__": unittest.main() From 2a9294a5f02b4178bf27804d05f6d88581db06e0 Mon Sep 17 00:00:00 2001 From: hehua2008 Date: Tue, 3 Dec 2024 05:15:19 +0800 Subject: [PATCH 40/77] Fix bug in FluxSampler.timesteps method (#1131) --- flux/flux/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux/flux/sampler.py b/flux/flux/sampler.py index 3bff1ca2..54c4fe35 100644 --- a/flux/flux/sampler.py +++ b/flux/flux/sampler.py @@ -25,7 +25,7 @@ class FluxSampler: ): t = mx.linspace(start, stop, num_steps + 1) - if self._schnell: + if not self._schnell: t = self._time_shift(image_sequence_length, t) return t.tolist() From eb9277f574c6ba9da8494afade631e7a8553402a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 2 Dec 2024 13:15:50 -0800 Subject: [PATCH 41/77] Allow loading from diffusers ckpt (#1117) --- flux/flux/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flux/flux/model.py b/flux/flux/model.py index 18ea70b0..d8ad9d9b 100644 --- a/flux/flux/model.py +++ b/flux/flux/model.py @@ -85,6 +85,8 @@ class Flux(nn.Module): def sanitize(self, weights): new_weights = {} for k, w in weights.items(): + if k.startswith("model.diffusion_model."): + k = k[22:] if k.endswith(".scale"): k = k[:-6] + ".weight" for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]: From 0ca162cfb2a85164cbec18304913c4220520786c Mon Sep 17 00:00:00 2001 From: sakares saengkaew Date: Tue, 3 Dec 2024 14:56:07 +0700 Subject: [PATCH 42/77] Fix data_iter in prepare_dataset from speechcommands example (#1113) --- speechcommands/main.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/speechcommands/main.py b/speechcommands/main.py index 0d8da9fd..ed328f4c 100644 --- a/speechcommands/main.py +++ b/speechcommands/main.py @@ -76,6 +76,7 @@ def train_epoch(model, train_iter, optimizer, epoch): samples_per_sec = [] model.train(True) + train_iter.reset() for batch_counter, batch in enumerate(train_iter): x = mx.array(batch["audio"]) y = mx.array(batch["label"]) @@ -111,6 +112,7 @@ def test_epoch(model, test_iter): model.train(False) accs = [] throughput = [] + test_iter.reset() for batch_counter, batch in enumerate(test_iter): x = mx.array(batch["audio"]) y = mx.array(batch["label"]) From 1963df856529765b1e11beb24ea4542e6e75916d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 3 Dec 2024 16:17:14 -0800 Subject: [PATCH 43/77] Allow prompt callback to `generate_step` (#1133) * allow prompt callback and use in cache_prompt * nit * comments * bump version --- llms/mlx_lm/_version.py | 2 +- llms/mlx_lm/cache_prompt.py | 35 +++++++++++--------------- llms/mlx_lm/generate.py | 2 +- llms/mlx_lm/utils.py | 44 +++++++++++++++++++-------------- llms/tests/test_prompt_cache.py | 13 +++++----- 5 files changed, 48 insertions(+), 48 deletions(-) diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 343e0016..0f885fba 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.20.1" +__version__ = "0.20.2" diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 987b640d..9d7d1603 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -8,7 +8,7 @@ import time import mlx.core as mx from .models.cache import make_prompt_cache, save_prompt_cache -from .utils import load, maybe_quantize_kv_cache +from .utils import generate_step, load DEFAULT_QUANTIZED_KV_START = 5000 @@ -50,12 +50,6 @@ def setup_arg_parser(): action="store_true", help="Use the default chat template", ) - parser.add_argument( - "--cache-limit-gb", - type=int, - default=None, - help="Set the MLX cache limit in GB", - ) parser.add_argument( "--max-kv-size", type=int, @@ -99,9 +93,6 @@ def main(): parser = setup_arg_parser() args = parser.parse_args() - if args.cache_limit_gb is not None: - mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - # Building tokenizer_config tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} if args.eos_token is not None: @@ -144,26 +135,28 @@ def main(): y = mx.array(tokenizer.encode(prompt)) # Process the prompt - processed = 0 - step_size = 512 start = time.time() max_msg_len = 0 - while y.size > 0: - model(y[:step_size][None], cache=cache) - mx.eval([c.state for c in cache]) - mx.metal.clear_cache() - processed += min(y.size, step_size) - y = y[step_size:] + def callback(processed, total_tokens): current = time.time() speed = processed / (current - start) msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" + nonlocal max_msg_len max_msg_len = max(max_msg_len, len(msg)) print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) - maybe_quantize_kv_cache( - cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits - ) + for _ in generate_step( + y, + model, + max_tokens=0, + prompt_cache=cache, + kv_bits=args.kv_bits, + kv_group_size=args.kv_group_size, + quantized_kv_start=args.quantized_kv_start, + prompt_progress_callback=callback, + ): + pass print() print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB") diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 9e96fbdc..0c1b4acd 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -77,7 +77,7 @@ def setup_arg_parser(): ) parser.add_argument( "--min-tokens-to-keep", - type=float, + type=int, default=DEFAULT_MIN_TOKENS_TO_KEEP, help="Minimum tokens to keep for min-p sampling.", ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index f439ca99..86b786ce 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -183,6 +183,7 @@ def generate_step( prompt: mx.array, model: nn.Module, *, + max_tokens: int = 256, sampler: Optional[Callable[mx.array, mx.array]] = None, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, max_kv_size: Optional[int] = None, @@ -191,6 +192,7 @@ def generate_step( kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, + prompt_progress_callback: Optional[Callable[int, int]] = None, temp: Optional[float] = None, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = None, @@ -204,21 +206,25 @@ def generate_step( Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - prefill_step_size (int): Step size for processing the prompt. - max_kv_size (int, optional): Maximum size of the key-value cache. Old - entries (except the first 4 tokens) will be overwritten. - prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if - provided, the cache will be updated in place. + max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite + generator. Default: ``256``. sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a token from a vector of log probabilities. Default: ``None``. logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed logits. Default: ``None``. + max_kv_size (int, optional): Maximum size of the key-value cache. Old + entries (except the first 4 tokens) will be overwritten. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. + prefill_step_size (int): Step size for processing the prompt. kv_bits (int, optional): Number of bits to use for KV cache quantization. None implies no cache quantization. Default: ``None``. kv_group_size (int): Group size for KV cache quantization. Default: ``64``. quantized_kv_start (int): Step to begin using a quantized KV cache. when ``kv_bits`` is non-None. Default: ``0``. + prompt_prorgress_callback (Callable[int, int]): A call-back which takes the + prompt tokens processed so far and the total number of prompt tokens. Yields: Tuple[mx.array, mx.array]: One token and a vector of log probabilities. @@ -253,6 +259,7 @@ def generate_step( logits_processors = logits_processors or make_logits_processors( None, repetition_penalty, repetition_context_size or 20 ) + prompt_progress_callback = prompt_progress_callback or (lambda *_: None) def _step(y): with mx.stream(generation_stream): @@ -275,9 +282,13 @@ def generate_step( return y, logprobs.squeeze(0) with mx.stream(generation_stream): + total_prompt_tokens = y.size + prompt_processed_tokens = 0 while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) mx.eval([c.state for c in prompt_cache]) + prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) + prompt_processed_tokens += prefill_step_size y = y[prefill_step_size:] mx.metal.clear_cache() @@ -286,20 +297,25 @@ def generate_step( mx.async_eval(y, logprobs) n = 0 while True: - next_y, next_logprobs = _step(y) - mx.async_eval(next_y, next_logprobs) + if n != max_tokens: + next_y, next_logprobs = _step(y) + mx.async_eval(next_y, next_logprobs) + if n == 0: + mx.eval(y) + prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) + if n == max_tokens: + break yield y.item(), logprobs if n % 256 == 0: mx.metal.clear_cache() - n += 1 y, logprobs = next_y, next_logprobs + n += 1 def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: Union[str, mx.array, List[int]], - max_tokens: int = 100, **kwargs, ) -> Generator[GenerationResponse, None, None]: """ @@ -309,7 +325,6 @@ def stream_generate( model (nn.Module): The model to use for generation. tokenizer (PreTrainedTokenizer): The tokenizer. prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. - max_tokens (int): The maximum number of tokens. Default: ``100``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. @@ -330,10 +345,7 @@ def stream_generate( with wired_limit(model, [generation_stream]): detokenizer.reset() tic = time.perf_counter() - for n, (token, logprobs) in zip( - range(max_tokens), - generate_step(prompt, model, **kwargs), - ): + for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): if n == 0: prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time @@ -343,9 +355,6 @@ def stream_generate( detokenizer.add_token(token) - if n == (max_tokens - 1): - break - yield GenerationResponse( text=detokenizer.last_segment, token=token, @@ -385,7 +394,6 @@ def generate( model (nn.Module): The language model. tokenizer (PreTrainedTokenizer): The tokenizer. prompt (str): The string prompt. - max_tokens (int): The maximum number of tokens. Default: ``100``. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. kwargs: The remaining options get passed to :func:`stream_generate`. diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 0867ab56..de5694d5 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -121,21 +121,20 @@ class TestPromptCache(unittest.TestCase): def test_cache_with_generate(self): model, tokenizer = load(HF_MODEL_PATH) prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] - results = zip(range(4), generate_step(prompt, model)) - toks, all_logits = zip(*(r[1] for r in results)) + results = list(generate_step(prompt, model, max_tokens=4)) + toks, all_logits = zip(*results) prompt_cache = make_prompt_cache(model) i = 0 - for _, (tok, logits) in zip( - range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + for tok, logits in generate_step( + prompt, model, prompt_cache=prompt_cache, max_tokens=2 ): self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i])) i += 1 - for _, (tok, logits) in zip( - range(1), - generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + for tok, logits in generate_step( + mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1 ): i += 1 self.assertEqual(tok, toks[i]) From 1727959a27f2fb7b459387084b59f066296757a5 Mon Sep 17 00:00:00 2001 From: vb Date: Wed, 4 Dec 2024 04:21:39 +0100 Subject: [PATCH 44/77] Add mentions of MLX-my-repo. (#1129) * Add mentions of MLX-my-repo. * simplify * move * move --------- Co-authored-by: Awni Hannun --- llms/README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/llms/README.md b/llms/README.md index 60f68353..4fff4207 100644 --- a/llms/README.md +++ b/llms/README.md @@ -77,7 +77,7 @@ to see how to use the API in more detail. The `mlx-lm` package also comes with functionality to quantize and optionally upload models to the Hugging Face Hub. -You can convert models in the Python API with: +You can convert models using the Python API: ```python from mlx_lm import convert @@ -163,6 +163,10 @@ mlx_lm.convert \ --upload-repo mlx-community/my-4bit-mistral ``` +Models can also be converted and quantized directly in the +[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging +Face Space. + ### Long Prompts and Generations `mlx-lm` has some tools to scale efficiently to long prompts and generations: From cd8cf28c395e8f163d0aeb5264c6f605ccbc4009 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Sun, 8 Dec 2024 12:20:10 -0800 Subject: [PATCH 45/77] `mlx_lm.evaluate` (#1140) * Add evaluation script * only write top level results * add lm eval version * typo * create output dir * relative import * comment --------- Co-authored-by: David Grangier --- llms/mlx_lm/evaluate.py | 355 ++++++++++++++++++++++++++++++++++++++++ llms/setup.py | 2 + 2 files changed, 357 insertions(+) create mode 100644 llms/mlx_lm/evaluate.py diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py new file mode 100644 index 00000000..423d5823 --- /dev/null +++ b/llms/mlx_lm/evaluate.py @@ -0,0 +1,355 @@ +# Adapted from a PyTorch implementation by David Grangier + +import argparse +import json +import logging +import os +from importlib.metadata import version +from pathlib import Path +from typing import Optional + +import lm_eval +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from lm_eval.api.model import LM +from lm_eval.api.registry import register_model +from tqdm import tqdm + +from .models.cache import make_prompt_cache +from .utils import load, stream_generate + +PAD = 0 + + +def _len_longest_common_prefix(a, b): + l = 0 + for item_a, item_b in zip(a, b): + if item_a != item_b: + break + l += 1 + return l + + +def _rstrip_until(s, untils): + """Limit a string to the first occurence of any substring in untils.""" + l = len(s) + f = [s.find(u) for u in untils] + f = [l if x < 0 else x for x in f] + return s[: min(f)] + + +def _pad_inputs( + inputs, + maxlen, + genlen=0, + pad_left=False, + pad_multiple=32, + truncate=False, +): + # pad the prompts to the left with at least genlen tokens. + actual_maxlen = max(len(p) for p in inputs) + genlen + if actual_maxlen > maxlen: + if not truncate: + raise ValueError("Inputs are too long.") + else: # drop begining + actual_maxlen = maxlen + inputs = [p[max(0, len(p) - maxlen) :] for p in inputs] + if pad_multiple > 0: + maxlen = (actual_maxlen + pad_multiple - 1) // pad_multiple + maxlen *= pad_multiple + assert PAD == 0 + lr = np.array((1, 0) if pad_left else (0, 1)) + return np.stack( + [np.pad(np.array(x, np.int32), lr * (maxlen - len(x))) for x in inputs], + axis=0, + ) + + +@register_model("mlxlm") +class MLXLM(LM): + def __init__( + self, + path_or_hf_repo: str, + batch_size: int = 16, + max_tokens: Optional[int] = None, + ) -> None: + super().__init__() + self._batch_size = batch_size + self._model, self._tokenizer = load(path_or_hf_repo) + self._max_tokens = max_tokens or self._tokenizer.model_max_length + + def _score_fn(self, inputs, tokenize=True, step_size=32): + if tokenize: + inputs = self._tokenizer.encode(inputs) + inputs = _pad_inputs(inputs, self._max_tokens, truncate=False) + inputs = mx.array(inputs) + inputs, targets = inputs[..., :-1], inputs[..., 1:] + + cache = make_prompt_cache(self._model) + + mask = targets != PAD + + scores, is_greedy = [], [] + for i in range(0, inputs.shape[1], step_size): + logits = self._model(inputs[:, i : i + step_size], cache=cache) + + log_probs = nn.log_softmax(logits.astype(mx.float32)) + score = mx.take_along_axis( + log_probs, targets[:, i : i + step_size, mx.newaxis], axis=-1 + )[..., 0] + ig = mask[:, i : i + step_size] * ( + targets[:, i : i + step_size] == mx.argmax(logits, axis=-1) + ) + + mx.eval(score, ig) + mx.metal.clear_cache() + + is_greedy.append(ig) + scores.append(score) + + scores = mx.concatenate(scores, axis=1) + is_greedy = mx.concatenate(is_greedy, axis=1) + + return scores, mask.sum(axis=-1), is_greedy + + def _loglikelihood(self, texts, score_spans=None, tokenize=True): + # sort by length to get batches with little padding. + sorted_indices = sorted(range(len(texts)), key=lambda i: -len(texts[i])) + sorted_inputs = [texts[sorted_indices[i]] for i in range(len(texts))] + sorted_spans = None + if score_spans is not None: + sorted_spans = [score_spans[sorted_indices[i]] for i in range(len(texts))] + + results = [] + for i in tqdm(range(0, len(sorted_inputs), self._batch_size)): + batch = sorted_inputs[i : i + self._batch_size] + scores, length, is_greedy = self._score_fn(batch, tokenize=tokenize) + for j in range(len(batch)): + if sorted_spans is None: # full sequence score + mask = mx.arange(scores[j].shape[-1]) < length + score = (scores[j].astype(mx.float32) * mask).sum(axis=-1) + ig = (is_greedy[j].astype(mx.int32) * mask).sum(axis=-1) + else: # subsequence score + start, end = sorted_spans[i + j] + score = scores[j][start:end].astype(mx.float32).sum() + ig = is_greedy[j][start:end].astype(mx.int32).sum() + length = end - start + + results.append((score.item(), ig.item(), length)) + + # reorder the outputs + inv_sort = np.argsort(sorted_indices) + results = [results[inv_sort[i]] for i in range(len(results))] + + return results + + def _tokenize(self, texts): + return [tuple(self._tokenizer.encode(t)) for t in texts] + + def loglikelihood(self, requests) -> list[tuple[float, bool]]: + """Compute log-likelihood of generating a continuation from a context. + Downstream tasks should attempt to use loglikelihood instead of other + LM calls whenever possible. + :param requests: list[Instance] + A list of Instance objects, with property `args` which returns a tuple (context, continuation). + `context: str` + Context string. Implementations of LM must be able to handle an + empty context string. + `continuation: str` + The continuation over which log likelihood will be calculated. If + there is a word boundary, the space should be in the continuation. + For example, context="hello" continuation=" world" is correct. + :return: list[tuple[float, bool]] + A list of pairs (logprob, isgreedy) + `logprob: float` + The log probability of `continuation`. + `isgreedy`: + Whether `continuation` would be generated by greedy sampling from `context`. + """ + logging.info("Estimating loglikelihood for %d pairs." % len(requests)) + + # tokenize prefix and prefix + completion for all requests. + tokenized = self._tokenize( + [t for r in requests for t in [r.args[0], r.args[0] + r.args[1]]] + ) + + # max length (prefix + completion) and longest common prefix per question. + length_stats = {} + for prefix, completed in zip(tokenized[0::2], tokenized[1::2]): + max_completed_l, min_prefix_l = length_stats.get(prefix, (0, 1e8)) + length_stats[prefix] = ( + max(max_completed_l, len(completed)), + min(min_prefix_l, _len_longest_common_prefix(prefix, completed)), + ) + + # truncate requests for completed sequences longer than model context. + shortened = [] + completion_spans = [] + long_completions = 0 + for prefix, completed in zip(tokenized[0::2], tokenized[1::2]): + max_completed_l, prefix_l = length_stats[prefix] + # compute truncation length + truncation = max(0, max_completed_l - self._max_tokens - 1) + prefix_l = prefix_l - truncation + if prefix_l <= 0: + # completion too long, prefix is eliminated for some requests. + long_completions += 1 + truncation = max(0, len(completed) - self._max_tokens - 1) + prefix_l = 1 + # truncate the completed sequence + completed = completed[truncation:] + shortened.append(completed) + # scores do not include initial bos, substract 1 to span bounds + completion_spans.append((prefix_l - 1, len(completed) - 1)) + + if long_completions > 0: + logging.info( + f"Prefix eliminated for {long_completions} requests with " + + "completion longer than context." + ) + + # model scoring, returns num_requests x (logp, is_greedy, length). + results = self._loglikelihood( + shortened, + score_spans=completion_spans, + tokenize=False, + ) + return [(r[0], r[1] == r[2]) for r in results] + + def loglikelihood_rolling(self, requests) -> list[float]: + """Compute full log-likelihood of a string, with no truncation, for perplexity computation + - We will use the full max context length of the model. + - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to + the max context length. + - IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations + which may simply concatenate multiple documents together. + - IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into + multiple chunks, the last input will still a full-sized context. + Example: + Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ] + Prefix: EOT + Max context length: 4 + Resulting input/prediction pairs: + INPUT: EOT 0 1 2 + PRED: 0 1 2 3 + INPUT: 3 4 5 6 + PRED: 4 5 6 7 + INPUT: 5 6 7 8 + PRED: 8 9 + Observe that: + 1. Each token is predicted exactly once + 2. For the last pair, we provide the full context, but only score the last two tokens + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context,). + string: str + String for which we are computing overall loglikelihood + :return: list[tuple[float]] + A list of tuples (logprob,) + logprob: float + The log probability of `context` conditioned on the EOT token. + """ + logging.info( + "Estimating loglikelihood rolling for %d sequences." % len(requests) + ) + inputs = [req.args[0] for req in requests] + return [t[0] for t in self._loglikelihood(inputs)] + + def generate_until(self, requests) -> list[str]: + """Generate greedily until a stopping sequence + :param requests: list[Instance] + A list of Instance objects with property `args` which returns a tuple (context, until). + context: str + Context string + until: [str] + The string sequences to generate until. These string sequences + may each span across multiple tokens, or may be part of one token. + :return: list[str] + A list of strings continuation + continuation: str + The generated continuation. + """ + logging.info("Generating continuation for %d sequences." % len(requests)) + contexts, options = zip(*[req.args for req in requests]) + # contrary to the doc the second element of the tuple contains + # {'do_sample': False, 'until': ['\n\n'], 'temperature': 0} + keys = list(options[0].keys()) + assert "until" in keys + untils = [x["until"] for x in options] + completions = [] + for context, until in tqdm(zip(contexts, untils), total=len(contexts)): + if ( + hasattr(self._tokenizer, "apply_chat_template") + and self._tokenizer.chat_template is not None + ): + messages = [{"role": "user", "content": context}] + context = self._tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + max_tokens = min( + self._max_tokens, + self._tokenizer.model_max_length - len(self._tokenizer.encode(context)), + ) + text = "" + for response in stream_generate( + self._model, self._tokenizer, prompt=context, max_tokens=max_tokens + ): + text += response.text + if any(u in text for u in until): + text = _rstrip_until(text, until) + completions.append(text) + break + else: + completions.append(text) + return completions + + +def main(): + parser = argparse.ArgumentParser( + "Evaluate an MLX model using lm-evaluation-harness." + ) + parser.add_argument("--model", help="Model to evaluate", required=True) + parser.add_argument("--tasks", nargs="+", required=True) + parser.add_argument( + "--output-dir", default=".", help="Output directory for result files." + ) + parser.add_argument("--batch-size", type=int, default=16, help="Batch size") + parser.add_argument("--num-shots", type=int, default=0, help="Number of shots") + parser.add_argument( + "--max-tokens", + type=int, + help="Maximum nunber of tokens to generate. Defaults to the model's max context length.", + ) + parser.add_argument("--seed", type=int, default=123, help="Random seed.") + args = parser.parse_args() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Silence tokenizer warnings + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + mx.random.seed(args.seed) + + lm = MLXLM(args.model, batch_size=args.batch_size, max_tokens=args.max_tokens) + + results = lm_eval.simple_evaluate( + model=lm, + tasks=args.tasks, + num_fewshot=args.num_shots, + random_seed=args.seed, + numpy_random_seed=args.seed, + torch_random_seed=args.seed, + fewshot_random_seed=args.seed, + ) + + model_name = args.model.replace("/", "_") + task_names = "_".join(args.tasks) + ver = version("lm_eval") + filename = f"eval_{model_name}_{task_names}_{args.num_shots:02d}_v_{ver}.json" + output_path = output_dir / filename + output_path.write_text(json.dumps(results["results"], indent=4)) + print("Results:") + for result in results["results"].values(): + print(json.dumps(result, indent=4)) diff --git a/llms/setup.py b/llms/setup.py index 1c696dc0..b88dcd33 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -28,12 +28,14 @@ setup( python_requires=">=3.8", extras_require={ "testing": ["datasets"], + "evaluation": ["lm-eval"], }, entry_points={ "console_scripts": [ "mlx_lm.cache_prompt = mlx_lm.cache_prompt:main", "mlx_lm.chat = mlx_lm.chat:main", "mlx_lm.convert = mlx_lm.convert:main", + "mlx_lm.evaluate = mlx_lm.evaluate:main", "mlx_lm.fuse = mlx_lm.fuse:main", "mlx_lm.generate = mlx_lm.generate:main", "mlx_lm.lora = mlx_lm.lora:main", From 2211b27388f8cc5725360e3b14c2b114d61b0e8d Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Sun, 8 Dec 2024 14:21:50 -0800 Subject: [PATCH 46/77] Mixed Quantizations (#1132) * saving/loading mixed quantizations * comment * add bits per weight * more concise bpw * count bias too --- llms/mlx_lm/tuner/utils.py | 12 ++++---- llms/mlx_lm/utils.py | 61 +++++++++++++++++++++++++++++++++----- 2 files changed, 61 insertions(+), 12 deletions(-) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 835cb482..8351ed1b 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -250,12 +250,14 @@ def remove_lora_layers(model: nn.Module) -> nn.Module: return model -def print_trainable_parameters(model): - def nparams(m): - if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): - return m.weight.size * (32 // m.bits) - return sum(v.size for _, v in tree_flatten(m.parameters())) +def nparams(module): + if hasattr(module, "bits"): + n = 0 if not hasattr(module, "bias") else module.bias.size + return n + module.weight.size * 32 // module.bits + return sum(v.size for _, v in tree_flatten(module.parameters())) + +def print_trainable_parameters(model): leaf_modules = tree_flatten( model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 86b786ce..66a106a1 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download -from mlx.utils import tree_flatten, tree_reduce +from mlx.utils import tree_flatten, tree_map, tree_reduce from transformers import PreTrainedTokenizer # Local imports @@ -24,7 +24,7 @@ from .models import cache from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model -from .tuner.utils import load_adapters +from .tuner.utils import load_adapters, nparams # Constants MODEL_REMAPPING = { @@ -127,6 +127,17 @@ def _get_classes(config: dict): return arch.Model, arch.ModelArgs +def compute_bits_per_weight(model): + model_bytes = tree_reduce( + lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 + ) + leaf_modules = tree_flatten( + model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) + ) + model_params = sum(nparams(m) for _, m in leaf_modules) + return model_bytes * 8 / model_params + + def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: """ Ensures the model is available locally. If the path does not exist locally, @@ -496,15 +507,20 @@ def load_model( weights = model.sanitize(weights) if (quantization := config.get("quantization", None)) is not None: - # Handle legacy models which may not have everything quantized + def class_predicate(p, m): + # Handle custom per layer quantizations + if p in config["quantization"]: + return config["quantization"][p] if not hasattr(m, "to_quantized"): return False + # Handle legacy models which may not have everything quantized return f"{p}.scales" in weights nn.quantize( model, - **quantization, + group_size=quantization["group_size"], + bits=quantization["bits"], class_predicate=class_predicate, ) @@ -707,7 +723,13 @@ def save_weights( def quantize_model( - model: nn.Module, config: dict, q_group_size: int, q_bits: int + model: nn.Module, + config: dict, + q_group_size: int, + q_bits: int, + quant_predicate: Optional[ + Callable[[str, nn.Module, dict], Union[bool, dict]] + ] = None, ) -> Tuple: """ Applies quantization to the model weights. @@ -717,17 +739,37 @@ def quantize_model( config (dict): Model configuration. q_group_size (int): Group size for quantization. q_bits (int): Bits per weight for quantization. + quant_predicate (Callable): A callable that decides how + to quantize each layer based on the path. + Accepts the layer `path`, the `module` and the model `config`. + Returns either a bool to signify quantize/no quantize or + a dict of quantization parameters to pass to `to_quantized`. Returns: Tuple: Tuple containing quantized weights and config. """ quantized_config = copy.deepcopy(config) - nn.quantize(model, q_group_size, q_bits) quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} + + # Add any custom quantization parameters to the config as we go + def _class_predicate(p, m): + bool_or_params = quant_predicate(p, m, config) + quantized_config["quantization"][p] = bool_or_params + return bool_or_params + + nn.quantize( + model, + q_group_size, + q_bits, + class_predicate=_class_predicate if quant_predicate else None, + ) # support hf model tree #957 quantized_config["quantization_config"] = quantized_config["quantization"] quantized_weights = dict(tree_flatten(model.parameters())) + bpw = compute_bits_per_weight(model) + print(f"[INFO] Quantized model with {bpw:.3f} bits per weight.") + return quantized_weights, quantized_config @@ -764,6 +806,9 @@ def convert( upload_repo: str = None, revision: Optional[str] = None, dequantize: bool = False, + quant_predicate: Optional[ + Callable[[str, nn.Module, dict], Union[bool, dict]] + ] = None, ): # Check the save path is empty if isinstance(mlx_path, str): @@ -789,7 +834,9 @@ def convert( if quantize: print("[INFO] Quantizing") model.load_weights(list(weights.items())) - weights, config = quantize_model(model, config, q_group_size, q_bits) + weights, config = quantize_model( + model, config, q_group_size, q_bits, quant_predicate=quant_predicate + ) if dequantize: print("[INFO] Dequantizing") From 1fd6aae871e9e21613ae90624cb4a72bdf709cc6 Mon Sep 17 00:00:00 2001 From: hehua2008 Date: Mon, 9 Dec 2024 14:09:04 +0800 Subject: [PATCH 47/77] Fix flux training with batch size (#1135) Co-authored-by: Angelos Katharopoulos --- flux/flux/sampler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flux/flux/sampler.py b/flux/flux/sampler.py index 54c4fe35..e7a1080d 100644 --- a/flux/flux/sampler.py +++ b/flux/flux/sampler.py @@ -50,6 +50,7 @@ class FluxSampler: if noise is not None else mx.random.normal(x.shape, dtype=x.dtype, key=key) ) + t = t.reshape([-1] + [1] * (x.ndim - 1)) return x * (1 - t) + t * noise def step(self, pred, x_t, t, t_prev): From ed91bbc4dcf2734203c5302e2cfd1c5a10daa2e2 Mon Sep 17 00:00:00 2001 From: Peter Sibley Date: Mon, 9 Dec 2024 02:01:53 -0500 Subject: [PATCH 48/77] Fix final message at end of flux training (#1143) --- flux/dreambooth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux/dreambooth.py b/flux/dreambooth.py index ffdb02d7..f82178b9 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -289,4 +289,4 @@ if __name__ == "__main__": tic = time.time() save_adapters("final_adapters.safetensors", flux, args) - print(f"Training successful. Saved final weights to {args.adapter_file}.") + print("Training successful.") From 893b3f085e01dc1db224d9f983a7a82fb4f4d584 Mon Sep 17 00:00:00 2001 From: hehua2008 Date: Mon, 9 Dec 2024 15:29:48 +0800 Subject: [PATCH 49/77] Change Flux default max_shift to 1.15 to match the official one (#1137) --- flux/flux/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flux/flux/sampler.py b/flux/flux/sampler.py index e7a1080d..6f293edc 100644 --- a/flux/flux/sampler.py +++ b/flux/flux/sampler.py @@ -7,7 +7,7 @@ import mlx.core as mx class FluxSampler: - def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5): + def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.15): self._base_shift = base_shift self._max_shift = max_shift self._schnell = "schnell" in name From 5687d5b99b66171a234705d0fa30721076f7446e Mon Sep 17 00:00:00 2001 From: n8programs <43304488+N8python@users.noreply.github.com> Date: Mon, 9 Dec 2024 10:58:25 -0500 Subject: [PATCH 50/77] Adds EXAONE architecture. (#1145) * Adds EXAONE architecture. * nits + format * format * clean up and fix rope * clean up and fix rope --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/models/exaone.py | 163 +++++++++++++++++++++++++++++++ llms/mlx_lm/models/llama.py | 120 ++--------------------- llms/mlx_lm/models/olmo2.py | 121 ++--------------------- llms/mlx_lm/models/rope_utils.py | 91 +++++++++++++++++ llms/mlx_lm/tuner/utils.py | 2 + llms/tests/test_models.py | 39 ++++++++ 6 files changed, 312 insertions(+), 224 deletions(-) create mode 100644 llms/mlx_lm/models/exaone.py create mode 100644 llms/mlx_lm/models/rope_utils.py diff --git a/llms/mlx_lm/models/exaone.py b/llms/mlx_lm/models/exaone.py new file mode 100644 index 00000000..eaed5dd8 --- /dev/null +++ b/llms/mlx_lm/models/exaone.py @@ -0,0 +1,163 @@ +# Copyright © 2024 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_layers: int + intermediate_size: int + num_attention_heads: int + vocab_size: int + rope_theta: float + layer_norm_epsilon: float + num_key_value_heads: int + head_dim: Optional[int] = None + max_position_embeddings: Optional[int] = None + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + attention_bias: bool = False + mlp_bias: bool = False + + +class AttentionModule(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.head_dim = head_dim = args.head_dim or (dim // n_heads) + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.out_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) + + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) + + def __call__( + self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None + ) -> mx.array: + B, L, D = x.shape + q = self.q_proj(x).reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + k = self.k_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + v = self.v_proj(x).reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + q = self.rope(q, offset=cache.offset) + k = self.rope(k, offset=cache.offset) + k, v = cache.update_and_fetch(k, v) + else: + q = self.rope(q) + k = self.rope(k) + + out = scaled_dot_product_attention( + q, k, v, cache=cache, scale=self.scale, mask=mask + ) + out = out.transpose(0, 2, 1, 3).reshape(B, L, D) + return self.out_proj(out) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.attention = AttentionModule(args) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + dim = args.hidden_size + hidden_dim = args.intermediate_size + self.c_fc_0 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) + self.c_fc_1 = nn.Linear(dim, hidden_dim, bias=args.mlp_bias) + self.c_proj = nn.Linear(hidden_dim, dim, bias=args.mlp_bias) + + def __call__(self, x: mx.array) -> mx.array: + return self.c_proj(nn.silu(self.c_fc_0(x)) * self.c_fc_1(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.attn = Attention(args) + self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + self.mlp = MLP(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + h = x + self.attn.attention(self.ln_1(x), mask, cache) + out = h + self.mlp(self.ln_2(h)) + return out + + +class ExaoneModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.wte = nn.Embedding(args.vocab_size, args.hidden_size) + self.h = [TransformerBlock(args) for _ in range(args.num_layers)] + self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.wte(inputs) + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.h) + + for layer, c in zip(self.h, cache): + h = layer(h, mask, cache=c) + + return self.ln_f(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.transformer = ExaoneModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.transformer(inputs, cache) + if self.args.tie_word_embeddings: + out = self.transformer.wte.as_linear(out) + else: + out = self.lm_head(out) + return out + + @property + def layers(self): + return self.transformer.h diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 438278e5..290cb83e 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -7,6 +7,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope @dataclass @@ -32,117 +33,6 @@ class ModelArgs(BaseModelArgs): if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads - if self.rope_scaling: - if not "factor" in self.rope_scaling: - raise ValueError(f"rope_scaling must contain 'factor'") - rope_type = self.rope_scaling.get("type") or self.rope_scaling.get( - "rope_type" - ) - if rope_type is None: - raise ValueError( - f"rope_scaling must contain either 'type' or 'rope_type'" - ) - if rope_type not in ["linear", "dynamic", "llama3"]: - raise ValueError( - "rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'" - ) - - -class DynamicNTKScalingRoPE(nn.Module): - """Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE.""" - - def __init__( - self, - dims: int, - max_position_embeddings: int = 2048, - traditional: bool = False, - base: float = 10000, - scale: float = 1.0, - rope_type: str = "default", - rope_scaling: dict = None, - ): - super().__init__() - self.dims = dims - self.max_position_embeddings = max_position_embeddings - self.traditional = traditional - self.scale = scale - self.rope_type = rope_type - self.rope_scaling = rope_scaling - self.base = base - self.compute_freqs() - - def compute_freqs(self): - if self.rope_type != "llama3": - self._freqs = None - return - factor = self.rope_scaling["factor"] - low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0) - high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0) - old_context_len = self.rope_scaling.get( - "original_max_position_embeddings", - 8192, - ) - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - - freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims) - wavelens = 2 * mx.pi * freqs - - freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) - is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) - smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) - self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) - self.base = None - - def extra_repr(self): - return ( - f"{self.dims}, traditional={self.traditional}, " - f"max_position_embeddings={self.max_position_embeddings}, " - f"scaling_factor={self.scale}, rope_type={self.rope_type}" - ) - - def __call__(self, x, offset: int = 0): - return mx.fast.rope( - x, - self.dims, - traditional=self.traditional, - base=self.base, - scale=self.scale, - offset=offset, - freqs=self._freqs, - ) - - -def initialize_rope(args: ModelArgs): - head_dim = args.head_dim or args.hidden_size // args.num_attention_heads - - rope_scaling = args.rope_scaling - rope_type = "default" - rope_scale = 1.0 - - if rope_scaling is not None: - rope_type = ( - rope_scaling.get("type") or rope_scaling.get("rope_type") or "default" - ) - if rope_type == "linear": - rope_scale = 1 / rope_scaling["factor"] - elif rope_type == "llama3": - rope_scale = 1.0 # The scaling is handled internally for llama3 - - return DynamicNTKScalingRoPE( - dims=head_dim, - max_position_embeddings=args.max_position_embeddings, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - rope_type=rope_type, - rope_scaling=rope_scaling, - ) - class Attention(nn.Module): def __init__(self, args: ModelArgs): @@ -165,7 +55,13 @@ class Attention(nn.Module): self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) - self.rope = initialize_rope(args) + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) def __call__( self, diff --git a/llms/mlx_lm/models/olmo2.py b/llms/mlx_lm/models/olmo2.py index a28fdcc1..64d7e116 100644 --- a/llms/mlx_lm/models/olmo2.py +++ b/llms/mlx_lm/models/olmo2.py @@ -7,6 +7,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .rope_utils import initialize_rope @dataclass @@ -32,117 +33,6 @@ class ModelArgs(BaseModelArgs): if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads - if self.rope_scaling: - if not "factor" in self.rope_scaling: - raise ValueError(f"rope_scaling must contain 'factor'") - rope_type = self.rope_scaling.get("type") or self.rope_scaling.get( - "rope_type" - ) - if rope_type is None: - raise ValueError( - f"rope_scaling must contain either 'type' or 'rope_type'" - ) - if rope_type not in ["linear", "dynamic", "llama3"]: - raise ValueError( - "rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'" - ) - - -class DynamicNTKScalingRoPE(nn.Module): - """Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE.""" - - def __init__( - self, - dims: int, - max_position_embeddings: int = 2048, - traditional: bool = False, - base: float = 10000, - scale: float = 1.0, - rope_type: str = "default", - rope_scaling: dict = None, - ): - super().__init__() - self.dims = dims - self.max_position_embeddings = max_position_embeddings - self.traditional = traditional - self.scale = scale - self.rope_type = rope_type - self.rope_scaling = rope_scaling - self.base = base - self.compute_freqs() - - def compute_freqs(self): - if self.rope_type != "llama3": - self._freqs = None - return - factor = self.rope_scaling["factor"] - low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0) - high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0) - old_context_len = self.rope_scaling.get( - "original_max_position_embeddings", - 8192, - ) - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - - freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims) - wavelens = 2 * mx.pi * freqs - - freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) - is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) - smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) - self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) - self.base = None - - def extra_repr(self): - return ( - f"{self.dims}, traditional={self.traditional}, " - f"max_position_embeddings={self.max_position_embeddings}, " - f"scaling_factor={self.scale}, rope_type={self.rope_type}" - ) - - def __call__(self, x, offset: int = 0): - return mx.fast.rope( - x, - self.dims, - traditional=self.traditional, - base=self.base, - scale=self.scale, - offset=offset, - freqs=self._freqs, - ) - - -def initialize_rope(args: ModelArgs): - head_dim = args.head_dim or args.hidden_size // args.num_attention_heads - - rope_scaling = args.rope_scaling - rope_type = "default" - rope_scale = 1.0 - - if rope_scaling is not None: - rope_type = ( - rope_scaling.get("type") or rope_scaling.get("rope_type") or "default" - ) - if rope_type == "linear": - rope_scale = 1 / rope_scaling["factor"] - elif rope_type == "llama3": - rope_scale = 1.0 # The scaling is handled internally for llama3 - - return DynamicNTKScalingRoPE( - dims=head_dim, - max_position_embeddings=args.max_position_embeddings, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - rope_type=rope_type, - rope_scaling=rope_scaling, - ) - class Attention(nn.Module): def __init__(self, args: ModelArgs): @@ -165,7 +55,14 @@ class Attention(nn.Module): self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) - self.rope = initialize_rope(args) + self.rope = initialize_rope( + self.head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + args.max_position_embeddings, + ) + self.q_norm = nn.RMSNorm(n_heads * head_dim, args.rms_norm_eps) self.k_norm = nn.RMSNorm(n_kv_heads * head_dim, args.rms_norm_eps) diff --git a/llms/mlx_lm/models/rope_utils.py b/llms/mlx_lm/models/rope_utils.py new file mode 100644 index 00000000..d30b432d --- /dev/null +++ b/llms/mlx_lm/models/rope_utils.py @@ -0,0 +1,91 @@ +# Copyright © 2023-2024 Apple Inc. + +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + + +class Llama3RoPE(nn.Module): + + def __init__( + self, + dims: int, + max_position_embeddings: int = 2048, + traditional: bool = False, + base: float = 10000, + scaling_config: dict = None, + ): + super().__init__() + self.dims = dims + self.max_position_embeddings = max_position_embeddings + self.traditional = traditional + + factor = scaling_config["factor"] + low_freq_factor = scaling_config.get("low_freq_factor", 1.0) + high_freq_factor = scaling_config.get("high_freq_factor", 4.0) + old_context_len = scaling_config.get( + "original_max_position_embeddings", + 8192, + ) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + freqs = base ** (mx.arange(0, dims, 2) / dims) + wavelens = 2 * mx.pi * freqs + + freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) + is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) + smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) + self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) + + def extra_repr(self): + return ( + f"{self.dims}, traditional={self.traditional}, " + f"max_position_embeddings={self.max_position_embeddings}" + ) + + def __call__(self, x, offset: int = 0): + return mx.fast.rope( + x, + self.dims, + traditional=self.traditional, + base=None, + scale=1.0, + offset=offset, + freqs=self._freqs, + ) + + +def initialize_rope( + dims, + base, + traditional, + scaling_config: Optional[dict] = None, + max_position_embeddings: Optional[int] = None, +): + if scaling_config is not None: + rope_type = scaling_config.get("type") or scaling_config.get( + "rope_type", "default" + ) + else: + rope_type = "default" + + if rope_type in ["default", "linear"]: + scale = 1 / scaling_config["factor"] if rope_type == "linear" else 1.0 + return nn.RoPE(dims, traditional=traditional, base=base, scale=scale) + + elif rope_type == "llama3": + return Llama3RoPE( + dims=dims, + max_position_embeddings=max_position_embeddings, + traditional=traditional, + base=base, + scaling_config=scaling_config, + ) + else: + raise ValueError(f"Unsupported RoPE type {rope_type}") diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 8351ed1b..6821f434 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -144,6 +144,8 @@ def linear_to_lora_layers( "mixer.out_proj", ] ) + elif model.model_type == "exaone": + keys = set(["attn.attention.q_proj", "attn.attention.v_proj"]) else: raise ValueError(f"Lora does not support {model.model_type}") diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index edb594d7..374a5113 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -2,7 +2,9 @@ import unittest import mlx.core as mx +import mlx.nn as nn from mlx.utils import tree_map +from mlx_lm.models import rope_utils from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache @@ -126,6 +128,26 @@ class TestModels(unittest.TestCase): self.assertEqual(cache.offset, 22) self.assertTrue(mx.allclose(x, k[..., -2:, :])) + def test_rope(self): + rope = rope_utils.initialize_rope(32, base=100, traditional=False) + self.assertTrue(isinstance(rope, nn.RoPE)) + + rope = rope_utils.initialize_rope( + 32, + base=100, + traditional=False, + scaling_config={"rope_type": "linear", "factor": 10.0}, + ) + self.assertTrue(isinstance(rope, nn.RoPE)) + + rope = rope_utils.initialize_rope( + 32, + base=100, + traditional=False, + scaling_config={"rope_type": "llama3", "factor": 2.0}, + ) + self.assertTrue(isinstance(rope, rope_utils.Llama3RoPE)) + def model_test_runner(self, model, model_type, vocab_size, num_layers): self.assertEqual(len(model.layers), num_layers) @@ -812,6 +834,23 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_exaone(self): + from mlx_lm.models import exaone + + args = exaone.ModelArgs( + model_type="exaone", + hidden_size=128, + num_layers=4, + intermediate_size=256, + num_attention_heads=8, + num_key_value_heads=2, + vocab_size=1000, + layer_norm_epsilon=1e-4, + rope_theta=10000, + ) + model = exaone.Model(args) + self.model_test_runner(model, args.model_type, args.vocab_size, args.num_layers) + if __name__ == "__main__": unittest.main() From 12083c4b7ed041fcb733ac7821eb726a0169ff76 Mon Sep 17 00:00:00 2001 From: madroid Date: Tue, 10 Dec 2024 00:53:58 +0800 Subject: [PATCH 51/77] Support for multiple EOS tokens (#1141) * Support for multiple EOS tokens * Change _eos_token_ids type from list to set * Remove model_config & add eos_token_id * nits --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/tokenizer_utils.py | 23 +++++++++++++++++++---- llms/mlx_lm/utils.py | 24 +++++++++++++----------- llms/tests/test_utils_load_model.py | 4 ++-- 3 files changed, 34 insertions(+), 17 deletions(-) diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 0fa41ac0..10a257f6 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -254,21 +254,33 @@ class TokenizerWrapper: huggingface tokenizer. """ - def __init__(self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer): + def __init__( + self, tokenizer, detokenizer_class=NaiveStreamingDetokenizer, eos_token_ids=None + ): self._tokenizer = tokenizer self._detokenizer = detokenizer_class(tokenizer) + self._eos_token_ids = ( + set(eos_token_ids) + if eos_token_ids is not None + else {tokenizer.eos_token_id} + ) def __getattr__(self, attr): if attr == "detokenizer": return self._detokenizer + elif attr == "eos_token_ids": + return self._eos_token_ids elif attr.startswith("_"): return self.__getattribute__(attr) else: return getattr(self._tokenizer, attr) def __setattr__(self, attr, value): - if attr == "detokenizer": - raise AttributeError("Cannot set the detokenizer.") + if attr in {"detokenizer", "eos_token_ids"}: + if attr == "detokenizer": + raise AttributeError("Cannot set the detokenizer.") + elif attr == "eos_token_ids": + self._eos_token_ids = set(value) if value is not None else set() elif attr.startswith("_"): super().__setattr__(attr, value) else: @@ -315,7 +327,7 @@ def _is_bpe_decoder(decoder): return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel" -def load_tokenizer(model_path, tokenizer_config_extra={}): +def load_tokenizer(model_path, tokenizer_config_extra={}, eos_token_ids=None): """Load a huggingface tokenizer and try to infer the type of streaming detokenizer to use. @@ -336,7 +348,10 @@ def load_tokenizer(model_path, tokenizer_config_extra={}): elif _is_bpe_decoder(tokenizer_content["decoder"]): detokenizer_class = BPEStreamingDetokenizer + if isinstance(eos_token_ids, int): + eos_token_ids = [eos_token_ids] return TokenizerWrapper( AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), detokenizer_class, + eos_token_ids=eos_token_ids, ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 66a106a1..d81bb66a 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -361,7 +361,7 @@ def stream_generate( prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time tic = time.perf_counter() - if token == tokenizer.eos_token_id: + if token in tokenizer.eos_token_ids: break detokenizer.add_token(token) @@ -467,11 +467,11 @@ def load_model( lazy (bool): If False eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` - model_config (dict, optional): Configuration parameters for the model. - Defaults to an empty dictionary. + model_config (dict, optional): Optional configuration parameters for the + model. Defaults to an empty dictionary. get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): A function that returns the model class and model args class given a config. - Defaults to the _get_classes function. + Defaults to the ``_get_classes`` function. Returns: nn.Module: The loaded and initialized model. @@ -480,7 +480,6 @@ def load_model( FileNotFoundError: If the weight files (.safetensors) are not found. ValueError: If the model class or args class are not found or cannot be instantiated. """ - config = load_config(model_path) config.update(model_config) @@ -530,7 +529,7 @@ def load_model( mx.eval(model.parameters()) model.eval() - return model + return model, config def load( @@ -563,11 +562,13 @@ def load( """ model_path = get_model_path(path_or_hf_repo) - model = load_model(model_path, lazy, model_config) + model, config = load_model(model_path, lazy) if adapter_path is not None: model = load_adapters(model, adapter_path) model.eval() - tokenizer = load_tokenizer(model_path, tokenizer_config) + tokenizer = load_tokenizer( + model_path, tokenizer_config, eos_token_ids=config.get("eos_token_id", None) + ) return model, tokenizer @@ -575,9 +576,10 @@ def load( def fetch_from_hub( model_path: Path, lazy: bool = False ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: - model = load_model(model_path, lazy) - config = load_config(model_path) - tokenizer = load_tokenizer(model_path) + model, config = load_model(model_path, lazy) + tokenizer = load_tokenizer( + model_path, eos_token_ids=config.get("eos_token_id", None) + ) return model, config, tokenizer diff --git a/llms/tests/test_utils_load_model.py b/llms/tests/test_utils_load_model.py index 73ee1352..5821f9e9 100644 --- a/llms/tests/test_utils_load_model.py +++ b/llms/tests/test_utils_load_model.py @@ -32,7 +32,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase): return CustomQwenModel, CustomQwenConfig model_path = get_model_path(HF_MODEL_PATH) - model = load_model(model_path, get_model_classes=custom_get_classes) + model, _ = load_model(model_path, get_model_classes=custom_get_classes) self.assertIsInstance(model, CustomQwenModel) self.assertTrue(hasattr(model, "custom_attribute")) @@ -41,7 +41,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase): def test_load_model_with_default_get_classes(self): model_path = get_model_path(HF_MODEL_PATH) - model = load_model(model_path) + model, _ = load_model(model_path) self.assertIsInstance(model, Qwen2Model) From 135c5818c1b2fbea5970b41a10172518dfa8a73a Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Tue, 10 Dec 2024 11:26:04 -0800 Subject: [PATCH 52/77] Fix max_tokens (#1148) --- llms/mlx_lm/chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 7795d8d7..5a8245ef 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -79,7 +79,7 @@ def main(): model, tokenizer, prompt, - args.max_tokens, + max_tokens=args.max_tokens, sampler=make_sampler(args.temp, args.top_p), prompt_cache=prompt_cache, ): From 77b42b7c8bf27272f0263cd04ce962d20295504f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 12 Dec 2024 10:37:26 -0800 Subject: [PATCH 53/77] fix llava (#1149) --- llava/generate.py | 7 +++---- llava/llava.py | 26 ++++++++------------------ llms/mlx_lm/generate.py | 7 ++++--- 3 files changed, 15 insertions(+), 25 deletions(-) diff --git a/llava/generate.py b/llava/generate.py index 8067839e..64313858 100644 --- a/llava/generate.py +++ b/llava/generate.py @@ -79,10 +79,10 @@ def load_image(image_source): def prepare_inputs(processor, image, prompt): if isinstance(image, str): image = load_image(image) - inputs = processor(prompt, image, return_tensors="np") + inputs = processor(image, prompt, return_tensors="np") pixel_values = mx.array(inputs["pixel_values"]) input_ids = mx.array(inputs["input_ids"]) - return input_ids, pixel_values + return pixel_values, input_ids def load_model(model_path, tokenizer_config={}): @@ -126,8 +126,7 @@ def main(): processor, model = load_model(args.model, tokenizer_config) prompt = codecs.decode(args.prompt, "unicode_escape") - - input_ids, pixel_values = prepare_inputs(processor, args.image, prompt) + pixel_values, input_ids = prepare_inputs(processor, args.image, prompt) print(prompt) generated_text = generate_text( diff --git a/llava/llava.py b/llava/llava.py index 9e6b7511..c5f190f8 100644 --- a/llava/llava.py +++ b/llava/llava.py @@ -104,31 +104,21 @@ class LlavaModel(nn.Module): self, image_features, inputs_embeds, input_ids ): image_token_index = self.config.image_token_index - num_images, num_image_patches, embed_dim = image_features.shape + batch_size, num_image_patches, embed_dim = image_features.shape # Positions of tokens in input_ids, assuming batch size is 1 - image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + image_positions = mx.array( + np.where(input_ids[0] == image_token_index)[0], mx.uint32 + ) - if len(image_positions) != num_images: + if len(image_positions) != num_image_patches: raise ValueError( f"The number of image tokens ({len(image_positions)}) does not " - f" match the number of image inputs ({num_images})." + f" match the number of image patches ({num_image_patches})." ) - text_segments = [] - start_idx = 0 - - for position in image_positions: - text_segments.append(inputs_embeds[:, start_idx:position]) - start_idx = position + 1 - - image_embeddings = mx.split(image_features, image_features.shape[0]) - final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] - final_embeddings += [inputs_embeds[:, start_idx:]] - - # Create a final embedding of shape - # (1, num_image_patches*num_images + sequence_len, embed_dim) - return mx.concatenate(final_embeddings, axis=1) + inputs_embeds[0, image_positions] = image_features + return inputs_embeds def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): input_embddings = self.get_input_embeddings(input_ids, pixel_values) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0c1b4acd..84dc63ca 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -1,6 +1,7 @@ # Copyright © 2023-2024 Apple Inc. import argparse +import codecs import json import sys @@ -188,6 +189,8 @@ def main(): elif using_cache: tokenizer.chat_template = metadata["chat_template"] + prompt = codecs.decode(args.prompt, "unicode_escape") + if not args.ignore_chat_template and ( hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None @@ -199,7 +202,7 @@ def main(): messages.append( { "role": "user", - "content": sys.stdin.read() if args.prompt == "-" else args.prompt, + "content": sys.stdin.read() if prompt == "-" else prompt, } ) prompt = tokenizer.apply_chat_template( @@ -216,8 +219,6 @@ def main(): add_generation_prompt=True, ) prompt = prompt[test_prompt.index("") :] - else: - prompt = args.prompt sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate( From 06af3c9b0eac1aea927dcbbda66cacd5aab76f4a Mon Sep 17 00:00:00 2001 From: madroid Date: Fri, 13 Dec 2024 02:37:40 +0800 Subject: [PATCH 54/77] Add finish_reason in GenerationResponse (#1153) --- llms/mlx_lm/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index d81bb66a..493c1c42 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download -from mlx.utils import tree_flatten, tree_map, tree_reduce +from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer # Local imports @@ -59,6 +59,7 @@ class GenerationResponse: generation_tokens (int): The number of generated tokens. generation_tps (float): The tokens-per-second for generation. peak_memory (float): The peak memory used so far in GB. + finish_reason (str): The reason the response is being sent: "length", "stop" or `None` """ text: str @@ -69,6 +70,7 @@ class GenerationResponse: generation_tokens: int generation_tps: float peak_memory: float + finish_reason: Optional[str] = None @contextlib.contextmanager @@ -375,6 +377,7 @@ def stream_generate( generation_tokens=n + 1, generation_tps=(n + 1) / (time.perf_counter() - tic), peak_memory=mx.metal.get_peak_memory() / 1e9, + finish_reason=None, ) detokenizer.finalize() @@ -387,6 +390,7 @@ def stream_generate( generation_tokens=n + 1, generation_tps=(n + 1) / (time.perf_counter() - tic), peak_memory=mx.metal.get_peak_memory() / 1e9, + finish_reason="stop" if token in tokenizer.eos_token_ids else "length", ) From 19abf3dcaac2809984800680842ae3d859dda6dc Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 12 Dec 2024 11:10:41 -0800 Subject: [PATCH 55/77] Replace unicode errors instead of raising exception (#1146) --- llms/mlx_lm/tokenizer_utils.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 10a257f6..114a35e7 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -3,8 +3,6 @@ from functools import partial from transformers import AutoTokenizer -REPLACEMENT_CHAR = "\ufffd" - class StreamingDetokenizer: """The streaming detokenizer interface so that we can detokenize one token at a time. @@ -51,11 +49,9 @@ class StreamingDetokenizer: def last_segment(self): """Return the last segment of readable text since last time this property was accessed.""" text = self.text - if text and text[-1] != REPLACEMENT_CHAR: - segment = text[self.offset :] - self.offset = len(text) - return segment - return "" + segment = text[self.offset :] + self.offset = len(text) + return segment class NaiveStreamingDetokenizer(StreamingDetokenizer): @@ -132,7 +128,7 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): self.tokens = [] def _flush(self): - text = self._unflushed.replace(self._sep, b" ").decode("utf-8") + text = self._unflushed.replace(self._sep, b" ").decode("utf-8", "replace") if not self.text and self.trim_space and text and text[0] == " ": text = text[1:] self.text += text @@ -202,7 +198,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): if is_added or self._byte_decoder[v[0]] == 32: current_text = bytearray( self._byte_decoder[c] for c in self._unflushed - ).decode("utf-8") + ).decode("utf-8", "replace") self.text += self._maybe_trim_space(current_text) if is_added: self.text += v @@ -214,7 +210,8 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): def finalize(self): current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( - "utf-8" + "utf-8", + "replace", ) self.text += self._maybe_trim_space(current_text) self._unflushed = "" From 2ba0e3668382d2c18ab6f691e2f662081596269f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 12 Dec 2024 11:12:21 -0800 Subject: [PATCH 56/77] [mlx-lm] Use top p in server (#1144) * use top p in server * couple other fixes --- llms/mlx_lm/sample_utils.py | 2 +- llms/mlx_lm/server.py | 2 +- llms/mlx_lm/utils.py | 3 +++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index f9868422..c77f056a 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -190,7 +190,7 @@ def make_repetition_penalty(penalty: float, context_size: int = 20): Callable[[mx.array, List[int]], mx.array]: The repetition penalty processor. """ - if penalty < 0 or not isinstance(penalty, float): + if penalty < 0 or not isinstance(penalty, (int, float)): raise ValueError(f"penalty must be a non-negative float, got {penalty}") def repetition_penalty_processor(tokens, logits): diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index ce09cf45..c12513ff 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -465,7 +465,7 @@ class APIHandler(BaseHTTPRequestHandler): text = "" tic = time.perf_counter() - sampler = make_sampler(self.temperature) + sampler = make_sampler(self.temperature, top_p=self.top_p) logits_processors = make_logits_processors( self.logit_bias, self.repetition_penalty, self.repetition_context_size ) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 493c1c42..b87f5a24 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -299,6 +299,9 @@ def generate_step( prompt_processed_tokens = 0 while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) + maybe_quantize_kv_cache( + prompt_cache, quantized_kv_start, kv_group_size, kv_bits + ) mx.eval([c.state for c in prompt_cache]) prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) prompt_processed_tokens += prefill_step_size From 9f2ea5892e3a9517853c526a928268250741f623 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 12 Dec 2024 13:13:50 -0800 Subject: [PATCH 57/77] Bpe stream without space (#1154) * bpe streaming detokenization without space * version bump --- llms/mlx_lm/_version.py | 2 +- llms/mlx_lm/tokenizer_utils.py | 20 +++++++++----------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 0f885fba..3af2d5fd 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.20.2" +__version__ = "0.20.4" diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 114a35e7..8251e62f 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -195,18 +195,16 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): self.tokens.append(token) v = self.tokenmap[token] is_added = token in self._added_ids - if is_added or self._byte_decoder[v[0]] == 32: - current_text = bytearray( - self._byte_decoder[c] for c in self._unflushed - ).decode("utf-8", "replace") - self.text += self._maybe_trim_space(current_text) - if is_added: - self.text += v - self._unflushed = "" - else: - self._unflushed = v - else: + if not is_added: self._unflushed += v + text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( + "utf-8", "replace" + ) + if is_added: + text += v + if not text.endswith("\ufffd"): + self.text += self._maybe_trim_space(text) + self._unflushed = "" def finalize(self): current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( From fc0674d2d8fa4cd384641f5473d1dc7ffca918df Mon Sep 17 00:00:00 2001 From: Ikko Eltociear Ashimine Date: Sun, 15 Dec 2024 23:06:29 +0900 Subject: [PATCH 58/77] chore: update evaluate.py (#1159) occurence -> occurrence --- llms/mlx_lm/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py index 423d5823..c4b15748 100644 --- a/llms/mlx_lm/evaluate.py +++ b/llms/mlx_lm/evaluate.py @@ -32,7 +32,7 @@ def _len_longest_common_prefix(a, b): def _rstrip_until(s, untils): - """Limit a string to the first occurence of any substring in untils.""" + """Limit a string to the first occurrence of any substring in untils.""" l = len(s) f = [s.find(u) for u in untils] f = [l if x < 0 else x for x in f] From dfa4dd6c93c4c2f81bfed6becb8af5cc3a89ae61 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 16 Dec 2024 17:01:03 +0100 Subject: [PATCH 59/77] Add support for cohere2 (#1157) * add support for cohere2 * revert to act_fn to silu * fix tests and sliding window attention * add tests * add to tuner * fix sliding window * add coauthor :) Co-authored-by: n8programs <43304488+N8python@users.noreply.github.com> * Add rotating kvcache to save space * some nits * style * nits --------- Co-authored-by: n8programs <43304488+N8python@users.noreply.github.com> Co-authored-by: N8 Co-authored-by: Awni Hannun --- llms/mlx_lm/models/cohere2.py | 207 ++++++++++++++++++++++++++++++++++ llms/mlx_lm/tuner/utils.py | 1 + llms/mlx_lm/utils.py | 7 +- llms/tests/test_models.py | 16 +++ 4 files changed, 228 insertions(+), 3 deletions(-) create mode 100644 llms/mlx_lm/models/cohere2.py diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py new file mode 100644 index 00000000..fcb4061b --- /dev/null +++ b/llms/mlx_lm/models/cohere2.py @@ -0,0 +1,207 @@ +# Copyright © 2023-2024 Apple Inc. + +from dataclasses import dataclass +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention +from .cache import KVCache, RotatingKVCache + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int = 4096 + head_dim: int = 128 + num_hidden_layers: int = 32 + intermediate_size: int = 14336 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + rope_theta: float = 50000.0 + vocab_size: int = 256000 + layer_norm_eps: float = 1e-05 + logit_scale: float = 0.0625 + attention_bias: bool = False + layer_norm_bias: bool = False + sliding_window: int = 4096 + sliding_window_pattern: int = 4 + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.args = args + self.layer_idx = layer_idx + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.head_dim = head_dim = args.head_dim + if (head_dim * n_heads) != dim: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {dim}" + f" and `num_heads`: {n_heads})." + ) + self.scale = head_dim**-0.5 + + attetion_bias = args.attention_bias + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attetion_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias) + + self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta) + + self.use_sliding_window = (layer_idx + 1) % args.sliding_window_pattern != 0 + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + # Apply RoPE only if sliding window is enabled + if self.use_sliding_window: + if cache is None: + queries = self.rope(queries) + keys = self.rope(keys) + else: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + + if cache is not None: + keys, values = cache.update_and_fetch(keys, values) + + if self.use_sliding_window and mask is not None: + key_len = keys.shape[-2] + if mask.shape[-1] != key_len: + mask = mask[..., -key_len:] + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + + def __call__(self, x): + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.hidden_size = args.hidden_size + self.n_heads = args.num_attention_heads + + self.self_attn = Attention(args, layer_idx) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = nn.LayerNorm( + args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + h = self.input_layernorm(x) + attn_h = self.self_attn(h, mask, cache) + ff_h = self.mlp(h) + return attn_h + ff_h + x + + +class CohereModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args, layer_idx=i) + for i in range(args.num_hidden_layers) + ] + self.norm = nn.LayerNorm( + args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias + ) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + T = h.shape[1] + if T > 1: + offset = cache[0].offset if cache else 0 + mask = create_causal_mask(T, offset).astype(h.dtype) + else: + mask = None + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model_type = args.model_type + self.model = CohereModel(args) + self.args = args + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + out = self.model.embed_tokens.as_linear(out) + out = out * self.model.args.logit_scale + return out + + def make_cache(self): + caches = [] + for i in range(self.args.num_hidden_layers): + if ( + i % self.args.sliding_window_pattern + == self.args.sliding_window_pattern - 1 + ): + caches.append(KVCache()) + else: + caches.append( + RotatingKVCache(max_size=self.args.sliding_window, keep=0) + ) + return caches + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 6821f434..3986952a 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -96,6 +96,7 @@ def linear_to_lora_layers( "gemma2", "starcoder2", "cohere", + "cohere2", "minicpm", "deepseek", "olmo2", diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b87f5a24..4d69115e 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -187,9 +187,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ and prompt_cache[0].offset > quantized_kv_start ): for i in range(len(prompt_cache)): - prompt_cache[i] = prompt_cache[i].to_quantized( - group_size=kv_group_size, bits=kv_bits - ) + if isinstance(prompt_cache[i], cache.KVCache): + prompt_cache[i] = prompt_cache[i].to_quantized( + group_size=kv_group_size, bits=kv_bits + ) def generate_step( diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 374a5113..3097c522 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -851,6 +851,22 @@ class TestModels(unittest.TestCase): model = exaone.Model(args) self.model_test_runner(model, args.model_type, args.vocab_size, args.num_layers) + def test_cohere2(self): + from mlx_lm.models import cohere2 + + args = cohere2.ModelArgs( + model_type="cohere2", + hidden_size=4096, + head_dim=128, + num_hidden_layers=40, + sliding_window=4096, + sliding_window_pattern=4, + ) + model = cohere2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + if __name__ == "__main__": unittest.main() From 845efddc8cb4578fe008c2ad0c26ec595e7f6b1e Mon Sep 17 00:00:00 2001 From: Billel Mokeddem Date: Tue, 17 Dec 2024 21:54:29 +0400 Subject: [PATCH 60/77] Fix decoding manually added tokens (#1164) * Fix decoding manually added tokens * fix + test * nit * nit * no lag bpe --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/tokenizer_utils.py | 44 +++++++++++++++++++--------------- llms/tests/test_tokenizers.py | 4 ++++ 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 8251e62f..ca3d6c06 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -127,23 +127,23 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): self.text = "" self.tokens = [] - def _flush(self): + def _try_flush(self, force=False): text = self._unflushed.replace(self._sep, b" ").decode("utf-8", "replace") + if not force and text.endswith("\ufffd"): + return if not self.text and self.trim_space and text and text[0] == " ": text = text[1:] self.text += text + self._unflushed = b"" def add_token(self, token): self.tokens.append(token) v = self.tokenmap[token] - if v.startswith(self._sep): - self._flush() - self._unflushed = v - else: - self._unflushed += v + self._unflushed += v + self._try_flush() def finalize(self): - self._flush() + self._try_flush(force=True) self._unflushed = b"" @@ -158,7 +158,6 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): _space_matches = (".", "?", "!", ",", "n't", "'m", "'s", "'ve", "'re") def __init__(self, tokenizer): - self.clean_spaces = tokenizer.clean_up_tokenization_spaces # Extract the tokens in a list from id to text @@ -172,14 +171,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): # https://github.com/openai/gpt-2/blob/master/src/encoder.py self.make_byte_decoder() - self._added_ids = set(tokenizer.added_tokens_decoder.keys()) - def reset(self): self.offset = 0 self._unflushed = "" self.text = "" self.tokens = [] + def _decode_bytes(self, seq): + barr = bytearray() + for c in seq: + res = self._byte_decoder.get(c, False) + if res: + barr.append(res) + else: + barr.extend(bytes(c, "utf-8")) + return barr.decode("utf-8", "replace") + def _maybe_trim_space(self, current_text): if len(current_text) == 0: return current_text @@ -194,15 +201,14 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): def add_token(self, token): self.tokens.append(token) v = self.tokenmap[token] - is_added = token in self._added_ids - if not is_added: - self._unflushed += v - text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode( - "utf-8", "replace" - ) - if is_added: - text += v - if not text.endswith("\ufffd"): + self._unflushed += v + text = self._decode_bytes(self._unflushed) + + # For multi-byte utf-8 wait until they are complete + # For single spaces wait until the next token to clean it if needed + if not text.endswith("\ufffd") and not ( + len(v) == 1 and self._byte_decoder[v[0]] == 32 + ): self.text += self._maybe_trim_space(text) self._unflushed = "" diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py index db6b9f9e..3009d1b1 100644 --- a/llms/tests/test_tokenizers.py +++ b/llms/tests/test_tokenizers.py @@ -58,6 +58,9 @@ class TestTokenizers(unittest.TestCase): tokens = tokenizer.encode("import 'package:flutter/material.dart';") check(tokens) + tokens = tokenizer.encode("hello\nworld") + check(tokens) + def test_tokenizers(self): tokenizer_repos = [ ("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer), @@ -65,6 +68,7 @@ class TestTokenizers(unittest.TestCase): ("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer), ("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer), ("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer), + ("mlx-community/Falcon3-7B-Instruct-4bit", BPEStreamingDetokenizer), ] for tokenizer_repo, expected_detokenizer in tokenizer_repos: with self.subTest(tokenizer=tokenizer_repo): From db109184b7f23ce3166c6cfd4682b092b4bdfbb6 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 18 Dec 2024 18:46:50 -0800 Subject: [PATCH 61/77] Fix no template prompt + top_k sampling (#1166) * fix no template prompt * add top_k sampling * fix chinese --- llms/mlx_lm/generate.py | 12 +++--------- llms/mlx_lm/sample_utils.py | 34 ++++++++++++++++++++++++++++++++- llms/tests/test_sample_utils.py | 23 +++++++++++++++++++++- 3 files changed, 58 insertions(+), 11 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 84dc63ca..afb1394e 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -1,7 +1,6 @@ # Copyright © 2023-2024 Apple Inc. import argparse -import codecs import json import sys @@ -189,8 +188,8 @@ def main(): elif using_cache: tokenizer.chat_template = metadata["chat_template"] - prompt = codecs.decode(args.prompt, "unicode_escape") - + prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t") + prompt = sys.stdin.read() if prompt == "-" else prompt if not args.ignore_chat_template and ( hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None @@ -199,12 +198,7 @@ def main(): messages = [{"role": "system", "content": args.system_prompt}] else: messages = [] - messages.append( - { - "role": "user", - "content": sys.stdin.read() if prompt == "-" else prompt, - } - ) + messages.append({"role": "user", "content": prompt}) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index c77f056a..c48a32cf 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -12,6 +12,7 @@ def make_sampler( top_p: float = 0.0, min_p: float = 0.0, min_tokens_to_keep: int = 1, + top_k: int = -1, ) -> Callable[mx.array, mx.array]: """ Make a sampler function for use with ``generate_step``. @@ -25,6 +26,8 @@ def make_sampler( probability) that a token probability must have to be considered. min_tokens_to_keep (int, optional): Minimum number of tokens that cannot be filtered by min_p sampling. + top_k (int, optional): The top k tokens ranked by probability to constrain + the sampling to. Returns: Callable[mx.array, mx.array]: @@ -36,6 +39,8 @@ def make_sampler( return lambda x: top_p_sampling(x, top_p, temp) elif min_p != 0.0: return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp) + elif top_k > 0: + return lambda x: top_k_sampling(x, top_k, temp) else: return lambda x: categorical_sampling(x, temp) @@ -79,6 +84,33 @@ def make_logits_processors( return logits_processors +@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) +def top_k_sampling( + logprobs: mx.array, + top_k: int, + temperature=1.0, +) -> mx.array: + """ + Sample from only the top K tokens ranked by probability. + + Args: + logprobs: A vector of log probabilities. + top_k (int): Top k tokens to sample from. + """ + vocab_size = logprobs.shape[-1] + if not isinstance(top_k, int) or not (0 < top_k < vocab_size): + raise ValueError( + f"`top_k` has to be an integer in the (0, {vocab_size}] interval," + f" but is {top_k}." + ) + logprobs = logprobs * (1 / temperature) + mask_idx = mx.argpartition(-logprobs, kth=top_k - 1, axis=-1)[..., top_k:] + masked_logprobs = mx.put_along_axis( + logprobs, mask_idx, mx.array(-float("inf"), logprobs.dtype), axis=-1 + ) + return mx.random.categorical(masked_logprobs, axis=-1) + + @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def min_p_sampling( logprobs: mx.array, @@ -87,7 +119,7 @@ def min_p_sampling( temperature=1.0, ) -> mx.array: """ - Apply min-p sampling to the logits. + Apply min-p sampling to the logprobs. Min-p keeps all tokens that are above a minimum probability, scaled by the probability of the most likely token. As a result, the filter is more diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index ebc90ce8..c45fa443 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -1,7 +1,7 @@ import unittest import mlx.core as mx -from mlx_lm.sample_utils import min_p_sampling, top_p_sampling +from mlx_lm.sample_utils import min_p_sampling, top_k_sampling, top_p_sampling class TestSampleUtils(unittest.TestCase): @@ -42,6 +42,27 @@ class TestSampleUtils(unittest.TestCase): token = min_p_sampling(logits, 0.05) self.assertTrue(token in (0, 3)) + def test_top_k_sampling(self): + probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] + logits = mx.log(probs) + + token = top_k_sampling(logits, 1).item() + self.assertEqual(token, 0) + + probs = mx.array([0.5, 0.0, 0.0, 0.5])[None] + tokens = set() + for _ in range(100): + token = top_k_sampling(logits, 2) + tokens.add(token.item()) + self.assertEqual(tokens, {0, 3}) + + # Batch mode works + probs = mx.array([[0.9, 0.0, 0.0, 0.1], [0.0, 0.8, 0.0, 0.1]]) + logits = mx.log(probs) + + tokens = top_k_sampling(logits, 1) + self.assertEqual(tokens.tolist(), [0, 1]) + if __name__ == "__main__": unittest.main() From d4ef909d4ab44d9f8cf89f5baa8a433d76d7d6b1 Mon Sep 17 00:00:00 2001 From: Alex Barron Date: Wed, 18 Dec 2024 19:43:52 -0800 Subject: [PATCH 62/77] Length masking for batch inputs (#1173) * length masking * add mask to mlx_lm model interface * remove lengths * fix test: * comment + fix --- llms/mlx_lm/models/base.py | 10 +++++++++- llms/mlx_lm/models/cohere.py | 7 +++++-- llms/mlx_lm/models/cohere2.py | 14 ++++++-------- llms/mlx_lm/models/dbrx.py | 7 +++++-- llms/mlx_lm/models/deepseek.py | 7 +++++-- llms/mlx_lm/models/deepseek_v2.py | 8 ++++++-- llms/mlx_lm/models/exaone.py | 7 +++++-- llms/mlx_lm/models/gemma.py | 7 +++++-- llms/mlx_lm/models/gemma2.py | 7 +++++-- llms/mlx_lm/models/gpt2.py | 7 +++++-- llms/mlx_lm/models/gpt_bigcode.py | 7 +++++-- llms/mlx_lm/models/gpt_neox.py | 7 +++++-- llms/mlx_lm/models/hunyuan.py | 7 +++++-- llms/mlx_lm/models/internlm2.py | 7 +++++-- llms/mlx_lm/models/llama.py | 7 +++++-- llms/mlx_lm/models/minicpm.py | 7 +++++-- llms/mlx_lm/models/mixtral.py | 7 +++++-- llms/mlx_lm/models/nemotron.py | 7 +++++-- llms/mlx_lm/models/olmo.py | 10 +++++++--- llms/mlx_lm/models/olmo2.py | 7 +++++-- llms/mlx_lm/models/openelm.py | 7 +++++-- llms/mlx_lm/models/phi.py | 8 +++++--- llms/mlx_lm/models/phi3.py | 7 +++++-- llms/mlx_lm/models/phi3small.py | 7 +++++-- llms/mlx_lm/models/phimoe.py | 7 +++++-- llms/mlx_lm/models/phixtral.py | 4 +++- llms/mlx_lm/models/plamo.py | 7 +++++-- llms/mlx_lm/models/qwen.py | 3 ++- llms/mlx_lm/models/qwen2.py | 7 +++++-- llms/mlx_lm/models/qwen2_moe.py | 7 +++++-- llms/mlx_lm/models/recurrent_gemma.py | 8 +++++--- llms/mlx_lm/models/stablelm.py | 5 ++++- llms/mlx_lm/models/starcoder2.py | 7 +++++-- llms/tests/test_models.py | 25 ++++++++++++++++++++++++- 34 files changed, 191 insertions(+), 72 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index f02f49b1..ad7a4a65 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -23,7 +23,12 @@ class BaseModelArgs: ) -def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None): +def create_causal_mask( + N: int, + offset: int = 0, + window_size: Optional[int] = None, + lengths: Optional[mx.array] = None, +): rinds = mx.arange(offset + N) linds = mx.arange(offset, offset + N) if offset else rinds linds = linds[:, None] @@ -31,6 +36,9 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non mask = linds < rinds if window_size is not None: mask = mask | (linds > rinds + window_size) + if lengths is not None: + lengths = lengths[:, None, None, None] + mask = mask | (rinds >= lengths) return mask * -1e9 diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index 7e002b0c..b2d16dd7 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -155,11 +155,13 @@ class CohereModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -180,9 +182,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) out = out * self.model.args.logit_scale return out diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index fcb4061b..ec0e9276 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -6,7 +6,7 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_causal_mask, scaled_dot_product_attention +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .cache import KVCache, RotatingKVCache @@ -151,16 +151,13 @@ class CohereModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - T = h.shape[1] - if T > 1: - offset = cache[0].offset if cache else 0 - mask = create_causal_mask(T, offset).astype(h.dtype) - else: - mask = None + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -181,9 +178,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) out = out * self.model.args.logit_scale return out diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index 7be274cc..886b5630 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -197,11 +197,13 @@ class DBRX(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.wte(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) @@ -223,9 +225,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) return self.lm_head(out) @property diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py index b7b24dba..ffc30c36 100644 --- a/llms/mlx_lm/models/deepseek.py +++ b/llms/mlx_lm/models/deepseek.py @@ -211,9 +211,11 @@ class DeepseekModel(nn.Module): self, x: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(x) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -236,8 +238,9 @@ class Model(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 444813b9..9027da7e 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -370,9 +370,12 @@ class DeepseekV2Model(nn.Module): self, x: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(x) - mask = create_attention_mask(h, cache) + + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -395,8 +398,9 @@ class Model(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/exaone.py b/llms/mlx_lm/models/exaone.py index eaed5dd8..ee3ed1e8 100644 --- a/llms/mlx_lm/models/exaone.py +++ b/llms/mlx_lm/models/exaone.py @@ -123,10 +123,12 @@ class ExaoneModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.wte(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.h) @@ -149,9 +151,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.transformer.wte.as_linear(out) else: diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index 3f384c3f..0860ddeb 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -138,12 +138,14 @@ class GemmaModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -164,9 +166,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) return out diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index 64951ae4..321a58ff 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -160,12 +160,14 @@ class GemmaModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) h = h * (self.args.hidden_size**0.5) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -187,9 +189,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) out = mx.tanh(out / self.final_logit_softcapping) out = out * self.final_logit_softcapping diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index 52076a34..5b277734 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -126,6 +126,7 @@ class GPT2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): _, L = inputs.shape @@ -138,7 +139,8 @@ class GPT2Model(nn.Module): position_ids = mx.array(np.arange(L)) hidden_states += self.wpe(position_ids) - mask = create_attention_mask(hidden_states, cache) + if mask is None: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) @@ -159,9 +161,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.wte.as_linear(out) return out diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 23e86e20..8415c59e 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -137,6 +137,7 @@ class GPTBigCodeModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): B, L = inputs.shape @@ -149,7 +150,8 @@ class GPTBigCodeModel(nn.Module): position_ids = mx.array(np.arange(L)) hidden_states += self.wpe(position_ids) - mask = create_attention_mask(hidden_states, cache) + if mask is None: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) @@ -172,9 +174,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.transformer.wte.as_linear(out) else: diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index ccb0b28b..5e124a67 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -146,13 +146,15 @@ class GPTNeoXModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): _, L = inputs.shape hidden_states = self.embed_in(inputs) - mask = create_attention_mask(hidden_states, cache) + if mask is None: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) @@ -176,9 +178,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return out def sanitize(self, weights): diff --git a/llms/mlx_lm/models/hunyuan.py b/llms/mlx_lm/models/hunyuan.py index b098c20d..f9dc5652 100644 --- a/llms/mlx_lm/models/hunyuan.py +++ b/llms/mlx_lm/models/hunyuan.py @@ -239,11 +239,13 @@ class HunYuanModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -266,9 +268,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.model.embed_tokens.as_linear(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index f5ce057e..28a095e1 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -193,11 +193,13 @@ class InternLM2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.tok_embeddings(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -220,9 +222,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.tok_embeddings.as_linear(out) else: diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 290cb83e..7b452ea4 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -155,11 +155,13 @@ class LlamaModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -182,9 +184,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index 907beb2a..edddd583 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -158,11 +158,13 @@ class MiniCPMModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) * self.args.scale_emb - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -186,9 +188,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if not self.args.tie_word_embeddings: out = self.lm_head(out / (self.args.hidden_size / self.args.dim_model_base)) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index dd94d1f4..0afd1235 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -162,11 +162,13 @@ class MixtralModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -188,9 +190,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py index f73c0277..eabfac8c 100644 --- a/llms/mlx_lm/models/nemotron.py +++ b/llms/mlx_lm/models/nemotron.py @@ -176,11 +176,13 @@ class NemotronModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -203,9 +205,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 3627df06..4273b0ec 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -124,11 +124,13 @@ class Transformer(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.wte(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.blocks) @@ -152,9 +154,10 @@ class OlmoModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - return self.transformer(inputs, cache) + return self.transformer(inputs, mask, cache) class Model(nn.Module): @@ -167,9 +170,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - return self.model(inputs, cache) + return self.model(inputs, mask, cache) @property def layers(self): diff --git a/llms/mlx_lm/models/olmo2.py b/llms/mlx_lm/models/olmo2.py index 64d7e116..510ff882 100644 --- a/llms/mlx_lm/models/olmo2.py +++ b/llms/mlx_lm/models/olmo2.py @@ -163,10 +163,12 @@ class LlamaModel(nn.Module): self, inputs: mx.array, cache=None, + mask=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -190,8 +192,9 @@ class Model(nn.Module): self, inputs: mx.array, cache=None, + mask=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 408802f4..504fe95c 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -178,11 +178,13 @@ class OpenELMModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.token_embeddings(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -205,9 +207,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.transformer(inputs, cache) + out = self.transformer(inputs, mask, cache) if self.args.share_input_output_layers: out = self.transformer.token_embeddings.as_linear(out) else: diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 510025ea..e9724691 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -143,10 +143,11 @@ class PhiModel(nn.Module): config.hidden_size, eps=config.layer_norm_eps ) - def __call__(self, x, cache): + def __call__(self, x, mask, cache): x = self.embed_tokens(x) - mask = create_attention_mask(x, cache) + if mask is None: + mask = create_attention_mask(x, cache) if cache is None: cache = [None] * len(self.layers) @@ -167,9 +168,10 @@ class Model(nn.Module): def __call__( self, x: mx.array, + mask: mx.array = None, cache=None, ) -> mx.array: - y = self.model(x, cache) + y = self.model(x, mask, cache) return self.lm_head(y) @property diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index ee6efc49..d1c21e25 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -168,11 +168,13 @@ class Phi3Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -194,9 +196,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) @property diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index 53e1a638..cd566eec 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -258,13 +258,15 @@ class Phi3Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) if self.mup_embedding_multiplier: h = self.mup_embedding_multiplier * h - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -290,9 +292,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) out = self.model.embed_tokens.as_linear(out) if self.mup_width_multiplier: out = out / self.mup_width_multiplier diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py index f42a6dd0..bddcb128 100644 --- a/llms/mlx_lm/models/phimoe.py +++ b/llms/mlx_lm/models/phimoe.py @@ -155,11 +155,13 @@ class PhiMoEModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ) -> mx.array: h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -181,9 +183,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 42d647b0..5477c2c0 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -175,7 +175,9 @@ class Model(nn.Module): mask: mx.array = None, cache=None, ) -> mx.array: - mask = create_attention_mask(x, cache) + + if mask is None: + mask = create_attention_mask(x, cache) y = self.transformer(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index c8e5bf50..9107daad 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -174,10 +174,12 @@ class PlamoModel(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None for _ in range(len(self.layers.layers))] @@ -202,8 +204,9 @@ class Model(nn.Module): self, inputs: mx.array, cache: Optional[Any] = None, + mask: Optional[mx.array] = None, ) -> mx.array: - out = self.model(inputs, cache) + out = self.model(inputs, cache, mask) return self.lm_head(out) @property diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 8145a890..ec8a0199 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -123,7 +123,8 @@ class QwenModel(nn.Module): def __call__(self, inputs, mask=None, cache=None): x = self.wte(inputs) - mask = create_attention_mask(x, cache) + if mask is None: + mask = create_attention_mask(x, cache) if cache is None: cache = [None] * len(self.h) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index fac59d78..381767c4 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -149,11 +149,13 @@ class Qwen2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -176,9 +178,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index 167fc5dd..c6aba622 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -187,11 +187,13 @@ class Qwen2MoeModel(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -213,9 +215,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) return self.lm_head(out) def sanitize(self, weights): diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 49e4bb8f..ad07d925 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -389,6 +389,7 @@ class Griffin(nn.Module): def __call__( self, tokens, + mask: mx.array = None, cache=None, ): x = self.embed_tokens(tokens) @@ -402,7 +403,8 @@ class Griffin(nn.Module): if block.temporal_block_type != "recurrent": mask_cache = [cache[i]] - mask = create_attention_mask(x, mask_cache) + if mask is None: + mask = create_attention_mask(x, mask_cache) for i, block in enumerate(self.layers): x = block(x, mask=mask, cache=cache[i]) @@ -418,12 +420,12 @@ class Model(nn.Module): self.model_type = config.model_type self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - def __call__(self, tokens: mx.array, cache=None) -> mx.array: + def __call__(self, tokens: mx.array, mask: mx.array = None, cache=None) -> mx.array: """ Args: tokens: Sequence of input tokens. """ - logits = self.model(tokens, cache=cache) + logits = self.model(tokens, mask=mask, cache=cache) if "lm_head" in self: logits = self.lm_head(logits) else: diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 482bb324..0bbc2ca4 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -199,7 +199,10 @@ class Model(nn.Module): mask: mx.array = None, cache=None, ) -> mx.array: - mask = create_attention_mask(x, cache) + + if mask is None: + mask = create_attention_mask(x, cache) + y = self.model(x, mask, cache) return self.lm_head(y) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index d7e626f2..71c397f6 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -125,11 +125,13 @@ class Starcoder2Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + if mask is None: + mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) @@ -152,9 +154,10 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, + mask: mx.array = None, cache=None, ): - out = self.model(inputs, cache) + out = self.model(inputs, mask, cache) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(out) else: diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 3097c522..7b4376bb 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -5,6 +5,7 @@ import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_map from mlx_lm.models import rope_utils +from mlx_lm.models.base import create_causal_mask from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache @@ -128,6 +129,22 @@ class TestModels(unittest.TestCase): self.assertEqual(cache.offset, 22) self.assertTrue(mx.allclose(x, k[..., -2:, :])) + def test_causal_mask_lengths(self): + mx.random.seed(8) + B, N_q, T_q, N_kv, T_kv, D = (4, 8, 3, 2, 3, 2) + lengths = mx.array([1, 2, 3, 1]) + q = mx.random.uniform(shape=(B, N_q, T_q, D)) + k = mx.random.uniform(shape=(B, N_kv, T_kv, D)) + v = k + mask = create_causal_mask(T_q, 0, lengths=lengths) + + out1 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + q[1, :, 2:] = mx.ones_like(q[1, :, 2:]) + k[1, :, 2:] = mx.ones_like(k[1, :, 2:]) + v[1, :, 2:] = mx.ones_like(v[1, :, 2:]) + out2 = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + self.assertTrue(mx.allclose(out1[1, :, :2], out2[1, :, :2])) + def test_rope(self): rope = rope_utils.initialize_rope(32, base=100, traditional=False) self.assertTrue(isinstance(rope, nn.RoPE)) @@ -162,10 +179,16 @@ class TestModels(unittest.TestCase): self.assertEqual(outputs.dtype, t) cache = make_prompt_cache(model) - outputs = model(inputs, cache) + outputs = model(inputs, cache=cache) self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) + if model_type != "mamba": + mask = create_causal_mask(inputs.shape[1], 0).astype(t) + outputs = model(inputs, mask=mask) + self.assertEqual(outputs.shape, (1, 2, vocab_size)) + self.assertEqual(outputs.dtype, t) + outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache) self.assertEqual(outputs.shape, (1, 1, vocab_size)) self.assertEqual(outputs.dtype, t) From 3a58c361096e5be7a927e7719c5ef66bace9a8ab Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 1 Jan 2025 16:25:57 +0100 Subject: [PATCH 63/77] Improvements to mlx_lm.manage (#1178) * improvements to manage. Default value is N and size added to deletion confirmation. * Fixing case for no case * nits --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/manage.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/llms/mlx_lm/manage.py b/llms/mlx_lm/manage.py index bb5c3a09..9827f3dc 100644 --- a/llms/mlx_lm/manage.py +++ b/llms/mlx_lm/manage.py @@ -6,19 +6,18 @@ from transformers.commands.user import tabulate def ask_for_confirmation(message: str) -> bool: + """Ask user for confirmation with Y/N prompt. + Returns True for Y/yes, False for N/no/empty.""" y = ("y", "yes", "1") - n = ("n", "no", "0") - all_values = y + n + ("",) - full_message = f"{message} (Y/n) " + n = ("n", "no", "0", "") + full_message = f"{message} (y/n) " while True: answer = input(full_message).lower() - if answer == "": - return False if answer in y: return True if answer in n: return False - print(f"Invalid input. Must be one of {all_values}") + print(f"Invalid input. Must be one of: yes/no/y/n or empty for no") def main(): @@ -43,9 +42,7 @@ def main(): args = parser.parse_args() if args.scan: - print( - "Scanning Hugging Face cache for models with" f'pattern "{args.pattern}".' - ) + print(f'Scanning Hugging Face cache for models with pattern "{args.pattern}".') hf_cache_info = scan_cache_dir() print( tabulate( @@ -86,35 +83,41 @@ def main(): if args.pattern in repo.repo_id ] if repos: + print("\nFound the following models:") print( tabulate( rows=[ [ repo.repo_id, + repo.size_on_disk_str, # Added size information str(repo.repo_path), ] for repo in repos ], headers=[ "REPO ID", + "SIZE", # Added size header "LOCAL PATH", ], ) ) - confirmed = ask_for_confirmation(f"Confirm deletion ?") + confirmed = ask_for_confirmation( + "\nAre you sure you want to delete these models?" + ) if confirmed: for model_info in repos: + print(f"\nDeleting {model_info.repo_id}...") for revision in sorted( model_info.revisions, key=lambda revision: revision.commit_hash ): strategy = hf_cache_info.delete_revisions(revision.commit_hash) strategy.execute() - print("Model(s) deleted.") + print("\nModel(s) deleted successfully.") else: - print("Deletion is cancelled. Do nothing.") + print("\nDeletion cancelled - no changes made.") else: - print(f"No models found.") + print(f'No models found matching pattern "{args.pattern}"') if __name__ == "__main__": From c4833a2f55c4553f71b16a412a6eb6d2f1427380 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 3 Jan 2025 10:50:59 -0800 Subject: [PATCH 64/77] fix encoding with special tokens + chat template (#1189) --- llms/README.md | 4 +- llms/mlx_lm/cache_prompt.py | 20 ++---- llms/mlx_lm/chat.py | 4 +- llms/mlx_lm/evaluate.py | 28 ++++++--- llms/mlx_lm/examples/chat.py | 8 +-- llms/mlx_lm/examples/generate_response.py | 2 +- llms/mlx_lm/generate.py | 9 +-- llms/mlx_lm/lora.py | 2 + llms/mlx_lm/server.py | 6 +- llms/mlx_lm/tuner/datasets.py | 77 ++++++++++++----------- llms/mlx_lm/tuner/trainer.py | 8 +-- llms/mlx_lm/utils.py | 19 +++--- llms/tests/test_datsets.py | 5 +- 13 files changed, 95 insertions(+), 97 deletions(-) diff --git a/llms/README.md b/llms/README.md index 4fff4207..e943ed69 100644 --- a/llms/README.md +++ b/llms/README.md @@ -58,7 +58,7 @@ prompt = "Write a story about Einstein" messages = [{"role": "user", "content": prompt}] prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, add_generation_prompt=True ) text = generate(model, tokenizer, prompt=prompt, verbose=True) @@ -115,7 +115,7 @@ prompt = "Write a story about Einstein" messages = [{"role": "user", "content": prompt}] prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, add_generation_prompt=True ) for response in stream_generate(model, tokenizer, prompt, max_tokens=512): diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 9d7d1603..c18f1bae 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -110,29 +110,17 @@ def main(): if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template - if not args.ignore_chat_template and ( - hasattr(tokenizer, "apply_chat_template") - and tokenizer.chat_template is not None - ): + if not args.ignore_chat_template and tokenizer.chat_template is not None: messages = [{"role": "user", "content": args.prompt}] prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, add_generation_prompt=False, continue_final_message=True ) - # Treat the prompt as a prefix assuming that the suffix will be - # provided at generation time. - test_prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": ""}], - tokenize=False, - add_generation_prompt=True, - ) - n = len(test_prompt) - test_prompt.index("") - len("") - prompt = prompt[:-n] else: - prompt = args.prompt + prompt = tokenizer.encode(args.prompt) cache = make_prompt_cache(model, args.max_kv_size) - y = mx.array(tokenizer.encode(prompt)) + y = mx.array(prompt) # Process the prompt start = time.time() diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 5a8245ef..e52ad10d 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -72,9 +72,7 @@ def main(): if query == "q": break messages = [{"role": "user", "content": query}] - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) + prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) for response in stream_generate( model, tokenizer, diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py index c4b15748..bf7bf4d4 100644 --- a/llms/mlx_lm/evaluate.py +++ b/llms/mlx_lm/evaluate.py @@ -1,4 +1,8 @@ -# Adapted from a PyTorch implementation by David Grangier +# Copyright © 2024 Apple Inc. + +""" +Adapted from a PyTorch implementation by David Grangier +""" import argparse import json @@ -6,7 +10,7 @@ import logging import os from importlib.metadata import version from pathlib import Path -from typing import Optional +from typing import Optional, Union import lm_eval import mlx.core as mx @@ -277,19 +281,19 @@ class MLXLM(LM): assert "until" in keys untils = [x["until"] for x in options] completions = [] + for context, until in tqdm(zip(contexts, untils), total=len(contexts)): - if ( - hasattr(self._tokenizer, "apply_chat_template") - and self._tokenizer.chat_template is not None - ): + if self._tokenizer.chat_template is not None: messages = [{"role": "user", "content": context}] context = self._tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, add_generation_prompt=True ) + else: + context = self._tokenizer.encode(context) max_tokens = min( self._max_tokens, - self._tokenizer.model_max_length - len(self._tokenizer.encode(context)), + self._tokenizer.model_max_length - len(context), ) text = "" for response in stream_generate( @@ -321,6 +325,12 @@ def main(): type=int, help="Maximum nunber of tokens to generate. Defaults to the model's max context length.", ) + parser.add_argument( + "--limit", + default=1.0, + help="Limit the number of examples per task.", + type=float, + ) parser.add_argument("--seed", type=int, default=123, help="Random seed.") args = parser.parse_args() @@ -338,10 +348,12 @@ def main(): model=lm, tasks=args.tasks, num_fewshot=args.num_shots, + limit=args.limit, random_seed=args.seed, numpy_random_seed=args.seed, torch_random_seed=args.seed, fewshot_random_seed=args.seed, + apply_chat_template=True, ) model_name = args.model.replace("/", "_") diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py index c7512b3c..4a7020f1 100644 --- a/llms/mlx_lm/examples/chat.py +++ b/llms/mlx_lm/examples/chat.py @@ -15,9 +15,7 @@ prompt_cache = make_prompt_cache(model) # User turn prompt = "Hi my name is ." messages = [{"role": "user", "content": prompt}] -prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True -) +prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) # Assistant response response = generate( @@ -32,9 +30,7 @@ response = generate( # User turn prompt = "What's my name?" messages = [{"role": "user", "content": prompt}] -prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True -) +prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) # Assistant response response = generate( diff --git a/llms/mlx_lm/examples/generate_response.py b/llms/mlx_lm/examples/generate_response.py index e6535b47..41eaf1da 100644 --- a/llms/mlx_lm/examples/generate_response.py +++ b/llms/mlx_lm/examples/generate_response.py @@ -14,7 +14,7 @@ conversation = [{"role": "user", "content": prompt}] # Transform the prompt into the chat template prompt = tokenizer.apply_chat_template( - conversation=conversation, tokenize=False, add_generation_prompt=True + conversation=conversation, add_generation_prompt=True ) # Specify the maximum number of tokens diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index afb1394e..1ea66384 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -190,10 +190,7 @@ def main(): prompt = args.prompt.replace("\\n", "\n").replace("\\t", "\t") prompt = sys.stdin.read() if prompt == "-" else prompt - if not args.ignore_chat_template and ( - hasattr(tokenizer, "apply_chat_template") - and tokenizer.chat_template is not None - ): + if not args.ignore_chat_template and tokenizer.chat_template is not None: if args.system_prompt is not None: messages = [{"role": "system", "content": args.system_prompt}] else: @@ -214,6 +211,10 @@ def main(): ) prompt = prompt[test_prompt.index("") :] + prompt = tokenizer.encode(prompt, add_special_tokens=False) + else: + prompt = tokenizer.encode(prompt) + sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate( model, diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index c96e75a7..6fb86917 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -2,6 +2,7 @@ import argparse import math +import os import re import types from pathlib import Path @@ -271,6 +272,7 @@ def run(args, training_callback: TrainingCallback = None): def main(): + os.environ["TOKENIZERS_PARALLELISM"] = "true" parser = build_parser() args = parser.parse_args() config = args.config diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index c12513ff..4523e3ae 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -590,14 +590,10 @@ class APIHandler(BaseHTTPRequestHandler): # Determine response type self.request_id = f"chatcmpl-{uuid.uuid4()}" self.object_type = "chat.completion.chunk" if self.stream else "chat.completion" - if ( - hasattr(self.tokenizer, "apply_chat_template") - and self.tokenizer.chat_template - ): + if self.tokenizer.chat_template: prompt = self.tokenizer.apply_chat_template( body["messages"], body.get("tools", None), - tokenize=True, add_generation_prompt=True, ) else: diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 20b32eff..fa848f47 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -10,41 +10,47 @@ class Dataset: Light-weight wrapper to hold a dataset. """ - def __init__(self, data: List[Dict[str, str]], text_key: str = "text"): - self._text_key = text_key - self._data = data + def __init__( + self, + data: List[Dict[str, str]], + tokenizer: PreTrainedTokenizer, + text_key: str = "text", + ): + self._data = [tokenizer.encode(d[text_key]) for d in data] + for d in self._data: + if d[-1] != tokenizer.eos_token_id: + d.append(tokenizer.eos_token_id) def __getitem__(self, idx: int): - return self._data[idx][self._text_key] + return self._data[idx] def __len__(self): - if self._data is None: - return 0 return len(self._data) -class ChatDataset(Dataset): +class ChatDataset: """ A dataset for chat data in the format of {"messages": [...]} https://platform.openai.com/docs/guides/fine-tuning/example-format """ def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): - super().__init__(data) - self._tokenizer = tokenizer + self._data = [ + tokenizer.apply_chat_template( + d["messages"], + tools=d.get("tools", None), + ) + for d in data + ] def __getitem__(self, idx: int): - messages = self._data[idx]["messages"] - text = self._tokenizer.apply_chat_template( - messages, - tools=self._data[idx].get("tools", None), - tokenize=False, - add_generation_prompt=True, - ) - return text + return self._data[idx] + + def __len__(self): + return len(self._data) -class CompletionsDataset(Dataset): +class CompletionsDataset: """ A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...} or using user-provided keys for prompt and completion values @@ -58,25 +64,24 @@ class CompletionsDataset(Dataset): prompt_key: str = "prompt", completion_key: str = "completion", ): - super().__init__(data) - self._tokenizer = tokenizer - self._prompt_key = prompt_key - self._completion_key = completion_key + self._data = [ + tokenizer.apply_chat_template( + [ + {"role": "user", "content": d[prompt_key]}, + {"role": "assistant", "content": d[completion_key]}, + ], + ) + for d in data + ] def __getitem__(self, idx: int): - data = self._data[idx] - text = self._tokenizer.apply_chat_template( - [ - {"role": "user", "content": data[self._prompt_key]}, - {"role": "assistant", "content": data[self._completion_key]}, - ], - tokenize=False, - add_generation_prompt=True, - ) - return text + return self._data[idx] + + def __len__(self): + return len(self._data) -def create_dataset(data, tokenizer: PreTrainedTokenizer = None): +def create_dataset(data, tokenizer: PreTrainedTokenizer): sample = data[0] if "messages" in sample: @@ -84,7 +89,7 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer = None): elif "prompt" in sample and "completion" in sample: return CompletionsDataset(data, tokenizer) elif "text" in sample: - return Dataset(data) + return Dataset(data, tokenizer) else: raise ValueError( "Unsupported data format, check the supported formats here:\n" @@ -143,7 +148,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): if prompt_feature and completion_feature: return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) elif text_feature: - return Dataset(train_ds, text_key=text_feature) + return Dataset(train_ds, tokenizer, text_key=text_feature) else: raise ValueError( "Specify either a prompt and completion feature or a text " @@ -166,7 +171,7 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): def load_dataset(args, tokenizer: PreTrainedTokenizer): - if getattr(args, "hf_dataset", None) is not None: + if getattr(args, "hf_dataset", False): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 21b1af18..a76b8336 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) while True: indices = np.random.permutation(len(batch_idx)) for i in indices: - # Encode batch - batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]] - for b in batch: - if b[-1] != tokenizer.eos_token_id: - b.append(tokenizer.eos_token_id) - + batch = [dataset[j] for j in batch_idx[i]] lengths = [len(x) for x in batch] - if max(lengths) > max_seq_length: print( f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 4d69115e..0c35d07f 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -353,9 +353,13 @@ def stream_generate( tokenizer = TokenizerWrapper(tokenizer) if not isinstance(prompt, mx.array): - prompt = mx.array( - prompt if isinstance(prompt, list) else tokenizer.encode(prompt) - ) + if isinstance(prompt, str): + # Try to infer if special tokens are needed + add_special_tokens = tokenizer.bos_token is None or not prompt.startswith( + tokenizer.bos_token + ) + prompt = tokenizer.encode(prompt, add_special_tokens=add_special_tokens) + prompt = mx.array(prompt) detokenizer = tokenizer.detokenizer @@ -401,7 +405,7 @@ def stream_generate( def generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: str, + prompt: Union[str, List[int]], verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, @@ -412,7 +416,7 @@ def generate( Args: model (nn.Module): The language model. tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (str): The string prompt. + prompt (Union[str, List[int]]): The input prompt string or integer tokens. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. kwargs: The remaining options get passed to :func:`stream_generate`. @@ -425,7 +429,6 @@ def generate( ) if verbose: print("=" * 10) - print("Prompt:", prompt) text = "" for response in stream_generate(model, tokenizer, prompt, **kwargs): @@ -654,10 +657,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): prompt="hello" - if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None: + if tokenizer.chat_template is not None: messages = [{{"role": "user", "content": prompt}}] prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True + messages, add_generation_prompt=True ) response = generate(model, tokenizer, prompt=prompt, verbose=True) diff --git a/llms/tests/test_datsets.py b/llms/tests/test_datsets.py index 240bfb4a..dd86d277 100644 --- a/llms/tests/test_datsets.py +++ b/llms/tests/test_datsets.py @@ -36,7 +36,8 @@ class TestDatasets(unittest.TestCase): data = {"text": "This is an example for the model."} self.save_data(4 * [data]) args = types.SimpleNamespace(train=True, test=False, data=self.test_dir) - train, valid, test = datasets.load_dataset(args, None) + tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH) + train, valid, test = datasets.load_dataset(args, tokenizer) self.assertEqual(len(train), 4) self.assertEqual(len(valid), 4) self.assertEqual(len(test), 0) @@ -82,6 +83,8 @@ class TestDatasets(unittest.TestCase): "name": "billsum", "prompt_feature": "text", "completion_feature": "summary", + "train_split": "train[:2%]", + "valid_split": "train[-2%:]", }, test=False, train=True, From 25ec2d8c4496be68acf7e0c9ea1ae4269e1a2a19 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 5 Jan 2025 22:26:05 -0800 Subject: [PATCH 65/77] Change the eos-token argument for mlx_lm.generate (#1176) --- llms/mlx_lm/generate.py | 9 +++++---- llms/mlx_lm/tokenizer_utils.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 1ea66384..3301edae 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -43,10 +43,11 @@ def setup_arg_parser(): help="Optional path for the trained adapter weights and config.", ) parser.add_argument( - "--eos-token", + "--extra-eos-token", type=str, default=None, - help="End of sequence token for tokenizer", + nargs="+", + help="Add tokens in the list of eos tokens that stop generation.", ) parser.add_argument( "--system-prompt", @@ -161,8 +162,6 @@ def main(): {} if not using_cache else json.loads(metadata["tokenizer_config"]) ) tokenizer_config["trust_remote_code"] = True - if args.eos_token is not None: - tokenizer_config["eos_token"] = args.eos_token model_path = args.model if using_cache: @@ -181,6 +180,8 @@ def main(): adapter_path=args.adapter_path, tokenizer_config=tokenizer_config, ) + for eos_token in args.extra_eos_token: + tokenizer.add_eos_token(eos_token) if args.use_default_chat_template: if tokenizer.chat_template is None: diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index ca3d6c06..1b5bdd77 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -266,6 +266,18 @@ class TokenizerWrapper: else {tokenizer.eos_token_id} ) + def add_eos_token(self, token: str): + token_id = None + try: + token_id = int(token) + except ValueError: + token_id = self._tokenizer.convert_tokens_to_ids(token) + + if token_id is None: + raise ValueError(f"'{token}' is not a token for this tokenizer") + + self._eos_token_ids.add(token_id) + def __getattr__(self, attr): if attr == "detokenizer": return self._detokenizer From f2619f507c7dcde70410cc2cbb1d4715476d79ee Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Mon, 6 Jan 2025 10:58:43 -0500 Subject: [PATCH 66/77] Add support for fewshot and apply chat template lm_eval functionality (#1180) * Add support for multiturn fewshot examples and chat templates Added two new arguments to the evaluation script: `--fewshot-as-multiturn` and `--apply-chat-template` which correspond to lm_eval options of similar names and are very often used to ensure apples-to-apples comparisons of lm_evaluation results * Add HF overrides for methods needed by added options * don't add duplicate bos --------- Co-authored-by: Awni Hannun --- .circleci/config.yml | 2 +- llms/mlx_lm/evaluate.py | 59 +++++++++++++++++++++++++++++------------ llms/setup.py | 4 +-- 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index cecd2d57..8367281e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -32,7 +32,7 @@ jobs: pip install --upgrade pip pip install unittest-xml-reporting cd llms/ - pip install -e ".[testing]" + pip install -e ".[test]" - run: name: Run Python tests command: | diff --git a/llms/mlx_lm/evaluate.py b/llms/mlx_lm/evaluate.py index bf7bf4d4..ca5e83bb 100644 --- a/llms/mlx_lm/evaluate.py +++ b/llms/mlx_lm/evaluate.py @@ -77,15 +77,19 @@ class MLXLM(LM): path_or_hf_repo: str, batch_size: int = 16, max_tokens: Optional[int] = None, + use_chat_template: Optional[bool] = None, ) -> None: super().__init__() self._batch_size = batch_size - self._model, self._tokenizer = load(path_or_hf_repo) - self._max_tokens = max_tokens or self._tokenizer.model_max_length + self._model, self.tokenizer = load(path_or_hf_repo) + self._max_tokens = max_tokens or self.tokenizer.model_max_length + self.use_chat_template = use_chat_template or ( + self.tokenizer.chat_template is not None + ) def _score_fn(self, inputs, tokenize=True, step_size=32): if tokenize: - inputs = self._tokenizer.encode(inputs) + inputs = self._tokenize(inputs) inputs = _pad_inputs(inputs, self._max_tokens, truncate=False) inputs = mx.array(inputs) inputs, targets = inputs[..., :-1], inputs[..., 1:] @@ -149,7 +153,12 @@ class MLXLM(LM): return results def _tokenize(self, texts): - return [tuple(self._tokenizer.encode(t)) for t in texts] + return [ + tuple( + self.tokenizer.encode(t, add_special_tokens=not self.use_chat_template) + ) + for t in texts + ] def loglikelihood(self, requests) -> list[tuple[float, bool]]: """Compute log-likelihood of generating a continuation from a context. @@ -221,6 +230,9 @@ class MLXLM(LM): ) return [(r[0], r[1] == r[2]) for r in results] + tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name + apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template + def loglikelihood_rolling(self, requests) -> list[float]: """Compute full log-likelihood of a string, with no truncation, for perplexity computation - We will use the full max context length of the model. @@ -283,21 +295,14 @@ class MLXLM(LM): completions = [] for context, until in tqdm(zip(contexts, untils), total=len(contexts)): - if self._tokenizer.chat_template is not None: - messages = [{"role": "user", "content": context}] - context = self._tokenizer.apply_chat_template( - messages, add_generation_prompt=True - ) - else: - context = self._tokenizer.encode(context) - + context = self._tokenize(context) max_tokens = min( self._max_tokens, - self._tokenizer.model_max_length - len(context), + self.tokenizer.model_max_length - len(context), ) text = "" for response in stream_generate( - self._model, self._tokenizer, prompt=context, max_tokens=max_tokens + self._model, self.tokenizer, prompt=context, max_tokens=max_tokens ): text += response.text if any(u in text for u in until): @@ -332,6 +337,21 @@ def main(): type=float, ) parser.add_argument("--seed", type=int, default=123, help="Random seed.") + parser.add_argument( + "--fewshot-as-multiturn", + action="store_true", + help="Whether to provide the fewshot examples as a multiturn " + "conversation or a single user turn.", + default=False, + ) + parser.add_argument( + "--apply-chat-template", + action=argparse.BooleanOptionalAction, + help="Specifies whether to apply a chat template to the prompt. If " + "the model has a chat template, this defaults to `True`, " + "otherwise `False`.", + default=None, + ) args = parser.parse_args() output_dir = Path(args.output_dir) @@ -342,18 +362,23 @@ def main(): mx.random.seed(args.seed) - lm = MLXLM(args.model, batch_size=args.batch_size, max_tokens=args.max_tokens) - + lm = MLXLM( + args.model, + batch_size=args.batch_size, + max_tokens=args.max_tokens, + use_chat_template=args.apply_chat_template, + ) results = lm_eval.simple_evaluate( model=lm, tasks=args.tasks, + fewshot_as_multiturn=args.fewshot_as_multiturn, + apply_chat_template=lm.use_chat_template, num_fewshot=args.num_shots, limit=args.limit, random_seed=args.seed, numpy_random_seed=args.seed, torch_random_seed=args.seed, fewshot_random_seed=args.seed, - apply_chat_template=True, ) model_name = args.model.replace("/", "_") diff --git a/llms/setup.py b/llms/setup.py index b88dcd33..e6fddbae 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -27,8 +27,8 @@ setup( packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"], python_requires=">=3.8", extras_require={ - "testing": ["datasets"], - "evaluation": ["lm-eval"], + "test": ["datasets"], + "evaluate": ["lm-eval", "tqdm"], }, entry_points={ "console_scripts": [ From 9183fe8b6d6b5e86cac0f47b54675f272c9f3591 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 6 Jan 2025 10:12:07 -0800 Subject: [PATCH 67/77] fix (#1192) --- llms/mlx_lm/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 3301edae..26481d6b 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -45,7 +45,7 @@ def setup_arg_parser(): parser.add_argument( "--extra-eos-token", type=str, - default=None, + default=(), nargs="+", help="Add tokens in the list of eos tokens that stop generation.", ) From b8f0cacfa8dd08aaca7025351a7afddd481ca490 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 7 Jan 2025 18:18:31 +0100 Subject: [PATCH 68/77] Use upload_large_folder (#1193) --- llms/mlx_lm/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0c35d07f..ad79349e 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -673,12 +673,10 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): api = HfApi() api.create_repo(repo_id=upload_repo, exist_ok=True) - api.upload_folder( + api.upload_large_folder( folder_path=path, repo_id=upload_repo, repo_type="model", - multi_commits=True, - multi_commits_verbose=True, ) print(f"Upload successful, go to https://huggingface.co/{upload_repo} for details.") From 40b88eff488d82b8d8739de6d60f59c1f0789a14 Mon Sep 17 00:00:00 2001 From: Jarrett <2613089+jjaareet@users.noreply.github.com> Date: Thu, 9 Jan 2025 12:33:54 -0700 Subject: [PATCH 69/77] fix(lora): config yaml & arg default merge bug (#1196) --- llms/mlx_lm/lora.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 6fb86917..4d050bd5 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -58,6 +58,8 @@ CONFIG_DEFAULTS = { "test": False, "test_batches": 500, "max_seq_length": 2048, + "config": None, + "grad_checkpoint": False, "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, } @@ -67,6 +69,7 @@ def build_parser(): parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser.add_argument( "--model", + type=str, help="The path to the local model directory or Hugging Face repo.", ) @@ -75,7 +78,6 @@ def build_parser(): "--train", action="store_true", help="Do training", - default=None, ) parser.add_argument( "--data", @@ -89,7 +91,6 @@ def build_parser(): "--fine-tune-type", type=str, choices=["lora", "dora", "full"], - default="lora", help="Type of fine-tuning to perform: lora, dora, or full.", ) parser.add_argument( @@ -134,7 +135,6 @@ def build_parser(): "--test", action="store_true", help="Evaluate on the test set after training", - default=None, ) parser.add_argument( "--test-batches", @@ -149,16 +149,15 @@ def build_parser(): parser.add_argument( "-c", "--config", - default=None, + type=str, help="A YAML configuration file with the training options", ) parser.add_argument( "--grad-checkpoint", action="store_true", help="Use gradient checkpointing to reduce memory use.", - default=None, ) - parser.add_argument("--seed", type=int, default=None, help="The PRNG seed") + parser.add_argument("--seed", type=int, help="The PRNG seed") return parser From 5cae0a60e6acb3599483a9304aebbc89e0bff1c4 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 9 Jan 2025 15:55:53 -0800 Subject: [PATCH 70/77] deepseek v3 model with pipeline parallelism (#1191) * deepseekv3 * use upload_large_file instead of deprecated multi comit * add pipeline generation and example * comment * get fp16 working * use mlx==0.22 --- llms/mlx_lm/_version.py | 2 +- llms/mlx_lm/examples/pipeline_generate.py | 75 ++++ llms/mlx_lm/models/deepseek_v3.py | 460 ++++++++++++++++++++++ llms/mlx_lm/requirements.txt | 2 +- llms/mlx_lm/utils.py | 4 +- llms/tests/test_models.py | 37 ++ llms/tests/test_utils_load_model.py | 2 +- 7 files changed, 577 insertions(+), 5 deletions(-) create mode 100644 llms/mlx_lm/examples/pipeline_generate.py create mode 100644 llms/mlx_lm/models/deepseek_v3.py diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 3af2d5fd..b2f98e6f 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.20.4" +__version__ = "0.21.0" diff --git a/llms/mlx_lm/examples/pipeline_generate.py b/llms/mlx_lm/examples/pipeline_generate.py new file mode 100644 index 00000000..b98e757b --- /dev/null +++ b/llms/mlx_lm/examples/pipeline_generate.py @@ -0,0 +1,75 @@ +# Copyright © 2024 Apple Inc. + +""" +Run with: + +``` +/path/to/mpirun \ + -np 2 \ + --hostfile /path/to/hosts.txt \ + python /path/to/pipeline_generate.py --prompt "hello world" +``` + +Make sure you can run MLX over MPI on two hosts. For more information see the +documentation: + +https://ml-explore.github.io/mlx/build/html/usage/distributed.html). +""" + +import argparse + +import mlx.core as mx +from mlx_lm import load, stream_generate + +parser = argparse.ArgumentParser(description="LLM pipelined inference example") +parser.add_argument( + "--prompt", + "-p", + default="Write a quicksort in C++.", + help="Message to be processed by the model ('-' reads from stdin)", +) +parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=256, + help="Maximum number of tokens to generate", +) +args = parser.parse_args() + +model_repo = "mlx-community/DeepSeek-V3-3bit" + +model, tokenizer = load(model_repo, lazy=True) + +messages = [{"role": "user", "content": args.prompt}] +prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) + +group = mx.distributed.init() +rank = group.rank() +model.model.pipeline(group) +mx.eval(model.parameters()) + +# Synchronize processes before generation to avoid timeout if downloading +# model for the first time. +mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu)) + + +def rprint(*args, **kwargs): + if rank == 0: + print(*args, **kwargs) + + +for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens): + rprint(response.text, end="", flush=True) + +rprint() +rprint("=" * 10) +rprint( + f"Prompt: {response.prompt_tokens} tokens, " + f"{response.prompt_tps:.3f} tokens-per-sec" +) +rprint( + f"Generation: {response.generation_tokens} tokens, " + f"{response.generation_tps:.3f} tokens-per-sec" +) +rprint(f"Peak memory: {response.peak_memory:.3f} GB") diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py new file mode 100644 index 00000000..f95949f9 --- /dev/null +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -0,0 +1,460 @@ +# Copyright © 2024 Apple Inc. + +import math +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str = "deepseek_v3" + vocab_size: int = 102400 + hidden_size: int = 4096 + intermediate_size: int = 11008 + moe_intermediate_size: int = 1407 + num_hidden_layers: int = 30 + num_attention_heads: int = 32 + num_key_value_heads: int = 32 + n_shared_experts: Optional[int] = None + n_routed_experts: Optional[int] = None + routed_scaling_factor: float = 1.0 + kv_lora_rank: int = 512 + q_lora_rank: int = 1536 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + qk_nope_head_dim: int = 128 + topk_method: str = "noaux_tc" + scoring_func: str = "sigmoid" + norm_topk_prob: bool = True + n_group: Optional[int] = None + topk_group: Optional[int] = None + num_experts_per_tok: Optional[int] = None + moe_layer_freq: int = 1 + first_k_dense_replace: int = 0 + max_position_embeddings: int = 2048 + rms_norm_eps: float = 1e-6 + rope_theta: float = 10000.0 + rope_scaling: Dict = None + attention_bias: bool = False + + +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min_val, max_val, dim): + if min_val == max_val: + max_val += 0.001 # Prevent singularity + + linear_func = (mx.arange(dim, dtype=mx.float32) - min_val) / (max_val - min_val) + return mx.clip(linear_func, 0, 1) + + +class DeepseekV3YarnRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + super().__init__() + self.mscale = yarn_get_mscale(scaling_factor, mscale) / yarn_get_mscale( + scaling_factor, mscale_all_dim + ) + freq_extra = base ** (mx.arange(0, dim, 2, dtype=mx.float32) / dim) + freq_inter = scaling_factor * base ** ( + mx.arange(0, dim, 2, dtype=mx.float32) / dim + ) + low, high = yarn_find_correction_range( + beta_fast, + beta_slow, + dim, + base, + original_max_position_embeddings, + ) + freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) + self._freqs = (freq_inter * freq_extra) / ( + freq_inter * freq_mask + freq_extra * (1 - freq_mask) + ) + + def __call__(self, x, offset=0): + if self.mscale != 1.0: + x = self.mscale * x + return mx.fast.rope( + x, + x.shape[-1], + traditional=True, + base=None, + scale=1.0, + offset=offset, + freqs=self._freqs, + ) + + +class DeepseekV3Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.scale = self.q_head_dim**-0.5 + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, self.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = nn.RMSNorm(self.q_lora_rank) + self.q_b_proj = nn.Linear( + self.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = nn.RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scale = self.scale * mscale * mscale + + rope_kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rope = DeepseekV3YarnRotaryEmbedding( + dim=self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **rope_kwargs, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + if self.q_lora_rank is None: + q = self.q_proj(x) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) + + q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) + q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1) + compressed_kv = self.kv_a_proj_with_mqa(x) + compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1) + k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) + + k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) + + if cache is not None: + q_pe = self.rope(q_pe, cache.offset) + k_pe = self.rope(k_pe, cache.offset) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) + keys, values = cache.update_and_fetch( + mx.concatenate([k_nope, k_pe], axis=-1), values + ) + else: + q_pe = self.rope(q_pe) + k_pe = self.rope(k_pe) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) + keys = mx.concatenate([k_nope, k_pe], axis=-1) + + queries = mx.concatenate([q_nope, q_pe], axis=-1) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class DeepseekV3MLP(nn.Module): + def __init__( + self, config: ModelArgs, hidden_size: int = None, intermediate_size: int = None + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + + def __call__(self, x): + down_proj = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + self.weight = mx.zeros((self.n_routed_experts, config.hidden_size)) + self.e_score_correction_bias = mx.zeros((self.n_routed_experts,)) + + def __call__(self, x): + gates = x @ self.weight.T + + scores = mx.sigmoid(gates.astype(mx.float32)) + + assert self.topk_method == "noaux_tc", "Unsupported topk method." + bsz, seq_len = x.shape[:2] + scores = scores + self.e_score_correction_bias + scores = scores.reshape(bsz, seq_len, self.n_group, -1) + group_scores = mx.topk(scores, 2, axis=-1).sum(axis=-1) + k = self.n_group - self.topk_group + group_idx = mx.argpartition(group_scores, kth=k - 1, axis=-1)[..., :k] + batch_idx = mx.expand_dims(mx.arange(bsz), (1, 2)) + seq_idx = mx.expand_dims(mx.arange(seq_len), (0, 2)) + scores[batch_idx, seq_idx, group_idx] = 0.0 + scores = scores.reshape(bsz, seq_len, -1) + + k = self.top_k + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] + scores = mx.take_along_axis(scores, inds, axis=-1) + if self.top_k > 1 and self.norm_topk_prob: + denominator = scores.sum(axis=-1, keepdims=True) + 1e-20 + scores = scores / denominator + scores = scores * self.routed_scaling_factor + + return inds, scores + + +class DeepseekV3MoE(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + self.switch_mlp = SwitchGLU( + config.hidden_size, config.moe_intermediate_size, config.n_routed_experts + ) + + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV3MLP( + config=config, intermediate_size=intermediate_size + ) + + def __call__(self, x): + inds, scores = self.gate(x) + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2).astype(y.dtype) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(x) + + return y + + +class DeepseekV3DecoderLayer(nn.Module): + def __init__(self, config: ModelArgs, layer_idx: int): + super().__init__() + self.self_attn = DeepseekV3Attention(config) + self.mlp = ( + DeepseekV3MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV3MLP(config) + ) + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + # Protect against overflow for fp16 + if out.dtype == mx.float16: + out = mx.clip(out, a_min=None, a_max=mx.finfo(mx.float16).max - 1000) + return out + + +class DeepseekV3Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = [ + DeepseekV3DecoderLayer(config, idx) + for idx in range(config.num_hidden_layers) + ] + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pipeline_rank = 0 + self.pipeline_size = 1 + + def pipeline(self, group): + # Split layers in reverse so rank=0 gets the last layers and + # rank=pipeline_size-1 gets the first + self.pipeline_rank = group.rank() + self.pipeline_size = group.size() + layers_per_rank = ( + len(self.layers) + self.pipeline_size - 1 + ) // self.pipeline_size + start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank + self.layers = self.layers[start : start + layers_per_rank] + + def __call__( + self, + x: mx.array, + cache: Optional[Any] = None, + mask: Optional[mx.array] = None, + ) -> mx.array: + h = self.embed_tokens(x) + + pipeline_rank = self.pipeline_rank + pipeline_size = self.pipeline_size + if mask is None: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + # Receive from the previous process in the pipeline + if pipeline_rank < pipeline_size - 1: + h = mx.distributed.recv_like(h, (pipeline_rank + 1)) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, c) + + # Send to the next process in the pipeline + if pipeline_rank != 0: + h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size) + + # Broadcast h while keeping it in the graph + h = mx.distributed.all_gather(h)[: h.shape[0]] + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + self.args = config + self.model_type = config.model_type + self.model = DeepseekV3Model(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache: Optional[Any] = None, + mask: Optional[mx.array] = None, + ): + out = self.model(inputs, cache, mask) + return self.lm_head(out) + + def sanitize(self, weights): + for l in range(self.args.num_hidden_layers): + prefix = f"model.layers.{l}" + for n, m in [("w1", "gate_proj"), ("w2", "down_proj"), ("w3", "up_proj")]: + for k in ["weight", "scales", "biases"]: + if f"{prefix}.mlp.experts.0.{m}.{k}" in weights: + to_join = [ + weights.pop(f"{prefix}.mlp.experts.{e}.{m}.{k}") + for e in range(self.args.n_routed_experts) + ] + weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) + + # Remove multi-token prediction layer + return {k: v for k, v in weights.items() if not k.startswith("model.layers.61")} + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 48012863..72e1ef89 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.19.2 +mlx>=0.22.0 numpy transformers[sentencepiece]>=4.39.3 protobuf diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index ad79349e..0e06b5a0 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -561,7 +561,7 @@ def load( Defaults to an empty dictionary. adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers to the model. Default: ``None``. - lazy (bool): If False eval the model parameters to make sure they are + lazy (bool): If ``False`` eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` Returns: @@ -655,7 +655,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): model, tokenizer = load("{upload_repo}") - prompt="hello" + prompt = "hello" if tokenizer.chat_template is not None: messages = [{{"role": "user", "content": prompt}}] diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 7b4376bb..118ec6f2 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -682,6 +682,43 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_deepseek_v3(self): + from mlx_lm.models import deepseek_v3 + + args = deepseek_v3.ModelArgs( + model_type="deepseek_v3", + vocab_size=1024, + hidden_size=128, + intermediate_size=256, + moe_intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + n_routed_experts=4, + n_group=2, + topk_group=1, + num_experts_per_tok=2, + n_shared_experts=1, + kv_lora_rank=4, + q_lora_rank=4, + qk_rope_head_dim=32, + v_head_dim=16, + qk_nope_head_dim=32, + rope_scaling={ + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn", + }, + ) + model = deepseek_v3.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + def test_gemma2(self): from mlx_lm.models import gemma2 diff --git a/llms/tests/test_utils_load_model.py b/llms/tests/test_utils_load_model.py index 5821f9e9..8da19afb 100644 --- a/llms/tests/test_utils_load_model.py +++ b/llms/tests/test_utils_load_model.py @@ -17,7 +17,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase): self.config = args self.custom_attribute = "This is a custom model" - def load_weights(self, weights): + def load_weights(self, weights, **kwargs): self.qwenWeights = weights class CustomQwenConfig: From 93c5cfd7819cac681bd35f8c928f752d72da8334 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 10 Jan 2025 15:27:08 -0800 Subject: [PATCH 71/77] Add a speculative decoding generator (#1155) * add a speculative decoding generator * fix * fixes * optional kwarg pop --- llms/mlx_lm/generate.py | 21 +++- llms/mlx_lm/utils.py | 209 ++++++++++++++++++++++++++++++++++------ 2 files changed, 198 insertions(+), 32 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 26481d6b..0d286c75 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -131,6 +131,18 @@ def setup_arg_parser(): type=int, default=DEFAULT_QUANTIZED_KV_START, ) + parser.add_argument( + "--draft-model", + type=str, + help="A model to be used for speculative decoding.", + default=None, + ) + parser.add_argument( + "--num-draft-tokens", + type=int, + help="Number of tokens to draft when using speculative decoding.", + default=2, + ) return parser @@ -211,11 +223,16 @@ def main(): add_generation_prompt=True, ) prompt = prompt[test_prompt.index("") :] - prompt = tokenizer.encode(prompt, add_special_tokens=False) else: prompt = tokenizer.encode(prompt) + if args.draft_model is not None: + draft_model, draft_tokenizer = load(args.draft_model) + if draft_tokenizer.vocab_size != tokenizer.vocab_size: + raise ValueError("Draft model tokenizer does not match model tokenizer.") + else: + draft_model = None sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate( model, @@ -229,6 +246,8 @@ def main(): kv_bits=args.kv_bits, kv_group_size=args.kv_group_size, quantized_kv_start=args.quantized_kv_start, + draft_model=draft_model, + num_draft_tokens=args.num_draft_tokens, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0e06b5a0..2fc0446b 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -2,6 +2,7 @@ import contextlib import copy +import functools import glob import importlib import json @@ -207,12 +208,6 @@ def generate_step( kv_group_size: int = 64, quantized_kv_start: int = 0, prompt_progress_callback: Optional[Callable[int, int]] = None, - temp: Optional[float] = None, - repetition_penalty: Optional[float] = None, - repetition_context_size: Optional[int] = None, - top_p: Optional[float] = None, - min_p: Optional[float] = None, - min_tokens_to_keep: Optional[int] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -256,25 +251,17 @@ def generate_step( elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") - if temp is not None or top_p is not None or min_tokens_to_keep is not None: - print( - "[Warning] Specifying sampling arguments to ``generate_step`` is " - "deprecated. Pass in a ``sampler`` instead." - ) - if repetition_penalty is not None: - print( - "[Warning] Specifying ``repetition_penalty`` is deprecated. " - "Pass in ``logits_processors`` instead." - ) - - sampler = sampler or make_sampler( - temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1 - ) - logits_processors = logits_processors or make_logits_processors( - None, repetition_penalty, repetition_context_size or 20 - ) prompt_progress_callback = prompt_progress_callback or (lambda *_: None) + quantize_cache_fn = functools.partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + def _step(y): with mx.stream(generation_stream): logits = model(y[None], cache=prompt_cache) @@ -287,9 +274,7 @@ def generate_step( for processor in logits_processors: logits = processor(tokens, logits) - maybe_quantize_kv_cache( - prompt_cache, quantized_kv_start, kv_group_size, kv_bits - ) + quantize_cache_fn(prompt_cache) logprobs = logits - mx.logsumexp(logits, keepdims=True) y = sampler(logprobs) @@ -300,9 +285,7 @@ def generate_step( prompt_processed_tokens = 0 while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) - maybe_quantize_kv_cache( - prompt_cache, quantized_kv_start, kv_group_size, kv_bits - ) + quantize_cache_fn(prompt_cache) mx.eval([c.state for c in prompt_cache]) prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) prompt_processed_tokens += prefill_step_size @@ -329,10 +312,162 @@ def generate_step( n += 1 +def speculative_generate_step( + prompt: mx.array, + model: nn.Module, + draft_model: nn.Module, + *, + num_draft_tokens=2, + max_tokens: int = 256, + sampler: Optional[Callable[mx.array, mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prompt_cache: Optional[Any] = None, + prefill_step_size: int = 512, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, +) -> Generator[Tuple[mx.array, mx.array], None, None]: + """ + A generator producing token ids based on the given prompt from the model. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + draft_model (nn.Module): The draft model for speculative decoding. + num_draft_tokens (int, optional): The number of draft tokens for + speculative decoding. Default: ``2``. + max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite + generator. Default: ``256``. + sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a + token from a vector of log probabilities. Default: ``None``. + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): + A list of functions that take tokens and logits and return the processed + logits. Default: ``None``. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. The cache must be trimmable. + prefill_step_size (int): Step size for processing the prompt. + kv_bits (int, optional): Number of bits to use for KV cache quantization. + None implies no cache quantization. Default: ``None``. + kv_group_size (int): Group size for KV cache quantization. Default: ``64``. + quantized_kv_start (int): Step to begin using a quantized KV cache. + when ``kv_bits`` is non-None. Default: ``0``. + + Yields: + Tuple[mx.array, mx.array]: One token and a vector of log probabilities. + """ + + y = prompt + tokens = None + + # Create the KV cache for generation + if prompt_cache is None: + model_cache = cache.make_prompt_cache(model) + draft_cache = cache.make_prompt_cache(draft_model) + elif len(prompt_cache) != (len(model.layers) + len(draft_model.layers)): + raise ValueError("Wrong number of layers in the prompt cache.") + else: + model_cache = prompt_cache[: len(model.layers)] + draft_cache = prompt_cache[len(model.layers) :] + + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + + quantize_cache_fn = functools.partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + + def _step(model, cache, y, n_predict=1): + with mx.stream(generation_stream): + logits = model(y[None], cache=cache) + logits = logits[:, -n_predict:, :] + + quantize_cache_fn(cache) + + logprobs = logits - mx.logsumexp(logits, keepdims=True) + y = sampler(logprobs).squeeze(0) + return y, logprobs.squeeze(0) + + def _prefill(model, cache, y): + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=cache) + quantize_cache_fn(cache) + mx.eval([c.state for c in cache]) + y = y[prefill_step_size:] + mx.metal.clear_cache() + return y + + def _rewind_cache(num_draft, num_accept): + cache.trim_prompt_cache(model_cache, num_draft - num_accept) + cache.trim_prompt_cache(draft_cache, max(num_draft - num_accept - 1, 0)) + + def _draft_generate(y, num_draft): + if num_draft == 0: + return mx.array([], mx.uint32) + ys = [] + for _ in range(num_draft): + y, _ = _step(draft_model, draft_cache, y) + mx.async_eval(y) + ys.append(y) + return mx.concatenate(ys) + + with mx.stream(generation_stream): + draft_y = _prefill(draft_model, draft_cache, y) + y = _prefill(model, model_cache, y) + + ntoks = 0 + # Set these so the finally block doesn't raise + num_draft = 0 + n = 0 + try: + while True: + num_draft = min(max_tokens - ntoks, num_draft_tokens) + draft_tokens = _draft_generate(draft_y, num_draft) + y = mx.concatenate([y, draft_tokens]) + + tokens, logprobs = _step(model, model_cache, y, num_draft + 1) + mx.eval(tokens, draft_tokens) + draft_tokens = draft_tokens.tolist() + tokens = tokens.tolist() + n = 0 + while n < num_draft: + tn, dtn, lpn = tokens[n], draft_tokens[n], logprobs[n] + if tn != dtn: + break + n += 1 + ntoks += 1 + yield tn, lpn + if ntoks == max_tokens: + break + if ntoks < max_tokens: + ntoks += 1 + yield tokens[n], logprobs[n] + + if ntoks == max_tokens: + break + + y = mx.array([tokens[n]], mx.uint32) + draft_y = y + + # If we accpeted all the draft tokens, include the last + # draft token in the next draft step since it hasn't been + # processed yet by the draft model + if n == num_draft: + draft_y = mx.concatenate( + [mx.array(draft_tokens[-1:], mx.uint32), draft_y] + ) + + _rewind_cache(num_draft, n) + finally: + _rewind_cache(num_draft, n) + + def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: Union[str, mx.array, List[int]], + draft_model: Optional[nn.Module] = None, **kwargs, ) -> Generator[GenerationResponse, None, None]: """ @@ -341,7 +476,11 @@ def stream_generate( Args: model (nn.Module): The model to use for generation. tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. + prompt (Union[str, mx.array, List[int]]): The input prompt string or + integer tokens. + draft_model (Optional[nn.Module]): An optional draft model. If provided + then speculative decoding is used. The draft model must use the same + tokenizer as the main model. Default: ``None``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. @@ -363,10 +502,18 @@ def stream_generate( detokenizer = tokenizer.detokenizer + if draft_model is None: + kwargs.pop("num_draft_tokens", None) + token_generator = generate_step(prompt, model, **kwargs) + else: + kwargs.pop("max_kv_size", None) + token_generator = speculative_generate_step( + prompt, model, draft_model, **kwargs + ) with wired_limit(model, [generation_stream]): detokenizer.reset() tic = time.perf_counter() - for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): + for n, (token, logprobs) in enumerate(token_generator): if n == 0: prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time From 514502da22f0dc4c1ac439bdf78c07d5ec41acf7 Mon Sep 17 00:00:00 2001 From: "Xingjun.Wang" Date: Sat, 11 Jan 2025 07:29:34 +0800 Subject: [PATCH 72/77] Support snapshot_download for ModelScope (#1194) * add MLX_USE_MODELSCOPE env * update * update snapshot_download * update * remove modelscope dependency and add import check * update * nits * fix --------- Co-authored-by: wangxingjun778 Co-authored-by: Awni Hannun --- llms/mlx_lm/utils.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 2fc0446b..b9037295 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -7,6 +7,7 @@ import glob import importlib import json import logging +import os import shutil import time from dataclasses import dataclass @@ -16,7 +17,17 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn -from huggingface_hub import snapshot_download + +if os.getenv("MLXLM_USE_MODELSCOPE", "False").lower() == "true": + try: + from modelscope import snapshot_download + except ImportError: + raise ImportError( + "Please run `pip install modelscope` to activate the ModelScope." + ) +else: + from huggingface_hub import snapshot_download + from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer @@ -154,11 +165,12 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path Path: The path to the model. """ model_path = Path(path_or_hf_repo) + if not model_path.exists(): try: model_path = Path( snapshot_download( - repo_id=path_or_hf_repo, + path_or_hf_repo, revision=revision, allow_patterns=[ "*.json", From bf2da36fc640e6bfab933ac8c10d76c86fcdb288 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 12 Jan 2025 21:58:08 +0100 Subject: [PATCH 73/77] Fix Cohere2: mask shape error (long context) (#1202) * fix mask shape error (long context) * Update llms/mlx_lm/models/cohere2.py Co-authored-by: Awni Hannun * revert layer_idx * black formatting * Update cohere2.py * format --------- Co-authored-by: Awni Hannun Co-authored-by: Awni Hannun --- llms/mlx_lm/models/cohere2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index ec0e9276..19bfa6b6 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -156,12 +156,13 @@ class CohereModel(nn.Module): ): h = self.embed_tokens(inputs) - if mask is None: - mask = create_attention_mask(h, cache) - if cache is None: cache = [None] * len(self.layers) + if mask is None: + j = self.args.sliding_window_pattern + mask = create_attention_mask(h, cache[j - 1 : j]) + for layer, c in zip(self.layers, cache): h = layer(h, mask, c) From 0228c46434157adaa48b44f9a227d2bb93354dc3 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Mon, 13 Jan 2025 13:01:18 -0500 Subject: [PATCH 74/77] Custom local dataset features (#1085) * Generalize prompt_feature and completion_feature for use in local datasets to facilitate compatibility with many other training dataset formats. * Persist configured prompt/completion key * rebase + nits --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/LORA.md | 17 +++++++++-- llms/mlx_lm/tuner/datasets.py | 55 ++++++++++++++++++++++++++--------- 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 15676360..9eac9d7f 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -241,14 +241,25 @@ Refer to the documentation for the model you are fine-tuning for more details. {"prompt": "What is the capital of France?", "completion": "Paris."} ``` +For the `completions` data format, a different key can be used for the prompt +and completion by specifying the following in the YAML config: + +```yaml +prompt_feature: "input" +completion_feature: "output" +``` + +Here, `"input"` is the expected key instead of the default `"prompt"`, and +`"output"` is the expected key instead of `"completion"`. + `text`: ```jsonl {"text": "This is an example for the model."} ``` -Note, the format is automatically determined by the dataset. Note also, keys in -each line not expected by the loader will be ignored. +Note, the format is automatically determined by the dataset. Note also, keys +in each line not expected by the loader will be ignored. > [!NOTE] > Each example in the datasets must be on a single line. Do not put more than @@ -270,7 +281,7 @@ Otherwise, provide a mapping of keys in the dataset to the features MLX LM expects. Use a YAML config to specify the Hugging Face dataset arguments. For example: -``` +```yaml hf_dataset: name: "billsum" prompt_feature: "text" diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index fa848f47..1b09c7e2 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional from transformers import PreTrainedTokenizer @@ -61,8 +61,8 @@ class CompletionsDataset: self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, - prompt_key: str = "prompt", - completion_key: str = "completion", + prompt_key: str, + completion_key: str, ): self._data = [ tokenizer.apply_chat_template( @@ -81,13 +81,19 @@ class CompletionsDataset: return len(self._data) -def create_dataset(data, tokenizer: PreTrainedTokenizer): +def create_dataset( + data, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): + prompt_feature = prompt_feature or "prompt" + completion_feature = completion_feature or "completion" sample = data[0] - if "messages" in sample: return ChatDataset(data, tokenizer) - elif "prompt" in sample and "completion" in sample: - return CompletionsDataset(data, tokenizer) + elif prompt_feature in sample and completion_feature in sample: + return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature) elif "text" in sample: return Dataset(data, tokenizer) else: @@ -97,20 +103,30 @@ def create_dataset(data, tokenizer: PreTrainedTokenizer): ) -def load_local_dataset(data_path: Path, tokenizer: PreTrainedTokenizer): +def load_local_dataset( + data_path: Path, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): def load_subset(path): if not path.exists(): return [] with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer) + return create_dataset(data, tokenizer, prompt_feature, completion_feature) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names] return train, valid, test -def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): +def load_hf_dataset( + data_id: str, + tokenizer: PreTrainedTokenizer, + prompt_feature: Optional[str] = None, + completion_feature: Optional[str] = None, +): from datasets import exceptions, load_dataset try: @@ -119,7 +135,13 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): names = ("train", "valid", "test") train, valid, test = [ - create_dataset(dataset[n], tokenizer) if n in dataset.keys() else [] + ( + create_dataset( + dataset[n], tokenizer, prompt_feature, completion_feature + ) + if n in dataset.keys() + else [] + ) for n in names ] @@ -175,11 +197,18 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data) + + prompt_feature = getattr(args, "prompt_feature", None) + completion_feature = getattr(args, "completion_feature", None) if data_path.exists(): - train, valid, test = load_local_dataset(data_path, tokenizer) + train, valid, test = load_local_dataset( + data_path, tokenizer, prompt_feature, completion_feature + ) else: print(f"Loading Hugging Face dataset {args.data}.") - train, valid, test = load_hf_dataset(args.data, tokenizer) + train, valid, test = load_hf_dataset( + args.data, tokenizer, prompt_feature, completion_feature + ) if args.train and len(train) == 0: raise ValueError( From c117af83b8cbec15523bd0d69e7a57f01237ca89 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 13 Jan 2025 10:22:32 -0800 Subject: [PATCH 75/77] fix gpt bigcode (#1204) --- llms/mlx_lm/models/gpt_bigcode.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 8415c59e..1d9794b6 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -145,16 +145,16 @@ class GPTBigCodeModel(nn.Module): hidden_states = self.wte(inputs) mask = None - if hidden_states.shape[1] > 1: - - position_ids = mx.array(np.arange(L)) - hidden_states += self.wpe(position_ids) - - if mask is None: - mask = create_attention_mask(hidden_states, cache) + if mask is not None and hidden_states.shape[1] > 1: + mask = create_attention_mask(hidden_states, cache) if cache is None: cache = [None] * len(self.h) + position_ids = mx.array(np.arange(L)) + else: + position_ids = mx.array(np.arange(cache[0].offset, cache[0].offset + L)) + + hidden_states += self.wpe(position_ids) for layer, c in zip(self.h, cache): hidden_states = layer(hidden_states, mask, cache=c) From 6ae6c72c2ec9d4a2adb453712a75099b84e9593a Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Wed, 15 Jan 2025 02:20:42 +0100 Subject: [PATCH 76/77] reduction moved to CPU in case of distributed training (#1200) --- llms/mlx_lm/tuner/trainer.py | 8 ++++---- llms/tests/test_finetune.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index a76b8336..63ca58bb 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -159,8 +159,8 @@ def evaluate( ntokens += toks mx.eval(all_losses, ntokens) - all_losses = mx.distributed.all_sum(all_losses) - ntokens = mx.distributed.all_sum(ntokens) + all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) + ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) return (all_losses / ntokens).item() @@ -272,9 +272,9 @@ def train( if it % args.steps_per_report == 0 or it == args.iters: stop = time.perf_counter() - train_loss = mx.distributed.all_sum(losses).item() + train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * mx.distributed.init().size() - n_tokens = mx.distributed.all_sum(n_tokens).item() + n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) diff --git a/llms/tests/test_finetune.py b/llms/tests/test_finetune.py index 6ba81628..a6d53747 100644 --- a/llms/tests/test_finetune.py +++ b/llms/tests/test_finetune.py @@ -21,7 +21,7 @@ from mlx_lm.tuner.utils import build_schedule @contextmanager def swapped_with_identity(obj, func): old_func = getattr(obj, func) - setattr(obj, func, lambda x: x) + setattr(obj, func, lambda x, **kwargs: x) yield setattr(obj, func, old_func) From 50f0a7f6d99839de0b9439d609e136089f141a3c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 15 Jan 2025 14:55:41 -0800 Subject: [PATCH 77/77] add internlm3 (#1206) --- llms/mlx_lm/models/internlm3.py | 241 ++++++++++++++++++++++++++++++++ llms/mlx_lm/tuner/utils.py | 1 + llms/tests/test_models.py | 17 +++ 3 files changed, 259 insertions(+) create mode 100644 llms/mlx_lm/models/internlm3.py diff --git a/llms/mlx_lm/models/internlm3.py b/llms/mlx_lm/models/internlm3.py new file mode 100644 index 00000000..3be6f536 --- /dev/null +++ b/llms/mlx_lm/models/internlm3.py @@ -0,0 +1,241 @@ +# Copyright © 2023-2024 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + bias: bool = False + qkv_bias: bool = False + max_position_embeddings: int = 32768 + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = False + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling: + required_keys = {"factor", "rope_type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + if self.rope_scaling["rope_type"] not in ["linear", "dynamic"]: + raise ValueError( + "rope_scaling 'rope_type' currently only supports 'linear' or 'dynamic" + ) + + +class DynamicNTKScalingRoPE(nn.Module): + """Implements the rotary positional encoding with Dynamic NTK scaling.""" + + def __init__( + self, + dims: int, + max_position_embeddings: int = 2048, + traditional: bool = False, + base: float = 10000, + scale: float = 1.0, + ): + super().__init__() + self.max_position_embeddings = max_position_embeddings + self.original_base = base + self.dims = dims + self.traditional = traditional + self.scale = scale + + def extra_repr(self): + return f"{self.dims}, traditional={self.traditional}, max_position_embeddings={self.max_position_embeddings}, scaling_factor={self.scaling_factor}" + + def __call__(self, x, offset: int = 0): + seq_len = x.shape[1] + offset + if seq_len > self.max_position_embeddings: + base = self.original_base * ( + (self.scale * seq_len / self.max_position_embeddings) - (self.scale - 1) + ) ** (self.dims / (self.dims - 2)) + else: + base = self.original_base + + return mx.fast.rope( + x, + self.dims, + traditional=self.traditional, + base=base, + scale=self.scale, + offset=offset, + ) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + qkv_bias = args.qkv_bias + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.n_kv_groups = n_heads // args.num_key_value_heads + + self.head_dim = head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=qkv_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=qkv_bias) + + rope_scale = ( + 1 / args.rope_scaling["factor"] + if args.rope_scaling is not None + and args.rope_scaling["rope_type"] == "linear" + else 2.0 + ) + + self.rope = DynamicNTKScalingRoPE( + head_dim, + max_position_embeddings=args.max_position_embeddings, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim, bias): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=bias) + self.down_proj = nn.Linear(hidden_dim, dim, bias=bias) + self.up_proj = nn.Linear(dim, hidden_dim, bias=bias) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size, args.bias) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + +class InternLM2Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + assert args.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + h = self.embed_tokens(inputs) + + if mask is None: + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = InternLM2Model(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + mask: mx.array = None, + cache=None, + ): + out = self.model(inputs, mask, cache) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + def sanitize(self, weights): + # Remove unused precomputed rotary freqs + return {k: v for k, v in weights.items() if "attention.rope.inv_freq" not in k} + + @property + def layers(self): + return self.model.layers diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 3986952a..594f8040 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -100,6 +100,7 @@ def linear_to_lora_layers( "minicpm", "deepseek", "olmo2", + "internlm3", ]: keys = set(["self_attn.q_proj", "self_attn.v_proj"]) if model.model_type in ["mixtral", "phimoe"]: diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 118ec6f2..d8cf6820 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -927,6 +927,23 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_internlm3(self): + from mlx_lm.models import internlm3 + + args = internlm3.ModelArgs( + model_type="internlm3", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=4, + rms_norm_eps=1e-5, + vocab_size=10_000, + ) + model = internlm3.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + if __name__ == "__main__": unittest.main()